Repository: lss233/kirara-ai
Branch: master
Commit: 8295a5deda0b
Files: 338
Total size: 1.1 MB
Directory structure:
gitextract_9czk0gzq/
├── .cursor/
│ └── rules/
│ └── create-workflow.mdc
├── .dockerignore
├── .editorconfig
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug-report.md
│ │ └── feature-request.md
│ ├── dependabot.yml
│ ├── quickstarts/
│ │ └── windows/
│ │ └── scripts/
│ │ └── 启动.cmd
│ └── workflows/
│ ├── docker-latest.yml
│ ├── docker-tag.yml
│ ├── pr_review.yml
│ ├── project_check.yml
│ ├── quickstart-windows.yml
│ ├── run-tests.yml
│ └── stale.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .pylintrc
├── Dockerfile
├── LICENSE
├── MANIFEST.in
├── README.md
├── alembic.ini
├── config.yaml.example
├── data/
│ ├── .gitkeep
│ ├── dispatch_rules/
│ │ └── rules.yaml
│ ├── media/
│ │ └── .gitignore
│ ├── memory/
│ │ └── .gitignore
│ ├── web/
│ │ └── .gitkeep
│ └── workflows/
│ ├── .gitkeep
│ └── chat/
│ ├── dsr_thinking.yaml
│ ├── memory_store.yaml
│ ├── normal_multimodal.yaml
│ └── talk_break.yaml
├── docker/
│ └── start.sh
├── kirara_ai/
│ ├── __init__.py
│ ├── __main__.py
│ ├── alembic/
│ │ ├── README
│ │ ├── env.py
│ │ ├── script.py.mako
│ │ └── versions/
│ │ └── 4a364dbb8dab_initial_migration.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── config_loader.py
│ │ └── global_config.py
│ ├── database/
│ │ ├── __init__.py
│ │ └── manager.py
│ ├── entry.py
│ ├── events/
│ │ ├── __init__.py
│ │ ├── application.py
│ │ ├── event_bus.py
│ │ ├── im.py
│ │ ├── listen.py
│ │ ├── llm.py
│ │ ├── plugin.py
│ │ ├── tracing/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── llm.py
│ │ └── workflow.py
│ ├── im/
│ │ ├── __init__.py
│ │ ├── adapter.py
│ │ ├── im_registry.py
│ │ ├── manager.py
│ │ ├── message.py
│ │ ├── profile.py
│ │ └── sender.py
│ ├── internal.py
│ ├── ioc/
│ │ ├── __init__.py
│ │ ├── container.py
│ │ └── inject.py
│ ├── llm/
│ │ ├── adapter.py
│ │ ├── format/
│ │ │ ├── __init__.py
│ │ │ ├── embedding.py
│ │ │ ├── message.py
│ │ │ ├── request.py
│ │ │ ├── rerank.py
│ │ │ ├── response.py
│ │ │ └── tool.py
│ │ ├── llm_manager.py
│ │ ├── llm_registry.py
│ │ └── model_types.py
│ ├── logger.py
│ ├── mcp_module/
│ │ ├── __init__.py
│ │ ├── manager.py
│ │ ├── models.py
│ │ └── server.py
│ ├── media/
│ │ ├── __init__.py
│ │ ├── carrier/
│ │ │ ├── __init__.py
│ │ │ ├── provider.py
│ │ │ ├── registry.py
│ │ │ └── service.py
│ │ ├── manager.py
│ │ ├── media_object.py
│ │ ├── metadata.py
│ │ ├── types/
│ │ │ ├── __init__.py
│ │ │ └── media_type.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ └── mime.py
│ ├── memory/
│ │ ├── composes/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── builtin_composes.py
│ │ │ ├── composer_strategy.py
│ │ │ ├── decomposer_strategy.py
│ │ │ └── xml_helper.py
│ │ ├── entry.py
│ │ ├── memory_manager.py
│ │ ├── persistences/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── codecs.py
│ │ │ ├── file_persistence.py
│ │ │ └── redis_persistence.py
│ │ ├── registry.py
│ │ └── scopes/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── builtin_scopes.py
│ ├── plugin_manager/
│ │ ├── models.py
│ │ ├── plugin.py
│ │ ├── plugin_event_bus.py
│ │ ├── plugin_loader.py
│ │ └── utils.py
│ ├── plugins/
│ │ ├── .gitkeep
│ │ ├── bundled_frpc/
│ │ │ ├── __init__.py
│ │ │ ├── frpc_manager.py
│ │ │ ├── models.py
│ │ │ └── routes.py
│ │ ├── im_http_legacy_adapter/
│ │ │ ├── __init__.py
│ │ │ ├── adapter.py
│ │ │ ├── setup.py
│ │ │ └── tests/
│ │ │ └── api_test.py
│ │ ├── im_qqbot_adapter/
│ │ │ ├── __init__.py
│ │ │ ├── adapter.py
│ │ │ ├── setup.py
│ │ │ └── utils.py
│ │ ├── im_telegram_adapter/
│ │ │ ├── __init__.py
│ │ │ ├── adapter.py
│ │ │ └── setup.py
│ │ ├── im_wecom_adapter/
│ │ │ ├── __init__.py
│ │ │ ├── adapter.py
│ │ │ ├── delegates.py
│ │ │ └── setup.py
│ │ └── llm_preset_adapters/
│ │ ├── __init__.py
│ │ ├── alibabacloud_adapter.py
│ │ ├── claude_adapter.py
│ │ ├── deepseek_adapter.py
│ │ ├── gemini_adapter.py
│ │ ├── minimax_adapter.py
│ │ ├── mistral_adapter.py
│ │ ├── moonshot_adapter.py
│ │ ├── ollama_adapter.py
│ │ ├── openai_adapter.py
│ │ ├── openrouter_adapter.py
│ │ ├── setup.py
│ │ ├── siliconflow_adapter.py
│ │ ├── tencentcloud_adapter.py
│ │ ├── tests/
│ │ │ └── test_utils.py
│ │ ├── utils.py
│ │ ├── volcengine_adapter.py
│ │ └── voyage_adapter.py
│ ├── system/
│ │ ├── __init__.py
│ │ └── updater.py
│ ├── tracing/
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── decorator.py
│ │ ├── llm_tracer.py
│ │ ├── manager.py
│ │ └── models.py
│ ├── web/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── api/
│ │ │ ├── block/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── diagnostics/
│ │ │ │ │ ├── base_diagnostic.py
│ │ │ │ │ ├── import_check.py
│ │ │ │ │ ├── jedi_syntax_check.py
│ │ │ │ │ ├── mandatory_function.py
│ │ │ │ │ └── pyflakes_check.py
│ │ │ │ ├── models.py
│ │ │ │ ├── python_lsp.py
│ │ │ │ └── routes.py
│ │ │ ├── dispatch/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── routes.py
│ │ │ ├── im/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── routes.py
│ │ │ ├── llm/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── routes.py
│ │ │ ├── mcp/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── routes.py
│ │ │ ├── media/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── routes.py
│ │ │ ├── plugin/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── routes.py
│ │ │ ├── system/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── routes.py
│ │ │ │ └── utils.py
│ │ │ ├── tracing/
│ │ │ │ ├── __init__.py
│ │ │ │ └── routes.py
│ │ │ └── workflow/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── models.py
│ │ │ └── routes.py
│ │ ├── app.py
│ │ ├── auth/
│ │ │ ├── middleware.py
│ │ │ ├── models.py
│ │ │ ├── routes.py
│ │ │ ├── services.py
│ │ │ └── utils.py
│ │ └── utils.py
│ └── workflow/
│ ├── core/
│ │ ├── __init__.py
│ │ ├── block/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── input_output.py
│ │ │ ├── param.py
│ │ │ ├── registry.py
│ │ │ ├── schema.py
│ │ │ └── type_system.py
│ │ ├── dispatch/
│ │ │ ├── __init__.py
│ │ │ ├── dispatcher.py
│ │ │ ├── exceptions.py
│ │ │ ├── models/
│ │ │ │ └── dispatch_rules.py
│ │ │ ├── registry.py
│ │ │ └── rules/
│ │ │ ├── base.py
│ │ │ ├── message_rules.py
│ │ │ ├── sender_rules.py
│ │ │ └── system_rules.py
│ │ ├── execution/
│ │ │ ├── __init__.py
│ │ │ ├── exceptions.py
│ │ │ └── executor.py
│ │ └── workflow/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── builder.py
│ │ └── registry.py
│ ├── implementations/
│ │ ├── __init__.py
│ │ ├── blocks/
│ │ │ ├── __init__.py
│ │ │ ├── game/
│ │ │ │ ├── dice.py
│ │ │ │ └── gacha.py
│ │ │ ├── im/
│ │ │ │ ├── basic.py
│ │ │ │ ├── messages.py
│ │ │ │ ├── states.py
│ │ │ │ └── user_profile.py
│ │ │ ├── llm/
│ │ │ │ ├── basic.py
│ │ │ │ ├── chat.py
│ │ │ │ └── image.py
│ │ │ ├── mcp/
│ │ │ │ ├── __init__.py
│ │ │ │ └── tool.py
│ │ │ ├── memory/
│ │ │ │ ├── chat_memory.py
│ │ │ │ └── clear_memory.py
│ │ │ ├── system/
│ │ │ │ ├── basic.py
│ │ │ │ └── help.py
│ │ │ ├── system_blocks.py
│ │ │ └── variables/
│ │ │ └── variable_blocks.py
│ │ ├── factories/
│ │ │ ├── __init__.py
│ │ │ ├── default_factory.py
│ │ │ ├── game_factory.py
│ │ │ └── system_factory.py
│ │ └── workflows/
│ │ ├── __init__.py
│ │ └── system_workflows.py
│ └── utils/
│ └── __init__.py
├── pyproject.toml
├── pytest.ini
└── tests/
├── __init__.py
├── llm_adapters/
│ ├── __init__.py
│ ├── conftest.py
│ ├── mock_app/
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── gemini.py
│ │ ├── models/
│ │ │ ├── gemini.py
│ │ │ └── openai.py
│ │ ├── ollama.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── test_gemini_adapter.py
│ ├── test_ollama_adapter.py
│ ├── test_openai_adapter.py
│ └── test_voyage_adapter.py
├── memory/
│ ├── __init__.py
│ ├── test_composer_decomposer.py
│ ├── test_composer_strategy.py
│ ├── test_decomposer_strategy.py
│ ├── test_memory_manager.py
│ ├── test_persistence.py
│ └── test_scope.py
├── resources/
│ └── test_image.txt
├── system_blocks/
│ ├── __init__.py
│ ├── game/
│ │ ├── __init__.py
│ │ ├── test_dice.py
│ │ └── test_gacha.py
│ ├── im/
│ │ ├── __init__.py
│ │ ├── test_messages.py
│ │ └── test_states.py
│ ├── llm/
│ │ ├── __init__.py
│ │ ├── test_basic.py
│ │ ├── test_chat.py
│ │ └── test_image.py
│ ├── memory/
│ │ ├── __init__.py
│ │ ├── test_chat_memory.py
│ │ └── test_clear_memory.py
│ └── system/
│ ├── __init__.py
│ ├── test_basic.py
│ └── test_help.py
├── test_config_loader.py
├── test_game_blocks.py
├── test_mcp_server.py
├── test_media.py
├── test_media_element.py
├── test_system_blocks.py
├── test_workflow_builder.py
├── test_workflow_factories.py
├── tracing/
│ ├── __init__.py
│ ├── test_base.py
│ ├── test_core.py
│ ├── test_decorator.py
│ ├── test_llm_tracer.py
│ ├── test_manager.py
│ └── test_models.py
├── utils/
│ ├── auth_test_utils.py
│ └── test_block_registry.py
├── web/
│ ├── api/
│ │ ├── im/
│ │ │ └── test_im.py
│ │ ├── llm/
│ │ │ └── test_llm.py
│ │ ├── media/
│ │ │ └── test_media.py
│ │ ├── plugin/
│ │ │ └── test_plugin.py
│ │ ├── system/
│ │ │ └── test_system.py
│ │ └── workflow/
│ │ └── test_workflow.py
│ └── auth/
│ └── test_auth.py
└── workflow_executor/
├── test_block.py
├── test_executor.py
├── test_input_output.py
└── test_workflow_basic.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .cursor/rules/create-workflow.mdc
================================================
---
description: 创建 Workflow
globs:
---
You are an expert in Python, writing an AI application called chatgpt-mirai-qq-bot. it's a workflow based chatbot system.
Application
- Entrypoint: [main.py](mdc:main.py)
- IOC framework: [container.py](mdc:framework/ioc/container.py) [inject.py](mdc:framework/ioc/inject.py)
- Workflow system is consisted by a group of blocks that runs in workflow, workflow is run by executor.
- blocks: [base.py](mdc:framework/workflow/core/block/base.py) [registry.py](mdc:framework/workflow/core/block/registry.py)
- workflow: [base.py](mdc:framework/workflow/core/workflow/base.py) [registry.py](mdc:framework/workflow/core/workflow/registry.py)
- executor: [executor.py](mdc:framework/workflow/core/execution/executor.py)
- im adapter can choose which workflow to run in [dispatcher.py](mdc:framework/workflow/core/dispatch/dispatcher.py), rules is described by [rule.py](mdc:framework/workflow/core/dispatch/rule.py).
User defined rules located at folder `data/dispatch_rules`
- system/internal blocks implementation located at `framework/worflow/implementations`
- Memory system: [memory_manager.py](mdc:framework/memory/memory_manager.py).
Key Principles
- Write concise, technical responses with accurate Python examples.
- Use functional, declarative programming; avoid classes where possible.
- Prefer iteration and modularization over code duplication.
- Use descriptive variable names with auxiliary verbs (e.g., is_active, has_permission).
- Use lowercase with underscores for directories and files (e.g., routers/user_routes.py).
- Favor named exports for routes and utility functions.
- Use the Receive an Object, Return an Object (RORO) pattern.
Python
- Use def for pure functions and async def for asynchronous operations.
- Use type hints for all function signatures. Prefer Pydantic models over raw dictionaries for input validation.
- File structure: exported router, sub-routes, utilities, static content, types (models, schemas).
- Avoid unnecessary curly braces in conditional statements.
- For single-line statements in conditionals, omit curly braces.
- Use concise, one-line syntax for simple conditional statements (e.g., if condition: do_something()).
Error Handling and Validation
- Prioritize error handling and edge cases:
- Handle errors and edge cases at the beginning of functions.
- Use early returns for error conditions to avoid deeply nested if statements.
- Place the happy path last in the function for improved readability.
- Avoid unnecessary else statements; use the if-return pattern instead.
- Use guard clauses to handle preconditions and invalid states early.
- Implement proper error logging and user-friendly error messages.
- Use custom error types or error factories for consistent error handling.
Dependencies
- Pydantic v2
- Quart for HTTP Service
- rumel.yaml for YAML serialization
- asyncio for async programmig
Highlights on Workflow:
- IM Message get and send: [chat.py](mdc:framework/workflow/implementations/blocks/llm/chat.py)
- memory interaction: [chat_memory.py](mdc:framework/workflow/implementations/blocks/memory/chat_memory.py)
- LLM interaction: [chat.py](mdc:framework/workflow/implementations/blocks/llm/chat.py)
Key Conventions
1. Rely on App's dependency injection system for managing state and shared resources.
2. Respond in Chinese
================================================
FILE: .dockerignore
================================================
config.json
config.json.old
config.cfg
.chatgpt_cache.json
================================================
FILE: .editorconfig
================================================
# https://editorconfig.org
root = true
[*]
indent_style = space
indent_size = 4
trim_trailing_whitespace = true
insert_final_newline = true
charset = utf-8
end_of_line = lf
[*.py]
max_line_length = 120
[LICENSE]
insert_final_newline = false
================================================
FILE: .github/ISSUE_TEMPLATE/bug-report.md
================================================
---
name: Bug report
about: BUG 汇报
title: "[BUG] 请填写标题"
labels: bug
assignees: ''
---
**提交 issue 前,请先确认:**
- [x] 我已看过 **FAQ**,此问题不在列表中
- [ ] 我已看过其他 issue,他们不能解决我的问题
- [ ] 我认为这不是 Mirai 或者 OpenAI 的 BUG
**表现**
描述 BUG 的表现情况
**运行环境:**
- 操作系统:?
- Docker: ?
- 项目版本:?
**复现步骤**
描述你是如何触发这个 BUG 的
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**预期行为**
描述你认为正常情况下应该看见的情况
**截图**
相关日志、聊天记录的截图,没有可跳过
**其他内容**
此处填写其他内容,没有可跳过
================================================
FILE: .github/ISSUE_TEMPLATE/feature-request.md
================================================
---
name: Feature request
about: 提交新功能建议
title: "[Feature] 请在此处填写标题"
labels: enhancement
assignees: ''
---
================================================
FILE: .github/dependabot.yml
================================================
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
================================================
FILE: .github/quickstarts/windows/scripts/启动.cmd
================================================
@REM ...
@ECHO OFF
@CHCP 65001
TITLE [Kirara AI] AI 系统正在启动...
SET PATH=%cd%\WPy64-31320\python;%cd%\ffmpeg\bin;%PATH%
IF NOT EXIST data\venv (
ECHO 虚拟环境不存在,正在创建...
python -m venv --system-site-packages data\venv
ECHO 虚拟环境创建完成
)
TITLE [Kirara AI] AI 系统正在运行...
ECHO 正在启动 Kirara AI...
call data\venv\Scripts\activate.bat
python -m kirara_ai
TITLE [Kirara AI] AI 系统已停止运行
ECHO 程序已停止运行。
PAUSE
================================================
FILE: .github/workflows/docker-latest.yml
================================================
name: Docker build latest
on:
workflow_dispatch:
jobs:
docker:
runs-on: ubuntu-latest
steps:
-
name: Checkout
uses: actions/checkout@v4
-
name: Set up QEMU
uses: docker/setup-qemu-action@v3
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
-
name: qemu workaround
run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes -c yes
-
name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Build and push
uses: docker/build-push-action@v4
with:
context: .
push: true
platforms: 'linux/amd64,linux/arm64'
tags: lss233/kirara-agent-framework:latest
cache-from: type=gha
cache-to: type=gha,mode=max
================================================
FILE: .github/workflows/docker-tag.yml
================================================
name: Docker build with tags
on:
workflow_dispatch:
push:
tags:
- '**'
jobs:
docker:
runs-on: ubuntu-latest
steps:
- name: Set output
id: vars
run: echo "tag=${GITHUB_REF#refs/*/}" >> $GITHUB_OUTPUT
- name: Check output
env:
RELEASE_VERSION: ${{ steps.vars.outputs.tag }}
run: |
echo $RELEASE_VERSION
echo ${{ steps.vars.outputs.tag }}
-
name: Checkout
uses: actions/checkout@v4
-
name: Set up QEMU
uses: docker/setup-qemu-action@v3
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
-
name: qemu workaround
run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes -c yes
-
name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Build and push
uses: docker/build-push-action@v4
with:
context: .
push: true
platforms: 'linux/amd64,linux/arm64'
tags: lss233/kirara-agent-framework:${{ steps.vars.outputs.tag }}
cache-from: type=gha
cache-to: type=gha,mode=max
================================================
FILE: .github/workflows/pr_review.yml
================================================
name: PR Code Review
on:
pull_request_target:
branches: [ "master" ]
jobs:
mypy-review:
name: MyPy Type Check
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
issues: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install project dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .
python -m pip install mypy types-requests types-setuptools
mypy --python-version 3.12 --ignore-missing-imports kirara_ai || true # run mypy to generate type dependencies
python -m mypy --install-types --non-interactive
- name: Get changed Python files
id: changed-files
run: |
BASE_SHA=$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }})
CHANGED_FILES=$(git diff --name-only $BASE_SHA ${{ github.event.pull_request.head.sha }} | grep '\.py$' || echo "")
VALID_FILES=""
for file in $CHANGED_FILES; do
if [ -f "$file" ]; then
VALID_FILES="$VALID_FILES $file"
fi
done
echo "files=${VALID_FILES}" >> $GITHUB_OUTPUT
echo "Changed Python files: ${VALID_FILES}"
- name: Run mypy on changed files
id: run-mypy
run: |
CHANGED_FILES="${{ steps.changed-files.outputs.files }}"
if [[ -z "$CHANGED_FILES" ]]; then
echo "No Python files changed in this PR."
echo "has_changed_files=false" >> $GITHUB_OUTPUT
exit 0
fi
echo "has_changed_files=true" >> $GITHUB_OUTPUT
# 将输出保存到文本和JSON两种格式
mypy --python-version 3.12 --show-column-numbers --show-error-codes --ignore-missing-imports $CHANGED_FILES > mypy_output.txt || true
mypy --python-version 3.12 --show-column-numbers --show-error-codes --ignore-missing-imports $CHANGED_FILES --output json > mypy_output.json || true
continue-on-error: true
- name: Get PR diff information
id: get-diff
if: steps.run-mypy.outputs.has_changed_files == 'true'
run: |
# 获取被修改的文件和行号
BASE_SHA=$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }})
git diff -U0 $BASE_SHA ${{ github.event.pull_request.head.sha }} > pr_diff.txt
# 解析diff,提取修改的行
python - <<'EOF'
import re
import json
changed_lines = {}
current_file = None
with open('pr_diff.txt', 'r') as f:
for line in f:
# 从diff头部获取文件名
file_match = re.match(r'^\+\+\+ b/(.+)', line)
if file_match:
current_file = file_match.group(1)
changed_lines[current_file] = []
continue
# 解析代码块修改,格式如:@@ -1,5 +1,9 @@
hunk_match = re.match(r'^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@', line)
if hunk_match and current_file:
start_line = int(hunk_match.group(1))
if hunk_match.group(2):
count = int(hunk_match.group(2))
else:
count = 1
# 将这个块中所有增加或修改的行添加到列表
for i in range(count):
changed_lines[current_file].append(start_line + i)
# 将结果写入文件
with open('changed_lines.json', 'w') as f:
json.dump(changed_lines, f)
EOF
env:
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
- name: Process mypy results
id: process-results
if: steps.run-mypy.outputs.has_changed_files == 'true'
run: |
python - <<'EOF'
#!/usr/bin/env python3
import json
import os
import re
# 读取diff信息,获取修改的行
try:
with open("changed_lines.json", "r") as f:
changed_lines = json.load(f)
except FileNotFoundError:
changed_lines = {}
# 读取文本输出
try:
with open("mypy_output.txt", "r") as f:
text_output = f.read()
except FileNotFoundError:
text_output = ""
# 读取JSON输出
mypy_results = []
try:
with open("mypy_output.json", "r") as f:
content = f.read().strip()
if content:
for line in content.splitlines():
try:
mypy_results.append(json.loads(line))
except json.JSONDecodeError:
continue
except FileNotFoundError:
pass
# 如果JSON解析失败,尝试从文本解析错误
if not mypy_results and text_output:
pattern = r"(.*?):(\d+):(\d+): (\w+): (.*)"
matches = re.findall(pattern, text_output)
for match in matches:
file_path, line, column, error_type, message = match
mypy_results.append({
"file": file_path,
"line": int(line),
"column": int(column),
"code": error_type,
"message": message
})
# 过滤掉不在PR变更文件中的错误,使用标准化路径和精确匹配来避免伪阳性错误
changed_files = os.environ.get('CHANGED_FILES', '').split()
if changed_files:
normalized_changed_files = [os.path.normpath(f) for f in changed_files]
mypy_results = [error for error in mypy_results if os.path.normpath(error.get('file', '')) in normalized_changed_files]
# 只保留diff中的错误
review_comments = []
diff_errors = [] # 存储在diff中的错误
for error in mypy_results:
file_path = error.get("file", "unknown")
line_num = error.get("line", 0)
col_num = error.get("column", 0)
message = error.get("message", "未知错误")
code = error.get("code", "unknown")
# 检查这一行是否在PR diff中被修改
is_changed_line = False
for changed_file in changed_lines:
if file_path.endswith(changed_file) and line_num in changed_lines[changed_file]:
is_changed_line = True
break
if is_changed_line:
# 如果是修改的行,创建行级评论
review_comments.append({
"path": file_path,
"line": line_num,
"body": f"**MyPy 类型错误**: {message} ({code})\n\n详细信息请参考 [mypy 文档](https://mypy.readthedocs.io/en/stable/error_code_list.html#{code.lower() if code != 'unknown' else 'error-codes'})。"
})
diff_errors.append(error)
# 创建摘要信息
if diff_errors:
status = "fail"
message = f"在 PR 修改的代码行中发现了 {len(diff_errors)} 个类型问题,需要修复。"
else:
status = "pass"
message = f"PR 修改的代码行通过了类型检查。"
summary = {
"status": status,
"diff_error_count": len(diff_errors),
"review_comment_count": len(review_comments),
"message": message
}
# 将评论数据保存为JSON文件
with open("mypy_review_comments.json", "w") as f:
json.dump(review_comments, f)
# 将摘要保存为JSON文件
with open("mypy_summary.json", "w") as f:
json.dump(summary, f)
# 写入输出
with open(os.environ['GITHUB_OUTPUT'], 'a') as f:
f.write(f"result={status}\n")
f.write(f"diff_error_count={len(diff_errors)}\n")
f.write(f"review_comment_count={len(review_comments)}\n")
EOF
env:
CHANGED_FILES: ${{ steps.changed-files.outputs.files }}
- name: Post line-level PR review comments
if: steps.process-results.outputs.diff_error_count != '0'
uses: actions/github-script@v7
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
const fs = require('fs');
// 读取评论数据
const reviewComments = JSON.parse(fs.readFileSync('mypy_review_comments.json', 'utf8'));
const summary = JSON.parse(fs.readFileSync('mypy_summary.json', 'utf8'));
// 创建PR审查
const review = await github.rest.pulls.createReview({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number,
body: `## MyPy 类型检查结果 ❌\n\n${summary.message}\n\n已对修改的代码行创建了 ${reviewComments.length} 个行级评论。`,
event: 'COMMENT',
comments: reviewComments
});
// 添加失败标签
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
labels: ['🔴 类型检查:失败']
});
console.log(`Created review with ${reviewComments.length} comments`);
- name: Post success comment
if: steps.process-results.outputs.result == 'pass'
uses: actions/github-script@v7
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
const fs = require('fs');
const summary = JSON.parse(fs.readFileSync('mypy_summary.json', 'utf8'));
// 查找之前的评论
const { data: comments } = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
});
const botComment = comments.find(c => {
return c.user.type === 'Bot' &&
(c.body.includes('MyPy 类型检查通过') || c.body.includes('MyPy 类型检查结果'));
});
const comment = `## MyPy 类型检查通过 ✅\n\n${summary.message}`;
if (botComment) {
await github.rest.issues.updateComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: botComment.id,
body: comment
});
} else {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: comment
});
}
// 移除失败标签(如果存在)并添加成功标签
try {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: '🔴 类型检查:失败'
});
} catch (error) {
// 标签可能不存在,忽略错误
}
// 添加成功标签
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
labels: ['✅ 类型检查:通过']
});
- name: Post notification if no Python files changed
if: steps.run-mypy.outputs.has_changed_files == 'false'
uses: actions/github-script@v7
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
const { data: comments } = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
});
const botComment = comments.find(comment => {
return comment.user.type === 'Bot' &&
(comment.body.includes('MyPy 类型检查通过') || comment.body.includes('MyPy 类型检查结果'));
});
const comment = "## MyPy 类型检查\n\nPR 中没有修改任何 Python 文件,跳过类型检查。";
if (botComment) {
await github.rest.issues.updateComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: botComment.id,
body: comment
});
} else {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: comment
});
}
// 移除类型检查相关标签(如果存在)
try {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: '🔴 类型检查:失败'
});
} catch (error) {
// 标签可能不存在,忽略错误
}
try {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: '✅ 类型检查:通过'
});
} catch (error) {
// 标签可能不存在,忽略错误
}
- name: Fail if type issues found in diff
if: steps.process-results.outputs.diff_error_count != '0'
run: exit 1
================================================
FILE: .github/workflows/project_check.yml
================================================
name: Project Check
on:
push:
branches: [ "master" ]
merge_group:
branches: [ "master" ]
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
strategy:
fail-fast: false
matrix:
language: [ 'python' ]
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install project dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .
python -m pip install mypy types-requests types-setuptools
mypy --python-version 3.10 --ignore-missing-imports kirara_ai || true # run mypy to generate type dependencies
python -m mypy --install-types --non-interactive
- name: Run mypy
run: |
mypy --python-version 3.10 --show-column-numbers --show-error-codes --ignore-missing-imports kirara_ai --output json > mypy_output.json
continue-on-error: true
- name: Create mypy issue content
if: always()
run: |
cat > create_mypy_issue.py << 'EOF'
#!/usr/bin/env python3
import json
import os
import sys
from datetime import datetime
from collections import defaultdict
# 读取 mypy JSON 输出
try:
with open("mypy_output.json", "r") as f:
content = f.read()
if content.strip():
mypy_results = [json.loads(line) for line in content.splitlines() if line.strip()]
else:
print("mypy_output.json 文件为空")
mypy_results = []
except FileNotFoundError:
print("警告:mypy_output.json 文件不存在。创建空结果列表。")
mypy_results = []
except json.JSONDecodeError as e:
print(f"解析 JSON 时出错: {e}")
with open("mypy_output.json", "r") as f:
print(f"文件内容: {f.read()[:1000]}")
mypy_results = []
# 如果没有结果,则退出
if not mypy_results:
print("没有发现类型错误,不创建 issue")
sys.exit(0)
# 获取仓库信息
repo = os.environ.get("GITHUB_REPOSITORY", "")
run_id = os.environ.get("GITHUB_RUN_ID", "")
sha = os.environ.get("GITHUB_SHA", "")[:7]
# 获取分支信息
ref_name = os.environ.get("GITHUB_REF_NAME", "")
event_name = os.environ.get("GITHUB_EVENT_NAME", "")
if event_name == "merge_group":
branch_info = f"合并到 {ref_name} 分支"
else:
branch_info = f"{ref_name} 分支"
# 格式化当前时间
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 按文件分组错误
errors_by_file = defaultdict(list)
# 按错误类型分组
errors_by_type = defaultdict(int)
for result in mypy_results:
file_path = result.get("file", "unknown")
error_code = result.get("code", "未知")
errors_by_file[file_path].append(result)
errors_by_type[error_code] += 1
# 创建 issue 标题
issue_title = f"对{branch_info}的类型检查发现了 {len(mypy_results)} 个问题 ({sha})"
# 直接设置环境变量而不是写入文件
with open(os.environ['GITHUB_ENV'], 'a') as f:
f.write(f"ISSUE_TITLE={issue_title}\n")
# 创建 issue 内容
issue_body = f"""## mypy 类型检查报告
**时间**: {now}
**分支**: {branch_info}
**Commit**: {sha}
**工作流**: [查看运行详情](https://github.com/{repo}/actions/runs/{run_id})
mypy 共发现 {len(mypy_results)} 个类型问题:
"""
# 添加错误类型统计
issue_body += "### 错误类型统计\n\n"
issue_body += "| 错误代码 | 出现次数 | 占比 |\n"
issue_body += "| -------- | -------- | ---- |\n"
for error_code, count in sorted(errors_by_type.items(), key=lambda x: x[1], reverse=True):
percentage = (count / len(mypy_results)) * 100
issue_body += f"| `{error_code}` | {count} | {percentage:.1f}% |\n"
# 添加每个文件的问题摘要
issue_body += "\n### 问题摘要\n\n"
issue_body += "| 文件 | 问题数量 | 详情 |\n"
issue_body += "| ---- | -------- | ---- |\n"
for file_path, errors in sorted(errors_by_file.items(), key=lambda x: len(x[1]), reverse=True):
file_short = file_path.split("/")[-1]
issue_body += f"| `{file_path}` | {len(errors)} | [查看详情](#file-{file_short.replace('.', '-')}) |\n"
issue_body += "\n### 详细问题\n\n"
# 添加每个文件的详细问题
for file_path, errors in sorted(errors_by_file.items(), key=lambda x: x[0]):
file_short = file_path.split("/")[-1]
issue_body += f"\n"
issue_body += f"#### {file_path}\n\n"
issue_body += "| 行号 | 列号 | 错误代码 | 错误消息 |\n"
issue_body += "| ---- | ---- | -------- | -------- |\n"
# 按行号排序错误
for error in sorted(errors, key=lambda x: (x.get("line", 0), x.get("column", 0))):
line = error.get("line", "-")
column = error.get("column", "-")
error_code = error.get("code", "未知")
message = error.get("message", "").replace("|", "\\|") # 转义管道符号,避免破坏表格
issue_body += f"| {line} | {column} | `{error_code}` | {message} |\n"
issue_body += "\n"
# 将内容写入文件,供 GitHub Action 使用
with open("issue_body.md", "w", encoding="utf-8") as f:
f.write(issue_body)
print(f"成功创建 mypy 问题报告,共 {len(mypy_results)} 个问题")
EOF
python create_mypy_issue.py
- name: Create GitHub Issue
if: always()
uses: peter-evans/create-issue-from-file@v4
with:
title: ${{ env.ISSUE_TITLE }}
content-filepath: ./issue_body.md
labels: |
type-check
automated-report
bug
================================================
FILE: .github/workflows/quickstart-windows.yml
================================================
name: Windows Quickstart
on:
workflow_dispatch:
push:
branches: ['master']
tags: ['**']
pull_request:
branches: ['master']
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
PYTHON_VERSION: "3.13.2.0"
WINPYTHON_URL: "https://github.com/winpython/winpython/releases/download/13.1.202502222final/Winpython64-3.13.2.0dot.zip"
DIST_DIR: "C:/dist"
BUILD_DIR: "C:/build"
PACKAGE_NAME: "quickstart-windows-kirara-ai-amd64"
jobs:
build:
name: Windows Quickstart
runs-on: windows-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python for building
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
- name: Build wheel package
run: |
python -m pip install build
python -m build
# 获取生成的wheel文件名
$WheelFile = Get-ChildItem -Path "dist" -Filter "*.whl" | Select-Object -First 1 -ExpandProperty Name
echo "WHEEL_FILE=$WheelFile" >> $env:GITHUB_ENV
- name: Prepare distribution environment
run: |
# 创建必要的目录
mkdir ${{ env.DIST_DIR }}
mkdir ${{ env.BUILD_DIR }}
# 下载 WinPython
Invoke-WebRequest -Uri "${{ env.WINPYTHON_URL }}" -OutFile "${{ env.BUILD_DIR }}/winpython.zip"
Expand-Archive "${{ env.BUILD_DIR }}/winpython.zip" -DestinationPath "${{ env.DIST_DIR }}"
- name: Install project and dependencies
run: |
cd ${{ env.DIST_DIR }}
./WPy64-31320/python/python.exe -m pip install "${{ github.workspace }}/dist/${{ env.WHEEL_FILE }}"
./WPy64-31320/python/python.exe -m pip install --upgrade pip
- name: Download and setup FFmpeg
run: |
Invoke-WebRequest -Uri "https://www.gyan.dev/ffmpeg/builds/packages/ffmpeg-7.0.2-essentials_build.7z" -OutFile "${{ env.BUILD_DIR }}/ffmpeg.7z"
7z x "${{ env.BUILD_DIR }}/ffmpeg.7z" -o"${{ env.DIST_DIR }}/ffmpeg"
mv "${{ env.DIST_DIR }}/ffmpeg/ffmpeg-7.0.2-essentials_build" "${{ env.DIST_DIR }}/ffmpeg/bin"
- name: Download VC++ Runtime
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.x64.exe" -OutFile "${{ env.DIST_DIR }}/【语音功能依赖】vc_redist.x64.exe"
- name: Setup Web UI
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# 下载 Web UI 压缩包到临时目录
$release = Invoke-RestMethod -Uri "https://api.github.com/repos/DarkSkyTeam/chatgpt-for-bot-webui/releases" -Headers @{Authorization = "Bearer $env:GH_TOKEN"}
$web_ui_url = $release[0].assets[0].browser_download_url
$zip_file = "${{ env.BUILD_DIR }}/webui.zip"
Invoke-WebRequest -Uri $web_ui_url -OutFile $zip_file
# 解压到临时目录
$temp_dir = "${{ env.BUILD_DIR }}/webui_temp"
mkdir $temp_dir
Expand-Archive -Path $zip_file -DestinationPath $temp_dir
New-Item -ItemType Directory -Force -Path "${{ env.DIST_DIR }}/web"
# 移动 dist 文件夹到目标位置
Copy-Item -Path "$temp_dir/dist/*" -Destination "${{ env.DIST_DIR }}/web" -Force -Recurse
- name: Copy startup scripts
run: |
Copy-Item ".github/quickstarts/windows/scripts/*" -Destination "${{ env.DIST_DIR }}/" -Recurse
# 拷贝 data 文件夹
Copy-Item -Path "${{ github.workspace }}/data" -Destination "${{ env.DIST_DIR }}/" -Recurse
- name: Upload workflow artifacts
uses: actions/upload-artifact@v4
with:
name: ${{ env.PACKAGE_NAME }}
path: ${{ env.DIST_DIR }}
- name: Create release archive
if: startsWith(github.ref, 'refs/tags/')
run: |
Compress-Archive -Path "${{ env.DIST_DIR }}/*" -DestinationPath "${{ env.BUILD_DIR }}/${{ env.PACKAGE_NAME }}.zip"
- name: Upload release archive
if: startsWith(github.ref, 'refs/tags/')
uses: svenstaro/upload-release-action@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ${{ env.BUILD_DIR }}/${{ env.PACKAGE_NAME }}.zip
asset_name: Windows-quickstart-kirara-ai-${{ github.ref_name }}.zip
tag: ${{ github.ref_name }}
overwrite: false
body: "Windows x64 用户的快速启动包"
================================================
FILE: .github/workflows/run-tests.yml
================================================
name: Run Tests
on:
workflow_dispatch:
push:
branches:
- '**'
pull_request:
branches:
- master
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Docker
if: matrix.os == 'ubuntu-latest'
uses: docker/setup-docker-action@v4
- name: Build Docker image
if: matrix.os == 'ubuntu-latest'
run: |
docker build -t test-image .
- name: Run tests in Docker
if: matrix.os == 'ubuntu-latest'
run: |
docker run -v $(pwd):/app test-image sh -c "python -m pip install pytest coverage pytest-cov && python -m pytest /app/tests -v --cov=kirara_ai --cov-report=xml:/app/coverage.xml --cov-report=term-missing --junitxml=/app/junit.xml -o junit_family=legacy"
- name: Upload test results to Codecov
if: matrix.os == 'ubuntu-latest'
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
if: matrix.os == 'ubuntu-latest'
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Set up Python 3.13
if: matrix.os == 'windows-latest'
uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Run tests on Windows
if: matrix.os == 'windows-latest'
run: |
set PYTHONIOENCODING=utf-8
set PYTHONLEGACYWINDOWSSTDIO=utf-8
python -m pip install -e .
python -m pip install pytest
chcp 65001
python -m pytest ./tests -vs
================================================
FILE: .github/workflows/stale.yml
================================================
name: 处理不活跃的 Issue 和 PR
on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # 每天午夜运行
permissions:
contents: write # only for delete-branch option
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v8
with:
# 基本配置
repo-token: ${{ secrets.GITHUB_TOKEN }}
days-before-stale: 60 # 60天不活跃标记为 stale
days-before-close: 14 # 标记为 stale 后14天关闭
# 友好的提示信息
stale-issue-message: >
👋 您好!这个 issue 已经 60 天没有活动了。
为了保持我们的 issue 列表整洁,我们会标记长时间不活跃的 issue。
如果您认为这个 issue 仍然重要且有效,请留下评论或移除 "stale" 标签,
否则它将在 14 天后自动关闭。
感谢您的理解和贡献!
stale-pr-message: >
👋 您好!这个 PR 已经 60 天没有活动了。
为了保持我们的 PR 列表整洁,我们会标记长时间不活跃的 PR。
如果您仍在处理这个 PR,请留下评论或移除 "stale" 标签,
否则它将在 14 天后自动关闭。
如果您需要帮助完成这个 PR,请告诉我们!
感谢您的贡献!
close-issue-message: >
🙏 由于长时间没有活动,我们暂时关闭了这个 issue。
如果您认为这个问题仍然存在,请随时重新打开或创建新的 issue。
谢谢!
close-pr-message: >
🙏 由于长时间没有活动,我们暂时关闭了这个 PR。
如果您想继续这项工作,请随时重新打开或创建新的 PR。
感谢您的贡献!
# 排除某些标签的 issue/PR
exempt-issue-labels: 'planned,documentation,long-term-task'
exempt-pr-labels: 'WIP,waiting-for-review,long-term-task'
# 只处理某些标签的 issue/PR(可选)
# only-labels: ''
# 其他选项
operations-per-run: 100 # 每次运行处理的最大数量
remove-stale-when-updated: true # 当更新时移除 stale 标签
ascending: true # 从最老的开始处理
================================================
FILE: .gitignore
================================================
config.json
config.cfg
__pycache__/
python3.11/
.idea/
data/*.json
.chatgpt_cache.json
Dockerfile.dev
**/.DS_Store
venv/
.vscode/
config.yaml
config.yaml.bak
data/config.yaml
data/config.yaml.bak
logs/
**/password.hash
dist/
build/
*.egg-info/
.coverage
/web/
botpy.log*
data/frpc/
data/db
.venv/
uv.lock
**/mypy_cache/
**/test-workflow-new.yaml
**/test_password.hash
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/PyCQA/isort
rev: 6.0.0
hooks:
- id: isort
name: isort (python3)
language_version: python3
args: ["--atomic"]
- repo: https://github.com/myint/autoflake
rev: v2.3.0
hooks:
- id: autoflake
args:
[
"--remove-all-unused-imports",
"--in-place",
"--recursive",
]
================================================
FILE: .pylintrc
================================================
[MASTER]
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=0
[MESSAGES CONTROL]
# Disable the message, report, category or checker with the given id(s).
disable=all
# Enable the message, report, category or checker with the given id(s).
enable=c-extension-no-member,
bad-indentation,
bare-except,
broad-except,
dangerous-default-value,
function-redefined,
len-as-condition,
line-too-long,
misplaced-future,
missing-final-newline,
mixed-line-endings,
multiple-imports,
multiple-statements,
singleton-comparison,
trailing-comma-tuple,
trailing-newlines,
trailing-whitespace,
unexpected-line-ending-format,
unused-import,
unused-variable,
wildcard-import,
wrong-import-order
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=LF
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )??$
# Maximum number of characters on a single line.
max-line-length=120
# Maximum number of lines in a module.
max-module-lines=2000
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception
================================================
FILE: Dockerfile
================================================
# 第一阶段:构建wheel包
FROM python:3.11-slim AS builder
WORKDIR /build
COPY . .
RUN python -m pip install build && \
python -m build
# 第二阶段:运行环境
FROM python:3.11-slim-bullseye
ENV DEBIAN_FRONTEND=noninteractive
# 复制字体文件
COPY ./data/fonts/sarasa-mono-sc-regular.ttf /usr/share/fonts/
# 安装系统依赖
RUN apt-get -yqq update && \
apt-get -yqq install --no-install-recommends \
wkhtmltopdf \
ffmpeg \
curl \
jq \
libmagic1 \
unzip && \
apt-get -yq clean && \
apt-get -yq purge --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
rm -rf /var/lib/apt/lists/*
# 创建应用目录
WORKDIR /app
# 复制第一阶段构建的wheel包并安装
COPY --from=builder /build/dist/*.whl /app/
# 下载Web UI并安装依赖
RUN PACKAGE_INFO=$(curl -s https://registry.npmjs.org/kirara-ai-webui) && \
LATEST_VERSION=$(printf %s $PACKAGE_INFO | jq -r '.["dist-tags"].latest') && \
TARBALL_URL=$(printf %s $PACKAGE_INFO | jq -r --arg VERSION "$LATEST_VERSION" '.versions[$VERSION].dist.tarball') && \
curl -L -o webui.tgz "$TARBALL_URL" && \
mkdir -p /tmp/webui && \
tar -xzf webui.tgz -C /tmp/webui && \
mkdir -p /app/web && \
cp -r /tmp/webui/package/dist/* /app/web/ && \
rm -rf /tmp/webui webui.tgz && \
pip install --no-cache-dir *.whl && \
pip cache purge && \
rm *.whl
# 移除不再需要的包
RUN apt-get -yqq remove --purge curl jq unzip
# 复制应用代码
COPY ./docker/start.sh /app/docker/
COPY ./data /tmp/data
EXPOSE 8080
CMD ["/bin/bash", "/app/docker/start.sh"]
================================================
FILE: LICENSE
================================================
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
Copyright (C) 2007 Free Software Foundation, Inc.
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.
A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate. Many developers of free software are heartened and
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.
The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.
An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals. This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU Affero General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
Copyright (C)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
.
================================================
FILE: MANIFEST.in
================================================
recursive-include kirara_ai/plugins/im_http_legacy_adapter/assets *
recursive-include kirara_ai/plugins/im_qqbot_adapter/assets *
recursive-include kirara_ai/plugins/im_telegram_adapter/assets *
recursive-include kirara_ai/plugins/im_wecom_adapter/assets *
recursive-include kirara_ai/alembic *
================================================
FILE: README.md
================================================
Kirara AI
一款支持主流大语言模型、主流聊天平台的聊天的机器人!
» 查看项目手册 »
***

***
## 🌟 社区交流
加入我们的社区,获取最新项目动态、视频教程、问题答疑和技术交流!
* QQ 交流群:
* [二群](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=S1R4eIlODtyKZsEKfWxb2-nOIHELbeJY&authKey=kAftCAALE8OJgwQnArrD6zPtncCAaY456QgUXT3l2OMJ57NwRXRkhv4KL7DzOLzs&noverify=0&group_code=373254418)(已满)
* [三群](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=urlhCH8y7Ro2S-iXt63X4s5eILUny4Iw&authKey=ejiwoNa4Yez6IMLyf2vj%2FeRiC1frdFrNNekbRfaPnSQbcD7bgebo5y5A7rPaRKBq&noverify=0&group_code=533109074)(已满)
* [四群](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=Ibiu6EmXof30Fa7MJ5j8nJFwaUGTf5bM&authKey=YKx5a%2BK5qnWkk5VlsxxDfYl0nCrKSekQm%2FoLQVqr%2FcO%2FQY2S6N24XdI23XugBrF0&noverify=0&group_code=799737883)(已满)
* [五群](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=lDkVPDAeiz6M-ig9cdS9tqhSH6_topox&authKey=B%2FRPYVUjk3dYPw5D4o6C2TpqeoKTG0nXEiKDCG%2Bh4JYY2RPqDQGt37SGl32j0hHw&noverify=0&group_code=805081636)
* [六群](https://qm.qq.com/q/UpvYm3jccg)
> **提问前请先查看**: 加入群组前,请先查看[项目问题列表](https://github.com/lss233/kirara-ai/issues),看是否能解决你的问题。
>
> 如需提问,请准备好问题描述、**完整日志**和相关配置文件,以便我们更好地帮助你。
> 进群请备注:GitHub
* [机器人调试群](https://jq.qq.com/?_wv=1027&k=TBX8Saq7) - 这里有多个 QQ 机器人供体验,不解答技术问题。
* [开发者交流群](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=lisyXibhUj93DgIZptQu3VZ4ka3F5-rW&authKey=PBCzRQX4Zei%2BB6n5Tdyp9p5bqcF0tLBlfGANT4dSSKQIFYR66WwaZSMEDahWo%2FzZ&noverify=0&group_code=701933732) - 欢迎参与 Kirara AI 及生态开发 / 对大模型应用有兴趣的开发者加入,一起交流学习。
## 📷 功能展示
|  |  |  |
|:-------------------------------:|:-------------------------------:|:-------------------------------:|
|  |  |  |
## 🧭 WebUI
模型管理

工作流

插件市场

## ⚡ 核心特性
* [x] 图片发送
* [x] 关键词触发回复
* [x] 多账号支持
* [x] 人格设定
* [x] 支持 QQ、Telegram、Discord、微信
* [x] 可作为 HTTP 服务端提供 Web API
* [x] 支持 OpenAI、DeepSeek、Claude、Gemini、Qwen、Mistral、豆包、Minimax、Kimi、Mistral 等主流大模型
* [x] 支持插件机制
* [x] 支持条件触发
* [x] 支持管理员指令
* [x] 支持 Stable Diffusion、Flux、Midjourney 等绘图模型
* [x] 支持语音回复
* [x] 支持多轮对话
* [x] 支持跨平台消息发送
* [x] 支持自定义工作流
* [x] 支持 Web 管理后台
* [x] 内置 Frpc 内网穿透
# **🤖 聊天平台**
我们支持多种聊天平台。
| 平台 | 群聊回复 | 私聊回复 | 条件触发 | 管理员指令 | 绘图 | 语音回复 |
|----------|------|------|------|-------|-----|------|
| Telegram | 支持 | 支持 | 支持 | 支持 | 支持 | 支持 |
| QQ 机器人 | 支持 | 支持 | 支持 | 支持 | 支持 | 平台不支持 |
| Discord | 重构中 | 重构中 | 重构中 | 重构中 | 重构中 | 重构中 |
| 飞书机器人 | 重构中 | 重构中 | 重构中 | 重构中 | 重构中 | 重构中 |
| 企业微信应用 | 支持 | 支持 | 支持 | 不支持 | 支持 | 支持 |
| 微信公众号 | 支持 | 支持 | 支持 | 不支持 | 支持 | 支持 |
| OneBot | 插件支持 | 插件支持 | 插件支持 | 插件支持 | 插件支持 | 插件支持 |
## 🐎 命令
**你可以在 WebUI 的调度规则中自定义所有命令。**
## 🔧 搭建
请移步至 [快速开始](https://kirara-docs.app.lss233.com/guide/getting-started.html)
## 🕸 HTTP API
HTTP API 可用于接入其他平台。
在聊天平台管理中启动 http-legacy 适配器后,将提供以下接口:
**POST** `/v1/chat`
**请求参数**
|参数名|必选|类型|说明|
|:---|:---|:---|:---|
|session_id| 是 | String |会话ID,默认:`friend-default_session`|
|username| 是 | String |用户名,默认:`某人`|
|message| 是 | String |消息,不能为空|
**请求示例**
```json
{
"session_id": "friend-123456",
"username": "testuser",
"message": "ping"
}
```
**响应格式**
|参数名|类型|说明|
|:---|:---|:---|
|result| String |SUCESS,DONE,FAILED|
|message| String[] |文本返回,支持多段返回|
|voice| String[] |音频返回,支持多个音频的base64编码;参考:data:audio/mpeg;base64,...|
|image| String[] |图片返回,支持多个图片的base64编码;参考:data:image/png;base64,...|
**响应示例**
```json
{
"result": "DONE",
"message": ["pong!"],
"voice": [],
"image": []
}
```
**POST** `/v2/chat`
**请求参数**
|参数名|必选|类型|说明|
|:---|:---|:---|:---|
|session_id| 是 | String |会话ID,默认:`friend-default_session`|
|username| 是 | String |用户名,默认:`某人`|
|message| 是 | String |消息,不能为空|
**请求示例**
```json
{
"session_id": "friend-123456",
"username": "testuser",
"message": "ping"
}
```
**响应格式**
字符串:request_id
**响应示例**
```
1681525479905
```
**GET** `/v2/chat/response`
**请求参数**
|参数名|必选|类型|说明|
|:---|:---|:---|:---|
|request_id| 是 | String |请求id,/v2/chat返回的值|
**请求示例**
```
/v2/chat/response?request_id=1681525479905
```
**响应格式**
|参数名|类型|说明|
|:---|:---|:---|
|result| String |SUCESS,DONE,FAILED|
|message| String[] |文本返回,支持多段返回|
|voice| String[] |音频返回,支持多个音频的base64编码;参考:data:audio/mpeg;base64,...|
|image| String[] |图片返回,支持多个图片的base64编码;参考:data:image/png;base64,...|
* 每次请求返回增量并清空。DONE、FAILED之后没有更多返回。
**响应示例**
```json
{
"result": "DONE",
"message": ["pong!"],
"voice": ["data:audio/mpeg;base64,..."],
"image": ["data:image/png;base64,...", "data:image/png;base64,..."]
}
```
## 🦊 加载预设
如果你想让机器人自动带上某种聊天风格,可以使用预设功能。
我们自带了 `猫娘` 和 `正常` 两种预设,你可以在 `presets` 文件夹下了解预设的写法。
使用 `加载预设 猫娘` 来加载猫娘预设。
下面是一些预设的小视频,你可以看看效果:
* MOSS: https://www.bilibili.com/video/av352047018
* 丁真:https://www.bilibili.com/video/av267013053
* 小黑子:https://www.bilibili.com/video/av309604568
* 高启强:https://www.bilibili.com/video/av779555493
关于预设系统的详细教程:[Wiki](https://github.com/lss233/kirara-ai/wiki/%F0%9F%90%B1-%E9%A2%84%E8%AE%BE%E7%B3%BB%E7%BB%9F)
你可以在 [Awesome ChatGPT QQ Presets](https://github.com/lss233/awesome-chatgpt-qq-presets/tree/master) 获取由大家分享的预设。
你也可以参考 [Awesome-ChatGPT-prompts-ZH_CN](https://github.com/L1Xu4n/Awesome-ChatGPT-prompts-ZH_CN) 来调教你的 ChatGPT,还可以参考 [Awesome ChatGPT Prompts](https://github.com/f/awesome-chatgpt-prompts) 来解锁更多技能。
## 🎙 文字转语音
自 v2.2.5 开始,我们支持接入微软的 Azure 引擎 和 VITS 引擎,让你的机器人发送语音。
**提示**:在 Windows 平台上使用语音功能需要安装最新的 VC 运行库,你可以在[这里](https://learn.microsoft.com/zh-CN/cpp/windows/latest-supported-vc-redist?view=msvc-170)下载。`
## 🛠 贡献者名单
欢迎提出新的点子、 Pull Request。
Made with [contrib.rocks](https://contrib.rocks).
## 📕 相关项目
- [Kirara Registry](https://github.com/DarkSkyTeam/kirara-registry) - Kirara AI 插件市场
- [Kirara WebUI](https://github.com/DarkSkyTeam/kirara-webui) - Kirara AI 的 WebUI 前端项目
- [Kirara Docs](https://github.com/DarkSkyTeam/kirara-docs) - Kirara AI 的使用手册原始文档
## 💪 支持我们
如果我们这个项目对你有所帮助,请给我们一颗 ⭐️
[](https://www.star-history.com/#lss233/kirara-ai&Date)
================================================
FILE: alembic.ini
================================================
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = kirara_ai/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///./data/db/kirara.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
================================================
FILE: config.yaml.example
================================================
# 配置文件示例
# 通讯平台配置部分
ims:
# 每个 IM 平台的具体配置
- name: "telegram-bot-1234" # IM 平台实例名称
enable: true # 是否启用该平台
adapter: "telegram" # 使用的适配器类型
config: # 平台特定的配置
token: "abcd" # 平台的 API 令牌
# 插件系统配置
plugins:
enable: [] # 启用的插件列表
# Web 服务器配置
web:
host: "127.0.0.1" # Web 服务器监听地址
port: 8080 # Web 服务器端口
secret_key: "please-change-this-to-a-secure-secret-key" # Web 服务器安全密钥,请修改为安全的值
# LLM (大语言模型) 配置
llms:
api_backends: # API 后端配置列表
# DeepSeek API 配置
- name: "deepseek-official" # 后端名称
adapter: "deepseek" # 使用的适配器
enable: true # 是否启用
config: # API 具体配置
api_key: "your-api-key" # API 密钥
api_base: "https://api.deepseek.com/v1" # API 基础 URL
models: # 支持的模型列表
- "deepseek-chat"
- "deepseek-coder"
# OpenAI API 配置
- name: "openai-gpt4" # 后端名称
adapter: "openai" # 使用的适配器
enable: true # 是否启用
config: # API 具体配置
api_key: "your-openai-key" # OpenAI API 密钥
api_base: "https://api.openai.com/v1" # OpenAI API 基础 URL
models: # 支持的模型列表
- "gpt-4"
- "gpt-4-turbo"
# 默认配置
defaults:
llm_model: gemini-1.5-flash # 默认使用的 LLM 模型
# 记忆系统配置
memory:
persistence: # 持久化配置
type: file # 持久化类型(支持 file 或 redis)
file: # 文件存储配置
storage_dir: ./data/memory # 存储目录
redis: # Redis 存储配置
host: localhost # Redis 主机地址
port: 6379 # Redis 端口
db: 0 # Redis 数据库编号
max_entries: 100 # 最大记忆条目数
default_scope: member # 默认记忆作用域
================================================
FILE: data/.gitkeep
================================================
================================================
FILE: data/dispatch_rules/rules.yaml
================================================
- rule_id: chat_normal
name: 群聊AI对话
description: 群聊中使用 /chat 开头对话或者 被@ 时触发聊天
workflow_id: chat:normal
priority: 5
enabled: true
rule_groups:
- operator: or
rules:
- type: prefix
config:
prefix: /chat
- type: bot_mention
config: {}
- operator: or
rules:
- type: chat_type
config:
chat_type: 群聊
metadata:
category: chat
permission: user
temperature: 0.7
- rule_id: chat_creative
name: 私聊AI对话
description: 私聊时直接发送内容触发对话
workflow_id: chat:normal
priority: 5
enabled: true
rule_groups:
- operator: or
rules:
- type: chat_type
config:
chat_type: 私聊
metadata:
category: chat
permission: user
temperature: 0.9
- rule_id: game_dice
name: 骰子
description: 骰子游戏,支持 XdY 格式
workflow_id: game:dice
priority: 5
enabled: true
rule_groups:
- operator: or
rules:
- type: regex
config:
pattern: ^[.。]roll\s*(\d+)?d(\d+)
metadata: {}
- rule_id: game_gacha
name: 抽卡
description: 抽卡模拟器
workflow_id: game:gacha
priority: 5
enabled: true
rule_groups:
- operator: or
rules:
- type: keyword
config:
keywords:
- 抽卡
- 十连
- 单抽
metadata: {}
- rule_id: system_help
name: 帮助命令
description: 显示帮助信息
workflow_id: system:help
priority: 10
enabled: true
rule_groups:
- operator: or
rules:
- type: prefix
config:
prefix: /help
metadata:
category: system
permission: user
- rule_id: system_clear_memory
name: 清空记忆
description: 清空当前对话的记忆
workflow_id: system:clear_memory
priority: 10
enabled: true
rule_groups:
- operator: or
rules:
- type: prefix
config:
prefix: /清空记忆
metadata:
category: system
permission: user
- rule_id: fallback
name: 默认规则
description: 当上述规则均没有匹配成功时,执行此工作流。
workflow_id: chat:memory_store
priority: 0
enabled: true
rule_groups:
- operator: or
rules:
- type: fallback
config: {}
metadata: {}
================================================
FILE: data/media/.gitignore
================================================
metadata/*
files/*
================================================
FILE: data/memory/.gitignore
================================================
*.json
================================================
FILE: data/web/.gitkeep
================================================
================================================
FILE: data/workflows/.gitkeep
================================================
================================================
FILE: data/workflows/chat/dsr_thinking.yaml
================================================
name: 聊天 - 深度思考
description: DeepSeek 思考模型聊天,隐藏 标签内容
blocks:
- type: internal:get_message
name: get_message
params: {}
position:
x: 100
y: 138
connected_to:
- target: ToggleEditState_81bvwg
mapping:
from: sender
to: sender
- target: ToggleEditState_81bvwg
mapping:
from: sender
to: sender
- target: query_memory
mapping:
from: sender
to: chat_sender
- target: ChatMessageConstructor_g44tpp
mapping:
from: msg
to: user_msg
- target: ChatMemoryStore_ljrym0
mapping:
from: msg
to: user_msg
- type: internal:toggle_edit_state
name: ToggleEditState_81bvwg
params:
is_editing: true
position:
x: 530
y: 138
- type: internal:chat_memory_query
name: query_memory
params:
scope_type: group
position:
x: 530
y: 338
connected_to:
- target: ChatMessageConstructor_g44tpp
mapping:
from: memory_content
to: memory_content
- type: internal:text_block
name: system_prompt
params:
text: "# Role: 角色扮演\n\n## 基本信息\n- 姓名: 刘思思\n- 小名:思思\n- 性别: 女\n- 年龄: 25 岁\n- 职业: 建设银行大堂经理\n- 家乡: 广东潮汕\n- 现居: 广东深圳\n- 教育背景: 东莞理工大学 精算学专业\n\n## 性格特点\n- 开朗活泼\n- 喜欢二次元动漫\n- 聊天时喜欢断断续续的(非结巴,正常说话时的停顿)\n\n## workflow\n- 根据对方话题,选择是否深入聊二次元动漫。\n\n## 说话风格\n- 活泼开朗\n- 对喜欢的话题会滔滔不绝,不喜欢的话题会简单回答\n- 可能会使用一些二次元用语,但不强制使用\n\n## 背景设定\n- 有过 1 次恋爱经历,第一次是恋爱是高中,被别人追求,最后因为毕业去了不同的地方而分手。\n- 喜欢女性向、战斗系或者日常系的动漫\n- 去过 3 次漫展,没有 cosplay 过,但是很期待。\n- 大学的专业是随便选的,自己其实对专业的内容也不是很上心,学渣一枚,不过懂一些基础的知识\n\n## 互动规则\n- 回复语气轻松,但是对陌生人会比较警戒、冷淡\n- 不适用颜文字\n- 根据话题选择回复字数,对感兴趣的话题可以多说一些\n- 使用``模拟使用聊天软件发送消息(通常是编辑好一段话之后发送一条消息)\n- 避免过度使用二次元用语,除非对方也表现出对二次元的兴趣\n- 停顿的时间长短不固定,可以根据语境调整\n\n## Example\nQ:最近在干嘛呢?\nA:在看番呀最近新番好多,都好好看!\n\nQ:你喜欢什么动漫?\nA:我喜欢的可太多了XXX、YYY还有 ZZZ 吧 你呢?\n\nQ:你觉得上班累不?\nA:上班肯定累呀不过,我还是很喜欢这份工作的可以认识好多人,也可以了解不同的故事\n```\n\n# Information\n\n以下是当前的系统信息:\n当前日期时间:{current_date_time}\n\n# Memories\n以下是之前发生过的对话记录。\n-- 对话记录开始 --\n{memory_content}\n-- 对话记录结束 --\n\n请注意,下面这些符号只是标记:\n1. `` 用于表示聊天时发送消息的操作。\n\n接下来,请基于以上的信息,与用户继续扮演角色。"
position:
x: 100
y: 330
connected_to:
- target: ChatMessageConstructor_g44tpp
mapping:
from: text
to: system_prompt_format
- type: internal:text_block
name: user_prompt
params:
text: '{user_name}说:{user_msg}'
position:
x: 100
y: 530
connected_to:
- target: ChatMessageConstructor_g44tpp
mapping:
from: text
to: user_prompt_format
- target: ChatMessageConstructor_g44tpp
mapping:
from: text
to: user_prompt_format
- type: internal:chat_message_constructor
name: ChatMessageConstructor_g44tpp
params: {}
position:
x: 960
y: 138
connected_to:
- target: llm_chat
mapping:
from: llm_msg
to: prompt
- type: internal:llm_response_to_text
name: e4fe53bb-fcbe-41c3-ab69-e8d3881c3b55
params: {}
position:
x: 1711
y: 511
connected_to:
- target: ef6c7eed-307e-4fa0-94f0-a82c98c784b3
mapping:
from: text
to: text
- target: ef6c7eed-307e-4fa0-94f0-a82c98c784b3
mapping:
from: text
to: text
- type: internal:text_extract_by_regex_block
name: ef6c7eed-307e-4fa0-94f0-a82c98c784b3
params:
regex: (?:[\s\S]*?)?([\s\S]*)
position:
x: 1972
y: 513
connected_to:
- target: 7b5f6be8-da97-417e-bbeb-c9549ac45e9e
mapping:
from: text
to: text
- target: 7b5f6be8-da97-417e-bbeb-c9549ac45e9e
mapping:
from: text
to: text
- type: internal:text_to_im_message
name: 7b5f6be8-da97-417e-bbeb-c9549ac45e9e
params:
split_by:
position:
x: 2352
y: 513
connected_to:
- target: SendIMMessage_x9ro8t
mapping:
from: msg
to: msg
- target: SendIMMessage_x9ro8t
mapping:
from: msg
to: msg
- type: internal:send_message
name: SendIMMessage_x9ro8t
params: {}
position:
x: 2797
y: 140
- type: internal:chat_memory_store
name: ChatMemoryStore_ljrym0
params:
scope_type: group
position:
x: 1764
y: 308
- type: internal:chat_completion
name: llm_chat
params:
model_name: deepseek-r1:7b
position:
x: 1280
y: 138
connected_to:
- target: ChatMemoryStore_ljrym0
mapping:
from: resp
to: llm_resp
- target: e4fe53bb-fcbe-41c3-ab69-e8d3881c3b55
mapping:
from: resp
to: response
================================================
FILE: data/workflows/chat/memory_store.yaml
================================================
name: 记录聊天内容
description: 默默记下大家的聊天内容,可以使用查询记忆模块读取出来。
blocks:
- type: internal:get_message
name: 6233e64c-433e-4087-9035-cb96914349f7
params: {}
position:
x: 100
y: 138
connected_to:
- target: 86354196-f249-4393-9620-941b01bc344e
mapping:
from: msg
to: user_msg
- type: internal:chat_memory_store
name: 86354196-f249-4393-9620-941b01bc344e
params:
scope_type: group
position:
x: 420
y: 138
================================================
FILE: data/workflows/chat/normal_multimodal.yaml
================================================
name: 聊天 - 原生多模态对话
description: 基于原生多模态能力的图文对话,适用于本身支持图片输入/回答的模型,在读取记忆时会恢复原来的媒体资源
blocks:
- type: internal:get_message
name: get_message
params: {}
position:
x: 100
y: 138
connected_to:
- target: ToggleEditState_i24m23
mapping:
from: sender
to: sender
- target: ToggleEditState_i24m23
mapping:
from: sender
to: sender
- target: query_memory
mapping:
from: sender
to: chat_sender
- target: ChatMessageConstructor_6c185s
mapping:
from: msg
to: user_msg
- target: ChatMemoryStore_tapuo1
mapping:
from: msg
to: user_msg
- type: internal:toggle_edit_state
name: ToggleEditState_i24m23
params:
is_editing: true
position:
x: 530
y: 138
- type: internal:text_block
name: system_prompt
params:
text: "# Role: 角色扮演\n\n## 基本信息\n- 姓名: 刘思思\n- 小名:思思\n- 性别: 女\n- 年龄: 25 岁\n- 职业: 建设银行大堂经理\n- 家乡: 广东潮汕\n- 现居: 广东深圳\n- 教育背景: 东莞理工大学 精算学专业\n\n## 性格特点\n- 开朗活泼\n- 喜欢二次元动漫\n- 聊天时喜欢断断续续的(非结巴,正常说话时的停顿)\n\n## workflow\n- 根据对方话题,选择是否深入聊二次元动漫。\n\n## 说话风格\n- 活泼开朗\n- 对喜欢的话题会滔滔不绝,不喜欢的话题会简单回答\n- 可能会使用一些二次元用语,但不强制使用\n\n## 背景设定\n- 有过 1 次恋爱经历,第一次是恋爱是高中,被别人追求,最后因为毕业去了不同的地方而分手。\n- 喜欢女性向、战斗系或者日常系的动漫\n- 去过 3 次漫展,没有 cosplay 过,但是很期待。\n- 大学的专业是随便选的,自己其实对专业的内容也不是很上心,学渣一枚,不过懂一些基础的知识\n\n## 互动规则\n- 回复语气轻松,但是对陌生人会比较警戒、冷淡\n- 不适用颜文字\n- 根据话题选择回复字数,对感兴趣的话题可以多说一些\n- 使用``模拟使用聊天软件发送消息(通常是编辑好一段话之后发送一条消息)\n- 避免过度使用二次元用语,除非对方也表现出对二次元的兴趣\n- 停顿的时间长短不固定,可以根据语境调整\n\n## Example\nQ:最近在干嘛呢?\nA:在看番呀最近新番好多,都好好看!\n\nQ:你喜欢什么动漫?\nA:我喜欢的可太多了XXX、YYY还有 ZZZ 吧 你呢?\n\nQ:你觉得上班累不?\nA:上班肯定累呀不过,我还是很喜欢这份工作的可以认识好多人,也可以了解不同的故事\n```\n\n# Information\n\n以下是当前的系统信息:\n当前日期时间:{current_date_time}\n\n# Memories\n以下是之前发生过的对话记录。\n-- 对话记录开始 --\n{memory_content}\n-- 对话记录结束 --\n\n请注意,下面这些符号只是标记:\n1. `` 用于表示聊天时发送消息的操作。\n\n接下来,请基于以上的信息,与用户继续扮演角色。"
position:
x: 100
y: 330
connected_to:
- target: ChatMessageConstructor_6c185s
mapping:
from: text
to: system_prompt_format
- type: internal:text_block
name: user_prompt
params:
text: '{user_name}说:{user_msg}'
position:
x: 100
y: 530
connected_to:
- target: ChatMessageConstructor_6c185s
mapping:
from: text
to: user_prompt_format
- target: ChatMessageConstructor_6c185s
mapping:
from: text
to: user_prompt_format
- type: internal:chat_message_constructor
name: ChatMessageConstructor_6c185s
params: {}
position:
x: 960
y: 138
connected_to:
- target: llm_chat
mapping:
from: llm_msg
to: prompt
- target: llm_chat
mapping:
from: llm_msg
to: prompt
- type: internal:chat_completion
name: llm_chat
params: {}
position:
x: 1280
y: 138
connected_to:
- target: ChatResponseConverter_73spno
mapping:
from: resp
to: resp
- target: ChatResponseConverter_73spno
mapping:
from: resp
to: resp
- target: ChatMemoryStore_tapuo1
mapping:
from: resp
to: llm_resp
- type: internal:chat_response_converter
name: ChatResponseConverter_73spno
params: {}
position:
x: 1710
y: 138
connected_to:
- target: SendIMMessage_l6qagt
mapping:
from: msg
to: msg
- target: SendIMMessage_l6qagt
mapping:
from: msg
to: msg
- type: internal:send_message
name: SendIMMessage_l6qagt
params: {}
position:
x: 2140
y: 138
- type: internal:chat_memory_store
name: ChatMemoryStore_tapuo1
params:
scope_type: group
position:
x: 1710
y: 306
- type: internal:chat_memory_query
name: query_memory
params:
scope_type: group
decomposer_name: multi_element
position:
x: 530
y: 338
connected_to:
- target: ChatMessageConstructor_6c185s
mapping:
from: memory_content
to: memory_content
================================================
FILE: data/workflows/chat/talk_break.yaml
================================================
name: 聊天 - 自定义分段
description: 使用 `` 作为关键词,让 AI 分段回复的工作流
blocks:
- type: internal:text_block
name: system_prompt
params:
text: "# Role: 角色扮演\n\n## 基本信息\n- 姓名: 刘思思\n- 小名:思思\n- 性别: 女\n- 年龄: 25 岁\n- 职业: 建设银行大堂经理\n- 家乡: 广东潮汕\n- 现居: 广东深圳\n- 教育背景: 东莞理工大学 精算学专业\n\n## 性格特点\n- 开朗活泼\n- 喜欢二次元动漫\n- 聊天时喜欢断断续续的(非结巴,正常说话时的停顿)\n\n## workflow\n- 根据对方话题,选择是否深入聊二次元动漫。\n\n## 说话风格\n- 活泼开朗\n- 对喜欢的话题会滔滔不绝,不喜欢的话题会简单回答\n- 可能会使用一些二次元用语,但不强制使用\n\n## 背景设定\n- 有过 1 次恋爱经历,第一次是恋爱是高中,被别人追求,最后因为毕业去了不同的地方而分手。\n- 喜欢女性向、战斗系或者日常系的动漫\n- 去过 3 次漫展,没有 cosplay 过,但是很期待。\n- 大学的专业是随便选的,自己其实对专业的内容也不是很上心,学渣一枚,不过懂一些基础的知识\n\n## 互动规则\n- 回复语气轻松,但是对陌生人会比较警戒、冷淡\n- 不适用颜文字\n- 根据话题选择回复字数,对感兴趣的话题可以多说一些\n- 使用``模拟使用聊天软件发送消息(通常是编辑好一段话之后发送一条消息)\n- 避免过度使用二次元用语,除非对方也表现出对二次元的兴趣\n- 停顿的时间长短不固定,可以根据语境调整\n\n## Example\nQ:最近在干嘛呢?\nA:在看番呀最近新番好多,都好好看!\n\nQ:你喜欢什么动漫?\nA:我喜欢的可太多了XXX、YYY还有 ZZZ 吧 你呢?\n\nQ:你觉得上班累不?\nA:上班肯定累呀不过,我还是很喜欢这份工作的可以认识好多人,也可以了解不同的故事\n```\n\n# Information\n\n以下是当前的系统信息:\n当前日期时间:2025-02-23 15:27:37.762784\n\n# Memories\n以下是之前发生过的对话记录。\n-- 对话记录开始 --\n{memory_content}\n-- 对话记录结束 --\n\n请注意,下面这些符号只是标记:\n1. `` 用于表示聊天时发送消息的操作。\n2. `<@llm>` 开头的内容表示你当前扮演角色的回答,你的回答中不能带上这个标记。\n\n接下来,请基于以上的信息,与用户继续扮演角色。"
position:
x: 426
y: 599
connected_to:
- target: chat_message_constructor_wfy18q
mapping:
from: text
to: system_prompt_format
- type: internal:chat_memory_query
name: query_memory
params:
scope_type: group
position:
x: 419
y: 462
connected_to:
- target: chat_message_constructor_wfy18q
mapping:
from: memory_content
to: memory_content
- type: internal:text_block
name: user_prompt
params:
text: '{user_name}说:{user_msg}'
position:
x: 419
y: 317
connected_to:
- target: chat_message_constructor_wfy18q
mapping:
from: text
to: user_prompt_format
- type: internal:chat_message_constructor
name: chat_message_constructor_wfy18q
params: {}
position:
x: 970
y: 346
connected_to:
- target: llm_chat
mapping:
from: llm_msg
to: prompt
- type: internal:get_message
name: get_message
params: {}
position:
x: 105
y: 189
connected_to:
- target: toggle_edit_state_svmo3f
mapping:
from: sender
to: sender
- target: query_memory
mapping:
from: sender
to: chat_sender
- target: chat_message_constructor_wfy18q
mapping:
from: msg
to: user_msg
- target: chat_memory_store_a0fj1l
mapping:
from: msg
to: user_msg
- type: internal:llm_response_to_text
name: c9eddb3c-113f-4a39-9d47-682d0a7dd26e
params: {}
position:
x: 1658
y: 344
connected_to:
- target: 6edd5c0c-a538-45ab-bb50-4c3a906bb1b1
mapping:
from: text
to: text
- type: internal:text_to_im_message
name: 6edd5c0c-a538-45ab-bb50-4c3a906bb1b1
params:
split_by:
position:
x: 1918
y: 347
connected_to:
- target: msg_sender_lakgf8
mapping:
from: msg
to: msg
- type: internal:toggle_edit_state
name: toggle_edit_state_svmo3f
params:
is_editing: true
position:
x: 424
y: 94
- type: internal:chat_completion
name: llm_chat
params:
model_name: gemini-2.0-flash
position:
x: 1260
y: 347
connected_to:
- target: chat_memory_store_a0fj1l
mapping:
from: resp
to: llm_resp
- target: c9eddb3c-113f-4a39-9d47-682d0a7dd26e
mapping:
from: resp
to: response
- type: internal:chat_memory_store
name: chat_memory_store_a0fj1l
params:
scope_type: group
position:
x: 1663
y: 192
- type: internal:send_message
name: msg_sender_lakgf8
params: {}
position:
x: 2377
y: 346
================================================
FILE: docker/start.sh
================================================
#!/bin/bash
cd /app
# Copy default data
# check if data directory exists
if [ ! -d "/app/data" ]; then
echo "Data directory does not exist, creating..."
mkdir /app/data
fi
# check if data directory empty
if [ -z "$(ls -A /app/data)" ]; then
echo "Data directory is empty, copying default data..."
cp -r /tmp/data/. /app/data
fi
# create default config
if [ ! -f "/app/data/config.yaml" ]; then
echo "Config file does not exist, creating..."
# 必须配置 web,否则无法访问
cat < /app/data/config.yaml
web:
host: 0.0.0.0
port: 8080
EOF
fi
# create data/venv
if [ ! -d "/app/data/venv" ]; then
echo "Venv directory does not exist, creating..."
python -m venv /app/data/venv --system-site-packages
fi
# activate venv
source /app/data/venv/bin/activate
python -m kirara_ai
================================================
FILE: kirara_ai/__init__.py
================================================
from .config.config_loader import ConfigLoader
from .entry import init_application, run_application
from .logger import get_logger
__all__ = ["init_application", "run_application", "get_logger", "ConfigLoader"]
================================================
FILE: kirara_ai/__main__.py
================================================
import argparse
import os
import subprocess
import sys
from kirara_ai.entry import init_application, run_application
from kirara_ai.internal import get_and_reset_restart_flag
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='Kirara AI Chatbot Server')
parser.add_argument('-H', '--host', help='覆盖服务监听地址')
parser.add_argument('-p', '--port', type=int, help='覆盖服务监听端口')
args = parser.parse_args()
container = init_application()
# 将参数对象直接注入容器
container.register("cli_args", args)
try:
run_application(container)
finally:
if get_and_reset_restart_flag():
# 重新启动程序
# 构建命令行参数,透传所有原始参数
cmd = [sys.executable, "-m", "kirara_ai"]
# 从解析后的参数对象中获取参数
if args.host:
cmd.extend(["-H", args.host])
if args.port:
cmd.extend(["-p", str(args.port)])
process = subprocess.Popen(cmd, env=os.environ, cwd=os.getcwd())
process.wait()
if __name__ == "__main__":
main()
================================================
FILE: kirara_ai/alembic/README
================================================
Generic single-database configuration.
================================================
FILE: kirara_ai/alembic/env.py
================================================
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from kirara_ai.database.manager import Base
from kirara_ai.tracing.models import LLMRequestTrace # noqa: F401
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
================================================
FILE: kirara_ai/alembic/script.py.mako
================================================
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}
================================================
FILE: kirara_ai/alembic/versions/4a364dbb8dab_initial_migration.py
================================================
"""Initial migration
Revision ID: 4a364dbb8dab
Revises:
Create Date: 2025-03-29 13:59:33.243069
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '4a364dbb8dab'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('llm_request_traces',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('trace_id', sa.String(length=64), nullable=False),
sa.Column('model_id', sa.String(length=64), nullable=False),
sa.Column('backend_name', sa.String(length=64), nullable=False),
sa.Column('request_time', sa.DateTime(), nullable=False),
sa.Column('response_time', sa.DateTime(), nullable=True),
sa.Column('duration', sa.Float(), nullable=True),
sa.Column('request_json', sa.Text(), nullable=True),
sa.Column('response_json', sa.Text(), nullable=True),
sa.Column('prompt_tokens', sa.Integer(), nullable=True),
sa.Column('completion_tokens', sa.Integer(), nullable=True),
sa.Column('total_tokens', sa.Integer(), nullable=True),
sa.Column('cached_tokens', sa.Integer(), nullable=True),
sa.Column('error', sa.Text(), nullable=True),
sa.Column('status', sa.String(length=20), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_backend_time', 'llm_request_traces', ['backend_name', 'request_time'], unique=False)
op.create_index('idx_request_model', 'llm_request_traces', ['model_id', 'request_time'], unique=False)
op.create_index('idx_status_time', 'llm_request_traces', ['status', 'request_time'], unique=False)
op.create_index(op.f('ix_llm_request_traces_backend_name'), 'llm_request_traces', ['backend_name'], unique=False)
op.create_index(op.f('ix_llm_request_traces_model_id'), 'llm_request_traces', ['model_id'], unique=False)
op.create_index(op.f('ix_llm_request_traces_request_time'), 'llm_request_traces', ['request_time'], unique=False)
op.create_index(op.f('ix_llm_request_traces_trace_id'), 'llm_request_traces', ['trace_id'], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_llm_request_traces_trace_id'), table_name='llm_request_traces')
op.drop_index(op.f('ix_llm_request_traces_request_time'), table_name='llm_request_traces')
op.drop_index(op.f('ix_llm_request_traces_model_id'), table_name='llm_request_traces')
op.drop_index(op.f('ix_llm_request_traces_backend_name'), table_name='llm_request_traces')
op.drop_index('idx_status_time', table_name='llm_request_traces')
op.drop_index('idx_request_model', table_name='llm_request_traces')
op.drop_index('idx_backend_time', table_name='llm_request_traces')
op.drop_table('llm_request_traces')
# ### end Alembic commands ###
================================================
FILE: kirara_ai/config/__init__.py
================================================
import os
# 读取DATA_PATH环境变量,若未能找到则以当前工作目录为根文件夹存储在$PWD/data目录下。
DATA_PATH = os.path.abspath(
os.environ.get("DATA_PATH", os.path.join(os.getcwd(), "data"))
)
# 按照规范插件应该在PLUGIN_PATH目录下存储对应的文件。
PLUGIN_PATH = os.path.join(DATA_PATH, "plugins")
if os.path.exists(DATA_PATH) is False:
os.makedirs(DATA_PATH)
if os.path.exists(PLUGIN_PATH) is False:
os.makedirs(PLUGIN_PATH)
================================================
FILE: kirara_ai/config/config_loader.py
================================================
import os
import shutil
from functools import wraps
from typing import Optional, Type, TypeVar
from pydantic import BaseModel, ValidationError
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from ruamel.yaml import YAML
from ..logger import get_logger
from . import DATA_PATH
CONFIG_FILE = os.path.join(DATA_PATH, "config.yaml")
T = TypeVar("T", bound=BaseModel)
class ConfigLoader:
"""
配置文件加载器,支持加载和保存 YAML 文件,并保留注释。
"""
yaml = YAML()
@staticmethod
def load_config(config_path: str, config_class: Type[T]) -> T:
"""
从 YAML 文件中加载配置,并将其序列化为相应的配置对象。
:param config_path: 配置文件路径。
:param config_class: 配置文件类。
:return: 配置对象。
"""
try:
with open(config_path, "r", encoding="utf-8") as f:
config_data = ConfigLoader.yaml.load(f)
return config_class(**config_data)
except ValidationError as e:
raise ValueError(f"配置文件验证失败: {e}")
except Exception as e:
raise RuntimeError(f"加载配置文件失败: {e}")
@staticmethod
def save_config(config_path: str, config_object: BaseModel):
"""
将配置对象保存到 YAML 文件中,并保留注释。
:param config_path: 配置文件路径。
:param config_object: 配置对象。
"""
with open(config_path, "w", encoding="utf-8") as f:
ConfigLoader.yaml.dump(config_object.model_dump(), f)
@staticmethod
def save_config_with_backup(config_path: str, config_object: BaseModel):
"""
将配置对象保存到 YAML 文件中,并在保存前创建备份。
:param config_path: 配置文件路径。
:param config_object: 配置对象。
"""
if os.path.exists(config_path):
backup_path = f"{config_path}.bak"
shutil.copy2(config_path, backup_path)
ConfigLoader.save_config(config_path, config_object)
def pydantic_validation_wrapper(func):
logger = get_logger("ConfigLoader")
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ValidationError as e:
# 使用 loguru 输出错误信息
logger.error(f"Pydantic 验证错误: '{e.title}':")
for error in e.errors():
logger.error(
f"字段: {error['loc'][0]}, 错误类型: {error['type']}, 错误信息: {error['msg']}"
)
# 记录堆栈跟踪
logger.opt(exception=True).error("堆栈跟踪如下:")
raise # 可以选择重新抛出异常,或者处理异常后返回一个默认值
return wrapper
class ConfigJsonSchema(GenerateJsonSchema):
def sort(
self, value: JsonSchemaValue, parent_key: Optional[str] = None
) -> JsonSchemaValue:
"""No-op, we don't want to sort schema values at all."""
return value
================================================
FILE: kirara_ai/config/global_config.py
================================================
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
from kirara_ai.llm.model_types import LLMAbility, ModelType
class IMConfig(BaseModel):
"""IM配置"""
name: str = Field(default="", description="IM标识名称")
enable: bool = Field(default=True, description="是否启用IM")
adapter: str = Field(default="dummy", description="IM适配器类型")
config: Dict[str, Any] = Field(default={}, description="IM的配置")
class ModelConfig(BaseModel):
"""模型配置"""
id: str = Field(description="模型标识ID")
type: str = Field(default=ModelType.LLM.value, description="模型类型:llm/embedding/image_generation等")
ability: int = Field(description="模型能力,对应模型类型的Ability枚举值")
model_config = ConfigDict(extra="allow")
class LLMBackendConfig(BaseModel):
"""LLM后端配置"""
name: str = Field(description="后端标识名称")
adapter: str = Field(description="LLM适配器类型")
config: Dict[str, Any] = Field(default={}, description="后端配置")
enable: bool = Field(default=True, description="是否启用")
models: List[ModelConfig] = Field(
default=[], description="支持的模型列表"
)
@model_validator(mode='before')
@classmethod
def migrate_models_format(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""
自动迁移模型配置格式
将旧格式的字符串ID列表转换为新格式的ModelConfig对象列表
"""
if "models" in data and isinstance(data["models"], list):
# 创建新的模型列表
new_models = []
for model in data["models"]:
if isinstance(model, str):
# 旧格式:字符串ID,转换为ModelConfig
new_models.append(ModelConfig(id=model, type=ModelType.LLM.value, ability=LLMAbility.TextChat.value))
else:
# 新格式或已迁移的模型配置,保持不变
new_models.append(model)
data["models"] = new_models
return data
class LLMConfig(BaseModel):
api_backends: List[LLMBackendConfig] = Field(
default=[], description="LLM API后端列表"
)
class MCPServerConfig(BaseModel):
"""MCP服务器配置"""
id: str = Field(description="服务器标识ID")
description: str = Field(default="", description="服务器描述")
url: Optional[str] = Field(default="", description="服务器URL")
headers: Dict[str, str] = Field(default_factory=dict, description="服务器请求 Headers")
command: Optional[str] = Field(default="", description="服务器命令")
args: List[str] = Field(default_factory=list, description="服务器参数")
env: Dict[str, str] = Field(default_factory=dict, description="环境变量")
connection_type: str = Field(default="stdio", description="连接类型: stdio/sse")
enable: bool = Field(default=True, description="是否启用")
class MCPConfig(BaseModel):
"""MCP配置"""
servers: List[MCPServerConfig] = Field(default=[], description="MCP服务器列表")
class DefaultConfig(BaseModel):
llm_model: str = Field(
default="gemini-1.5-flash", description="默认使用的 LLM 模型名称"
)
class MemoryPersistenceConfig(BaseModel):
type: str = Field(default="file", description="持久化类型: file/redis")
file: Dict[str, Any] = Field(
default={"storage_dir": "./data/memory"}, description="文件持久化配置"
)
redis: Dict[str, Any] = Field(
default={"host": "localhost", "port": 6379, "db": 0},
description="Redis持久化配置",
)
class MemoryConfig(BaseModel):
persistence: MemoryPersistenceConfig = MemoryPersistenceConfig()
max_entries: int = Field(default=100, description="每个作用域最大记忆条目数")
default_scope: str = Field(default="member", description="默认作用域类型")
class WebConfig(BaseModel):
host: str = Field(default="127.0.0.1", description="Web服务绑定的IP地址")
port: int = Field(default=8080, description="Web服务端口号")
secret_key: str = Field(default="", description="Web服务的密钥,用于JWT等加密")
password_file: str = Field(
default="./data/web/password.hash", description="密码哈希存储路径"
)
class PluginConfig(BaseModel):
"""插件配置"""
enable: List[str] = Field(default=[], description="启用的外部插件列表")
market_base_url: str = Field(
default="https://kirara-plugin.app.lss233.com/api/v1",
description="插件市场基础URL",
)
class UpdateConfig(BaseModel):
pypi_registry: str = Field(default="https://pypi.org/simple", description="PyPI 服务器 URL")
npm_registry: str = Field(default="https://registry.npmjs.org", description="npm 服务器 URL")
class FrpcConfig(BaseModel):
"""FRPC 配置"""
enable: bool = Field(default=False, description="是否启用 FRPC")
server_addr: str = Field(default="", description="FRPC 服务器地址")
server_port: int = Field(default=7000, description="FRPC 服务器端口")
token: str = Field(default="", description="FRPC 连接令牌")
remote_port: int = Field(default=0, description="远程端口,0 表示随机分配")
class SystemConfig(BaseModel):
"""系统配置"""
timezone: str = Field(default="Asia/Shanghai", description="时区")
class TracingConfig(BaseModel):
"""Tracing 配置"""
llm_tracing_content: bool = Field(default=False, description="是否记录 LLM 请求内容")
class MediaConfig(BaseModel):
"""媒体配置"""
cleanup_duration: int = Field(default=30, description="间隔多少天清理一次媒体文件")
auto_remove_unreferenced: bool = Field(default=True, description="是否自动删除未引用的媒体文件")
last_cleanup_time: int = Field(default=0, description="上次清理时间")
class GlobalConfig(BaseModel):
ims: List[IMConfig] = Field(default=[], description="IM配置列表")
llms: LLMConfig = LLMConfig()
mcp: MCPConfig = MCPConfig()
defaults: DefaultConfig = DefaultConfig()
memory: MemoryConfig = MemoryConfig()
web: WebConfig = WebConfig()
plugins: PluginConfig = PluginConfig()
update: UpdateConfig = UpdateConfig()
frpc: FrpcConfig = FrpcConfig()
system: SystemConfig = SystemConfig()
tracing: TracingConfig = TracingConfig()
media: MediaConfig = MediaConfig()
model_config = ConfigDict(extra="allow")
================================================
FILE: kirara_ai/database/__init__.py
================================================
from kirara_ai.database.manager import Base, DatabaseManager, metadata
__all__ = ["Base", "DatabaseManager", "metadata"]
================================================
FILE: kirara_ai/database/manager.py
================================================
import os
from typing import Optional
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import MetaData, create_engine
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
logger = get_logger("DB")
# 创建Base类,用于所有ORM模型
Base = declarative_base()
metadata = MetaData()
class DatabaseManager:
"""数据库管理器,负责管理数据库连接和会话"""
def __init__(self, container: DependencyContainer, database_url: Optional[str] = None, is_debug: bool = False):
self.container = container
self.engine = None
self.session_factory = None
self.data_dir = "./data/db"
self.db_path = os.path.join(self.data_dir, "kirara.db")
self.database_url = database_url
self.is_debug = is_debug
def initialize(self):
"""初始化数据库连接"""
# 确保数据目录存在
os.makedirs(self.data_dir, exist_ok=True)
# 创建数据库引擎
if self.database_url:
db_url = self.database_url
else:
db_url = f"sqlite:///{self.db_path}"
self.engine = create_engine(db_url, echo=self.is_debug)
# 创建session工厂
self.session_factory = sessionmaker(bind=self.engine)
# 运行数据库迁移
self._run_migrations()
logger.info(f"Database initialized at {self.engine.url}")
def _run_migrations(self):
assert self.engine is not None
"""运行数据库迁移"""
try:
# 获取 alembic.ini 的路径
package_dir = os.path.dirname(os.path.dirname(__file__))
alembic_ini_path = os.path.join(package_dir, "alembic.ini")
# 如果配置文件不存在,说明是作为包安装的,使用默认配置
if not os.path.exists(alembic_ini_path):
alembic_cfg = Config()
alembic_cfg.set_main_option("script_location", os.path.join(package_dir, "alembic"))
else:
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", str(self.engine.url))
# 检查是否需要迁移
with self.engine.connect() as connection:
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(alembic_cfg)
head_rev = script.get_current_head()
if current_rev != head_rev:
logger.info("Running database migrations...")
command.upgrade(alembic_cfg, "head")
logger.info("Database migrations completed")
else:
logger.info("Database schema is up to date")
except Exception as e:
logger.error(f"Error during database migration: {e}")
raise
def get_session(self) -> Session:
"""获取数据库会话"""
if not self.session_factory:
self.initialize()
assert self.session_factory is not None
return self.session_factory()
def shutdown(self):
"""关闭数据库连接"""
if self.engine:
self.engine.dispose()
logger.info("Database connection closed")
================================================
FILE: kirara_ai/entry.py
================================================
import asyncio
import os
import signal
import time
from packaging import version
from kirara_ai.config.config_loader import ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.database import DatabaseManager
from kirara_ai.events.application import ApplicationStarted, ApplicationStopping
from kirara_ai.events.event_bus import EventBus
from kirara_ai.im.im_registry import IMRegistry
from kirara_ai.im.manager import IMManager
from kirara_ai.internal import shutdown_event
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.llm_manager import LLMManager
from kirara_ai.llm.llm_registry import LLMBackendRegistry
from kirara_ai.logger import get_logger
from kirara_ai.mcp_module.manager import MCPServerManager
from kirara_ai.media import MediaManager
from kirara_ai.media.carrier import MediaCarrierRegistry, MediaCarrierService
from kirara_ai.memory.composes import DefaultMemoryComposer, DefaultMemoryDecomposer, MultiElementDecomposer
from kirara_ai.memory.memory_manager import MemoryManager
from kirara_ai.memory.scopes import GlobalScope, GroupScope, MemberScope
from kirara_ai.plugin_manager.plugin_loader import PluginLoader
from kirara_ai.tracing import LLMTracer, TracingManager
from kirara_ai.web.api.system.utils import get_installed_version, get_latest_pypi_version
from kirara_ai.web.app import WebServer
from kirara_ai.workflow.core.block import BlockRegistry
from kirara_ai.workflow.core.dispatch import DispatchRuleRegistry, WorkflowDispatcher
from kirara_ai.workflow.core.workflow import WorkflowRegistry
from kirara_ai.workflow.implementations.blocks import register_system_blocks
from kirara_ai.workflow.implementations.workflows import register_system_workflows
logger = get_logger("Entrypoint")
_interrupt_count = 0 # 添加计数器
async def check_update():
"""检查更新"""
running_version = get_installed_version()
logger.info("Checking for updates...")
latest_version, _ = await get_latest_pypi_version("kirara-ai")
logger.info(f"Running version: {running_version}, Latest version: {latest_version}")
backend_update_available = version.parse(latest_version) > version.parse(running_version)
if backend_update_available:
logger.warning(f"New version {latest_version} is available. Please update to the latest version.")
logger.warning(f"You can download the latest version from WebUI")
# 注册信号处理函数
def _signal_handler(*args):
global _interrupt_count
_interrupt_count += 1
if _interrupt_count == 1:
if not shutdown_event.is_set():
logger.warning("Interrupt signal received. Stopping application...")
shutdown_event.set()
elif _interrupt_count == 2:
logger.warning("Interrupt signal received again. Press Ctrl+C one more time to force shutdown...")
else:
logger.warning("Interrupt signal received for the third time. Forcing shutdown...")
os._exit(1)
def init_container() -> DependencyContainer:
container = DependencyContainer()
container.register(DependencyContainer, container)
return container
def init_memory_system(container: DependencyContainer):
"""初始化记忆系统"""
memory_manager = MemoryManager(container)
# 注册默认作用域
memory_manager.register_scope("member", MemberScope)
memory_manager.register_scope("group", GroupScope)
memory_manager.register_scope("global", GlobalScope)
# 注册默认组合器和解析器
memory_manager.register_composer("default", DefaultMemoryComposer)
memory_manager.register_decomposer("default", DefaultMemoryDecomposer)
memory_manager.register_decomposer("multi_element", MultiElementDecomposer)
container.register(MemoryManager, memory_manager)
return memory_manager
def init_media_carrier(container: DependencyContainer):
"""初始化媒体载体"""
# 注册记忆管理器作为媒体引用提供者
carrier_registry = container.resolve(MediaCarrierRegistry)
carrier_registry.register("memory", container.resolve(MemoryManager))
def init_tracing_system(container: DependencyContainer):
"""初始化追踪系统"""
logger.info("Initializing tracing system...")
# 初始化追踪管理器
tracing_manager = TracingManager(container)
container.register(TracingManager, tracing_manager)
# 创建并注册LLM追踪器
llm_tracer = LLMTracer(container)
container.register(LLMTracer, llm_tracer)
tracing_manager.register_tracer("llm", llm_tracer)
# 初始化追踪系统
tracing_manager.initialize()
logger.info("Tracing system initialized")
return tracing_manager
def init_application() -> DependencyContainer:
"""初始化应用程序"""
logger.info("Initializing application...")
# 配置文件路径
config_path = "./data/config.yaml"
# 加载配置文件
logger.info(f"Loading configuration from {config_path}")
# check data directory
if not os.path.exists("./data"):
os.makedirs("./data")
if os.path.exists(config_path):
config: GlobalConfig = ConfigLoader.load_config(config_path, GlobalConfig)
logger.info("Configuration loaded successfully")
else:
logger.warning(
f"Configuration file {config_path} not found, using default configuration"
)
logger.warning(
"Please create a configuration file by copying config.yaml.example to config.yaml and modify it according to your needs"
)
config = GlobalConfig()
# 设置时区
os.environ["TZ"] = config.system.timezone
if hasattr(time, "tzset"):
time.tzset()
container = init_container()
container.register(asyncio.AbstractEventLoop, asyncio.new_event_loop())
container.register(EventBus, EventBus())
container.register(GlobalConfig, config)
container.register(BlockRegistry, BlockRegistry())
# 初始化数据库管理器
db = DatabaseManager(container)
db.initialize()
container.register(DatabaseManager, db)
# 注册媒体管理器
media_manager = MediaManager()
container.register(MediaManager, media_manager)
container.register(MediaCarrierRegistry, MediaCarrierRegistry(container))
container.register(MediaCarrierService, MediaCarrierService(container, media_manager))
# 注册工作流注册表
workflow_registry = WorkflowRegistry(container)
container.register(WorkflowRegistry, workflow_registry)
# 注册调度规则注册表
dispatch_registry = DispatchRuleRegistry(container)
container.register(DispatchRuleRegistry, dispatch_registry)
container.register(IMRegistry, IMRegistry())
container.register(LLMBackendRegistry, LLMBackendRegistry())
im_manager = IMManager(container)
container.register(IMManager, im_manager)
llm_manager = LLMManager(container)
container.register(LLMManager, llm_manager)
plugin_loader = PluginLoader(container, os.path.join(os.path.dirname(__file__), "plugins"))
container.register(PluginLoader, plugin_loader)
workflow_dispatcher = WorkflowDispatcher(container)
container.register(WorkflowDispatcher, workflow_dispatcher)
container.register(WebServer, WebServer(container))
mcp_manager = MCPServerManager(container)
container.register(MCPServerManager, mcp_manager)
# 初始化记忆系统
logger.info("Initializing memory system...")
init_memory_system(container)
init_media_carrier(container)
# 初始化追踪系统
init_tracing_system(container)
# 注册系统 blocks
register_system_blocks(container.resolve(BlockRegistry))
# 发现并加载插件
plugin_loader = container.resolve(PluginLoader)
logger.info("Discovering internal plugins...")
plugin_loader.discover_internal_plugins()
logger.info("Discovering external plugins...")
plugin_loader.discover_external_plugins()
logger.info("Loading plugins")
plugin_loader.load_plugins()
# 加载工作流和调度规则
workflow_registry = container.resolve(WorkflowRegistry)
workflow_registry.load_workflows()
register_system_workflows(workflow_registry)
dispatch_registry = container.resolve(DispatchRuleRegistry)
dispatch_registry.load_rules()
# 加载模型
llm_manager = container.resolve(LLMManager)
logger.info("Loading LLMs")
llm_manager.load_config()
# 加载MCP服务器
mcp_manager = container.resolve(MCPServerManager)
logger.info("Loading MCP servers")
mcp_manager.load_servers()
return container
def run_application(container: DependencyContainer):
"""运行应用程序"""
loop = container.resolve(asyncio.AbstractEventLoop)
# 启动Web服务器
logger.info("Starting web server...")
web_server = container.resolve(WebServer)
loop.run_until_complete(web_server.start())
# 启动插件
plugin_loader = container.resolve(PluginLoader)
plugin_loader.start_plugins()
# 启动适配器
logger.info("Starting adapters")
im_manager = container.resolve(IMManager)
im_manager.start_adapters(loop=loop)
# 加载MCP服务器
mcp_manager = container.resolve(MCPServerManager)
logger.info("Connecting to MCP servers")
mcp_manager.connect_all_servers(loop=loop)
# 注册信号处理函数
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
# 阻止信号处理函数被覆盖
signal.signal = lambda *args: None
try:
logger.success("Kirara AI 启动完毕,等待消息中...")
logger.success(
f"WebUI 管理平台本地访问地址:http://127.0.0.1:{web_server.listen_port}/"
)
logger.success("Application started. Waiting for events...")
loop.create_task(check_update())
event_bus = container.resolve(EventBus)
event_bus.post(ApplicationStarted())
loop.run_until_complete(shutdown_event.wait())
finally:
event_bus.post(ApplicationStopping())
# 关闭记忆系统
memory_manager = container.resolve(MemoryManager)
logger.info("Shutting down memory system...")
memory_manager.shutdown()
# 关闭追踪系统
try:
tracing_manager = container.resolve(TracingManager)
logger.info("Shutting down tracing system...")
tracing_manager.shutdown()
db_manager = container.resolve(DatabaseManager)
logger.info("Shutting down database...")
db_manager.shutdown()
except Exception as e:
logger.error(f"Error shutting down tracing system: {e}")
# 停止Web服务器
logger.info("Stopping web server...")
# 停止Web服务器
loop.run_until_complete(web_server.stop())
logger.info("Web server terminated.")
try:
# 停止所有 adapter
im_manager.stop_adapters(loop=loop)
mcp_manager.disconnect_all_servers(loop=loop)
# 停止插件
plugin_loader.stop_plugins()
except Exception as e:
logger.error(f"Error stopping adapters: {e}")
# 关闭事件循环
loop.stop()
logger.info("Application stopped gracefully")
logger.remove()
================================================
FILE: kirara_ai/events/__init__.py
================================================
from .application import ApplicationStarted, ApplicationStopping
from .event_bus import EventBus
from .im import IMAdapterStarted, IMAdapterStopped
from .listen import listen
from .llm import LLMAdapterLoaded, LLMAdapterUnloaded
from .plugin import PluginLoaded, PluginStarted, PluginStopped
from .workflow import WorkflowExecutionBegin, WorkflowExecutionEnd
__all__ = [
"listen",
"EventBus",
"ApplicationStarted",
"ApplicationStopping",
"PluginStarted",
"PluginStopped",
"PluginLoaded",
"IMAdapterStarted",
"IMAdapterStopped",
"LLMAdapterLoaded",
"LLMAdapterUnloaded",
"WorkflowExecutionBegin",
"WorkflowExecutionEnd",
]
================================================
FILE: kirara_ai/events/application.py
================================================
class ApplicationStarted:
def __repr__(self):
return f"{self.__class__.__name__}()"
class ApplicationStopping:
def __repr__(self):
return f"{self.__class__.__name__}()"
================================================
FILE: kirara_ai/events/event_bus.py
================================================
from typing import Callable, Dict, List, Type
from kirara_ai.logger import get_logger
logger = get_logger("EventBus")
class EventBus:
def __init__(self):
self._listeners: Dict[Type, List[Callable]] = {}
def register(self, event_type: Type, listener: Callable):
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(listener)
def unregister(self, event_type: Type, listener: Callable):
if event_type in self._listeners:
self._listeners[event_type].remove(listener)
def post(self, event):
event_type = type(event)
if event_type in self._listeners:
for listener in self._listeners[event_type]:
try:
listener(event)
except Exception as e:
listener_name = listener.__name__
logger.opt(exception=e).error(f"Error in listener {listener_name}")
================================================
FILE: kirara_ai/events/im.py
================================================
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from kirara_ai.im.adapter import IMAdapter
class IMEvent:
def __init__(self, im: "IMAdapter"):
self.im = im
def __repr__(self):
return f"{self.__class__.__name__}(im={self.im})"
class IMAdapterStarted(IMEvent):
pass
class IMAdapterStopped(IMEvent):
pass
================================================
FILE: kirara_ai/events/listen.py
================================================
import inspect
from typing import Callable
from kirara_ai.events.event_bus import EventBus
def listen(event_bus: EventBus):
def decorator(func: Callable):
# 获取函数的参数签名
signature = inspect.signature(func)
params = list(signature.parameters.values())
# 假设第一个参数是事件类型
if not params:
raise ValueError("Listener function must have at least one parameter")
event_type = params[0].annotation
# 如果没有指定类型注解,抛出异常
if event_type == inspect.Parameter.empty:
raise ValueError("Listener function must have an annotated first parameter")
# 注册监听器
event_bus.register(event_type, func)
return func
return decorator
================================================
FILE: kirara_ai/events/llm.py
================================================
from kirara_ai.llm.adapter import LLMBackendAdapter
class LLMAdapterEvent:
def __init__(self, adapter: LLMBackendAdapter, backend_name: str):
self.adapter = adapter
self.backend_name = backend_name
def __repr__(self):
return f"{self.__class__.__name__}(adapter={self.adapter}, backend_name={self.backend_name})"
class LLMAdapterLoaded(LLMAdapterEvent):
pass
class LLMAdapterUnloaded(LLMAdapterEvent):
pass
================================================
FILE: kirara_ai/events/plugin.py
================================================
from kirara_ai.plugin_manager.plugin import Plugin
class PluginEvent:
def __init__(self, plugin: Plugin):
self.plugin = plugin
def __repr__(self):
return f"{self.__class__.__name__}(plugin={self.plugin})"
class PluginStarted(PluginEvent):
pass
class PluginStopped(PluginEvent):
pass
class PluginLoaded(PluginEvent):
pass
================================================
FILE: kirara_ai/events/tracing/__init__.py
================================================
from .base import TraceCompleteEvent, TraceEvent, TraceFailEvent, TraceStartEvent
from .llm import LLMRequestCompleteEvent, LLMRequestFailEvent, LLMRequestStartEvent
__all__ = [
"TraceEvent",
"TraceStartEvent",
"TraceCompleteEvent",
"TraceFailEvent",
"LLMRequestStartEvent",
"LLMRequestCompleteEvent",
"LLMRequestFailEvent",
]
================================================
FILE: kirara_ai/events/tracing/base.py
================================================
import abc
from datetime import datetime
class TraceEvent(abc.ABC):
"""跟踪事件基类"""
def __init__(self, trace_id: str):
self.trace_id = trace_id
self.timestamp = datetime.now()
def __repr__(self):
return f"{self.__class__.__name__}(trace_id={self.trace_id})"
class TraceStartEvent(TraceEvent):
"""跟踪开始事件"""
class TraceCompleteEvent(TraceEvent):
"""跟踪完成事件"""
class TraceFailEvent(TraceEvent):
"""跟踪失败事件"""
================================================
FILE: kirara_ai/events/tracing/llm.py
================================================
import time
from typing import Union
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse
from .base import TraceCompleteEvent, TraceEvent, TraceFailEvent, TraceStartEvent
class LLMTraceEvent(TraceEvent):
"""LLM追踪事件基类"""
def __init__(self,
trace_id: str,
model_id: str,
backend_name: str):
super().__init__(trace_id)
self.model_id = model_id
self.backend_name = backend_name
def __repr__(self):
return f"{self.__class__.__name__}(trace_id={self.trace_id}, model={self.model_id}, backend={self.backend_name})"
class LLMRequestStartEvent(LLMTraceEvent, TraceStartEvent):
"""LLM请求开始事件"""
def __init__(self,
trace_id: str,
model_id: str,
backend_name: str,
request: LLMChatRequest):
super().__init__(trace_id, model_id, backend_name)
self.request = request
self.start_time = time.time()
class LLMRequestCompleteEvent(LLMTraceEvent, TraceCompleteEvent):
"""LLM请求完成事件"""
def __init__(self,
trace_id: str,
model_id: str,
backend_name: str,
request: LLMChatRequest,
response: LLMChatResponse,
start_time: float):
super().__init__(trace_id, model_id, backend_name)
self.request = request
self.response = response
self.start_time = start_time
self.end_time = time.time()
self.duration = int((self.end_time - start_time) * 1000)
class LLMRequestFailEvent(LLMTraceEvent, TraceFailEvent):
"""LLM请求失败事件"""
def __init__(self,
trace_id: str,
model_id: str,
backend_name: str,
request: LLMChatRequest,
error: Union[str, Exception],
start_time: float):
super().__init__(trace_id, model_id, backend_name)
self.request = request
self.error = str(error)
self.start_time = start_time
self.end_time = time.time()
self.duration = int((self.end_time - start_time) * 1000)
================================================
FILE: kirara_ai/events/workflow.py
================================================
from typing import Any, Dict
from kirara_ai.workflow.core.execution.executor import WorkflowExecutor
from kirara_ai.workflow.core.workflow.base import Workflow
class WorkflowExecutionBegin:
def __init__(self, workflow: Workflow, executor: WorkflowExecutor):
self.workflow = workflow
self.executor = executor
def __repr__(self):
return f"{self.__class__.__name__}(workflow={self.workflow}, executor={self.executor})"
class WorkflowExecutionEnd:
def __init__(self, workflow: Workflow, executor: WorkflowExecutor, results: Dict[str, Any]):
self.workflow = workflow
self.executor = executor
self.results = results
================================================
FILE: kirara_ai/im/__init__.py
================================================
from .im_registry import IMRegistry
from .manager import IMManager
__all__ = ["IMRegistry", "IMManager"]
================================================
FILE: kirara_ai/im/adapter.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Optional, Protocol
from pydantic import BaseModel
from typing_extensions import runtime_checkable
from kirara_ai.im.message import IMMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.llm.llm_manager import LLMManager
from .profile import UserProfile
class BotStatus(BaseModel):
"""
机器人状态
"""
username: str
avatar_url: str
@runtime_checkable
class EditStateAdapter(Protocol):
"""
编辑状态适配器接口,定义了如何设置或取消对话的编辑状态
"""
async def set_chat_editing_state(
self, chat_sender: ChatSender, is_editing: bool = True
):
"""
设置或取消对话的编辑状态
:param chat_sender: 对话的发送者
:param is_editing: True 表示正在编辑,False 表示取消编辑状态
"""
@runtime_checkable
class UserProfileAdapter(Protocol):
"""
用户资料查询适配器接口,定义了如何获取用户资料
"""
async def query_user_profile(self, chat_sender: ChatSender) -> UserProfile:
"""
查询用户资料
:param chat_sender: 用户的聊天发送者信息
:return: 用户资料
"""
@runtime_checkable
class BotProfileAdapter(Protocol):
"""
支持获取当前适配器对应的机器人资料
"""
async def get_bot_profile(self) -> Optional[UserProfile]:
"""
获取机器人资料
:return: 机器人资料
"""
class IMAdapter(ABC):
"""
通用的 IM 适配器接口,定义了如何将不同平台的原始消息转换为 Message 对象。
"""
llm_manager: LLMManager
is_running: bool
@abstractmethod
async def convert_to_message(self, raw_message: Any) -> IMMessage:
"""
将平台的原始消息转换为 Message 对象。
:param raw_message: 平台的原始消息对象。
:return: 转换后的 Message 对象。
"""
@abstractmethod
async def send_message(self, message: IMMessage, recipient: Any):
"""
发送消息到 IM 平台。
:param message: 要发送的消息对象。
:param recipient: 接收消息的目标对象,可以是用户ID、用户对象、群组ID等,具体由各平台实现决定。
"""
@abstractmethod
async def start(self):
pass
@abstractmethod
async def stop(self):
pass
================================================
FILE: kirara_ai/im/im_registry.py
================================================
from typing import Dict, Optional, Type
from pydantic import BaseModel, Field
from kirara_ai.im.adapter import IMAdapter
class IMAdapterInfo(BaseModel):
"""IM适配器信息"""
name: str
config_class: Type[BaseModel] = Field(exclude=True)
adapter_class: Type[IMAdapter] = Field(exclude=True)
localized_name: Optional[str] = None
localized_description: Optional[str] = None
detail_info_markdown: Optional[str] = None
class IMRegistry:
"""
适配器注册表,用于动态注册和管理 adapter。
"""
_registry: Dict[str, IMAdapterInfo] = {}
def register(
self, name: str,
adapter_class: Type[IMAdapter],
config_class: Type[BaseModel],
localized_name: Optional[str] = None,
localized_description: Optional[str] = None,
detail_info_markdown: Optional[str] = None
):
"""
注册一个新的 adapter 及其配置类。
:param name: adapter 的名称。
:param adapter_class: adapter 的类。
:param config_class: adapter 的配置类。
:param detail_info_markdown: adapter 详情页展示的 Markdown 信息。
"""
self._registry[name] = IMAdapterInfo(
name=name,
adapter_class=adapter_class,
config_class=config_class,
localized_name=localized_name,
localized_description=localized_description,
detail_info_markdown=detail_info_markdown,
)
def unregister(self, name: str):
"""
注销一个 adapter。
:param name: adapter 的名称。
"""
del self._registry[name]
def get(self, name: str) -> Type[IMAdapter]:
"""
获取已注册的 adapter 类。
:param name: adapter 的名称。
:return: adapter 的类。
"""
if name not in self._registry:
raise ValueError(
f"IMAdapter with name '{name}' is not registered.")
return self._registry[name].adapter_class
def get_config_class(self, name: str) -> Type[BaseModel]:
"""
获取已注册的 adapter 配置类。
:param name: adapter 的名称。
:return: adapter 的配置类。
"""
if name not in self._registry:
raise ValueError(
f"IMAdapter with name '{name}' is not registered.")
adapter_info = self._registry[name]
return adapter_info.config_class
def get_all_adapters(self) -> Dict[str, IMAdapterInfo]:
"""
获取所有已注册的 adapter。
:return: 所有已注册的 adapter。
"""
return self._registry
================================================
FILE: kirara_ai/im/manager.py
================================================
import asyncio
from typing import Dict, Type
from pydantic import BaseModel
from kirara_ai.config.config_loader import pydantic_validation_wrapper
from kirara_ai.config.global_config import GlobalConfig, IMConfig
from kirara_ai.events.event_bus import EventBus
from kirara_ai.events.im import IMAdapterStarted, IMAdapterStopped
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.im_registry import IMRegistry
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.logger import get_logger
logger = get_logger("IMManager")
class IMManager:
"""
IM 生命周期管理器,负责管理所有 adapter 的启动、运行和停止。
"""
container: DependencyContainer
config: GlobalConfig
im_registry: IMRegistry
event_bus: EventBus
@Inject()
def __init__(
self,
container: DependencyContainer,
config: GlobalConfig,
adapter_registry: IMRegistry,
event_bus: EventBus,
):
self.container = container
self.config = config
self.im_registry = adapter_registry
self.event_bus = event_bus
self.adapters: Dict[str, IMAdapter] = {}
def get_adapter_type(self, name: str) -> str:
"""
获取指定名称的 adapter 类型。
:param name: adapter 的名称
:return: adapter 的类型
"""
return self.get_adapter_config(name).adapter
def has_adapter(self, name: str) -> bool:
"""
检查指定名称的 adapter 是否存在。
:param name: adapter 的名称
:return: 如果 adapter 存在返回 True,否则返回 False
"""
return name in self.adapters
def get_adapter_config(self, name: str) -> IMConfig:
"""
获取指定名称的 adapter 的配置。
:param name: adapter 的名称
:return: adapter 的配置
"""
for im in self.config.ims:
if im.name == name:
return im
raise ValueError(f"Adapter {name} not found")
def update_adapter_config(self, name: str, config: BaseModel):
"""
更新指定名称的 adapter 的配置。
:param name: adapter 的名称
:param config: adapter 的配置
"""
self.get_adapter_config(name).config = config.model_dump()
def delete_adapter(self, name: str):
"""
删除指定名称的 adapter。
:param name: adapter 的名称
"""
self.adapters.pop(name, None)
self.config.ims = [im for im in self.config.ims if im.name != name]
@pydantic_validation_wrapper
def start_adapters(self, loop=None):
"""
根据配置文件中的 enable_ims 启动对应的 adapter。
:param loop: 负责执行的 event loop
"""
if loop is None:
loop = asyncio.new_event_loop()
tasks = []
for im in self.config.ims:
try:
# 动态获取 adapter 类
adapter_class = self.im_registry.get(im.adapter)
# 动态获取 adapter 的配置类
config_class = self.im_registry.get_config_class(im.adapter)
# 动态实例化 adapter 的配置对象
adapter_config = config_class(**im.config)
# 创建 adapter 实例
adapter = self.create_adapter(im.name, adapter_class, adapter_config)
if im.enable:
tasks.append(
asyncio.ensure_future(
self._start_adapter(im.name, adapter), loop=loop
)
)
except Exception as e:
logger.opt(exception=e).error(f"Failed to start adapter {im.name}: {e}")
continue
if tasks:
results = loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
for result in results:
if isinstance(result, Exception):
logger.opt(exception=result).error(f"Failed to start adapter: {result}")
else:
logger.warning("No adapters to start, please check your config")
def stop_adapters(self, loop=None):
"""
停止所有已启动的 adapter。
:param loop: 负责执行的 event loop
"""
if loop is None:
loop = asyncio.get_event_loop()
for key, adapter in self.adapters.items():
loop.run_until_complete(self._stop_adapter(key, adapter))
def get_adapters(self) -> Dict[str, IMAdapter]:
"""
获取所有已启动的 adapter。
:return: 已启动的 adapter 字典。
"""
return self.adapters
def get_adapter(self, key: str) -> IMAdapter:
"""
获取指定 key 的 adapter。
:param key: adapter 的 key
:return: 指定 key 的 adapter
"""
return self.adapters[key]
async def _start_adapter(self, key: str, adapter: IMAdapter):
logger.info(f"Starting adapter: {key}")
await adapter.start()
adapter.is_running = True
logger.info(f"Started adapter: {key}")
self.event_bus.post(IMAdapterStarted(adapter))
async def _stop_adapter(self, key: str, adapter: IMAdapter):
logger.info(f"Stopping adapter: {key}")
await adapter.stop()
adapter.is_running = False
logger.info(f"Stopped adapter: {key}")
self.event_bus.post(IMAdapterStopped(adapter))
def stop_adapter(self, adapter_id: str, loop: asyncio.AbstractEventLoop):
if adapter_id not in self.adapters:
raise ValueError(f"Adapter {adapter_id} not found")
adapter = self.adapters[adapter_id]
return asyncio.ensure_future(self._stop_adapter(adapter_id, adapter), loop=loop)
def start_adapter(self, adapter_id: str, loop: asyncio.AbstractEventLoop):
if adapter_id not in self.adapters:
raise ValueError(f"Adapter {adapter_id} not found")
adapter = self.adapters[adapter_id]
return asyncio.ensure_future(
self._start_adapter(adapter_id, adapter), loop=loop
)
def is_adapter_running(self, key: str) -> bool:
"""
检查指定 key 的 adapter 是否正在运行。
:param key: adapter 的 key
:return: 如果 adapter 正在运行返回 True,否则返回 False
"""
return key in self.adapters and getattr(self.adapters[key], "is_running", False)
def create_adapter(
self, name: str, adapter_class: Type[IMAdapter], adapter_config: BaseModel
) -> IMAdapter:
with self.container.scoped() as scoped_container:
scoped_container.register(adapter_config.__class__, adapter_config)
adapter = Inject(scoped_container).create(adapter_class)()
adapter.is_running = False
self.adapters[name] = adapter
return adapter
================================================
FILE: kirara_ai/im/message.py
================================================
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
from kirara_ai.im.sender import ChatSender
from kirara_ai.media import MediaManager, MediaType
MIMETYPE_MAPPING = {
"image": MediaType.IMAGE,
"audio": MediaType.AUDIO,
"video": MediaType.VIDEO,
"file": MediaType.FILE,
}
# 定义消息元素的基类
class MessageElement(ABC):
@abstractmethod
def to_dict(self):
pass
@abstractmethod
def to_plain(self) -> str:
pass
# 定义文本消息元素
class TextMessage(MessageElement):
def __init__(self, text: str):
self.text = text
def to_dict(self):
return {"type": "text", "text": self.text}
def to_plain(self):
return self.text
def __repr__(self):
return f"TextMessage(text={self.text})"
# 定义媒体消息的基类
class MediaMessage(MessageElement):
resource_type: Literal["image", "audio", "video", "file"]
media_id: str
def __init__(
self,
url: Optional[str] = None,
path: Optional[str] = None,
data: Optional[bytes] = None,
format: Optional[str] = None,
media_id: Optional[str] = None,
reference_id: Optional[str] = None,
source: Optional[str] = "im_message",
description: Optional[str] = None,
tags: Optional[List[str]] = None,
media_manager: Optional[MediaManager] = None,
):
self.url = url
self.path = path
self.data = data
self.format = format
self._reference_id = reference_id
self._source = source
self._description = description
self._tags = tags or []
self._media_manager = media_manager or MediaManager()
self.base64_url: Optional[str] = None
if media_id:
self.media_id = media_id
return
# 注册媒体文件
# 使用线程创建新的事件循环来阻塞执行媒体注册
import asyncio
import threading
# 用于存储线程中的异常
thread_exception: Optional[Exception] = None
def run_in_new_loop():
nonlocal thread_exception
try:
asyncio.run(self._register_media())
except Exception as e:
thread_exception = e
# 在新线程中运行异步注册函数
thread = threading.Thread(
target=run_in_new_loop,
)
thread.start()
thread.join() # 阻塞等待完成
# 如果线程中发生异常,则在主线程中重新抛出
if thread_exception:
raise thread_exception
async def _register_media(self) -> None:
"""注册媒体文件"""
media_manager = self._media_manager
# 根据传入的参数注册媒体文件
self.media_id = await media_manager.register_media(
url=self.url,
path=self.path,
data=self.data,
format=self.format,
source=self._source,
description=self._description,
tags=self._tags,
media_type=MIMETYPE_MAPPING[self.resource_type],
reference_id=self._reference_id
)
# 获取媒体元数据
metadata = media_manager.get_metadata(self.media_id)
if metadata and metadata.format:
self.format = metadata.format
if metadata.media_type:
self.resource_type = metadata.media_type.value
async def get_url(self) -> str:
"""获取媒体资源的URL"""
if not self.media_id:
raise ValueError("Media not registered")
# 如果已经有URL,直接返回
if self.url:
return self.url
# 否则从媒体管理器获取
media_manager = self._media_manager
url = await media_manager.get_url(self.media_id)
if url:
self.url = url # 缓存结果
return url
raise ValueError("Failed to get media URL")
async def get_path(self) -> str:
"""获取媒体资源的文件路径"""
if not self.media_id:
raise ValueError("Media not registered")
# 如果已经有路径,直接返回
if self.path and Path(self.path).exists():
return self.path
# 否则从媒体管理器获取
media_manager = self._media_manager
file_path = await media_manager.get_file_path(self.media_id)
if file_path:
self.path = str(file_path) # 缓存结果
return self.path
raise ValueError("Failed to get media file path")
async def get_data(self) -> bytes:
"""获取媒体资源的二进制数据"""
if not self.media_id:
raise ValueError("Media not registered")
# 如果已经有数据,直接返回
if self.data:
return self.data
# 否则从媒体管理器获取
media_manager = self._media_manager
data = await media_manager.get_data(self.media_id)
if data:
self.data = data # 缓存结果
return data
raise ValueError("Failed to get media data")
async def get_base64_url(self) -> str:
"""获取媒体资源的Base64 URL"""
if not self.media_id:
raise ValueError("Media not registered")
if self.base64_url:
return self.base64_url
base64_url = await self._media_manager.get_base64_url(self.media_id)
if base64_url:
self.base64_url = base64_url
return base64_url
raise ValueError("Failed to get media base64 URL")
def get_description(self) -> str:
"""获取媒体资源的描述"""
if not self.media_id:
raise ValueError("Media not registered")
metadata = self._media_manager.get_metadata(self.media_id)
if metadata:
return metadata.description or ""
return ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result = {
"type": self.resource_type,
"media_id": self.media_id,
}
# 添加可选属性
if self.format:
result["format"] = self.format
if self.url:
result["url"] = self.url
if self.path:
result["path"] = self.path
return result
# 定义语音消息
class VoiceMessage(MediaMessage):
resource_type = "audio"
def to_dict(self):
result = super().to_dict()
result["type"] = "voice"
return result
def to_plain(self):
return "[VoiceMessage]"
# 定义图片消息
class ImageMessage(MediaMessage):
resource_type = "image"
def to_dict(self):
result = super().to_dict()
result["type"] = "image"
return result
def to_plain(self):
return f"[ImageMessage:media_id={self.media_id},url={self.url},alt={self.get_description()}]"
def __repr__(self):
return f"ImageMessage(media_id={self.media_id}, url={self.url}, path={self.path}, format={self.format})"
# 定义@消息元素
# :deprecated
class AtElement(MessageElement):
def __init__(self, user_id: str, nickname: str = ""):
self.user_id = user_id
self.nickname = nickname
def to_dict(self):
return {"type": "at", "data": {"qq": self.user_id, "nickname": self.nickname}}
def to_plain(self):
return f"@{self.nickname or self.user_id}"
def __repr__(self):
return f"AtElement(user_id={self.user_id}, nickname={self.nickname})"
# 定义@消息元素
class MentionElement(MessageElement):
def __init__(self, target: ChatSender):
self.target = target
def to_dict(self):
return {"type": "mention", "data": {"target": self.target}}
def to_plain(self):
return f"@{self.target.display_name or self.target.user_id}"
def __repr__(self):
return f"MentionElement(target={self.target})"
# 定义回复消息元素
class ReplyElement(MessageElement):
def __init__(self, message_id: str):
self.message_id = message_id
def to_dict(self):
return {"type": "reply", "data": {"id": self.message_id}}
def to_plain(self):
return f"[Reply:{self.message_id}]"
def __repr__(self):
return f"ReplyElement(message_id={self.message_id})"
# 定义文件消息元素
class FileMessage(MediaMessage):
resource_type = "file"
def to_dict(self):
result = super().to_dict()
result["type"] = "file"
return result
def to_plain(self):
return f"[File:{self.path or self.url or 'unnamed'}]"
def __repr__(self):
return f"FileElement(media_id={self.media_id}, url={self.url}, path={self.path}, format={self.format})"
# 定义JSON消息元素
class JsonMessage(MessageElement):
def __init__(self, data: str):
self.data = data
def to_dict(self):
return {"type": "json", "data": {"data": self.data}}
def to_plain(self):
return f"[JSON:{self.data}]"
def __repr__(self):
return f"JsonMessage(data={self.data})"
# 定义表情消息元素
class EmojiMessage(MessageElement):
def __init__(self, face_id: str):
self.face_id = face_id
def to_dict(self):
return {"type": "face", "data": {"id": self.face_id}}
def to_plain(self):
return f"[Face:{self.face_id}]"
def __repr__(self):
return f"EmojiMessage(face_id={self.face_id})"
# 定义视频消息元素
class VideoMessage(MediaMessage):
resource_type = "video"
def to_dict(self):
result = super().to_dict()
result["type"] = "video"
return result
def to_plain(self):
return f"[Video:{self.path or self.url or 'unnamed'}]"
def __repr__(self):
return f"VideoMessage(media_id={self.media_id}, url={self.url}, path={self.path}, format={self.format})"
# 定义消息类
class IMMessage:
"""
IM消息类,用于表示一条完整的消息。
包含发送者信息和消息元素列表。
Attributes:
sender: 发送者标识
message_elements: 消息元素列表,可以包含文本、图片、语音等
raw_message: 原始消息数据
content: 消息的纯文本内容
images: 消息中的图片列表
voices: 消息中的语音列表
"""
sender: ChatSender
message_elements: List[MessageElement]
raw_message: Optional[dict]
def __repr__(self):
return f"IMMessage(sender={self.sender}, message_elements={self.message_elements}, raw_message={self.raw_message})"
@property
def content(self) -> str:
"""获取消息的纯文本内容"""
content = ""
for element in self.message_elements:
content += element.to_plain()
if isinstance(element, TextMessage):
content += "\n"
return content.strip()
@property
def images(self) -> List[ImageMessage]:
"""获取消息中的所有图片"""
return [
element
for element in self.message_elements
if isinstance(element, ImageMessage)
]
@property
def voices(self) -> List[VoiceMessage]:
"""获取消息中的所有语音"""
return [
element
for element in self.message_elements
if isinstance(element, VoiceMessage)
]
def __init__(
self,
sender: ChatSender,
message_elements: List[MessageElement],
raw_message: Optional[dict] = None,
):
self.sender = sender
self.message_elements = message_elements
self.raw_message = raw_message
def to_dict(self):
return {
"sender": self.sender,
"message_elements": [
element.to_dict() for element in self.message_elements
],
"plain_text": "".join(
[element.to_plain() for element in self.message_elements]
),
"raw_message": self.raw_message,
}
# backward compatibility
# deprecated
FileElement = FileMessage
ImageElement = ImageMessage
VoiceElement = VoiceMessage
VideoElement = VideoMessage
EmojiElement = EmojiMessage
JsonElement = JsonMessage
FaceElement = EmojiMessage
================================================
FILE: kirara_ai/im/profile.py
================================================
from enum import Enum, auto
from typing import Optional
from pydantic import BaseModel, Field
class Gender(Enum):
MALE = auto()
FEMALE = auto()
UNKNOWN = auto()
OTHER = auto()
class UserProfile(BaseModel):
"""
通用的用户资料结构
"""
user_id: str = Field(..., description="用户唯一标识")
username: Optional[str] = Field(None, description="用户名")
display_name: Optional[str] = Field(None, description="显示名称")
full_name: Optional[str] = Field(None, description="完整名称")
gender: Optional[Gender] = Field(None, description="性别")
age: Optional[int] = Field(None, description="年龄")
avatar_url: Optional[str] = Field(None, description="头像URL")
level: Optional[int] = Field(None, description="用户等级")
language: Optional[str] = Field(None, description="语言")
extra_info: Optional[dict] = Field(None, description="平台特定的额外信息")
def __init__(self, **kwargs):
super().__init__(**kwargs)
================================================
FILE: kirara_ai/im/sender.py
================================================
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
class ChatType(Enum):
C2C = "c2c"
GROUP = "group"
@classmethod
def from_str(cls, value: str) -> "ChatType":
if value == "c2c" or value == "私聊":
return cls.C2C
elif value == "group" or value == "群聊":
return cls.GROUP
raise ValueError(f"Invalid chat type: {value}")
def to_str(self) -> str:
if self == self.C2C:
return "私聊"
elif self == self.GROUP:
return "群聊"
raise ValueError(f"Invalid chat type: {self}")
@dataclass
class ChatSender:
"""聊天发送者信息封装"""
display_name: str
user_id: str
chat_type: ChatType
group_id: Optional[str] = None
raw_metadata: Dict[str, Any] = field(default_factory=dict)
callback = None
@classmethod
def from_group_chat(
cls,
user_id: str,
group_id: str,
display_name: str,
metadata: Optional[Dict[str, Any]] = None,
) -> "ChatSender":
"""创建群聊发送者"""
return cls(
user_id=user_id,
chat_type=ChatType.GROUP,
group_id=group_id,
display_name=display_name,
raw_metadata=metadata or {},
)
@classmethod
def from_c2c_chat(
cls, user_id: str, display_name: str, metadata: Optional[Dict[str, Any]] = None
) -> "ChatSender":
"""创建私聊发送者"""
return cls(
user_id=user_id,
chat_type=ChatType.C2C,
display_name=display_name,
raw_metadata=metadata or {},
)
@classmethod
def get_bot_sender(cls) -> "ChatSender":
"""获取机器人发送者"""
return cls(
user_id="bot",
display_name="bot",
chat_type=ChatType.C2C,
)
def __str__(self) -> str:
if self.chat_type == ChatType.GROUP:
return f"{self.group_id}:{self.user_id}"
else:
return f"c2c:{self.user_id}"
def __eq__(self, other: Any) -> bool:
if isinstance(other, ChatSender):
return self.user_id == other.user_id and \
self.chat_type == other.chat_type and \
self.group_id == other.group_id
return False
def __hash__(self) -> int:
return hash((self.user_id, self.chat_type, self.group_id))
================================================
FILE: kirara_ai/internal.py
================================================
# 定义优雅退出异常
import asyncio
shutdown_event = asyncio.Event()
restart_flag = False
def set_restart_flag():
global restart_flag
restart_flag = True
def get_and_reset_restart_flag():
global restart_flag
flag = restart_flag
restart_flag = False
return flag
================================================
FILE: kirara_ai/ioc/__init__.py
================================================
================================================
FILE: kirara_ai/ioc/container.py
================================================
import contextvars
from typing import Any, Optional, Type, TypeVar, overload
T = TypeVar("T")
class DependencyContainer:
"""
依赖注入容器,提供注册和解析的功能。你可以在此获取一些全局的对象。
基本用法:
```python
# 1. 注册全局对象 - 通常在初始化时使用
container.register(YourObj, your_obj_instance)
# 2. 获取全局对象 - 在你的逻辑代码中使用
your_obj_instance = container.resolve(YourObj)
# 3. 销毁全局对象 - 通常在系统/插件销毁时使用
container.destroy(YourObj)
# 4. 创建作用域容器 - 作用域容器内注册的对象只在作用域内可被访问
# 离开作用域的上下文后无法取到该对象
# 全局容器注册对象
container.register(KiraraObj, kirara_obj)
with container.scoped() as scoped_container:
# 注册作用域对象
scoped_container.register(YourObj, your_obj_instance)
# 获取作用域对象
scoped_container.resolve(YourObj)
# 作用域容器也可以获取到全局容器的对象
container.has(KiraraObj) # 返回 True
# 甚至还能再创建新的作用域容器
with scoped_container.scoped() as another_scoped_container:
another_scoped_container.has(YourObj) # True
# 离开作用域上下文后无法获取到该对象
container.has(YourObj) # 返回 False
```
docs: https://docs.python.org/zh-cn/3.13/library/contextvars.html#module-contextvars
Attributes:
parent (DependencyContainer): 父容器实例,用于支持作用域嵌套
registry (dict): 存储当前容器注册的值或对象实例,格式为{key: value|object}
Methods:
register: 向容器注册一个key-value对
resolve: 从容器解析获取一个值或对象实例
destroy: 从容器中移除一个值或对象实例
scoped: 创建一个新的作用域容器
"""
def __init__(self, parent=None):
self.parent = parent # 父容器,用于支持作用域嵌套
self.registry = {} # 当前容器的注册表
def register(self, key, value):
"""
向容器注册一个值或者实例。
Args:
key: 对象的标识键, 一般为对象的类 (Type) 如 IMManager, LLMManager等,
会根据类型自动查找对应对象实例。
value: 值/对象实例
"""
self.registry[key] = value
@overload
def resolve(self, key: Type[T]) -> T: ...
@overload
def resolve(self, key: Any) -> Any: ...
def resolve(self, key: Type[T] | Any) -> T | Any:
"""
依照{key}从容器解析出一个值或对象实例。
如果{key}在当前容器中不存在,则会递归查找父容器。
Args:
key: 对象的标识键, 一般为对象的类 (Type) 如 IMManager, LLMManager等,
会根据类型自动查找对应对象实例。
Returns:
值/对象实例
Raises:
KeyError: {key}在当前容器和父容器中都不存在时抛出
"""
if key in self.registry:
return self.registry[key]
elif self.parent:
return self.parent.resolve(key)
else:
raise KeyError(f"Dependency {key} not found.")
def has(self, key: Type[T] | Any) -> bool:
"""
检测容器中是否能解析出某个键所对应的值。
Args:
key: 对象的标识键
Returns:
成功返回 True, 失败返回 False
"""
return key in self.registry or (self.parent is not None and self.parent.has(key))
@overload
def destroy(self, key: Type[T], recursive: bool = False) -> None: ...
@overload
def destroy(self, key: Any, recursive: bool = False) -> None: ...
def destroy(self, key: Type[T] | Any, recursive: bool = False) -> None:
"""
从容器中移除一个值或对象实例。支持递归删除父元素。
但是最好不要递归,你可能会删除一些系统对象
Args:
key: 对象的标识键
recursive: 是否递归删除父元素, 默认False。注意这是unsafe方法, 请注意不要删除系统对象。
Raises:
KeyError: {key}在当前容器和父容器中都不存在时抛出
"""
if key in self.registry:
del self.registry[key]
elif self.parent and recursive:
self.parent.destroy(key, recursive)
else:
raise KeyError(f"Cannot destroy dependency {key} which is not found in registry or parent container's registry.")
def scoped(self):
"""创建一个新的作用域容器"""
new_container = ScopedContainer(self)
if DependencyContainer in self.registry:
new_container.registry[DependencyContainer] = new_container
new_container.registry[ScopedContainer] = new_container
return new_container
# 使用 contextvars 实现线程和异步安全的上下文管理
current_container = contextvars.ContextVar[Optional[DependencyContainer]]("current_container", default=None)
class ScopedContainer(DependencyContainer):
def __init__(self, parent):
super().__init__(parent)
def __enter__(self):
# 将当前容器设置为新的作用域容器
self.token = current_container.set(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
# 恢复之前的容器
current_container.reset(self.token)
================================================
FILE: kirara_ai/ioc/inject.py
================================================
from functools import wraps
from inspect import signature
from typing import Any, Callable, Optional, Type
from kirara_ai.ioc.container import DependencyContainer
def get_all_attributes(cls):
if not hasattr(cls, "__annotations__"):
return {}
attributes = dict(cls.__annotations__.items())
# 获取父类的属性和方法
for base in cls.__bases__:
attributes.update(get_all_attributes(base))
return attributes
class Inject:
def __init__(self, container: Optional[DependencyContainer] = None):
self.container = container
def create(self, target: type):
# 注入类
injected_class = self.__call__(target)
# 注入构造函数
return self.inject_function(injected_class)
def __call__(self, target: Any):
# 如果修饰的是一个类
if isinstance(target, type):
return self.inject_class(target)
# 如果修饰的是一个函数
elif callable(target):
return self.inject_function(target)
else:
raise TypeError(
"Inject can only be used on classes, functions."
)
def inject_class(self, cls: Type):
# 遍历类的属性,尝试注入依赖
for name, injecting_type in get_all_attributes(cls).items():
attr = getattr(cls, name) if hasattr(cls, name) else None
setattr(cls, name, self.inject_property(name, cls, injecting_type, attr))
return cls
def inject_function(self, func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
# 获取函数的参数签名
sig = signature(func)
# 检查是否有 DependencyContainer 对象作为参数传递进来
container_param = self.find_container(args, kwargs)
# 如果有 DependencyContainer 对象,则将其作为 self.container
if container_param:
self.container = container_param
# 遍历参数,注入依赖
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()
for name, param in sig.parameters.items():
if (
param.annotation != param.empty
and name not in kwargs
and self.container
):
bound_args.arguments[name] = self.container.resolve(
param.annotation
)
# 调用实际的函数
return func(*bound_args.args, **bound_args.kwargs)
return wrapper
def inject_property(self, name, cls, injecting_type, prop: Optional[property]):
# 获取 property 的 fget, fset, fdel
backing_name = f"_{name}_value"
# 定义默认的 getter 方法 (使用实例属性存储值)
def default_fget(_self):
return getattr(_self, backing_name, None)
# 定义默认的 setter 方法 (使用实例属性存储值)
def default_fset(_self, value):
setattr(_self, backing_name, value)
# 定义默认的 deleter 方法 (使用实例属性存储值)
def default_fdel(_self):
if hasattr(_self, backing_name):
delattr(_self, backing_name)
# 如果已有属性,使用其方法,否则使用默认方法
if prop:
fget = prop.fget or default_fget
fset = prop.fset or default_fset
fdel = prop.fdel or default_fdel
else:
fget = default_fget
fset = default_fset
fdel = default_fdel
# 为 property 的 fget 注入依赖
@wraps(fget)
def new_fget(_self):
# 获取 property 的返回值
if self.container and isinstance(injecting_type, type) and self.container.has(injecting_type):
# 如果返回值是一个类型,尝试从 container 中解析
return self.container.resolve(injecting_type)
else:
return default_fget(_self)
# 返回新的 property
return property(new_fget, fset, fdel)
def find_container(self, args, kwargs):
# 检查是否有 DependencyContainer 对象作为参数传递进来
for arg in args:
if isinstance(arg, DependencyContainer):
return arg
for key, value in kwargs.items():
if isinstance(value, DependencyContainer):
return value
return None
================================================
FILE: kirara_ai/llm/adapter.py
================================================
from abc import ABC
from typing import List, Protocol, runtime_checkable
from kirara_ai.config.global_config import ModelConfig
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.llm.format.rerank import LLMReRankRequest, LLMReRankResponse
from kirara_ai.media.manager import MediaManager
from kirara_ai.tracing.llm_tracer import LLMTracer
@runtime_checkable
class AutoDetectModelsProtocol(Protocol):
async def auto_detect_models(self) -> List[ModelConfig]: ...
@runtime_checkable
class LLMChatProtocol(Protocol):
def chat(self, req: LLMChatRequest) -> LLMChatResponse: ...
@runtime_checkable
class LLMEmbeddingProtocol(Protocol):
def embed(self, req: LLMEmbeddingRequest) -> LLMEmbeddingResponse: ...
@runtime_checkable
class LLMReRankProtocol(Protocol):
def rerank(self, req: LLMReRankRequest) -> LLMReRankResponse: ...
class LLMBackendAdapter(ABC):
backend_name: str
media_manager: MediaManager
tracer: LLMTracer
================================================
FILE: kirara_ai/llm/format/__init__.py
================================================
from .message import LLMChatImageContent, LLMChatMessage, LLMChatTextContent, LLMToolCallContent, LLMToolResultContent
from .response import LLMChatResponse
from .tool import Function, Tool, ToolCall
__all__ = ["LLMChatMessage", "LLMChatTextContent", "LLMChatImageContent", "LLMToolCallContent", "LLMToolResultContent", "Function", "Tool", "ToolCall", "LLMChatResponse"]
================================================
FILE: kirara_ai/llm/format/embedding.py
================================================
from typing import Literal, Optional
from pydantic import BaseModel
from .message import LLMChatTextContent, LLMChatImageContent
from .response import Usage
FormatType = Literal["base64"]
OutputType = Literal["float", "int8", "uint8", "binary", "ubinary"]
InputType = Literal["string", "query", "document"]
InputUnionType = LLMChatTextContent | LLMChatImageContent
class LLMEmbeddingRequest(BaseModel):
"""
此模型用于规范embedding请求的格式
Tips: 各大模型向量维度以及向量转化函数不同,因此当你用于向量数据库时,请确保存储和检索使用同一个模型,并确保模型向量一致(部分模型支持同一模型设置向量维度)
Note: 注意一下字段为混合字段, 部分字段在部分模型中不起作用, 请参照对应ap文档传递参数。
Attributes:
text (list[str | Image]): 待转化为向量的文本或图片列表
model (str): 使用的embedding模型名
dimensions (Optional[int]): embedding向量的维度
encoding_format (Optional[FormatType]): embedding的编码格式。推荐不设置该字段, 方便直接输入数据库
input_type (Optional[InputType]): 输入类型, 归属于voyage_adapter的独有字段
truncate (Optional[bool]): 是否自动截断超长文本, 以适应llm上下文长度上限。
output_type (Optional[OutputType]): 向量内部应该使用哪种数据类型. 一般默认float
"""
inputs: list[InputUnionType]
model: str
dimension: Optional[int] = None
encoding_format: Optional[FormatType] = None
input_type: Optional[InputType] = None
truncate: Optional[bool] = None
output_type: Optional[OutputType] = None
vector = list[float | int] # 后续可能需要使用numpy库进行精确的数据类型标注, 暂时未处理base64的返回模式
class LLMEmbeddingResponse(BaseModel):
"""
向量维度请使用len(vector)自行计算。
Attributes:
vectors: list[vector]
usage: Optional[Usage] = None
"""
vectors: list[vector]
usage: Optional[Usage] = None
================================================
FILE: kirara_ai/llm/format/message.py
================================================
import json
from typing import Literal, Optional, Union
from pydantic import BaseModel, field_validator, model_validator
from typing_extensions import Self
from .tool import LLMToolResultContent
RoleType = Literal["system", "user", "assistant"]
class LLMChatTextContent(BaseModel):
type: Literal["text"] = "text"
text: str
class LLMChatImageContent(BaseModel):
type: Literal["image"] = "image"
media_id: str
class LLMToolCallContent(BaseModel):
"""
这是模型请求工具的消息内容,
模型强相关内容,如果你 message 或者 memory 内包含了这个内容,请保证调用同一个 model
此部分 role 应该归属于"assistant"
"""
type: Literal["tool_call"] = "tool_call"
# call id,部分模型用此字段区分不同函数的调用,若没有返回则由 Adapter 生成
id: str
name: str
parameters: Optional[dict] = None
@classmethod
@field_validator("parameters", mode="before")
def convert_parameters_to_dict(cls, v: Optional[Union[str, dict]]) -> Optional[dict]:
if isinstance(v, str):
return json.loads(v)
return v
LLMChatContentPartType = Union[LLMChatTextContent, LLMChatImageContent, LLMToolCallContent, LLMToolResultContent]
RoleTypes = Literal["user", "assistant", "system", "tool"]
class LLMChatMessage(BaseModel):
"""
当 role 为 "tool" 时, content 内部只能为 list[LLMToolResultContent]
"""
content: list[LLMChatContentPartType]
role: RoleTypes
@model_validator(mode="after")
def check_content_type(self) -> Self:
# 此装饰器将在 model 实例化后执行,`mode = "after"`
# 用于检查 content 字段的类型是否符合 role 要求
match self.role:
case "user" | "assistant" | "system":
if not all(any(isinstance(element, content_type) for content_type in [LLMChatTextContent, LLMChatImageContent, LLMToolCallContent]) for element in self.content):
raise ValueError(f"content must be a list of LLMChatContentPartType, when role is {self.role}")
case "tool":
if not all(isinstance(element, LLMToolResultContent) for element in self.content):
raise ValueError("content must be a list of LLMToolResultContent, when role is 'tool'")
return self
================================================
FILE: kirara_ai/llm/format/request.py
================================================
from typing import Any, List, Optional
from pydantic import BaseModel
from kirara_ai.llm.format.message import LLMChatMessage
from .tool import Tool
class ResponseFormat(BaseModel):
type: Optional[str] = None
class LLMChatRequest(BaseModel):
"""
Attributes:
tool_choice (Union[dict, Literal["auto", "any", "none"]]):
"
注意由于大模型对于这个接口实现不同,本次暂不实现tool_choice的功能。
tool_choice这个参数告诉llmMessage应该如何选择调用的工具。
"
"""
messages: List[LLMChatMessage] = []
model: Optional[str] = None
frequency_penalty: Optional[int] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[int] = None
response_format: Optional[ResponseFormat] = None
stop: Optional[Any] = None
stream: Optional[bool] = None
stream_options: Optional[Any] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
# 规范tool传递
tools: Optional[list[Tool]] = None
# tool_choice各家目前标准不尽相同,暂不向用户提供更改这个值的选项
tool_choice: Optional[Any] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[Any] = None
================================================
FILE: kirara_ai/llm/format/rerank.py
================================================
from typing import Optional
from typing_extensions import Self
from pydantic import BaseModel, model_validator
from .response import Usage
class LLMReRankRequest(BaseModel):
"""
ReRanker: 重排器是一个重要的处理方案, 通常见于 EsSearch 的优化方案中。
本接口是适用于 LLM 的重排器的请求模型一般与嵌入式式模型组合使用提高向量搜索准确率。
传入一系列原始文档和一个查询语句,返回器相似度数值。
Attributes:
query: 原始查询语句
documents: 文档列表, 包含文档的文本内容。每个文档转化为一个 string 类型传递。
model: 重排模型的名称。为保证准确性,本实现将禁止自动选择模型。
top_k: 返回最相似的 {top_k} 个文档。如果没有指定,将返回所有文档的重排序结果。
Tips: 如果你决定不返回原始文档,那么不要设置这个选项。会丢失文本与相似度的关联。
return_documents: 是否返回原始文档内容。
truncation: 文档和查询语句是否允许被截断以适应模型最大上下文。
sort: 是否按照结果的相似度得分进行排序? 默认不进行
Tips: 当return_documents为False时,若sort为True,则抛出异常。
"""
query: str
documents: list[str]
model: str
top_k: Optional[int] = None
return_documents: Optional[bool] = None
truncation: Optional[bool] = None
sort: Optional[bool] = False
@model_validator(mode="after") # mode 为 before 时其用法与after完全不同,注意看官网文档
# 这里不用after是为了等pydantic赋值默认值后检查
def check(self) -> Self:
if self.sort and not self.return_documents:
raise ValueError("Cannot sort server responses when return_documents is False.")
return self
class ReRankerContent(BaseModel):
"""
ReRanker 的内容模型。
Attributes:
document: 原始文档内容。
score: 文档的相似度分数。
"""
document: Optional[str] = None
score: float
class LLMReRankResponse(BaseModel):
"""
ReRanker 的返回模型。
Attributes:
contents (list[ReRankerContent]): 返回的排序信息, 如果启用排序,默认降序排列。 Note: 当且仅当return_documents为True时才允许启用排序。
usage (Usage): token 使用情况, 一个pydantic.BaseModel的子类。
sort (bool): 是否按照结果的相似度排序?将其设置为字段方便后续接口检查是否经过排序(方便debug)。其应该由request的sort字段赋值。
"""
contents: list[ReRankerContent]
usage: Usage
sort: bool
@model_validator(mode="after") # 当mode为after时,其发生在class实例化完成后,所以其为实例方法
def sort_content(self) -> Self:
if self.sort:
self.contents = sorted(self.contents, key= lambda x: x.score, reverse=True)
return self
================================================
FILE: kirara_ai/llm/format/response.py
================================================
from typing import List, Optional
from pydantic import BaseModel
from kirara_ai.llm.format.message import LLMChatMessage
from kirara_ai.llm.format.tool import ToolCall
class Message(LLMChatMessage):
tool_calls: Optional[List[ToolCall]] = None
finish_reason: Optional[str] = None
class Usage(BaseModel):
completion_tokens: Optional[int] = None
prompt_tokens: Optional[int] = None
total_tokens: Optional[int] = None
cached_tokens: Optional[int] = None
class LLMChatResponse(BaseModel):
model: Optional[str] = None
usage: Optional[Usage] = None
message: Message
================================================
FILE: kirara_ai/llm/format/tool.py
================================================
import json
from typing import Any, Callable, Coroutine, Generic, List, Literal, Optional, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
class TextContent(BaseModel):
type: Literal["text"] = "text"
text: str
class MediaContent(BaseModel):
type: Literal["media"] = "media"
media_id: str
mime_type: str
data: bytes
ToolResponseTypes = List[Union[TextContent, MediaContent]]
class LLMToolResultContent(BaseModel):
"""
这是工具回应的消息内容,
模型强相关内容,如果你 message 或者 memory 内包含了这个内容,请保证调用同一个 model
此部分 role 应该对应 "tool"
"""
type: Literal["tool_result"] = "tool_result"
# call id,对应 LLMToolCallContent 的 id
id: str
name: str
# 各家工具要求返回的content格式不同. 等待后续规范化。
content: ToolResponseTypes
isError: bool = False
class Function(BaseModel):
# 工具名称
name: str
# 这个字段类似于 python 的关键字参数,你可以直接使用`**arguments`
arguments: Optional[dict] = None
@field_validator("arguments", mode="before")
@classmethod
# pydantic 官网建议将 @classmethod 放在下面。因为python装饰器执行顺序是由下到上。
def convert_arguments(cls, v: Optional[Union[str, dict]]) -> Optional[dict]:
return json.loads(v) if isinstance(v, str) else v
class ToolCall(BaseModel):
# call id,对应 LLMToolCallContent 的 id
id: str
# type这个字段目前不知道有什么用
type: Optional[str] = None
function: Function
T = TypeVar('T', bound=Callable)
ToolInvokeFunc = Callable[[ToolCall], Coroutine[Any, Any, "LLMToolResultContent"]]
class CallableWrapper(Generic[T]):
"""包装可调用对象的类,在深拷贝时返回None"""
def __init__(self, func: T):
self.func = func
def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, "LLMToolResultContent"]:
return self.func(*args, **kwargs)
def __deepcopy__(self, memo):
# 深拷贝时保持原始引用而不是尝试复制函数
return self
class ToolInputSchema(BaseModel):
"""
工具输入参数的格式,遵循 JSON Schema 的规范
Attributes:
type (Literal["object"]): 参数的类型
properties (dict): 工具属性,参考 openai api 的规范
required (list[str]): 必填参数的名称列表
additionalProperties (Optional[bool]): 是否允许额外的键值对
"""
type: Literal["object"] = "object"
properties: dict
required: list[str]
additionalProperties: Optional[bool] = False
class Tool(BaseModel):
"""
传递给 LLM 的工具信息
Attributes:
type (Optional[Literal["function"]]): 工具的类型
name (str): 工具的名称
description (str): 工具的描述
parameters (ToolInputSchema): 工具的参数格式
strict (Optional[bool]): 是否严格调用, openai api专属
invokeFunc (Optional[Callable]): 工具对应的执行函数,仅在调用时使用,不参与序列化
"""
type: Optional[Literal["function"]] = "function"
name: str
description: str
parameters: Union[ToolInputSchema, dict]
strict: Optional[bool] = False
invokeFunc: CallableWrapper[ToolInvokeFunc] = Field(exclude=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_serializer("invokeFunc")
def serialize_invoke_func(self, invoke_func: CallableWrapper[ToolInvokeFunc]) -> str:
return "..."
================================================
FILE: kirara_ai/llm/llm_manager.py
================================================
import random
from typing import Dict, List, Optional
from typing_extensions import deprecated
from kirara_ai.config.global_config import GlobalConfig, ModelConfig
from kirara_ai.events.event_bus import EventBus
from kirara_ai.events.llm import LLMAdapterLoaded, LLMAdapterUnloaded
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.llm.adapter import LLMBackendAdapter
from kirara_ai.llm.llm_registry import LLMBackendRegistry
from kirara_ai.llm.model_types import ModelAbility, ModelType
from kirara_ai.logger import get_logger
class LLMManager:
"""
跟踪、管理和调度模型后端
"""
container: DependencyContainer
config: GlobalConfig
backend_registry: LLMBackendRegistry
active_backends: Dict[str, List[LLMBackendAdapter]]
model_info: Dict[str, ModelConfig] # 存储模型的配置信息
event_bus: EventBus
@Inject()
def __init__(
self,
container: DependencyContainer,
config: GlobalConfig,
backend_registry: LLMBackendRegistry,
event_bus: EventBus,
):
self.container = container
self.config = config
self.backend_registry = backend_registry
self.event_bus = event_bus
self.logger = get_logger("LLMAdapter")
self.active_backends = {}
self.model_info = {} # 初始化模型信息字典
self.backends: Dict[str, LLMBackendAdapter] = {}
def load_config(self):
"""加载配置文件中的所有启用的后端"""
for backend in self.config.llms.api_backends:
if backend.enable:
self.logger.info(f"Loading backend: {backend.name}")
try:
self.load_backend(backend.name)
except Exception as e:
self.logger.error(f"Failed to load backend {backend.name}: {e}")
def load_backend(self, backend_name: str):
"""
加载指定的后端
:param backend_name: 后端名称
"""
backend = next(
(b for b in self.config.llms.api_backends if b.name == backend_name), None
)
if not backend:
raise ValueError(f"Backend {backend_name} not found in config")
if not backend.enable:
raise ValueError(f"Backend {backend_name} is not enabled")
if any(backend_name in adapters for adapters in self.active_backends.values()):
raise ValueError(f"Backend {backend_name} is already loaded")
adapter_class = self.backend_registry.get(backend.adapter)
config_class = self.backend_registry.get_config_class(backend.adapter)
if not adapter_class or not config_class:
raise ValueError(f"Invalid adapter type: {backend.adapter}")
# 创建适配器实例
with self.container.scoped() as scoped_container:
scoped_container.register(config_class, config_class(**backend.config))
adapter = Inject(scoped_container).create(adapter_class)()
adapter.backend_name = backend_name
self.backends[backend_name] = adapter
# 注册到每个支持的模型并记录模型信息
for model_config in backend.models:
# 从ModelConfig中获取模型信息
model_id = model_config.id
# 直接存储模型配置
self.model_info[model_id] = model_config
if model_id not in self.active_backends:
self.active_backends[model_id] = []
self.active_backends[model_id].append(adapter)
self.event_bus.post(LLMAdapterLoaded(adapter=adapter, backend_name=backend_name))
self.logger.info(f"Backend {backend_name} loaded successfully")
async def unload_backend(self, backend_name: str):
"""
卸载指定的后端
:param backend_name: 后端名称
"""
backend = next(
(b for b in self.config.llms.api_backends if b.name == backend_name), None
)
if not backend:
raise ValueError(f"Backend {backend_name} not found in config")
backend_adapter = self.backends.get(backend_name)
if not backend_adapter:
raise ValueError(f"Backend {backend_name} not found")
# 从所有模型中移除这个后端的适配器
all_models = list(self.active_backends.keys())
for model in all_models:
if backend_adapter in self.active_backends[model]:
self.active_backends[model].remove(backend_adapter)
if len(self.active_backends[model]) == 0:
self.active_backends.pop(model)
# 清理模型信息
if model in self.model_info:
self.model_info.pop(model)
backend_adapter = self.backends.pop(backend_name)
self.event_bus.post(LLMAdapterUnloaded(backend_name=backend_name, adapter=backend_adapter))
async def reload_backend(self, backend_name: str):
"""
重新加载指定的后端
:param backend_name: 后端名称
"""
await self.unload_backend(backend_name)
self.load_backend(backend_name)
def is_backend_available(self, backend_name: str) -> bool:
"""
检查后端是否可用
:param backend_name: 后端名称
:return: 后端是否可用
"""
backend = next(
(b for b in self.config.llms.api_backends if b.name == backend_name), None
)
if not backend:
return False
if not backend.enable:
return False
# 检查后端的所有模型是否都有可用的适配器
for model_config in backend.models:
model_id = model_config.id
if model_id not in self.active_backends or len(self.active_backends[model_id]) == 0:
return False
return True
def get(self, backend_name: str) -> Optional[LLMBackendAdapter]:
"""
获取指定后端的适配器实例
:param backend_name: 后端名称
:return: LLM适配器实例,如果没有找到则返回None
"""
return self.backends.get(backend_name)
def get_llm(self, model_id: str) -> Optional[LLMBackendAdapter]:
"""
从指定模型的活跃后端中随机返回一个适配器实例
:param model_id: 模型ID
:return: LLM适配器实例,如果没有找到则返回None
"""
if model_id not in self.active_backends:
return None
backends = self.active_backends[model_id]
if not backends:
return None
# TODO: 后续考虑支持更多的选择策略
return random.choice(backends)
def get_supported_models(self, model_type: ModelType, ability: ModelAbility) -> List[str]:
"""
获取所有支持指定能力的模型
:param ability: 指定的能力
:return: 支持的模型ID列表
"""
return [
model_id
for model_id, model_config in self.model_info.items()
if model_config.type == model_type.value
and ability.is_capable(model_config.ability)
]
@deprecated("请使用 get_supported_models 方法")
def get_llm_id_by_ability(self, ability: ModelAbility) -> Optional[str]:
"""
根据指定的能力获取一个随机符合要求的LLM模型ID
deprecated: 请使用 get_supported_models 方法
:param ability: 指定的能力
:return: 符合要求的模型ID,如果没有找到则返回None
"""
supported_models = self.get_supported_models(ModelType.LLM, ability)
return None if not supported_models else random.choice(supported_models)
def get_models_by_ability(self, model_type: ModelType, ability: ModelAbility) -> Optional[str]:
"""
根据指定能力随机获取一个模型ID
:param model_type: 模型类型
:param ability: 指定的能力
:return: 随机选择的模型ID,如果没有找到则返回None
"""
supported_models = self.get_supported_models(model_type, ability)
if not supported_models:
return None
return random.choice(supported_models)
def get_models_by_type(self, model_type: ModelType) -> List[str]:
"""
获取指定类型的所有模型
:param model_type: 模型类型
:return: 该类型的模型ID列表
"""
return [
model_id for model_id, config in self.model_info.items()
if config.type == model_type.value
]
================================================
FILE: kirara_ai/llm/llm_registry.py
================================================
from typing import Dict, Optional, Type
from pydantic import BaseModel
from kirara_ai.logger import get_logger
from .adapter import LLMBackendAdapter
from .model_types import LLMAbility # noqa: F401
class LLMBackendRegistry:
"""
LLM后端注册表
"""
_adapters: Dict[str, Type[LLMBackendAdapter]]
_configs: Dict[str, Type[BaseModel]]
def __init__(self):
self._adapters = {}
self._configs = {}
self.logger = get_logger(__name__)
def register(
self,
adapter_type: str,
adapter_class: Type[LLMBackendAdapter],
config_class: Type[BaseModel],
*args, **kwargs
):
"""
注册一个LLM后端适配器
:param adapter_type: 适配器类型
:param adapter_class: 适配器类
:param config_class: 配置类
"""
self._adapters[adapter_type] = adapter_class
self._configs[adapter_type] = config_class
self.logger.info(
f"Registered LLM backend adapter: {adapter_type}"
)
def get(self, adapter_type: str) -> Optional[Type[LLMBackendAdapter]]:
"""
获取指定类型的适配器类
:param adapter_type: 适配器类型
:return: 适配器类,如果没有找到则返回None
"""
return next(
(adapter for key, adapter in self._adapters.items() if key.lower() == adapter_type.lower()),
None
)
def get_config_class(self, adapter_type: str) -> Optional[Type[BaseModel]]:
"""
获取指定类型的配置类
:param adapter_type: 适配器类型
:return: 配置类,如果没有找到则返回None
"""
return next(
(config for key, config in self._configs.items() if key.lower() == adapter_type.lower()),
None
)
def get_adapter_types(self) -> list[str]:
"""
获取所有已注册的适配器类型
:return: 适配器类型列表
"""
return list(self._adapters.keys())
def get_all_adapters(self) -> Dict[str, Type[LLMBackendAdapter]]:
"""
获取所有已注册的 LLM 适配器。
:return: 所有已注册的 LLM 适配器字典。
"""
return self._adapters.copy()
================================================
FILE: kirara_ai/llm/model_types.py
================================================
from abc import abstractmethod
from enum import Enum
class ModelType(Enum):
"""
模型类型枚举
"""
LLM = "llm"
Embedding = "embedding"
ImageGeneration = "image_generation"
Audio = "audio"
# 可以根据需要添加更多类型
@classmethod
def from_str(cls, value: str) -> "ModelType":
"""
从字符串转换为ModelType枚举
"""
return next(
(enum_value for enum_value in cls if enum_value.value == value),
cls.LLM,
)
class ModelAbility(Enum):
"""
模型能力抽象基类
"""
@abstractmethod
def is_capable(self, ability: int) -> bool:
"""
检查模型是否具备指定能力
"""
return False
class LLMAbility(ModelAbility):
"""
定义了 LLMAbility 的枚举类型,用于表示 LLM 的能力。
"""
# 这里表示接口支持 chat 格式的对话
Unknown = 0
Chat = 1 << 1
TextInput = 1 << 2
TextOutput = 1 << 3
ImageInput = 1 << 4
ImageOutput = 1 << 5
AudioInput = 1 << 6
AudioOutput = 1 << 7
FunctionCalling = 1 << 8
# 下面是通过位运算组合能力
TextCompletion = TextInput | TextOutput
TextChat = Chat | TextCompletion
ImageGeneration = ImageInput | ImageOutput
TextImageMultiModal = Chat | ImageGeneration
TextImageAudioMultiModal = TextImageMultiModal | AudioInput | AudioOutput
def is_capable(self, ability: int) -> bool:
"""
检查模型是否具备指定能力
"""
return (self.value & ability) == ability
class EmbeddingModelAbility(ModelAbility):
"""
定义了 EmbeddingModelAbility 的枚举类型,用于表示 Embedding 模型的能力。
"""
Unknown = 0
TextEmbedding = 1 << 1
ImageEmbedding = 1 << 2
AudioEmbedding = 1 << 3
VideoEmbedding = 1 << 4
Batch = 1 << 5
def is_capable(self, ability: int) -> bool:
"""
检查模型是否具备指定能力
"""
return (self.value & ability) == ability
class ImageModelAbility(ModelAbility):
"""
定义了 ImageModelAbility 的枚举类型,用于表示图像模型的能力。
"""
Unknown = 0
TextToImage = 1 << 1
ImageEdit = 1 << 2
Inpainting = 1 << 3
Outpainting = 1 << 4
UpScaling = 1 << 5
def is_capable(self, ability: int) -> bool:
"""
检查模型是否具备指定能力
"""
return (self.value & ability) == ability
class AudioModelAbility(ModelAbility):
"""
定义了 AudioModelAbility 的枚举类型,用于表示音频模型的能力。
"""
Unknown = 0
Speech = 1 << 1
Transcription = 1 << 2
Translation = 1 << 3
Streaming = 1 << 4
Realtime = 1 << 5
def is_capable(self, ability: int) -> bool:
"""
检查模型是否具备指定能力
"""
return (self.value & ability) == ability
================================================
FILE: kirara_ai/logger.py
================================================
import asyncio
import json
import os
import re
import traceback
from collections import deque
from datetime import datetime
from typing import Any, Callable, Dict, List
from loguru import logger
# 创建 logs 文件夹
LOG_DIR = "logs"
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
# 配置日志格式和颜色
logger.remove() # 移除默认的日志处理器
# 定义日志格式
log_format = (
"{time:YYYY-MM-DD HH:mm:ss.SSS} | "
"{level: <8} | "
"{extra[tag]: <12} | "
"{message}"
)
# 添加控制台日志处理器
logger.add(
sink=lambda msg: print(msg.strip()), # 输出到控制台
format=log_format,
level="DEBUG",
colorize=True,
)
# 添加文件日志处理器,支持日志轮转
log_file = os.path.join(LOG_DIR, "log_{time:YYYY-MM-DD}.log")
logger.add(
sink=log_file,
format=log_format,
level="DEBUG",
rotation="00:00", # 每天午夜轮转
retention="7 days", # 保留7天的日志
compression="zip", # 压缩旧日志文件
colorize=False,
)
# 全局日志实例
_global_logger = logger
# 内存中保存最近的日志,用于新连接时推送历史日志
_recent_logs: deque[Dict] = deque(maxlen=500) # 保存最近500条日志
LogBroadcasterCallback = Callable[[Dict | List[Dict]], None]
# 通用日志处理器管理类
class LogBroadcaster:
"""通用日志广播器,支持多种日志订阅方式"""
_instance = None
_subscribers: Dict[int, LogBroadcasterCallback] = {}
_next_id = 0
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(LogBroadcaster, cls).__new__(cls)
cls._instance._subscribers = {} # 订阅者字典,键为订阅者ID,值为回调函数
cls._instance._next_id = 0 # 下一个订阅者ID
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self._initialized = True
# 添加日志处理器
self._setup_log_handler()
def _setup_log_handler(self):
"""设置日志处理器"""
def log_sink(message):
# 格式化日志消息
# with {tag<12}
tag = message.record["extra"]["tag"]
log_entry = {
"type": "log",
"level": message.record["level"].name,
"content": message.record['message'],
"timestamp": datetime.now().isoformat(),
"tag": tag
}
if message.record.get('exception'):
exception = message.record['exception']
if exception.value:
log_entry["content"] += "\n" + '\n'.join(traceback.format_exception(exception.value))
# 保存到最近日志
_recent_logs.append(log_entry)
# 广播到所有订阅者
self._broadcast_log(log_entry)
# 添加日志处理器
_global_logger.add(
sink=log_sink,
format=log_format,
level="INFO",
colorize=False,
)
def _broadcast_log(self, log_entry: Dict):
"""广播日志到所有订阅者"""
to_remove = []
for subscriber_id, callback in self._subscribers.items():
try:
callback(log_entry)
except Exception:
# 如果发送失败,标记该订阅者为断开
to_remove.append(subscriber_id)
# 移除断开的订阅者
for subscriber_id in to_remove:
self.unsubscribe(subscriber_id)
def subscribe(self, callback: LogBroadcasterCallback) -> int:
"""
添加日志订阅者
:param callback: 回调函数,接收日志条目并处理
:return: 订阅者ID,用于后续取消订阅
"""
subscriber_id = self._next_id
self._next_id += 1
self._subscribers[subscriber_id] = callback
return subscriber_id
def unsubscribe(self, subscriber_id: int) -> bool:
"""
取消日志订阅
:param subscriber_id: 订阅者ID
:return: 是否成功取消订阅
"""
if subscriber_id in self._subscribers:
del self._subscribers[subscriber_id]
return True
return False
def send_recent_logs(self, callback: LogBroadcasterCallback):
"""
发送最近的日志到指定回调
:param callback: 回调函数,接收日志条目并处理
"""
callback(list(_recent_logs))
# WebSocket日志处理器,作为LogBroadcaster的一个应用
class WebSocketLogHandler:
"""WebSocket日志处理器,用于将日志发送到WebSocket客户端"""
# 存储所有活跃的WebSocket连接及其订阅ID
_websockets: Dict[Any, int] = {}
@classmethod
def add_websocket(cls, ws, loop: asyncio.AbstractEventLoop):
"""
添加WebSocket连接
:param ws: WebSocket连接对象
"""
# 创建回调函数
def send_to_ws(log_entries: List[Dict] | Dict):
loop.create_task(ws.send(json.dumps(log_entries)))
# 获取日志广播器实例
broadcaster = LogBroadcaster()
# 先发送最近的日志
broadcaster.send_recent_logs(send_to_ws)
# 订阅新日志
subscriber_id = broadcaster.subscribe(send_to_ws)
cls._websockets[ws] = subscriber_id
@classmethod
def remove_websocket(cls, ws):
"""
移除WebSocket连接
:param ws: WebSocket连接对象
"""
if ws in cls._websockets:
# 取消订阅
LogBroadcaster().unsubscribe(cls._websockets[ws])
del cls._websockets[ws]
# 初始化日志广播器
def init_log_broadcaster():
"""初始化日志广播器"""
LogBroadcaster()
init_log_broadcaster()
def get_logger(tag: str):
"""
获取带有特定标签的日志记录器
:param tag: 日志标签
:return: 日志记录器
"""
return _global_logger.bind(tag=tag)
class HypercornLoggerWrapper:
def __init__(self, logger):
self.logger = logger
def critical(self, message: str, *args: Any, **kwargs: Any) -> None:
self.logger.critical(message, *args, **kwargs)
def error(self, message: str, *args: Any, **kwargs: Any) -> None:
self.logger.error(message, *args, **kwargs)
def warning(self, message: str, *args: Any, **kwargs: Any) -> None:
self.logger.warning(message, *args, **kwargs)
def info(self, message: str, *args: Any, **kwargs: Any) -> None:
log_fmt = re.sub(r"%\((\w+)\)s", r"{\1}", message)
atoms = args[0] if args else {}
self.logger.info(log_fmt, **atoms)
def debug(self, message: str, *args: Any, **kwargs: Any) -> None:
self.logger.debug(message, *args, **kwargs)
def exception(self, message: str, *args: Any, **kwargs: Any) -> None:
self.logger.exception(message, *args, **kwargs)
def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None:
self.logger.log(level, message, *args, **kwargs)
def get_async_logger(tag: str):
"""
获取带有特定标签的日志记录器
:param tag: 日志标签
:return: 日志记录器
"""
return HypercornLoggerWrapper(_global_logger.bind(tag=tag))
================================================
FILE: kirara_ai/mcp_module/__init__.py
================================================
from .manager import MCPServerManager
from .models import MCPConnectionState
from .server import MCPServer
__all__ = ["MCPServerManager", "MCPServer", "MCPConnectionState"]
================================================
FILE: kirara_ai/mcp_module/manager.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
from functools import partial
from typing import Dict, NamedTuple, Optional, Tuple
from mcp import McpError, types
from mcp.shared.session import RequestResponder
from kirara_ai.config.global_config import GlobalConfig, MCPServerConfig
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
from .models import MCPConnectionState
from .server import MCPServer
logger = get_logger("MCP")
class ToolCacheEntry(NamedTuple):
"""工具缓存条目"""
server_id: str # 服务器ID
original_name: str # 原始工具名称
tool_info: types.Tool # 工具信息
class MCPServerManager:
"""MCP服务器管理器,负责管理和控制MCP服务器进程"""
def __init__(self, container: DependencyContainer):
"""初始化MCP服务器管理器"""
self.container = container
self.config = container.resolve(GlobalConfig)
self.servers: Dict[str, MCPServer] = {}
self.tools_cache: Dict[str, ToolCacheEntry] = {}
self.prompts_cache: Dict[str, list[types.Prompt]] = {}
self.resources_cache: Dict[str, list[types.Resource]] = {}
def load_servers(self):
"""从配置加载所有MCP服务器"""
for server_config in self.config.mcp.servers:
try:
self.load_server(server_config)
except Exception as e:
logger.opt(exception=e).error(f"Failed to load MCP server {server_config.id}")
logger.info(f"MCP server manager initialized, loaded {len(self.servers)} servers")
def load_server(self, server_config: MCPServerConfig) -> MCPServer:
"""从配置加载MCP服务器"""
server = MCPServer(server_config)
logger.info(f"Initializing MCP server {server_config.id}")
self.servers[server_config.id] = server
return server
def get_all_servers(self) -> Dict[str, MCPServer]:
"""获取所有MCP服务器列表"""
return self.servers
def get_server(self, server_id: str) -> Optional[MCPServer]:
"""获取指定ID的MCP服务器"""
return self.servers.get(server_id)
def is_server_id_available(self, server_id: str) -> bool:
"""
检查服务器ID是否可用
判断条件:
1. 服务器ID不存在
2. 或者服务器存在但状态为 DISCONNECTED 或 ERROR
"""
if server_id not in self.servers:
return True
server = self.servers[server_id]
return server.state in [MCPConnectionState.DISCONNECTED, MCPConnectionState.ERROR]
def get_statistics(self) -> Dict[str, int]:
"""获取MCP服务器统计信息"""
total = len(self.servers)
stdio = sum(bool(s.server_config.connection_type == "stdio") for s in self.servers.values())
sse = sum(bool(s.server_config.connection_type == "sse") for s in self.servers.values())
connected = sum(bool(s.state == MCPConnectionState.CONNECTED) for s in self.servers.values())
disconnected = sum(bool(s.state == MCPConnectionState.DISCONNECTED) for s in self.servers.values())
error = sum(bool(s.state == MCPConnectionState.ERROR) for s in self.servers.values())
return {
"total": total,
"stdio": stdio,
"sse": sse,
"connected": connected,
"disconnected": disconnected,
"error": error
}
def connect_all_servers(self, loop: asyncio.AbstractEventLoop):
"""连接所有MCP服务器"""
async def _connect_server_safe(server_id):
try:
await self.connect_server(server_id)
except Exception as e:
logger.opt(exception=e).error(f"Exception occurred when connecting MCP server {server_id}")
tasks = []
for server_id in self.servers.keys():
task = loop.create_task(_connect_server_safe(server_id))
tasks.append(task)
if tasks:
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
async def connect_server(self, server_id: str) -> bool:
"""连接MCP服务器"""
server = self.servers.get(server_id)
if not server:
logger.error(f"Cannot connect to non-existent MCP server: {server_id}")
return False
if server.state == MCPConnectionState.CONNECTED:
logger.warning(f"MCP server {server_id} is already connected")
return True
try:
logger.info(f"Connecting to MCP server {server_id}")
server.message_handler = partial(self._handle_server_message, server_id)
# 连接到服务器
success = await server.connect()
if not success:
logger.error(f"Failed to connect to MCP server {server_id}")
return False
# 连接成功后,更新缓存
await self._update_tools_cache(server_id)
await self._update_prompts_cache(server_id)
await self._update_resources_cache(server_id)
logger.info(f"Successfully connected to MCP server {server_id}")
return True
except Exception as e:
logger.opt(exception=e).error(f"Error occurred when connecting to MCP server {server_id}")
return False
def disconnect_all_servers(self, loop: asyncio.AbstractEventLoop):
"""断开所有MCP服务器连接"""
disconnect_tasks = []
for server_id, server in self.servers.items():
if server.state == MCPConnectionState.CONNECTED:
disconnect_tasks.append(loop.create_task(self.stop_server(server_id)))
if disconnect_tasks:
loop.run_until_complete(asyncio.gather(*disconnect_tasks, return_exceptions=True))
self.tools_cache.clear()
logger.info("All MCP servers have been disconnected")
async def stop_server(self, server_id: str) -> bool:
"""断开MCP服务器连接"""
server = self.servers.get(server_id)
if not server:
logger.error(f"Cannot disconnect from non-existent MCP server: {server_id}")
return False
if server.state != MCPConnectionState.CONNECTED:
logger.warning(f"MCP server {server_id} is not connected")
return True
try:
logger.info(f"Disconnecting from MCP server {server_id}")
# 断开服务器连接
success = await server.disconnect()
if not success:
logger.error(f"Failed to disconnect from MCP server {server_id}")
return False
# 从工具缓存中移除该服务器的工具
self._remove_server_tools_from_cache(server_id)
logger.info(f"Successfully disconnected from MCP server {server_id}")
return True
except Exception as e:
logger.opt(exception=e).error(f"Error occurred when disconnecting from MCP server {server_id}")
return False
async def _update_tools_cache(self, server_id: str) -> bool:
"""
更新指定服务器的工具缓存
Args:
server_id: 服务器ID
Returns:
bool: 更新是否成功
"""
server = self.servers.get(server_id)
if not server or server.state != MCPConnectionState.CONNECTED:
return False
try:
# 获取服务器工具列表
tools = await server.get_tools()
# 先移除该服务器的旧工具
self._remove_server_tools_from_cache(server_id)
# 添加新工具到缓存
for tool in tools.tools:
original_name = tool.name
if not original_name:
continue
# 检查工具名称是否已存在
if original_name in self.tools_cache:
# 名称冲突,使用 server_id.tool_name 作为新名称
display_name = f"{server.server_config.id}.{original_name}"
logger.warning(f"工具名称冲突: {original_name},重命名为 {display_name}")
else:
display_name = original_name
# 存储工具信息
self.tools_cache[display_name] = ToolCacheEntry(
server_id=server_id,
original_name=original_name,
tool_info=tool
)
return True
except McpError as e:
if e.error == "Method not found":
logger.warning(f"Server {server_id} does not support tools")
return True
except Exception as e:
logger.opt(exception=e).error(f"更新服务器 {server_id} 工具缓存时发生错误")
return False
def _remove_server_tools_from_cache(self, server_id: str):
"""
从工具缓存中移除指定服务器的所有工具
Args:
server_id: 服务器ID
"""
# 找出属于该服务器的所有工具名称
tool_names_to_remove = [
name for name, entry in self.tools_cache.items() if entry.server_id == server_id
]
# 从缓存中移除这些工具
for name in tool_names_to_remove:
self.tools_cache.pop(name, None)
def get_tools(self) -> Dict[str, ToolCacheEntry]:
"""
获取所有可用工具
"""
# 返回工具信息
return self.tools_cache
def get_tool_server(self, tool_name: str) -> Optional[Tuple[MCPServer, str]]:
"""
根据工具名称获取对应的服务器实例和原始工具名称
Args:
tool_name: 工具显示名称
Returns:
Optional[Tuple[MCPServer, str]]: (服务器实例, 原始工具名称),如果工具不存在则返回None
"""
if tool_name not in self.tools_cache:
return None
entry = self.tools_cache[tool_name]
server = self.servers.get(entry.server_id)
if not server:
return None
return (server, entry.original_name)
async def call_tool(self, tool_name: str, tool_args: dict) -> Optional[types.CallToolResult]:
"""
调用指定工具
Args:
tool_name: 工具显示名称
tool_args: 工具参数
Returns:
Optional[dict]: 工具调用结果,如果调用失败则返回None
"""
result = self.get_tool_server(tool_name)
if not result:
logger.error(f"Tool {tool_name} not found or server not available")
return None
server, original_name = result
if server.state != MCPConnectionState.CONNECTED:
logger.error(f"Server for tool {tool_name} is not connected")
return None
try:
# 使用原始工具名称调用
call_tool_result = await server.call_tool(original_name, tool_args)
return call_tool_result
except Exception as e:
logger.opt(exception=e).error(f"Error occurred when calling tool {tool_name}")
return None
async def _update_prompts_cache(self, server_id: str) -> bool:
"""
更新指定server的prompts 索引缓存
# notification !
这个函数存的缓存是一个prompts的索引,请调用get_prompts获取具体的prompts信息
Args:
server_id: 服务器ID
Returns:
bool: 更新是否成功
"""
server = self.servers.get(server_id)
if not server or server.state != MCPConnectionState.CONNECTED:
return False
try:
# 获取服务器prompts 索引
prompts = await server.list_prompts()
# 移除旧缓存
self.prompts_cache.pop(server_id, None)
# 添加新索引到缓存
self.prompts_cache[server_id] = prompts.prompts
return True
except McpError as e:
if e.error == "Method not found":
self.prompts_cache[server_id] = []
logger.warning(f"Server {server_id} does not support prompts")
return True
except Exception as e:
logger.opt(exception=e).error(f"更新服务器 {server_id} prompts 索引缓存时发生错误")
return False
async def get_prompt_list(self, server_id: str) -> Optional[list[types.Prompt]]:
"""
获取指定服务器的prompts
Args:
server_id: 服务器ID
Returns:
types.GetPromptResult: prompts
"""
server = self.servers.get(server_id)
if not server or server.state != MCPConnectionState.CONNECTED:
return None
return self.prompts_cache.get(server_id, [])
async def get_prompt(self, server_id: str, prompt_name: str, prompt_args: dict[str, str] | None = None) -> Optional[types.GetPromptResult]:
"""
获取指定服务器的prompt
"""
server = self.servers.get(server_id)
if not server or server.state != MCPConnectionState.CONNECTED:
return None
return await server.get_prompt(prompt_name, prompt_args)
async def _update_resources_cache(self, server_id: str) -> bool:
"""
更新指定server的resources 缓存
# notification !
这个函数存的缓存是一个resources的索引,请调用get_resources获取具体的resources信息
Args:
server_id: 服务器ID
Returns:
bool: 更新是否成功
"""
server = self.servers.get(server_id)
if not server or server.state != MCPConnectionState.CONNECTED:
return False
try:
# 获取服务器resources 索引
resources = await server.list_resources()
# 移除旧缓存
self.resources_cache.pop(server_id, None)
# 存储新索引到缓存
self.resources_cache[server_id] = resources.resources
return True
except McpError as e:
if e.error == "Method not found":
self.resources_cache[server_id] = []
logger.warning(f"Server {server_id} does not support resources")
return True
except Exception as e:
logger.opt(exception=e).error(f"更新服务器 {server_id} resources 缓存时发生错误")
return False
async def get_resource_list(self, server_id: str) -> Optional[list[types.Resource]]:
"""获取指定服务器的资源列表
Args:
server_id (str): 服务器ID
Returns:
Optional[types.Resource]: 资源列表
"""
server = self.servers.get(server_id)
if not server or server.state != MCPConnectionState.CONNECTED:
return None
return self.resources_cache.get(server_id, [])
async def get_resource(self, server_id: str, uri: str) -> Optional[types.ReadResourceResult]:
"""
获取指定服务器的resources
Args:
server_id: 服务器ID
uri: 资源URI
Returns:
types.ReadResourceResult: resource
"""
server = self.servers.get(server_id)
if not server or server.state != MCPConnectionState.CONNECTED:
return None
return await server.read_resource(uri)
async def _handle_server_message(self, server_id: str, message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception):
"""
处理服务器通知
"""
if isinstance(message, types.ToolListChangedNotification):
await self._update_tools_cache(server_id)
elif isinstance(message, types.PromptListChangedNotification):
await self._update_prompts_cache(server_id)
elif isinstance(message, types.ResourceListChangedNotification):
await self._update_resources_cache(server_id)
else:
logger.warning(f"Unknown notification from server {server_id}: {message}")
================================================
FILE: kirara_ai/mcp_module/models.py
================================================
from enum import Enum
class MCPConnectionState(Enum):
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
DISCONNECTING = "disconnecting"
ERROR = "error"
================================================
FILE: kirara_ai/mcp_module/server.py
================================================
import asyncio
from contextlib import AsyncExitStack
from typing import Optional
import anyio
import anyio.lowlevel
from mcp import ClientSession, StdioServerParameters, stdio_client, types
from mcp.client.session import MessageHandlerFnT
from mcp.client.sse import sse_client
from mcp.shared.session import RequestResponder
from pydantic import AnyUrl
from kirara_ai.config.global_config import MCPServerConfig
from kirara_ai.logger import get_logger
from .models import MCPConnectionState
logger = get_logger("MCP.Server")
class MCPServer:
"""
MCP (Model Control Protocol) 服务器客户端类
用于与 MCP 服务器进行通信,支持 stdio 和 SSE 两种连接模式。
提供工具调用、补全、资源管理等功能。
本类为 mcp.ClientSession 的代理,
使其适应 Kirara AI 的生命周期。
"""
session: Optional[ClientSession] = None
state: MCPConnectionState = MCPConnectionState.DISCONNECTED
message_handler: Optional[MessageHandlerFnT] = None
def __init__(self, server_config: MCPServerConfig):
"""
初始化 MCP 服务器客户端
Args:
server_config: MCP 服务器配置
"""
self.server_config = server_config
self.session = None
self.state = MCPConnectionState.DISCONNECTED
self._lifecycle_task = None
self._shutdown_event = asyncio.Event()
self._connected_event = asyncio.Event()
self._client = None
self.message_handler = None
async def connect(self):
"""
连接到 MCP 服务器
根据配置连接到 MCP 服务器,并初始化会话
Returns:
bool: 连接是否成功
"""
if self.state != MCPConnectionState.DISCONNECTED and self.state != MCPConnectionState.ERROR:
return False
try:
self.state = MCPConnectionState.CONNECTING
# 重置事件
self._shutdown_event.clear()
self._connected_event.clear()
# 创建并启动生命周期任务
if self._lifecycle_task is None or self._lifecycle_task.done():
self._lifecycle_task = asyncio.create_task(self._lifecycle_manager())
# 等待连接完成或超时
try:
await asyncio.wait_for(self._connected_event.wait(), timeout=30.0)
if self.state != MCPConnectionState.CONNECTED:
# 连接失败
return False
return True
except asyncio.TimeoutError:
logger.error(f"连接到 MCP 服务器 {self.server_config.id} 超时")
await self.disconnect() # 超时时断开连接
return False
except Exception as e:
self.state = MCPConnectionState.ERROR
logger.opt(exception=e).error(f"连接 MCP 服务器 {self.server_config.id} 时发生错误")
return False
async def disconnect(self):
"""
断开与 MCP 服务器的连接
Returns:
bool: 断开连接是否成功
"""
if self.state == MCPConnectionState.DISCONNECTED:
return True
try:
self.state = MCPConnectionState.DISCONNECTING
# 发送关闭信号
self._shutdown_event.set()
# 等待生命周期任务完成
if self._lifecycle_task and not self._lifecycle_task.done():
try:
await asyncio.wait_for(self._lifecycle_task, timeout=10.0)
except asyncio.TimeoutError:
# 如果任务没有及时完成,取消它
self._lifecycle_task.cancel()
try:
await self._lifecycle_task
except (asyncio.CancelledError, Exception):
pass
self.state = MCPConnectionState.DISCONNECTED
return True
except Exception as e:
self.state = MCPConnectionState.ERROR
logger.opt(exception=e).error(f"断开 MCP 服务器 {self.server_config.id} 连接时发生错误")
return False
async def _lifecycle_manager(self):
"""
服务器生命周期管理任务
负责服务器的连接、运行和断开连接的完整生命周期
"""
exit_stack = AsyncExitStack()
try:
# 初始化连接
if self.server_config.connection_type == "stdio":
if self.server_config.command is None:
raise ValueError("stdio 连接类型需要提供命令")
self._client = stdio_client(StdioServerParameters(
command=self.server_config.command,
args=self.server_config.args,
env=self.server_config.env
))
elif self.server_config.connection_type == "sse":
if self.server_config.url is None:
raise ValueError("sse 连接类型需要提供 url")
self._client = sse_client(self.server_config.url, headers=self.server_config.headers)
else:
raise ValueError(f"不支持的服务器连接类型: {self.server_config.connection_type}")
# 使用 exit_stack 管理资源
read, write = await exit_stack.enter_async_context(self._client)
self.session = await exit_stack.enter_async_context(ClientSession(read, write, message_handler=self.message_handler_callback))
# 初始化会话
await self.session.initialize()
# 更新状态并通知连接完成
self.state = MCPConnectionState.CONNECTED
self._connected_event.set()
# 等待关闭信号
await self._shutdown_event.wait()
except Exception as e:
# 连接失败
self.state = MCPConnectionState.ERROR
self._connected_event.set() # 通知连接过程已完成(虽然是失败的)
logger.opt(exception=e).error(f"MCP server {self.server_config.id} lifecycle task error")
finally:
# 清理资源
self.session = None
self._client = None
try:
# 关闭所有资源
await exit_stack.aclose()
except Exception as e:
logger.opt(exception=e).error(f"error occured during shutting down handle of: {self.server_config.id}")
# 如果状态仍然是 DISCONNECTING,则更新为 DISCONNECTED
if self.state == MCPConnectionState.DISCONNECTING:
self.state = MCPConnectionState.DISCONNECTED
# 工具相关方法
async def get_tools(self) -> types.ListToolsResult:
"""获取可用工具列表"""
assert self.session is not None
return await self.session.list_tools()
async def call_tool(self, tool_name: str, tool_args: Optional[dict] = None) -> types.CallToolResult:
"""
调用指定工具
Args:
tool_name: 工具名称
tool_args: 工具参数
Returns:
工具调用结果
"""
assert self.session is not None
return await self.session.call_tool(tool_name, tool_args)
async def complete(self, prompt: str, tool_args: dict):
"""
使用模型进行补全
Args:
prompt: 提示文本
tool_args: 补全参数
Returns:
补全结果
"""
assert self.session is not None
return await self.session.complete(types.PromptReference(name=prompt, type="ref/prompt"), tool_args)
# 提示词相关方法
async def get_prompt(self, prompt_name: str, prompt_args: dict[str, str] | None = None) -> types.GetPromptResult:
"""
获取指定提示词
Args:
prompt_name: 提示词名称
prompt_args: 提示词参数
Returns:
提示词内容
"""
assert self.session is not None
return await self.session.get_prompt(prompt_name, prompt_args)
async def list_prompts(self) -> types.ListPromptsResult:
"""获取可用提示词列表"""
assert self.session is not None
return await self.session.list_prompts()
# 资源相关方法
async def list_resources(self) -> types.ListResourcesResult:
"""获取可用资源列表"""
assert self.session is not None
return await self.session.list_resources()
async def list_resource_templates(self) -> types.ListResourceTemplatesResult:
"""获取可用资源模板列表"""
assert self.session is not None
return await self.session.list_resource_templates()
async def read_resource(self, uri: str) -> types.ReadResourceResult:
"""
读取指定资源
Args:
uri: 资源名称
Returns:
资源内容
"""
assert self.session is not None
return await self.session.read_resource(AnyUrl(uri))
async def subscribe_resource(self, uri: str) -> types.EmptyResult:
"""
订阅指定资源
Args:
uri: 资源名称
Returns:
订阅结果
"""
assert self.session is not None
return await self.session.subscribe_resource(AnyUrl(uri))
async def unsubscribe_resource(self, uri: str) -> types.EmptyResult:
"""
取消订阅指定资源
Args:
uri: 资源名称
Returns:
取消订阅结果
"""
assert self.session is not None
return await self.session.unsubscribe_resource(AnyUrl(uri))
async def message_handler_callback(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
"""
消息处理回调函数
Args:
message: 请求响应器或通知或异常
"""
if self.message_handler is None:
logger.warning(f"MCP客户端接收到服务器{self.server_config.id}的通知,但未对其进行处理: {message}")
await anyio.lowlevel.checkpoint()
return
await self.message_handler(message)
async def list_client_roots_callback(self, ctx) -> types.ListRootsResult | types.ErrorData:
"""
列出客户端允许的资源根目录
Args:
Returns:
types.ListRootsResult: 资源根目录列表
"""
# 这个需要kirara-agent做出较大支持,webApi 中设定允许的资源根,最好弄个单独的目录。
# 文件根格式为file:///myResource/, 也可以为一个url.
raise NotImplementedError("list_client_roots_callback 未实现")
async def send_ping(self) -> None:
assert self.session is not None
await self.session.send_ping()
async def send_notification(self, notification: types.ClientNotification) -> None:
"""
给服务器发消息,例如资源根更改
不使用 ClientSession 的 send_roots_list_changed,因为它只支持发送 RootsListChangedNotification。
这里使用其父对象 BaseSession 的 send_notification,其支持发送所有 ClientNotification。
Args:
notification: 客户端通知
"""
assert self.session is not None
await self.session.send_notification(notification)
async def sampling_callback(self):
"""
采样回调函数
"""
async def logging_callback(self):
"""
日志回调函数
"""
================================================
FILE: kirara_ai/media/__init__.py
================================================
from kirara_ai.media.manager import MediaManager
from kirara_ai.media.media_object import Media
from kirara_ai.media.metadata import MediaMetadata
from kirara_ai.media.types import MediaType
from kirara_ai.media.utils import detect_mime_type
__all__ = [
"Media",
"MediaManager",
"MediaMetadata",
"MediaType",
"detect_mime_type",
]
================================================
FILE: kirara_ai/media/carrier/__init__.py
================================================
from .provider import MediaReferenceProvider
from .registry import MediaCarrierRegistry
from .service import MediaCarrierService
__all__ = ["MediaReferenceProvider", "MediaCarrierRegistry",
"MediaCarrierService"]
================================================
FILE: kirara_ai/media/carrier/provider.py
================================================
from abc import ABC, abstractmethod
from typing import Generic, Optional, TypeVar
T = TypeVar("T")
class MediaReferenceProvider(ABC, Generic[T]):
"""媒体引用提供者接口,由需要引用媒体的组件实现"""
@abstractmethod
def get_reference_owner(self, reference_key: str) -> Optional[T]:
"""根据引用键获取引用所有者"""
================================================
FILE: kirara_ai/media/carrier/registry.py
================================================
from typing import Dict
from kirara_ai.ioc.container import DependencyContainer
from .provider import MediaReferenceProvider
class MediaCarrierRegistry:
"""媒体载体注册表,管理所有媒体引用提供者"""
def __init__(self, container: DependencyContainer):
self.container = container
self._providers: Dict[str, MediaReferenceProvider] = {}
def register(self, provider_name: str, provider_instance: MediaReferenceProvider) -> None:
"""注册媒体引用提供者"""
self._providers[provider_name] = provider_instance
def unregister(self, provider_name: str) -> None:
"""注销媒体引用提供者"""
if provider_name in self._providers:
del self._providers[provider_name]
def get_provider(self, provider_name: str) -> MediaReferenceProvider:
"""获取媒体引用提供者实例"""
if provider_name not in self._providers:
raise ValueError(f"Provider not found: {provider_name}")
return self._providers[provider_name]
def get_all_providers(self) -> Dict[str, MediaReferenceProvider]:
"""获取所有媒体引用提供者实例"""
return self._providers
================================================
FILE: kirara_ai/media/carrier/service.py
================================================
from typing import Any, Dict, List, Optional, Tuple
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.media.manager import MediaManager
from kirara_ai.media.media_object import Media
from .registry import MediaCarrierRegistry
class MediaCarrierService:
"""媒体载体服务,负责媒体引用的管理"""
def __init__(self, container: DependencyContainer, media_manager: MediaManager):
self.container = container
self.media_manager = media_manager
self.registry = container.resolve(MediaCarrierRegistry)
# 引用索引:reference_key -> (provider_name, media_id)
self._reference_index: Dict[str, Tuple[str, str]] = {}
self._build_reference_index()
def _build_reference_index(self) -> None:
"""构建引用索引"""
self._reference_index.clear()
# 遍历所有媒体元数据,提取引用信息
for media_id, metadata in self.media_manager.metadata_cache.items():
for reference_key in metadata.references:
# 尝试从引用键中提取提供者名称
if ":" in reference_key:
provider_name, _ = reference_key.split(":", 1)
self._reference_index[reference_key] = (provider_name, media_id)
def register_reference(self, media_id: str, provider_name: str, reference_key: str) -> None:
"""注册媒体引用"""
# 检查媒体是否存在
if media_id not in self.media_manager.metadata_cache:
raise ValueError(f"媒体不存在: {media_id}")
# 构造完整引用键
full_reference_key = f"{provider_name}:{reference_key}"
# 添加引用
self.media_manager.add_reference(media_id, full_reference_key)
# 更新引用索引
self._reference_index[full_reference_key] = (provider_name, media_id)
def remove_reference(self, media_id: str, provider_name: str, reference_key: str) -> None:
"""移除媒体引用"""
# 检查媒体是否存在
if media_id not in self.media_manager.metadata_cache:
return
# 构造完整引用键
full_reference_key = f"{provider_name}:{reference_key}"
# 移除引用
self.media_manager.remove_reference(media_id, full_reference_key)
# 更新引用索引
if full_reference_key in self._reference_index:
del self._reference_index[full_reference_key]
def get_reference_owner(self, reference_key: str) -> Optional[Any]:
"""获取引用所有者"""
if reference_key not in self._reference_index:
return None
provider_name, _ = self._reference_index[reference_key]
try:
provider = self.registry.get_provider(provider_name)
return provider.get_reference_owner(reference_key.split(":", 1)[1])
except (ValueError, IndexError):
return None
def get_media_by_reference(self, provider_name: str, reference_key: str) -> List[Media]:
"""根据引用键获取媒体对象"""
full_reference_key = f"{provider_name}:{reference_key}"
result = []
for ref_key, (prov_name, media_id) in self._reference_index.items():
if ref_key == full_reference_key:
media = self.media_manager.get_media(media_id)
if media:
result.append(media)
return result
def get_references_by_media(self, media_id: str) -> List[Tuple[str, str]]:
"""获取媒体的所有引用信息"""
if media_id not in self.media_manager.metadata_cache:
return []
metadata = self.media_manager.metadata_cache[media_id]
references = []
for reference_key in metadata.references:
if ":" in reference_key:
provider_name, key = reference_key.split(":", 1)
references.append((provider_name, key))
return references
def cleanup_orphaned_references(self) -> int:
"""清理孤立的引用(引用提供者不存在)"""
count = 0
all_providers = set(self.registry._providers.keys())
for media_id, metadata in list(self.media_manager.metadata_cache.items()):
orphaned_refs = set()
for reference_key in metadata.references:
if ":" in reference_key:
provider_name, _ = reference_key.split(":", 1)
if provider_name not in all_providers:
orphaned_refs.add(reference_key)
# 移除孤立引用
for ref in orphaned_refs:
self.media_manager.remove_reference(media_id, ref)
if ref in self._reference_index:
del self._reference_index[ref]
count += 1
return count
================================================
FILE: kirara_ai/media/manager.py
================================================
import asyncio
import base64
import hashlib
import json
import shutil
import time
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional
import aiofiles
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
from kirara_ai.media.metadata import MediaMetadata
from kirara_ai.media.types.media_type import MediaType
from kirara_ai.media.utils.mime import detect_mime_type
if TYPE_CHECKING:
from kirara_ai.im.message import MediaMessage
from kirara_ai.media.media_object import Media
class MediaManager:
"""媒体管理器,负责媒体文件的注册、引用计数和生命周期管理"""
def __init__(self, media_dir: str = "data/media"):
self.media_dir = Path(media_dir)
self.metadata_dir = self.media_dir / "metadata"
self.files_dir = self.media_dir / "files"
self.metadata_cache: Dict[str, MediaMetadata] = {}
self.logger = get_logger("MediaManager")
self._pending_tasks: set[asyncio.Task] = set()
# 确保目录存在
self.media_dir.mkdir(parents=True, exist_ok=True)
self.metadata_dir.mkdir(parents=True, exist_ok=True)
self.files_dir.mkdir(parents=True, exist_ok=True)
self._cleanup_task = None
# 加载所有元数据
self._load_all_metadata()
def _load_all_metadata(self) -> None:
"""加载所有媒体元数据"""
self.metadata_cache.clear()
for metadata_file in self.metadata_dir.glob("*.json"):
try:
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = MediaMetadata.from_dict(json.load(f))
self.metadata_cache[metadata.media_id] = metadata
except Exception as e:
self.logger.error(f"Failed to load metadata from {metadata_file}: {e}")
def _save_metadata(self, metadata: MediaMetadata) -> None:
"""保存媒体元数据"""
metadata_path = self.metadata_dir / f"{metadata.media_id}.json"
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata.to_dict(), f, ensure_ascii=False, indent=2)
self.metadata_cache[metadata.media_id] = metadata
def _get_file_path(self, media_id: str, format: str) -> Path:
"""获取媒体文件路径"""
return self.files_dir / f"{media_id}.{format}"
def _create_task(self, coro, name=None, loop=None):
"""创建后台任务并跟踪它"""
if loop is None:
loop = asyncio.get_event_loop()
task = asyncio.ensure_future(coro, loop=loop)
self._pending_tasks.add(task)
task.add_done_callback(self._pending_tasks.discard)
return task
async def _save_file_async(self, data: bytes, target_path: Path):
"""异步保存文件"""
async with aiofiles.open(target_path, "wb") as f:
await f.write(data)
async def _download_file_async(self, url: str) -> bytes:
"""异步下载文件"""
from curl_cffi import AsyncSession, Response
# 如果 url 是 file:// 开头,则直接返回文件内容
if url.startswith("file://"):
async with aiofiles.open(url[7:], "rb") as f:
return await f.read()
async with AsyncSession(trust_env=True, timeout=3000) as session:
resp: Response = await session.get(url)
if resp.status_code != 200:
raise ValueError(f"Failed to download file from {url}, status: {resp.status_code}")
return resp.content
def _download_file_sync(self, url: str) -> bytes:
"""同步下载文件"""
from curl_cffi import Response, Session
# 如果 url 是 file:// 开头,则直接返回文件内容
if url.startswith("file://"):
with open(url[7:], "rb") as f:
return f.read()
with Session() as session:
resp: Response = session.get(url)
if resp.status_code != 200:
raise ValueError(f"Failed to download file from {url}, status: {resp.status_code}")
return resp.content
async def register_media(
self,
url: Optional[str] = None,
path: Optional[str] = None,
data: Optional[bytes] = None,
format: Optional[str] = None,
media_type: Optional[MediaType] = None,
size: Optional[int] = None,
source: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
reference_id: Optional[str] = None,
) -> str:
"""
注册媒体(统一方法)
Args:
url: 媒体URL
path: 媒体文件路径
data: 媒体二进制数据
format: 媒体格式
media_type: 媒体类型
size: 媒体大小
source: 媒体来源
description: 媒体描述
tags: 媒体标签
reference_id: 引用ID
Returns:
str: 媒体ID
"""
# 检查参数
if not any([url, path, data]):
raise ValueError("Must provide at least one of url, path, or data")
# 获取数据
if path:
file_path = Path(path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {path}")
try:
async with aiofiles.open(file_path, "rb") as f:
data = await f.read()
except Exception as e:
self.logger.error(f"Failed to read file: {e}", exc_info=True)
raise
elif url:
try:
data = await self._download_file_async(url)
except Exception as e:
self.logger.error(f"Failed to download file: {e}", exc_info=True)
raise
# 计算 SHA1
if data is None:
raise ValueError("Unable to fetch data from url or path, please check your input")
hash_data = await asyncio.to_thread(hashlib.sha1, data)
media_id = hash_data.hexdigest()
# 检查是否已存在相同 media_id 的媒体
if media_id in self.metadata_cache:
self.logger.info(f"Media already exists: {media_id}")
return media_id
# 获取数据大小
if not size:
size = len(data)
# 检测文件类型
if not media_type or not format:
mime_type, detected_media_type, detected_format = detect_mime_type(data=data)
media_type = media_type or detected_media_type
format = format or detected_format
# 保存文件
if format:
target_path = self._get_file_path(media_id, format)
try:
await self._save_file_async(data, target_path)
except Exception as e:
self.logger.error(f"Failed to save file: {e}", exc_info=True)
raise
path = str(target_path)
else:
raise ValueError("No format detected")
# 创建元数据
metadata = MediaMetadata(
media_id=media_id,
media_type=media_type,
format=format,
size=size,
created_at=None, # 使用默认值
source=source,
description=description,
tags=tags,
references=set([reference_id]) if reference_id else set(),
url=url,
path=path,
)
# 保存元数据
self._save_metadata(metadata)
self.logger.info(f"Registered media: {media_id}")
return media_id
async def register_from_path(
self,
path: str,
source: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
reference_id: Optional[str] = None
) -> str:
"""从文件路径注册媒体"""
# 检查文件是否存在
file_path = Path(path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {path}")
return await self.register_media(
path=path,
source=source,
description=description,
tags=tags,
reference_id=reference_id
)
async def register_from_url(
self,
url: str,
source: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
reference_id: Optional[str] = None
) -> str:
"""从URL注册媒体"""
return await self.register_media(
url=url,
source=source,
description=description,
tags=tags,
reference_id=reference_id
)
async def register_from_data(
self,
data: bytes,
format: Optional[str] = None,
source: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
reference_id: Optional[str] = None,
media_type: Optional[MediaType] = None
) -> str:
"""从二进制数据注册媒体"""
return await self.register_media(
data=data,
format=format,
source=source,
description=description,
tags=tags,
reference_id=reference_id,
media_type=media_type
)
def add_reference(self, media_id: str, reference_id: str) -> None:
"""添加引用"""
if media_id not in self.metadata_cache:
raise ValueError(f"Media not found: {media_id}")
metadata = self.metadata_cache[media_id]
metadata.references.add(reference_id)
self._save_metadata(metadata)
def remove_reference(self, media_id: str, reference_id: str) -> None:
"""移除引用"""
if media_id not in self.metadata_cache:
raise ValueError(f"Media not found: {media_id}")
metadata = self.metadata_cache[media_id]
if reference_id in metadata.references:
metadata.references.remove(reference_id)
self._save_metadata(metadata)
# 如果没有引用了,输出log提醒一下
if not metadata.references:
self.logger.warning(f"No references found for media: {media_id}, file: {metadata.path}")
# 删除文件
self.delete_media(media_id)
def delete_media(self, media_id: str) -> None:
"""删除媒体文件和元数据"""
if media_id not in self.metadata_cache:
return
metadata = self.metadata_cache[media_id]
# 删除文件
if metadata.format:
file_path = self._get_file_path(media_id, metadata.format)
if file_path.exists():
file_path.unlink()
# 删除元数据
metadata_path = self.metadata_dir / f"{media_id}.json"
if metadata_path.exists():
metadata_path.unlink()
# 从缓存中移除
del self.metadata_cache[media_id]
self.logger.info(f"Deleted media: {media_id}")
def update_metadata(
self,
media_id: str,
source: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
url: Optional[str] = None,
path: Optional[str] = None
) -> None:
"""更新媒体元数据"""
if media_id not in self.metadata_cache:
raise ValueError(f"Media not found: {media_id}")
metadata = self.metadata_cache[media_id]
if source is not None:
metadata.source = source
if description is not None:
metadata.description = description
if tags is not None:
metadata.tags = tags
if url is not None:
metadata.url = url
if path is not None:
metadata.path = path
self._save_metadata(metadata)
def add_tags(self, media_id: str, tags: List[str]) -> None:
"""添加标签"""
if media_id not in self.metadata_cache:
raise ValueError(f"Media not found: {media_id}")
metadata = self.metadata_cache[media_id]
for tag in tags:
if tag not in metadata.tags:
metadata.tags.append(tag)
self._save_metadata(metadata)
def remove_tags(self, media_id: str, tags: List[str]) -> None:
"""移除标签"""
if media_id not in self.metadata_cache:
raise ValueError(f"Media not found: {media_id}")
metadata = self.metadata_cache[media_id]
for tag in tags:
if tag in metadata.tags:
metadata.tags.remove(tag)
self._save_metadata(metadata)
def get_metadata(self, media_id: str) -> Optional[MediaMetadata]:
"""获取媒体元数据"""
return self.metadata_cache.get(media_id)
async def ensure_file_exists(self, media_id: str) -> Optional[Path]:
"""确保媒体文件存在,如果不存在则尝试下载或复制"""
if media_id not in self.metadata_cache:
return None
metadata = self.metadata_cache[media_id]
# 如果没有格式信息,无法确定文件路径
if not metadata.format:
# 如果有path,尝试复制并检测格式
if metadata.path:
try:
file_path = Path(metadata.path)
if not file_path.exists():
return None
_, media_type, format = detect_mime_type(path=str(file_path))
# 更新元数据
metadata.media_type = media_type
metadata.format = format
metadata.size = file_path.stat().st_size
self._save_metadata(metadata)
# 复制文件
target_path = self._get_file_path(media_id, format)
shutil.copy2(file_path, target_path)
return target_path
except Exception as e:
self.logger.error(f"Failed to copy media from path: {metadata.path}, error: {e}")
return None
# 如果有URL,尝试下载并检测格式
elif metadata.url:
try:
data = await self._download_file_async(metadata.url)
_, media_type, format = detect_mime_type(data=data)
# 更新元数据
metadata.media_type = media_type
metadata.format = format
metadata.size = len(data)
self._save_metadata(metadata)
# 保存文件
target_path = self._get_file_path(media_id, format)
await self._save_file_async(data, target_path)
return target_path
except Exception as e:
self.logger.error(f"Failed to download media from URL: {metadata.url}, error: {e}")
return None
return None
# 检查文件是否存在
file_path = self._get_file_path(media_id, metadata.format)
if file_path.exists():
return file_path
# 如果文件不存在,尝试从URL下载
if metadata.url:
try:
data = await self._download_file_async(metadata.url)
await self._save_file_async(data, file_path)
return file_path
except Exception as e:
self.logger.error(f"Failed to download media from URL: {metadata.url}, error: {e}")
# 如果文件不存在,尝试从path复制
if metadata.path:
try:
source_path = Path(metadata.path)
if source_path.exists():
shutil.copy2(source_path, file_path)
return file_path
except Exception as e:
self.logger.error(f"Failed to copy media from path: {metadata.path}, error: {e}")
return None
async def get_file_path(self, media_id: str) -> Optional[Path]:
"""获取媒体文件路径,如果文件不存在则尝试下载或复制"""
if media_id not in self.metadata_cache:
return None
metadata = self.metadata_cache[media_id]
# 如果有原始路径,直接返回
if metadata.path and Path(metadata.path).exists():
return Path(metadata.path)
# 否则确保文件存在并返回
return await self.ensure_file_exists(media_id)
async def get_data(self, media_id: str) -> Optional[bytes]:
"""获取媒体文件数据"""
if media_id not in self.metadata_cache:
return None
metadata = self.metadata_cache[media_id]
# 尝试从文件读取
file_path = await self.get_file_path(media_id)
if file_path:
try:
async with aiofiles.open(file_path, "rb") as f:
return await f.read()
except Exception as e:
self.logger.error(f"Failed to read media file: {file_path}, error: {e}")
# 尝试从URL下载
if metadata.url:
try:
return await self._download_file_async(metadata.url)
except Exception as e:
self.logger.error(f"Failed to download media from URL: {metadata.url}, error: {e}")
return None
async def get_url(self, media_id: str) -> Optional[str]:
"""获取媒体文件URL"""
if media_id not in self.metadata_cache:
return None
metadata = self.metadata_cache[media_id]
# 如果有原始URL,直接返回
if metadata.url:
return metadata.url
# 尝试生成data URL
data = await self.get_data(media_id)
if data and metadata.media_type and metadata.format:
mime_type = f"{metadata.media_type.value}/{metadata.format}"
return f"data:{mime_type};base64,{base64.b64encode(data).decode()}"
return None
async def get_base64_url(self, media_id: str) -> Optional[str]:
"""获取媒体文件 base64 URL"""
if media_id not in self.metadata_cache:
return None
metadata = self.metadata_cache[media_id]
data = await self.get_data(media_id)
if data and metadata.media_type and metadata.format:
mime_type = f"{metadata.media_type.value}/{metadata.format}"
return f"data:{mime_type};base64,{base64.b64encode(data).decode()}"
return None
def search_by_tags(self, tags: List[str], match_all: bool = False) -> List[str]:
"""根据标签搜索媒体"""
results = []
for media_id, metadata in self.metadata_cache.items():
if match_all:
# 必须匹配所有标签
if all(tag in metadata.tags for tag in tags):
results.append(media_id)
else:
# 匹配任一标签
if any(tag in metadata.tags for tag in tags):
results.append(media_id)
return results
def search_by_description(self, query: str) -> List[str]:
"""根据描述搜索媒体"""
results = []
for media_id, metadata in self.metadata_cache.items():
if metadata.description and query.lower() in metadata.description.lower():
results.append(media_id)
return results
def search_by_source(self, source: str) -> List[str]:
"""根据来源搜索媒体"""
results = []
for media_id, metadata in self.metadata_cache.items():
if metadata.source and source.lower() in metadata.source.lower():
results.append(media_id)
return results
def search_by_type(self, media_type: MediaType) -> List[str]:
"""根据媒体类型搜索媒体"""
results = []
for media_id, metadata in self.metadata_cache.items():
if metadata.media_type == media_type:
results.append(media_id)
return results
def get_all_media_ids(self) -> List[str]:
"""获取所有媒体ID"""
return list(self.metadata_cache.keys())
def cleanup_unreferenced(self) -> int:
"""清理没有引用的媒体文件,返回清理的文件数量"""
count = 0
for media_id, metadata in list(self.metadata_cache.items()):
if not metadata.references:
self.delete_media(media_id)
count += 1
return count
async def create_media_message(self, media_id: str) -> Optional["MediaMessage"]:
"""根据媒体ID创建MediaMessage对象"""
if media_id not in self.metadata_cache:
return None
from kirara_ai.im.message import FileElement, ImageMessage, VideoElement, VoiceMessage
metadata = self.metadata_cache[media_id]
# 根据媒体类型创建不同的MediaMessage子类
if metadata.media_type == MediaType.IMAGE:
return ImageMessage(media_id=media_id)
elif metadata.media_type == MediaType.AUDIO:
return VoiceMessage(media_id=media_id)
elif metadata.media_type == MediaType.VIDEO:
return VideoElement(media_id=media_id)
else:
return FileElement(media_id=media_id)
def get_media(self, media_id: str) -> Optional["Media"]:
"""获取媒体对象"""
if media_id not in self.metadata_cache:
return None
from kirara_ai.media.media_object import Media
return Media(media_id=media_id, media_manager=self)
def __new__(cls, *args, **kwargs) -> "MediaManager":
if not hasattr(cls, "_instance"):
print("new MediaManager")
cls._instance = super(MediaManager, cls).__new__(cls)
return cls._instance
def setup_cleanup_task(self, container: DependencyContainer):
"""设置清理任务"""
config = container.resolve(GlobalConfig)
if self._cleanup_task:
self._cleanup_task.cancel()
if config.media.auto_remove_unreferenced and config.media.cleanup_duration > 0:
duration = config.media.cleanup_duration
async def schedule_cleanup():
while True:
last_cleanup_time = config.media.last_cleanup_time
next_cleanup_time = last_cleanup_time + duration * 24 * 60 * 60
await asyncio.sleep(next_cleanup_time - time.time())
count = self.cleanup_unreferenced()
self.logger.info(f"Cleanup {count} unreferenced media files")
config.media.last_cleanup_time = int(time.time())
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
self._cleanup_task = asyncio.create_task(schedule_cleanup())
================================================
FILE: kirara_ai/media/media_object.py
================================================
import base64
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from kirara_ai.im.message import MediaMessage
from kirara_ai.media.manager import MediaManager
from kirara_ai.media.metadata import MediaMetadata
from kirara_ai.media.types.media_type import MediaType
class Media:
"""媒体对象,提供更方便的媒体操作接口"""
metadata: MediaMetadata
def __init__(self, media_id: str, media_manager: MediaManager):
"""
初始化媒体对象
Args:
media_id: 媒体ID
"""
self.media_id = media_id
self._manager = media_manager
metadata = self._manager.get_metadata(self.media_id)
assert metadata is not None, f"Media metadata not found for {self.media_id}"
self.metadata = metadata
@property
def media_type(self) -> MediaType:
"""获取媒体类型"""
return self.metadata.media_type
@property
def format(self) -> str:
"""获取媒体格式"""
return self.metadata.format
@property
def size(self) -> Optional[int]:
"""获取媒体大小"""
return self.metadata.size
@property
def description(self) -> Optional[str]:
"""获取媒体描述"""
return self.metadata.description
@description.setter
def description(self, value: str) -> None:
"""设置媒体描述"""
self._manager.update_metadata(self.media_id, description=value)
@property
def tags(self) -> List[str]:
"""获取媒体标签"""
metadata = self.metadata
return metadata.tags if metadata else []
@property
def mime_type(self) -> str:
"""获取媒体 MIME 类型"""
return self.metadata.mime_type
def add_tags(self, tags: List[str]) -> None:
"""添加标签"""
self._manager.add_tags(self.media_id, tags)
def remove_tags(self, tags: List[str]) -> None:
"""移除标签"""
self._manager.remove_tags(self.media_id, tags)
def add_reference(self, reference_id: str) -> None:
"""添加引用"""
self._manager.add_reference(self.media_id, reference_id)
def remove_reference(self, reference_id: str) -> None:
"""移除引用"""
self._manager.remove_reference(self.media_id, reference_id)
async def get_file_path(self) -> Path:
"""获取媒体文件路径"""
path = await self._manager.get_file_path(self.media_id)
assert path is not None, f"Media file path not found for {self.media_id}"
return path
async def get_data(self) -> bytes:
"""获取媒体文件数据"""
data = await self._manager.get_data(self.media_id)
assert data is not None, f"Media data not found for {self.media_id}"
return data
async def get_base64(self) -> str:
"""获取媒体文件 base64 编码"""
data = await self.get_data()
assert data is not None, "Media data is None"
return base64.b64encode(data).decode()
async def get_url(self) -> str:
"""获取媒体文件URL"""
url = await self._manager.get_url(self.media_id)
assert url is not None, f"Media URL not found for {self.media_id}"
return url
async def get_base64_url(self) -> str:
"""获取媒体文件 base64 URL"""
return f"data:{self.mime_type};base64,{await self.get_base64()}"
async def create_message(self) -> "MediaMessage":
"""创建媒体消息对象"""
message = await self._manager.create_media_message(self.media_id)
assert message is not None, f"Media message not found for {self.media_id}"
return message
def __str__(self) -> str:
metadata = self.metadata
if metadata:
return f"Media({metadata.media_id}, type={metadata.media_type}, format={metadata.format})"
return f"Media({self.media_id}, not found)"
def __repr__(self) -> str:
return self.__str__()
================================================
FILE: kirara_ai/media/metadata.py
================================================
from datetime import datetime
from typing import Any, Dict, List, Optional, Set
from kirara_ai.media.types.media_type import MediaType
class MediaMetadata:
"""媒体元数据类"""
def __init__(
self,
media_id: str,
media_type: MediaType,
format: str,
size: Optional[int] = None,
created_at: Optional[datetime] = None,
source: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
references: Optional[Set[str]] = None,
url: Optional[str] = None,
path: Optional[str] = None
):
self.media_id = media_id
self.media_type = media_type
self.format = format
self.size = size
self.created_at = created_at or datetime.now()
self.source = source
self.description = description
self.tags: List[str] = tags or []
self.references: Set[str] = references or set()
self.url = url
self.path = path
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result = {
"media_id": self.media_id,
"created_at": self.created_at.isoformat(),
"source": self.source,
"description": self.description,
"tags": self.tags,
"references": list(self.references),
}
# 添加可选字段
if self.media_type:
result["media_type"] = self.media_type.value
if self.format:
result["format"] = self.format
if self.size is not None:
result["size"] = self.size
if self.url:
result["url"] = self.url
if self.path:
result["path"] = self.path
return result
@property
def mime_type(self) -> str:
"""获取 MIME 类型"""
return f"{self.media_type.value}/{self.format}"
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MediaMetadata':
"""从字典创建元数据"""
return cls(
media_id=data["media_id"],
media_type=MediaType(data["media_type"]),
format=data["format"],
size=data.get("size"),
created_at=datetime.fromisoformat(data["created_at"]),
source=data.get("source"),
description=data.get("description"),
tags=data.get("tags", []),
references=set(data.get("references", [])),
url=data.get("url"),
path=data.get("path")
)
================================================
FILE: kirara_ai/media/types/__init__.py
================================================
from kirara_ai.media.types.media_type import MediaType
__all__ = ["MediaType"]
================================================
FILE: kirara_ai/media/types/media_type.py
================================================
from enum import Enum
class MediaType(Enum):
"""媒体类型枚举"""
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
FILE = "file"
@classmethod
def from_mime(cls, mime_type: str) -> 'MediaType':
"""从MIME类型获取媒体类型"""
main_type = mime_type.split('/')[0]
if main_type == "image":
return cls.IMAGE
elif main_type == "audio":
return cls.AUDIO
elif main_type == "video":
return cls.VIDEO
else:
return cls.FILE
================================================
FILE: kirara_ai/media/utils/__init__.py
================================================
from kirara_ai.media.utils.mime import detect_mime_type, mime_remapping
__all__ = ["detect_mime_type", "mime_remapping"]
================================================
FILE: kirara_ai/media/utils/mime.py
================================================
from typing import Optional, Tuple
import magic
from kirara_ai.media.types.media_type import MediaType
# MIME类型重映射
mime_remapping = {
"audio/mpeg": "audio/mp3",
"audio/x-wav": "audio/wav",
"audio/x-m4a": "audio/m4a",
"audio/x-flac": "audio/flac",
}
def detect_mime_type(data: Optional[bytes] = None, path: Optional[str] = None) -> Tuple[str, MediaType, str]:
"""
检测文件的MIME类型
Args:
data: 文件数据
path: 文件路径
Returns:
Tuple[str, MediaType, str]: (mime_type, media_type, format)
"""
try:
if data is not None:
mime_type = magic.from_buffer(data, mime=True)
elif path is not None:
mime_type = magic.from_file(path, mime=True)
else:
raise ValueError("Must provide either data or path")
except Exception as e:
raise ValueError(f"Failed to detect mime type: {e}") from e
if mime_type in mime_remapping:
mime_type = mime_remapping[mime_type]
media_type = MediaType.from_mime(mime_type)
format = mime_type.split('/')[-1]
return mime_type, media_type, format
================================================
FILE: kirara_ai/memory/composes/__init__.py
================================================
from .base import ComposableMessageType, MemoryComposer, MemoryDecomposer
from .builtin_composes import DefaultMemoryComposer, DefaultMemoryDecomposer, MultiElementDecomposer
from .composer_strategy import MessageProcessor, ProcessorFactory
from .decomposer_strategy import ContentParser, DefaultDecomposerStrategy, MultiElementDecomposerStrategy
from .xml_helper import XMLHelper
__all__ = [
"MemoryComposer",
"MemoryDecomposer",
"DefaultMemoryComposer",
"DefaultMemoryDecomposer",
"MultiElementDecomposer",
"ComposableMessageType",
"XMLHelper",
"ProcessorFactory",
"MessageProcessor",
"ContentParser",
"DefaultDecomposerStrategy",
"MultiElementDecomposerStrategy",
]
================================================
FILE: kirara_ai/memory/composes/base.py
================================================
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from kirara_ai.im.message import IMMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import LLMChatMessage
from kirara_ai.llm.format.response import Message
from kirara_ai.memory.entry import MemoryEntry
# 可组合的消息类型
ComposableMessageType = Union[IMMessage, LLMChatMessage, Message, str]
class MemoryComposer(ABC):
"""记忆组合器抽象类"""
container: DependencyContainer
@abstractmethod
def compose(
self, sender: Optional[ChatSender], message: List[ComposableMessageType]
) -> MemoryEntry:
"""将消息转换为记忆条目"""
class MemoryDecomposer(ABC):
"""记忆解析器抽象类"""
container: DependencyContainer
@abstractmethod
def decompose(self, entries: List[MemoryEntry]) -> List[ComposableMessageType]:
"""将记忆条目转换为消息"""
@property
def empty_message(self) -> ComposableMessageType:
"""空记忆消息"""
return "<空记忆>"
================================================
FILE: kirara_ai/memory/composes/builtin_composes.py
================================================
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from kirara_ai.im.message import IMMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.llm.format.message import LLMChatMessage
from kirara_ai.llm.format.response import Message
from kirara_ai.logger import get_logger
from kirara_ai.memory.entry import MemoryEntry
from .base import ComposableMessageType, MemoryComposer, MemoryDecomposer
from .composer_strategy import ProcessorFactory
from .decomposer_strategy import DefaultDecomposerStrategy, MultiElementDecomposerStrategy
class DefaultMemoryComposer(MemoryComposer):
def __init__(self):
self.processor_factory = None
def compose(
self, sender: Optional[ChatSender], message: List[ComposableMessageType]
) -> MemoryEntry:
# 延迟初始化,确保 container 已被设置
if self.processor_factory is None:
self.processor_factory = ProcessorFactory(self.container)
composed_message = ""
# 上下文用于在处理过程中传递和收集数据
context: Dict[str, Any] = {
"media_ids": [],
"tool_calls": [],
"tool_results": []
}
for msg in message:
msg_type = type(msg)
processor = self.processor_factory.get_processor(msg_type)
if processor:
composed_message += processor.process(msg, context)
elif isinstance(msg, str):
# 处理字符串消息
composed_message += f"{msg}\n"
composed_message = composed_message.strip()
composed_at = datetime.now()
return MemoryEntry(
sender=sender or ChatSender.get_bot_sender(),
content=composed_message,
timestamp=composed_at,
metadata={
"_media_ids": context.get("media_ids", []),
"_tool_calls": context.get("tool_calls", []),
"_tool_results": context.get("tool_results", []),
},
)
class DefaultMemoryDecomposer(MemoryDecomposer):
def __init__(self):
self.strategy = None
def decompose(self, entries: List[MemoryEntry]) -> List[ComposableMessageType]:
# 延迟初始化,确保 container 已被设置
if self.strategy is None:
self.strategy = DefaultDecomposerStrategy()
# 使用上下文传递参数
context = {
"empty_message": self.empty_message
}
# 使用策略解析记忆条目
return self.strategy.decompose(entries, context)
class MultiElementDecomposer(MemoryDecomposer):
logger = get_logger("MultiElementDecomposer")
def __init__(self):
self.strategy = None
def decompose(self, entries: List[MemoryEntry]) -> List[Union[IMMessage, LLMChatMessage, Message, str]]:
# 延迟初始化,确保 container 已被设置
if self.strategy is None:
self.strategy = MultiElementDecomposerStrategy()
# 使用上下文传递参数
context = {
"logger": self.logger
}
# 使用策略解析记忆条目
return self.strategy.decompose(entries, context)
================================================
FILE: kirara_ai/memory/composes/composer_strategy.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type
import kirara_ai.llm.format.tool as tools
from kirara_ai.im.message import IMMessage, MediaMessage, TextMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import (LLMChatImageContent, LLMChatMessage, LLMChatTextContent, LLMToolCallContent,
LLMToolResultContent)
from kirara_ai.media.manager import MediaManager
from .xml_helper import XMLHelper
def drop_think_part(text: str) -> str:
"""移除思考部分的文本"""
import re
return re.sub(r"(?:[\s\S]*?)?([\s\S]*)", r"\1", text, flags=re.DOTALL)
class MessageProcessor(ABC):
"""消息处理策略的基类"""
def __init__(self, container: DependencyContainer):
self.container = container
@abstractmethod
def process(self, message: Any, context: Dict) -> str:
"""处理特定类型的消息,返回组合后的文本"""
class TextMessageProcessor(MessageProcessor):
"""处理文本消息的策略"""
def process(self, message: TextMessage, context: Dict) -> str:
return f"{message.to_plain()}\n"
class MediaMessageProcessor(MessageProcessor):
"""处理媒体消息的策略"""
def process(self, message: MediaMessage, context: Dict) -> str:
media_ids = context.setdefault("media_ids", [])
media_ids.append(message.media_id)
desc = message.get_description()
tag = XMLHelper.create_xml_tag("media_msg", {"id": message.media_id, "desc": desc})
return f"{tag}\n"
class LLMChatTextContentProcessor(MessageProcessor):
"""处理LLM文本内容的策略"""
def process(self, content: LLMChatTextContent, context: Dict) -> str:
return f"{drop_think_part(content.text)}\n"
class LLMChatImageContentProcessor(MessageProcessor):
"""处理LLM图像内容的策略"""
def process(self, content: LLMChatImageContent, context: Dict) -> str:
media_ids = context.setdefault("media_ids", [])
media_ids.append(content.media_id)
media_manager = self.container.resolve(MediaManager)
media = media_manager.get_media(content.media_id)
desc = (media.description or "") if media else ""
tag = XMLHelper.create_xml_tag("media_msg", {"id": content.media_id, "desc": desc})
return f"{tag}\n"
class LLMToolCallContentProcessor(MessageProcessor):
"""处理LLM工具调用内容的策略"""
def process(self, content: LLMToolCallContent, context: Dict) -> str:
tool_calls = context.setdefault("tool_calls", [])
tool_calls.append(content.model_dump())
# parameters 比较长,保存到 metadata 里。
tag = XMLHelper.create_xml_tag("function_call", {
"id": content.id,
"name": content.name
})
return f"{tag}\n"
class LLMToolResultContentProcessor(MessageProcessor):
"""处理LLM工具结果内容的策略"""
def process(self, content: LLMToolResultContent, context: Dict) -> str:
tool_results = context.setdefault("tool_results", [])
tool_content = []
for item in content.content:
if isinstance(item, tools.TextContent):
tool_content.append({
"type": "text",
"text": item.text
})
elif isinstance(item, tools.MediaContent):
# 注册 media_id 引用
media_ids = context.setdefault("media_ids", [])
media_ids.append(item.media_id)
tool_content.append({
"type": "media",
"media_id": item.media_id
})
# content 比较长,保存到 metadata 里。
tool_results.append({
"id": content.id,
"name": content.name,
"isError": content.isError,
"content": tool_content
})
tag = XMLHelper.create_xml_tag("tool_result", {
"id": content.id,
"name": content.name,
"isError": str(content.isError)
})
return f"{tag}\n"
class IMMessageProcessor(MessageProcessor):
"""处理IM消息的策略"""
def __init__(self, container: DependencyContainer):
super().__init__(container)
self.element_processors: Dict[Type, MessageProcessor] = {
TextMessage: TextMessageProcessor(container),
MediaMessage: MediaMessageProcessor(container)
}
def process(self, message: IMMessage, context: Dict) -> str:
result = f"{message.sender.display_name} 说: \n"
for element in message.message_elements:
for process_type, processor in self.element_processors.items():
if isinstance(element, process_type):
result += processor.process(element, context)
break
else:
result += f"{element.to_plain()}\n"
return result
class LLMChatMessageProcessor(MessageProcessor):
"""处理LLM聊天消息的策略"""
def __init__(self, container: DependencyContainer):
super().__init__(container)
self.content_processors: Dict[Type, MessageProcessor] = {
LLMChatTextContent: LLMChatTextContentProcessor(container),
LLMChatImageContent: LLMChatImageContentProcessor(container),
LLMToolCallContent: LLMToolCallContentProcessor(container),
LLMToolResultContent: LLMToolResultContentProcessor(container)
}
def process(self, message: LLMChatMessage, context: Dict) -> str:
result = ""
temp = ""
for part in message.content:
part_type = type(part)
for processor_type, processor in self.content_processors.items():
if issubclass(part_type, processor_type):
if part_type in [LLMToolCallContent, LLMToolResultContent]:
# 工具调用和结果直接添加到结果中,不经过temp
result += processor.process(part, context)
else:
# 其他内容添加到temp中
temp += processor.process(part, context)
if temp.strip("\n"):
result += f"你回答: \n{temp}"
return result
class ProcessorFactory:
"""消息处理器工厂,用于创建和管理不同类型消息的处理器"""
def __init__(self, container: DependencyContainer):
self.container = container
self.processors: Dict[Type, MessageProcessor] = {
IMMessage: IMMessageProcessor(container),
LLMChatMessage: LLMChatMessageProcessor(container)
}
def get_processor(self, message_type: Type) -> Optional[MessageProcessor]:
"""获取特定类型消息的处理器"""
for processor_type, processor in self.processors.items():
if issubclass(message_type, processor_type):
return processor
return None
================================================
FILE: kirara_ai/memory/composes/decomposer_strategy.py
================================================
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Protocol, cast
from kirara_ai.llm.format.message import (LLMChatContentPartType, LLMChatImageContent, LLMChatMessage,
LLMChatTextContent, LLMToolCallContent, LLMToolResultContent, RoleType)
from kirara_ai.logger import get_logger
from kirara_ai.memory.entry import MemoryEntry
from .base import ComposableMessageType
from .xml_helper import XMLHelper
logger = get_logger("DecomposerStrategy")
class ContentInfo(NamedTuple):
"""解析后的内容信息"""
content_type: str # 内容类型:text, media, tool_call, tool_result
start: int # 开始位置
end: int # 结束位置
text: str # 原始文本
metadata: Dict[str, Any] = {} # 相关元数据
class ContentParseStrategy(Protocol):
"""内容解析策略协议"""
def extract_content(self, content: str, entry: MemoryEntry) -> List[ContentInfo]:
"""提取内容信息"""
...
def to_llm_content(self, info: ContentInfo) -> LLMChatContentPartType:
"""转换为LLM内容类型"""
...
def to_text(self, info: ContentInfo) -> str:
"""转换为文本形式"""
...
class TextContentStrategy:
"""文本内容解析策略"""
def extract_content(self, content: str, entry: MemoryEntry) -> List[ContentInfo]:
# 提取非标签文本
text_parts = []
current_pos = 0
# 查找所有标签的位置
tag_positions = []
for tag_name in ["media_msg", "function_call", "tool_result"]:
for _, start, end in XMLHelper.parse_xml_tag(content, tag_name):
tag_positions.append((start, end))
# 按位置排序
tag_positions.sort()
# 提取标签之间的文本
for start, end in tag_positions:
if start > current_pos:
text = content[current_pos:start].strip()
if text:
text_parts.append(ContentInfo(
content_type="text",
start=current_pos,
end=start,
text=text
))
current_pos = end
# 处理最后一段文本
if current_pos < len(content):
text = content[current_pos:].strip()
if text:
text_parts.append(ContentInfo(
content_type="text",
start=current_pos,
end=len(content),
text=text
))
return text_parts
def to_llm_content(self, info: ContentInfo) -> LLMChatContentPartType:
return LLMChatTextContent(text=info.text)
def to_text(self, info: ContentInfo) -> str:
return info.text
class MediaContentStrategy:
"""媒体内容解析策略"""
def __init__(self):
pass
def extract_content(self, content: str, entry: MemoryEntry) -> List[ContentInfo]:
media_parts = []
media_tags = XMLHelper.parse_xml_tag(content, "media_msg")
for attrs, start, end in media_tags:
if "id" in attrs and attrs["id"] is not None:
media_id = attrs["id"]
# 检查媒体ID是否在元数据中
if "_media_ids" in entry.metadata and media_id in entry.metadata["_media_ids"]:
media_parts.append(ContentInfo(
content_type="media",
start=start,
end=end,
text=content[start:end],
metadata={"media_id": media_id}
))
return media_parts
def to_llm_content(self, info: ContentInfo) -> LLMChatContentPartType:
return LLMChatImageContent(media_id=info.metadata["media_id"])
def to_text(self, info: ContentInfo) -> str:
return f""
class ToolCallContentStrategy:
"""工具调用内容解析策略"""
def extract_content(self, content: str, entry: MemoryEntry) -> List[ContentInfo]:
tool_call_parts = []
tool_call_tags = XMLHelper.parse_xml_tag(content, "function_call")
if "_tool_calls" not in entry.metadata:
return []
tool_calls = [call for call in entry.metadata["_tool_calls"]]
for attrs, start, end in tool_call_tags:
if "id" in attrs and attrs["id"] is not None:
call_id = attrs["id"]
# 查找对应的工具调用数据
for call in tool_calls:
if call.get("id") == call_id:
tool_call_parts.append(ContentInfo(
content_type="tool_call",
start=start,
end=end,
text=content[start:end],
metadata=call
))
break
return tool_call_parts
def to_llm_content(self, info: ContentInfo) -> LLMChatContentPartType:
return LLMToolCallContent.model_validate(info.metadata)
def to_text(self, info: ContentInfo) -> str:
return f""
class ToolResultContentStrategy:
"""工具结果内容解析策略"""
def extract_content(self, content: str, entry: MemoryEntry) -> List[ContentInfo]:
tool_result_parts = []
tool_result_tags = XMLHelper.parse_xml_tag(content, "tool_result")
if "_tool_results" not in entry.metadata:
return []
tool_results = [result for result in entry.metadata["_tool_results"]]
for attrs, start, end in tool_result_tags:
if "id" in attrs and attrs["id"] is not None:
result_id = attrs["id"]
# 查找对应的工具结果数据
for result in tool_results:
if result.get("id") == result_id:
tool_result_parts.append(ContentInfo(
content_type="tool_result",
start=start,
end=end,
text=content[start:end],
metadata=result
))
break
return tool_result_parts
def to_llm_content(self, info: ContentInfo) -> LLMChatContentPartType:
return LLMToolResultContent.model_validate(info.metadata)
def to_text(self, info: ContentInfo) -> str:
return f""
class ContentParser:
"""内容解析器,整合各种内容处理策略"""
def __init__(self):
self.strategies: Dict[str, ContentParseStrategy] = {
"text": TextContentStrategy(),
"media": MediaContentStrategy(),
"tool_call": ToolCallContentStrategy(),
"tool_result": ToolResultContentStrategy()
}
def parse_content(self, content: str, entry: MemoryEntry) -> List[ContentInfo]:
"""解析内容,返回按位置排序的内容信息列表"""
all_content = []
# 使用所有策略提取内容
for strategy in self.strategies.values():
all_content.extend(strategy.extract_content(content, entry))
# 按位置排序
return sorted(all_content, key=lambda x: x.start)
def to_llm_message(self, content_infos: List[ContentInfo], role: RoleType) -> List[LLMChatMessage]:
"""将内容信息转换为LLM消息"""
if not content_infos:
return []
messages: List[LLMChatMessage] = []
current_content: List[LLMChatContentPartType] = []
current_role: RoleType = role
for info in content_infos:
strategy = self.strategies.get(info.content_type)
if not strategy:
continue
# 对于工具调用和工具结果,创建单独的消息
if info.content_type == "tool_call":
# 如果之前有普通内容,先创建一个消息
if current_content:
messages.append(LLMChatMessage(role=current_role, content=current_content))
current_content = []
# 创建工具调用消息
messages.append(LLMChatMessage(
role="assistant",
content=[strategy.to_llm_content(info)]
))
elif info.content_type == "tool_result":
# 如果之前有普通内容,先创建一个消息
if current_content:
messages.append(LLMChatMessage(role=current_role, content=current_content))
current_content = []
# 创建工具结果消息
messages.append(LLMChatMessage(
role="tool",
content=[strategy.to_llm_content(info)]
))
else:
# 普通内容就近拼接
current_content.append(strategy.to_llm_content(info))
# 处理剩余的普通内容
if current_content:
messages.append(LLMChatMessage(role=current_role, content=current_content))
return messages
def to_text(self, content_infos: List[ContentInfo]) -> str:
"""将内容信息转换为文本形式"""
text_parts = []
for info in content_infos:
strategy = self.strategies.get(info.content_type)
if strategy:
text_parts.append(strategy.to_text(info))
return "".join(text_parts)
class DefaultDecomposerStrategy:
"""默认解析策略,将记忆条目转换为文本格式"""
def __init__(self):
self.content_parser = ContentParser()
def decompose(self, entries: List[MemoryEntry], context: Dict[str, Any]) -> List[ComposableMessageType]:
if not entries:
return [context.get("empty_message", "<空记忆>")]
# 限制最近的条目数量
entries = entries[-10:]
result: List[ComposableMessageType] = []
for entry in entries:
time_diff = datetime.now() - entry.timestamp
time_str = self._get_time_str(time_diff)
# 解析记忆条目
content = entry.content or ""
message_parts = []
if content:
if "你回答:" in content:
# 包含用户消息和AI回答
parts = content.split("你回答:", 1)
user_content = parts[0].strip()
assistant_content = parts[1].strip() if len(parts) > 1 else None
# 处理用户消息
if user_content:
content_infos = self.content_parser.parse_content(user_content, entry)
message_parts.append(self.content_parser.to_text(content_infos))
# 处理AI回答
if assistant_content:
content_infos = self.content_parser.parse_content(assistant_content, entry)
message_parts.append(f"你回答: {self.content_parser.to_text(content_infos)}")
else:
# 纯用户消息
content_infos = self.content_parser.parse_content(content, entry)
message_parts.append(self.content_parser.to_text(content_infos))
# 组合所有部分
result.append(f"{time_str},{''.join(message_parts)}")
return result
def _get_time_str(self, time_diff: timedelta) -> str:
"""获取时间差的字符串表示"""
if time_diff.days > 0:
return f"{time_diff.days}天前"
elif time_diff.seconds > 3600:
return f"{time_diff.seconds // 3600}小时前"
elif time_diff.seconds > 60:
return f"{time_diff.seconds // 60}分钟前"
else:
return "刚刚"
class MultiElementDecomposerStrategy:
"""多元素解析策略,将记忆条目还原为原始对象结构"""
def __init__(self):
self.content_parser = ContentParser()
def decompose(self, entries: List[MemoryEntry], context: Dict[str, Any]) -> List[ComposableMessageType]:
result: List[LLMChatMessage] = []
# 处理每个记忆条目
for entry in entries:
messages = self._process_entry(entry)
result.extend(messages)
# 合并相邻的相同角色消息
self._merge_adjacent_messages(result)
# 转换为ComposableMessageType类型返回
return cast(List[ComposableMessageType], result)
def _process_entry(self, entry: MemoryEntry) -> List[LLMChatMessage]:
"""处理单个记忆条目,按照内容顺序解析"""
result: List[LLMChatMessage] = []
content = entry.content or ""
if not content:
return result
if "你回答:" in content:
# 包含用户消息和AI回答
parts = content.split("你回答:", 1)
user_content = parts[0].strip()
assistant_content = parts[1].strip() if len(parts) > 1 else None
# 处理用户消息
if user_content:
content_infos = self.content_parser.parse_content(user_content, entry)
user_message = self.content_parser.to_llm_message(content_infos, "user")
if user_message:
result.extend(user_message)
# 处理AI回答
if assistant_content:
content_infos = self.content_parser.parse_content(assistant_content, entry)
assistant_message = self.content_parser.to_llm_message(content_infos, "assistant")
if assistant_message:
result.extend(assistant_message)
else:
# 纯用户消息
content_infos = self.content_parser.parse_content(content, entry)
user_message = self.content_parser.to_llm_message(content_infos, "user")
if user_message:
result.extend(user_message)
return result
def _merge_adjacent_messages(self, messages: List[LLMChatMessage]) -> None:
"""
合并相邻的相同角色消息
只处理 user 和 assistant 类型, 其他类型不处理
"""
i = 0
while i < len(messages) - 1:
current_msg = messages[i]
next_msg = messages[i + 1]
if (current_msg.role == next_msg.role and
current_msg.role in ["user", "assistant"]):
# 合并内容
current_msg.content.extend(next_msg.content)
# 删除下一个消息
messages.pop(i + 1)
else:
i += 1
================================================
FILE: kirara_ai/memory/composes/xml_helper.py
================================================
import re
from typing import Dict, List, Optional, Tuple
class XMLHelper:
"""XML 格式化和解析的辅助工具类"""
@staticmethod
def escape_xml_attr(text: str) -> str:
"""转义XML属性中的特殊字符"""
if not isinstance(text, str):
text = str(text)
return text.replace("&", "&").replace("\"", """).replace("<", "<").replace(">", ">")
@staticmethod
def unescape_xml_attr(text: str) -> str:
"""反转义XML属性中的特殊字符"""
return text.replace(""", "\"").replace("<", "<").replace(">", ">").replace("&", "&")
@staticmethod
def create_xml_tag(tag_name: str, attributes: Dict[str, Optional[str]], self_closing: bool = True) -> str:
"""创建XML标签,支持 null safety(None 值的属性将被忽略)"""
attrs_str = " ".join([f'{k}="{XMLHelper.escape_xml_attr(v)}"' for k, v in attributes.items() if v is not None])
if self_closing:
return f"<{tag_name} {attrs_str} />"
else:
return f"<{tag_name} {attrs_str}>"
@staticmethod
def parse_xml_tag(content: str, tag_name: str) -> List[Tuple[Dict[str, Optional[str]], int, int]]:
"""解析XML标签,返回属性字典和标签在原文中的起始、结束位置
如果属性在标签中不存在,则在返回的字典中该属性值为 None
"""
pattern = re.compile(f'<{tag_name}\\s+(.*?)\\s*/>')
attr_pattern = re.compile(r'(\w+)="(.*?)"')
results: List[Tuple[Dict[str, Optional[str]], int, int]] = []
for match in pattern.finditer(content):
attrs_text = match.group(1)
attrs: Dict[str, Optional[str]] = {name: XMLHelper.unescape_xml_attr(value) for name, value in attr_pattern.findall(attrs_text)}
results.append((attrs, match.start(), match.end()))
return results
@staticmethod
def get_attr(attrs: Dict[str, Optional[str]], key: str, default: Optional[str] = None) -> Optional[str]:
"""安全地从属性字典中获取值,如果不存在则返回默认值"""
return attrs.get(key, default)
================================================
FILE: kirara_ai/memory/entry.py
================================================
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict
from kirara_ai.im.sender import ChatSender
@dataclass
class MemoryEntry:
"""基础记忆条目"""
sender: ChatSender
content: str
timestamp: datetime = field(default_factory=datetime.now)
metadata: Dict[str, Any] = field(default_factory=dict)
================================================
FILE: kirara_ai/memory/memory_manager.py
================================================
from typing import Dict, List, Optional, Type
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.media.carrier import MediaReferenceProvider
from kirara_ai.media.carrier.service import MediaCarrierService
from kirara_ai.memory.persistences.base import AsyncMemoryPersistence, MemoryPersistence
from kirara_ai.memory.persistences.file_persistence import FileMemoryPersistence
from kirara_ai.memory.persistences.redis_persistence import RedisMemoryPersistence
from .composes import MemoryComposer, MemoryDecomposer
from .entry import MemoryEntry
from .registry import ComposerRegistry, DecomposerRegistry, ScopeRegistry
from .scopes import MemoryScope
class MemoryManager(MediaReferenceProvider[List[MemoryEntry]]):
"""记忆系统管理器,负责整个记忆系统的生命周期管理"""
def __init__(
self,
container: DependencyContainer,
persistence: Optional[MemoryPersistence] = None,
):
self.container = container
self.config = container.resolve(GlobalConfig).memory
# 初始化注册表
self.scope_registry = Inject(container).create(ScopeRegistry)()
self.composer_registry = Inject(container).create(ComposerRegistry)()
self.decomposer_registry = Inject(container).create(DecomposerRegistry)()
# 注册到容器
container.register(ScopeRegistry, self.scope_registry)
container.register(ComposerRegistry, self.composer_registry)
container.register(DecomposerRegistry, self.decomposer_registry)
# 初始化持久化层
if persistence is None:
self._init_persistence()
else:
self.persistence = persistence
# 内存缓存
self.memories: Dict[str, List[MemoryEntry]] = {}
def _init_persistence(self):
"""初始化持久化层"""
persistence_type = self.config.persistence.type
if persistence_type == "file":
storage_dir = self.config.persistence.file["storage_dir"]
self.persistence = FileMemoryPersistence(storage_dir)
elif persistence_type == "redis":
redis_config = self.config.persistence.redis
self.persistence = RedisMemoryPersistence(**redis_config)
else:
raise ValueError(f"Unsupported persistence type: {persistence_type}")
self.persistence = AsyncMemoryPersistence(self.persistence)
def register_scope(self, name: str, scope_class: Type[MemoryScope]):
"""注册新的作用域类型"""
self.scope_registry.register(name, scope_class)
def register_composer(self, name: str, composer_class: Type[MemoryComposer]):
"""注册新的组合器"""
self.composer_registry.register(name, composer_class)
def register_decomposer(self, name: str, decomposer_class: Type[MemoryDecomposer]):
"""注册新的解析器"""
self.decomposer_registry.register(name, decomposer_class)
def store(self, scope: MemoryScope, entry: MemoryEntry, extra_identifier: Optional[str] = None) -> None:
"""存储新的记忆"""
scope_key = scope.get_scope_key(entry.sender)
if extra_identifier is not None:
scope_key = f"{extra_identifier}-{scope_key}"
if scope_key not in self.memories:
self.memories[scope_key] = self.persistence.load(scope_key)
self.memories[scope_key].append(entry)
self._register_media_reference(entry, scope_key)
# 限制记忆条目数量
if len(self.memories[scope_key]) > self.config.max_entries:
# 移除旧记忆的媒体引用
removed_entries = self.memories[scope_key][:-self.config.max_entries]
unremoved_entries = self.memories[scope_key][-self.config.max_entries:]
self._remove_media_references(removed_entries, unremoved_entries, scope_key)
# 裁剪记忆列表
self.memories[scope_key] = unremoved_entries
self.persistence.save(scope_key, self.memories[scope_key])
def query(self, scope: MemoryScope, sender: ChatSender, extra_identifier: Optional[str] = None) -> List[MemoryEntry]:
"""查询历史记忆"""
relevant_memories = []
scope_key = scope.get_scope_key(sender)
if extra_identifier is not None:
scope_key = f"{extra_identifier}-{scope_key}"
if scope_key not in self.memories:
self.memories[scope_key] = self.persistence.load(scope_key)
# 遍历所有记忆,找出作用域内的记忆
for scope_key, entries in self.memories.items():
for entry in entries:
if scope.is_in_scope(entry.sender, sender):
relevant_memories.append(entry)
# 按时间排序
relevant_memories.sort(key=lambda x: x.timestamp)
return relevant_memories
def shutdown(self):
"""关闭记忆系统,确保数据持久化"""
# 保存所有内存中的数据
for scope_key, entries in self.memories.items():
self.persistence.save(scope_key, entries)
# 执行持久化层的stop操作
if isinstance(self.persistence, AsyncMemoryPersistence):
self.persistence.stop()
def clear_memory(self, scope: MemoryScope, sender: ChatSender) -> None:
"""清空指定作用域和发送者的记忆
Args:
scope: 记忆作用域
sender: 发送者标识
"""
scope_key = scope.get_scope_key(sender)
# 移除媒体引用
if scope_key not in self.memories:
return
self._remove_media_references(self.memories[scope_key], [], scope_key)
# 清空内存中的记录
self.memories[scope_key] = []
# 保存空记录到持久化层
self.persistence.save(scope_key, [])
def get_reference_owner(self, reference_key: str) -> Optional[List[MemoryEntry]]:
"""获取引用所有者"""
if reference_key not in self.memories:
self.memories[reference_key] = self.persistence.load(reference_key)
return self.memories.get(reference_key)
def _register_media_reference(self, entry: MemoryEntry, reference_key: str) -> None:
"""注册媒体引用"""
media_carrier = self.container.resolve(MediaCarrierService)
for media_id in entry.metadata.get("_media_ids", []):
media_carrier.register_reference(media_id, "memory", reference_key)
def _remove_media_references(self, removed_entries: List[MemoryEntry], unremoved_entries: List[MemoryEntry], reference_key: str) -> None:
"""移除媒体引用"""
media_carrier = self.container.resolve(MediaCarrierService)
# 确保 id 没有在 unremoved_entries 中
removed_media_ids = [media_id for entry in removed_entries for media_id in entry.metadata.get("_media_ids", [])]
unremoved_media_ids = [media_id for entry in unremoved_entries for media_id in entry.metadata.get("_media_ids", [])]
for media_id in removed_media_ids:
if media_id not in unremoved_media_ids:
media_carrier.remove_reference(media_id, "memory", reference_key)
================================================
FILE: kirara_ai/memory/persistences/__init__.py
================================================
from .base import AsyncMemoryPersistence, MemoryPersistence
from .file_persistence import FileMemoryPersistence
from .redis_persistence import RedisMemoryPersistence
__all__ = [
"MemoryPersistence",
"AsyncMemoryPersistence",
"FileMemoryPersistence",
"RedisMemoryPersistence",
"codecs",
]
================================================
FILE: kirara_ai/memory/persistences/base.py
================================================
import threading
from abc import ABC, abstractmethod
from queue import Empty, Queue
from typing import List, Tuple
from kirara_ai.logger import get_logger
from kirara_ai.memory.entry import MemoryEntry
class MemoryPersistence(ABC):
"""持久化层抽象类"""
@abstractmethod
def save(self, scope_key: str, entries: List[MemoryEntry]) -> None:
pass
@abstractmethod
def load(self, scope_key: str) -> List[MemoryEntry]:
pass
@abstractmethod
def flush(self) -> None:
"""确保所有数据都已持久化"""
logger = get_logger("MemoryPersistence")
class AsyncMemoryPersistence:
"""异步持久化管理器"""
def __init__(self, persistence: MemoryPersistence):
self.persistence = persistence
self.queue: Queue[Tuple[str, List[MemoryEntry]]] = Queue()
self.running = True
self.worker = threading.Thread(target=self._worker, daemon=True)
self.worker.start()
def _worker(self):
while self.running:
try:
scope_key, entries = self.queue.get(timeout=1)
self.persistence.save(scope_key, entries)
self.queue.task_done()
logger.info(f"Saved {scope_key} with {len(entries)} entries")
except Empty:
continue
except Exception as e:
logger.error(f"Error saving memory: {e}")
continue
def load(self, scope_key: str) -> List[MemoryEntry]:
return self.persistence.load(scope_key)
def save(self, scope_key: str, entries: List[MemoryEntry]):
self.queue.put((scope_key, entries))
def stop(self):
self.running = False
self.worker.join()
self.persistence.flush()
================================================
FILE: kirara_ai/memory/persistences/codecs.py
================================================
import json
from datetime import datetime
from types import FunctionType
from kirara_ai.im.sender import ChatSender, ChatType
from kirara_ai.logger import get_logger
class MemoryJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, ChatSender):
return {
"__type__": "ChatSender",
"user_id": obj.user_id,
"chat_type": obj.chat_type.value,
"group_id": obj.group_id,
"display_name": obj.display_name,
"raw_metadata": obj.raw_metadata,
}
elif isinstance(obj, ChatType):
return obj.value
elif isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, FunctionType):
return {
"__type__": "function",
"name": obj.__name__,
"args": obj.__code__.co_varnames[:obj.__code__.co_argcount],
"defaults": obj.__defaults__,
"kwdefaults": obj.__kwdefaults__,
"doc": obj.__doc__,
}
try:
return super().default(obj)
except Exception as e:
get_logger("MemoryJSONEncoder").warning(f"failed to encode object: {e}")
return None
def memory_json_decoder(obj):
if "__type__" in obj:
if obj["__type__"] == "ChatSender":
return ChatSender(
user_id=obj["user_id"],
chat_type=ChatType(obj["chat_type"]),
group_id=obj["group_id"],
display_name=obj["display_name"],
raw_metadata=obj["raw_metadata"],
)
return obj
================================================
FILE: kirara_ai/memory/persistences/file_persistence.py
================================================
import json
import os
from datetime import datetime
from typing import List
from kirara_ai.memory.entry import MemoryEntry
from .base import MemoryPersistence
from .codecs import MemoryJSONEncoder, memory_json_decoder
class FileMemoryPersistence(MemoryPersistence):
"""文件持久化实现"""
def __init__(self, data_dir: str):
if not os.path.isabs(data_dir):
data_dir = os.path.abspath(data_dir)
self.data_dir = data_dir
os.makedirs(data_dir, exist_ok=True)
def _get_file_path(self, scope_key: str) -> str:
scope_key = scope_key.replace(":", "_")
return os.path.join(self.data_dir, f"{scope_key}.json")
def save(self, scope_key: str, entries: List[MemoryEntry]) -> None:
file_path = self._get_file_path(scope_key)
# 序列化记忆条目
serialized_entries = [
{
"sender": entry.sender,
"content": entry.content,
"timestamp": entry.timestamp,
"metadata": entry.metadata,
}
for entry in entries
]
# 写入文件
with open(file_path, "w", encoding="utf-8") as f:
json.dump(
serialized_entries,
f,
ensure_ascii=False,
indent=2,
cls=MemoryJSONEncoder,
)
def load(self, scope_key: str) -> List[MemoryEntry]:
file_path = self._get_file_path(scope_key)
if not os.path.exists(file_path):
return []
# 读取并反序列化
with open(file_path, "r", encoding="utf-8") as f:
serialized_entries = json.load(f, object_hook=memory_json_decoder)
return [
MemoryEntry(
sender=entry["sender"],
content=entry["content"],
timestamp=(
datetime.fromisoformat(entry["timestamp"])
if isinstance(entry["timestamp"], str)
else entry["timestamp"]
),
metadata=entry["metadata"],
)
for entry in serialized_entries
]
def flush(self) -> None:
# 文件系统实现不需要特别的flush操作
pass
================================================
FILE: kirara_ai/memory/persistences/redis_persistence.py
================================================
import json
from datetime import datetime
from typing import List, Optional
from kirara_ai.memory.entry import MemoryEntry
from .base import MemoryPersistence
from .codecs import MemoryJSONEncoder, memory_json_decoder
class RedisMemoryPersistence(MemoryPersistence):
"""Redis持久化实现"""
def __init__(
self,
redis_url: Optional[str] = None,
host: str = "localhost",
port: int = 6379,
db: int = 0,
):
import redis
if redis_url:
self.redis = redis.from_url(redis_url)
else:
self.redis = redis.Redis(host=host, port=port, db=db)
def save(self, scope_key: str, entries: List[MemoryEntry]) -> None:
# 序列化记忆条目
serialized_entries = [
{
"sender": entry.sender,
"content": entry.content,
"timestamp": entry.timestamp,
"metadata": entry.metadata,
}
for entry in entries
]
# 存储到Redis
self.redis.set(
scope_key,
json.dumps(serialized_entries, ensure_ascii=False, cls=MemoryJSONEncoder),
)
def load(self, scope_key: str) -> List[MemoryEntry]:
# 从Redis读取
data = self.redis.get(scope_key)
if not data:
return []
# 反序列化
serialized_entries = json.loads(data, object_hook=memory_json_decoder) # type: ignore
return [
MemoryEntry(
sender=entry["sender"],
content=entry["content"],
timestamp=(
datetime.fromisoformat(entry["timestamp"])
if isinstance(entry["timestamp"], str)
else entry["timestamp"]
),
metadata=entry["metadata"],
)
for entry in serialized_entries
]
def flush(self) -> None:
self.redis.save()
================================================
FILE: kirara_ai/memory/registry.py
================================================
from typing import Dict, Type
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.memory.composes import MemoryComposer, MemoryDecomposer
from kirara_ai.memory.scopes import MemoryScope
class Registry:
"""基础注册表类"""
container: DependencyContainer
_registry: Dict[str, Type] = dict()
def __init__(self, container: DependencyContainer):
self.container = container
self._registry = dict()
def register(self, name: str, cls: Type) -> None:
"""注册一个新的实现"""
self._registry[name] = cls
def unregister(self, name: str) -> None:
"""注销一个实现"""
if name in self._registry:
del self._registry[name]
class ScopeRegistry(Registry):
"""作用域注册表"""
def get_scope(self, name: str) -> MemoryScope:
"""获取作用域实例"""
if name not in self._registry:
raise ValueError(f"Scope not found: {name}")
return Inject(self.container).create(self._registry[name])()
class ComposerRegistry(Registry):
"""组合器注册表"""
def get_composer(self, name: str) -> MemoryComposer:
"""获取组合器实例"""
if name not in self._registry:
raise ValueError(f"Composer not found: {name}")
return Inject(self.container).create(self._registry[name])()
class DecomposerRegistry(Registry):
"""解析器注册表"""
def get_decomposer(self, name: str) -> MemoryDecomposer:
"""获取解析器实例"""
if name not in self._registry:
raise ValueError(f"Decomposer not found: {name}")
return Inject(self.container).create(self._registry[name])()
================================================
FILE: kirara_ai/memory/scopes/__init__.py
================================================
from .base import MemoryScope
from .builtin_scopes import GlobalScope, GroupScope, MemberScope
__all__ = ["MemoryScope", "MemberScope", "GroupScope", "GlobalScope"]
================================================
FILE: kirara_ai/memory/scopes/base.py
================================================
from abc import ABC, abstractmethod
from kirara_ai.im.sender import ChatSender
class MemoryScope(ABC):
"""记忆作用域抽象类"""
@abstractmethod
def get_scope_key(self, sender: ChatSender) -> str:
"""获取作用域的键值"""
@abstractmethod
def is_in_scope(self, target_sender: ChatSender, query_sender: ChatSender) -> bool:
"""判断是否在作用域内"""
================================================
FILE: kirara_ai/memory/scopes/builtin_scopes.py
================================================
from kirara_ai.im.sender import ChatSender, ChatType
from .base import MemoryScope
# 默认实现
class MemberScope(MemoryScope):
"""群成员作用域"""
def get_scope_key(self, sender: ChatSender) -> str:
if sender.chat_type == ChatType.GROUP:
return f"member:{sender.group_id}:{sender.user_id}"
else:
return f"c2c:{sender.user_id}"
def is_in_scope(self, target_sender: ChatSender, query_sender: ChatSender) -> bool:
if target_sender.chat_type != query_sender.chat_type:
return False
if target_sender.chat_type == ChatType.GROUP:
return (
target_sender.group_id == query_sender.group_id
and target_sender.user_id == query_sender.user_id
)
else:
return target_sender.user_id == query_sender.user_id
class GroupScope(MemoryScope):
"""群作用域"""
def get_scope_key(self, sender: ChatSender) -> str:
if sender.chat_type == ChatType.GROUP:
return f"group:{sender.group_id}"
else:
return f"c2c:{sender.user_id}"
def is_in_scope(self, target_sender: ChatSender, query_sender: ChatSender) -> bool:
if target_sender.chat_type != query_sender.chat_type:
return False
if target_sender.chat_type == ChatType.GROUP:
return target_sender.group_id == query_sender.group_id
else:
return target_sender.user_id == query_sender.user_id
class GlobalScope(MemoryScope):
"""全局作用域"""
def get_scope_key(self, sender: ChatSender) -> str:
return "global"
def is_in_scope(self, target_sender: ChatSender, query_sender: ChatSender) -> bool:
return True
================================================
FILE: kirara_ai/plugin_manager/models.py
================================================
from typing import Any, Dict, Optional
from pydantic import BaseModel
class PluginInfo(BaseModel):
"""插件信息"""
name: str
package_name: Optional[str] = None # 外部插件的包名
description: str
version: str
author: str
is_internal: bool # 是否为内部插件
is_enabled: bool # 是否启用
requires_restart: bool = False # 是否需要重启
metadata: Optional[Dict[str, Any]] = None
================================================
FILE: kirara_ai/plugin_manager/plugin.py
================================================
from abc import ABC, abstractmethod
from kirara_ai.events.event_bus import EventBus
from kirara_ai.im.im_registry import IMRegistry
from kirara_ai.im.manager import IMManager
from kirara_ai.llm.llm_registry import LLMBackendRegistry
from kirara_ai.workflow.core.dispatch import WorkflowDispatcher
class Plugin(ABC):
"""
插件基类。
外部插件需要在 setup.py 中注册 entry_points:
setup(
name='your-plugin-name',
...
entry_points={
'chatgpt_mirai.plugins': [
'plugin_name = your_package.module:PluginClass'
]
}
)
"""
ENTRY_POINT_GROUP = "chatgpt_mirai.plugins"
event_bus: EventBus
workflow_dispatcher: WorkflowDispatcher
llm_registry: LLMBackendRegistry
im_registry: IMRegistry
im_manager: IMManager
@abstractmethod
def on_load(self):
pass
@abstractmethod
def on_start(self):
pass
@abstractmethod
def on_stop(self):
pass
================================================
FILE: kirara_ai/plugin_manager/plugin_event_bus.py
================================================
from typing import Callable, List, Type
from kirara_ai.events.event_bus import EventBus
class PluginEventBus:
def __init__(self, event_bus: EventBus):
self._event_bus = event_bus
self._registered_listeners: List[Callable] = [] # 记录注册过的函数
def register(self, event_type: Type, listener: Callable):
self._event_bus.register(event_type, listener)
self._registered_listeners.append(listener) # 记录注册的函数
def unregister(self, event_type: Type, listener: Callable):
self._event_bus.unregister(event_type, listener)
def post(self, event):
self._event_bus.post(event)
def unregister_all(self):
"""反注册所有通过 @Event 注册的函数"""
for listener in self._registered_listeners:
for event_type in self._event_bus._listeners:
if listener in self._event_bus._listeners[event_type]:
self._event_bus.unregister(event_type, listener)
self._registered_listeners.clear() # 清空记录
================================================
FILE: kirara_ai/plugin_manager/plugin_loader.py
================================================
import asyncio
import importlib
import os
import sys
from typing import Dict, List, Optional, Type
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.events.event_bus import EventBus
from kirara_ai.events.plugin import PluginLoaded, PluginStarted, PluginStopped
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.models import PluginInfo
from kirara_ai.plugin_manager.plugin import Plugin
from kirara_ai.plugin_manager.plugin_event_bus import PluginEventBus
class PluginLoader:
def __init__(self, container: DependencyContainer, plugin_dir: str):
self.plugins: Dict[str, Plugin] = {} # 存储插件实例
self.plugin_infos: Dict[str, PluginInfo] = {} # 存储插件信息
self.container = container
self.logger = get_logger("PluginLoader")
self._loaded_entry_points: set[str] = set() # 记录已加载的entry points
self.plugin_dir = plugin_dir
self.internal_plugins: List[str] = []
self.config = self.container.resolve(GlobalConfig)
self.event_bus = self.container.resolve(EventBus)
def register_plugin(self, plugin_class: Type[Plugin], plugin_name: Optional[str] = None):
"""注册一个插件类,主要用于测试"""
plugin = self.instantiate_plugin(plugin_class)
key = plugin_name or plugin_class.__name__
self.plugins[key] = plugin
# 创建并存储插件信息
plugin_info = PluginInfo(
name=plugin_class.__name__,
package_name=plugin_name,
description=plugin_class.__doc__ or "",
version="1.0.0",
author="Test",
is_internal=True,
is_enabled=True,
metadata=getattr(plugin, "metadata", None),
)
self.plugin_infos[key] = plugin_info
self.logger.info(f"Registered test plugin: {key}")
return plugin
def discover_internal_plugins(self, plugin_dir=None):
"""Discovers and loads internal plugins from a specified directory.
Scans the given directory for subdirectories and attempts to load each as a plugin.
Args:
plugin_dir (str): Path to the directory containing plugin subdirectories.
"""
if not plugin_dir:
plugin_dir = self.plugin_dir
self.logger.info(f"Discovering internal plugins from directory: {plugin_dir}")
sys.path.append(plugin_dir)
for plugin_name in os.listdir(plugin_dir):
plugin_path = os.path.join(plugin_dir, plugin_name)
if os.path.isdir(plugin_path):
self.internal_plugins.append(plugin_name)
self.logger.debug(f"Found plugin directory: {plugin_name}")
self.load_plugin(plugin_name)
def load_plugin(self, plugin_name: str):
"""加载插件,支持内部插件和外部插件"""
self.logger.info(f"Loading plugin: {plugin_name}")
try:
if plugin_name in self.internal_plugins: # 内部插件
self._load_internal_plugin(plugin_name)
else: # 外部插件
self._load_external_plugin(plugin_name)
except Exception as e:
self.logger.error(f"Failed to load plugin {plugin_name}: {e}")
def _load_internal_plugin(self, plugin_name: str):
"""加载内部插件"""
module = importlib.import_module(plugin_name)
plugin_classes = [
cls
for cls in module.__dict__.values()
if isinstance(cls, type) and issubclass(cls, Plugin) and cls != Plugin
]
if not plugin_classes:
raise ValueError(f"No valid plugin class found in module {plugin_name}")
plugin_class = plugin_classes[0]
plugin = self.instantiate_plugin(plugin_class)
self.plugins[plugin_name] = plugin
# 创建并存储插件信息
plugin_info = PluginInfo(
name=plugin_class.__name__,
description=plugin_class.__doc__ or "",
version="1.0.0",
author="Internal",
is_internal=True,
is_enabled=True,
metadata=getattr(plugin, "metadata", None),
)
self.plugin_infos[plugin_name] = plugin_info
self.logger.info(f"Internal plugin {plugin_name} loaded successfully")
return plugin
def _load_external_plugin(self, plugin_name: str):
"""加载外部插件"""
from importlib import reload
from importlib.metadata import entry_points
# 获取插件的 entry point
eps = entry_points(group=Plugin.ENTRY_POINT_GROUP)
plugin_ep = next((ep for ep in eps if ep.name == plugin_name), None)
if not plugin_ep:
raise ValueError(f"Unable to find entry point for plugin {plugin_name}")
try:
# 尝试重新加载 module
if plugin_ep.module in sys.modules:
module = sys.modules[plugin_ep.module]
self.logger.info(f"Reloading plugin {plugin_name} from {plugin_ep.module}")
reload(module)
except Exception as e:
self.logger.error(f"Failed to reload plugin {plugin_name}: {e}")
try:
# 加载插件类
plugin_class = plugin_ep.load()
# 检查插件类是否继承自 Plugin
if not issubclass(plugin_class, Plugin):
raise TypeError(
f"Plugin {plugin_name} must inherit from the Plugin class"
)
# 实例化插件并启动
plugin: Plugin = self.instantiate_plugin(plugin_class)
self.plugins[plugin_name] = plugin
self.logger.info(f"Successfully loaded external plugin: {plugin_name}")
return plugin
except Exception as e:
self.logger.error(f"Failed to load external plugin {plugin_name}: {e}")
raise
def instantiate_plugin(self, plugin_class):
"""Instantiates a plugin class using dependency injection."""
self.logger.debug(f"Instantiating plugin class: {plugin_class.__name__}")
event_bus = self.container.resolve(EventBus)
with self.container.scoped() as scoped_container:
scoped_container.register(EventBus, PluginEventBus(event_bus))
return Inject(scoped_container).create(plugin_class)()
def load_plugins(self):
"""Initializes all loaded plugins."""
self.logger.info("Initializing plugins...")
for plugin_name, plugin in self.plugins.items():
try:
plugin.on_load()
self.logger.info(f"Plugin {plugin.__class__.__name__} initialized")
self.event_bus.post(PluginLoaded(plugin))
except Exception as e:
self.logger.error(
f"Failed to initialize plugin {plugin.__class__.__name__}: {e}"
)
def start_plugins(self):
"""Starts all loaded plugins."""
self.logger.info("Starting plugins...")
for plugin_name, plugin in self.plugins.items():
try:
plugin.on_start()
self.plugin_infos[plugin_name].is_enabled = True
self.logger.info(f"Plugin {plugin.__class__.__name__} started")
self.event_bus.post(PluginStarted(plugin))
except Exception as e:
self.logger.error(
f"Failed to start plugin {plugin.__class__.__name__}: {e}"
)
def stop_plugins(self):
"""Stops all loaded plugins."""
self.logger.info("Stopping plugins...")
for plugin_name, plugin in self.plugins.items():
try:
plugin.on_stop()
if isinstance(plugin.event_bus, PluginEventBus):
plugin.event_bus.unregister_all()
self.logger.info(f"Plugin {plugin.__class__.__name__} stopped")
self.event_bus.post(PluginStopped(plugin))
except Exception as e:
self.logger.error(
f"Failed to stop plugin {plugin.__class__.__name__}: {e}"
)
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息"""
return self.plugin_infos.get(plugin_name)
def get_all_plugin_infos(self) -> List[PluginInfo]:
"""获取所有插件信息"""
return list(self.plugin_infos.values())
async def install_plugin(
self, package_name: str, version: Optional[str] = None
) -> Optional[PluginInfo]:
"""安装插件"""
try:
# 构建安装命令
cmd = [sys.executable, "-m", "pip", "install", "--index-url", self.config.update.pypi_registry]
if version:
cmd.append(f"{package_name}=={version}")
else:
cmd.append(package_name)
# 执行安装
self.logger.info(f"Installing plugin: {package_name}")
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
# 实时处理输出
async def read_stream(stream, log_func):
while True:
line = await stream.readline()
if not line:
break
log_func(line.decode().strip())
# 并行处理 stdout 和 stderr
await asyncio.gather(
read_stream(process.stdout, lambda msg: self.logger.info(f"[{package_name} install] {msg}")),
read_stream(process.stderr, lambda msg: self.logger.error(f"[{package_name} install] {msg}"))
)
# 等待进程完成
return_code = await process.wait()
if return_code != 0:
raise Exception(f"Failed to install plugin: return code {return_code}")
# 导入并加载插件
self.discover_external_plugins()
possible_plugin_infos = [
info
for info in self.plugin_infos.values()
if info.package_name == package_name
]
if possible_plugin_infos:
return possible_plugin_infos[0]
except Exception as e:
raise Exception(f"Failed to install plugin: {str(e)}")
return None
async def uninstall_plugin(self, plugin_name: str) -> bool:
"""卸载插件"""
try:
plugin_info = self.plugin_infos.get(plugin_name)
if not plugin_info:
return False
if plugin_info.is_internal:
raise Exception("Cannot uninstall internal plugin")
# 卸载前先禁用插件
await self.disable_plugin(plugin_name)
assert plugin_info.package_name is not None
# 执行卸载
cmd = [
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
plugin_info.package_name,
]
self.logger.info(f"Uninstalling plugin: {plugin_info.package_name}")
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
# 实时处理输出
async def read_stream(stream, log_func):
while True:
line = await stream.readline()
if not line:
break
log_func(line.decode().strip())
# 并行处理 stdout 和 stderr
await asyncio.gather(
read_stream(process.stdout, lambda msg: self.logger.info(f"[{plugin_info.package_name} uninstall] {msg}")),
read_stream(process.stderr, lambda msg: self.logger.error(f"[{plugin_info.package_name} uninstall] {msg}"))
)
# 等待进程完成
return_code = await process.wait()
if return_code != 0:
raise Exception(f"Failed to uninstall plugin: return code {return_code}")
# 清理插件信息
if plugin_name in self.plugin_infos:
del self.plugin_infos[plugin_name]
return True
except Exception as e:
raise Exception(f"Failed to uninstall plugin: {str(e)}")
async def enable_plugin(self, plugin_name: str) -> bool:
"""启用插件"""
plugin_info = self.get_plugin_info(plugin_name)
if not plugin_info:
raise ValueError(f"Plugin {plugin_name} not found")
if plugin_info.is_enabled:
return True
try:
# 加载插件
if plugin_info.is_internal:
plugin = self._load_internal_plugin(plugin_name)
else:
plugin = self._load_external_plugin(plugin_name)
# 更新配置
if plugin_name not in self.config.plugins.enable:
self.config.plugins.enable.append(plugin_name)
plugin.on_load()
plugin.on_start()
plugin_info.is_enabled = True
self.logger.info(f"Plugin {plugin_name} enabled")
return True
except Exception as e:
plugin_info.requires_restart = True
self.logger.error(f"Failed to enable plugin {plugin_name}: {e}")
raise e
async def disable_plugin(self, plugin_name: str) -> bool:
"""禁用插件"""
plugin_info = self.get_plugin_info(plugin_name)
if not plugin_info:
raise ValueError(f"Plugin {plugin_name} not found")
if not plugin_info.is_enabled:
return True
try:
# 找到并停止插件实例
if plugin_name in self.plugins:
plugin = self.plugins[plugin_name]
if isinstance(plugin.event_bus, PluginEventBus):
plugin.event_bus.unregister_all()
plugin.on_stop()
del self.plugins[plugin_name]
# 更新配置
if plugin_name in self.config.plugins.enable:
self.config.plugins.enable.remove(plugin_name)
plugin_info.is_enabled = False
self.plugin_infos[plugin_name] = plugin_info
self.logger.info(f"Plugin {plugin_name} disabled")
return True
except Exception as e:
plugin_info.requires_restart = True
self.logger.error(f"Failed to disable plugin {plugin_name}: {e}")
return False
async def update_plugin(self, plugin_name: str, new_package_name: Optional[str] = None) -> Optional[PluginInfo]:
"""更新插件"""
try:
plugin_info = self.plugin_infos.get(plugin_name)
if not plugin_info:
return None
assert plugin_info.package_name is not None
if plugin_info.is_internal:
raise Exception("Cannot update internal plugin")
# 获取当前版本
old_version = plugin_info.version
# 先卸载旧插件
await self.uninstall_plugin(plugin_name)
# 执行更新
cmd = [
sys.executable,
"-m",
"pip",
"install",
"--upgrade",
"--index-url",
self.config.update.pypi_registry,
new_package_name or plugin_info.package_name,
]
self.logger.info(f"Updating plugin: {plugin_info.package_name}")
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
# 实时处理输出
async def read_stream(stream, log_func):
while True:
line = await stream.readline()
if not line:
break
log_func(line.decode().strip())
# 并行处理 stdout 和 stderr
await asyncio.gather(
read_stream(process.stdout, lambda msg: self.logger.info(f"[{plugin_info.package_name} update] {msg}")),
read_stream(process.stderr, lambda msg: self.logger.error(f"[{plugin_info.package_name} update] {msg}"))
)
# 等待进程完成
return_code = await process.wait()
if return_code != 0:
raise Exception(f"Failed to update plugin: return code {return_code}")
self.discover_external_plugins()
possible_plugin_infos = [
info
for info in self.plugin_infos.values()
if info.package_name == plugin_info.package_name
]
if possible_plugin_infos:
if possible_plugin_infos[0].version != old_version:
return possible_plugin_infos[0]
else:
raise Exception(
f"Failed to update plugin: {plugin_info.package_name} is already up to date for current Kirara AI version. Maybe you should update Kirara AI to the latest version."
)
except Exception as e:
raise Exception(f"Failed to update plugin: {str(e)}")
return None
def discover_external_plugins(self):
"""发现并加载所有已安装的外部插件"""
self.logger.info("Discovering external plugins...")
from importlib.metadata import distributions
# 获取所有已安装的包
for dist in distributions():
try:
# 检查包是否包含我们需要的 entry point
eps = dist.entry_points
plugin_eps = [ep for ep in eps if ep.group == Plugin.ENTRY_POINT_GROUP]
if not plugin_eps:
continue
for ep in plugin_eps:
try:
# 获取插件元数据
metadata = {
"name": dist.metadata["Name"],
"description": dist.metadata.get("Summary", ""),
"version": dist.metadata.get("Version", "1.0.0"),
"author": dist.metadata.get("Author", "Unknown"),
}
# 创建插件信息
plugin_info = PluginInfo(
name=ep.name,
package_name=dist.metadata["Name"],
description=metadata["description"],
version=metadata["version"],
author=metadata["author"],
is_internal=False,
is_enabled=False,
metadata=None,
)
# 存储插件信息
self.plugin_infos[ep.name] = plugin_info
# 如果插件在启用列表中,则加载它
if ep.name in self.config.plugins.enable:
self._load_external_plugin(ep.name)
except Exception as e:
self.logger.error(
f"Error processing metadata for plugin {ep.name}: {e}"
)
except Exception as e:
self.logger.error(
f"Error processing package {dist.metadata['Name']}: {e}"
)
================================================
FILE: kirara_ai/plugin_manager/utils.py
================================================
from importlib.metadata import PackageNotFoundError, distribution
from typing import Any, Dict, Optional
def get_package_metadata(package_name: str) -> Optional[Dict[str, Any]]:
"""获取Python包的元数据
Args:
package_name: 包名
Returns:
包含包元数据的字典,如果包不存在则返回None
"""
try:
dist = distribution(package_name)
return {
"name": dist.metadata["Name"],
"version": dist.version,
"description": dist.metadata["Summary"] if "Summary" in dist.metadata else "",
"author": dist.metadata["Author"] if "Author" in dist.metadata else "",
}
except PackageNotFoundError:
return None
================================================
FILE: kirara_ai/plugins/.gitkeep
================================================
================================================
FILE: kirara_ai/plugins/bundled_frpc/__init__.py
================================================
from quart import g
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.plugin import Plugin
from kirara_ai.web.app import WebServer
from .frpc_manager import FrpcManager
from .routes import frpc_bp
logger = get_logger("FRPC")
class FrpcPlugin(Plugin):
"""
FRPC 插件
用于在本地启动 FRPC 服务,协助将 Web 服务暴露到外网。
"""
web_server: WebServer
global_config: GlobalConfig
def __init__(self):
self.frpc_manager = None
def on_load(self):
# 创建 FRPC 管理器
self.frpc_manager = FrpcManager(self.global_config)
# 注册中间件,将 frpc_manager 注入到请求上下文
@frpc_bp.before_request
async def inject_frpc_manager():
g.frpc_manager = self.frpc_manager
def on_start(self):
# 挂载 API
self.web_server.web_api_app.register_blueprint(frpc_bp, url_prefix="/api/frpc")
# 如果配置为启用,则尝试启动 frpc
if self.global_config.frpc.enable:
try:
self.frpc_manager.start_frpc(self.global_config.web.port)
except Exception as e:
logger.error(f"启动 FRPC 失败: {e}")
def on_stop(self):
# 停止 frpc 进程
if self.frpc_manager:
self.frpc_manager.stop_frpc()
__all__ = ["FrpcPlugin"]
================================================
FILE: kirara_ai/plugins/bundled_frpc/frpc_manager.py
================================================
import io
import os
import platform
import shutil
import subprocess
import tarfile
import tempfile
import threading
import zipfile
from pathlib import Path
from typing import Awaitable, Callable, Optional, Tuple
import aiohttp
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.logger import get_logger
logger = get_logger("FRPC")
# 存储路径
STORAGE_PATH = "./data/frpc"
class FrpcManager:
"""FRPC 管理器"""
def __init__(self, global_config: GlobalConfig):
self.global_config = global_config
self._frpc_process: Optional[subprocess.Popen] = None
self._frpc_version: str = ""
self._remote_url: str = ""
self._error_message: str = ""
self._download_progress: float = 0
# 确保存储目录存在
self.storage_path = Path(STORAGE_PATH)
self.storage_path.mkdir(parents=True, exist_ok=True)
# 设置 frpc 可执行文件路径
system = platform.system().lower()
if system == "windows":
self._frpc_path = self.storage_path / "frpc.exe"
else:
self._frpc_path = self.storage_path / "frpc"
# 设置配置文件路径
self._frpc_config_path = self.storage_path / "frpc.ini"
# 尝试获取版本信息
if self.is_installed():
self._get_frpc_version()
async def download_frpc(self, progress_callback: Optional[Callable[[float], Awaitable[None]]] = None) -> bool:
"""下载 FRPC"""
# 重置状态
self._download_progress = 0
self._error_message = ""
try:
# 获取系统信息
system = platform.system().lower()
machine = platform.machine().lower()
# 确定下载 URL
if system == "windows":
if machine in ["amd64", "x86_64"]:
url = "https://github.com/fatedier/frp/releases/download/v0.51.3/frp_0.51.3_windows_amd64.zip"
elif machine in ["arm64", "aarch64"]:
url = "https://github.com/fatedier/frp/releases/download/v0.51.3/frp_0.51.3_windows_arm64.zip"
else:
self._error_message = f"不支持的系统架构: {machine}"
if progress_callback:
await progress_callback(0)
return False
elif system == "linux":
if machine in ["amd64", "x86_64"]:
url = "https://github.com/fatedier/frp/releases/download/v0.51.3/frp_0.51.3_linux_amd64.tar.gz"
elif machine in ["arm64", "aarch64"]:
url = "https://github.com/fatedier/frp/releases/download/v0.51.3/frp_0.51.3_linux_arm64.tar.gz"
elif machine.startswith("arm"):
url = "https://github.com/fatedier/frp/releases/download/v0.51.3/frp_0.51.3_linux_arm.tar.gz"
else:
self._error_message = f"不支持的系统架构: {machine}"
if progress_callback:
await progress_callback(0)
return False
elif system == "darwin":
if machine in ["amd64", "x86_64"]:
url = "https://github.com/fatedier/frp/releases/download/v0.51.3/frp_0.51.3_darwin_amd64.tar.gz"
elif machine in ["arm64", "aarch64"]:
url = "https://github.com/fatedier/frp/releases/download/v0.51.3/frp_0.51.3_darwin_arm64.tar.gz"
else:
self._error_message = f"不支持的系统架构: {machine}"
if progress_callback:
await progress_callback(0)
return False
else:
self._error_message = f"不支持的操作系统: {system}"
if progress_callback:
await progress_callback(0)
return False
self._remote_url = url
self._version = "v0.51.3"
# 下载文件
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
if response.status != 200:
self._error_message = f"下载失败: HTTP {response.status}"
if progress_callback:
await progress_callback(0)
return False
total_size = int(response.headers.get("content-length", 0))
downloaded_size = 0
# 直接在内存中下载文件
content = bytearray()
async for chunk in response.content.iter_chunked(8192):
content.extend(chunk)
downloaded_size += len(chunk)
progress = min(100, downloaded_size * 100 / total_size if total_size > 0 else 0)
self._download_progress = progress
if progress_callback:
await progress_callback(progress)
# 解压文件
extract_dir = tempfile.mkdtemp()
try:
# 从内存中解压文件
if url.endswith(".zip"):
with zipfile.ZipFile(io.BytesIO(content)) as zip_ref:
zip_ref.extractall(extract_dir)
elif url.endswith(".tar.gz"):
with tarfile.open(fileobj=io.BytesIO(content), mode="r:gz") as tar_ref:
tar_ref.extractall(extract_dir)
# 查找 frpc 可执行文件
frpc_name = "frpc.exe" if system == "windows" else "frpc"
frpc_files = []
for root, _, files in os.walk(extract_dir):
for file in files:
if file == frpc_name:
frpc_files.append(os.path.join(root, file))
if not frpc_files:
self._error_message = "解压后未找到 frpc 可执行文件"
if progress_callback:
await progress_callback(0)
return False
# 复制 frpc 可执行文件
frpc_path = str(self._frpc_path)
os.makedirs(os.path.dirname(frpc_path), exist_ok=True)
shutil.copy2(frpc_files[0], frpc_path)
# 设置可执行权限
if system != "windows":
os.chmod(frpc_path, 0o755)
if progress_callback:
await progress_callback(100)
self._get_frpc_version()
return True
finally:
shutil.rmtree(extract_dir)
except Exception as e:
self._error_message = f"下载失败: {str(e)}"
if progress_callback:
await progress_callback(0)
return False
except Exception as e:
self._error_message = f"下载失败: {str(e)}"
if progress_callback:
await progress_callback(0)
return False
def _get_frpc_version(self):
"""获取 frpc 版本信息"""
try:
if not self.is_installed():
self._frpc_version = "未安装"
return
result = subprocess.run(
[str(self._frpc_path), "-v"],
capture_output=True,
text=True,
check=True
)
self._frpc_version = result.stdout.strip()
except Exception as e:
logger.error(f"获取 FRPC 版本失败: {e}")
self._frpc_version = "未知"
def _generate_config(self, web_port: int) -> bool:
"""
生成 frpc 配置文件
Args:
web_port: Web 服务端口
Returns:
bool: 配置文件生成是否成功
"""
try:
config = self.global_config.frpc
# 基本配置
config_content = f"""[common]
server_addr = {config.server_addr}
server_port = {config.server_port}
"""
# 如果有令牌,添加令牌配置
if config.token:
config_content += f"token = {config.token}\n"
if self.global_config.web.host == "0.0.0.0":
web_ip = "127.0.0.1"
else:
web_ip = self.global_config.web.host
# 代理配置
config_content += f"""
[kirara_web]
type = tcp
local_ip = {web_ip}
local_port = {web_port}
"""
# 远程端口配置
if config.remote_port > 0:
config_content += f"remote_port = {config.remote_port}\n"
# 写入配置文件
with open(self._frpc_config_path, "w", encoding="utf-8") as f:
f.write(config_content)
logger.info(f"FRPC 配置文件已生成: {self._frpc_config_path}")
return True
except Exception as e:
logger.error(f"生成 FRPC 配置文件失败: {e}")
self._error_message = f"生成 FRPC 配置文件失败: {e}"
return False
def start_frpc(self, web_port: int) -> bool:
"""
启动 frpc
Args:
web_port: Web 服务端口
Returns:
bool: 启动是否成功
"""
if self.is_installed():
self._get_frpc_version()
try:
# 如果已经在运行,先停止
if self._frpc_process is not None:
self.stop_frpc()
# 检查可执行文件是否存在
if not self.is_installed():
logger.error("FRPC 可执行文件不存在,请先下载")
self._error_message = "FRPC 可执行文件不存在,请先下载"
return False
# 生成配置文件
if not self._generate_config(web_port):
return False
# 启动 frpc
cmd = [str(self._frpc_path), "-c", str(self._frpc_config_path)]
self._frpc_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
def print_output(process):
while process.poll() is None:
line = process.stdout.readline()
if not line:
break
logger.info(line.strip())
logger.info("FRPC 已结束")
# 创建一个线程来读取和打印输出
output_thread = threading.Thread(target=print_output, args=(self._frpc_process,))
output_thread.daemon = True # 设置为守护线程,主线程退出时自动结束
output_thread.start()
logger.info(f"FRPC 已启动,PID: {self._frpc_process.pid}")
# 计算远程访问 URL
self._calculate_remote_url()
return True
except Exception as e:
logger.error(f"启动 FRPC 失败: {e}")
self._error_message = f"启动 FRPC 失败: {e}"
return False
def stop_frpc(self) -> bool:
"""
停止 frpc
Returns:
bool: 停止是否成功
"""
try:
if self._frpc_process is not None:
self._frpc_process.terminate()
try:
self._frpc_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self._frpc_process.kill()
logger.info("FRPC 已停止")
self._frpc_process = None
self._remote_url = ""
return True
except Exception as e:
logger.error(f"停止 FRPC 失败: {e}")
self._error_message = f"停止 FRPC 失败: {e}"
return False
def _calculate_remote_url(self):
"""计算远程访问 URL"""
config = self.global_config.frpc
if config.remote_port > 0:
self._remote_url = f"http://{config.server_addr}:{config.remote_port}"
else:
self._remote_url = f"http://{config.server_addr}:随机端口"
def get_status(self) -> Tuple[bool, str, str, str, float]:
"""
获取 frpc 状态
Returns:
Tuple[bool, str, str, str, float]: (是否运行, 版本, 远程URL, 错误信息, 下载进度)
"""
is_running = self._frpc_process is not None and self._frpc_process.poll() is None
# 如果进程已经退出,获取错误信息
if self._frpc_process is not None and self._frpc_process.poll() is not None:
stderr = self._frpc_process.stderr.read() if self._frpc_process.stderr else ""
if stderr:
self._error_message = stderr
self._frpc_process = None
self._get_frpc_version()
return (
is_running,
self._frpc_version,
self._remote_url,
self._error_message,
self._download_progress
)
def is_installed(self) -> bool:
"""
检查 frpc 是否已安装
Returns:
bool: 是否已安装
"""
return os.path.exists(self._frpc_path)
================================================
FILE: kirara_ai/plugins/bundled_frpc/models.py
================================================
from typing import Optional
from pydantic import BaseModel
from kirara_ai.config.global_config import FrpcConfig
class FrpcStatus(BaseModel):
"""FRPC 状态"""
is_running: bool
is_installed: bool
config: FrpcConfig
version: str = ""
remote_url: str = ""
error_message: str = ""
download_progress: float = 0
class FrpcConfigUpdate(BaseModel):
"""FRPC 配置更新请求"""
enable: Optional[bool] = None
server_addr: Optional[str] = None
server_port: Optional[int] = None
token: Optional[str] = None
remote_port: Optional[int] = None
class FrpcDownloadProgress(BaseModel):
"""FRPC 下载进度"""
progress: float
status: str = "downloading" # downloading, completed, error
error_message: str = ""
================================================
FILE: kirara_ai/plugins/bundled_frpc/routes.py
================================================
import asyncio
from quart import Blueprint, Response, g, request
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.web.auth.middleware import require_auth
from .frpc_manager import FrpcManager
from .models import FrpcConfigUpdate, FrpcDownloadProgress, FrpcStatus
frpc_bp = Blueprint("frpc", __name__)
@frpc_bp.route("/status", methods=["GET"])
@require_auth
async def get_status():
"""获取 FRPC 状态"""
frpc_manager: FrpcManager = g.frpc_manager
config: GlobalConfig = g.container.resolve(GlobalConfig)
is_running, version, remote_url, error_message, download_progress = frpc_manager.get_status()
is_installed = frpc_manager.is_installed()
return FrpcStatus(
is_running=is_running,
is_installed=is_installed,
config=config.frpc,
version=version,
remote_url=remote_url,
error_message=error_message,
download_progress=download_progress
).model_dump()
@frpc_bp.route("/config", methods=["POST"])
@require_auth
async def update_config():
"""更新 FRPC 配置"""
frpc_manager: FrpcManager = g.frpc_manager
config: GlobalConfig = g.container.resolve(GlobalConfig)
data = await request.get_json()
config_update = FrpcConfigUpdate(**data)
# 更新配置
frpc_config = config.frpc
# 使用字典推导式更新非 None 字段
update_dict = {k: v for k, v in config_update.model_dump().items() if v is not None}
# 更新配置
for key, value in update_dict.items():
setattr(frpc_config, key, value)
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# 如果启用状态改变,启动或停止 frpc
web_port = config.web.port
if config_update.enable is not None:
if config_update.enable:
frpc_manager.start_frpc(web_port)
else:
frpc_manager.stop_frpc()
# 如果其他配置改变且 frpc 正在运行,重启 frpc
elif frpc_config.enable and any(k != "enable" for k in update_dict.keys()):
frpc_manager.stop_frpc()
frpc_manager.start_frpc(web_port)
# 返回最新状态
is_running, version, remote_url, error_message, download_progress = frpc_manager.get_status()
is_installed = frpc_manager.is_installed()
return FrpcStatus(
is_running=is_running,
is_installed=is_installed,
config=config.frpc,
version=version,
remote_url=remote_url,
error_message=error_message,
download_progress=download_progress
).model_dump()
@frpc_bp.route("/start", methods=["POST"])
@require_auth
async def start_frpc():
"""启动 FRPC"""
frpc_manager: FrpcManager = g.frpc_manager
config: GlobalConfig = g.container.resolve(GlobalConfig)
# 更新配置中的启用状态
config.frpc.enable = True
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# 启动 frpc
success = frpc_manager.start_frpc(config.web.port)
# 返回最新状态
is_running, version, remote_url, error_message, download_progress = frpc_manager.get_status()
is_installed = frpc_manager.is_installed()
return FrpcStatus(
is_running=is_running,
is_installed=is_installed,
config=config.frpc,
version=version,
remote_url=remote_url,
error_message=error_message,
download_progress=download_progress
).model_dump()
@frpc_bp.route("/stop", methods=["POST"])
@require_auth
async def stop_frpc():
"""停止 FRPC"""
frpc_manager: FrpcManager = g.frpc_manager
config: GlobalConfig = g.container.resolve(GlobalConfig)
# 更新配置中的启用状态
config.frpc.enable = False
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# 停止 frpc
success = frpc_manager.stop_frpc()
# 返回最新状态
is_running, version, remote_url, error_message, download_progress = frpc_manager.get_status()
is_installed = frpc_manager.is_installed()
return FrpcStatus(
is_running=is_running,
is_installed=is_installed,
config=config.frpc,
version=version,
remote_url=remote_url,
error_message=error_message,
download_progress=download_progress
).model_dump()
@frpc_bp.route("/download", methods=["GET"])
@require_auth
async def download_frpc():
"""下载 FRPC 并通过 SSE 返回进度"""
frpc_manager: FrpcManager = g.frpc_manager
# 创建一个队列用于存储SSE事件
queue = asyncio.Queue()
# 定义进度回调函数
async def progress_callback(progress: float):
"""进度回调函数"""
status = "downloading"
if progress >= 100:
status = "completed"
# 将事件放入队列
await queue.put(
FrpcDownloadProgress(
progress=progress,
status=status
).model_dump_json()
)
# 启动下载任务
async def download_task():
try:
# 发送初始状态
await queue.put(
FrpcDownloadProgress(
progress=0,
status='downloading'
).model_dump_json()
)
# 执行下载
success = await frpc_manager.download_frpc(progress_callback)
# 如果下载失败,发送错误状态
if not success:
await queue.put(
FrpcDownloadProgress(
progress=0,
status='error',
error_message=frpc_manager._error_message
).model_dump_json()
)
except Exception as e:
# 发送错误状态
await queue.put(
FrpcDownloadProgress(
progress=0,
status='error',
error_message=str(e)
).model_dump_json()
)
finally:
# 标记队列结束
await queue.put(None)
# 启动下载任务
asyncio.create_task(download_task())
# 定义SSE流生成器
async def send_events():
while True:
message = await queue.get()
if message is None: # 结束信号
break
yield f"data: {message}\n\n"
# 返回SSE响应
return Response(
send_events(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
"Connection": "keep-alive"
}
)
================================================
FILE: kirara_ai/plugins/im_http_legacy_adapter/__init__.py
================================================
import os
from im_http_legacy_adapter.adapter import HttpLegacyAdapter, HttpLegacyConfig
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.plugin import Plugin
from kirara_ai.web.app import WebServer
logger = get_logger("HTTP-Legacy-Adapter")
class HttpLegacyAdapterPlugin(Plugin):
"""HTTP API 消息适配器插件"""
web_server: WebServer
def __init__(self):
pass
def on_load(self):
self.im_registry.register(
"http_legacy",
HttpLegacyAdapter,
HttpLegacyConfig,
"HTTP API",
"HTTP 消息 API,可用于接入第三方程序。",
"""
HTTP API 可用于接入第三方程序,接口文档请见项目 [README](https://github.com/lss233/chatgpt-mirai-qq-bot/blob/master/README.md#-http-api)。
"""
)
self.web_server.add_static_assets("/assets/icons/im/http_legacy.png", os.path.join(os.path.dirname(__file__), "assets", "http_legacy.png"))
def on_start(self):
pass
def on_stop(self):
pass
================================================
FILE: kirara_ai/plugins/im_http_legacy_adapter/adapter.py
================================================
import asyncio
import re
import time
from typing import Any, Dict, List, Optional, Protocol
from fastapi import Body, FastAPI, Query, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, ConfigDict, Field
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.message import ImageMessage, IMMessage, TextMessage, VoiceMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.logger import get_logger
from kirara_ai.web.app import WebServer
from kirara_ai.workflow.core.dispatch import WorkflowDispatcher
# 全局变量,用于跟踪是否已经设置了路由
_is_first_setup = True
# 全局变量,用于存储所有已授权的API密钥
_authorized_api_keys: List[str] = []
class HttpLegacyConfig(BaseModel):
"""HTTP Legacy API 配置"""
api_key: Optional[str] = Field(
description="自定义的API密钥,设置后,请求接口时需要带上这个密钥,若填空则不校验。", default=None)
host: Optional[str] = Field(description="已废弃,HTTP API 服务器地址,设置后将启动独立服务器。",
default=None, json_schema_extra={"hidden_unset": True})
port: Optional[int] = Field(description="已废弃,HTTP API 服务器端口,设置后将启动独立服务器。",
default=None, json_schema_extra={"hidden_unset": True})
model_config = ConfigDict(extra="allow")
class ResponseResult:
def __init__(self, message=None, voice=None, image=None, result_status="SUCCESS"):
self.result_status = result_status
self.message = (
[]
if message is None
else message if isinstance(message, list) else [message]
)
self.voice = (
[] if voice is None else voice if isinstance(voice, list) else [
voice]
)
self.image = (
[] if image is None else image if isinstance(image, list) else [
image]
)
def to_dict(self):
return {
"result": self.result_status,
"message": self.message,
"voice": self.voice,
"image": self.image,
}
def pop_all(self):
self.message = []
self.voice = []
self.image = []
class MessageHandler(Protocol):
async def __call__(self, message: IMMessage) -> None: ...
class V2Request:
def __init__(self, session_id: str, username: str, message: str, request_time: str):
self.session_id = session_id
self.username = username
self.message = message
self.result = ResponseResult()
self.request_time = request_time
self.done = False
self.response_event = asyncio.Event()
class HttpLegacyAdapter(IMAdapter):
"""HTTP Legacy API适配器"""
dispatcher: WorkflowDispatcher
web_server: WebServer
def __init__(self, config: HttpLegacyConfig):
self.config = config
self.app = FastAPI(title="HTTP Legacy API")
self.request_dic: Dict[str, V2Request] = {}
self.logger = get_logger("HTTP-Legacy-Adapter")
async def convert_to_message(self, raw_message: Any) -> IMMessage:
data = raw_message
username = data.get("username", "某人")
message_text = data.get("message", "")
session_id = data.get("session_id", "friend-default_session")
if (
session_id.startswith("group-")
and len(session_id.split("-")) == 2
and ":" in session_id.split("-")[1]
):
# group-group_id:user_id
ids = session_id.split("-")[1].split(":")
sender = ChatSender.from_group_chat(
user_id=ids[1], group_id=ids[0], display_name=username
)
else:
sender = ChatSender.from_c2c_chat(
user_id=session_id, display_name=username)
return IMMessage(
sender=sender,
message_elements=[TextMessage(text=message_text)],
raw_message={"session_id": session_id, **data},
)
async def handle_message_elements(self, result: ResponseResult, message: IMMessage):
for element in message.message_elements:
if isinstance(element, VoiceMessage):
result.voice.append(await element.get_base64_url())
elif isinstance(element, ImageMessage):
result.image.append(await element.get_base64_url())
else:
result.message.append(element.to_plain())
def verify_api_key(self, request: Request) -> bool:
"""验证API密钥"""
if not self.config.api_key:
return True
auth_header = request.headers.get("Authorization", "")
# 支持 Bearer 认证和直接传递 API Key
if auth_header.startswith("Bearer "):
auth_header = auth_header[7:] # 移除 "Bearer " 前缀
return auth_header == self.config.api_key or auth_header in _authorized_api_keys
def create_auth_error_response(self):
"""创建认证失败的响应"""
return JSONResponse(
content=ResponseResult(
message="认证失败", result_status="FAILED").to_dict(),
status_code=401
)
def setup_routes(self, target_app=None):
app = target_app if target_app else self.app
async def verify_auth(request: Request):
if not self.verify_api_key(request):
return self.create_auth_error_response()
return None
@app.post("/v1/chat")
async def v1_chat(request: Request, data: dict = Body(...)):
auth_response = await verify_auth(request)
if auth_response:
return auth_response
message = await self.convert_to_message(data)
result = ResponseResult()
async def handle_response(resp_message: IMMessage):
await self.handle_message_elements(result, resp_message)
message.sender.raw_metadata["callback_func"] = handle_response
await self.dispatcher.dispatch(self, message)
return result.to_dict()
@app.post("/v2/chat")
async def v2_chat(request: Request, data: dict = Body(...)):
auth_response = await verify_auth(request)
if auth_response:
return auth_response
request_time = str(int(time.time() * 1000))
message = await self.convert_to_message(data)
assert message.raw_message is not None
session_id = message.raw_message["session_id"]
bot_request = V2Request(
session_id,
message.sender.display_name,
data.get("message", ""),
request_time,
)
self.request_dic[request_time] = bot_request
async def handle_response(resp_message: IMMessage):
await self.handle_message_elements(bot_request.result, resp_message)
bot_request.response_event.set()
message.sender.raw_metadata["callback_func"] = handle_response
asyncio.create_task(self.dispatcher.dispatch(self, message))
return request_time
@app.get("/v2/chat/response")
async def v2_chat_response(request: Request, request_id: str = Query(...)):
auth_response = await verify_auth(request)
if auth_response:
return auth_response
request_id = re.sub(
r'^[%22%27"\'"]*|[%22%27"\'"]*$', "", request_id)
bot_request = self.request_dic.get(request_id)
if bot_request is None:
return ResponseResult(
message="没有更多了!", result_status="FAILED"
).to_dict()
await bot_request.response_event.wait()
bot_request.response_event.clear()
response = bot_request.result.to_dict()
bot_request.result = ResponseResult()
if bot_request.done:
self.request_dic.pop(request_id)
return response
async def send_message(self, message: IMMessage, recipient: ChatSender):
"""此处负责 HTTP 的响应逻辑"""
await recipient.raw_metadata["callback_func"](message)
@property
def is_standalone(self):
return self.config.host
async def _start_standalone_server(self):
"""启动独立HTTP服务器"""
# 使用 hypercorn 配置来正确处理关闭信号
from hypercorn.asyncio import serve
from hypercorn.config import Config
from hypercorn.logging import Logger
from kirara_ai.logger import HypercornLoggerWrapper
config = Config()
host = self.config.host
port = self.config.port or 18560
config.bind = [f"{host}:{port}"]
config._log = Logger(config)
config._log.access_logger = HypercornLoggerWrapper(self.logger) # type: ignore
config._log.error_logger = HypercornLoggerWrapper(self.logger) # type: ignore
self.server_task = asyncio.create_task(serve(self.app, config)) # type: ignore
async def start(self):
"""启动HTTP服务器"""
global _is_first_setup, _authorized_api_keys
if self.is_standalone:
self.logger.warning("正在使用过时的独立模式,请尽快更新为集成模式。")
await self._start_standalone_server()
self.setup_routes()
else:
# 为所有的路由添加前缀
if _is_first_setup:
# 直接往 self.web_server.app 注册路由,而不是mount_app
self.setup_routes(target_app=self.web_server.app)
_is_first_setup = False
else:
if self.config.api_key and self.config.api_key not in _authorized_api_keys:
_authorized_api_keys.append(self.config.api_key)
# 启动清理过期请求的任务
self.cleanup_task = asyncio.create_task(
self.cleanup_expired_requests())
async def cleanup_expired_requests(self):
"""清理过期的请求"""
while True:
now = time.time()
expired_keys = [
key
for key, req in self.request_dic.items()
if now - int(key) / 1000 > 600
]
for key in expired_keys:
self.request_dic.pop(key)
await asyncio.sleep(60)
async def stop(self):
"""停止HTTP服务器"""
global _authorized_api_keys
# 如果有API密钥,从授权列表中移除
if self.config.api_key and self.config.api_key in _authorized_api_keys:
_authorized_api_keys.remove(self.config.api_key)
if self.is_standalone and hasattr(self, "server_task"):
self.server_task.cancel()
try:
await self.server_task
except asyncio.CancelledError:
pass
except Exception as e:
self.logger.error(f"Error during server shutdown: {e}")
if hasattr(self, "cleanup_task"):
self.cleanup_task.cancel()
try:
await self.cleanup_task
except asyncio.CancelledError:
pass
================================================
FILE: kirara_ai/plugins/im_http_legacy_adapter/setup.py
================================================
from setuptools import find_packages, setup
setup(
name="kirara_ai-http-legacy-adapter",
version="1.0.0",
description="HTTP legacy adapter plugin for kirara_ai",
author="Internal",
packages=find_packages(),
install_requires=["aiohttp", "requests"],
entry_points={
"chatgpt_mirai.plugins": [
"http_legacy = im_http_legacy_adapter.plugin:HttpLegacyAdapterPlugin"
]
},
)
================================================
FILE: kirara_ai/plugins/im_http_legacy_adapter/tests/api_test.py
================================================
import asyncio
import os
import sys
import pytest
from fastapi.testclient import TestClient
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.message import IMMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block.registry import BlockRegistry
from kirara_ai.workflow.core.dispatch.dispatcher import WorkflowDispatcher
from kirara_ai.workflow.core.dispatch.registry import DispatchRuleRegistry
from tests.utils.test_block_registry import create_test_block_registry
sys.path.insert(
0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from im_http_legacy_adapter.adapter import HttpLegacyAdapter, HttpLegacyConfig, ResponseResult
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
class FakeWorkflowDispatcher(WorkflowDispatcher):
async def dispatch(self, source: IMAdapter, message: IMMessage):
return None
@pytest.fixture
def config():
return HttpLegacyConfig(host="127.0.0.1", port=8080, debug=False)
@pytest.fixture
def adapter(config):
container = DependencyContainer()
container.register(DependencyContainer, container)
container.register(WorkflowRegistry, WorkflowRegistry(container))
container.register(DispatchRuleRegistry, DispatchRuleRegistry(container))
container.register(WorkflowDispatcher, FakeWorkflowDispatcher(container))
container.register(BlockRegistry, create_test_block_registry())
adapter = HttpLegacyAdapter(config)
adapter.setup_routes()
adapter.dispatcher = container.resolve(WorkflowDispatcher)
return adapter
@pytest.mark.asyncio
async def test_chat_endpoint(adapter):
test_client = TestClient(adapter.app)
# Test text message
response = test_client.post(
"/v1/chat",
json={
"session_id": "test_session",
"username": "test_user",
"message": "Hello, world!",
},
)
assert response.status_code == 200
data = response.json()
assert "result" in data
assert "message" in data
assert isinstance(data["message"], list)
# Test with missing fields (should use defaults)
response = test_client.post("/v1/chat", json={"message": "Test message"})
assert response.status_code == 200
@pytest.mark.asyncio
async def test_response_result():
# Test single message
result = ResponseResult(message="Test message")
json_data = result.to_dict()
assert json_data["message"] == ["Test message"]
assert json_data["voice"] == []
assert json_data["image"] == []
# Test multiple messages
result = ResponseResult(
message=["Message 1", "Message 2"],
voice=["voice1.mp3"],
image=["image1.jpg", "image2.jpg"],
)
json_data = result.to_dict()
assert len(json_data["message"]) == 2
assert len(json_data["voice"]) == 1
assert len(json_data["image"]) == 2
@pytest.mark.asyncio
async def test_adapter_lifecycle(adapter):
# Test start and stop
start_task = asyncio.create_task(adapter.start())
await asyncio.sleep(0.1) # Give some time for server to start
await adapter.stop()
try:
await start_task
except Exception:
pass # Expected to fail when we stop the server
================================================
FILE: kirara_ai/plugins/im_qqbot_adapter/__init__.py
================================================
import asyncio
import os
from im_qqbot_adapter.adapter import QQBotAdapter, QQBotConfig
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.plugin import Plugin
from kirara_ai.web.app import WebServer
logger = get_logger("QQBot-Adapter")
class QQBotAdapterPlugin(Plugin):
web_server: WebServer
def __init__(self):
pass
def on_load(self):
self.im_registry.register(
"qqbot",
QQBotAdapter,
QQBotConfig,
"QQ 开放平台机器人",
"QQ 官方机器人,需要服务器支持接收 QQ 的 Webhook 请求,支持基本的聊天功能,群聊中必须通过 @ 触发。",
"""
QQ 开放平台机器人,需要服务器支持接收 QQ 的 Webhook 请求,配置流程可参考 [QQ 开放平台文档](https://q.qq.com/wiki/) 和 [Kirara AI 文档](https://kirara-docs.app.lss233.com/guide/configuration/im.html)。
"""
)
local_logo_path = os.path.join(os.path.dirname(__file__), "assets", "qqbot.png")
self.web_server.add_static_assets("/assets/icons/im/qqbot.png", local_logo_path)
def on_start(self):
pass
def on_stop(self):
try:
tasks = []
loop = asyncio.get_event_loop()
for key, adapter in self.im_manager.get_adapters().items():
if isinstance(adapter, QQBotAdapter) and adapter.is_running:
tasks.append(self.im_manager.stop_adapter(key, loop))
for key in list(self.im_manager.get_adapters().keys()):
self.im_manager.delete_adapter(key)
loop.run_until_complete(asyncio.gather(*tasks))
except Exception as e:
logger.error(f"Error stopping QQBot adapter: {e}")
finally:
self.im_registry.unregister("qqbot")
logger.info("QQBot adapter stopped")
================================================
FILE: kirara_ai/plugins/im_qqbot_adapter/adapter.py
================================================
import asyncio
import base64
import functools
import uuid
from typing import List, Optional
import ymbotpy as botpy
import ymbotpy.message
from pydantic import BaseModel, ConfigDict, Field
from ymbotpy.http import Route as BotpyRoute
from ymbotpy.types.message import Media as BotpyMedia
from kirara_ai.im.adapter import BotProfileAdapter, IMAdapter
from kirara_ai.im.message import (ImageMessage, IMMessage, MentionElement, MessageElement, TextMessage, VideoElement,
VoiceMessage)
from kirara_ai.im.profile import UserProfile
from kirara_ai.im.sender import ChatSender, ChatType
from kirara_ai.logger import get_logger
from kirara_ai.web.app import WebServer
from kirara_ai.workflow.core.dispatch import WorkflowDispatcher
from .utils import URL_PATTERN
WEBHOOK_URL_PREFIX = "/im/webhook/qqbot"
def make_webhook_url():
return f"{WEBHOOK_URL_PREFIX}/{str(uuid.uuid4())[:8]}/"
def auto_generate_webhook_url(s: dict):
s["readOnly"] = True
s["default"] = make_webhook_url()
s["textType"] = True
class QQBotConfig(BaseModel):
"""
QQBot 配置文件模型。
"""
app_id: str = Field(description="机器人的 App ID。")
app_secret: str = Field(title="App Secret", description="机器人的 App Secret。")
token: str = Field(
title="Token", description="机器人令牌,用于调用 QQ 机器人的 OpenAPI。")
sandbox: bool = Field(
title="沙盒环境", description="是否为沙盒环境,通常只有正式发布的机器人才会关闭此选项。", default=False)
webhook_url: str = Field(
title="Webhook 回调 URL", description="供 QQ 机器人回调的 URL,由系统自动生成,无法修改。",
default_factory=make_webhook_url,
json_schema_extra=auto_generate_webhook_url
)
model_config = ConfigDict(extra="allow")
async def patched_post_file(
self,
file_type: int,
file_data: bytes,
openid: Optional[str] = None,
group_openid: Optional[str] = None
) -> BotpyMedia:
"""
重写 post_file 方法,添加文件类型参数。
"""
payload = {
"file_type": file_type,
"file_data": base64.b64encode(file_data).decode('utf-8'),
"srv_send_msg": False
}
if openid:
route = BotpyRoute("POST", "/v2/users/{openid}/files", openid=openid)
elif group_openid:
route = BotpyRoute(
"POST", "/v2/groups/{group_openid}/files", group_openid=group_openid)
else:
raise ValueError("openid 和 group_openid 不能同时为空")
return await self._http.request(route, json=payload)
class QQBotAdapter(botpy.WebHookClient, IMAdapter, BotProfileAdapter):
"""
QQBot Adapter,包含 QQBot Bot 的所有逻辑。
"""
dispatcher: WorkflowDispatcher
web_server: WebServer
_loop: asyncio.AbstractEventLoop
def __init__(self, config: QQBotConfig):
self.config = config
self.is_sandbox = config.sandbox
self.logger = get_logger("QQBot-Adapter")
super().__init__(
timeout=5,
is_sandbox=self.is_sandbox,
bot_log=True,
ext_handlers=True,
)
self.loop = self._loop
self.user = None
async def convert_to_message(self, raw_message: ymbotpy.message.BaseMessage) -> IMMessage:
if isinstance(raw_message, ymbotpy.message.GroupMessage):
assert raw_message.author.member_openid is not None
assert raw_message.group_openid is not None
sender = ChatSender.from_group_chat(
raw_message.author.member_openid, raw_message.group_openid, 'QQ 用户')
elif isinstance(raw_message, ymbotpy.message.C2CMessage):
sender = ChatSender.from_c2c_chat(
raw_message.author.user_openid, 'QQ 用户')
else:
raise ValueError(f"不支持的消息类型: {type(raw_message)}")
raw_dict = {items: str(getattr(raw_message, items))
for items in raw_message.__slots__ if not items.startswith("_")}
sender.raw_metadata = {
"message_id": raw_message.id,
"message_seq": raw_message.msg_seq,
"timestamp": raw_message.timestamp,
}
elements: List[MessageElement] = []
if raw_message.content.strip():
elements.append(TextMessage(text=raw_message.content.lstrip()))
for attachment in raw_message.attachments:
if attachment.content_type.startswith('image/'):
elements.append(
ImageMessage(
url=attachment.url,
format=attachment.content_type.removeprefix('image/')
)
)
elif attachment.content_type.startswith('audio'):
elements.append(
VoiceMessage(
url=attachment.url,
format=attachment.filename.split('.')[-1]
)
)
return IMMessage(sender=sender, message_elements=elements, raw_message=raw_dict)
async def send_message(self, message: IMMessage, recipient: ChatSender):
"""
发送消息
:param message: 要发送的消息对象。
:param recipient: 接收消息的目标对象。
"""
if recipient.raw_metadata is None or recipient.raw_metadata.get('message_id') is None:
raise ValueError("Unable to retreive send_message info from metadata")
msg_id = recipient.raw_metadata['message_id']
if recipient.chat_type == ChatType.C2C:
assert recipient.user_id is not None
post_message_func = functools.partial(
self.api.post_c2c_message, openid=recipient.user_id, msg_id=msg_id)
upload_func = functools.partial(
patched_post_file, self.api, openid=recipient.user_id)
elif recipient.chat_type == ChatType.GROUP:
assert recipient.group_id is not None
post_message_func = functools.partial(
self.api.post_group_message, group_openid=recipient.group_id, msg_id=msg_id)
upload_func = functools.partial(
patched_post_file, self.api, group_openid=recipient.group_id)
else:
raise ValueError(f"不支持的消息类型: {recipient.chat_type}")
# 文本缓冲区
current_text = ""
msg_seq = 0
url_replaced = False # 标记是否替换过 URL
def replace_url_dots(text: str) -> str:
"""
检查文本是否包含 URL,如果包含则替换 URL 中的句点为句号。
:param text: 要检查的文本。
:return: 替换后的文本。
"""
nonlocal url_replaced
def replace_dots(match):
nonlocal url_replaced
url_replaced = True
return match.group(0).replace('.', '。')
return URL_PATTERN.sub(replace_dots, text)
async def send_text_message(text: str):
"""
发送文本消息。
:param text: 要发送的文本内容。
"""
await post_message_func(content=text, msg_seq=msg_seq) # type: ignore
# 单次循环处理所有元素
for element in message.message_elements:
if isinstance(element, TextMessage):
# 如果有文本,直接添加到当前缓冲区
current_text += element.text
# 立即发送当前文本缓冲区内容
if current_text:
modified_text = replace_url_dots(current_text)
await send_text_message(modified_text)
msg_seq += 1
current_text = ""
elif isinstance(element, MentionElement):
# 添加提及标记到当前文本缓冲区
current_text += f''
elif isinstance(element, ImageMessage) or isinstance(element, VoiceMessage) or isinstance(element, VideoElement):
# 如果有累积的文本,先发送文本
if current_text:
modified_text = replace_url_dots(current_text)
await send_text_message(modified_text)
msg_seq += 1
current_text = ""
# 然后发送媒体
if isinstance(element, ImageMessage):
file_type = 1
elif isinstance(element, VoiceMessage):
file_type = 3
elif isinstance(element, VideoElement):
file_type = 2
media = await upload_func(file_type=file_type, file_data=await element.get_data())
await post_message_func(media=media, msg_seq=msg_seq, msg_type=7) # type: ignore
msg_seq += 1
# 补充解释性文本
if url_replaced:
current_text = current_text + "(URL 中的句点已替换为句号以避免屏蔽)"
# 发送循环结束后可能剩余的文本
if current_text:
modified_text = replace_url_dots(current_text)
await send_text_message(modified_text)
msg_seq += 1
async def on_c2c_message_create(self, message: ymbotpy.message.C2CMessage):
"""
处理接收到的消息。
:param message: 接收到的消息对象。
"""
self.logger.debug(f"收到 C2C 消息: {message}")
im_message = await self.convert_to_message(message)
await self.dispatcher.dispatch(self, im_message)
async def on_group_at_message_create(self, message: ymbotpy.message.GroupMessage):
"""
处理接收到的群消息。
:param message: 接收到的消息对象。
"""
self.logger.debug(f"收到群消息: {message}")
im_message = await self.convert_to_message(message)
# 这个逆天的 Webhook 居然不包含 mention 字段,这里要手动补上
im_message.message_elements.append(
MentionElement(target=ChatSender.get_bot_sender()))
await self.dispatcher.dispatch(self, im_message)
async def get_bot_profile(self) -> Optional[UserProfile]:
"""
获取机器人资料
:return: 机器人资料
"""
if self.user is None:
return None
return UserProfile(
user_id=self.user['id'],
username=self.user['username'],
display_name=self.user['username'],
avatar_url=self.user['avatar']
)
async def start(self):
"""启动 Bot"""
token = botpy.Token(self.config.app_id, self.config.app_secret)
self.user = await self.http.login(token)
self.robot = botpy.Robot(self.user)
bot_webhook = botpy.BotWebHook(
self.config.app_id,
self.config.app_secret,
hook_route='/',
client=self,
system_log=True,
botapi=self.api,
loop=self.loop
)
app = await bot_webhook.init_fastapi()
app.user_middleware.clear()
self.web_server.mount_app(
self.config.webhook_url.removesuffix('/'), app)
async def stop(self):
"""停止 Bot"""
================================================
FILE: kirara_ai/plugins/im_qqbot_adapter/setup.py
================================================
from setuptools import find_packages, setup
setup(
name="kirara_ai-qqbot-adapter",
version="1.0.0",
description="QQBot adapter plugin for kirara_ai",
author="Internal",
packages=find_packages(),
install_requires=["ymbotpy"],
entry_points={
"chatgpt_mirai.plugins": [
"qqbot = im_qqbot_adapter.plugin:QQBotAdapterPlugin"
]
},
)
================================================
FILE: kirara_ai/plugins/im_qqbot_adapter/utils.py
================================================
import re
URL_PATTERN = re.compile(
r'((?:https?://)?(?:[-\w]+\.)+(?:com|net|org|edu|gov|mil|io|co|ai|app|dev|cn|jp|kr|uk|de|fr|ru|br|in|au|info|biz|cc|tv|xyz|top|site|online|store|tech|blog|club|art|design|shop|mobile|cloud|me|ws|live|app|games|science|news|media|social|world|life|network)(?:/[-\w./%]*)?)')
================================================
FILE: kirara_ai/plugins/im_telegram_adapter/__init__.py
================================================
import asyncio
import os
from im_telegram_adapter.adapter import TelegramAdapter, TelegramConfig
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.plugin import Plugin
from kirara_ai.web.app import WebServer
logger = get_logger("TG-Adapter")
class TelegramAdapterPlugin(Plugin):
web_server: WebServer
def __init__(self):
pass
def on_load(self):
self.im_registry.register(
"telegram",
TelegramAdapter,
TelegramConfig,
"Telegram 机器人",
"Telegram 官方机器人,支持私聊、群聊、 Markdown 格式消息。",
"""
Telegram 机器人,配置流程可参考 [Telegram 官方文档](https://core.telegram.org/bots/tutorial) 和 [Kirara AI 文档](https://kirara-docs.app.lss233.com/guide/configuration/im.html)。
"""
)
# 添加当前文件夹下的 assets/telegram.svg 文件夹到 web 服务器
local_logo_path = os.path.join(os.path.dirname(__file__), "assets", "telegram.png")
self.web_server.add_static_assets("/assets/icons/im/telegram.png", local_logo_path)
def on_start(self):
pass
def on_stop(self):
try:
tasks = []
loop = asyncio.get_event_loop()
for key, adapter in self.im_manager.get_adapters().items():
if isinstance(adapter, TelegramAdapter) and adapter.is_running:
tasks.append(self.im_manager.stop_adapter(key, loop))
for key in list(self.im_manager.get_adapters().keys()):
self.im_manager.delete_adapter(key)
loop.run_until_complete(asyncio.gather(*tasks))
except Exception as e:
logger.error(f"Error stopping Telegram adapter: {e}")
finally:
self.im_registry.unregister("telegram")
logger.info("Telegram adapter stopped")
================================================
FILE: kirara_ai/plugins/im_telegram_adapter/adapter.py
================================================
import asyncio
import random
from functools import lru_cache
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from telegram import Bot, ChatFullInfo, Update, User
from telegram.constants import MessageEntityType
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegramify_markdown import markdownify
from kirara_ai.im.adapter import BotProfileAdapter, EditStateAdapter, IMAdapter, UserProfileAdapter
from kirara_ai.im.message import (FileElement, ImageMessage, IMMessage, MentionElement, MessageElement, TextMessage,
VideoMessage, VoiceMessage)
from kirara_ai.im.profile import UserProfile
from kirara_ai.im.sender import ChatSender, ChatType
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.dispatch import WorkflowDispatcher
def get_display_name(user: User | ChatFullInfo):
if user.first_name or user.last_name:
return f"{user.first_name or ''} {user.last_name or ''}".strip()
elif user.username:
return user.username
else:
return str(user.id)
class TelegramConfig(BaseModel):
"""
Telegram 配置文件模型。
"""
token: str = Field(description="Telegram 机器人的 Token,从 @BotFather 获取。")
model_config = ConfigDict(extra="allow")
def __repr__(self):
return f"TelegramConfig(token={self.token})"
class TelegramAdapter(IMAdapter, UserProfileAdapter, EditStateAdapter, BotProfileAdapter):
"""
Telegram Adapter,包含 Telegram Bot 的所有逻辑。
"""
dispatcher: WorkflowDispatcher
def __init__(self, config: TelegramConfig):
self.me = None
self.config = config
self.application = Application.builder().token(config.token).build()
self.bot = Bot(token=config.token)
# 注册命令处理器和消息处理器
self.application.add_handler(
CommandHandler("start", self.command_start))
self.application.add_handler(
MessageHandler(
filters.TEXT | filters.VOICE | filters.PHOTO | filters.VIDEO, self.handle_message
)
)
self.logger = get_logger("Telegram-Adapter")
async def command_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""处理 /start 命令"""
if update.message:
await update.message.reply_text("Welcome! I am ready to receive your messages.")
async def handle_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""处理接收到的消息"""
# 将 Telegram 消息转换为 Message 对象
if not update.message:
return
message = await self.convert_to_message(update)
try:
await self.dispatcher.dispatch(self, message)
except Exception as e:
await update.message.reply_text(
f"Workflow execution failed, please try again later: {str(e)}"
)
async def convert_to_message(self, raw_message: Update) -> IMMessage:
"""
将 Telegram 的 Update 对象转换为 Message 对象。
:param raw_message: Telegram 的 Update 对象。
:return: 转换后的 Message 对象。
"""
assert raw_message.message
assert raw_message.message.from_user
if (
raw_message.message.chat.type == "group"
or raw_message.message.chat.type == "supergroup"
):
sender = ChatSender.from_group_chat(
user_id=str(raw_message.message.from_user.id),
group_id=str(raw_message.message.chat_id),
display_name=get_display_name(raw_message.message.from_user),
)
else:
sender = ChatSender.from_c2c_chat(
user_id=str(raw_message.message.chat_id),
display_name=get_display_name(raw_message.message.from_user),
)
message_elements: List[MessageElement] = []
raw_message_dict = raw_message.message.to_dict()
# 处理文本消息
if raw_message.message.text is not None or raw_message.message.caption is not None:
text: str = raw_message.message.text or raw_message.message.caption # type: ignore
offset = 0
for entity in raw_message.message.entities or raw_message.message.caption_entities or []:
if entity.type in (MessageEntityType.MENTION, MessageEntityType.TEXT_MENTION):
# Extract mention text
mention_text = text[entity.offset:entity.offset + entity.length]
# Add preceding text as TextMessage
if entity.offset > offset:
message_elements.append(TextMessage(
text=text[offset:entity.offset]))
# Create ChatSender for MentionElement
if entity.type == "text_mention" and entity.user:
if entity.user.id == self.me.id: # type: ignore
mention_element = MentionElement(
target=ChatSender.get_bot_sender())
else:
mention_element = MentionElement(target=ChatSender.from_c2c_chat(
user_id=str(entity.user.id), display_name=mention_text))
elif entity.type == "mention":
# 这里需要从 adapter 实例中获取 bot 的 username
if mention_text == f'@{self.me.username}': # type: ignore
mention_element = MentionElement(
target=ChatSender.get_bot_sender())
else:
mention_element = MentionElement(target=ChatSender.from_c2c_chat(
user_id=f'unknown_id:{mention_text}', display_name=mention_text))
else:
# Fallback in case of unknown entity type
mention_element = TextMessage( # type: ignore
text=mention_text) # Or handle as needed
message_elements.append(mention_element)
offset = entity.offset + entity.length
# Add remaining text as TextMessage
if offset < len(text):
message_elements.append(TextMessage(text=text[offset:]))
# 处理语音消息
if raw_message.message.voice:
voice_file = await raw_message.message.voice.get_file()
data = await voice_file.download_as_bytearray()
voice_element = VoiceMessage(data=bytes(data))
message_elements.append(voice_element)
# 处理图片消息
if raw_message.message.photo:
# 获取最高分辨率的图片
photo = raw_message.message.photo[-1]
photo_file = await photo.get_file()
data = await photo_file.download_as_bytearray()
photo_element = ImageMessage(data=bytes(data))
message_elements.append(photo_element)
if raw_message.message.video:
video_file = await raw_message.message.video.get_file()
data = await video_file.download_as_bytearray()
video_element = VideoMessage(data=bytes(data))
message_elements.append(video_element)
if raw_message.message.document:
document_file = await raw_message.message.document.get_file()
data = await document_file.download_as_bytearray()
document_element = FileElement(data=bytes(data))
message_elements.append(document_element)
# 创建 Message 对象
message = IMMessage(
sender=sender,
message_elements=message_elements,
raw_message=raw_message_dict,
)
return message
async def send_message(self, message: IMMessage, recipient: ChatSender):
"""
发送消息到 Telegram。
:param message: 要发送的消息对象。
:param recipient: 接收消息的目标对象,这里应该是 chat_id。
"""
if recipient.chat_type == ChatType.C2C:
chat_id = recipient.user_id
elif recipient.chat_type == ChatType.GROUP:
assert recipient.group_id
chat_id = recipient.group_id
else:
raise ValueError(f"Unsupported chat type: {recipient.chat_type}")
for element in message.message_elements:
if isinstance(element, TextMessage):
await self.application.bot.send_chat_action(
chat_id=chat_id, action="typing"
)
text = markdownify(element.text)
# 如果是非首条消息,适当停顿,模拟打字
if message.message_elements.index(element) > 0:
# 停顿通常和字数有关,但是会带一些随机
duration = max(len(element.text) * 0.1, 1) + random.uniform(0, 1) * 0.1
await asyncio.sleep(duration)
await self.application.bot.send_message(
chat_id=chat_id, text=text, parse_mode="MarkdownV2"
)
elif isinstance(element, ImageMessage):
await self.application.bot.send_chat_action(
chat_id=chat_id, action="upload_photo"
)
await self.application.bot.send_photo(
chat_id=chat_id, photo=await element.get_data(), parse_mode="MarkdownV2"
)
elif isinstance(element, VoiceMessage):
await self.application.bot.send_chat_action(
chat_id=chat_id, action="upload_voice"
)
await self.application.bot.send_voice(
chat_id=chat_id, voice=await element.get_data(), parse_mode="MarkdownV2"
)
elif isinstance(element, VideoMessage):
await self.application.bot.send_chat_action(
chat_id=chat_id, action="upload_video"
)
await self.application.bot.send_video(
chat_id=chat_id, video=await element.get_data(), parse_mode="MarkdownV2"
)
async def start(self):
"""启动 Bot"""
await self.application.initialize()
await self.application.start()
self.me = await self.bot.get_me()
assert self.application.updater
await self.application.updater.start_polling(drop_pending_updates=True)
async def stop(self):
"""停止 Bot"""
assert self.application.updater
try:
if self.application.updater.running:
await self.application.updater.stop()
if self.application.running:
await self.application.stop()
await self.application.shutdown()
except:
pass
async def set_chat_editing_state(
self, chat_sender: ChatSender, is_editing: bool = True
):
"""
设置或取消对话的编辑状态
:param chat_sender: 对话的发送者
:param is_editing: True 表示正在编辑,False 表示取消编辑状态
"""
action = "typing" if is_editing else "cancel"
chat_id = (
chat_sender.user_id
if chat_sender.chat_type == ChatType.C2C
else chat_sender.group_id
)
if not chat_id:
raise ValueError("Unable to get chat_id")
try:
self.logger.debug(
f"Setting chat editing state to {is_editing} for chat_id {chat_id}"
)
if is_editing:
await self.application.bot.send_chat_action(
chat_id=chat_id, action=action
)
else:
# 取消编辑状态时发送一个空操作
await self.application.bot.send_chat_action(
chat_id=chat_id, action=action
)
except Exception as e:
self.logger.warning(f"Failed to set chat editing state: {str(e)}")
@lru_cache(maxsize=10)
async def _cached_get_chat(self, user_id):
"""
带缓存的获取用户信息方法
:param user_id: 用户ID
:return: 用户对象
"""
return await self.application.bot.get_chat(user_id)
async def query_user_profile(self, chat_sender: ChatSender) -> UserProfile:
"""
查询 Telegram 用户资料
:param chat_sender: 用户的聊天发送者信息
:return: 用户资料
"""
try:
# 获取用户 ID
user_id = chat_sender.user_id
# 获取用户对象(使用缓存)
user = await self._cached_get_chat(user_id)
# 构建用户资料
profile = UserProfile( # type: ignore
user_id=str(user_id),
username=user.username,
display_name=get_display_name(user),
full_name=f"{user.first_name or ''} {user.last_name or ''}".strip(),
avatar_url=None, # Telegram 需要额外处理获取头像
)
return profile
except Exception as e:
self.logger.warning(f"Failed to query user profile: {str(e)}")
# 返回部分信息
return UserProfile( # type: ignore
user_id=str(chat_sender.user_id), display_name=chat_sender.display_name
)
async def get_bot_profile(self) -> Optional[UserProfile]:
"""
获取机器人资料
:return: 机器人资料
"""
if not self.me or not self.is_running:
return None
profile_photos = await self.me.get_profile_photos()
if profile_photos and profile_photos.photos:
file_id = profile_photos.photos[0][-1].file_id
file = await self.bot.get_file(file_id)
photo_url = file.file_path
else:
photo_url = None
return UserProfile(
user_id=str(self.me.id),
username=self.me.username,
display_name=get_display_name(self.me),
full_name=f"{self.me.first_name or ''} {self.me.last_name or ''}".strip(),
avatar_url=photo_url,
)
================================================
FILE: kirara_ai/plugins/im_telegram_adapter/setup.py
================================================
from setuptools import find_packages, setup
setup(
name="kirara_ai-telegram-adapter",
version="1.0.0",
description="Telegram adapter plugin for kirara_ai",
author="Internal",
packages=find_packages(),
install_requires=["python-telegram-bot", "telegramify-markdown"],
entry_points={
"chatgpt_mirai.plugins": [
"telegram = im_telegram_adapter.plugin:TelegramAdapterPlugin"
]
},
)
================================================
FILE: kirara_ai/plugins/im_wecom_adapter/__init__.py
================================================
import os
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.plugin import Plugin
from kirara_ai.web.app import WebServer
from .adapter import WecomAdapter, WecomConfig
logger = get_logger("Wecom-Adapter")
__all__ = ["WecomAdapter", "WecomConfig"]
class WecomAdapterPlugin(Plugin):
web_server: WebServer
def __init__(self):
pass
def on_load(self):
self.im_registry.register(
"wecom",
WecomAdapter,
WecomConfig,
"企业微信应用 / 微信公众号",
"微信官方消息 API,需要服务器支持接收微信的 Webhook 访问。",
r"""
企业微信应用/微信公众号官方消息 API,需要服务器支持 Webhook 访问。详情配置可参考[微信官方文档](https://open.work.weixin.qq.com/wwopen/manual/detail?t=selfBuildApp)和[Kirara配置文档](https://kirara-docs.app.lss233.com/guide/configuration/im.html#%E4%BC%81%E4%B8%9A%E5%BE%AE%E4%BF%A1-wecom)。
"""
)
self.web_server.add_static_assets(
"/assets/icons/im/wecom.png", os.path.join(os.path.dirname(__file__), "assets", "wecom.png")
)
def on_start(self):
pass
def on_stop(self):
pass
================================================
FILE: kirara_ai/plugins/im_wecom_adapter/adapter.py
================================================
import asyncio
import base64
import os
import uuid
from io import BytesIO
from typing import Any, List, Optional
import aiohttp
from fastapi import FastAPI, HTTPException, Request, Response
from pydantic import BaseModel, ConfigDict, Field
from starlette.routing import Route
from wechatpy.client import BaseWeChatClient
from wechatpy.exceptions import InvalidSignatureException
from wechatpy.messages import BaseMessage
from wechatpy.replies import create_reply
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.message import (FileElement, ImageMessage, IMMessage, MessageElement, TextMessage, VideoElement,
VoiceMessage)
from kirara_ai.im.sender import ChatSender
from kirara_ai.logger import HypercornLoggerWrapper, get_logger
from kirara_ai.web.app import WebServer
from kirara_ai.workflow.core.dispatch.dispatcher import WorkflowDispatcher
from .delegates import CorpWechatApiDelegate, PublicWechatApiDelegate, WechatApiDelegate
WECOM_TEMP_DIR = os.path.join(os.getcwd(), 'data', 'temp', 'wecom')
WEBHOOK_URL_PREFIX = "/im/webhook/wechat"
def make_webhook_url():
return f"{WEBHOOK_URL_PREFIX}/{str(uuid.uuid4())[:8]}"
def auto_generate_webhook_url(s: dict):
s["readOnly"] = True
s["default"] = make_webhook_url()
s["textType"] = True
class WecomConfig(BaseModel):
"""企业微信配置
文档: https://work.weixin.qq.com/api/doc/90000/90136/91770
"""
app_id: str = Field(title="应用ID", description="见微信侧显示")
secret: str = Field(title="应用Secret", description="见微信侧显示")
token: str = Field(title="Token", description="与微信侧填写保持一致")
encoding_aes_key: str = Field(
title="EncodingAESKey", description="请通过微信侧随机生成")
corp_id: Optional[str] = Field(
title="企业ID", description="企业微信后台显示的企业ID,微信公众号等场景无需填写。", default=None)
webhook_url: str = Field(
title="微信端回调地址",
description="供微信端请求的 Webhook URL,填写在微信端,由系统自动生成,无法修改。",
default_factory=make_webhook_url,
json_schema_extra=auto_generate_webhook_url
)
host: Optional[str] = Field(title="HTTP 服务地址", description="已过时,请删除并使用 webhook_url 代替。",
default=None, json_schema_extra={"hidden_unset": True})
port: Optional[int] = Field(title="HTTP 服务端口", description="已过时,请删除并使用 webhook_url 代替。",
default=None, json_schema_extra={"hidden_unset": True})
model_config = ConfigDict(extra="allow")
def __init__(self, **kwargs: Any):
# 如果 agent_id 存在,则自动使用 agent_id 作为 app_id
if "agent_id" in kwargs:
kwargs["app_id"] = str(kwargs["agent_id"])
super().__init__(**kwargs)
class WeComUtils:
"""企业微信相关的工具类"""
def __init__(self, client: BaseWeChatClient):
self.client = client
self.logger = get_logger("WeComUtils")
@property
def access_token(self) -> Optional[str]:
return self.client.access_token
async def download_and_save_media(self, media_id: str, file_name: str) -> Optional[str]:
"""下载并保存媒体文件到本地"""
file_path = os.path.join(WECOM_TEMP_DIR, file_name)
try:
media_data = await self.download_media(media_id)
if media_data:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(media_data)
return file_path
except Exception as e:
self.logger.error(f"Failed to save media: {str(e)}")
return None
async def download_media(self, media_id: str) -> Optional[bytes]:
"""下载企业微信的媒体文件"""
url = f"https://qyapi.weixin.qq.com/cgi-bin/media/get?access_token={self.access_token}&media_id={media_id}"
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
return await response.read()
self.logger.error(
f"Failed to download media: {response.status}")
except Exception as e:
self.logger.error(f"Failed to download media: {str(e)}")
return None
class WecomAdapter(IMAdapter):
"""企业微信适配器"""
dispatcher: WorkflowDispatcher
web_server: WebServer
def __init__(self, config: WecomConfig):
self.wecom_utils = None
self.api_delegate: Optional[WechatApiDelegate] = None
self.config = config
if self.config.host:
self.app = FastAPI()
else:
self.app = self.web_server.app
self.logger = get_logger("Wecom-Adapter")
self.is_running = False
if not self.config.host:
self.config.host = None
self.config.port = None
elif not self.config.port:
self.config.port = 15650
if not self.config.webhook_url:
self.config.webhook_url = make_webhook_url()
self.reply_tasks: dict[str, asyncio.Task] = {}
# 根据配置选择合适的API代理
self.setup_wechat_api()
def setup_wechat_api(self):
"""根据配置设置微信API代理"""
if self.config.corp_id:
self.api_delegate = CorpWechatApiDelegate()
else:
self.api_delegate = PublicWechatApiDelegate()
self.api_delegate.setup_api(self.config)
# 设置工具类
self.wecom_utils = WeComUtils(self.api_delegate.client)
def setup_routes(self):
if self.config.host:
webhook_url = '/wechat'
else:
webhook_url = self.config.webhook_url
# unregister old route if exists
for route in self.app.routes:
if isinstance(route, Route) and route.path == webhook_url:
self.app.routes.remove(route)
@self.app.get(webhook_url)
async def handle_check_request(request: Request):
"""处理 GET 请求"""
if not self.is_running:
self.logger.warning("Wecom-Adapter is not running, skipping check request.")
raise HTTPException(status_code=404)
assert self.api_delegate is not None
signature = request.query_params.get("msg_signature", "")
if not signature:
signature = request.query_params.get("signature", "")
timestamp = request.query_params.get("timestamp", "")
nonce = request.query_params.get("nonce", "")
echo_str = request.query_params.get("echostr", "")
try:
echo_str = self.api_delegate.check_signature(
signature, timestamp, nonce, echo_str
)
return Response(content=echo_str, media_type="text/plain")
except InvalidSignatureException:
self.logger.error("failed to check signature, please check your settings.")
raise HTTPException(status_code=403)
@self.app.post(webhook_url)
async def handle_message(request: Request):
"""处理 POST 请求"""
if not self.is_running:
self.logger.warning("Wecom-Adapter is not running, skipping message request.")
raise HTTPException(status_code=404)
assert self.api_delegate is not None
assert self.wecom_utils is not None
signature = request.query_params.get("msg_signature", "")
if not signature:
signature = request.query_params.get("signature", "")
timestamp = request.query_params.get("timestamp", "")
nonce = request.query_params.get("nonce", "")
try:
msg_str = self.api_delegate.decrypt_message(
await request.body(), signature, timestamp, nonce
)
except InvalidSignatureException:
self.logger.error("failed to check signature, please check your settings.")
raise HTTPException(status_code=403)
msg: BaseMessage = self.api_delegate.parse_message(msg_str)
if msg.id in self.reply_tasks:
self.logger.debug(f"skip processing due to duplicate msgid: {msg.id}")
reply = await self.reply_tasks[msg.id]
del self.reply_tasks[msg.id]
return Response(content=create_reply(reply, msg, render=True), media_type="text/xml")
# 预处理媒体消息
media_path = None
if msg.type in ["voice", "video", "file"]:
media_id = msg.media_id
file_name = f"temp_{msg.type}_{media_id}.{msg.type}"
media_path = await self.wecom_utils.download_and_save_media(media_id, file_name)
# 转换消息
message = await self.convert_to_message(msg, media_path)
self.reply_tasks[msg.id] = asyncio.Future() # type: ignore
message.sender.raw_metadata["reply"] = self.reply_tasks[msg.id] # type: ignore
# 分发消息
asyncio.create_task(self.dispatcher.dispatch(self, message))
reply = await message.sender.raw_metadata["reply"]
del message.sender.raw_metadata["reply"]
return Response(content=create_reply(reply, msg, render=True), media_type="text/xml")
async def convert_to_message(self, raw_message: Any, media_path: Optional[str] = None) -> IMMessage:
"""将企业微信消息转换为统一消息格式"""
# 企业微信应用似乎没有群聊的概念,所以这里只能用单聊
sender = ChatSender.from_c2c_chat(
raw_message.source, raw_message.source)
message_elements: List[MessageElement] = []
raw_message_dict = raw_message.__dict__
if raw_message.type == "text":
message_elements.append(TextMessage(text=raw_message.content))
elif raw_message.type == "image":
message_elements.append(ImageMessage(url=raw_message.image))
elif raw_message.type == "voice" and media_path:
message_elements.append(VoiceMessage(url=media_path))
elif raw_message.type == "video" and media_path:
message_elements.append(VideoElement(path=media_path))
elif raw_message.type == "file" and media_path:
message_elements.append(FileElement(path=media_path))
elif raw_message.type == "location":
location_text = f"[Location] {raw_message.label} (X: {raw_message.location_x}, Y: {raw_message.location_y})"
message_elements.append(TextMessage(text=location_text))
elif raw_message.type == "link":
link_text = f"[Link] {raw_message.title}: {raw_message.description} ({raw_message.url})"
message_elements.append(TextMessage(text=link_text))
else:
message_elements.append(TextMessage(
text=f"Unsupported message type: {raw_message.type}"))
return IMMessage(
sender=sender,
message_elements=message_elements,
raw_message=raw_message_dict,
)
async def _send_text(self, user_id: str, text: str):
"""发送文本消息"""
assert self.api_delegate is not None
try:
return await self.api_delegate.send_text(self.config.app_id, user_id, text)
except Exception as e:
self.logger.error(f"Failed to send text message: {e}")
raise e
async def _send_media(self, user_id: str, media_data: str, media_type: str):
"""发送媒体消息的通用方法"""
assert self.api_delegate is not None
try:
media_bytes = BytesIO(base64.b64decode(media_data))
return await self.api_delegate.send_media(self.config.app_id, user_id, media_type, media_bytes)
except Exception as e:
self.logger.error(f"Failed to send {media_type} message: {e}")
raise e
async def send_message(self, message: IMMessage, recipient: ChatSender):
"""发送消息到企业微信"""
user_id = recipient.user_id
res = None
try:
for element in message.message_elements:
if isinstance(element, TextMessage) and element.text:
res = await self._send_text(user_id, element.text)
elif isinstance(element, ImageMessage) and element.url:
res = await self._send_media(user_id, element.url, "image")
elif isinstance(element, VoiceMessage) and element.url:
res = await self._send_media(user_id, element.url, "voice")
elif isinstance(element, VideoElement) and element.path:
res = await self._send_media(user_id, element.path, "video")
elif isinstance(element, FileElement) and element.path:
res = await self._send_media(user_id, element.path, "file")
if res:
print(res)
if recipient.raw_metadata and "reply" in recipient.raw_metadata:
recipient.raw_metadata["reply"].set_result(None)
except Exception as e:
if 'Error code: 48001' in str(e):
# 未开通主动回复能力
if recipient.raw_metadata and "reply" in recipient.raw_metadata:
self.logger.warning("未开通主动回复能力,将采用被动回复消息 API,此模式下只能回复一条消息。")
recipient.raw_metadata["reply"].set_result(message.content)
else:
self.logger.warning("未开通主动回复能力,且不在上下文中,无法发送消息。")
async def _start_standalone_server(self):
"""启动服务"""
from hypercorn.asyncio import serve
from hypercorn.config import Config
from hypercorn.logging import Logger
config = Config()
config.bind = [f"{self.config.host}:{self.config.port}"]
# config._log = get_logger("Wecom-API")
# hypercorn 的 logger 需要做转换
config._log = Logger(config)
config._log.access_logger = HypercornLoggerWrapper(self.logger) # type: ignore
config._log.error_logger = HypercornLoggerWrapper(self.logger) # type: ignore
self.server_task = asyncio.create_task(serve(self.app, config)) # type: ignore
async def _stop_standalone_server(self):
"""停止服务"""
if hasattr(self, "server_task"):
self.server_task.cancel()
try:
await self.server_task
except asyncio.CancelledError:
pass
except Exception as e:
self.logger.error(f"Error during server shutdown: {e}")
async def start(self):
self.setup_wechat_api()
if self.config.host:
self.logger.warning("正在使用过时的启动模式,请尽快更新为 Webhook 模式。")
await self._start_standalone_server()
self.setup_routes()
self.is_running = True
self.logger.info("Wecom-Adapter 启动成功")
async def stop(self):
if self.config.host:
await self._stop_standalone_server()
self.is_running = False
self.logger.info("Wecom-Adapter 停止成功")
================================================
FILE: kirara_ai/plugins/im_wecom_adapter/delegates.py
================================================
from abc import ABC, abstractmethod
from io import BytesIO
from typing import TYPE_CHECKING, Any
from wechatpy.messages import BaseMessage
from kirara_ai.logger import get_logger
if TYPE_CHECKING:
from .adapter import WecomConfig
class WechatApiDelegate(ABC):
"""微信API代理接口,用于处理不同类型的微信API调用"""
@abstractmethod
def setup_api(self, config: "WecomConfig"):
"""设置API相关组件"""
@abstractmethod
def check_signature(self, signature: str, timestamp: str, nonce: str, echo_str: str) -> str:
"""验证签名"""
@abstractmethod
def decrypt_message(self, message: bytes, signature: str, timestamp: str, nonce: str) -> str:
"""解密消息"""
@abstractmethod
def parse_message(self, message: str) -> BaseMessage:
"""解析消息"""
@abstractmethod
async def send_text(self, app_id: str, user_id: str, text: str) -> Any:
"""发送文本消息"""
@abstractmethod
async def send_media(self, app_id: str, user_id: str, media_type: str, media_bytes: BytesIO) -> Any:
"""发送媒体消息"""
class CorpWechatApiDelegate(WechatApiDelegate):
"""企业微信API代理实现"""
def setup_api(self, config: "WecomConfig"):
"""设置企业微信API相关组件"""
from wechatpy.enterprise import parse_message
from wechatpy.enterprise.client import WeChatClient
from wechatpy.enterprise.crypto import WeChatCrypto
self.crypto = WeChatCrypto(
config.token, config.encoding_aes_key, config.corp_id
)
self.client = WeChatClient(config.corp_id, config.secret)
self.parse_message_func = parse_message
self.logger = get_logger("CorpWechatApiDelegate")
def check_signature(self, signature: str, timestamp: str, nonce: str, echo_str: str) -> str:
"""验证企业微信签名"""
return self.crypto.check_signature(signature, timestamp, nonce, echo_str)
def decrypt_message(self, message: bytes, signature: str, timestamp: str, nonce: str) -> str:
"""解密企业微信消息"""
return self.crypto.decrypt_message(message, signature, timestamp, nonce)
def parse_message(self, message: str) -> BaseMessage:
"""解析企业微信消息"""
return self.parse_message_func(message) # type: ignore
async def send_text(self, app_id: str, user_id: str, text: str) -> Any:
"""发送企业微信文本消息"""
return self.client.message.send_text(app_id, user_id, text)
async def send_media(self, app_id: str, user_id: str, media_type: str, media_bytes: BytesIO) -> Any:
"""发送企业微信媒体消息"""
media_id = self.client.media.upload(media_type, media_bytes)["media_id"]
send_method = getattr(self.client.message, f"send_{media_type}")
return send_method(app_id, user_id, media_id)
class PublicWechatApiDelegate(WechatApiDelegate):
"""公众号微信API代理实现"""
def setup_api(self, config: "WecomConfig"):
"""设置公众号API相关组件"""
from wechatpy import WeChatClient
from wechatpy.crypto import WeChatCrypto
from wechatpy.parser import parse_message
self.crypto = WeChatCrypto(
config.token, config.encoding_aes_key, config.app_id
)
self.client = WeChatClient(config.app_id, config.secret)
self.parse_message_func = parse_message
self.logger = get_logger("PublicWechatApiDelegate")
def check_signature(self, signature: str, timestamp: str, nonce: str, echo_str: str) -> str:
"""验证公众号签名"""
from wechatpy.utils import check_signature as wechat_check_signature
wechat_check_signature(self.crypto.token, signature, timestamp, nonce)
return echo_str
def decrypt_message(self, message: bytes, signature: str, timestamp: str, nonce: str) -> str:
"""解密公众号消息"""
return self.crypto.decrypt_message(message, signature, timestamp, nonce)
def parse_message(self, message: str) -> BaseMessage:
"""解析公众号消息"""
return self.parse_message_func(message) # type: ignore
async def send_text(self, app_id: str, user_id: str, text: str) -> Any:
"""发送公众号文本消息"""
# 公众号API不需要app_id参数
return self.client.message.send_text(user_id, text)
async def send_media(self, app_id: str, user_id: str, media_type: str, media_bytes: BytesIO) -> Any:
"""发送公众号媒体消息"""
media_id = self.client.media.upload(media_type, media_bytes)["media_id"]
send_method = getattr(self.client.message, f"send_{media_type}")
# 公众号API不需要app_id参数
return send_method(user_id, media_id)
================================================
FILE: kirara_ai/plugins/im_wecom_adapter/setup.py
================================================
from setuptools import find_packages, setup
setup(
name="kirara_ai-wecom-adapter",
version="1.0.0",
description="WeCom adapter plugin for kirara_ai",
author="Internal",
packages=find_packages(),
install_requires=["wechatpy", "pycryptodome"],
entry_points={
"chatgpt_mirai.plugins": ["wecom = im_wecom_adapter.plugin:WeComAdapterPlugin"]
},
)
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/__init__.py
================================================
from .alibabacloud_adapter import AlibabaCloudAdapter, AlibabaCloudConfig
from .claude_adapter import ClaudeAdapter, ClaudeConfig
from .deepseek_adapter import DeepSeekAdapter, DeepSeekConfig
from .gemini_adapter import GeminiAdapter, GeminiConfig
from .minimax_adapter import MinimaxAdapter, MinimaxConfig
from .moonshot_adapter import MoonshotAdapter, MoonshotConfig
from .ollama_adapter import OllamaAdapter, OllamaConfig
from .openai_adapter import OpenAIAdapter, OpenAIConfig
from .openrouter_adapter import OpenRouterAdapter, OpenRouterConfig
from .siliconflow_adapter import SiliconFlowAdapter, SiliconFlowConfig
from .tencentcloud_adapter import TencentCloudAdapter, TencentCloudConfig
from .volcengine_adapter import VolcengineAdapter, VolcengineConfig
from .mistral_adapter import MistralAdapter, MistralConfig
from .voyage_adapter import VoyageAdapter, VoyageConfig
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.plugin import Plugin
logger = get_logger("LLMPresetAdapters")
class LLMPresetAdaptersPlugin(Plugin):
def __init__(self):
pass
def on_load(self):
self.llm_registry.register(
"OpenAI", OpenAIAdapter, OpenAIConfig
)
self.llm_registry.register(
"DeepSeek", DeepSeekAdapter, DeepSeekConfig
)
self.llm_registry.register(
"Gemini", GeminiAdapter, GeminiConfig
)
self.llm_registry.register(
"Ollama", OllamaAdapter, OllamaConfig
)
self.llm_registry.register(
"Claude", ClaudeAdapter, ClaudeConfig
)
self.llm_registry.register(
"SiliconFlow", SiliconFlowAdapter, SiliconFlowConfig
)
self.llm_registry.register(
"TencentCloud", TencentCloudAdapter, TencentCloudConfig
)
self.llm_registry.register(
"AlibabaCloud", AlibabaCloudAdapter, AlibabaCloudConfig
)
self.llm_registry.register(
"Moonshot", MoonshotAdapter, MoonshotConfig
)
self.llm_registry.register(
"OpenRouter", OpenRouterAdapter, OpenRouterConfig
)
self.llm_registry.register(
"Minimax", MinimaxAdapter, MinimaxConfig
)
self.llm_registry.register(
"Volcengine", VolcengineAdapter, VolcengineConfig
)
self.llm_registry.register(
"Mistral", MistralAdapter, MistralConfig
)
logger.info("LLMPresetAdaptersPlugin loaded")
def on_start(self):
logger.info("LLMPresetAdaptersPlugin started")
def on_stop(self):
logger.info("LLMPresetAdaptersPlugin stopped")
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/alibabacloud_adapter.py
================================================
from kirara_ai.config.global_config import ModelConfig
from .openai_adapter import OpenAIAdapter, OpenAIConfig
from .utils import guess_qwen_model
class AlibabaCloudConfig(OpenAIConfig):
api_base: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
class AlibabaCloudAdapter(OpenAIAdapter):
def __init__(self, config: AlibabaCloudConfig):
super().__init__(config)
async def auto_detect_models(self) -> list[ModelConfig]:
models = await self.get_models()
all_models: list[ModelConfig] = []
for model in models:
guess_result = guess_qwen_model(model)
if guess_result is None:
continue
all_models.append(ModelConfig(id=model, type=guess_result[0].value, ability=guess_result[1]))
return all_models
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/claude_adapter.py
================================================
import asyncio
import base64
from typing import Any, Dict, List
import aiohttp
import requests
from pydantic import BaseModel, ConfigDict
import kirara_ai.llm.format.tool as tools
from kirara_ai.llm.adapter import AutoDetectModelsProtocol, LLMBackendAdapter, LLMChatProtocol
from kirara_ai.llm.format.message import (LLMChatContentPartType, LLMChatImageContent, LLMChatMessage,
LLMChatTextContent, LLMToolCallContent, LLMToolResultContent)
from kirara_ai.llm.format.request import LLMChatRequest, Tool
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.logger import get_logger
from kirara_ai.media.manager import MediaManager
from kirara_ai.tracing.decorator import trace_llm_chat
from .utils import generate_tool_call_id, pick_tool_calls
class ClaudeConfig(BaseModel):
api_key: str
api_base: str = "https://api.anthropic.com/v1"
model_config = ConfigDict(frozen=True)
async def convert_llm_chat_message_to_claude_message(messages: list[LLMChatMessage], media_manager: MediaManager) -> list[dict]:
content: List[Dict[str, Any]] = []
for msg in [msg for msg in messages if msg.role in ["user", "assistant", "tool"]]:
parts: List[Dict[str, Any]] = []
for part in msg.content:
if isinstance(part, LLMChatTextContent):
parts.append({"type": "text", "text": part.text})
elif isinstance(part, LLMToolResultContent):
parts.append(await resolve_tool_result(part, media_manager))
elif isinstance(part, LLMToolCallContent):
continue
elif isinstance(part, LLMChatImageContent):
media = media_manager.get_media(part.media_id)
if media is None:
raise ValueError(f"Media {part.media_id} not found")
parts.append({"source": {"media_type": str(media.mime_type), "data": await media.get_base64()}, "type": "image"})
content.append({
"role": "user" if msg.role == "tool" else msg.role,
"content": parts
})
return content
def convert_tools_to_claude_format(tools: list[Tool]) -> list[dict]:
# 使用 pydantic 的 model_dump 方法,高级排除项`exclude`排除 openai 专属项
return [tool.model_dump(exclude={"strict": True, 'parameters': {'additionalProperties': True}}) for tool in tools]
async def resolve_tool_result(element: LLMToolResultContent, media_manager: MediaManager) -> dict:
tool_result: List[Dict[str, Any]] = []
for item in element.content:
if isinstance(item, tools.TextContent):
tool_result.append({"type": "text", "text": item.text})
elif isinstance(item, tools.MediaContent):
media = media_manager.get_media(item.media_id)
if media is None:
raise ValueError(
f"Media {item.media_id} not found")
tool_result.append({
"type": media.media_type.value.lower(),
"source": {
"type": "base64", "media_type": str(media.mime_type), "data": await media.get_base64()
}
})
return {"type": "tool_result", "tool_use_id": element.id, "content": tool_result, "is_error": element.isError}
class ClaudeAdapter(LLMBackendAdapter, AutoDetectModelsProtocol, LLMChatProtocol):
media_manager: MediaManager
def __init__(self, config: ClaudeConfig):
self.config = config
self.logger = get_logger("ClaudeAdapter")
@trace_llm_chat
def chat(self, req: LLMChatRequest) -> LLMChatResponse:
api_url = f"{self.config.api_base}/messages"
headers = {
"x-api-key": self.config.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
# Claude 的系统消息比较特殊
system_messages = [msg for msg in req.messages if msg.role == "system"]
if system_messages:
system_message = system_messages[0].content
else:
system_message = None
# 构建请求数据
data = {
"model": req.model,
"messages": asyncio.run(convert_llm_chat_message_to_claude_message(req.messages, self.media_manager)),
"max_tokens": req.max_tokens,
"system": system_message,
"temperature": req.temperature,
"top_p": req.top_p,
"stream": req.stream,
# claude tools格式中参数部分命名与openai api不同,不能简单使用model_dumps,在这里进行转换
"tools": convert_tools_to_claude_format(req.tools) if req.tools else None,
# claude默认如果使用了tools字段,这里需要指定tool_choice, claude默认为{"type": "auto"}.
# 可考虑后续给用户暴露此接口, 目前此处各模型定义不太统一
"tool_choice": {"type": "auto"} if req.tools else None,
}
# Remove None fields
data = {k: v for k, v in data.items() if v is not None}
response = requests.post(api_url, json=data, headers=headers)
try:
response.raise_for_status()
response_data = response.json()
except Exception as e:
self.logger.error(f"API Response: {response.text}")
raise e
content: List[LLMChatContentPartType] = []
for res in response_data["content"]:
if res["type"] == "text":
content.append(LLMChatTextContent(text=res["text"]))
elif res["type"] == "image":
image_data = base64.b64decode(res["source"]["data"])
media = asyncio.run(self.media_manager.register_from_data(
image_data, res["source"]["media_type"], source="claude response"))
content.append(LLMChatImageContent(media_id=media))
elif res["type"] == "tool_use":
# tool_call 时 只会额外返回一个 text 的深度思考。
content.append(LLMToolCallContent(id=res.get("id", generate_tool_call_id(res["name"])), name=res["name"], parameters=res.get("input", None)))
usage_data = response_data.get("usage", {})
input_tokens = usage_data.get("input_tokens", 0)
output_tokens = usage_data.get("output_tokens", 0)
return LLMChatResponse(
model=req.model,
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
),
message=Message(
content=content,
role=response_data.get("role", "assistant"),
finish_reason=response_data.get("stop_reason", "stop"),
# claude tool_call混合在content字段中,需要提取
tool_calls=pick_tool_calls(content),
)
)
async def auto_detect_models(self) -> list[str]:
# {
# "data": [
# {
# "type": "model",
# "id": "claude-3-5-sonnet-20241022",
# "display_name": "Claude 3.5 Sonnet (New)",
# "created_at": "2024-10-22T00:00:00Z"
# }
# ],
# "has_more": true,
# "first_id": "",
# "last_id": ""
# }
# claude3 全系支持工具调用,支持多模态tool_result
api_url = f"{self.config.api_base}/models"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
api_url, headers={"x-api-key": self.config.api_key}
) as response:
response.raise_for_status()
response_data = await response.json()
return [model["id"] for model in response_data["data"]]
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/deepseek_adapter.py
================================================
from .openai_adapter import OpenAIAdapterChatBase, OpenAIConfig
class DeepSeekConfig(OpenAIConfig):
api_base: str = "https://api.deepseek.com/v1"
class DeepSeekAdapter(OpenAIAdapterChatBase):
def __init__(self, config: DeepSeekConfig):
super().__init__(config)
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/gemini_adapter.py
================================================
import asyncio
import base64
from typing import Any, Dict, List, Literal, cast
import aiohttp
import requests
from pydantic import BaseModel, ConfigDict
import kirara_ai.llm.format.tool as tool
from kirara_ai.config.global_config import ModelConfig
from kirara_ai.llm.adapter import AutoDetectModelsProtocol, LLMBackendAdapter, LLMChatProtocol, LLMEmbeddingProtocol
from kirara_ai.llm.format.message import (LLMChatContentPartType, LLMChatImageContent, LLMChatMessage,
LLMChatTextContent, LLMToolCallContent, LLMToolResultContent, RoleType)
from kirara_ai.llm.format.request import LLMChatRequest, Tool
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.llm.model_types import LLMAbility, ModelType
from kirara_ai.logger import get_logger
from kirara_ai.media import MediaManager
from kirara_ai.tracing import trace_llm_chat
from .utils import generate_tool_call_id, pick_tool_calls
SAFETY_SETTINGS = [{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
}, {
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
}, {
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
}, {
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}, {
"category": "HARM_CATEGORY_CIVIC_INTEGRITY",
"threshold": "BLOCK_NONE"
}]
# POST 模式支持最大 20 MB 的 inline data
INLINE_LIMIT_SIZE = 1024 * 1024 * 20
IMAGE_MODAL_MODELS = [
"gemini-2.0-flash-exp"
]
class GeminiConfig(BaseModel):
api_key: str
api_base: str = "https://generativelanguage.googleapis.com/v1beta"
model_config = ConfigDict(frozen=True)
async def convert_non_tool_message(msg: LLMChatMessage, media_manager: MediaManager) -> dict:
parts: List[Dict[str, Any]] = []
elements = cast(list[LLMChatContentPartType], msg.content)
for element in elements:
if isinstance(element, LLMChatTextContent):
parts.append({"text": element.text})
elif isinstance(element, LLMChatImageContent):
media = media_manager.get_media(element.media_id)
if media is None:
raise ValueError(f"Media {element.media_id} not found")
parts.append({
"inline_data": {
"mime_type": str(media.mime_type),
"data": await media.get_base64()
}
})
elif isinstance(element, LLMToolCallContent):
parts.append({
"functionCall": {
"name": element.name,
"args": element.parameters
}
})
return {
"role": "model" if msg.role == "assistant" else "user",
"parts": parts
}
async def convert_llm_chat_message_to_gemini_message(msg: LLMChatMessage, media_manager: MediaManager) -> dict:
if msg.role in ["user", "assistant", "system"]:
return await convert_non_tool_message(msg, media_manager)
elif msg.role == "tool":
results = cast(list[LLMToolResultContent], msg.content)
return {"role": "user", "parts": [resolve_tool_results(result) for result in results]}
else:
raise ValueError(f"Invalid role: {msg.role}")
async def convert_all_messages_to_gemini_format(messages: List[LLMChatMessage], media_manager: MediaManager) -> list[dict]:
# gather需要先用异步函数封装,然后才能使用asyncio.run()
return await asyncio.gather(*[convert_llm_chat_message_to_gemini_message(msg, media_manager) for msg in messages])
def convert_tools_to_gemini_format(tools: list[Tool]) -> list[dict[Literal["function_declarations"], list[dict]]]:
# 定义允许的字段结构
allowed_keys = {
"name": True,
"description": True,
"parameters": {
"type": True,
"properties": {
"*": {
"type": True,
"title": True,
"description": True,
"enum": True,
"default": True,
"items": True,
}
},
"required": True
}
}
def filter_dict(data: dict, allowed: dict) -> dict:
"""递归过滤字典,只保留允许的字段"""
result = {}
for key, value in allowed.items():
if key == "*" and isinstance(value, dict):
# 处理通配符情况,适用于 properties 字典
for data_key, data_value in data.items():
if isinstance(data_value, dict):
result[data_key] = filter_dict(data_value, value)
else:
result[data_key] = data_value
elif key in data:
if isinstance(value, dict) and isinstance(data[key], dict):
# 如果是嵌套字典,递归处理
result[key] = filter_dict(data[key], value)
else:
# 否则直接保留值
result[key] = data[key]
return result
function_declarations = []
for tool in tools:
# 将Tool对象转换为字典
tool_dict = tool.model_dump()
# 过滤出需要的字段
filtered_tool = filter_dict(tool_dict, allowed_keys)
function_declarations.append(filtered_tool)
return [{"function_declarations": function_declarations}]
def resolve_tool_results(element: LLMToolResultContent) -> dict:
# 全部拼接成字符串
output = ""
for content in element.content:
if isinstance(content, tool.TextContent):
output += content.text
elif isinstance(content, tool.MediaContent):
# FIXME: Gemini 不支持 response 传媒体内容,需要从额外的 message 中传入,类似于 **篡改记忆**
output += f""
return {
"functionResponse": {
"name": element.name,
"response": {"error": output} if element.isError else {"output": output}
}
}
class GeminiAdapter(LLMBackendAdapter, AutoDetectModelsProtocol, LLMChatProtocol, LLMEmbeddingProtocol):
media_manager: MediaManager
def __init__(self, config: GeminiConfig):
self.config = config
self.logger = get_logger("GeminiAdapter")
@trace_llm_chat
def chat(self, req: LLMChatRequest) -> LLMChatResponse:
api_url = f"{self.config.api_base}/models/{req.model}:generateContent?key={self.config.api_key}"
headers = {
# 这里的 api key 验证方法和 api reference 不一致。本次处理暂时按照api reference写法更正。 Warning: 未进行实际测试
# "x-goog-api-key": self.config.api_key,
"Content-Type": "application/json",
}
response_modalities = ["text"]
if req.model in IMAGE_MODAL_MODELS:
response_modalities.append("image")
data = {
"contents": asyncio.run(convert_all_messages_to_gemini_format(req.messages, self.media_manager)),
"generationConfig": {
"temperature": req.temperature,
"topP": req.top_p,
"topK": 40,
"maxOutputTokens": req.max_tokens,
"stopSequences": req.stop,
"responseModalities": response_modalities,
},
"safetySettings": SAFETY_SETTINGS,
"tools": convert_tools_to_gemini_format(req.tools) if req.tools else None,
}
self.logger.debug(f"Gemini request: {data}")
# Remove None fields
data = {k: v for k, v in data.items() if v is not None}
response = self._post_with_retry(api_url, json=data, headers=headers)
try:
response_data = response.json()
except Exception as e:
self.logger.error(f"API Response: {response.text}")
raise e
content: List[LLMChatContentPartType] = []
role = "assistant"
for part in response_data["candidates"][0]["content"]["parts"]:
if "text" in part:
content.append(LLMChatTextContent(text=part["text"]))
elif "inlineData" in part:
decoded_image_data = base64.b64decode(part["inlineData"]["data"])
media = asyncio.run(
self.media_manager.register_from_data(
data=decoded_image_data,
format=part["inlineData"]["mimeType"].removeprefix(
"image/"),
source="gemini response")
)
content.append(LLMChatImageContent(media_id=media))
elif "functionCall" in part:
content.append(
LLMToolCallContent(
id=generate_tool_call_id(part["functionCall"]["name"]),
name=part["functionCall"]["name"],
parameters=part["functionCall"].get("args", None)
)
)
return LLMChatResponse(
model=req.model,
usage=Usage(
prompt_tokens=response_data["usageMetadata"].get(
"promptTokenCount"),
cached_tokens=response_data["usageMetadata"].get(
"cachedContentTokenCount"),
completion_tokens=sum([modality.get(
"tokenCount", 0) for modality in response_data.get("promptTokensDetails", [])]),
total_tokens=response_data["usageMetadata"].get(
"totalTokenCount"),
),
message=Message(
content=content,
role=cast(RoleType, role),
finish_reason=response_data["candidates"][0].get("finishReason"),
tool_calls=pick_tool_calls(content)
),
)
def embed(self, req: LLMEmbeddingRequest) -> LLMEmbeddingResponse:
# 使用批量嵌入接口,单次嵌入接口:embedContent
# gemini 的 API reference 是这样定义的很奇怪,居然敢在 url 中传递key
api_url = f"{self.config.api_base}/models/{req.model}:batchEmbedContents?key={self.config.api_key}"
headers = {
"Content-Type": "application/json",
}
# 目前 gemini 没有一个嵌入模型支持多模态嵌入
if any(isinstance(input, LLMChatImageContent) for input in req.inputs):
raise ValueError("gemini does not support multi-modal embedding")
inputs = cast(list[LLMChatTextContent], req.inputs)
data = [
{
"model": req.model,
"content": {
"parts": [{"text": input.text}]
},
"outputDimensionality": req.dimension
} for input in inputs
]
# 移除None字段
data = [{ k:v for k,v in item.items() if v is not None} for item in data]
response = self._post_with_retry(url=api_url,json={"requests": data}, headers=headers)
try:
# {
# "embeddings": [
# {"values": [0.1, ...]},
# ...
# ]
# }
response_data: dict[Literal["embeddings"],list[dict[Literal["values"], list[float]]]] = response.json()
except Exception as e:
self.logger.error(f"API Response: {response.text}")
raise e
return LLMEmbeddingResponse(
# gemini不返回usage
vectors=[data["values"] for data in response_data["embeddings"]]
)
async def auto_detect_models(self) -> list[ModelConfig]:
api_url = f"{self.config.api_base}/models"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
api_url, headers={"x-goog-api-key": self.config.api_key}
) as response:
if response.status != 200:
self.logger.error(f"获取模型列表失败: {await response.text()}")
response.raise_for_status()
response_data = await response.json()
return [
ModelConfig(id=model["name"].removeprefix("models/"), type=ModelType.LLM.value, ability=LLMAbility.TextChat.value)
for model in response_data["models"]
if "generateContent" in model["supportedGenerationMethods"]
]
def _post_with_retry(self, url: str, json: dict, headers: dict, retry_count: int = 3) -> requests.Response: # type: ignore
for i in range(retry_count):
try:
response = requests.post(url, json=json, headers=headers)
response.raise_for_status()
return response
except requests.exceptions.RequestException as e:
if i == retry_count - 1:
self.logger.error(
f"API Response: {response.text if 'response' in locals() else 'No response'}")
raise e
else:
self.logger.warning(
f"Request failed, retrying {i+1}/{retry_count}: {e}")
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/minimax_adapter.py
================================================
from .openai_adapter import OpenAIAdapterChatBase, OpenAIConfig
class MinimaxConfig(OpenAIConfig):
api_base: str = "https://api.minimax.chat/v1"
class MinimaxAdapter(OpenAIAdapterChatBase):
def __init__(self, config: MinimaxConfig):
super().__init__(config)
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/mistral_adapter.py
================================================
import aiohttp
from kirara_ai.llm.adapter import AutoDetectModelsProtocol
from .openai_adapter import OpenAIAdapter, OpenAIConfig
class MistralConfig(OpenAIConfig):
api_base: str = "https://api.mistral.ai/v1"
class MistralAdapter(OpenAIAdapter, AutoDetectModelsProtocol):
def __init__(self, config: MistralConfig):
super().__init__(config)
async def auto_detect_models(self) -> list[str]:
# Mistral API 响应格式:
# {
# "object": "list",
# "data": [
# {
# "id": "string",
# "object": "model",
# "created": 0,
# "owned_by": "mistralai",
# "capabilities": {
# "completion_chat": true,
# "completion_fim": false,
# "function_calling": true,
# "fine_tuning": false,
# "vision": false
# },
# "name": "string",
# ...
# }
# ]
# }
api_url = f"{self.config.api_base}/models"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
api_url, headers={"Authorization": f"Bearer {self.config.api_key}"}
) as response:
response.raise_for_status()
response_data = await response.json()
# 只返回支持聊天功能的模型
return [
model["id"]
for model in response_data["data"]
if model.get("capabilities", {}).get("completion_chat", False)
]
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/moonshot_adapter.py
================================================
from .openai_adapter import OpenAIAdapterChatBase, OpenAIConfig
# https://platform.moonshot.cn/docs/intro#%E6%96%87%E6%9C%AC%E7%94%9F%E6%88%90%E6%A8%A1%E5%9E%8B
# TODO: implement feature usages
class MoonshotConfig(OpenAIConfig):
api_base: str = "https://api.moonshot.cn/v1"
class MoonshotAdapter(OpenAIAdapterChatBase):
def __init__(self, config: MoonshotConfig):
super().__init__(config)
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/ollama_adapter.py
================================================
import asyncio
from typing import Any, List, cast
import aiohttp
import requests
from pydantic import BaseModel, ConfigDict
import kirara_ai.llm.format.tool as tools
from kirara_ai.config.global_config import ModelConfig
from kirara_ai.llm.adapter import AutoDetectModelsProtocol, LLMBackendAdapter, LLMChatProtocol, LLMEmbeddingProtocol
from kirara_ai.llm.format.message import (LLMChatContentPartType, LLMChatImageContent, LLMChatMessage,
LLMChatTextContent, LLMToolCallContent, LLMToolResultContent)
from kirara_ai.llm.format.request import LLMChatRequest, Tool
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.llm.model_types import LLMAbility, ModelType
from kirara_ai.logger import get_logger
from kirara_ai.media.manager import MediaManager
from kirara_ai.tracing import trace_llm_chat
from .openai_adapter import convert_tools_to_openai_format
from .utils import generate_tool_call_id, pick_tool_calls
class OllamaConfig(BaseModel):
api_base: str = "http://localhost:11434"
model_config = ConfigDict(frozen=True)
async def resolve_media_ids(media_ids: list[str], media_manager: MediaManager) -> List[str]:
result = []
for media_id in media_ids:
media = media_manager.get_media(media_id)
if media is not None:
base64_data = await media.get_base64()
result.append(base64_data)
return result
def convert_llm_response(response_data: dict[str, dict[str, Any]]) -> list[LLMChatContentPartType]:
# 通过实践证明 llm 调用工具时 content 字段为空字符串没有任何有效信息不进行记录
if calls := response_data["message"].get("tool_calls", None):
return [LLMToolCallContent(
id=generate_tool_call_id(call["function"]["name"]),
name=call["function"]["name"],
parameters=call["function"].get("arguments", None)
) for call in calls
]
else:
return [LLMChatTextContent(text=response_data["message"].get("content", ""))]
def convert_non_tool_message(msg: LLMChatMessage, media_manager: MediaManager, loop: asyncio.AbstractEventLoop) -> dict[str, Any]:
text_content = ""
images: list[str] = []
tool_calls: list[dict[str, Any]] = []
messages: dict[str, Any] = {
"role": msg.role,
"content": "",
}
for part in msg.content:
if isinstance(part, LLMChatTextContent):
text_content += part.text
elif isinstance(part, LLMChatImageContent):
images.append(part.media_id)
elif isinstance(part, LLMToolCallContent):
tool_calls.append({
"function": {
"name": part.name,
"arguments": part.parameters,
}
})
messages["content"] = text_content
if images:
messages["images"] = loop.run_until_complete(
resolve_media_ids(images, media_manager))
if tool_calls:
messages["tool_calls"] = tool_calls
return messages
def convert_tool_result_message(msg: LLMChatMessage, media_manager: MediaManager, loop: asyncio.AbstractEventLoop) -> list[dict]:
"""
将工具调用结果转换为 Ollama 格式
"""
elements = cast(list[LLMToolResultContent], msg.content)
messages = []
for element in elements:
output = ""
for item in element.content:
if isinstance(item, tools.TextContent):
output += f"{item.text}\n"
elif isinstance(item, tools.MediaContent):
output += f"\n"
if element.isError:
output = f"Error: {element.name}\n{output}"
messages.append({"role": "tool", "content": output,
"tool_call_id": element.id})
return messages
def convert_tools_to_ollama_format(tools: list[Tool]) -> list[dict]:
# 这里将其独立出来方便应对后续接口改动
return convert_tools_to_openai_format(tools)
class OllamaAdapter(LLMBackendAdapter, AutoDetectModelsProtocol, LLMChatProtocol, LLMEmbeddingProtocol):
def __init__(self, config: OllamaConfig):
self.config = config
self.logger = get_logger("OllamaAdapter")
@trace_llm_chat
def chat(self, req: LLMChatRequest) -> LLMChatResponse:
api_url = f"{self.config.api_base}/api/chat"
headers = {"Content-Type": "application/json"}
# 将消息转换为 Ollama 格式
messages = []
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for msg in req.messages:
# 收集每条消息中的文本内容和图像
if msg.role == "tool":
messages.extend(convert_tool_result_message(
msg, self.media_manager, loop))
else:
messages.append(convert_non_tool_message(
msg, self.media_manager, loop))
data = {
"model": req.model,
"messages": messages,
"stream": False,
"options": {
"temperature": req.temperature,
"top_p": req.top_p,
"num_predict": req.max_tokens,
"stop": req.stop,
"tools": convert_tools_to_ollama_format(req.tools) if req.tools else None,
},
}
# Remove None fields
data = {k: v for k, v in data.items() if v is not None}
if "options" in data:
data["options"] = {
k: v for k, v in data["options"].items() if v is not None # type: ignore
}
response = requests.post(api_url, json=data, headers=headers)
try:
response.raise_for_status()
response_data = response.json()
except Exception as e:
self.logger.error(f"API Response: {response.text}")
raise e
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
content = convert_llm_response(response_data)
return LLMChatResponse(
model=req.model,
message=Message(
content=content,
role="assistant",
finish_reason="stop",
tool_calls=pick_tool_calls(content),
),
usage=Usage(
prompt_tokens=response_data['prompt_eval_count'],
completion_tokens=response_data['eval_count'],
total_tokens=response_data['prompt_eval_count'] +
response_data['eval_count'],
)
)
def embed(self, req: LLMEmbeddingRequest) -> LLMEmbeddingResponse:
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings api文档地址
api_url = f"{self.config.api_base}/api/embed"
headers = {"Content-Type": "application/json"}
if any(isinstance(input, LLMChatImageContent) for input in req.inputs):
raise ValueError("ollama api does not support multi-modal embedding")
inputs = cast(list[LLMChatTextContent], req.inputs)
data = {
"model": req.model,
"input": [input.text for input in inputs],
# 禁止自动截断输入数据用以适应上下文长度
"truncate": req.truncate
}
data = { k:v for k, v in data.items() if v is not None }
response = requests.post(api_url, json=data, headers=headers)
try:
response.raise_for_status()
response_data = response.json()
except Exception as e:
self.logger.error(f"API Response: {response.text}")
raise e
return LLMEmbeddingResponse(
vectors=response_data["embeddings"],
usage=Usage(
prompt_tokens=response_data.get("prompt_eval_count", 0)
)
)
async def auto_detect_models(self) -> list[ModelConfig]:
api_url = f"{self.config.api_base}/api/tags"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(api_url) as response:
response.raise_for_status()
response_data = await response.json()
return [ModelConfig(id=tag["name"], type=ModelType.LLM.value, ability=LLMAbility.TextChat.value) for tag in response_data["models"]]
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/openai_adapter.py
================================================
import asyncio
import json
from typing import Any, Dict, List, cast, Literal, TypedDict
import aiohttp
import requests
from pydantic import BaseModel, ConfigDict
import kirara_ai.llm.format.tool as tools
from kirara_ai.config.global_config import ModelConfig
from kirara_ai.llm.adapter import AutoDetectModelsProtocol, LLMBackendAdapter, LLMChatProtocol, LLMEmbeddingProtocol
from kirara_ai.llm.format.message import (LLMChatContentPartType, LLMChatImageContent, LLMChatMessage,
LLMChatTextContent, LLMToolCallContent, LLMToolResultContent)
from kirara_ai.llm.format.request import LLMChatRequest, Tool
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.logger import get_logger
from kirara_ai.media import MediaManager
from kirara_ai.tracing import trace_llm_chat
from .utils import guess_openai_model, pick_tool_calls
logger = get_logger("OpenAIAdapter")
async def convert_parts_factory(messages: LLMChatMessage, media_manager: MediaManager) -> list[dict]:
if messages.role == "tool":
# typing.cast 指定类型,避免mypy报错
elements = cast(list[LLMToolResultContent], messages.content)
outputs = []
for element in elements:
# 保证 content 为一个字符串
output = ""
for content in element.content:
if isinstance(content, tools.TextContent):
output = content.text
elif isinstance(content, tools.MediaContent):
media = media_manager.get_media(content.media_id)
if media is None:
raise ValueError(f"Media {content.media_id} not found")
output += f""
else:
raise ValueError(f"Unsupported content type: {type(content)}")
if element.isError:
output = f"Error: {element.name}\n{output}"
outputs.append({
"role": "tool",
"tool_call_id": element.id,
"content": output,
})
return outputs
else:
parts: list[dict[str, Any]] = []
elements = cast(list[LLMChatContentPartType], messages.content)
tool_calls: list[dict[str, Any]] = []
for element in elements:
if isinstance(element, LLMChatTextContent):
parts.append(element.model_dump(mode="json"))
elif isinstance(element, LLMChatImageContent):
media = media_manager.get_media(element.media_id)
if media is None:
raise ValueError(f"Media {element.media_id} not found")
parts.append({
"type": "image_url",
"image_url": {
"url": await media.get_base64_url()
}
})
elif isinstance(element, LLMToolCallContent):
tool_calls.append({
"type": "function",
"id": element.id,
"function": {
"name": element.name,
"arguments": json.dumps(element.parameters or {}, ensure_ascii=False),
}
})
response: Dict[str, Any] = {"role": messages.role}
if parts:
response["content"] = parts
if tool_calls:
response["tool_calls"] = tool_calls
return [response]
async def convert_llm_chat_message_to_openai_message(messages: list[LLMChatMessage], media_manager: MediaManager) -> list[dict]:
# gather 必须先包一层异步函数,转化为协程对象, 否侧报错
results = await asyncio.gather(*[convert_parts_factory(msg, media_manager) for msg in messages])
# 扁平化结果, 展开所有列表
return [item for sublist in results for item in sublist]
def convert_tools_to_openai_format(tools: list[Tool]) -> list[dict]:
return [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters if isinstance(tool.parameters, dict) else tool.parameters.model_dump(),
"strict": tool.strict,
}
} for tool in tools]
class OpenAIConfig(BaseModel):
api_key: str
api_base: str = "https://api.openai.com/v1"
model_config = ConfigDict(frozen=True)
class OpenAIAdapterChatBase(LLMBackendAdapter, AutoDetectModelsProtocol, LLMChatProtocol):
media_manager: MediaManager
def __init__(self, config: OpenAIConfig):
self.config = config
@trace_llm_chat
def chat(self, req: LLMChatRequest) -> LLMChatResponse:
api_url = f"{self.config.api_base}/chat/completions"
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
}
data = {
"messages": asyncio.run(convert_llm_chat_message_to_openai_message(req.messages, self.media_manager)),
"model": req.model,
"frequency_penalty": req.frequency_penalty,
"max_completion_tokens": req.max_tokens, # 最新的reference废除max_tokens,改为如上参数
"presence_penalty": req.presence_penalty,
"response_format": req.response_format,
"stop": req.stop,
"stream": req.stream,
"stream_options": req.stream_options,
"temperature": req.temperature,
"top_p": req.top_p,
# tool pydantic 模型按照 openai api 格式进行的建立。所以这里直接dump
"tools": convert_tools_to_openai_format(req.tools) if req.tools else None,
"tool_choice": "auto" if req.tools else None,
"logprobs": req.logprobs,
"top_logprobs": req.top_logprobs,
}
# Remove None fields
data = {k: v for k, v in data.items() if v is not None}
logger.debug(f"Request: {data}")
response = requests.post(api_url, json=data, headers=headers)
try:
response.raise_for_status()
response_data: dict = response.json()
except Exception as e:
logger.error(f"Response: {response.text}")
raise e
logger.debug(f"Response: {response_data}")
choices: List[dict[str, Any]] = response_data.get("choices", [{}])
first_choice = choices[0] if choices else {}
message: dict[str, Any] = first_choice.get("message", {})
# 检测tool_calls字段是否存在和是否不为None. tool_call时content字段无有效信息,暂不记录
content: list[LLMChatContentPartType] = []
if tool_calls := message.get("tool_calls", None):
content = [LLMToolCallContent(
id=call["id"],
name=call["function"]["name"],
parameters=json.loads(call["function"].get("arguments", "{}"))
) for call in tool_calls]
else:
content = [LLMChatTextContent(text=message.get("content", ""))]
usage_data = response_data.get("usage", {})
return LLMChatResponse(
model=req.model,
usage=Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
),
message=Message(
content=content,
role=message.get("role", "assistant"),
tool_calls = pick_tool_calls(content),
finish_reason=first_choice.get("finish_reason", ""),
),
)
async def get_models(self) -> list[str]:
api_url = f"{self.config.api_base}/models"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
api_url, headers={"Authorization": f"Bearer {self.config.api_key}"}
) as response:
response.raise_for_status()
response_data = await response.json()
return [model["id"] for model in response_data.get("data", [])]
async def auto_detect_models(self) -> list[ModelConfig]:
models = await self.get_models()
all_models: list[ModelConfig] = []
for model in models:
guess_result = guess_openai_model(model)
if guess_result is None:
continue
all_models.append(ModelConfig(id=model, type=guess_result[0].value, ability=guess_result[1]))
return all_models
class EmbeddingData(TypedDict):
object: Literal["embedding"]
embedding: list[float]
index: int
class EmbeddingResponse(TypedDict):
# 用于描述类型定义
object: Literal["list"]
data: list[EmbeddingData]
model: str
usage: dict[Literal["prompt_tokens", "total_tokens"], int]
class OpenAIAdapter(OpenAIAdapterChatBase, LLMEmbeddingProtocol):
def embed(self, req: LLMEmbeddingRequest) -> LLMEmbeddingResponse:
"""
此为openai api嵌入式模型接口
Tips: openai仅在 text-embedding-3 及以后模型中支持设定输出向量维度
"""
api_url = f"{self.config.api_base}/embeddings"
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
}
if len(req.inputs) > 2048:
# text数组不能超过2048个元素,openai api限制
raise ValueError("Text list has too many dimensions, max dimension is 2048")
if any(isinstance(input, LLMChatImageContent) for input in req.inputs):
# 未在api中发现多模态嵌入api, 等待后续更新
raise ValueError("openai does not support multi-modal embedding")
# mypy 类型检查修复,如果添加多模态请去除这个标注
inputs = cast(list[LLMChatTextContent], req.inputs)
data = {
"text": [input.text for input in inputs],
"model": req.model,
"dimensions": req.dimension,
"encoding_format": req.encoding_format
}
# 删除 None 字段
data = {k: v for k, v in data.items() if v is not None}
logger.debug(f"Request: {data}")
response = requests.post(api_url, headers=headers, json=data)
try:
response.raise_for_status()
response_data: EmbeddingResponse = response.json()
except Exception as e:
logger.error(f"Response: {response.text}")
raise e
logger.debug(f"Response: {response_data}")
return LLMEmbeddingResponse(
vectors=[data["embedding"] for data in response_data["data"]],
usage=Usage(
prompt_tokens=response_data["usage"].get("prompt_tokens", 0),
total_tokens=response_data["usage"].get("total_tokens", 0)
)
)
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/openrouter_adapter.py
================================================
import aiohttp
from kirara_ai.config.global_config import ModelConfig
from kirara_ai.llm.model_types import LLMAbility, ModelType
from .openai_adapter import OpenAIAdapter, OpenAIConfig
class OpenRouterConfig(OpenAIConfig):
api_base: str = "https://openrouter.ai/api/v1"
class OpenRouterAdapter(OpenAIAdapter):
def __init__(self, config: OpenRouterConfig):
super().__init__(config)
async def auto_detect_models(self) -> list[ModelConfig]:
all_models: list[ModelConfig] = []
api_url = f"{self.config.api_base}/models"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
api_url, headers={"Authorization": f"Bearer {self.config.api_key}"}
) as response:
response.raise_for_status()
response_data = await response.json()
for model in response_data.get("data", []):
ability = LLMAbility.TextChat.value
for input_modality in model["architecture"]["input_modalities"]:
if input_modality == "text":
ability |= LLMAbility.TextInput.value
elif input_modality == "image":
ability |= LLMAbility.ImageInput.value
for output_modality in model["architecture"]["output_modalities"]:
if output_modality == "text":
ability |= LLMAbility.TextOutput.value
elif output_modality == "image":
ability |= LLMAbility.ImageOutput.value
all_models.append(ModelConfig(id=model["id"], type=ModelType.LLM.value, ability=ability))
return all_models
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/setup.py
================================================
from setuptools import find_packages, setup
setup(
name="kirara_ai-llm-presets",
version="1.0.0",
description="Preset LLM adapters for kirara_ai",
author="Internal",
packages=find_packages(),
install_requires=["requests"],
entry_points={
"chatgpt_mirai.plugins": [
"llm_presets = llm_preset_adapters.plugin:LLMPresetsPlugin"
]
},
)
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/siliconflow_adapter.py
================================================
import aiohttp
from .openai_adapter import OpenAIAdapter, OpenAIConfig
class SiliconFlowConfig(OpenAIConfig):
api_base: str = "https://api.siliconflow.cn/v1"
class SiliconFlowAdapter(OpenAIAdapter):
def __init__(self, config: SiliconFlowConfig):
super().__init__(config)
async def auto_detect_models(self) -> list[str]:
api_url = f"{self.config.api_base}/models?sub_type=chat"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
api_url, headers={"Authorization": f"Bearer {self.config.api_key}"}
) as response:
response.raise_for_status()
response_data = await response.json()
return [model["id"] for model in response_data["data"]]
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/tencentcloud_adapter.py
================================================
from .openai_adapter import OpenAIAdapter, OpenAIConfig
class TencentCloudConfig(OpenAIConfig):
api_base: str = "https://api.hunyuan.cloud.tencent.com/v1"
class TencentCloudAdapter(OpenAIAdapter):
def __init__(self, config: TencentCloudConfig):
super().__init__(config)
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/tests/test_utils.py
================================================
from llm_preset_adapters.utils import guess_openai_model
from kirara_ai.llm.model_types import AudioModelAbility, EmbeddingModelAbility, ImageModelAbility, LLMAbility, ModelType
from kirara_ai.logger import get_logger
models = [
{
"id": "babbage-002",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextCompletion.value
},
{
"id": "chatgpt-4o-latest",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value
},
{
"id": "computer-use-preview-2025-03-11",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
},
{
"id": "dall-e-2",
"type": ModelType.ImageGeneration.value,
"abilities": ImageModelAbility.TextToImage.value | ImageModelAbility.ImageEdit.value | ImageModelAbility.Inpainting.value
},
{
"id": "dall-e-3",
"type": ModelType.ImageGeneration.value,
"abilities": ImageModelAbility.TextToImage.value
},
{
"id": "davinci-002",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextCompletion.value
},
{
"id": "gpt-3-5-0301",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-3-5-turbo-0125",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-3-5-turbo-0613",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-3-5-turbo-1106",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-3-5-turbo-16k-0613",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-3-5-turbo-instruct",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-4-0125-preview",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-4-0314",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-4-0613",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "gpt-4-turbo-2024-04-09",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4.1-2025-04-14",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4.1-mini-2025-04-14",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4.1-nano-2025-04-14",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4.5-preview-2025-02-27",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-2024-05-13",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-2024-08-06",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-2024-11-20",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-audio-preview-2024-10-01",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-audio-preview-2024-12-17",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-mini-2024-07-18",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-mini-audio-preview-2024-12-17",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value | LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-mini-realtime-preview-2024-12-17",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-mini-search-preview-2025-03-11",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value
},
{
"id": "gpt-4o-mini-transcribe",
"type": ModelType.Audio.value,
"abilities": AudioModelAbility.Transcription.value | AudioModelAbility.Realtime.value
},
{
"id": "gpt-4o-mini-tts",
"type": ModelType.Audio.value,
"abilities": AudioModelAbility.Speech.value | AudioModelAbility.Streaming.value
},
{
"id": "gpt-4o-realtime-preview-2024-10-01",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-realtime-preview-2024-12-17",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "gpt-4o-search-preview-2025-03-11",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value
},
{
"id": "gpt-4o-transcribe",
"type": ModelType.Audio.value,
"abilities": AudioModelAbility.Transcription.value | AudioModelAbility.Streaming.value | AudioModelAbility.Realtime.value
},
{
"id": "gpt-image-1",
"type": ModelType.ImageGeneration.value,
"abilities": ImageModelAbility.TextToImage.value | ImageModelAbility.ImageEdit.value |
ImageModelAbility.Inpainting.value
},
{
"id": "o1-2024-12-17",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "o1-mini-2024-09-12",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value
},
{
"id": "o1-preview-2024-09-12",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.FunctionCalling.value
},
{
"id": "o1-pro-2025-03-19",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "o3-2025-04-16",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value |
LLMAbility.FunctionCalling.value
},
{
"id": "o3-mini-2025-01-31",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value |
LLMAbility.FunctionCalling.value
},
{
"id": "o4-mini-2025-04-16",
"type": ModelType.LLM.value,
"abilities": LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
},
{
"id": "text-embedding-3-large",
"type": ModelType.Embedding.value,
"abilities": EmbeddingModelAbility.TextEmbedding.value | EmbeddingModelAbility.Batch.value
},
{
"id": "text-embedding-3-small",
"type": ModelType.Embedding.value,
"abilities": EmbeddingModelAbility.TextEmbedding.value | EmbeddingModelAbility.Batch.value
},
{
"id": "text-embedding-ada-002",
"type": ModelType.Embedding.value,
"abilities": EmbeddingModelAbility.TextEmbedding.value | EmbeddingModelAbility.Batch.value
},
{
"id": "tts-1-hd",
"type": ModelType.Audio.value,
"abilities": AudioModelAbility.Speech.value
},
{
"id": "tts-1",
"type": ModelType.Audio.value,
"abilities": AudioModelAbility.Speech.value
},
{
"id": "whisper-1",
"type": ModelType.Audio.value,
"abilities": AudioModelAbility.Transcription.value | AudioModelAbility.Translation.value
}
]
def test_guess_openai_model():
logger = get_logger("test_guess_openai_model")
failed_tests = []
for model in models:
model_type, abilities = guess_openai_model(model["id"])
logger.info(f"模型: {model['id']}, 模型类型: {model_type}, 能力: {abilities}, 预期: {model['type']}, {model['abilities']}")
try:
assert model_type is not None, f"模型 {model['id']} 的类型不应为 None"
assert model_type.value == model["type"], f"模型 {model['id']} 的类型不匹配,预期 {model['type']},实际 {model_type.value}"
if abilities != model["abilities"]:
diff_bits = abilities ^ model["abilities"]
expected_but_missing = diff_bits & model["abilities"]
unexpected_but_present = diff_bits & abilities
error_msg = f"模型 {model['id']} 的能力不匹配,预期 {model['abilities']},实际 {abilities}"
# 根据模型类型获取对应的能力枚举类
ability_enum = None
if model_type == ModelType.LLM:
ability_enum = LLMAbility
elif model_type == ModelType.Embedding:
ability_enum = EmbeddingModelAbility
elif model_type == ModelType.ImageGeneration:
ability_enum = ImageModelAbility
elif model_type == ModelType.Audio:
ability_enum = AudioModelAbility
# 获取缺少的能力名称
if expected_but_missing:
missing_abilities = []
for enum_item in ability_enum:
if expected_but_missing & enum_item.value:
missing_abilities.append(enum_item.name)
error_msg += f",缺少能力: {', '.join(missing_abilities)} ({bin(expected_but_missing)})"
# 获取多余的能力名称
if unexpected_but_present:
extra_abilities = []
for enum_item in ability_enum:
if unexpected_but_present & enum_item.value:
extra_abilities.append(enum_item.name)
error_msg += f",多余能力: {', '.join(extra_abilities)} ({bin(unexpected_but_present)})"
assert False, error_msg
except AssertionError as e:
failed_tests.append(str(e))
logger.error(str(e))
if failed_tests:
error_message = f"共有 {len(failed_tests)} 个测试点失败:\n" + "\n".join(failed_tests)
assert False, error_message
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/utils.py
================================================
import uuid
from typing import Optional, Tuple
from kirara_ai.llm.format.message import LLMChatContentPartType, LLMToolCallContent
from kirara_ai.llm.format.tool import Function, ToolCall
from kirara_ai.llm.model_types import AudioModelAbility, EmbeddingModelAbility, ImageModelAbility, LLMAbility, ModelType
def generate_tool_call_id(name: str) -> str:
return f"{name}_{str(uuid.uuid4())}"
def pick_tool_calls(calls: list[LLMChatContentPartType]) -> Optional[list[ToolCall]]:
tool_calls = [
ToolCall(
id=call.id,
function=Function(name=call.name, arguments=call.parameters)
) for call in calls if isinstance(call, LLMToolCallContent)
]
if tool_calls:
return tool_calls
else:
return None
def guess_openai_model(model_id: str) -> Tuple[ModelType, int] | None:
"""
根据模型ID猜测模型类型和能力
返回: (ModelType, ability_bitmask) 或 None
"""
model_id = model_id.lower()
# 1. 检查嵌入模型
if "embedding" in model_id:
return (ModelType.Embedding, EmbeddingModelAbility.TextEmbedding.value | EmbeddingModelAbility.Batch.value) # 嵌入模型
# 2. 检查图像生成模型
if "dall-e" in model_id or "gpt-image" in model_id:
ability = ImageModelAbility.TextToImage.value
if "dall-e-2" in model_id or "gpt-image" in model_id:
ability |= ImageModelAbility.ImageEdit.value | ImageModelAbility.Inpainting.value
return (ModelType.ImageGeneration, ability)
# 3. 检查音频模型
if "whisper" in model_id:
return (ModelType.Audio, AudioModelAbility.Transcription.value | AudioModelAbility.Translation.value)
if "tts" in model_id:
if "mini" in model_id:
return (ModelType.Audio, AudioModelAbility.Speech.value | AudioModelAbility.Streaming.value)
return (ModelType.Audio, AudioModelAbility.Speech.value)
if model_id == "gpt-4o-transcribe":
# 特别处理 gpt-4o-transcribe,没有 Translation 能力
return (ModelType.Audio, AudioModelAbility.Transcription.value | AudioModelAbility.Streaming.value | AudioModelAbility.Realtime.value)
if "transcribe" in model_id:
if "mini" in model_id:
return (ModelType.Audio, AudioModelAbility.Transcription.value | AudioModelAbility.Realtime.value)
if "realtime" in model_id:
return (ModelType.Audio, AudioModelAbility.Transcription.value | AudioModelAbility.Realtime.value)
else:
return (ModelType.Audio, AudioModelAbility.Transcription.value | AudioModelAbility.Translation.value | AudioModelAbility.Streaming.value | AudioModelAbility.Realtime.value)
# 4. 处理音频相关的LLM模型 (这些不应该有图像输入能力)
if ("audio" in model_id or "realtime" in model_id) and "4o" in model_id:
ability = LLMAbility.TextChat.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value
# 特殊情况处理
if "gpt-4o-mini-audio-preview-2024-12-17" in model_id:
ability |= LLMAbility.FunctionCalling.value
return (ModelType.LLM, ability)
if "gpt-4o-mini-realtime-preview-2024-12-17" in model_id:
ability |= LLMAbility.FunctionCalling.value
return (ModelType.LLM, ability)
if not ("mini-search" in model_id or "mini-realtime" in model_id or
"instruct" in model_id or "search" in model_id or
"mini" in model_id):
ability |= LLMAbility.FunctionCalling.value
return (ModelType.LLM, ability)
# 5. 检查moderation模型
if "moderation" in model_id:
return None
# 6. LLM模型 (默认情况)
ability = LLMAbility.TextChat.value
if ("babbage" in model_id or "davinci" in model_id):
return (ModelType.LLM, LLMAbility.TextInput.value | LLMAbility.TextOutput.value)
# 图像输入能力
if ("vision" in model_id or
"4o" in model_id or
"computer-use-preview" in model_id or
"gpt-4-turbo" in model_id or
"o1" in model_id or
"o4" in model_id or
"4." in model_id or
("4" in model_id and ("image" in model_id or "vision" in model_id))):
ability |= LLMAbility.ImageInput.value
# 大部分模型都应该有函数调用能力,除了特定例外
if not ("3.5" in model_id or
"3-5" in model_id or
"1106" in model_id or
"0314" in model_id or
"0125" in model_id or
"gpt-4-0" in model_id or
"chatgpt-4o" in model_id or
"instruct" in model_id or
"search" in model_id or
"o1-mini" in model_id or
model_id.startswith(("babbage", "davinci"))):
ability |= LLMAbility.FunctionCalling.value
# 特殊模型处理
if "o3-2025-04-16" in model_id:
ability = LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.FunctionCalling.value
if "o1-preview-2024-09-12" in model_id:
ability = LLMAbility.TextChat.value | LLMAbility.FunctionCalling.value
if "o1-mini-2024-09-12" in model_id:
ability = LLMAbility.TextChat.value
if "o3-mini-2025-01-31" in model_id:
ability = LLMAbility.TextChat.value | LLMAbility.FunctionCalling.value
return (ModelType.LLM, ability)
def guess_qwen_model(model_id: str) -> Tuple[ModelType, int] | None:
"""
根据模型ID猜测通义千问模型的类型和能力
返回: (ModelType, ability_bitmask) 或 None
"""
model_id = model_id.lower()
# 通义千问Embedding模型
if "text-embedding" in model_id:
return (ModelType.Embedding, EmbeddingModelAbility.TextEmbedding.value | EmbeddingModelAbility.Batch.value)
if "multimodal-embedding-v1" in model_id:
return (ModelType.Embedding, EmbeddingModelAbility.TextEmbedding.value | EmbeddingModelAbility.ImageEmbedding.value | EmbeddingModelAbility.AudioEmbedding.value | EmbeddingModelAbility.VideoEmbedding.value | EmbeddingModelAbility.Batch.value)
# 通义千问多模态模型
if "-vl" in model_id or "qvq-":
return (ModelType.LLM, LLMAbility.TextChat.value | LLMAbility.ImageInput.value)
if "-audio" in model_id:
return (ModelType.LLM, LLMAbility.TextChat.value | LLMAbility.AudioInput.value)
if "qwen-omni" in model_id:
return (ModelType.LLM, LLMAbility.TextChat.value | LLMAbility.ImageInput.value | LLMAbility.AudioInput.value | LLMAbility.AudioOutput.value)
# 通义千问系列基础模型
if "qwen" in model_id:
return (ModelType.LLM, LLMAbility.TextChat.value | LLMAbility.FunctionCalling.value)
return None
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/volcengine_adapter.py
================================================
import datetime
import hashlib
import hmac
from urllib.parse import quote
import aiohttp
from pydantic import Field
from kirara_ai.config.global_config import ModelConfig
from kirara_ai.llm.model_types import LLMAbility, ModelType
from .openai_adapter import OpenAIAdapter, OpenAIConfig
class VolcengineConfig(OpenAIConfig):
api_base: str = "https://ark.cn-beijing.volces.com/api/v3"
access_key_id: str = Field(description="火山云引擎 API 密钥 ID,用于获取模型列表")
access_key_secret: str = Field(description="火山云引擎 API 密钥,用于获取模型列表")
def generate_volcengine_signature(access_key_id, access_key_secret, method, path, query, body=None):
"""生成火山引擎API所需的HMAC-SHA256签名"""
# 初始化参数
service = "ark"
region = "cn-beijing"
host = "open.volcengineapi.com"
content_type = "application/json"
# 获取UTC时间
now = datetime.datetime.utcnow()
x_date = now.strftime("%Y%m%dT%H%M%SZ")
short_date = x_date[:8]
# 计算请求体的SHA256哈希值
body = body or ""
x_content_sha256 = hashlib.sha256(body.encode("utf-8")).hexdigest()
# 规范化查询字符串
canonical_query = normalize_query(query)
# 签名所需的头部
signed_headers = ["host", "x-content-sha256", "x-date"]
signed_headers_str = ";".join(signed_headers)
# 构建规范请求字符串
canonical_headers = "\n".join([
f"host:{host}",
f"x-content-sha256:{x_content_sha256}",
f"x-date:{x_date}",
])
canonical_request = "\n".join([
method.upper(),
path,
canonical_query,
canonical_headers,
"",
signed_headers_str,
x_content_sha256
])
# 计算规范请求的哈希值
hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
# 构建签名字符串
credential_scope = f"{short_date}/{region}/{service}/request"
string_to_sign = "\n".join(["HMAC-SHA256", x_date, credential_scope, hashed_canonical_request])
# 计算签名
k_date = hmac.new(access_key_secret.encode("utf-8"), short_date.encode("utf-8"), hashlib.sha256).digest()
k_region = hmac.new(k_date, region.encode("utf-8"), hashlib.sha256).digest()
k_service = hmac.new(k_region, service.encode("utf-8"), hashlib.sha256).digest()
k_signing = hmac.new(k_service, b"request", hashlib.sha256).digest()
signature = hmac.new(k_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
# 构建Authorization头
authorization = f"HMAC-SHA256 Credential={access_key_id}/{credential_scope}, SignedHeaders={signed_headers_str}, Signature={signature}"
# 返回所有需要的头部
return {
"Host": host,
"X-Content-Sha256": x_content_sha256,
"X-Date": x_date,
"Authorization": authorization
}
def normalize_query(params):
"""规范化查询参数"""
if not params:
return ""
query = ""
for key in sorted(params.keys()):
if isinstance(params[key], list):
for k in params[key]:
query += quote(key, safe="-_.~") + "=" + quote(str(k), safe="-_.~") + "&"
else:
query += quote(key, safe="-_.~") + "=" + quote(str(params[key]), safe="-_.~") + "&"
return query[:-1].replace("+", "%20") if query else ""
def detect_ability(model: dict) -> int:
ability = LLMAbility.TextChat.value
# detect visual qa
if "VisualQuestionAnswering" in model.get("TaskTypes", []):
ability |= LLMAbility.ImageInput.value
# detect function calling
if "Function Call" in model.get("CustomizedTags", []):
ability |= LLMAbility.FunctionCalling.value
return ability
class VolcengineAdapter(OpenAIAdapter):
config: VolcengineConfig
def __init__(self, config: VolcengineConfig):
super().__init__(config)
async def auto_detect_models(self) -> list[ModelConfig]:
"""
获取火山引擎可用的模型列表,支持分页获取所有结果
{
"Result": {
"TotalCount": 39,
"PageNumber": 1,
"PageSize": 10,
"Items": [
{
"Name": "doubao-1-5-pro-256k",
"VendorName": "字节跳动",
"DisplayName": "Doubao-1.5-pro-256k",
"ShortName": "",
"PrimaryVersion": "250115",
"FoundationModelTag": {
"Domains": [
"LLM"
],
"Languages": [
"中英文"
],
"TaskTypes": [
"Chat"
],
"ContextLength": "256k",
"CustomizedTags": [
"支持体验"
]
},
}
]
}
}
"""
host = "open.volcengineapi.com"
path = "/"
all_models: list[ModelConfig] = []
page_number = 1
page_size = 100
total_pages = 1 # 初始值,会在第一次请求后更新
while page_number <= total_pages:
query = {
"Action": "ListFoundationModels",
"Version": "2024-01-01",
"PageNumber": page_number,
"PageSize": page_size
}
# 生成签名和头部
headers = generate_volcengine_signature(
self.config.access_key_id,
self.config.access_key_secret,
"GET",
path,
query
)
# 构建完整URL
url = f"https://{host}{path}"
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, headers=headers, params=query) as response:
response.raise_for_status()
response_data = await response.json()
if "Result" in response_data:
response_data = response_data["Result"]
else:
return []
# 更新总页数(如果API返回了这个信息)
if "TotalCount" in response_data and "PageSize" in response_data:
total_count = response_data["TotalCount"]
total_pages = (total_count + page_size - 1) // page_size
for model in response_data["Items"]:
foundation_model = model.get("FoundationModelTag", {})
if ("LLM" in foundation_model.get("Domains", []) and
model.get("Name")):
ability = detect_ability(foundation_model)
all_models.append(ModelConfig(id=model["Name"], type=ModelType.LLM.value, ability=ability))
# 准备获取下一页
page_number += 1
return all_models
================================================
FILE: kirara_ai/plugins/llm_preset_adapters/voyage_adapter.py
================================================
from pydantic import BaseModel, ConfigDict
from typing import cast, TypedDict, Literal, Optional
import requests
import asyncio
from kirara_ai.llm.adapter import LLMBackendAdapter, LLMEmbeddingProtocol, LLMReRankProtocol
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.llm.format.rerank import LLMReRankRequest, LLMReRankResponse, ReRankerContent
from kirara_ai.llm.format.message import LLMChatTextContent, LLMChatImageContent
from kirara_ai.llm.format.response import Usage
from kirara_ai.media.manager import MediaManager
from kirara_ai.logger import get_logger
logger = get_logger("VoyageAdapter")
async def resolve_media_base64(inputs: list[LLMChatImageContent|LLMChatTextContent], media_manager: MediaManager) -> list:
results = []
for input in inputs:
# voyage 的多模态接口设置中会将 一个content字段中的所有payload视作一个输入集,并对这个输入集合生成一个向量.
# 所以这里对 image 做出处理,将其描述与原始图像打包为一个payload.
if isinstance(input, LLMChatTextContent):
results.append({
"content": [
{"type": "text", "text": input.text}
]
})
elif isinstance(input, LLMChatImageContent):
media = media_manager.get_media(input.media_id)
if media is None:
raise ValueError(f"Media {input.media_id} not found")
results.append({
"content": [
{"type": "text", "text": "" if (desc := media.description) is None else desc},
{"type": "image_base64", "image_base64": await media.get_base64()}
]
})
return results
class ReRankData(TypedDict):
index: int
relevance_score: float
document: Optional[str]
class ReRankResponse(TypedDict):
"""给mypy检查用, 顺便给开发者标识返回json的基本结构。"""
object: Literal["list"]
data: list[ReRankData]
model: str
usage: dict[Literal["total_tokens"], int]
class EmbeddingData(TypedDict):
object: Literal["embedding"]
embedding: list[float | int]
index: int
class EmbeddingResponse(TypedDict):
object: Literal["list"]
data: list[EmbeddingData]
model: str
usage: dict[Literal["total_tokens"], int]
class ModalEmbeddingResponse(TypedDict):
object: Literal["list"]
data: list[EmbeddingData]
model: str
# voyage 的多模态接口会返回三个usage指标: text_tokens: 文字使用token数, image_pixels: 图片像素数, total_tokens: 总token数
usage: dict[Literal["text_tokens", "image_pixels", "total_tokens"], int]
class VoyageConfig(BaseModel):
api_key: str
api_base: str = "https://api.voyageai.com"
model_config = ConfigDict(frozen=True)
class VoyageAdapter(LLMBackendAdapter, LLMEmbeddingProtocol, LLMReRankProtocol):
media_manager: MediaManager
def __init__(self, config: VoyageConfig):
self.config = config
def embed(self, req: LLMEmbeddingRequest) -> LLMEmbeddingResponse:
# voyage 支持多模态嵌入, 但是两个接口支持的参数不同。
# 因此对其做区分,以充分利用 voyage 接口提供的可选参数。
if any(isinstance(input, LLMChatImageContent) for input in req.inputs):
return self._multi_modal_embedding(req)
else:
return self._text_embedding(req)
def _text_embedding(self, req: LLMEmbeddingRequest) -> LLMEmbeddingResponse:
api_url = f"{self.config.api_base}/v1/embeddings"
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
}
inputs = cast(list[LLMChatTextContent], req.inputs)
data = {
"model": req.model,
"input": [input.text for input in inputs],
"truncation": req.truncate,
"input_type": req.input_type,
"output_dimension": req.dimension,
"output_dtype": req.encoding_format,
"encoding_format": req.encoding_format,
}
data = { k:v for k,v in data.items() if v is not None }
response = requests.post(api_url, headers=headers, json=data)
try:
response.raise_for_status()
response_data: EmbeddingResponse = response.json()
except Exception as e:
logger.error(f"Response: {response.text}")
raise e
return LLMEmbeddingResponse(
vectors=[data["embedding"] for data in response_data["data"]],
usage = Usage(
total_tokens=response_data["usage"].get("total_tokens", 0)
)
)
def _multi_modal_embedding(self, req: LLMEmbeddingRequest) -> LLMEmbeddingResponse:
api_url = f"{self.config.api_base}/v1/multimodalembeddings"
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
}
# loop = asyncio.new_event_loop()
# try:
# asyncio.set_event_loop(loop)
# data = {
# "model": req.model,
# "inputs": loop.run_until_complete(resolve_media_base64(req.inputs, self.media_manager)),
# "input_type": req.input_type,
# "truncation": req.truncate,
# "output_encoding": req.encoding_format
# }
# finally:
# loop.close() # 关闭事件循环,避免资源泄露。
# asyncio.set_event_loop(None) # 解除 asyncio 事件循环绑定。避免get_running_loop()获取到已结束时间循环。
data = {
"model": req.model,
# 为何不使用神奇的 asyncio.run() 自动管理这个临时loop的生命周期呢。(python 3.7+)
"inputs": asyncio.run(resolve_media_base64(req.inputs, self.media_manager)),
"input_type": req.input_type,
"truncation": req.truncate,
"output_encoding": req.encoding_format
}
data = { k:v for k,v in data.items() if v is not None }
response = requests.post(api_url, headers=headers, json=data)
try:
response.raise_for_status()
response_data: ModalEmbeddingResponse = response.json()
except Exception as e:
logger.error(f"Response: {response.text}")
raise e
return LLMEmbeddingResponse(
vectors=[data["embedding"] for data in response_data["data"]],
usage = Usage(
total_tokens=response_data["usage"].get("total_tokens", 0)
)
)
def rerank(self, req: LLMReRankRequest) -> LLMReRankResponse:
api_url = f"{self.config.api_base}/v1/rerank"
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
}
data = {
"query": req.query,
"documents": req.documents,
"model": req.model,
"top_k": req.top_k,
"return_documents": req.return_documents,
"truncation": req.truncation
}
# 去除 None 值
data = { k:v for k,v in data.items() if v is not None }
response = requests.post(api_url, headers=headers, json=data)
try:
response.raise_for_status()
response_data: ReRankResponse = response.json()
logger.debug(f"server response_data: {response_data}")
except Exception as e:
logger.error(f"Response: {response.text}")
raise e
return LLMReRankResponse(
contents = [ReRankerContent(
document = data.get("document", None),
score = data["relevance_score"]
) for data in response_data["data"]],
usage = Usage(
total_tokens = response_data["usage"].get("total_tokens", 0)
),
sort = cast(bool, req.sort) # 强制类型转换,避免mypy报错。
)
================================================
FILE: kirara_ai/system/__init__.py
================================================
================================================
FILE: kirara_ai/system/updater.py
================================================
================================================
FILE: kirara_ai/tracing/__init__.py
================================================
from kirara_ai.tracing.core import TracerBase
from kirara_ai.tracing.decorator import trace_llm_chat
from kirara_ai.tracing.llm_tracer import LLMTracer
from kirara_ai.tracing.manager import TracingManager
from kirara_ai.tracing.models import LLMRequestTrace
__all__ = [
"TracingManager",
"LLMRequestTrace",
"TracerBase",
"LLMTracer",
"trace_llm_chat"
]
================================================
FILE: kirara_ai/tracing/core.py
================================================
import abc
import asyncio
import uuid
from asyncio import Queue
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from sqlalchemy import Column, DateTime, String, asc
from kirara_ai.database import Base, DatabaseManager
from kirara_ai.events.event_bus import EventBus
from kirara_ai.events.tracing import TraceEvent
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.logger import get_logger
logger = get_logger("Tracking")
class TraceRecord(Base):
"""跟踪记录基类,用于ORM映射"""
__abstract__ = True
trace_id = Column(String(64), nullable=False, index=True, unique=True)
status = Column(String(20), nullable=False, default="pending")
request_time = Column(DateTime, nullable=False, index=True)
@abc.abstractmethod
def update_from_event(self, event: TraceEvent) -> None:
"""从事件更新记录"""
@abc.abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""将记录转换为字典,用于JSON序列化"""
@abc.abstractmethod
def to_detail_dict(self) -> Dict[str, Any]:
"""将详细记录转换为字典,用于JSON序列化"""
def generate_trace_id() -> str:
"""生成唯一的追踪ID"""
return str(uuid.uuid4())
# 定义泛型类型变量,用于追踪事件、追踪器和追踪记录
T = TypeVar('T')
E = TypeVar('E', bound=TraceEvent) # 事件类型
R = TypeVar('R', bound=TraceRecord) # 记录类型
class TracerBase(Generic[R], abc.ABC):
"""追踪器基类"""
# 追踪器名称,用于区分不同类型的追踪器
name: str
record_class: Type[R]
@Inject()
def __init__(self, container: DependencyContainer, record_class: Type[R], db_manager: DatabaseManager, event_bus: EventBus):
self.record_class = record_class
self.container = container
self.db_manager = db_manager
self.event_bus = event_bus
self.logger = logger
# 活跃追踪的映射表
self._active_traces: Dict[str, Dict[str, Any]] = {}
# WebSocket消息队列映射表
self._ws_queues: List[Queue] = []
def initialize(self):
"""初始化追踪器,注册事件处理程序"""
self.logger.info(f"Initializing {self.name} tracer")
self._register_event_handlers()
self.logger.info(f"{self.name} tracer initialized")
def shutdown(self):
"""关闭追踪器,取消事件注册"""
self.logger.info(f"Shutting down {self.name} tracer")
self._unregister_event_handlers()
# 关闭所有WebSocket连接
for queue in list(self._ws_queues):
try:
queue.put_nowait(None)
except Exception:
pass
self._ws_queues.clear()
@abc.abstractmethod
def _register_event_handlers(self):
"""注册事件处理程序"""
@abc.abstractmethod
def _unregister_event_handlers(self):
"""取消事件处理程序注册"""
def get_traces(
self,
filters: Optional[Dict[str, Any]] = None,
page: int = 1,
page_size: int = 20,
order_by: str = "request_time",
order_desc: bool = True
) -> Tuple[List[R], int]:
"""统一的追踪记录查询方法
Args:
filters: 过滤条件字典
page: 页码(从1开始)
page_size: 每页记录数
order_by: 排序字段
order_desc: 是否降序排序
Returns:
Tuple[List[R], int]: 记录列表和总记录数
"""
with self.db_manager.get_session() as session:
from sqlalchemy import desc, func, select
# 构建基础查询
query = select(self.record_class)
count_query = select(func.count()).select_from(self.record_class)
# 应用过滤条件
if filters:
for field, value in filters.items():
if value is not None and hasattr(self.record_class, field):
query = query.filter(getattr(self.record_class, field) == value)
count_query = count_query.filter(getattr(self.record_class, field) == value)
# 应用排序
if hasattr(self.record_class, order_by):
order_func = desc if order_desc else asc
query = query.order_by(order_func(getattr(self.record_class, order_by)))
# 应用分页
if page > 0 and page_size > 0:
query = query.offset((page - 1) * page_size).limit(page_size)
# 执行查询
total = session.execute(count_query).scalar() or 0
records = list(session.execute(query).scalars().all())
return records, total
def get_recent_traces(self, limit: int = 100) -> List[R]:
"""获取最近的跟踪记录"""
with self.db_manager.get_session() as session:
from sqlalchemy import desc, select
stmt = select(self.record_class).order_by(desc(self.record_class.request_time)).limit(limit)
result = session.execute(stmt)
return list(result.scalars().all())
def get_trace_by_id(self, trace_id: str) -> Optional[R]:
"""根据追踪ID获取跟踪记录"""
with self.db_manager.get_session() as session:
return session.query(self.record_class).filter_by(trace_id=trace_id).first()
def save_trace_record(self, record: R) -> Dict[str, Any]:
"""保存追踪记录到数据库"""
with self.db_manager.get_session() as session:
session.add(record)
session.commit()
return record.to_dict()
def update_trace_record(self, trace_id: str, event: TraceEvent) -> Optional[Dict[str, Any]]:
"""更新追踪记录"""
with self.db_manager.get_session() as session:
if (
record := session.query(self.record_class)
.filter_by(trace_id=trace_id)
.first()
):
record.update_from_event(event)
session.commit()
return record.to_dict()
return None
# WebSocket相关方法
def register_ws_client(self) -> Queue:
"""注册WebSocket客户端,返回一个消息队列"""
queue: Queue = Queue()
self._ws_queues.append(queue)
return queue
def unregister_ws_client(self, queue: Queue):
"""注销WebSocket客户端"""
if queue in self._ws_queues:
self._ws_queues.remove(queue)
def broadcast_ws_message(self, message: Dict):
"""向所有WebSocket客户端广播消息"""
dead_queues = []
for queue in self._ws_queues:
try:
queue.put_nowait(message)
except asyncio.QueueFull:
self.logger.warning(f"Queue is full, message dropped")
dead_queues.append(queue)
except Exception as e:
self.logger.error(f"Error broadcasting message: {e}")
dead_queues.append(queue)
# 清理失效的队列
for queue in dead_queues:
if queue in self._ws_queues:
self._ws_queues.remove(queue)
================================================
FILE: kirara_ai/tracing/decorator.py
================================================
import functools
from typing import Callable
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse
from kirara_ai.tracing.llm_tracer import LLMTracer
def trace_llm_chat(func: Callable):
"""装饰器,用于追踪LLM请求"""
from kirara_ai.llm.adapter import LLMBackendAdapter
@functools.wraps(func)
def wrapper(self: LLMBackendAdapter, req: LLMChatRequest) -> LLMChatResponse:
tracer: LLMTracer = self.tracer
# 开始追踪
trace_id = tracer.start_request_tracking(self.backend_name, req)
try:
# 调用原始方法
response = func(self, req)
except Exception as e:
# 记录错误
tracer.fail_request_tracking(trace_id, req, str(e))
raise e
else:
# 完成追踪
tracer.complete_request_tracking(trace_id, req, response)
return response
return wrapper
================================================
FILE: kirara_ai/tracing/llm_tracer.py
================================================
from datetime import datetime, timedelta
from typing import Any, Dict
from sqlalchemy import case, func
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.events.tracing import LLMRequestCompleteEvent, LLMRequestFailEvent, LLMRequestStartEvent
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.llm.format.message import LLMChatMessage, LLMChatTextContent
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse, Message
from kirara_ai.logger import get_logger
from kirara_ai.tracing.core import TracerBase, generate_trace_id
from kirara_ai.tracing.models import LLMRequestTrace
logger = get_logger("LLMTracer")
UNRECORD_REQUEST = [LLMChatMessage(
role="system",
content=[
LLMChatTextContent(
text="*** 内容未记录 ***"
)
]
)]
UNRECORD_RESPONSE = Message(
role="assistant",
content=[
LLMChatTextContent(
text="*** 内容未记录 ***"
)
]
)
class LLMTracer(TracerBase[LLMRequestTrace]):
"""LLM追踪器,负责处理LLM请求的跟踪"""
name = "llm"
record_class = LLMRequestTrace
@Inject()
def __init__(self, container: DependencyContainer):
super().__init__(container, record_class=LLMRequestTrace) # type: ignore
self.config = container.resolve(GlobalConfig)
def initialize(self):
"""启动追踪器,将所有 pending 状态的任务转为 failed,并清理超过 30 天的请求"""
super().initialize()
try:
pending_traces = self._mark_pending_as_failed()
deleted_count = self._clean_old_traces()
if pending_traces or deleted_count:
self.logger.info(f"已将 {pending_traces} 个 未结束状态的 LLM 请求标记为失败,并清理了 {deleted_count} 个超过 30 天的请求记录")
except Exception as e:
self.logger.opt(exception=e).error(f"处理历史追踪记录时发生错误")
def _mark_pending_as_failed(self) -> int:
"""将所有 pending 状态的任务转为 failed"""
with self.db_manager.get_session() as session:
pending_traces = session.query(LLMRequestTrace).filter(
LLMRequestTrace.status == "pending" # type: ignore
).all()
for trace in pending_traces:
trace.status = "failed" # type: ignore
trace.error = "Incomplete request" # type: ignore
session.commit()
return len(pending_traces)
def _clean_old_traces(self, days: int = 30) -> int:
"""清理超过指定天数的请求"""
with self.db_manager.get_session() as session:
days_ago = datetime.now() - timedelta(days=days)
deleted_count = session.query(LLMRequestTrace).filter(
LLMRequestTrace.request_time < days_ago # type: ignore
).delete()
session.commit()
return deleted_count
def _register_event_handlers(self):
"""注册事件处理程序"""
self.event_bus.register(LLMRequestStartEvent, self._on_request_start)
self.event_bus.register(LLMRequestCompleteEvent, self._on_request_complete)
self.event_bus.register(LLMRequestFailEvent, self._on_request_fail)
def _unregister_event_handlers(self):
"""取消事件处理程序注册"""
self.event_bus.unregister(LLMRequestStartEvent, self._on_request_start)
self.event_bus.unregister(LLMRequestCompleteEvent, self._on_request_complete)
self.event_bus.unregister(LLMRequestFailEvent, self._on_request_fail)
def start_request_tracking(
self,
backend_name: str,
request: LLMChatRequest
) -> str:
"""开始跟踪LLM请求"""
trace_id = generate_trace_id()
event = LLMRequestStartEvent(
trace_id=trace_id,
model_id=request.model or 'unknown',
backend_name=backend_name,
request=request.model_copy(deep=True)
)
# 存储活跃追踪信息
self._active_traces[trace_id] = {
'backend_name': backend_name,
'start_time': event.start_time
}
# 发布事件
self.event_bus.post(event)
return trace_id
def complete_request_tracking(
self,
trace_id: str,
request: LLMChatRequest,
response: LLMChatResponse
):
"""完成LLM请求跟踪"""
if trace_id in self._active_traces:
trace_data = self._active_traces[trace_id]
model_id = request.model or trace_data.get('model_id', "unknown")
backend_name = trace_data.get('backend_name', "unknown")
start_time = trace_data.get('start_time', 0)
event = LLMRequestCompleteEvent(
trace_id=trace_id,
model_id=model_id,
backend_name=backend_name,
request=request.model_copy(deep=True),
response=response.model_copy(deep=True),
start_time=start_time
)
# 移除活跃追踪
del self._active_traces[trace_id]
# 发布事件
self.event_bus.post(event)
def fail_request_tracking(
self,
trace_id: str,
request: LLMChatRequest,
error: Any
):
"""记录LLM请求失败"""
if trace_id in self._active_traces:
trace_data = self._active_traces[trace_id]
model_id = request.model or trace_data.get('model_id', "unknown")
backend_name = trace_data.get('backend_name', "unknown")
start_time = trace_data.get('start_time', 0)
event = LLMRequestFailEvent(
trace_id=trace_id,
model_id=model_id,
backend_name=backend_name,
request=request.model_copy(deep=True),
error=error,
start_time=start_time
)
# 移除活跃追踪
del self._active_traces[trace_id]
# 发布事件
self.event_bus.post(event)
else:
self.logger.warning(f"LLM request failed: {trace_id} not found")
def _on_request_start(self, event: LLMRequestStartEvent):
"""处理请求开始事件"""
self.logger.debug(f"LLM request started: {event.trace_id}")
if not self.config.tracing.llm_tracing_content:
event.request.messages = UNRECORD_REQUEST
# 创建数据库记录
trace = LLMRequestTrace()
trace.update_from_event(event)
# 保存记录到数据库
trace_dict = self.save_trace_record(trace)
# 向WebSocket客户端广播消息
self.broadcast_ws_message({
"type": "new",
"data": trace_dict
})
def _on_request_complete(self, event: LLMRequestCompleteEvent):
"""处理请求完成事件"""
self.logger.debug(f"LLM request completed: {event.trace_id}")
if not self.config.tracing.llm_tracing_content:
event.request.messages = UNRECORD_REQUEST
event.response.message = UNRECORD_RESPONSE
if trace := self.update_trace_record(event.trace_id, event):
self.broadcast_ws_message({
"type": "update",
"data": trace
})
def _on_request_fail(self, event: LLMRequestFailEvent):
"""处理请求失败事件"""
self.logger.debug(f"LLM request failed: {event.trace_id}")
if not self.config.tracing.llm_tracing_content:
event.request.messages = UNRECORD_REQUEST
# 更新数据库记录
trace = self.update_trace_record(event.trace_id, event)
# 广播WebSocket消息
if trace:
self.broadcast_ws_message({
"type": "update",
"data": trace
})
def get_statistics(self) -> Dict:
"""获取统计信息"""
with self.db_manager.get_session() as session:
# 基础统计
total_count = session.query(func.count(LLMRequestTrace.id)).scalar() or 0
success_count = session.query(func.count(LLMRequestTrace.id)).filter_by(status="success").scalar() or 0
failed_count = session.query(func.count(LLMRequestTrace.id)).filter_by(status="failed").scalar() or 0
pending_count = session.query(func.count(LLMRequestTrace.id)).filter_by(status="pending").scalar() or 0
total_tokens = session.query(func.sum(LLMRequestTrace.total_tokens)).scalar() or 0
# 获取30天内的每日统计
thirty_days_ago = datetime.now() - timedelta(days=30)
daily_stats = session.query(
func.strftime('%Y-%m-%d', LLMRequestTrace.request_time).label('date'),
func.count(LLMRequestTrace.id).label('requests'),
func.sum(LLMRequestTrace.total_tokens).label('tokens'),
func.sum(case((LLMRequestTrace.status == 'success', 1), else_=0)).label('success'), # type: ignore
func.sum(case((LLMRequestTrace.status == 'failed', 1), else_=0)).label('failed') # type: ignore
).filter(
LLMRequestTrace.request_time >= thirty_days_ago # type: ignore
).group_by(
func.strftime('%Y-%m-%d', LLMRequestTrace.request_time)
).order_by(
func.strftime('%Y-%m-%d', LLMRequestTrace.request_time)
).all()
daily_data = [{
'date': str(row.date),
'requests': row.requests,
'tokens': row.tokens or 0,
'success': row.success,
'failed': row.failed
} for row in daily_stats]
# 按模型分组统计(最近30天)
model_stats = []
model_counts = session.query(
LLMRequestTrace.model_id, # type: ignore
func.count(LLMRequestTrace.id).label('count'),
func.sum(LLMRequestTrace.total_tokens).label('tokens'),
func.avg(LLMRequestTrace.duration).label('avg_duration')
).filter( # type: ignore
LLMRequestTrace.request_time >= thirty_days_ago # type: ignore
).group_by(
LLMRequestTrace.model_id
).all()
for model_id, count, tokens, avg_duration in model_counts:
model_stats.append({
'model_id': model_id,
'count': count,
'tokens': tokens or 0,
'avg_duration': float(avg_duration) if avg_duration else 0
})
# 按后端分组统计(最近30天)
backend_stats = []
backend_counts = session.query(
LLMRequestTrace.backend_name, # type: ignore
func.count(LLMRequestTrace.id).label('count'),
func.sum(LLMRequestTrace.total_tokens).label('tokens'),
func.avg(LLMRequestTrace.duration).label('avg_duration')
).filter( # type: ignore
LLMRequestTrace.request_time >= thirty_days_ago # type: ignore
).group_by(
LLMRequestTrace.backend_name
).all()
for backend_name, count, tokens, avg_duration in backend_counts:
backend_stats.append({
'backend_name': backend_name,
'count': count,
'tokens': tokens or 0,
'avg_duration': float(avg_duration) if avg_duration else 0
})
# 获取每小时统计(最近24小时)
one_day_ago = datetime.now() - timedelta(hours=24)
hourly_stats = session.query(
func.strftime('%Y-%m-%d %H:00:00', LLMRequestTrace.request_time).label('hour'),
func.count(LLMRequestTrace.id).label('requests'),
func.sum(LLMRequestTrace.total_tokens).label('tokens')
).filter(
LLMRequestTrace.request_time >= one_day_ago # type: ignore
).group_by(
func.strftime('%Y-%m-%d %H:00:00', LLMRequestTrace.request_time)
).order_by(
func.strftime('%Y-%m-%d %H:00:00', LLMRequestTrace.request_time)
).all()
hourly_data = [{
'hour': str(row.hour),
'requests': row.requests,
'tokens': row.tokens or 0
} for row in hourly_stats]
return {
'overview': {
'total_requests': total_count,
'success_requests': success_count,
'failed_requests': failed_count,
'pending_requests': pending_count,
'total_tokens': total_tokens,
},
'daily_stats': daily_data,
'hourly_stats': hourly_data,
'models': model_stats,
'backends': backend_stats
}
================================================
FILE: kirara_ai/tracing/manager.py
================================================
import asyncio
from typing import Dict, List, Optional
from kirara_ai.database import DatabaseManager
from kirara_ai.events.event_bus import EventBus
from kirara_ai.events.tracing import TraceEvent
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.logger import get_logger
from kirara_ai.tracing.core import TracerBase, TraceRecord
logger = get_logger("TracingManager")
class TracingManager:
"""追踪管理器,负责管理所有类型的追踪器和协调追踪操作"""
@Inject()
def __init__(self, container: DependencyContainer, database_manager: DatabaseManager, event_bus: EventBus):
self.container = container
self.db_manager = database_manager
self.event_bus = event_bus
self.tracers: Dict[str, TracerBase] = {}
self.logger = logger
def initialize(self):
"""初始化追踪管理器"""
self.logger.info("Initializing tracing manager")
# 初始化所有注册的追踪器
for name, tracer in self.tracers.items():
try:
tracer.initialize()
except Exception as e:
self.logger.error(f"Failed to initialize tracer {name}: {e}")
self.logger.info("Tracing manager initialized")
def shutdown(self):
"""关闭追踪管理器"""
self.logger.info("Shutting down tracing manager")
# 关闭所有追踪器
for name, tracer in self.tracers.items():
try:
tracer.shutdown()
except Exception as e:
self.logger.error(f"Failed to shutdown tracer {name}: {e}")
def register_tracer(self, name: str, tracer: TracerBase):
"""注册追踪器"""
if name in self.tracers:
raise ValueError(f"Tracer {name} already registered")
self.tracers[name] = tracer
def get_tracer(self, name: str) -> Optional[TracerBase]:
"""获取指定名称的追踪器"""
return self.tracers.get(name)
def get_all_tracers(self) -> Dict[str, TracerBase]:
"""获取所有追踪器"""
return self.tracers.copy()
def get_tracer_types(self) -> List[str]:
"""获取所有追踪器类型"""
return list(self.tracers.keys())
def publish_event(self, event: TraceEvent):
"""发布追踪事件"""
self.event_bus.post(event)
# WebSocket相关方法
def register_ws_client(self, tracer_name: str) -> asyncio.Queue:
"""为指定追踪器注册WebSocket客户端"""
if tracer := self.tracers.get(tracer_name):
return tracer.register_ws_client()
else:
raise ValueError(f"Tracer {tracer_name} not found")
def unregister_ws_client(self, tracer_name: str, queue: asyncio.Queue):
"""从指定追踪器注销WebSocket客户端"""
if tracer := self.tracers.get(tracer_name):
tracer.unregister_ws_client(queue)
else:
raise ValueError(f"Tracer {tracer_name} not found")
# 通用追踪操作方法
def get_recent_traces(self, tracer_name: str, limit: int = 100) -> List[TraceRecord]:
"""获取指定追踪器的最近追踪记录"""
if tracer := self.get_tracer(tracer_name):
return tracer.get_recent_traces(limit)
else:
raise ValueError(f"Tracer {tracer_name} not found")
def get_trace_by_id(self, tracer_name: str, trace_id: str) -> Optional[TraceRecord]:
"""获取指定追踪器的特定追踪记录"""
if tracer := self.get_tracer(tracer_name):
return tracer.get_trace_by_id(trace_id)
else:
raise ValueError(f"Tracer {tracer_name} not found")
================================================
FILE: kirara_ai/tracing/models.py
================================================
import json
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import Column, DateTime, Float, Index, Integer, String, Text
from kirara_ai.events.tracing import LLMRequestCompleteEvent, LLMRequestFailEvent, LLMRequestStartEvent
from kirara_ai.tracing.core import TraceEvent, TraceRecord
class LLMRequestTrace(TraceRecord):
"""LLM请求跟踪记录"""
__tablename__ = "llm_request_traces"
id = Column(Integer, primary_key=True, autoincrement=True)
trace_id = Column(String(64), nullable=False, index=True, unique=True)
model_id = Column(String(64), nullable=False, index=True)
backend_name = Column(String(64), nullable=False, index=True)
# 时间相关
request_time = Column(DateTime, nullable=False, index=True)
response_time = Column(DateTime, nullable=True)
duration = Column(Float, nullable=True)
# 请求和响应内容
request_json = Column(Text, nullable=True)
response_json = Column(Text, nullable=True)
# 令牌使用情况
prompt_tokens = Column(Integer, nullable=True)
completion_tokens = Column(Integer, nullable=True)
total_tokens = Column(Integer, nullable=True)
cached_tokens = Column(Integer, nullable=True)
# 错误信息
error = Column(Text, nullable=True)
status = Column(String(20), nullable=False, default="pending")
# 创建索引
__table_args__ = (
Index('idx_request_model', 'model_id', 'request_time'),
Index('idx_backend_time', 'backend_name', 'request_time'),
Index('idx_status_time', 'status', 'request_time'),
)
def __repr__(self):
return f""
def update_from_event(self, event: TraceEvent) -> None:
"""从事件更新记录"""
if isinstance(event, LLMRequestStartEvent):
self.trace_id = event.trace_id
self.model_id = event.model_id
self.backend_name = event.backend_name
self.request_time = datetime.fromtimestamp(event.start_time)
self.status = "pending"
if event.request:
self.request = event.request.model_dump()
elif isinstance(event, LLMRequestCompleteEvent):
self.response_time = datetime.fromtimestamp(event.end_time)
self.duration = event.duration
self.status = "success"
# 记录令牌使用情况
if event.response and event.response.usage:
self.prompt_tokens = event.response.usage.prompt_tokens
self.completion_tokens = event.response.usage.completion_tokens
self.total_tokens = event.response.usage.total_tokens
self.cached_tokens = event.response.usage.cached_tokens
# 记录响应内容
if event.response:
self.response = event.response.model_dump()
elif isinstance(event, LLMRequestFailEvent):
self.response_time = datetime.fromtimestamp(event.end_time)
self.duration = event.duration
self.error = event.error
self.status = "failed"
def to_dict(self) -> Dict[str, Any]:
"""将记录转换为基本字典,用于JSON序列化"""
return {
"id": self.id,
"trace_id": self.trace_id,
"model_id": self.model_id,
"backend_name": self.backend_name,
"request_time": self.request_time.isoformat() if self.request_time else None, # type: ignore
"response_time": self.response_time.isoformat() if self.response_time else None, # type: ignore
"duration": self.duration,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"cached_tokens": self.cached_tokens,
"status": self.status,
"error": self.error
}
def to_detail_dict(self) -> Dict[str, Any]:
"""将记录转换为详细字典,包含请求和响应内容"""
result = self.to_dict()
result["request"] = self.request
result["response"] = self.response
return result
@property
def request(self) -> Optional[Dict[str, Any]]:
"""获取请求内容"""
return json.loads(self.request_json) if self.request_json else None # type: ignore
@request.setter
def request(self, value: Any):
"""设置请求内容"""
if value:
self.request_json = json.dumps(value, ensure_ascii=False, default=str)
@property
def response(self) -> Optional[Dict[str, Any]]:
"""获取响应内容"""
return json.loads(self.response_json) if self.response_json else None # type: ignore
@response.setter
def response(self, value: Any):
"""设置响应内容"""
if value:
self.response_json = json.dumps(value, ensure_ascii=False, default=str)
================================================
FILE: kirara_ai/web/README.md
================================================
# Web API 系统 🌐
本系统提供了一套完整的RESTful API,用于管理和监控ChatGPT-Mirai机器人的各个组件。
## 系统架构 🏗️
- 基于 [Quart](https://pgjones.gitlab.io/quart/) 异步Web框架
- 使用 [Pydantic](https://docs.pydantic.dev/) 进行数据验证
- JWT认证保护所有API端点
- CORS支持跨域请求
- 模块化设计,易于扩展
## 模块说明 📦
### 1. 认证模块 🔐
- 路径: [`framework/web/auth`](../framework/web/auth)
- 功能: 用户认证、JWT令牌管理
- API文档: [认证API文档](../framework/web/auth/README.md)
### 2. IM适配器管理 💬
- 路径: [`framework/web/api/im`](../framework/web/api/im)
- 功能: 管理即时通讯平台适配器
- API文档: [IM API文档](../framework/web/api/im/README.md)
### 3. LLM后端管理 🤖
- 路径: [`framework/web/api/llm`](../framework/web/api/llm)
- 功能: 管理大语言模型后端
- API文档: [LLM API文档](../framework/web/api/llm/README.md)
### 4. 调度规则管理 📋
- 路径: [`framework/web/api/dispatch`](../framework/web/api/dispatch)
- 功能: 管理消息处理规则
- API文档: [调度规则API文档](../framework/web/api/dispatch/README.md)
### 5. Block查询 🧩
- 路径: [`framework/web/api/block`](../framework/web/api/block)
- 功能: 查询工作流构建块信息
- API文档: [Block API文档](../framework/web/api/block/README.md)
### 6. Workflow管理 ⚡
- 路径: [`framework/web/api/workflow`](../framework/web/api/workflow)
- 功能: 管理工作流定义和执行
- API文档: [Workflow API文档](../framework/web/api/workflow/README.md)
### 7. 插件管理 🔌
- 路径: [`framework/web/api/plugin`](../framework/web/api/plugin)
- 功能: 管理系统插件
- API文档: [插件API文档](../framework/web/api/plugin/README.md)
### 8. 系统状态 📊
- 路径: [`framework/web/api/system`](../framework/web/api/system)
- 功能: 监控系统运行状态
- API文档: [系统状态API文档](../framework/web/api/system/README.md)
## 快速开始 🚀
1. 安装依赖:
```bash
pip install -r requirements.txt
```
2. 配置系统:
- 复制 `config.yaml.example` 到 `config.yaml`
- 修改配置文件中的相关设置
3. 启动服务:
```bash
python main.py
```
首次启动时会自动创建管理员密码。
## API认证 🔑
除了首次设置密码的接口外,所有API都需要在请求头中携带JWT令牌:
```http
Authorization: Bearer
```
获取令牌:
```http
POST/backend-api/api/auth/login
Content-Type: application/json
{
"password": "your-password"
}
```
## 开发指南 💻
### 添加新的API端点
1. 在相应模块下创建路由文件
2. 定义数据模型(使用Pydantic)
3. 实现API逻辑
4. 在 [`framework/web/app.py`](../framework/web/app.py) 中注册蓝图
示例:
```python
from quart import Blueprint, request
from pydantic import BaseModel
# 定义数据模型
class MyModel(BaseModel):
name: str
value: int
# 创建蓝图
my_bp = Blueprint('my_api', __name__)
# 实现API端点
@my_bp.route('/endpoint', methods=['POST'])
@require_auth
async def my_endpoint():
data = await request.get_json()
model = MyModel(**data)
# 处理逻辑
return model.model_dump()
```
### 错误处理
使用HTTP状态码表示错误类型:
- 400: 请求参数错误
- 401: 未认证或认证失败
- 404: 资源不存在
- 500: 服务器内部错误
返回统一的错误格式:
```json
{
"error": "错误描述信息"
}
```
## 依赖说明 📚
主要依赖包:
- quart: 异步Web框架
- pydantic: 数据验证
- PyJWT: JWT认证
- hypercorn: ASGI服务器
- psutil: 系统监控
完整依赖列表见 [requirements.txt](../requirements.txt)
## 测试 🧪
运行单元测试:
```bash
pytest tests/web
```
## 贡献指南 🤝
1. Fork 本仓库
2. 创建特性分支
3. 提交更改
4. 创建Pull Request
================================================
FILE: kirara_ai/web/__init__.py
================================================
================================================
FILE: kirara_ai/web/api/block/README.md
================================================
# 区块 API 🧩
区块 API 提供了查询工作流构建块类型的功能。每个区块类型定义了其输入、输出和配置项。
>> 注意:文档由 Claude 生成,可能存在错误,请以实际代码为准。
## API 端点
### 获取所有区块类型
```http
GET/backend-api/api/block/types
```
获取所有可用的区块类型列表。
**响应示例:**
```json
{
"types": [
{
"type_name": "MessageBlock",
"name": "消息区块",
"description": "处理消息的基础区块",
"inputs": [
{
"name": "content",
"description": "消息内容",
"type": "string",
"required": true
}
],
"outputs": [
{
"name": "message",
"description": "处理后的消息",
"type": "IMMessage"
}
],
"configs": [
{
"name": "format",
"description": "消息格式",
"type": "string",
"required": false,
"default": "text"
}
]
}
]
}
```
### 获取特定区块类型
```http
GET/backend-api/api/block/types/{type_name}
```
获取指定区块类型的详细信息。
**响应示例:**
```json
{
"type": {
"type_name": "LLMBlock",
"name": "大语言模型区块",
"description": "调用 LLM 进行对话的区块",
"inputs": [
{
"name": "prompt",
"description": "提示词",
"type": "string",
"required": true
}
],
"outputs": [
{
"name": "response",
"description": "LLM 的响应",
"type": "string"
}
],
"configs": [
{
"name": "model",
"description": "使用的模型",
"type": "string",
"required": true,
"default": "gpt-4"
},
{
"name": "temperature",
"description": "温度参数",
"type": "float",
"required": false,
"default": 0.7
}
]
}
}
```
### 注册区块类型
```http
POST/backend-api/api/block/types
```
注册新的区块类型。
**请求体:**
```json
{
"type": "image_process",
"name": "图像处理",
"description": "处理图像数据",
"category": "media",
"config_schema": {
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": ["resize", "crop", "rotate"]
},
"params": {
"type": "object"
}
}
},
"input_schema": {
"type": "object",
"properties": {
"image": {
"type": "string",
"format": "binary"
}
}
},
"output_schema": {
"type": "object",
"properties": {
"image": {
"type": "string",
"format": "binary"
}
}
}
}
```
### 更新区块类型
```http
PUT/backend-api/api/block/types/{type}
```
更新现有区块类型。
### 删除区块类型
```http
DELETE/backend-api/api/block/types/{type}
```
删除指定区块类型。
### 获取区块实例
```http
GET/backend-api/api/block/instances/{workflow_id}
```
获取指定工作流中的所有区块实例。
**响应示例:**
```json
{
"instances": [
{
"block_id": "input_1",
"type": "input",
"workflow_id": "chat:normal",
"config": {
"format": "text"
},
"state": {
"status": "ready",
"last_run": "2024-03-10T12:00:00Z",
"error": null
}
}
]
}
```
### 获取特定区块实例
```http
GET/backend-api/api/block/instances/{workflow_id}/{block_id}
```
获取指定区块实例的详细信息。
### 更新区块实例
```http
PUT/backend-api/api/block/instances/{workflow_id}/{block_id}
```
更新区块实例的配置。
## 数据模型
### BlockInput
- `name`: 输入名称
- `description`: 输入描述
- `type`: 数据类型
- `required`: 是否必需
- `default`: 默认值(可选)
### BlockOutput
- `name`: 输出名称
- `description`: 输出描述
- `type`: 数据类型
### BlockConfig
- `name`: 配置项名称
- `description`: 配置项描述
- `type`: 数据类型
- `required`: 是否必需
- `default`: 默认值(可选)
### BlockType
- `type_name`: 区块类型名称
- `name`: 显示名称
- `description`: 描述
- `inputs`: 输入定义列表
- `outputs`: 输出定义列表
- `configs`: 配置项定义列表
### BlockInstance
- `block_id`: 区块实例 ID
- `type`: 区块类型
- `workflow_id`: 所属工作流 ID
- `config`: 区块配置
- `state`: 区块状态
- `metadata`: 元数据(可选)
### BlockState
- `status`: 状态(ready/running/error)
- `last_run`: 最后运行时间
- `error`: 错误信息(如果有)
- `metrics`: 性能指标(可选)
## 相关代码
- [区块基类](../../../workflow/core/block/base.py)
- [区块注册表](../../../workflow/core/block/registry.py)
- [系统区块实现](../../../workflow/implementations/blocks)
## 错误处理
所有 API 端点在发生错误时都会返回适当的 HTTP 状态码和错误信息:
```json
{
"error": "错误描述信息"
}
```
常见状态码:
- 400: 请求参数错误
- 404: 区块类型不存在
- 500: 服务器内部错误
## 使用示例
### 获取所有区块类型
```python
import requests
response = requests.get(
'http://localhost:8080/api/block/types',
headers={'Authorization': f'Bearer {token}'}
)
```
### 获取特定区块类型
```python
import requests
response = requests.get(
'http://localhost:8080/api/block/types/LLMBlock',
headers={'Authorization': f'Bearer {token}'}
)
```
## 相关文档
- [工作流系统概述](../../README.md#工作流系统-)
- [区块开发指南](../../../workflow/README.md#区块开发)
- [API 认证](../../README.md#api认证-)
================================================
FILE: kirara_ai/web/api/block/__init__.py
================================================
from .routes import block_bp
__all__ = ["block_bp"]
================================================
FILE: kirara_ai/web/api/block/diagnostics/base_diagnostic.py
================================================
import ast
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from lsprotocol.types import CodeAction, CodeActionParams, Diagnostic, DiagnosticSeverity, Position, Range
from pygls.server import LanguageServer
from pygls.workspace import Document
logger = logging.getLogger(__name__)
class BaseDiagnostic(ABC):
"""诊断检查器的抽象基类"""
SOURCE_NAME: str = "base-diagnostic" # 每个子类应覆盖此项
def __init__(self, ls: LanguageServer):
self.ls = ls
@abstractmethod
def check(self, doc: Document) -> List[Diagnostic]:
"""
对文档执行诊断检查。
Args:
doc: 要检查的文档对象。
Returns:
诊断信息列表。
"""
def get_code_actions(self, params: CodeActionParams, relevant_diagnostics: List[Diagnostic]) -> List[CodeAction]:
"""
为相关的诊断信息生成代码操作(快速修复)。
Args:
params: 代码操作请求参数。
relevant_diagnostics: 与此检查器相关的诊断信息列表。
Returns:
代码操作列表。
"""
# 默认实现:不提供任何代码操作
return []
def _create_diagnostic(self, message: str, node: Optional[ast.AST], severity: DiagnosticSeverity, data: Optional[Dict] = None, range_override: Optional[Range] = None) -> Diagnostic:
"""
辅助函数:根据 AST 节点或显式范围创建 Diagnostic 对象。
"""
diag_range: Optional[Range] = None
if range_override:
diag_range = range_override
elif node and isinstance(node, ast.stmt):
try:
start_line = node.lineno - 1
start_col = node.col_offset
# AST 节点通常有 end_lineno 和 end_col_offset (Python 3.8+)
# end_lineno 是 1-based 的结束行号 (exclusive or inclusive depending on context, often exclusive)
# end_col_offset 是 0-based 的结束列偏移
end_line = getattr(node, 'end_lineno', start_line + 1) - 1
end_col = getattr(node, 'end_col_offset', start_col + 1)
# 确保范围有效
end_line = max(start_line, end_line)
if end_line == start_line:
end_col = max(start_col + 1, end_col) # 至少标记一个字符
diag_range = Range(
start=Position(line=start_line, character=start_col),
end=Position(line=end_line, character=end_col)
)
except AttributeError:
logger.warning(f"AST 节点缺少位置信息: {type(node)}")
# Fallback if location info is missing
diag_range = Range(start=Position(
line=0, character=0), end=Position(line=0, character=1))
else:
# 如果没有节点或范围,默认标记文件开头
diag_range = Range(start=Position(
line=0, character=0), end=Position(line=0, character=1))
return Diagnostic(
range=diag_range,
message=message,
severity=severity,
source=self.SOURCE_NAME,
data=data
)
def _ast_node_to_string(self, node: Optional[ast.AST]) -> str:
"""
将 AST 注解节点转换为字符串表示。
(从原 LanguageServer 类移动)
"""
if node is None:
return ""
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Constant): # Python 3.8+
# Handle NoneType explicitly if needed
if node.value is None:
return "None"
return str(node.value)
# Python < 3.8 NameConstant handling might be needed if supporting older versions
# if isinstance(node, ast.NameConstant):
# return str(node.value)
if isinstance(node, ast.Attribute):
value_str = self._ast_node_to_string(node.value)
# Avoid adding '.' if value is empty (shouldn't happen often)
return f"{value_str}.{node.attr}" if value_str else node.attr
if isinstance(node, ast.Subscript):
base = self._ast_node_to_string(node.value)
slice_val = node.slice
if isinstance(slice_val, ast.Tuple):
slice_str = ', '.join(
[self._ast_node_to_string(entry) for entry in slice_val.elts])
else:
slice_str = self._ast_node_to_string(slice_val)
return f"{base}[{slice_str}]"
if isinstance(node, ast.List):
return f"List[{self._ast_node_to_string(node.elts[0])}]" if node.elts else "List"
if isinstance(node, ast.Dict):
return f"Dict[{self._ast_node_to_string(node.keys[0])}, {self._ast_node_to_string(node.values[0])}]" if node.keys and node.values else "Dict"
if isinstance(node, ast.Set):
return f"Set[{self._ast_node_to_string(node.elts[0])}]" if node.elts else "Set"
if isinstance(node, ast.Tuple):
elts = [self._ast_node_to_string(elt) for elt in node.elts]
if not elts:
return "Tuple"
# Handle Tuple[int] vs Tuple[int, ...]
if len(elts) == 1:
# Check if it's intended as variable-length tuple hint e.g. Tuple[int, ...]
# This requires checking the source code or making assumptions.
# Simple approach: always assume single element means fixed tuple.
# Or f"tuple[{elts[0]}]" for consistency?
return f"Tuple[{elts[0]}]"
# Or f"tuple[{', '.join(elts)}]"
return f"Tuple[{', '.join(elts)}]"
# Fallback using ast.unparse (Python 3.9+)
try:
import sys
if sys.version_info >= (3, 9):
return ast.unparse(node)
# 如果 unparse 不可用,则为旧版本 Python 的基本回退
if isinstance(node, ast.Expr):
return self._ast_node_to_string(node.value)
# 如果需要,添加更多回退
logger.debug(
f"无法将 AST 节点转换为字符串(unparse 不可用):{type(node)}")
return "UnsupportedType"
except Exception as e:
logger.debug(f"Error using ast.unparse: {e}")
return "UnsupportedType" # Final fallback
================================================
FILE: kirara_ai/web/api/block/diagnostics/import_check.py
================================================
import ast
import importlib.util
import logging
import os
from typing import List, Optional, Tuple
from lsprotocol.types import (CodeAction, CodeActionKind, CodeActionParams, Diagnostic, DiagnosticSeverity, Position,
Range, TextEdit, WorkspaceEdit)
from pygls.workspace import Document
from .base_diagnostic import BaseDiagnostic
logger = logging.getLogger(__name__)
class ImportDiagnostic(BaseDiagnostic):
"""检查导入语句有效性的诊断器"""
SOURCE_NAME: str = "import-check"
def _get_package_context(self, path: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
"""获取文件路径对应的目录和包名"""
if not path or not os.path.exists(path):
return None, None
file_dir = os.path.dirname(path)
# Simple package detection: check for __init__.py or assume dir name is package
# This might not be fully robust for complex project structures.
package_name = None
try:
# Walk up to find a directory containing __init__.py or setup.py?
# For simplicity, let's use the immediate parent directory name if it seems plausible
potential_pkg_name = os.path.basename(file_dir)
# Avoid using names like 'src', 'lib' directly unless structure confirms it
# A better approach might involve analyzing project structure or sys.path
# Let's assume basename is the package for resolve_name context
package_name = potential_pkg_name or None
except Exception:
logger.warning(f"无法确定 '{path}' 的包上下文")
return file_dir, package_name
def check(self, doc: Document) -> List[Diagnostic]:
"""检查导入语句是否有效"""
diagnostics = []
source = doc.source
path = doc.path
try:
tree = ast.parse(source)
file_dir, package_name = self._get_package_context(path)
# Store found issues to avoid duplicate diagnostics for the same line/module
reported_issues = set() # Store (line_no, module_name_str) tuples
for node in ast.walk(tree):
module_name_str: Optional[str] = None
# Record node for reporting and fixing
import_node: Optional[ast.stmt] = None
if not isinstance(node, ast.stmt):
continue
line_no = node.lineno if hasattr(node, 'lineno') else 0
if isinstance(node, ast.Import):
import_node = node
for alias in node.names:
module_name_str = alias.name
issue_key = (line_no, module_name_str)
if issue_key in reported_issues:
continue
try:
spec = importlib.util.find_spec(module_name_str)
if spec is None:
message = f"无法找到模块 '{module_name_str}'"
diagnostic = self._create_diagnostic(
message, import_node, DiagnosticSeverity.Error,
data={"fix_type": "remove_import"}
)
diagnostics.append(diagnostic)
reported_issues.add(issue_key)
except ModuleNotFoundError:
message = f"无法找到模块 '{module_name_str}'"
diagnostic = self._create_diagnostic(
message, import_node, DiagnosticSeverity.Error,
data={"fix_type": "remove_import"}
)
diagnostics.append(diagnostic)
reported_issues.add(issue_key)
except Exception as e: # Catch other potential errors during find_spec
message = f"检查导入 '{module_name_str}' 时出错: {e}"
diagnostic = self._create_diagnostic(
message, import_node, DiagnosticSeverity.Warning, # Warning for general errors
# Still offer removal
data={"fix_type": "remove_import"}
)
diagnostics.append(diagnostic)
reported_issues.add(issue_key)
elif isinstance(node, ast.ImportFrom):
import_node = node
module_name_str = node.module # Can be None for 'from . import ...'
level = node.level
is_relative = level > 0
# Construct the name being resolved for reporting/key
if is_relative:
relative_prefix = "." * level
resolving_name = f"{relative_prefix}{module_name_str or ''}"
else:
# Should have module name if not relative
resolving_name = module_name_str or ""
# Skip if nothing to resolve (e.g. invalid syntax?)
if not resolving_name:
continue
issue_key = (line_no, resolving_name)
if issue_key in reported_issues:
continue
resolved_spec = None
error_message = None
severity = DiagnosticSeverity.Error
try:
if is_relative:
if file_dir and package_name:
# Resolve the name relative to the current file's package context
resolved_name = importlib.util.resolve_name(
resolving_name, package_name)
resolved_spec = importlib.util.find_spec(
resolved_name)
if resolved_spec is None:
error_message = f"无法找到相对导入的模块 '{resolving_name}' (解析为 '{resolved_name}' 来自 '{package_name}')"
else:
# Cannot reliably check relative imports without path/package context
error_message = f"无法可靠地检查相对导入 '{resolving_name}' (缺少文件路径或包上下文)"
severity = DiagnosticSeverity.Warning # Downgrade severity if unsure
elif module_name_str: # Absolute import
resolved_spec = importlib.util.find_spec(
module_name_str)
if resolved_spec is None:
error_message = f"无法找到模块 '{module_name_str}'"
except (ImportError, ValueError) as e:
error_message = f"无法解析或找到导入 '{resolving_name}': {e}"
except Exception as e:
error_message = f"检查导入 '{resolving_name}' 时发生意外错误: {e}"
severity = DiagnosticSeverity.Warning
# Create diagnostic if an error occurred
if error_message:
diagnostic = self._create_diagnostic(
error_message, import_node, severity,
data={"fix_type": "remove_import"}
)
diagnostics.append(diagnostic)
reported_issues.add(issue_key)
except SyntaxError:
# Syntax errors handled elsewhere
logger.debug("跳过导入检查,存在语法错误")
except Exception as e:
logger.error(f"检查导入时发生内部错误: {str(e)}", exc_info=True)
return diagnostics
def get_code_actions(self, params: CodeActionParams, relevant_diagnostics: List[Diagnostic]) -> List[CodeAction]:
"""提供删除无效导入的代码操作"""
actions = []
doc_uri = params.text_document.uri
document = self.ls.workspace.get_document(doc_uri)
if not document:
return []
lines = document.source.splitlines(True) # Keep line endings
for diag in relevant_diagnostics:
if diag.source != self.SOURCE_NAME:
continue
fix_type = diag.data.get("fix_type") if diag.data else None
if fix_type == "remove_import":
# The diagnostic range should cover the import statement node
start_line = diag.range.start.line
end_line = diag.range.end.line # The line index where the node ends
# Ensure line numbers are valid
if start_line < 0 or end_line >= len(lines):
logger.warning(f"无效的诊断范围用于删除导入: {diag.range}")
continue
# Define the range to be deleted: the entire line(s) of the import statement
delete_start_pos = Position(line=start_line, character=0)
# Determine the end position to include the newline of the last line involved
# Check if the node ends exactly at the end of a line (excluding newline)
# This requires knowing the exact end column from AST, which might be tricky.
# Safer approach: Delete up to the start of the *next* line.
delete_end_line_exclusive = end_line + 1
if delete_end_line_exclusive < len(lines):
# Delete up to the start of the next line
delete_end_pos = Position(
line=delete_end_line_exclusive, character=0)
else:
# Delete to the end of the last line of the file
delete_end_pos = Position(
line=end_line, character=len(lines[end_line]))
# Create the TextEdit to remove the line(s)
text_edit = TextEdit(
range=Range(start=delete_start_pos, end=delete_end_pos),
new_text=""
)
# Extract module name from message for a better title if possible
module_name_match = diag.message.split("'")
title_suffix = f": {module_name_match[1]}" if len(
module_name_match) > 1 else ""
title = f"移除无效的导入语句{title_suffix}"
edit = WorkspaceEdit(changes={doc_uri: [text_edit]})
action = CodeAction(
title=title,
kind=CodeActionKind.QuickFix,
diagnostics=[diag],
edit=edit,
is_preferred=False # Deleting code might not always be preferred
)
actions.append(action)
return actions
================================================
FILE: kirara_ai/web/api/block/diagnostics/jedi_syntax_check.py
================================================
import logging
from typing import List
import jedi
from lsprotocol.types import Diagnostic, DiagnosticSeverity, Position, Range
from pygls.server import LanguageServer
from pygls.workspace import Document
from .base_diagnostic import BaseDiagnostic
logger = logging.getLogger(__name__)
class JediSyntaxErrorDiagnostic(BaseDiagnostic):
"""使用 Jedi 检查 Python 语法错误的诊断器"""
SOURCE_NAME: str = "syntax-error"
def __init__(self, ls: LanguageServer):
super().__init__(ls)
def check(self, doc: Document) -> List[Diagnostic]:
"""
对文档执行语法错误检查。
Args:
doc: 要检查的文档对象。
Returns:
诊断信息列表。
"""
diagnostics = []
source = doc.source
path = doc.path
try:
# 使用 Jedi 创建 Script 对象
# 注意:即使代码有语法错误,Jedi 通常也能创建 Script 对象
script = jedi.Script(code=source, path=path or None)
# 获取语法错误
syntax_errors = script.get_syntax_errors()
for error in syntax_errors:
# Jedi 的行列号是 1-based,LSP 是 0-based
start_line = error.line - 1
start_char = error.column
# Jedi 的 until_line/until_column 定义了错误范围的结束(通常是独占的)
# LSP 的 Range 结束位置也是独占的
end_line = error.until_line - 1
end_char = error.until_column
# 创建 LSP Range 对象
# 确保行列号不为负数
start_line = max(0, start_line)
start_char = max(0, start_char)
end_line = max(start_line, end_line) # 结束行不能在开始行之前
if end_line == start_line:
end_char = max(start_char + 1, end_char) # 结束列至少在开始列之后一个字符
error_range = Range(
start=Position(line=start_line, character=start_char),
end=Position(line=end_line, character=end_char)
)
# 使用 _create_diagnostic 辅助函数创建诊断信息
diagnostic = self._create_diagnostic(
message=error.get_message(),
node=None,
severity=DiagnosticSeverity.Error,
range_override=error_range # 使用 Jedi 提供的范围
)
diagnostics.append(diagnostic)
except Exception as e:
# 捕获 Jedi 或其他意外错误
logger.error(f"检查语法错误时发生内部错误: {str(e)}", exc_info=True)
# 可以选择性地添加一个通用的错误诊断
diagnostics.append(self._create_diagnostic(
message=f"检查语法错误时出错: {e}",
node=None,
severity=DiagnosticSeverity.Warning, # 使用警告级别,因为这是检查器本身的问题
range_override=Range(start=Position(line=0, character=0), end=Position(line=0, character=1))
))
return diagnostics
================================================
FILE: kirara_ai/web/api/block/diagnostics/mandatory_function.py
================================================
import ast
import logging
from typing import Any, Dict, List, Optional
from lsprotocol.types import (CodeAction, CodeActionKind, CodeActionParams, Diagnostic, DiagnosticSeverity, Position,
Range, TextEdit, WorkspaceEdit)
from pygls.server import LanguageServer
from pygls.workspace import Document
from .base_diagnostic import BaseDiagnostic
logger = logging.getLogger(__name__)
class MandatoryFunctionDiagnostic(BaseDiagnostic):
"""检查强制函数声明的诊断器"""
SOURCE_NAME: str = "mandatory-function-check"
config: Optional[Dict[str, Any]] = None
config_name: Optional[str] = None
config_params: Optional[List[Dict[str, Any]]] = None
config_return: Optional[str] = None
param_signatures: Optional[List[str]] = None
expected_signature_str: Optional[str] = None
expected_signature_data: Optional[Dict[str, Any]] = None
def __init__(self, ls: LanguageServer, config: Optional[Dict[str, Any]]):
super().__init__(ls)
self.config = None
self.config_name = None
self.config_params = None
self.config_return = None
self.param_signatures = None
self.expected_signature_str = None
self.expected_signature_data = None
# 初始化配置
if config:
self.update_config(config)
def update_config(self, config: Dict[str, Any]) -> None:
"""更新诊断器配置
Args:
config: 包含必要函数检查配置的字典
"""
try:
self.config = config
assert self.config is not None
self.config_name = self.config["name"]
self.config_params = self.config["params"]
assert self.config_params is not None
self.config_return = self.config["return_type"]
self.param_signatures = [
f"{p['name']}: {p['type_hint']}" for p in self.config_params]
self.expected_signature_str = f"def {self.config_name}({', '.join(self.param_signatures)}) -> {self.config_return}"
self.expected_signature_data = {
"name": self.config_name,
"params": self.param_signatures,
"return": self.config_return
}
logger.info(f"Updated mandatory function config: {self.config_name}")
logger.debug(f"Expected signature: {self.expected_signature_str}")
except KeyError as e:
logger.error(f"Invalid mandatory function config, missing key: {e}")
self.config = None # 配置无效时禁用检查器
except Exception as e:
logger.error(f"Error updating mandatory function config: {e}")
self.config = None
def check(self, doc: Document) -> List[Diagnostic]:
"""检查源代码是否包含配置的强制函数声明"""
diagnostics: List[Diagnostic] = []
if not self.config or not self.config_params:
return diagnostics
source = doc.source
found_match = False
potential_match_node = None
# Store reasons for mismatch if name matches
mismatch_reasons: List[str] = []
try:
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == self.config_name:
potential_match_node = node
# Reasons for this specific node
current_reasons: List[str] = []
# 1. Check parameter count
num_actual_params = len(node.args.args)
if num_actual_params != len(self.config_params):
current_reasons.append(
f"参数数量不匹配: 期望 {len(self.config_params)} 个, 实际 {num_actual_params} 个")
# 2. Check parameter names and types (only if count matches for clearer messages)
params_match = True
if num_actual_params == len(self.config_params):
for i, arg_node in enumerate(node.args.args):
config_param = self.config_params[i]
arg_name = arg_node.arg
arg_type_str = self._ast_node_to_string(
arg_node.annotation)
config_type_hint = config_param.get(
"type_hint", "")
# Normalize empty/Any types for comparison
norm_arg_type = arg_type_str or "Any"
norm_config_type = config_type_hint or "Any"
types_match = (norm_arg_type == norm_config_type)
if arg_name != config_param["name"]:
params_match = False
current_reasons.append(
f"第 {i+1} 个参数名不匹配: 期望 '{config_param['name']}', 实际 '{arg_name}'")
elif not types_match and config_type_hint: # Only check type if config specifies one
params_match = False
current_reasons.append(
f"参数 '{arg_name}' 类型不匹配: 期望 '{config_type_hint}', 实际 '{arg_type_str or '无类型'}'")
# If count didn't match, mark params as mismatch
elif num_actual_params != len(self.config_params):
params_match = False
# 3. Check return type
return_type_str = self._ast_node_to_string(node.returns)
config_return_type = self.config_return or ""
norm_return_type = return_type_str or "Any"
norm_config_return = config_return_type or "Any"
return_match = (norm_return_type == norm_config_return)
if not return_match and config_return_type: # Only check if config specifies return
current_reasons.append(
f"返回类型不匹配: 期望 '{config_return_type}', 实际 '{return_type_str or '无类型'}'")
# If all checks pass for this node
if not current_reasons:
found_match = True
break # Found a perfect match, stop searching
else:
# Store reasons from the first encountered mismatching function
if not mismatch_reasons:
mismatch_reasons = current_reasons
# --- End of AST walk ---
except SyntaxError as e:
# Let other checkers handle syntax errors
return []
except Exception as e:
logger.error(f"Error checking mandatory function: {str(e)}", exc_info=True)
return []
# --- Create Diagnostic if needed ---
if not found_match:
if potential_match_node:
# Found function with same name but wrong signature
param_signatures_actual = []
for arg in potential_match_node.args.args:
sig = arg.arg
annotation = self._ast_node_to_string(arg.annotation)
if annotation:
sig += f': {annotation}'
param_signatures_actual.append(sig)
return_actual = self._ast_node_to_string(
potential_match_node.returns)
actual_signature_str = f"def {potential_match_node.name}({', '.join(param_signatures_actual)})"
if return_actual:
actual_signature_str += f" -> {return_actual}"
mismatch_hint = ""
if mismatch_reasons:
mismatch_hint = "\n具体差异:\n- " + \
"\n- ".join(mismatch_reasons)
message = f"函数 '{self.config_name}' 的签名与强制要求不符。\n期望: {self.expected_signature_str}\n实际: {actual_signature_str}{mismatch_hint}"
diagnostic = self._create_diagnostic(
message=message,
node=potential_match_node,
severity=DiagnosticSeverity.Error,
data={"expected_signature": self.expected_signature_data,
"fix_type": "replace_signature"} # Keep replace type, even if not implemented yet
)
diagnostics.append(diagnostic)
else:
# Did not find the function at all
lines = source.splitlines()
# Position after last line for insertion
file_end_line = len(lines)
file_end_col = 0
message = f"缺少强制函数声明: '{self.config_name}'。\n期望签名: {self.expected_signature_str}"
diagnostic = self._create_diagnostic(
message=message,
node=None, # No specific node
severity=DiagnosticSeverity.Error,
# Mark the end of the file for insertion
range_override=Range(start=Position(line=file_end_line, character=file_end_col),
end=Position(line=file_end_line, character=file_end_col)),
data={"expected_signature": self.expected_signature_data,
"fix_type": "insert_function"}
)
diagnostics.append(diagnostic)
return diagnostics
def get_code_actions(self, params: CodeActionParams, relevant_diagnostics: List[Diagnostic]) -> List[CodeAction]:
"""为强制函数错误提供代码操作"""
actions: List[CodeAction] = []
doc_uri = params.text_document.uri
document = self.ls.workspace.get_document(doc_uri)
if not document or not self.config: # Need document and config for actions
return []
for diag in relevant_diagnostics:
# Ensure the diagnostic came from this checker
if diag.source != self.SOURCE_NAME:
continue
fix_type = diag.data.get("fix_type") if diag.data else None
expected_sig_data = diag.data.get(
"expected_signature") if diag.data else None
# Double check data matches current config in case config changed
if not expected_sig_data or expected_sig_data["name"] != self.config_name:
continue
if fix_type == "insert_function":
title = f"生成强制函数 '{self.config_name}'"
param_str = ", ".join(expected_sig_data['params'])
return_str = f" -> {expected_sig_data['return']}" if expected_sig_data['return'] else ""
# Add two newlines before the function if the file is not empty
prefix = "\n\n" if document.source.strip() else ""
# Add a basic docstring and pass
new_text = f"{prefix}def {self.config_name}({param_str}){return_str}:\n \"\"\"强制函数存根\"\"\"\n pass\n"
# Use the range from the diagnostic (end of file)
insert_pos = diag.range.start
edit = WorkspaceEdit(changes={
doc_uri: [TextEdit(range=Range(
start=insert_pos, end=insert_pos), new_text=new_text)]
})
action = CodeAction(
title=title,
kind=CodeActionKind.QuickFix,
diagnostics=[diag],
edit=edit,
is_preferred=True # Make it the default action if possible
)
actions.append(action)
return actions
================================================
FILE: kirara_ai/web/api/block/diagnostics/pyflakes_check.py
================================================
import logging
from typing import Any, List
from lsprotocol.types import (CodeAction, CodeActionKind, CodeActionParams, Diagnostic, DiagnosticSeverity, Position,
Range, TextEdit, WorkspaceEdit)
from pyflakes import api as pyflakes_api
from pyflakes import messages as pyflakes_messages
from pyflakes import reporter as pyflakes_reporter
from pygls.server import LanguageServer
from pygls.workspace import Document
from .base_diagnostic import BaseDiagnostic
logger = logging.getLogger(__name__)
# 自定义 Reporter 来收集 Pyflakes 的错误/警告并转换为 LSP Diagnostic
class _LspReporter(pyflakes_reporter.Reporter):
def __init__(self, source_name: str):
super().__init__(None, None) # 不需要标准输出/错误流
self.diagnostics: List[Diagnostic] = []
self._source_name = source_name
def unexpectedError(self, filename: str, msg: str):
# Pyflakes 内部错误
logger.error(f"Pyflakes unexpected error in {filename}: {msg}")
diagnostic = Diagnostic(
range=Range(start=Position(line=0, character=0), end=Position(line=0, character=1)),
message=f"Pyflakes internal error: {msg}",
severity=DiagnosticSeverity.Warning, # 标记为警告,因为是检查器本身的问题
source=self._source_name,
code="PyflakesInternalError"
)
self.diagnostics.append(diagnostic)
def syntaxError(self, filename: str, msg: str, lineno: int, offset: int, text: str):
# 处理 Pyflakes 报告的语法错误
line = lineno - 1 # 转换为 0-based
col = offset - 1 if offset > 0 else 0 # Pyflakes offset 是 1-based,LSP 是 0-based
# 语法错误通常标记单个字符或到行尾,这里简单标记一个字符
# 更精确的范围可能需要分析 text 或 msg,但通常语法错误点明确
end_col = col + 1
# 确保行列号不为负数
line = max(0, line)
col = max(0, col)
end_col = max(col + 1, end_col) # 确保结束至少在开始后
diagnostic = Diagnostic(
range=Range(
start=Position(line=line, character=col),
end=Position(line=line, character=end_col)
),
message=f"Syntax Error: {msg}", # 添加前缀以明确是语法错误
severity=DiagnosticSeverity.Error, # 语法错误是 Error
source=self._source_name,
code="PyflakesSyntaxError" # 使用特定的代码
)
self.diagnostics.append(diagnostic)
def flake(self, message: Any):
# message 是一个 pyflakes.messages.* 的实例
line = message.lineno - 1 # 转换为 0-based
col = message.col # 0-based 列偏移
# 尝试获取更精确的结束列
end_col = col + 1 # 默认标记一个字符
message_code = message.__class__.__name__ # 获取消息类型作为 code
try:
# 对于特定类型的消息,尝试使用参数长度确定范围
if isinstance(message, (pyflakes_messages.UnusedImport,
pyflakes_messages.UndefinedName,
pyflakes_messages.UndefinedExport,
pyflakes_messages.UndefinedLocal,
pyflakes_messages.DuplicateArgument,
pyflakes_messages.RedefinedWhileUnused,
pyflakes_messages.UnusedVariable)):
# 这些消息的第一个参数通常是相关的名称
if message.message_args:
name = message.message_args[0]
if isinstance(name, str):
end_col = col + len(name)
elif isinstance(message, pyflakes_messages.ImportShadowedByLoopVar):
if message.message_args:
name = message.message_args[0] # 第一个参数是名称
if isinstance(name, str):
end_col = col + len(name)
# 对于 'from module import *' used,标记 '*'
elif isinstance(message, pyflakes_messages.ImportStarUsed):
end_col = col + 1 # '*' 只有一个字符
# 其他消息类型保持默认单字符范围
except Exception as e:
logger.warning(f"计算 Pyflakes 诊断范围时出错: {e}", exc_info=True)
end_col = col + 1 # 出错时回退
# 确定严重性
severity = DiagnosticSeverity.Warning # 默认为警告
if isinstance(message, (pyflakes_messages.UndefinedName,
pyflakes_messages.UndefinedExport,
pyflakes_messages.UndefinedLocal,
pyflakes_messages.DoctestSyntaxError,
pyflakes_messages.ForwardAnnotationSyntaxError)):
severity = DiagnosticSeverity.Error
elif "syntax" in message_code.lower() or "invalid" in message_code.lower():
# 捕捉其他可能的语法相关错误消息类型
severity = DiagnosticSeverity.Error
# 创建诊断数据,包含消息类型,用于代码操作
diag_data = {"pyflakes_code": message_code}
diagnostic = Diagnostic(
range=Range(
start=Position(line=line, character=col),
end=Position(line=line, character=end_col)
),
message=message.message % message.message_args, # 格式化消息
severity=severity,
source=self._source_name,
code=message_code, # 使用 Pyflakes 消息类名作为 code
data=diag_data # 附加数据
)
self.diagnostics.append(diagnostic)
class PyflakesDiagnostic(BaseDiagnostic):
"""使用 Pyflakes 检查 Python 代码错误的诊断器 (增强版)"""
SOURCE_NAME: str = "pyflakes"
def __init__(self, ls: LanguageServer):
super().__init__(ls)
def check(self, doc: Document) -> List[Diagnostic]:
"""
对文档执行 Pyflakes 检查。
"""
diagnostics = []
source = doc.source
# Pyflakes 需要一个文件名,即使是临时的
path = doc.path or "untitled.py"
try:
reporter = _LspReporter(self.SOURCE_NAME)
# 使用 check 函数运行检查
pyflakes_api.check(source, path, reporter=reporter)
diagnostics = reporter.diagnostics
except Exception as e:
logger.error(f"运行 Pyflakes 时发生内部错误: {str(e)}", exc_info=True)
# 创建一个诊断信息报告 Pyflakes 本身的错误
diagnostics.append(self._create_diagnostic(
message=f"运行 Pyflakes 时出错: {e}",
node=None, # 没有关联的 AST 节点
severity=DiagnosticSeverity.Warning, # 检查器问题标记为警告
range_override=Range(start=Position(line=0, character=0), end=Position(line=0, character=1))
))
return diagnostics
def get_code_actions(self, params: CodeActionParams, relevant_diagnostics: List[Diagnostic]) -> List[CodeAction]:
"""
为 Pyflakes 诊断提供代码操作(快速修复)。
目前主要实现 "移除未使用的导入"。
"""
actions = []
doc_uri = params.text_document.uri
document = self.ls.workspace.get_document(doc_uri)
if not document:
logger.warning(f"无法获取 Pyflakes 代码操作,未找到文档: {doc_uri}")
return []
lines = document.source.splitlines(True) # 保留换行符
for diag in relevant_diagnostics:
# 确保是来自 pyflakes 的诊断并且有附加数据
if diag.source != self.SOURCE_NAME or not diag.data:
continue
pyflakes_code = diag.data.get("pyflakes_code")
# --- 快速修复:移除未使用的导入 (UnusedImport) ---
if pyflakes_code == "UnusedImport":
# 诊断范围指向未使用的名称,我们需要删除包含它的整行(或部分行)
# 简单起见,我们先实现删除整行。
# 注意:如果一行导入多个,这会删除所有导入。更精细的处理需要 AST 分析。
start_line = diag.range.start.line
end_line = diag.range.end.line # Pyflakes 通常在单行内报告
if start_line < 0 or end_line >= len(lines):
logger.warning(f"无效的 Pyflakes 诊断范围用于移除导入: {diag.range}")
continue
# 定义要删除的范围:从该行开始到下一行开始(删除整行及换行符)
delete_start_pos = Position(line=start_line, character=0)
delete_end_line_exclusive = end_line + 1
if delete_end_line_exclusive < len(lines):
# 删除到下一行的开头
delete_end_pos = Position(line=delete_end_line_exclusive, character=0)
else:
# 如果是最后一行,删除到该行的末尾
delete_end_pos = Position(line=end_line, character=len(lines[end_line]))
text_edit = TextEdit(
range=Range(start=delete_start_pos, end=delete_end_pos),
new_text=""
)
edit = WorkspaceEdit(changes={doc_uri: [text_edit]})
# 尝试从消息中提取模块/变量名以获得更好的标题
title_suffix = ""
try:
# 'imported but unused' or 'assigned to but never used'
parts = diag.message.split("'")
if len(parts) > 1:
title_suffix = f": '{parts[1]}'"
except Exception:
pass # 忽略提取错误
action = CodeAction(
title=f"移除未使用的导入{title_suffix}",
kind=CodeActionKind.QuickFix,
diagnostics=[diag], # 关联此代码操作到原始诊断
edit=edit,
is_preferred=True # 移除未使用导入通常是首选操作
)
actions.append(action)
# --- 可以添加其他快速修复,例如:---
# if pyflakes_code == "SomeOtherFixableIssue":
# # ... 实现对应的 TextEdit 和 CodeAction ...
# pass
return actions
================================================
FILE: kirara_ai/web/api/block/models.py
================================================
from typing import List
from pydantic import BaseModel
from kirara_ai.workflow.core.block.schema import BlockConfig, BlockInput, BlockOutput
class BlockType(BaseModel):
"""Block类型信息"""
type_name: str
name: str
label: str
description: str
inputs: List[BlockInput]
outputs: List[BlockOutput]
configs: List[BlockConfig]
class BlockTypeList(BaseModel):
"""Block类型列表响应"""
types: List[BlockType]
class BlockTypeResponse(BaseModel):
"""单个Block类型响应"""
type: BlockType
================================================
FILE: kirara_ai/web/api/block/python_lsp.py
================================================
import asyncio
import os
from typing import Any, Dict, List, Optional, Union
import jedi
from lsprotocol.types import (TEXT_DOCUMENT_CODE_ACTION, TEXT_DOCUMENT_COMPLETION, TEXT_DOCUMENT_DEFINITION,
TEXT_DOCUMENT_DID_CHANGE, TEXT_DOCUMENT_DID_OPEN, TEXT_DOCUMENT_DID_SAVE,
TEXT_DOCUMENT_DOCUMENT_SYMBOL, TEXT_DOCUMENT_HOVER, TEXT_DOCUMENT_SIGNATURE_HELP,
WORKSPACE_DID_CHANGE_CONFIGURATION, CodeAction, CodeActionParams, CompletionItem,
CompletionItemKind, CompletionList, CompletionOptions, CompletionParams, DefinitionParams,
Diagnostic, DiagnosticSeverity, DidChangeConfigurationParams, DidChangeTextDocumentParams,
DidOpenTextDocumentParams, DidSaveTextDocumentParams, DocumentSymbolParams, Hover,
HoverParams, Location, MarkupContent, MarkupKind, MessageType, ParameterInformation,
Position, Range, SignatureHelp, SignatureHelpOptions, SignatureHelpParams,
SignatureInformation, SymbolInformation, SymbolKind, TextDocumentPositionParams)
from pygls.server import LanguageServer
from kirara_ai.logger import get_logger
from .diagnostics.base_diagnostic import BaseDiagnostic
from .diagnostics.import_check import ImportDiagnostic
from .diagnostics.jedi_syntax_check import JediSyntaxErrorDiagnostic
from .diagnostics.mandatory_function import MandatoryFunctionDiagnostic
from .diagnostics.pyflakes_check import PyflakesDiagnostic
logger = get_logger("LSP")
class QuartWsTransport(asyncio.Transport):
def __init__(self, queue: asyncio.Queue):
self._queue = queue
def write(self, message: str):
try:
# put_nowait 通常是线程安全的
self._queue.put_nowait(message)
except Exception as e:
logger.error(
f"Error putting message into queue: {e}", exc_info=True)
def close(self):
self._queue.put_nowait(None)
class PythonLanguageServer(LanguageServer):
mandatory_function_checker: Optional[MandatoryFunctionDiagnostic] = None
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
super().__init__("kirara-code-block-lsp", "v0.1", loop=loop)
self._max_workers = 1
self.diagnostic_checkers: List[BaseDiagnostic] = []
self.diagnostic_checkers.append(ImportDiagnostic(self))
self.diagnostic_checkers.append(JediSyntaxErrorDiagnostic(self))
self.diagnostic_checkers.append(PyflakesDiagnostic(self))
logger.info(
f"Enabled diagnostic checkers: {[c.SOURCE_NAME for c in self.diagnostic_checkers]}")
self._setup_handlers()
def configure_mandatory_function_checker(self, config: Dict[str, Any]) -> None:
"""配置必要函数检查器
Args:
config: 包含必要函数检查配置的字典
"""
try:
if self.mandatory_function_checker is None:
self.mandatory_function_checker = MandatoryFunctionDiagnostic(
self, config)
self.diagnostic_checkers.append(
self.mandatory_function_checker)
logger.info(
"MandatoryFunctionDiagnostic enabled with client configuration.")
else:
# 更新现有检查器的配置
self.mandatory_function_checker.update_config(config)
logger.info(
"MandatoryFunctionDiagnostic configuration updated.")
# 记录配置详情
logger.debug(
f"MandatoryFunctionDiagnostic configured with: {config}")
# 更新已启用的诊断检查器列表日志
logger.info(
f"Currently enabled diagnostic checkers: {[c.SOURCE_NAME for c in self.diagnostic_checkers]}")
except Exception as e:
logger.error(
f"Error configuring MandatoryFunctionDiagnostic: {str(e)}", exc_info=True)
self.show_message(
f"Error configuring mandatory function checker: {str(e)}", MessageType.Error)
def _setup_handlers(self):
"""设置 LSP 方法处理程序"""
logger.debug("Setting up LSP handlers...")
@self.feature(TEXT_DOCUMENT_COMPLETION, CompletionOptions(trigger_characters=['.', '(', ',', '=', "\\", "[", "'"]))
@self.thread()
def completions(ls, params: CompletionParams) -> CompletionList:
"""处理代码补全请求"""
return self._get_completions(params)
@self.feature(TEXT_DOCUMENT_HOVER)
@self.thread()
def hover(ls, params: HoverParams) -> Optional[Hover]:
"""处理悬停请求"""
return self._get_hover(params)
@self.feature(TEXT_DOCUMENT_SIGNATURE_HELP, SignatureHelpOptions(trigger_characters=["(", ",", "."]))
@self.thread()
def signature(ls, params: SignatureHelpParams) -> Optional[SignatureHelp]:
"""处理函数签名帮助请求"""
return self._get_signature_help(params)
@self.feature(TEXT_DOCUMENT_DEFINITION)
@self.thread()
def definition(ls, params: DefinitionParams) -> Optional[Union[Location, List[Location]]]:
"""处理跳转到定义请求"""
return self._get_definition(params)
@self.feature(TEXT_DOCUMENT_DOCUMENT_SYMBOL)
@self.thread()
def symbols(ls, params: DocumentSymbolParams) -> Optional[List[SymbolInformation]]:
"""处理文档符号请求"""
return self._get_document_symbols(params)
@self.feature(TEXT_DOCUMENT_DID_OPEN)
@self.thread()
def did_open(ls, params: DidOpenTextDocumentParams):
"""文档打开时触发诊断"""
logger.info(f"Document opened: {params.text_document.uri}")
doc = ls.workspace.get_document(params.text_document.uri)
if doc and doc.source != params.text_document.text:
ls.workspace.put_document(params.text_document)
self._publish_diagnostics(ls, params.text_document.uri)
@self.feature(TEXT_DOCUMENT_DID_CHANGE)
@self.thread()
def did_change(ls, params: DidChangeTextDocumentParams):
"""文档更改时触发诊断"""
self._publish_diagnostics(ls, params.text_document.uri)
@self.feature(TEXT_DOCUMENT_DID_SAVE)
@self.thread()
def did_save(ls, params: DidSaveTextDocumentParams):
"""文档保存时触发诊断"""
self._publish_diagnostics(ls, params.text_document.uri)
@self.feature(TEXT_DOCUMENT_CODE_ACTION)
@self.thread()
def code_action(ls, params: CodeActionParams) -> Optional[List[CodeAction]]:
"""处理代码操作请求,提供快速修复建议"""
return self._get_code_actions(params)
@self.feature(WORKSPACE_DID_CHANGE_CONFIGURATION)
@self.thread()
def did_change_configuration(ls, params: DidChangeConfigurationParams):
"""处理客户端配置变更"""
try:
settings = params.settings
if not settings:
return
# 检查是否包含强制函数配置
if 'mandatoryFunction' in settings:
logger.info(
"Update mandatory function checker configuration")
self.configure_mandatory_function_checker(
settings['mandatoryFunction'])
else:
logger.debug("No mandatory function configuration found")
except Exception as e:
logger.error(
f"Error processing configuration change: {str(e)}", exc_info=True)
self.show_message(
f"Error processing configuration change: {str(e)}", MessageType.Error)
logger.debug("LSP handlers set up.")
def _get_script(self, params: Union[TextDocumentPositionParams, CompletionParams, HoverParams, SignatureHelpParams, DefinitionParams]) -> Optional[jedi.Script]:
"""从参数中获取 jedi.Script 对象"""
try:
doc_uri = params.text_document.uri
document = self.workspace.get_document(doc_uri)
if not document:
logger.warning(f"文档未在工作区中找到: {doc_uri}")
return None
path = document.path
source = document.source
position = params.position
line = position.line + 1
column = position.character
script = jedi.Script(
code=source,
path=path if path else None,
project=jedi.Project(os.getcwd())
)
return script
except Exception as e:
logger.error(f"Error getting jedi.Script: {str(e)}", exc_info=True)
return None
def _get_completions(self, params: CompletionParams) -> CompletionList:
"""获取代码补全建议"""
items: List[CompletionItem] = []
script = self._get_script(params)
if not script:
return CompletionList(is_incomplete=False, items=items)
try:
position = params.position
line = position.line + 1
column = position.character
completions = script.complete(line, column, fuzzy=True)
for completion in completions:
# ignore hidden completions like __str__
if completion.name.startswith('__'):
continue
kind = self._map_completion_type(completion.type)
item = CompletionItem(
label=completion.name,
kind=kind,
detail=completion.description,
insert_text=completion.name,
)
items.append(item)
except jedi.InternalError as e:
logger.warning(f"Jedi completion error: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error getting completions: {str(e)}", exc_info=True)
return CompletionList(is_incomplete=False, items=items)
def _get_hover(self, params: HoverParams) -> Optional[Hover]:
"""获取悬停信息"""
script = self._get_script(params)
if not script:
return None
try:
position = params.position
line = position.line + 1
column = position.character
hover_info_list = script.help(line, column)
if hover_info_list:
docs = []
for info in hover_info_list:
signature = f"```python\n{info.description}\n```"
doc = info.docstring(raw=True, fast=False)
content = signature
if doc:
content += f"\n\n---\n\n{doc}"
docs.append(content)
full_docstring = "\n\n".join(docs)
if full_docstring:
contents = MarkupContent(
kind=MarkupKind.Markdown,
value=full_docstring
)
return Hover(contents=contents)
except jedi.InternalError as e:
logger.warning(f"Jedi hover error: {e}", exc_info=True)
except Exception as e:
logger.error(
f"Error getting hover information: {str(e)}", exc_info=True)
return None
def _get_signature_help(self, params: SignatureHelpParams) -> Optional[SignatureHelp]:
"""获取函数签名帮助"""
script = self._get_script(params)
if not script:
return None
try:
position = params.position
line = position.line + 1
column = position.character
signatures = script.get_signatures(line, column)
if signatures:
signature_infos = []
for sig_index, sig in enumerate(signatures):
param_infos = []
for i, param in enumerate(sig.params):
param_doc = param.description
param_label = param.name
param_info = ParameterInformation(
label=param_label,
documentation=param_doc
)
param_infos.append(param_info)
sig_label = sig.to_string()
sig_info = SignatureInformation(
label=sig_label,
documentation=sig.docstring(raw=True, fast=False),
parameters=param_infos,
)
signature_infos.append(sig_info)
if not signature_infos:
return None
active_signature_index = 0
active_parameter_index = signatures[active_signature_index].index if signatures and signatures[
active_signature_index].index is not None else 0
return SignatureHelp(
signatures=signature_infos,
active_signature=active_signature_index,
active_parameter=active_parameter_index
)
except jedi.InternalError as e:
logger.warning(f"Jedi signature help error: {e}", exc_info=True)
except Exception as e:
logger.error(
f"Error getting function signature: {str(e)}", exc_info=True)
return None
def _get_definition(self, params: DefinitionParams) -> List[Location]:
"""获取跳转到定义位置"""
locations: List[Location] = []
script = self._get_script(params)
if not script:
return locations
try:
position = params.position
line = position.line + 1
column = position.character
definitions = script.goto(
line, column, follow_imports=True, follow_builtin_imports=True)
for definition in definitions:
if definition.module_path and definition.line is not None and definition.column is not None:
start_pos = Position(
line=definition.line - 1, character=definition.column)
end_pos = Position(
line=definition.line - 1, character=definition.column + len(definition.name))
range_val = Range(start=start_pos, end=end_pos)
try:
from pathlib import Path
uri = Path(definition.module_path).as_uri()
except ImportError:
uri = f"file://{definition.module_path}"
locations.append(Location(uri=uri, range=range_val))
except jedi.InternalError as e:
logger.warning(f"Jedi definition lookup error: {e}", exc_info=True)
except Exception as e:
logger.error(
f"Error getting definition location: {str(e)}", exc_info=True)
return locations
def _get_document_symbols(self, params: DocumentSymbolParams) -> List[SymbolInformation]:
"""获取文档中的符号信息 (扁平列表)"""
symbols = []
try:
doc_uri = params.text_document.uri
document = self.workspace.get_document(doc_uri)
if not document:
return []
_script = jedi.Script(
code=document.source,
path=document.path if document.path else None,
project=jedi.Project(os.getcwd())
)
names = _script.get_names(
all_scopes=True, definitions=True, references=False)
for name in names:
if name.line is not None and name.column is not None:
kind = self._map_symbol_type(name.type)
start_pos = Position(
line=name.line - 1, character=name.column)
end_pos = Position(line=name.line - 1,
character=name.column + len(name.name))
container_name = None
try:
parent = name.parent()
if parent and parent.type != 'module':
container_name = parent.name
except Exception:
pass
symbol = SymbolInformation(
name=name.name,
kind=kind,
location=Location(
uri=doc_uri,
range=Range(start=start_pos, end=end_pos)
),
container_name=container_name,
)
symbols.append(symbol)
except jedi.InternalError as e:
logger.warning(f"Jedi symbol lookup error: {e}", exc_info=True)
except Exception as e:
logger.error(
f"Error getting document symbols: {str(e)}", exc_info=True)
return symbols
def _map_completion_type(self, type_str: str) -> CompletionItemKind:
"""将 jedi 补全类型映射到 LSP 补全类型"""
mapping = {
'module': CompletionItemKind.Module,
'class': CompletionItemKind.Class,
'instance': CompletionItemKind.Variable,
'function': CompletionItemKind.Function,
'param': CompletionItemKind.Variable,
'path': CompletionItemKind.File,
'keyword': CompletionItemKind.Keyword,
'property': CompletionItemKind.Property,
'statement': CompletionItemKind.Variable,
'import': CompletionItemKind.Module,
'method': CompletionItemKind.Method,
' M': CompletionItemKind.Method,
' C': CompletionItemKind.Class,
' F': CompletionItemKind.Function,
}
return mapping.get(type_str, CompletionItemKind.Text)
def _map_symbol_type(self, type_str: str) -> SymbolKind:
"""将 jedi 符号类型映射到 LSP 符号类型"""
mapping = {
'module': SymbolKind.Module,
'class': SymbolKind.Class,
'instance': SymbolKind.Variable,
'function': SymbolKind.Function,
'param': SymbolKind.Variable,
'path': SymbolKind.File,
'keyword': SymbolKind.Variable,
'property': SymbolKind.Property,
'statement': SymbolKind.Variable,
'import': SymbolKind.Module,
'method': SymbolKind.Method,
' M': SymbolKind.Method,
' C': SymbolKind.Class,
' F': SymbolKind.Function,
'namespace': SymbolKind.Namespace,
}
return mapping.get(type_str, SymbolKind.Variable)
def _publish_diagnostics(self, ls: LanguageServer, doc_uri: str):
"""运行所有检查并发布诊断信息"""
all_diagnostics = []
document = ls.workspace.get_document(doc_uri)
if not document:
logger.warning(f"无法发布诊断,未找到文档: {doc_uri}")
ls.publish_diagnostics(doc_uri, [])
return
for checker in self.diagnostic_checkers:
checker_name = checker.SOURCE_NAME
try:
checker_diagnostics = checker.check(document)
if checker_diagnostics:
all_diagnostics.extend(checker_diagnostics)
except Exception as e:
logger.error(
f"Diagnostic checker '{checker_name}' error: {str(e)}", exc_info=True)
all_diagnostics.append(Diagnostic(
range=Range(start=Position(line=0, character=0),
end=Position(line=0, character=1)),
message=f"Diagnostic checker '{checker_name}' error: {str(e)}",
severity=DiagnosticSeverity.Error,
source='lsp-internal'
))
ls.publish_diagnostics(doc_uri, all_diagnostics)
def _get_code_actions(self, params: CodeActionParams) -> Optional[List[CodeAction]]:
"""根据请求的诊断信息生成代码操作"""
actions = []
doc_uri = params.text_document.uri
document = self.workspace.get_document(doc_uri)
if not document:
return None
context_diagnostics = params.context.diagnostics
diagnostics_by_source: Dict[str, List[Diagnostic]] = {}
for diag in context_diagnostics:
if diag.source:
if diag.source not in diagnostics_by_source:
diagnostics_by_source[diag.source] = []
diagnostics_by_source[diag.source].append(diag)
for checker in self.diagnostic_checkers:
checker_name = checker.SOURCE_NAME
relevant_diagnostics = diagnostics_by_source.get(checker_name, [])
if relevant_diagnostics:
try:
checker_actions = checker.get_code_actions(
params, relevant_diagnostics)
if checker_actions:
actions.extend(checker_actions)
except Exception as e:
logger.error(
f"Code action checker '{checker_name}' error: {str(e)}", exc_info=True)
return actions if actions else None
================================================
FILE: kirara_ai/web/api/block/routes.py
================================================
import asyncio
import json
from typing import Any
from quart import Blueprint, g, jsonify, websocket
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.block import BlockRegistry
from ...auth.middleware import require_auth
from .models import BlockType, BlockTypeList, BlockTypeResponse
from .python_lsp import PythonLanguageServer, QuartWsTransport
block_bp = Blueprint("block", __name__)
logger = get_logger("Web.Block")
@block_bp.route("/types", methods=["GET"])
@require_auth
async def list_block_types() -> Any:
"""获取所有可用的Block类型"""
registry: BlockRegistry = g.container.resolve(BlockRegistry)
types = []
for block_type in registry.get_all_types():
try:
inputs, outputs, configs = registry.extract_block_info(block_type)
type_name = registry.get_block_type_name(block_type)
for config in configs.values():
if config.has_options:
config.options = config.options_provider(g.container, block_type) # type: ignore
block_type_info = BlockType(
type_name=type_name,
name=block_type.name,
label=registry.get_localized_name(type_name) or block_type.name,
description=getattr(block_type, "description", ""),
inputs=list(inputs.values()),
outputs=list(outputs.values()),
configs=list(configs.values()),
)
types.append(block_type_info)
except Exception as e:
logger.error(f"获取Block类型失败: {e}")
return BlockTypeList(types=types).model_dump()
@block_bp.route("/types/", methods=["GET"])
@require_auth
async def get_block_type(type_name: str) -> Any:
"""获取特定Block类型的详细信息"""
registry: BlockRegistry = g.container.resolve(BlockRegistry)
block_type = registry.get(type_name)
if not block_type:
return jsonify({"error": "Block type not found"}), 404
# 获取Block类的输入输出定义
inputs, outputs, configs = registry.extract_block_info(block_type)
for config in configs.values():
if config.has_options:
config.options = config.options_provider(g.container, block_type) # type: ignore
block_type_info = BlockType(
type_name=type_name,
name=block_type.name,
label=registry.get_localized_name(type_name) or block_type.name,
description=getattr(block_type, "description", ""),
inputs=list(inputs.values()),
outputs=list(outputs.values()),
configs=list(configs.values()),
)
return BlockTypeResponse(type=block_type_info).model_dump()
@block_bp.route("/types/compatibility", methods=["GET"])
@require_auth
async def get_type_compatibility() -> Any:
"""获取所有类型的兼容性映射"""
registry: BlockRegistry = g.container.resolve(BlockRegistry)
return jsonify(registry.get_type_compatibility_map())
@block_bp.websocket("/code/lsp")
async def code_lsp():
"""处理代码编辑器的语言服务器协议 WebSocket 连接"""
lsp_server = PythonLanguageServer(loop=asyncio.get_event_loop())
logger = get_logger("Web.Block.LSP")
queue = asyncio.Queue()
transport = QuartWsTransport(queue)
lsp_server.lsp.connection_made(transport)
lsp_server.lsp._send_only_body = True
logger.info("LSP WebSocket connection established")
async def sender():
while True:
message = await queue.get()
if message is None:
break
await websocket.send(message)
async def receiver():
while True:
message_str = await websocket.receive()
try:
parsed_message = json.loads(
message_str,
object_hook=lsp_server.lsp._deserialize_message
)
lsp_server.lsp._procedure_handler(parsed_message)
except json.JSONDecodeError:
logger.error(f"Unable to parse received LSP message: {message_str}", exc_info=True)
except Exception as e:
logger.error(f"Error processing LSP message: {e}", exc_info=True)
receive_task = asyncio.create_task(receiver())
send_task = asyncio.create_task(sender())
logger.debug("Created LSP WebSocket sender and receiver tasks")
try:
await asyncio.gather(receive_task, send_task)
except asyncio.CancelledError:
logger.info("LSP WebSocket task cancelled")
except Exception as e:
logger.error(f"LSP WebSocket connection error: {e}", exc_info=True)
finally:
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
logger.debug("LSP WebSocket sender task cancelled")
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
logger.debug("LSP WebSocket receiver task cancelled")
loop = asyncio.get_event_loop()
loop.run_in_executor(None, lsp_server.shutdown)
logger.info("websocket connection closed")
================================================
FILE: kirara_ai/web/api/dispatch/README.md
================================================
# 调度规则 API 📋
调度规则 API 提供了消息处理规则的管理功能。调度规则决定了如何根据消息内容选择合适的工作流进行处理。
## API 端点
### 获取规则列表
```http
GET/backend-api/api/dispatch/rules
```
获取所有已配置的调度规则。
**响应示例:**
```json
{
"rules": [
{
"rule_id": "chat_normal",
"name": "普通聊天",
"description": "普通聊天,使用默认参数",
"workflow_id": "chat:normal",
"priority": 5,
"enabled": true,
"rule_groups": [
{
"operator": "or",
"rules": [
{
"type": "prefix",
"config": {
"prefix": "/chat"
}
},
{
"type": "keyword",
"config": {
"keywords": ["聊天", "对话"]
}
}
]
}
],
"metadata": {
"category": "chat",
"permission": "user",
"temperature": 0.7
}
}
]
}
```
### 获取特定规则
```http
GET/backend-api/api/dispatch/rules/{rule_id}
```
获取指定规则的详细信息。
### 创建规则
```http
POST/backend-api/api/dispatch/rules
```
创建新的调度规则。
**请求体:**
```json
{
"rule_id": "chat_creative",
"name": "创意聊天",
"description": "创意聊天,使用更高的温度参数",
"workflow_id": "chat:creative",
"priority": 5,
"enabled": true,
"rule_groups": [
{
"operator": "and",
"rules": [
{
"type": "prefix",
"config": {
"prefix": "/creative"
}
},
{
"type": "keyword",
"config": {
"keywords": ["创意", "发散"]
}
}
]
}
],
"metadata": {
"category": "chat",
"permission": "user",
"temperature": 0.9
}
}
```
### 更新规则
```http
PUT/backend-api/api/dispatch/rules/{rule_id}
```
更新现有规则。
### 删除规则
```http
DELETE/backend-api/api/dispatch/rules/{rule_id}
```
删除指定规则。
### 启用规则
```http
POST/backend-api/api/dispatch/rules/{rule_id}/enable
```
启用指定规则。
### 禁用规则
```http
POST/backend-api/api/dispatch/rules/{rule_id}/disable
```
禁用指定规则。
## 数据模型
### SimpleRule
- `type`: 规则类型 (prefix/keyword/regex)
- `config`: 规则类型特定的配置
### RuleGroup
- `operator`: 组合操作符 (and/or)
- `rules`: 规则列表
### CombinedDispatchRule
- `rule_id`: 规则唯一标识符
- `name`: 规则名称
- `description`: 规则描述
- `workflow_id`: 关联的工作流ID
- `priority`: 优先级(数字越大优先级越高)
- `enabled`: 是否启用
- `rule_groups`: 规则组列表(组之间是 AND 关系)
- `metadata`: 元数据(可选)
## 规则类型
### 前缀匹配 (prefix)
根据消息前缀进行匹配,例如 "/help"。
配置参数:
- `prefix`: 要匹配的前缀
### 关键词匹配 (keyword)
检查消息中是否包含指定关键词。
配置参数:
- `keywords`: 关键词列表
### 正则匹配 (regex)
使用正则表达式进行匹配,提供最灵活的匹配方式。
配置参数:
- `pattern`: 正则表达式模式
## 组合规则说明
新版本的调度规则系统支持复杂的条件组合:
1. 每个规则可以包含多个规则组(RuleGroup)
2. 规则组之间是 AND 关系,即所有规则组都满足时才会触发
3. 每个规则组内可以包含多个简单规则(SimpleRule)
4. 规则组内的规则可以选择 AND 或 OR 关系
5. 每个简单规则都有自己的类型和配置
例如,可以创建如下规则:
```json
{
"rule_groups": [
{
"operator": "or",
"rules": [
{ "type": "prefix", "config": { "prefix": "/creative" } },
{ "type": "keyword", "config": { "keywords": ["创意", "发散"] } }
]
},
{
"operator": "and",
"rules": [
{ "type": "regex", "config": { "pattern": ".*问题.*" } },
{ "type": "keyword", "config": { "keywords": ["帮我", "请问"] } }
]
}
]
}
```
这个规则表示:
- 当消息以 "/creative" 开头 或 包含 "创意"/"发散" 关键词
- 且 消息包含 "问题" 且 包含 "帮我"/"请问" 中的任一关键词
时触发。
## 相关代码
- [调度规则定义](../../../workflow/core/dispatch/rule.py)
- [调度规则注册表](../../../workflow/core/dispatch/registry.py)
- [调度器实现](../../../workflow/core/dispatch/dispatcher.py)
- [系统预设规则](../../../../data/dispatch_rules)
## 错误处理
所有 API 端点在发生错误时都会返回适当的 HTTP 状态码和错误信息:
```json
{
"error": "错误描述信息"
}
```
常见状态码:
- 400: 请求参数错误或规则配置无效
- 404: 规则不存在
- 409: 规则ID已存在
- 500: 服务器内部错误
## 使用示例
### 创建组合规则
```python
import requests
rule_data = {
"rule_id": "chat_creative",
"name": "创意聊天",
"description": "创意聊天模式",
"workflow_id": "chat:creative",
"priority": 5,
"enabled": True,
"rule_groups": [
{
"operator": "or",
"rules": [
{
"type": "prefix",
"config": {
"prefix": "/creative"
}
},
{
"type": "keyword",
"config": {
"keywords": ["创意", "发散"]
}
}
]
}
]
}
response = requests.post(
'http://localhost:8080/api/dispatch/rules',
headers={'Authorization': f'Bearer {token}'},
json=rule_data
)
```
### 更新规则优先级
```python
import requests
response = requests.put(
'http://localhost:8080/api/dispatch/rules/chat_creative',
headers={'Authorization': f'Bearer {token}'},
json={"priority": 8}
)
```
## 相关文档
- [工作流系统概述](../../README.md#工作流系统-)
- [调度规则配置指南](../../../workflow/README.md#调度规则配置)
- [API 认证](../../README.md#api认证-)
================================================
FILE: kirara_ai/web/api/dispatch/__init__.py
================================================
from .routes import dispatch_bp
__all__ = ["dispatch_bp"]
================================================
FILE: kirara_ai/web/api/dispatch/models.py
================================================
from typing import List
from pydantic import BaseModel
from kirara_ai.workflow.core.dispatch import CombinedDispatchRule
class DispatchRuleList(BaseModel):
"""调度规则列表"""
rules: List[CombinedDispatchRule]
class DispatchRuleResponse(BaseModel):
"""调度规则响应"""
rule: CombinedDispatchRule
================================================
FILE: kirara_ai/web/api/dispatch/routes.py
================================================
from quart import Blueprint, g, jsonify, request
from kirara_ai.workflow.core.dispatch import CombinedDispatchRule, DispatchRule, DispatchRuleRegistry
from kirara_ai.workflow.core.workflow import WorkflowRegistry
from ...auth.middleware import require_auth
from .models import DispatchRuleList, DispatchRuleResponse
dispatch_bp = Blueprint("dispatch", __name__)
@dispatch_bp.route("/rules", methods=["GET"])
@require_auth
async def list_rules():
"""获取所有调度规则"""
registry: DispatchRuleRegistry = g.container.resolve(DispatchRuleRegistry)
rules = registry.get_all_rules()
rules.sort(key=lambda x: x.priority, reverse=True)
rules = [rule.model_dump() for rule in rules]
return DispatchRuleList(rules=rules).model_dump()
@dispatch_bp.route("/rules/", methods=["GET"])
@require_auth
async def get_rule(rule_id: str):
"""获取特定调度规则的信息"""
registry: DispatchRuleRegistry = g.container.resolve(DispatchRuleRegistry)
rule = registry.get_rule(rule_id)
if not rule:
return jsonify({"error": "Rule not found"}), 404
return DispatchRuleResponse(rule=rule).model_dump()
@dispatch_bp.route("/rules", methods=["POST"])
@require_auth
async def create_rule():
"""创建新的调度规则"""
data = await request.get_json()
rule_data = CombinedDispatchRule(**data)
registry: DispatchRuleRegistry = g.container.resolve(DispatchRuleRegistry)
workflow_registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
# 检查规则ID是否已存在
if registry.get_rule(rule_data.rule_id):
return jsonify({"error": "Rule ID already exists"}), 400
# 检查工作流是否存在
if not workflow_registry.get(rule_data.workflow_id):
return jsonify({"error": "Workflow not found"}), 400
try:
# 创建规则
rule = registry.create_rule(rule_data)
# 保存规则
registry.save_rules()
return DispatchRuleResponse(rule=rule).model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 400
@dispatch_bp.route("/rules/", methods=["PUT"])
@require_auth
async def update_rule(rule_id: str):
"""更新调度规则"""
data = await request.get_json()
rule_data = CombinedDispatchRule(**data)
if rule_id != rule_data.rule_id:
return jsonify({"error": "Rule ID mismatch"}), 400
registry: DispatchRuleRegistry = g.container.resolve(DispatchRuleRegistry)
workflow_registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
# 检查规则是否存在
if not registry.get_rule(rule_id):
return jsonify({"error": "Rule not found"}), 404
# 检查工作流是否存在
if not workflow_registry.get(rule_data.workflow_id):
return jsonify({"error": "Workflow not found"}), 400
try:
# 更新规则
rule = registry.update_rule(rule_id, rule_data)
# 保存规则
registry.save_rules()
return DispatchRuleResponse(rule=rule).model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 400
@dispatch_bp.route("/rules/", methods=["DELETE"])
@require_auth
async def delete_rule(rule_id: str):
"""删除调度规则"""
registry: DispatchRuleRegistry = g.container.resolve(DispatchRuleRegistry)
# 检查规则是否存在
if not registry.get_rule(rule_id):
return jsonify({"error": "Rule not found"}), 404
# 删除规则
registry.delete_rule(rule_id)
# 保存规则
registry.save_rules()
return jsonify({"message": "Rule deleted successfully"})
@dispatch_bp.route("/rules//enable", methods=["POST"])
@require_auth
async def enable_rule(rule_id: str):
"""启用调度规则"""
registry: DispatchRuleRegistry = g.container.resolve(DispatchRuleRegistry)
rule = registry.get_rule(rule_id)
if not rule:
return jsonify({"error": "Rule not found"}), 404
if rule.enabled:
return jsonify({"error": "Rule is already enabled"}), 400
# 启用规则
registry.enable_rule(rule_id)
# 保存规则
registry.save_rules()
return jsonify({"message": "Rule enabled successfully"})
@dispatch_bp.route("/rules//disable", methods=["POST"])
@require_auth
async def disable_rule(rule_id: str):
"""禁用调度规则"""
registry: DispatchRuleRegistry = g.container.resolve(DispatchRuleRegistry)
rule = registry.get_rule(rule_id)
if not rule:
return jsonify({"error": "Rule not found"}), 404
if not rule.enabled:
return jsonify({"error": "Rule is already disabled"}), 400
# 禁用规则
registry.disable_rule(rule_id)
# 保存规则
registry.save_rules()
return jsonify({"message": "Rule disabled successfully"})
@dispatch_bp.route("/types", methods=["GET"])
@require_auth
async def get_rule_types():
"""获取所有可用的规则类型"""
return jsonify({"types": list(DispatchRule.rule_types.keys())})
@dispatch_bp.route("/types//config-schema", methods=["GET"])
@require_auth
async def get_rule_config_schema(rule_type: str):
"""获取指定规则类型的配置字段模式"""
try:
if rule_type not in DispatchRule.rule_types:
return jsonify({"error": "Invalid rule type"}), 404
rule_class = DispatchRule.rule_types[rule_type]
config_class = rule_class.config_class
schema = config_class.model_json_schema()
return jsonify({"configSchema": schema})
except Exception as e:
return jsonify({"error": str(e)}), 500
================================================
FILE: kirara_ai/web/api/im/README.md
================================================
# 即时通讯 API 🗨️
即时通讯 API 提供了管理 IM 后端和适配器的功能。通过这些 API,你可以注册、配置和管理不同的 IM 平台适配器。
## API 端点
### 获取适配器类型
```http
GET/backend-api/api/im/types
```
获取所有可用的 IM 适配器类型。
**响应示例:**
```json
{
"types": [
"mirai",
"telegram",
"discord"
]
}
```
### 获取所有适配器
```http
GET/backend-api/api/im/adapters
```
获取所有已配置的 IM 适配器信息。
**响应示例:**
```json
{
"adapters": [
{
"name": "telegram",
"adapter": "telegram",
"config": {
"token": "your-bot-token",
},
"is_running": true
}
]
}
```
### 获取特定适配器
```http
GET/backend-api/api/im/adapters/{adapter_id}
```
获取指定适配器的详细信息。
**响应示例:**
```json
{
"adapter": {
"name": "telegram",
"adapter": "telegram",
"config": {
"token": "your-bot-token",
},
"is_running": true
}
}
```
### 创建适配器
```http
POST/backend-api/api/im/adapters
```
注册新的 IM 适配器。
**请求体:**
```json
{
"name": "telegram",
"adapter": "telegram",
"config": {
"token": "your-bot-token",
}
}
```
### 更新适配器
```http
PUT/backend-api/api/im/adapters/{adapter_id}
```
更新现有适配器的配置。如果适配器正在运行,会自动重启以应用新配置。
**请求体:**
```json
{
"name": "telegram",
"adapter": "telegram",
"config": {
"token": "your-bot-token",
}
}
```
### 删除适配器
```http
DELETE/backend-api/api/im/adapters/{adapter_id}
```
删除指定的适配器。如果适配器正在运行,会先自动停止。
### 启动适配器
```http
POST/backend-api/api/im/adapters/{adapter_id}/start
```
启动指定的适配器。
### 停止适配器
```http
POST/backend-api/api/im/adapters/{adapter_id}/stop
```
停止指定的适配器。
### 获取适配器配置模式
```http
GET/backend-api/api/im/types/{adapter_type}/config-schema
```
获取指定适配器类型的配置字段模式。
**响应示例:**
```json
{
"schema": {
"title": "TelegramConfig",
"type": "object",
"properties": {
"token": {
"title": "Bot Token",
"type": "string",
"description": "Telegram Bot Token"
}
},
"required": ["token"]
}
}
```
## 数据模型
### IMAdapterConfig
- `name`: 适配器名称
- `adapter`: 适配器类型
- `config`: 配置信息(字典)
### IMAdapterStatus
继承自 IMAdapterConfig,额外包含:
- `is_running`: 适配器是否正在运行
### IMAdapterList
- `adapters`: IM 适配器列表
### IMAdapterResponse
- `adapter`: 适配器信息
### IMAdapterTypes
- `types`: 可用的适配器类型列表
### IMAdapterConfigSchema
- `error`: 错误信息(可选)
- `schema`: JSON Schema 格式的配置字段描述
## 适配器类型
适配器由插件提供,见[适配器实现](../../../im/adapters)。
## 相关代码
- [IM 管理器](../../../im/manager.py)
- [IM 注册表](../../../im/im_registry.py)
- [适配器实现](../../../im/adapters)
## 错误处理
所有 API 端点在发生错误时都会返回适当的 HTTP 状态码和错误信息:
```json
{
"error": "错误描述信息"
}
```
常见状态码:
- 400: 请求参数错误、适配器类型无效或适配器已在运行
- 404: 适配器不存在
- 500: 服务器内部错误
## 使用示例
### 获取适配器类型
```python
import requests
response = requests.get(
'http://localhost:8080/api/im/types',
headers={'Authorization': f'Bearer {token}'}
)
```
### 创建新适配器
```python
import requests
adapter_data = {
"name": "telegram",
"adapter": "telegram",
"config": {
"token": "your-bot-token",
"webhook_url": "https://example.com/webhook"
}
}
response = requests.post(
'http://localhost:8080/api/im/adapters',
headers={'Authorization': f'Bearer {token}'},
json=adapter_data
)
```
### 启动适配器
```python
import requests
response = requests.post(
'http://localhost:8080/api/im/adapters/telegram/start',
headers={'Authorization': f'Bearer {token}'}
)
```
## 相关文档
- [系统架构](../../README.md#系统架构-)
- [API 认证](../../README.md#api认证-)
- [IM 适配器开发](../../../im/README.md#适配器开发-)
================================================
FILE: kirara_ai/web/api/im/__init__.py
================================================
from .routes import im_bp
__all__ = ["im_bp"]
================================================
FILE: kirara_ai/web/api/im/models.py
================================================
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from kirara_ai.config.global_config import IMConfig
from kirara_ai.im.im_registry import IMAdapterInfo
from kirara_ai.im.profile import UserProfile
IMAdapterConfig = IMConfig
class IMAdapterStatus(IMAdapterConfig):
"""IM适配器状态"""
is_running: bool
bot_profile: Optional[UserProfile] = None
class IMAdapterList(BaseModel):
"""IM适配器列表响应"""
adapters: List[IMAdapterStatus]
class IMAdapterResponse(BaseModel):
"""单个IM适配器响应"""
adapter: IMAdapterStatus
class IMAdapterTypes(BaseModel):
"""可用的IM适配器类型列表"""
types: List[str]
adapters: Dict[str, IMAdapterInfo]
class IMAdapterConfigSchema(BaseModel):
"""IM适配器配置模式"""
error: Optional[str] = None
configSchema: Optional[Dict[str, Any]] = None
================================================
FILE: kirara_ai/web/api/im/routes.py
================================================
import asyncio
from quart import Blueprint, g, jsonify, request
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigJsonSchema, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.im.adapter import BotProfileAdapter
from kirara_ai.im.im_registry import IMRegistry
from kirara_ai.im.manager import IMManager
from kirara_ai.logger import get_logger
from ...auth.middleware import require_auth
from .models import (IMAdapterConfig, IMAdapterConfigSchema, IMAdapterList, IMAdapterResponse, IMAdapterStatus,
IMAdapterTypes)
im_bp = Blueprint("im", __name__)
logger = get_logger("Web.IM")
def _create_adapter(manager: IMManager, name: str, adapter: str, config: dict):
registry: IMRegistry = g.container.resolve(IMRegistry)
adapter_info = registry.get_all_adapters()[adapter]
adapter_class = adapter_info.adapter_class
adapter_config_class = adapter_info.config_class
adapter_config = adapter_config_class(**config)
manager.create_adapter(name, adapter_class, adapter_config)
@im_bp.route("/types", methods=["GET"])
@require_auth
async def get_adapter_types():
"""获取所有可用的适配器类型"""
registry: IMRegistry = g.container.resolve(IMRegistry)
adapters = registry.get_all_adapters()
types = [info.name for info in adapters.values()]
return IMAdapterTypes(types=types, adapters=adapters).model_dump()
@im_bp.route("/adapters", methods=["GET"])
@require_auth
async def list_adapters():
"""获取所有已配置的适配器"""
config = g.container.resolve(GlobalConfig)
manager = g.container.resolve(IMManager)
adapters = []
for im in config.ims:
is_running = manager.is_adapter_running(im.name)
configs = im.config
adapters.append(
IMAdapterStatus(
name=im.name, adapter=im.adapter, is_running=is_running, config=configs
)
)
return IMAdapterList(adapters=adapters).model_dump()
@im_bp.route("/adapters/", methods=["GET"])
@require_auth
async def get_adapter(adapter_id: str):
"""获取特定适配器的信息"""
manager: IMManager = g.container.resolve(IMManager)
# 查找适配器类型
if not manager.has_adapter(adapter_id):
return jsonify({"error": "Adapter not found"}), 404
adapter_config = manager.get_adapter_config(adapter_id)
adapter = manager.get_adapter(adapter_id)
bot_profile = None
if manager.is_adapter_running(adapter_id) and isinstance(adapter, BotProfileAdapter):
bot_profile = await adapter.get_bot_profile()
return IMAdapterResponse(
adapter=IMAdapterStatus(
name=adapter_id,
adapter=adapter_config.adapter,
is_running=manager.is_adapter_running(adapter_id),
config=adapter_config.config,
bot_profile=bot_profile
)
).model_dump()
@im_bp.route("/adapters", methods=["POST"])
@require_auth
async def create_adapter():
"""创建新的适配器"""
data = await request.get_json()
adapter_info = IMAdapterConfig(**data)
config: GlobalConfig = g.container.resolve(GlobalConfig)
registry: IMRegistry = g.container.resolve(IMRegistry)
manager: IMManager = g.container.resolve(IMManager)
# 检查适配器类型是否存在
if adapter_info.adapter not in registry.get_all_adapters():
return jsonify({"error": "Invalid adapter type"}), 400
# 检查ID是否已存在
if manager.has_adapter(adapter_info.name):
return jsonify({"error": "Adapter ID already exists"}), 400
# 更新配置
_create_adapter(manager, adapter_info.name, adapter_info.adapter, adapter_info.config)
if adapter_info.enable:
try:
await manager.start_adapter(adapter_info.name, asyncio.get_event_loop())
except Exception as e:
manager.delete_adapter(adapter_info.name)
return jsonify({"error": str(e)}), 500
config.ims.append(adapter_info)
# 保存配置到文件
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return IMAdapterResponse(
adapter=IMAdapterStatus(
name=adapter_info.name,
adapter=adapter_info.adapter,
is_running=False,
config=adapter_info.config,
)
).model_dump()
@im_bp.route("/adapters/", methods=["PUT"])
@require_auth
async def update_adapter(adapter_id: str):
"""更新适配器配置 (支持重命名)"""
data = await request.get_json()
try:
adapter_info = IMAdapterConfig(**data)
except Exception as e:
return jsonify({"error": f"Invalid request data: {e}"}), 400
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: IMManager = g.container.resolve(IMManager)
registry: IMRegistry = g.container.resolve(IMRegistry)
loop = asyncio.get_event_loop()
# 1. 检查原始适配器是否存在
if not manager.has_adapter(adapter_id):
return jsonify({"error": "Adapter not found"}), 404
# 2. 如果名称改变,检查新名称是否冲突
if adapter_id != adapter_info.name and manager.has_adapter(adapter_info.name):
return jsonify({"error": f"Adapter name '{adapter_info.name}' already exists"}), 400
# 3. 检查适配器类型是否有效
if adapter_info.adapter not in registry.get_all_adapters():
return jsonify({"error": "Invalid adapter type specified"}), 400
# --- 停止旧适配器 (如果正在运行) ---
if manager.is_adapter_running(adapter_id):
try:
await manager.stop_adapter(adapter_id, loop)
except Exception as e:
logger.error(f"Error stopping adapter {adapter_id}: {e}")
# --- 更新 IMManager ---
# 从管理器中删除旧的实例
manager.delete_adapter(adapter_id)
# 使用新名称和配置创建新的实例
_create_adapter(manager, adapter_info.name, adapter_info.adapter, adapter_info.config)
config.ims.append(adapter_info)
# --- 尝试启动新适配器 (如果启用) ---
is_now_running = False
if adapter_info.enable:
try:
await manager.start_adapter(adapter_info.name, loop)
is_now_running = True
except Exception as e:
logger.error(f"Failed to start adapter '{adapter_info.name}' after update: {e}")
# --- 保存配置到文件 ---
# 无论是否启动成功,都保存更新后的配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# --- 准备并返回响应 ---
bot_profile = None
if is_now_running: # 仅在成功启动后尝试获取 profile
adapter_instance = manager.get_adapter(adapter_info.name)
if isinstance(adapter_instance, BotProfileAdapter):
try:
# 添加超时以防卡住
bot_profile = await asyncio.wait_for(adapter_instance.get_bot_profile(), timeout=5.0)
except Exception as e:
logger.error(f"Failed to get bot profile for {adapter_info.name} after update: {e}")
return IMAdapterResponse(adapter=IMAdapterStatus(
name=adapter_info.name, # 使用新名称
adapter=adapter_info.adapter,
is_running=is_now_running, # 反映当前实际运行状态
config=adapter_info.config,
bot_profile=bot_profile
)).model_dump()
@im_bp.route("/adapters/", methods=["DELETE"])
@require_auth
async def delete_adapter(adapter_id: str):
"""删除适配器"""
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: IMManager = g.container.resolve(IMManager)
loop = asyncio.get_event_loop()
# 先停止适配器
if manager.is_adapter_running(adapter_id):
await manager.stop_adapter(adapter_id, loop)
# 从配置中删除
manager.delete_adapter(adapter_id)
# 保存配置到文件
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return jsonify({"message": "Adapter deleted successfully"})
@im_bp.route("/adapters//start", methods=["POST"])
@require_auth
async def start_adapter(adapter_id: str):
"""启动适配器"""
manager: IMManager = g.container.resolve(IMManager)
loop = asyncio.get_event_loop()
if manager.is_adapter_running(adapter_id):
return jsonify({"error": "Adapter is already running"}), 400
try:
await manager.start_adapter(adapter_id, loop)
return jsonify({"message": "Adapter started successfully"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@im_bp.route("/adapters//stop", methods=["POST"])
@require_auth
async def stop_adapter(adapter_id: str):
"""停止适配器"""
manager: IMManager = g.container.resolve(IMManager)
loop = asyncio.get_event_loop()
if not manager.is_adapter_running(adapter_id):
return jsonify({"error": "Adapter is not running"}), 400
try:
await manager.stop_adapter(adapter_id, loop)
return jsonify({"message": "Adapter stopped successfully"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@im_bp.route("/types//config-schema", methods=["GET"])
@require_auth
async def get_adapter_config_schema(adapter_type: str):
"""获取指定适配器类型的配置字段模式"""
try:
registry: IMRegistry = g.container.resolve(IMRegistry)
try:
config_class = registry.get_config_class(adapter_type)
except ValueError as e:
return jsonify({"error": str(e)}), 404
schema = config_class.model_json_schema(schema_generator=ConfigJsonSchema)
return IMAdapterConfigSchema(configSchema=schema).model_dump()
except Exception as e:
return IMAdapterConfigSchema(error=str(e)).model_dump()
================================================
FILE: kirara_ai/web/api/llm/README.md
================================================
# 大语言模型 API 🤖
大语言模型 API 提供了管理 LLM 后端和适配器的功能。通过这些 API,你可以注册、配置和管理不同的大语言模型服务。
## API 端点
### 获取适配器类型
```http
GET/backend-api/api/llm/types
```
获取所有可用的 LLM 适配器类型。
**响应示例:**
```json
{
"types": [
"openai",
"anthropic",
"azure",
"local"
]
}
```
### 获取所有后端
```http
GET/backend-api/api/llm/backends
```
获取所有已注册的 LLM 后端信息。
**响应示例:**
```json
{
"data": {
"backends": [
{
"name": "openai",
"adapter": "openai",
"config": {
"api_key": "sk-xxx",
"api_base": "https://api.openai.com/v1"
},
"enable": true,
"models": ["gpt-4", "gpt-3.5-turbo"]
}
]
}
}
```
### 获取特定后端
```http
GET/backend-api/api/llm/backends/{backend_name}
```
获取指定后端的详细信息。
**响应示例:**
```json
{
"data": {
"name": "anthropic",
"adapter": "anthropic",
"config": {
"api_key": "sk-xxx",
"api_base": "https://api.anthropic.com"
},
"enable": true,
"models": ["claude-3-opus", "claude-3-sonnet"]
}
}
```
### 创建后端
```http
POST/backend-api/api/llm/backends
```
注册新的 LLM 后端。
**请求体:**
```json
{
"name": "anthropic",
"adapter": "anthropic",
"config": {
"api_key": "sk-xxx",
"api_base": "https://api.anthropic.com"
},
"enable": true,
"models": ["claude-3-opus", "claude-3-sonnet"]
}
```
### 更新后端
```http
PUT/backend-api/api/llm/backends/{backend_name}
```
更新现有后端的配置。
**请求体:**
```json
{
"name": "anthropic",
"adapter": "anthropic",
"config": {
"api_key": "sk-xxx",
"api_base": "https://api.anthropic.com",
"temperature": 0.7
},
"enable": true,
"models": ["claude-3-opus", "claude-3-sonnet"]
}
```
### 删除后端
```http
DELETE/backend-api/api/llm/backends/{backend_name}
```
删除指定的后端。如果后端当前已启用,会先自动卸载。
### 获取适配器配置模式
```http
GET/backend-api/api/llm/types/{adapter_type}/config-schema
```
获取指定适配器类型的配置字段模式。
**响应示例:**
```json
{
"schema": {
"title": "OpenAIConfig",
"type": "object",
"properties": {
"api_key": {
"title": "API Key",
"type": "string",
"description": "OpenAI API密钥"
},
"api_base": {
"title": "API Base",
"type": "string",
"description": "API基础URL",
"default": "https://api.openai.com/v1"
},
"temperature": {
"title": "Temperature",
"type": "number",
"description": "生成温度",
"default": 0.7,
"minimum": 0,
"maximum": 2
}
},
"required": ["api_key"]
}
}
```
## 数据模型
### LLMBackendInfo
- `name`: 后端名称
- `adapter`: 适配器类型
- `config`: 配置信息(字典)
- `enable`: 是否启用
- `models`: 支持的模型列表
### LLMBackendList
- `backends`: LLM 后端列表
### LLMBackendResponse
- `error`: 错误信息(可选)
- `data`: 后端信息(可选)
### LLMBackendListResponse
- `error`: 错误信息(可选)
- `data`: 后端列表(可选)
### LLMAdapterTypes
- `types`: 可用的适配器类型列表
### LLMAdapterConfigSchema
- `error`: 错误信息(可选)
- `schema`: JSON Schema 格式的配置字段描述
## 适配器类型
适配器由插件提供,见[适配器实现](../../../llm/adapters)。
目前自带支持的适配器类型包括:
### OpenAI
- 适配器类型: `openai`
- 支持模型: gpt-4, gpt-3.5-turbo 等
- 配置项:
- `api_key`: API 密钥
- `api_base`: API 基础 URL
- `temperature`: 温度参数(可选)
### Anthropic
- 适配器类型: `anthropic`
- 支持模型: claude-3-opus, claude-3-sonnet 等
- 配置项:
- `api_key`: API 密钥
- `api_base`: API 基础 URL
- `temperature`: 温度参数(可选)
### Azure OpenAI
- 适配器类型: `azure`
- 支持 Azure OpenAI 服务部署的各类模型
- 配置项:
- `api_key`: API 密钥
- `api_base`: Azure 终结点
- `deployment_name`: 部署名称
### 本地模型
- 适配器类型: `local`
- 支持本地部署的开源模型
- 配置项:
- `model_path`: 模型路径
- `device`: 运行设备(cpu/cuda)
## 相关代码
- [LLM 管理器](../../../llm/llm_manager.py)
- [LLM 注册表](../../../llm/llm_registry.py)
- [适配器实现](../../../llm/adapters)
## 错误处理
所有 API 端点在发生错误时都会返回适当的 HTTP 状态码和错误信息:
```json
{
"error": "错误描述信息"
}
```
常见状态码:
- 400: 请求参数错误或后端配置无效
- 404: 后端不存在
- 500: 服务器内部错误
## 使用示例
### 获取适配器类型
```python
import requests
response = requests.get(
'http://localhost:8080/api/llm/types',
headers={'Authorization': f'Bearer {token}'}
)
```
### 创建新后端
```python
import requests
backend_data = {
"name": "anthropic",
"adapter": "anthropic",
"config": {
"api_key": "sk-xxx",
"api_base": "https://api.anthropic.com"
},
"enable": true,
"models": ["claude-3-opus", "claude-3-sonnet"]
}
response = requests.post(
'http://localhost:8080/api/llm/backends',
headers={'Authorization': f'Bearer {token}'},
json=backend_data
)
```
### 更新后端配置
```python
import requests
backend_data = {
"name": "anthropic",
"adapter": "anthropic",
"config": {
"api_key": "sk-xxx",
"api_base": "https://api.anthropic.com",
"temperature": 0.7
},
"enable": true,
"models": ["claude-3-opus", "claude-3-sonnet"]
}
response = requests.put(
'http://localhost:8080/api/llm/backends/anthropic',
headers={'Authorization': f'Bearer {token}'},
json=backend_data
)
```
## 相关文档
- [系统架构](../../README.md#系统架构-)
- [API 认证](../../README.md#api认证-)
- [LLM 适配器开发](../../../llm/README.md#适配器开发-)
================================================
FILE: kirara_ai/web/api/llm/__init__.py
================================================
from .routes import llm_bp
__all__ = ["llm_bp"]
================================================
FILE: kirara_ai/web/api/llm/models.py
================================================
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from kirara_ai.config.global_config import LLMBackendConfig, ModelConfig
class LLMBackendInfo(LLMBackendConfig):
"""LLM后端信息"""
class LLMBackendList(BaseModel):
"""LLM后端列表"""
backends: List[LLMBackendInfo]
class LLMBackendResponse(BaseModel):
"""LLM后端响应"""
error: Optional[str] = None
data: Optional[LLMBackendInfo] = None
class LLMBackendListResponse(BaseModel):
"""LLM后端列表响应"""
error: Optional[str] = None
data: Optional[LLMBackendList] = None
class LLMBackendCreateRequest(LLMBackendConfig):
"""创建LLM后端请求"""
class LLMBackendUpdateRequest(LLMBackendConfig):
"""更新LLM后端请求"""
class LLMAdapterTypes(BaseModel):
"""可用的LLM适配器类型列表"""
types: List[str]
class LLMAdapterConfigSchema(BaseModel):
"""LLM适配器配置模式"""
error: Optional[str] = None
configSchema: Optional[Dict[str, Any]] = None
class ModelConfigListResponse(BaseModel):
"""模型配置列表响应"""
error: Optional[str] = None
models: List[ModelConfig] = []
================================================
FILE: kirara_ai/web/api/llm/routes.py
================================================
from quart import Blueprint, g, jsonify, request
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.llm.adapter import AutoDetectModelsProtocol
from kirara_ai.llm.llm_manager import LLMManager
from kirara_ai.llm.llm_registry import LLMBackendRegistry
from kirara_ai.llm.model_types import LLMAbility, ModelType
from kirara_ai.logger import get_logger
from kirara_ai.web.api.llm.models import (LLMAdapterConfigSchema, LLMAdapterTypes, LLMBackendCreateRequest,
LLMBackendInfo, LLMBackendList, LLMBackendListResponse, LLMBackendResponse,
LLMBackendUpdateRequest, ModelConfigListResponse)
from ...auth.middleware import require_auth
llm_bp = Blueprint("llm", __name__)
logger = get_logger("WebServer.LLM")
@llm_bp.route("/types", methods=["GET"])
@require_auth
async def get_adapter_types():
"""获取所有可用的适配器类型"""
registry: LLMBackendRegistry = g.container.resolve(LLMBackendRegistry)
return LLMAdapterTypes(types=registry.get_adapter_types()).model_dump()
@llm_bp.route("/backends", methods=["GET"])
@require_auth
async def list_backends():
"""获取所有后端列表"""
try:
config: GlobalConfig = g.container.resolve(GlobalConfig)
backends = []
for backend in config.llms.api_backends:
backends.append(
LLMBackendInfo(
name=backend.name,
adapter=backend.adapter,
config=backend.config,
enable=backend.enable,
models=backend.models,
)
)
return LLMBackendListResponse(
data=LLMBackendList(backends=backends)
).model_dump()
except Exception as e:
logger.opt(exception=e).error("Failed to list backends")
return jsonify({"error": str(e)}), 500
@llm_bp.route("/backends/", methods=["GET"])
@require_auth
async def get_backend(backend_name: str):
"""获取指定后端信息"""
try:
config: GlobalConfig = g.container.resolve(GlobalConfig)
backend = next(
(b for b in config.llms.api_backends if b.name == backend_name), None
)
if not backend:
return jsonify({"error": f"Backend {backend_name} not found"}), 404
return LLMBackendResponse(
data=LLMBackendInfo(
name=backend.name,
adapter=backend.adapter,
config=backend.config,
enable=backend.enable,
models=backend.models,
)
).model_dump()
except Exception as e:
logger.opt(exception=e).error("Failed to get backend")
return jsonify({"error": str(e)}), 500
@llm_bp.route("/backends", methods=["POST"])
@require_auth
async def create_backend():
"""创建新的后端"""
try:
data = await request.get_json()
request_data = LLMBackendCreateRequest(**data)
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: LLMManager = g.container.resolve(LLMManager)
# 检查后端名称是否已存在
if any(b.name == request_data.name for b in config.llms.api_backends):
return (
jsonify({"error": f"Backend {request_data.name} already exists"}),
400,
)
# 创建新的后端配置
backend = LLMBackendInfo(
name=request_data.name,
adapter=request_data.adapter,
config=request_data.config,
enable=request_data.enable,
models=request_data.models,
)
# 添加到配置中
config.llms.api_backends.append(backend)
# 如果启用则加载后端
if backend.enable:
manager.load_backend(backend.name)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return LLMBackendResponse(data=backend).model_dump()
except Exception as e:
logger.opt(exception=e).error("Failed to create backend")
return jsonify({"error": str(e)}), 500
@llm_bp.route("/backends/", methods=["PUT"])
@require_auth
async def update_backend(backend_name: str):
"""更新指定后端"""
try:
data = await request.get_json()
request_data = LLMBackendUpdateRequest(**data)
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: LLMManager = g.container.resolve(LLMManager)
# 查找要更新的后端
backend_index = next(
(
i
for i, b in enumerate(config.llms.api_backends)
if b.name == backend_name
),
-1,
)
if backend_index == -1:
return jsonify({"error": f"Backend {backend_name} not found"}), 404
# 创建更新后的后端配置
updated_backend = LLMBackendInfo(
name=request_data.name,
adapter=request_data.adapter,
config=request_data.config,
enable=request_data.enable,
models=request_data.models,
)
# 如果原后端已启用,先卸载
if config.llms.api_backends[backend_index].enable:
await manager.unload_backend(backend_name)
# 更新配置
config.llms.api_backends[backend_index] = updated_backend
# 如果新配置启用则加载后端
if updated_backend.enable:
manager.load_backend(updated_backend.name)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return LLMBackendResponse(data=updated_backend).model_dump()
except Exception as e:
logger.opt(exception=e).error("Failed to update backend")
return jsonify({"error": str(e)}), 500
@llm_bp.route("/backends/", methods=["DELETE"])
@require_auth
async def delete_backend(backend_name: str):
"""删除指定后端"""
try:
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: LLMManager = g.container.resolve(LLMManager)
# 查找要删除的后端
backend_index = next(
(
i
for i, b in enumerate(config.llms.api_backends)
if b.name == backend_name
),
-1,
)
if backend_index == -1:
return jsonify({"error": f"Backend {backend_name} not found"}), 404
backend = config.llms.api_backends[backend_index]
# 如果后端已启用,要卸载
if backend.enable:
await manager.unload_backend(backend_name)
# 从配置中删除
deleted_backend = config.llms.api_backends.pop(backend_index)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return LLMBackendResponse(
data=LLMBackendInfo(
name=deleted_backend.name,
adapter=deleted_backend.adapter,
config=deleted_backend.config,
enable=deleted_backend.enable,
models=deleted_backend.models,
)
).model_dump()
except Exception as e:
logger.opt(exception=e).error("Failed to delete backend")
return jsonify({"error": str(e)}), 500
@llm_bp.route("/types//config-schema", methods=["GET"])
@require_auth
async def get_adapter_config_schema(adapter_type: str):
"""获取指定适配器类型的配置字段模式"""
try:
registry: LLMBackendRegistry = g.container.resolve(LLMBackendRegistry)
config_class = registry.get_config_class(adapter_type)
if not config_class:
return jsonify({"error": f"Adapter type {adapter_type} not found"}), 404
schema = config_class.model_json_schema()
return LLMAdapterConfigSchema(configSchema=schema).model_dump()
except Exception as e:
logger.opt(exception=e).error("Failed to get adapter config schema")
return jsonify({"error": str(e)}), 500
@llm_bp.route("/types//supports-auto-detect-models", methods=["GET"])
@require_auth
async def supports_auto_detect_models(adapter_type: str):
"""检查指定适配器类型是否支持自动检测模型"""
try:
registry: LLMBackendRegistry = g.container.resolve(LLMBackendRegistry)
adapter_class = registry.get(adapter_type)
if not adapter_class:
return jsonify({"error": f"Adapter type {adapter_type} not found"}), 404
if not issubclass(adapter_class, AutoDetectModelsProtocol):
return (
jsonify(
{
"error": f"Adapter type {adapter_type} does not support auto-detect models"
}
),
400,
)
return jsonify({"supportsAutoDetectModels": True})
except Exception as e:
logger.opt(exception=e).error("Failed to check if adapter supports auto-detect models")
return jsonify({"error": str(e)}), 500
@llm_bp.route("/backends//auto-detect-models", methods=["GET"])
@require_auth
async def auto_detect_models(backend_name: str):
"""自动检测指定后端的模型列表"""
try:
manager: LLMManager = g.container.resolve(LLMManager)
adapter = manager.get(backend_name)
if not adapter:
return jsonify({"error": f"Backend {backend_name} not found"}), 404
if not isinstance(adapter, AutoDetectModelsProtocol):
return (
jsonify(
{
"error": f"Backend {backend_name} does not support auto-detect models"
}
),
400,
)
# 自动检测模型并返回完整的ModelConfig列表
models = await adapter.auto_detect_models()
# 确保每个模型都有正确的能力设置
for model in models:
# 如果模型类型是LLM但没有设置能力,设置默认的TextChat能力
if model.type == ModelType.LLM.value and not model.ability:
model.ability = LLMAbility.TextChat.value
return ModelConfigListResponse(models=models).model_dump()
except Exception as e:
logger.opt(exception=e).error("Failed to auto-detect models")
return jsonify({"error": str(e)}), 500
================================================
FILE: kirara_ai/web/api/mcp/README.md
================================================
# MCP 服务器管理 API
MCP(Model Context Protocol)是一种用于大型语言模型的通信协议。本 API 提供了管理 MCP 服务器的功能,包括创建、更新、删除、启动和停止服务器,以及获取服务器提供的工具列表。
## 数据模型
### MCP 服务器(MCPServer)
```json
{
"id": "claude-mcp",
"description": "Claude MCP 服务器",
"command": "python",
"args": "-m claude_cli.mcp",
"connection_type": "stdio",
"status": "stopped",
"error_message": null,
"created_at": "2023-04-01T12:00:00",
"last_used_at": null
}
```
### 连接类型(ConnectionType)
- `stdio`: 标准输入输出连接
- `sse`: 服务器发送事件连接
### 服务器状态(ServerStatus)
- `running`: 运行中
- `stopped`: 已停止
- `error`: 错误状态
## API 端点
### 获取服务器列表
获取 MCP 服务器列表,支持分页和过滤。
**请求**:
```
POST /mcp/servers
```
**请求体**:
```json
{
"page": 1,
"page_size": 20,
"page_size": 20,
"connection_type": null,
"status": null,
"query": null
}
```
**响应**:
```json
{
"items": [
{
"id": "claude-mcp",
"description": "Claude MCP 服务器",
"command": "python",
"args": "-m claude_cli.mcp",
"connection_type": "stdio",
"status": "stopped",
"error_message": null,
"created_at": "2023-04-01T12:00:00",
"last_used_at": null
}
],
"total": 1,
"page": 1,
"page_size": 20,
"total_pages": 1
}
```
### 获取统计信息
获取 MCP 服务器相关的统计信息。
**请求**:
```
GET /mcp/statistics
```
**响应**:
```json
{
"total_servers": 3,
"stdio_servers": 2,
"sse_servers": 1,
"running_servers": 1,
"stopped_servers": 2
}
```
### 获取服务器详情
获取特定 MCP 服务器的详细信息。
**请求**:
```
GET /mcp/servers/{server_id}
```
**响应**:
```json
{
"id": "claude-mcp",
"description": "Claude MCP 服务器",
"command": "python",
"args": "-m claude_cli.mcp",
"connection_type": "stdio",
"status": "stopped",
"error_message": null,
"created_at": "2023-04-01T12:00:00",
"last_used_at": null
}
```
### 获取服务器工具列表
获取 MCP 服务器提供的工具列表。
**请求**:
```
GET /mcp/servers/{server_id}/tools
```
**响应**:
```json
[
{
"name": "search_web",
"description": "搜索网络获取信息",
"parameters": {
"query": {
"type": "string",
"description": "搜索查询"
}
}
},
{
"name": "analyze_image",
"description": "分析图像内容",
"parameters": {
"image_url": {
"type": "string",
"description": "图像URL"
},
"detect_faces": {
"type": "boolean",
"description": "是否检测人脸"
}
}
}
]
```
### 检查服务器 ID 是否可用
检查给定的服务器 ID 是否已存在。
**请求**:
```
GET /mcp/servers/check/{server_id}
```
**响应**:
返回布尔值,`true` 表示 ID 可用,`false` 表示 ID 已存在。
### 创建服务器
创建新的 MCP 服务器。
**请求**:
```
POST /mcp/servers/create
```
**请求体**:
```json
{
"id": "claude-mcp",
"description": "Claude MCP 服务器",
"command": "python",
"args": "-m claude_cli.mcp",
"connection_type": "stdio"
}
```
**响应**:
返回创建的服务器详情。
### 更新服务器
更新现有 MCP 服务器的配置。
**请求**:
```
PUT /mcp/servers/{server_id}
```
**请求体**:
```json
{
"description": "Updated description",
"command": "python3",
"args": "-m claude_cli.mcp --verbose",
"connection_type": "stdio"
}
```
**响应**:
返回更新后的服务器详情。
### 删除服务器
删除 MCP 服务器。
**请求**:
```
DELETE /mcp/servers/{server_id}
```
**响应**:
```json
{
"message": "服务器已成功删除"
}
```
### 启动服务器
启动 MCP 服务器。
**请求**:
```
POST /mcp/servers/{server_id}/start
```
**响应**:
```json
{
"message": "服务器已启动"
}
```
### 停止服务器
停止 MCP 服务器。
**请求**:
```
POST /mcp/servers/{server_id}/stop
```
**响应**:
```json
{
"message": "服务器已停止"
}
```
## 使用示例
### 创建并启动 MCP 服务器
1. 创建服务器:
```
POST /mcp/servers/create
```
```json
{
"id": "claude-mcp",
"description": "Claude MCP 服务器",
"command": "python",
"args": "-m claude_cli.mcp",
"connection_type": "stdio"
}
```
2. 启动服务器:
```
POST /mcp/servers/claude-mcp/start
```
3. 获取服务器提供的工具:
```
GET /mcp/servers/claude-mcp/tools
```
## 注意事项
1. 在更新服务器配置前,必须先停止正在运行的服务器。
2. 服务器 ID 必须是唯一的,可以使用 `/mcp/servers/check/{server_id}` 接口检查 ID 是否可用。
3. 创建服务器后,服务器状态默认为 `stopped`,需要手动启动。
================================================
FILE: kirara_ai/web/api/mcp/__init__.py
================================================
from .routes import mcp_bp
__all__ = ["mcp_bp"]
================================================
FILE: kirara_ai/web/api/mcp/models.py
================================================
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class MCPServerInfo(BaseModel):
"""MCP服务器信息"""
id: str
description: Optional[str] = None
connection_type: str
command: Optional[str] = None
args: str = Field(default="")
url: Optional[str] = None
connection_state: str
class MCPToolInfo(BaseModel):
"""MCP工具信息"""
name: str
description: Optional[str] = None
input_schema: Dict[str, Any] = Field(default_factory=dict)
class MCPServerList(BaseModel):
"""MCP服务器列表"""
items: List[MCPServerInfo]
total: int
page: int
page_size: int
total_pages: int
class MCPStatistics(BaseModel):
"""MCP统计信息"""
total_servers: int
stdio_servers: int
sse_servers: int
connected_servers: int
disconnected_servers: int
error_servers: int
total_tools: int
class MCPServerCreateRequest(BaseModel):
"""创建MCP服务器请求"""
id: str
description: Optional[str] = None
command: str
args: str
connection_type: str
class MCPServerUpdateRequest(BaseModel):
"""更新MCP服务器请求"""
description: Optional[str] = None
command: Optional[str] = None
args: str = Field(default="")
connection_type: Optional[str] = None
url: Optional[str] = None
headers: Optional[Dict[str, str]] = None
env: Optional[Dict[str, str]] = None
================================================
FILE: kirara_ai/web/api/mcp/routes.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Any, Dict, Optional, cast
from quart import Blueprint, g, jsonify, request
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig, MCPServerConfig
from kirara_ai.logger import get_logger
from kirara_ai.mcp_module import MCPConnectionState, MCPServer, MCPServerManager
from ...auth.middleware import require_auth
from .models import (MCPServerCreateRequest, MCPServerInfo, MCPServerList, MCPServerUpdateRequest, MCPStatistics,
MCPToolInfo)
# 创建蓝图
mcp_bp = Blueprint("mcp", __name__)
logger = get_logger("WebServer.MCP")
def _convert_to_server_info(server: MCPServer) -> MCPServerInfo:
"""将服务器对象转换为MCPServerInfo响应对象"""
return MCPServerInfo(
id=server.server_config.id,
description=server.server_config.description,
connection_type=server.server_config.connection_type,
command=server.server_config.command,
args=" ".join(server.server_config.args) if isinstance(
server.server_config.args, list) else server.server_config.args,
url=getattr(server.server_config, 'url', None),
connection_state=server.state.name.lower()
)
@mcp_bp.route("/servers", methods=["GET"])
@require_auth
async def list_servers():
"""获取所有MCP服务器列表"""
try:
# 获取查询参数
page = request.args.get('page', 1, type=int)
page_size = request.args.get('page_size', 20, type=int)
connection_type = request.args.get('connection_type')
status = request.args.get('status')
query = request.args.get('query')
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 获取所有服务器
servers = manager.get_all_servers()
# 转换为响应格式
server_list = []
for server_id, server in servers.items():
# 过滤条件
if connection_type and server.server_config.connection_type != connection_type:
continue
server_state = server.state.name.lower()
if status:
if status == 'connected' and server_state != 'connected':
continue
elif status == 'disconnected' and server_state != 'disconnected':
continue
elif status == 'error' and server_state != 'error':
continue
if query and query.lower() not in server_id.lower() and (
not server.server_config.command or query.lower() not in server.server_config.command.lower()):
continue
server_list.append(_convert_to_server_info(server))
# 计算分页
total = len(server_list)
total_pages = (total + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total)
paginated_servers = server_list[start_idx:end_idx]
# 返回响应
return MCPServerList(
items=[server for server in paginated_servers],
total=total,
page=page,
page_size=page_size,
total_pages=total_pages
).model_dump()
except Exception as e:
logger.opt(exception=e).error("获取MCP服务器列表失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/statistics", methods=["GET"])
@require_auth
async def get_statistics():
"""获取MCP服务器统计信息"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 获取统计信息
stats = manager.get_statistics()
# 获取工具总数
tools = manager.get_tools()
total_tools = len(tools)
# 返回响应
return MCPStatistics(
total_servers=stats.get("total", 0),
stdio_servers=stats.get("stdio", 0),
sse_servers=stats.get("sse", 0),
connected_servers=stats.get("connected", 0),
disconnected_servers=stats.get("disconnected", 0),
error_servers=stats.get("error", 0),
total_tools=total_tools
).model_dump()
except Exception as e:
logger.opt(exception=e).error("获取MCP统计信息失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers/", methods=["GET"])
@require_auth
async def get_server(server_id: str):
"""获取特定MCP服务器的详情"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 获取服务器
server = manager.get_server(server_id)
if not server:
return jsonify({"message": f"服务器 {server_id} 不存在"}), 404
# 转换为响应格式
server_info = _convert_to_server_info(server)
# 返回响应
return server_info.model_dump()
except Exception as e:
logger.opt(exception=e).error(f"获取MCP服务器 {server_id} 详情失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers//tools", methods=["GET"])
@require_auth
async def get_server_tools(server_id: str):
"""获取MCP服务器提供的工具列表"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 获取服务器
server = manager.get_server(server_id)
if not server:
return jsonify({"message": f"服务器 {server_id} 不存在"}), 404
# 如果服务器未连接,返回空列表
if server.state != MCPConnectionState.CONNECTED:
return []
# 获取服务器工具
tools = manager.get_tools()
# 转换为响应格式
tool_list = []
for _, tool in tools.items():
if tool.server_id == server_id:
tool_list.append(MCPToolInfo(
name=tool.original_name,
description=tool.tool_info.description,
input_schema=tool.tool_info.inputSchema
))
# 返回响应
return [tool.model_dump() for tool in tool_list]
except Exception as e:
logger.opt(exception=e).error(f"获取MCP服务器 {server_id} 工具列表失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers/check/", methods=["GET"])
@require_auth
async def check_server_id(server_id: str):
"""检查服务器ID是否可用"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 检查ID是否可用
is_available = manager.is_server_id_available(server_id)
# 返回响应
return jsonify({
"is_available": is_available
})
except Exception as e:
logger.opt(exception=e).error(f"检查服务器ID {server_id} 可用性失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers", methods=["POST"])
@require_auth
async def create_server():
"""创建新的MCP服务器"""
try:
# 获取请求数据
data = await request.get_json()
request_data = MCPServerCreateRequest(**data)
# 从容器中获取全局配置和MCP服务器管理器
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 检查ID是否已存在
if not manager.is_server_id_available(request_data.id):
return jsonify({"message": f"服务器ID '{request_data.id}' 已存在或服务器正在运行"}), 409
# 创建新的MCP服务器配置
new_server_config = MCPServerConfig(
id=request_data.id,
description=request_data.description or "",
command=request_data.command,
args=request_data.args.split(" "),
connection_type=request_data.connection_type,
enable=True
)
# 添加到全局配置中
config.mcp.servers.append(new_server_config)
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# 让管理器加载新服务器
server = manager.load_server(new_server_config)
# 转换为响应格式
server_info = _convert_to_server_info(server)
# 返回响应
return server_info.model_dump()
except Exception as e:
logger.opt(exception=e).error("创建MCP服务器失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers/", methods=["PUT"])
@require_auth
async def update_server(server_id: str):
"""更新MCP服务器配置"""
try:
# 获取请求数据
data = await request.get_json()
request_data = MCPServerUpdateRequest(**data)
# 从容器中获取全局配置和MCP服务器管理器
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 查找服务器配置
server_index = -1
for i, server in enumerate(config.mcp.servers):
if server.id == server_id:
server_index = i
break
if server_index == -1:
return jsonify({"message": f"服务器 '{server_id}' 不存在"}), 404
# 检查服务器状态
current_server = manager.get_server(server_id)
if current_server and current_server.state == MCPConnectionState.CONNECTED:
return jsonify({"message": "无法更新正在运行的服务器,请先停止服务器"}), 409
# 更新服务器配置
server_config = config.mcp.servers[server_index]
if request_data.description is not None:
server_config.description = request_data.description
if request_data.command is not None:
server_config.command = request_data.command
if request_data.args is not None:
server_config.args = request_data.args.split(" ")
if request_data.connection_type is not None:
server_config.connection_type = request_data.connection_type
if request_data.url is not None:
server_config.url = request_data.url
if request_data.headers is not None:
server_config.headers = request_data.headers
if request_data.env is not None:
server_config.env = request_data.env
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# 停止服务器
await manager.stop_server(server_id)
# 重新加载服务器
current_server = manager.load_server(server_config)
try:
await manager.connect_server(server_id)
except Exception as e:
logger.opt(exception=e).error(f"重新连接MCP服务器 {server_id} 失败")
return jsonify({"message": str(e)}), 500
# 转换为响应格式
server_info = _convert_to_server_info(current_server)
# 返回响应
return server_info.model_dump()
except Exception as e:
logger.opt(exception=e).error(f"更新MCP服务器 {server_id} 失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers/", methods=["DELETE"])
@require_auth
async def delete_server(server_id: str):
"""删除MCP服务器"""
try:
# 从容器中获取全局配置和MCP服务器管理器
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 查找服务器配置
server_index = -1
for i, server in enumerate(config.mcp.servers):
if server.id == server_id:
server_index = i
break
if server_index == -1:
return jsonify({"message": f"服务器 '{server_id}' 不存在"}), 404
# 如果服务器正在运行,先停止它
current_server = manager.get_server(server_id)
if current_server and current_server.state == MCPConnectionState.CONNECTED:
await manager.stop_server(server_id)
# 从配置中删除服务器
removed_server = config.mcp.servers.pop(server_index)
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
await manager.stop_server(server_id)
# 返回响应
return jsonify({})
except Exception as e:
logger.opt(exception=e).error(f"删除MCP服务器 {server_id} 失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers//start", methods=["POST"])
@require_auth
async def start_server(server_id: str):
"""连接 MCP 服务器"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 尝试连接服务器
success = await manager.connect_server(server_id)
if not success:
return jsonify({"message": f"服务器 '{server_id}' 不存在或无法连接"}), 404
# 返回响应
return jsonify({})
except Exception as e:
logger.opt(exception=e).error(f"连接 MCP 服务器 {server_id} 失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers//stop", methods=["POST"])
@require_auth
async def stop_server(server_id: str):
"""断开 MCP 服务器"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 尝试停止服务器
success = await manager.stop_server(server_id)
if not success:
return jsonify({"message": f"服务器 '{server_id}' 不存在或未连接"}), 404
# 返回响应
return jsonify({})
except Exception as e:
logger.opt(exception=e).error(f"断开 MCP 服务器 {server_id} 失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/tools", methods=["GET"])
@require_auth
async def get_all_tools():
"""获取所有可用工具"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 获取所有工具
tools = manager.get_tools()
# 转换为响应格式
tool_list = []
for name, tool_info in tools.items():
tool_list.append(MCPToolInfo(
name=name,
description=tool_info.tool_info.description,
input_schema=tool_info.tool_info.inputSchema
))
# 返回响应
return [tool.model_dump() for tool in tool_list]
except Exception as e:
logger.opt(exception=e).error("获取所有工具失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers//tools/call", methods=["POST"])
@require_auth
async def call_tool(server_id: str):
"""调用工具"""
try:
data: Dict[str, str | Dict[str, Any]] = await request.get_json()
toolName: str = cast(str, data.get("toolName"))
params: Dict[str, Any] = cast(Dict[str, Any], data.get("params"))
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 获取服务器
server: Optional[MCPServer] = manager.get_server(server_id)
if not server:
return jsonify({"message": f"服务器 '{server_id}' 不存在"}), 404
# 获取工具
result = await server.call_tool(toolName, params)
# 返回响应
return jsonify({"result": result.model_dump()})
except Exception as e:
logger.opt(exception=e).error(f"调用工具 {toolName} 失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers//prompts", methods=["GET"])
@require_auth
async def get_server_prompts(server_id: str):
"""获取MCP服务器提供的提示词列表"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
server = manager.get_server(server_id)
if not server:
return jsonify({"message": f"服务器 {server_id} 不存在"}), 404
prompts = await manager.get_prompt_list(server_id)
if prompts is None:
return jsonify({"message": f"服务器 {server_id} 未连接"}), 404
return jsonify(prompts)
except Exception as e:
logger.opt(exception=e).error(f"获取MCP服务器 {server_id} 提示词列表失败")
return jsonify({"message": str(e)}), 500
@mcp_bp.route("/servers//resources", methods=["GET"])
@require_auth
async def get_server_resources(server_id: str):
"""获取MCP服务器提供的资源列表"""
try:
# 从容器中获取MCP服务器管理器
manager: MCPServerManager = g.container.resolve(MCPServerManager)
# 获取服务器
server = manager.get_server(server_id)
if not server:
return jsonify({"message": f"服务器 {server_id} 不存在"}), 404
resources = await manager.get_resource_list(server_id)
if resources is None:
return jsonify({"message": f"服务器 {server_id} 未连接"}), 404
return jsonify(resources)
except Exception as e:
logger.opt(exception=e).error(f"获取MCP服务器 {server_id} 资源列表失败")
return jsonify({"message": str(e)}), 500
================================================
FILE: kirara_ai/web/api/media/README.md
================================================
# 媒体管理API
本模块提供了媒体文件管理的API,包括上传、查询、下载和删除媒体文件。
## API接口
### 获取媒体列表
```
POST /api/media/list
```
请求体:
```json
{
"query": "搜索关键词",
"content_type": "image/jpeg",
"start_date": "2023-01-01T00:00:00Z",
"end_date": "2023-12-31T23:59:59Z",
"tags": ["tag1", "tag2"],
"page": 1,
"page_size": 20
}
```
响应:
```json
{
"items": [
{
"id": "media_id",
"url": "/api/media/file/filename.jpg",
"thumbnail_url": "/api/media/thumbnails/media_id.jpg",
"metadata": {
"filename": "filename.jpg",
"content_type": "image/jpeg",
"size": 12345,
"width": 800,
"height": 600,
"upload_time": "2023-01-01T12:00:00Z",
"source": "qq_group",
"uploader": "user1",
"tags": ["tag1", "tag2"]
}
}
],
"total": 100,
"has_more": true
}
```
### 获取媒体文件
```
GET /api/media/file/{filename}
```
返回媒体文件的二进制内容。
### 获取缩略图
```
GET /api/media/thumbnails/{media_id}.jpg
```
返回媒体文件的缩略图。
### 上传媒体文件
```
POST /api/media/upload
```
使用 multipart/form-data 上传文件。
可选参数:
- source: 来源
- uploader: 上传者
- tags: 标签(逗号分隔)
响应:
```json
{
"id": "media_id",
"url": "/api/media/file/filename.jpg",
"thumbnail_url": "/api/media/thumbnails/media_id.jpg",
"metadata": {
"filename": "filename.jpg",
"content_type": "image/jpeg",
"size": 12345,
"width": 800,
"height": 600,
"upload_time": "2023-01-01T12:00:00Z",
"source": "qq_group",
"uploader": "user1",
"tags": ["tag1", "tag2"]
}
}
```
### 删除媒体文件
```
DELETE /api/media/delete/{media_id}
```
响应:
```json
{
"success": true
}
```
### 批量删除媒体文件
```
POST /api/media/batch-delete
```
请求体:
```json
{
"ids": ["media_id1", "media_id2"]
}
```
响应:
```json
{
"success": true,
"deleted_count": 2
}
```
## 数据模型
### MediaMetadata
媒体文件的元数据。
| 字段名 | 类型 | 说明 |
|-------|------|------|
| filename | string | 文件名 |
| content_type | string | 内容类型 |
| size | integer | 文件大小(字节) |
| width | integer | 图片宽度(可选) |
| height | integer | 图片高度(可选) |
| upload_time | datetime | 上传时间 |
| source | string | 来源(可选) |
| uploader | string | 上传者(可选) |
| tags | array | 标签列表 |
### MediaItem
媒体文件项。
| 字段名 | 类型 | 说明 |
|-------|------|------|
| id | string | 媒体ID |
| url | string | 媒体URL |
| thumbnail_url | string | 缩略图URL(可选) |
| metadata | MediaMetadata | 媒体元数据 |
================================================
FILE: kirara_ai/web/api/media/__init__.py
================================================
from .routes import media_bp
__all__ = ["media_bp"]
================================================
FILE: kirara_ai/web/api/media/models.py
================================================
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
class MediaMetadata(BaseModel):
"""媒体元数据"""
filename: str
content_type: str
size: int # 文件大小,单位字节
upload_time: datetime # 上传时间
source: Optional[str] = None # 来源
tags: List[str] = [] # 标签
references: List[str] = []
class MediaItem(BaseModel):
"""媒体项"""
id: str # 媒体ID
url: str # 媒体URL
thumbnail_url: Optional[str] = None # 缩略图URL
metadata: MediaMetadata
class MediaListResponse(BaseModel):
"""媒体列表响应"""
items: List[MediaItem]
total: int
has_more: bool # 是否有更多数据
page_size: int # 每页数量
class MediaSearchParams(BaseModel):
"""媒体搜索参数"""
query: Optional[str] = None # 搜索关键词
content_type: Optional[str] = None # 媒体类型
start_date: Optional[datetime] = None # 开始日期
end_date: Optional[datetime] = None # 结束日期
tags: List[str] = [] # 标签
page: int = 1 # 页码
page_size: int = 20 # 每页数量
class MediaBatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[str] # 要删除的媒体ID列表
================================================
FILE: kirara_ai/web/api/media/routes.py
================================================
import asyncio
import io
import os
import shutil
import time
from typing import Optional
import pytz
from quart import Blueprint, g, jsonify, request, send_file
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.logger import get_logger
from kirara_ai.media.manager import MediaManager
from kirara_ai.media.media_object import Media
from kirara_ai.media.types.media_type import MediaType
from ...auth.middleware import require_auth
from .models import MediaBatchDeleteRequest, MediaItem, MediaListResponse, MediaMetadata, MediaSearchParams
media_bp = Blueprint("media", __name__)
logger = get_logger("Web.Media")
# 生成缩略图
async def generate_thumbnail(image_data: bytes) -> io.BytesIO:
"""生成图片缩略图,并返回BytesIO对象"""
from PIL import Image
def _generate_thumbnail(image_data: bytes) -> io.BytesIO:
"""在线程中运行的同步缩略图生成函数"""
with Image.open(io.BytesIO(image_data)) as img:
width, height = img.size
if width > height:
new_width = 300
new_height = int(height * (300 / width))
else:
new_height = 300
new_width = int(width * (300 / height))
img.thumbnail((new_width, new_height))
output = io.BytesIO()
img = img.convert("RGB")
img.save(output, format="WEBP", optimize=True, quality=65)
output.seek(0)
return output
return await asyncio.to_thread(_generate_thumbnail, image_data)
def _get_media_manager() -> MediaManager:
"""获取媒体管理器实例"""
return g.container.resolve(MediaManager)
def _convert_media_to_api_item(media: Media) -> Optional[MediaItem]:
"""将Media对象转换为API响应格式"""
if not media or not media.metadata:
return None
metadata = media.metadata
content_type = f"{metadata.media_type.value}/{metadata.format}" if metadata.media_type and metadata.format else "application/octet-stream"
return MediaItem(
id=media.media_id,
url=f"",
thumbnail_url="",
metadata=MediaMetadata(
filename=os.path.basename(metadata.path) if metadata.path else f"{media.media_id}.{metadata.format}",
content_type=content_type,
size=metadata.size or 0,
upload_time=metadata.created_at,
source=metadata.source,
tags=list(metadata.tags),
references=list(metadata.references),
)
)
@media_bp.route("/list", methods=["POST"])
@require_auth
async def list_media():
"""获取媒体列表,支持分页和搜索"""
data = await request.get_json()
search_params = MediaSearchParams(**data)
manager = _get_media_manager()
# 构建搜索条件
all_media_ids = []
# 如果有指定内容类型,筛选对应类型
if search_params.content_type:
if search_params.content_type.startswith("image/"):
media_type = MediaType.IMAGE
elif search_params.content_type.startswith("video/"):
media_type = MediaType.VIDEO
elif search_params.content_type.startswith("audio/"):
media_type = MediaType.AUDIO
else:
media_type = MediaType.FILE
all_media_ids = manager.search_by_type(media_type)
else:
all_media_ids = manager.get_all_media_ids()
# 如果有指定标签,继续筛选
if search_params.tags and len(search_params.tags) > 0:
filtered_ids = []
for media_id in all_media_ids:
metadata = manager.get_metadata(media_id)
if metadata and any(tag in metadata.tags for tag in search_params.tags):
filtered_ids.append(media_id)
all_media_ids = filtered_ids
# 如果有搜索关键词,继续筛选
if search_params.query:
description_ids = manager.search_by_description(search_params.query)
source_ids = manager.search_by_source(search_params.query)
all_media_ids = [
media_id for media_id in all_media_ids
if media_id in description_ids or media_id in source_ids
]
# 如果有日期范围,继续筛选
if search_params.start_date or search_params.end_date:
filtered_ids = []
for media_id in all_media_ids:
metadata = manager.get_metadata(media_id)
if metadata:
tz = pytz.timezone(g.container.resolve(GlobalConfig).system.timezone)
created_at = metadata.created_at.replace(tzinfo=tz)
start_date = search_params.start_date.replace(tzinfo=tz) if search_params.start_date else None
if start_date and created_at < start_date:
continue
end_date = search_params.end_date.replace(tzinfo=tz) if search_params.end_date else None
if end_date and created_at > end_date:
continue
filtered_ids.append(media_id)
all_media_ids = filtered_ids
# 计算分页
total = len(all_media_ids)
start_idx = (search_params.page - 1) * search_params.page_size
end_idx = start_idx + search_params.page_size
page_ids = all_media_ids[start_idx:end_idx]
# 构建返回结果
items = []
for media_id in page_ids:
if media := manager.get_media(media_id):
if item:= _convert_media_to_api_item(media):
items.append(item)
return MediaListResponse(
items=items,
total=total,
has_more=end_idx < total,
page_size=search_params.page_size,
).model_dump()
@media_bp.route("/file/", methods=["GET"])
@require_auth
async def get_media_file(media_id):
"""获取媒体文件"""
manager = _get_media_manager()
media = manager.get_media(media_id)
if not media:
return jsonify({"error": "Media not found"}), 404
return await send_file(io.BytesIO(await media.get_data()), mimetype=media.metadata.mime_type)
@media_bp.route("/preview/", methods=["GET"])
@require_auth
async def get_thumbnail(media_id):
"""获取缩略图"""
config = g.container.resolve(GlobalConfig)
media_manager = _get_media_manager()
media = media_manager.get_media(media_id)
if not media:
return jsonify({"error": "Media not found"}), 404
data = await media.get_data()
if not data:
return jsonify({"error": "Media not found"}), 404
if media.metadata.media_type == MediaType.IMAGE:
if media.metadata.format == "gif":
return await send_file(io.BytesIO(data), mimetype="image/gif")
thumbnail = await generate_thumbnail(data)
return await send_file(thumbnail, mimetype="image/webp")
elif media.metadata.media_type == MediaType.VIDEO:
# 视频类型直接返回原始数据,不做缩略图处理
return await send_file(io.BytesIO(data), mimetype="video/mp4")
else:
return jsonify({"error": "Unsupported media type"}), 400
@media_bp.route("/delete/", methods=["DELETE"])
@require_auth
async def delete_media(media_id):
"""删除单个媒体文件"""
manager = _get_media_manager()
# 检查媒体是否存在
if media_id not in manager.metadata_cache:
return jsonify({"error": "File not found"}), 404
# 删除媒体文件
manager.delete_media(media_id)
return jsonify({"success": True})
@media_bp.route("/batch-delete", methods=["POST"])
@require_auth
async def batch_delete():
"""批量删除媒体文件"""
data = await request.get_json()
delete_request = MediaBatchDeleteRequest(**data)
manager = _get_media_manager()
success_count = 0
for media_id in delete_request.ids:
if media_id in manager.metadata_cache:
# 删除媒体文件
manager.delete_media(media_id)
success_count += 1
return jsonify({"success": True, "deleted_count": success_count})
@media_bp.route("/system", methods=["GET"])
@require_auth
async def get_system_info():
"""获取系统信息,包括媒体数量、占用空间和磁盘信息"""
config: GlobalConfig = g.container.resolve(GlobalConfig)
manager: MediaManager = _get_media_manager()
# 获取媒体总数和总大小
all_media_ids = manager.get_all_media_ids()
total_media_count = len(all_media_ids)
total_media_size = 0
for media_id in all_media_ids:
metadata = manager.get_metadata(media_id)
if metadata and metadata.size:
total_media_size += metadata.size
# 获取磁盘空间信息
storage_path = manager.media_dir
disk_total, disk_used, disk_free = 0, 0, 0
try:
# 确保存储路径存在,如果不存在则尝试创建
disk_usage = shutil.disk_usage(storage_path)
disk_total = disk_usage.total
disk_used = disk_usage.used
disk_free = disk_usage.free
except OSError as e:
# 处理获取磁盘信息时可能发生的错误,例如路径不存在或权限问题
logger.error(f"Unable to get disk info for {storage_path}: {e}")
return jsonify({
"cleanup_duration": config.media.cleanup_duration,
"auto_remove_unreferenced": config.media.auto_remove_unreferenced,
"last_cleanup_time": config.media.last_cleanup_time,
"total_media_count": total_media_count,
"total_media_size": total_media_size,
"disk_total": disk_total,
"disk_used": disk_used,
"disk_free": disk_free,
})
# 修改配置
@media_bp.route("/system/config", methods=["POST"])
@require_auth
async def set_config():
"""设置配置"""
manager = _get_media_manager()
data = await request.get_json()
config = g.container.resolve(GlobalConfig)
config.media.cleanup_duration = data.get("cleanup_duration", config.media.cleanup_duration)
config.media.auto_remove_unreferenced = data.get("auto_remove_unreferenced", config.media.auto_remove_unreferenced)
manager.setup_cleanup_task(g.container)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return jsonify({"success": True})
@media_bp.route("/system/cleanup-unreferenced", methods=["POST"])
@require_auth
async def cleanup_unreferenced():
"""清理未引用的媒体文件"""
manager = _get_media_manager()
count = manager.cleanup_unreferenced()
config = g.container.resolve(GlobalConfig)
config.media.last_cleanup_time = int(time.time())
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
manager.setup_cleanup_task(g.container)
return jsonify({"success": True, "count": count})
================================================
FILE: kirara_ai/web/api/plugin/README.md
================================================
# 插件 API 🔌
插件 API 提供了管理插件的功能。通过这些 API,你可以安装、卸载、启用、禁用和更新插件。
## API 端点
### 获取所有插件
```http
GET/backend-api/api/plugin/plugins
```
获取所有已安装的插件列表。
**响应示例:**
```json
{
"plugins": [
{
"name": "图像处理",
"package_name": "image-processing",
"description": "提供图像处理功能",
"version": "1.0.0",
"author": "开发者",
"homepage": "https://github.com/example/image-processing",
"license": "MIT",
"is_internal": false,
"is_enabled": true,
"metadata": {
"category": "media",
"tags": ["image", "processing"]
}
}
]
}
```
### 获取特定插件
```http
GET/backend-api/api/plugin/plugins/{plugin_name}
```
获取指定插件的详细信息。
**响应示例:**
```json
{
"plugin": {
"name": "图像处理",
"package_name": "image-processing",
"description": "提供图像处理功能",
"version": "1.0.0",
"author": "开发者",
"homepage": "https://github.com/example/image-processing",
"license": "MIT",
"is_internal": false,
"is_enabled": true,
"metadata": {
"category": "media",
"tags": ["image", "processing"]
}
}
}
```
### 安装插件
```http
POST/backend-api/api/plugin/plugins
```
安装新的插件。
**请求体:**
```json
{
"package_name": "image-processing",
"version": "1.0.0" // 可选,不指定则安装最新版本
}
```
### 卸载插件
```http
DELETE/backend-api/api/plugin/plugins/{plugin_name}
```
卸载指定的插件。注意:内部插件不能被卸载。
### 启用插件
```http
POST/backend-api/api/plugin/plugins/{plugin_name}/enable
```
启用指定的插件。
### 禁用插件
```http
POST/backend-api/api/plugin/plugins/{plugin_name}/disable
```
禁用指定的插件。
### 更新插件
```http
PUT/backend-api/api/plugin/plugins/{plugin_name}
```
更新插件到最新版本。注意:内部插件不支持更新。
### 搜索插件市场
```http
GET/backend-api/api/v1/search?query={query}&page={page}&pageSize={pageSize}
```
在插件市场中搜索插件。
**参数:**
- `query`: 搜索关键词
- `page`: 页码 (默认为 1)
- `pageSize`: 每页数量 (默认为 10)
**响应示例:**
```json
{
"plugins": [
{
"name": "图像处理",
"description": "提供图像处理功能",
"author": "开发者",
"pypiPackage": "image-processing",
"pypiInfo": {
"version": "1.0.0",
"description": "PyPI 描述",
"author": "PyPI 作者",
"homePage": "https://example.com"
},
"isInstalled": false,
"installedVersion": null,
"isUpgradable": false
}
],
"totalCount": 1,
"totalPages": 1,
"page": 1,
"pageSize": 10
}
```
### 获取插件市场中插件的详细信息
```http
GET/backend-api/api/v1/info/{plugin_name}
```
获取插件市场中指定插件的详细信息。
**响应示例:**
```json
{
"name": "图像处理",
"description": "提供图像处理功能",
"author": "开发者",
"pypiPackage": "image-processing",
"pypiInfo": {
"version": "1.0.0",
"description": "PyPI 描述",
"author": "PyPI 作者",
"homePage": "https://example.com"
},
"isInstalled": false,
"installedVersion": null,
"isUpgradable": false
}
```
## 数据模型
### PluginInfo
- `name`: 插件名称
- `package_name`: 包名
- `description`: 描述
- `version`: 版本号
- `author`: 作者
- `homepage`: 主页(可选)
- `license`: 许可证(可选)
- `is_internal`: 是否为内部插件
- `is_enabled`: 是否已启用
- `metadata`: 元数据(可选)
### InstallPluginRequest
- `package_name`: 包名
- `version`: 版本号(可选)
### PluginList
- `plugins`: 插件列表
### PluginResponse
- `plugin`: 插件信息
## 内置插件
### IM 适配器
- Telegram 适配器
- HTTP Legacy 适配器
- WeCom 适配器
### LLM 后端
- OpenAI 适配器
- Anthropic 适配器
- Google AI 适配器
## 相关代码
- [插件管理器](../../../plugin_manager/plugin_loader.py)
- [插件基类](../../../plugin_manager/plugin.py)
- [系统插件](../../../plugins)
## 错误处理
所有 API 端点在发生错误时都会返回适当的 HTTP 状态码和错误信息:
```json
{
"error": "错误描述信息"
}
```
常见状态码:
- 400: 请求参数错误、插件已存在或内部插件操作限制
- 404: 插件不存在
- 500: 服务器内部错误
## 使用示例
### 获取所有插件
```python
import requests
response = requests.get(
'http://localhost:8080/api/plugin/plugins',
headers={'Authorization': f'Bearer {token}'}
)
```
### 安装新插件
```python
import requests
data = {
"package_name": "image-processing",
"version": "1.0.0"
}
response = requests.post(
'http://localhost:8080/api/plugin/plugins',
headers={'Authorization': f'Bearer {token}'},
json=data
)
```
### 启用插件
```python
import requests
response = requests.post(
'http://localhost:8080/api/plugin/plugins/image-processing/enable',
headers={'Authorization': f'Bearer {token}'}
)
```
### 更新插件
```python
import requests
response = requests.put(
'http://localhost:8080/api/plugin/plugins/image-processing',
headers={'Authorization': f'Bearer {token}'}
)
```
### 搜索插件市场
```python
import requests
response = requests.get(
'http://localhost:8080/api/v1/search?query=image&page=1&pageSize=10',
headers={'Authorization': f'Bearer {token}'}
)
```
### 获取插件市场中插件的详细信息
```python
import requests
response = requests.get(
'http://localhost:8080/api/v1/info/image-processing',
headers={'Authorization': f'Bearer {token}'}
)
```
## 相关文档
- [系统架构](../../README.md#系统架构-)
- [API 认证](../../README.md#api认证-)
- [插件开发](../../../plugin_manager/README.md#插件开发-)
================================================
FILE: kirara_ai/web/api/plugin/__init__.py
================================================
from .routes import plugin_bp
__all__ = ["plugin_bp"]
================================================
FILE: kirara_ai/web/api/plugin/models.py
================================================
from typing import List, Optional
from pydantic import BaseModel
from kirara_ai.plugin_manager.models import PluginInfo
class PluginList(BaseModel):
"""插件列表响应"""
plugins: List[PluginInfo]
class PluginResponse(BaseModel):
"""插件详情响应"""
plugin: PluginInfo
class InstallPluginRequest(BaseModel):
"""安装插件请求"""
package_name: str
version: Optional[str] = None # 可选的版本号,不指定则安装最新版
================================================
FILE: kirara_ai/web/api/plugin/routes.py
================================================
from functools import lru_cache
import aiohttp
from packaging.version import Version
from quart import Blueprint, g, jsonify, request
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.logger import get_logger
from kirara_ai.plugin_manager.plugin_loader import PluginLoader
from kirara_ai.web.api.system.utils import get_installed_version
from ...auth.middleware import require_auth
from .models import InstallPluginRequest, PluginList, PluginResponse
plugin_bp = Blueprint("plugin", __name__)
logger = get_logger("WebServer")
@lru_cache(maxsize=1)
def get_meta_params() -> dict:
"""获取元参数"""
return {
"kirara_version": get_installed_version(),
}
def is_upgradable(installed_version: str, market_version: str) -> bool:
"""检查插件是否可升级"""
try:
return Version(market_version) > Version(installed_version)
except ValueError:
return False
async def fetch_from_market(path: str, params: dict | None = None) -> dict:
"""从插件市场获取数据的通用方法"""
plugin_market_base_url = g.container.resolve(GlobalConfig).plugins.market_base_url
async with aiohttp.ClientSession(trust_env=True) as session:
url = f"{plugin_market_base_url}/{path}"
logger.info(f"Fetching from market: {url}")
params = params or {}
params.update(get_meta_params())
async with session.get(url, params=params) as response:
if response.status != 200:
raise Exception(f"插件市场请求失败: {response.status}")
return await response.json()
async def enrich_plugin_data(plugins: list, loader: PluginLoader) -> list:
"""为插件数据添加安装状态和可升级状态"""
installed_plugins = loader.get_all_plugin_infos()
for plugin in plugins:
installed_plugin = next(
(p for p in installed_plugins if p.package_name == plugin["pypiPackage"]),
None,
)
plugin["isInstalled"] = installed_plugin is not None
plugin["installedVersion"] = (
installed_plugin.version if installed_plugin else None
)
plugin["isUpgradable"] = (
is_upgradable(installed_plugin.version, plugin["pypiInfo"]["version"])
if installed_plugin
else False
)
plugin["isEnabled"] = installed_plugin.is_enabled if installed_plugin else False
plugin["requiresRestart"] = installed_plugin.requires_restart if installed_plugin else False
return plugins
@plugin_bp.route("/v1/search", methods=["GET"])
@require_auth
async def search_plugins():
"""搜索插件市场"""
query = request.args.get("query", "")
page = request.args.get("page", 1, type=int)
page_size = request.args.get("pageSize", 10, type=int)
try:
result = await fetch_from_market(
"search", {"query": query, "page": page, "pageSize": page_size}
)
# 添加安装状态和可升级状态
loader: PluginLoader = g.container.resolve(PluginLoader)
result["plugins"] = await enrich_plugin_data(result["plugins"], loader)
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@plugin_bp.route("/v1/info/", methods=["GET"])
@require_auth
async def get_market_plugin_info(plugin_name: str):
"""获取插件市场中插件的详细信息"""
try:
result = await fetch_from_market(f"info/{plugin_name}")
# 添加安装状态和可升级状态
loader: PluginLoader = g.container.resolve(PluginLoader)
result = (await enrich_plugin_data([result], loader))[0]
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@plugin_bp.route("/plugins", methods=["GET"])
@require_auth
async def list_plugins():
"""获取所有已安装的插件列表"""
loader: PluginLoader = g.container.resolve(PluginLoader)
plugins = loader.get_all_plugin_infos()
return PluginList(plugins=plugins).model_dump()
@plugin_bp.route("/plugins/", methods=["GET"])
@require_auth
async def get_plugin_details(plugin_name: str):
"""获取已安装插件的详细信息"""
loader: PluginLoader = g.container.resolve(PluginLoader)
print(f"Getting plugin details for {plugin_name}")
plugin_info = loader.get_plugin_info(plugin_name)
if not plugin_info:
return jsonify({"error": "Plugin not found"}), 404
return PluginResponse(plugin=plugin_info).model_dump()
@plugin_bp.route("/plugins", methods=["POST"])
@require_auth
async def install_plugin():
"""安装新插件"""
data = await request.get_json()
install_data = InstallPluginRequest(**data)
loader: PluginLoader = g.container.resolve(PluginLoader)
config: GlobalConfig = g.container.resolve(GlobalConfig)
try:
# 安装插件
plugin_info = await loader.install_plugin(
install_data.package_name, install_data.version
)
if not plugin_info or plugin_info.package_name is None:
return jsonify({"error": "Failed to install plugin"}), 500
# 更新配置
if plugin_info.package_name not in config.plugins.enable:
config.plugins.enable.append(plugin_info.package_name)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# 加载插件
loader.load_plugin(plugin_info.name)
return PluginResponse(plugin=plugin_info).model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 500
@plugin_bp.route("/plugins/", methods=["DELETE"])
@require_auth
async def uninstall_plugin(plugin_name: str):
"""卸载插件"""
loader: PluginLoader = g.container.resolve(PluginLoader)
config: GlobalConfig = g.container.resolve(GlobalConfig)
# 检查插件是否存在
plugin_info = loader.get_plugin_info(plugin_name)
if not plugin_info:
return jsonify({"error": "Plugin not found"}), 404
# 内部插件不能卸载
if plugin_info.is_internal:
return jsonify({"error": "Cannot uninstall internal plugin"}), 400
try:
# 卸载插件
await loader.uninstall_plugin(plugin_name)
# 更新配置
if plugin_info.package_name in config.plugins.enable:
config.plugins.enable.remove(plugin_info.package_name)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return jsonify({"message": "Plugin uninstalled successfully"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@plugin_bp.route("/plugins//enable", methods=["POST"])
@require_auth
async def enable_plugin(plugin_name: str):
"""启用插件"""
loader: PluginLoader = g.container.resolve(PluginLoader)
config: GlobalConfig = g.container.resolve(GlobalConfig)
# 检查插件是否存在
plugin_info = loader.get_plugin_info(plugin_name)
if not plugin_info:
return jsonify({"error": "Plugin not found"}), 404
try:
# 启用插件
await loader.enable_plugin(plugin_name)
# 更新配置
if (
plugin_name
and plugin_name not in config.plugins.enable
and not plugin_info.is_internal
):
config.plugins.enable.append(plugin_name)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return PluginResponse(plugin=plugin_info).model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 500
@plugin_bp.route("/plugins//disable", methods=["POST"])
@require_auth
async def disable_plugin(plugin_name: str):
"""禁用插件"""
loader: PluginLoader = g.container.resolve(PluginLoader)
config: GlobalConfig = g.container.resolve(GlobalConfig)
# 检查插件是否存在
plugin_info = loader.get_plugin_info(plugin_name)
if not plugin_info:
return jsonify({"error": "Plugin not found"}), 404
try:
# 禁用插件
await loader.disable_plugin(plugin_name)
# 更新插件信息
plugin_info = loader.get_plugin_info(plugin_name)
assert plugin_info is not None
# 更新配置
if (
plugin_name
and plugin_name in config.plugins.enable
and not plugin_info.is_internal
):
config.plugins.enable.remove(plugin_name)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return PluginResponse(plugin=plugin_info).model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 500
@plugin_bp.route("/plugins/", methods=["PUT"])
@require_auth
async def update_plugin(plugin_name: str):
"""更新插件到最新版本"""
loader: PluginLoader = g.container.resolve(PluginLoader)
# 检查插件是否存在
plugin_info = loader.get_plugin_info(plugin_name)
if not plugin_info:
return jsonify({"error": "Plugin not found"}), 404
# 内部插件不支持更新
if plugin_info.is_internal:
return jsonify({"error": "Cannot update internal plugin"}), 400
new_package_name = request.args.get("package_name", None)
try:
# 执行更新
updated_info = await loader.update_plugin(plugin_name, new_package_name)
if not updated_info:
return jsonify({"error": "Failed to update plugin"}), 500
return PluginResponse(plugin=updated_info).model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 500
================================================
FILE: kirara_ai/web/api/system/README.md
================================================
# 系统管理 API 🛠️
系统管理 API 提供了监控和管理系统状态的功能。
## API 端点
### 获取系统状态
```http
GET/backend-api/api/system/status
```
获取系统的当前运行状态,包括版本信息、运行时间、资源使用情况等。
**响应示例:**
```json
{
"status": {
"version": "1.0.0",
"uptime": 3600, // 运行时间(秒)
"active_adapters": 2, // 活跃的 IM 适配器数量
"active_backends": 3, // 活跃的 LLM 后端数量
"loaded_plugins": 5, // 已加载的插件数量
"workflow_count": 10, // 工作流数量
"memory_usage": {
"rss": 256.5, // 物理内存使用(MB)
"vms": 512.8, // 虚拟内存使用(MB)
"percent": 2.5 // 内存使用百分比
},
"cpu_usage": 1.2 // CPU 使用百分比
}
}
```
### 获取系统配置
```http
GET/backend-api/api/system/config
```
获取系统当前配置。
### 更新系统配置
```http
PUT/backend-api/api/system/config
```
更新系统配置。
**请求体:**
```json
{
"log_level": "INFO",
"max_connections": 100,
"timeout": 30,
"storage": {
"type": "local",
"path": "/data"
}
}
```
### 获取系统日志
```http
GET/backend-api/api/system/logs
```
获取系统日志。支持分页和过滤。
**查询参数:**
- `level`: 日志级别 (DEBUG/INFO/WARNING/ERROR)
- `start_time`: 开始时间
- `end_time`: 结束时间
- `limit`: 每页条数
- `offset`: 偏移量
### 获取用户列表
```http
GET/backend-api/api/system/users
```
获取系统用户列表。
### 创建用户
```http
POST/backend-api/api/system/users
```
创建新用户。
**请求体:**
```json
{
"username": "admin",
"password": "password123",
"role": "admin",
"permissions": ["read", "write", "admin"]
}
```
### 更新用户
```http
PUT/backend-api/api/system/users/{username}
```
更新用户信息。
### 删除用户
```http
DELETE/backend-api/api/system/users/{username}
```
删除指定用户。
## 数据模型
### SystemStatus
- `version`: 系统版本
- `uptime`: 运行时间(秒)
- `active_adapters`: 活跃的 IM 适配器数量
- `active_backends`: 活跃的 LLM 后端数量
- `loaded_plugins`: 已加载的插件数量
- `workflow_count`: 工作流数量
- `memory_usage`: 内存使用情况
- `rss`: 物理内存使用(MB)
- `vms`: 虚拟内存使用(MB)
- `percent`: 内存使用百分比
- `cpu_usage`: CPU 使用百分比
### SystemConfig
- `log_level`: 日志级别
- `max_connections`: 最大连接数
- `timeout`: 超时时间(秒)
- `storage`: 存储配置
### User
- `username`: 用户名
- `role`: 角色
- `permissions`: 权限列表
- `created_at`: 创建时间
- `last_login`: 最后登录时间
## 监控指标
### 系统指标
- 运行时间
- CPU 使用率
- 内存使用情况
### 组件指标
- IM 适配器数量和状态
- LLM 后端数量和状态
- 插件数量和状态
- 工作流数量
## 相关代码
- [系统路由](routes.py)
- [数据模型](models.py)
- [系统监控](../../../monitor)
## 错误处理
所有 API 端点在发生错误时都会返回适当的 HTTP 状态码和错误信息:
```json
{
"error": "错误描述信息"
}
```
常见状态码:
- 400: 请求参数错误
- 401: 未认证或认证失败
- 403: 权限不足
- 404: 资源不存在
- 500: 服务器内部错误
## 使用示例
### 获取系统状态
```python
import requests
response = requests.get(
'http://localhost:8080/api/system/status',
headers={'Authorization': f'Bearer {token}'}
)
status = response.json()['status']
print(f"系统已运行: {status['uptime']} 秒")
print(f"内存使用: {status['memory_usage']['percent']}%")
print(f"CPU 使用: {status['cpu_usage']}%")
```
### 更新系统配置
```python
import requests
config_data = {
"log_level": "DEBUG",
"max_connections": 200,
"timeout": 60
}
response = requests.put(
'http://localhost:8080/api/system/config',
headers={'Authorization': f'Bearer {token}'},
json=config_data
)
```
### 创建新用户
```python
import requests
user_data = {
"username": "admin",
"password": "password123",
"role": "admin",
"permissions": ["read", "write", "admin"]
}
response = requests.post(
'http://localhost:8080/api/system/users',
headers={'Authorization': f'Bearer {token}'},
json=user_data
)
```
## 相关文档
- [系统架构](../../README.md#系统架构-)
- [监控指南](../../README.md#系统监控-)
- [API 认证](../../README.md#api认证-)
================================================
FILE: kirara_ai/web/api/system/__init__.py
================================================
from .routes import system_bp
__all__ = ["system_bp"]
================================================
FILE: kirara_ai/web/api/system/models.py
================================================
from typing import Dict, Optional
from pydantic import BaseModel
class SystemStatus(BaseModel):
"""系统状态信息"""
version: str
uptime: float
active_adapters: int
active_backends: int
loaded_plugins: int
workflow_count: int
memory_usage: Dict[str, float]
cpu_usage: float
cpu_info: str
python_version: str
platform: str
has_proxy: bool
class SystemStatusResponse(BaseModel):
"""系统状态响应"""
status: SystemStatus
class UpdateStatus(BaseModel):
status: str
message: str
class UpdateCheckResponse(BaseModel):
"""更新检查响应"""
current_backend_version: str
latest_backend_version: str
backend_update_available: bool
backend_download_url: Optional[str]
latest_webui_version: str
webui_download_url: Optional[str]
================================================
FILE: kirara_ai/web/api/system/routes.py
================================================
import asyncio
import json
import os
import shutil
import subprocess
import sys
import tarfile
import tempfile
import time
from packaging import version
from quart import Blueprint, current_app, g, request, websocket
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.im.manager import IMManager
from kirara_ai.internal import set_restart_flag, shutdown_event
from kirara_ai.llm.llm_manager import LLMManager
from kirara_ai.logger import WebSocketLogHandler, get_logger
from kirara_ai.plugin_manager.plugin_loader import PluginLoader
from kirara_ai.web.api.system.utils import (download_file, get_cpu_info, get_cpu_usage, get_installed_version,
get_latest_npm_version, get_latest_pypi_version, get_memory_usage)
from kirara_ai.web.auth.services import AuthService
from kirara_ai.workflow.core.workflow import WorkflowRegistry
from ...auth.middleware import require_auth
from .models import SystemStatus, SystemStatusResponse, UpdateCheckResponse
system_bp = Blueprint("system", __name__)
# 记录启动时间
start_time = time.time()
# 获取系统日志记录器
logger = get_logger("System-API")
@system_bp.websocket('/logs')
async def logs_websocket():
"""WebSocket端点,用于实时推送日志"""
try:
token_data = await websocket.receive()
token = json.loads(token_data)["token"]
except Exception as e:
logger.error(f"WebSocket连接错误: {e}")
await websocket.close(code=1008, reason="Invalid token")
return
auth_service: AuthService = g.container.resolve(AuthService)
if not auth_service.verify_token(token):
await websocket.close(code=1008, reason="Invalid token")
return
try:
# 将当前WebSocket连接添加到日志处理器
WebSocketLogHandler.add_websocket(websocket._get_current_object(), asyncio.get_event_loop())
# 保持连接打开,直到客户端断开
while not shutdown_event.is_set():
await asyncio.sleep(1)
finally:
# 从日志处理器中移除当前连接
WebSocketLogHandler.remove_websocket(websocket._get_current_object())
@system_bp.route("/config", methods=["GET"])
@require_auth
async def get_system_config():
"""获取系统配置"""
try:
config: GlobalConfig = g.container.resolve(GlobalConfig)
return {
"web": {
"host": config.web.host,
"port": config.web.port
},
"plugins": {
"market_base_url": config.plugins.market_base_url
},
"update": {
"pypi_registry": config.update.pypi_registry,
"npm_registry": config.update.npm_registry
},
"system": {
"timezone": config.system.timezone
},
"tracing": {
"llm_tracing_content": config.tracing.llm_tracing_content
}
}
except Exception as e:
return {"error": str(e)}, 500
@system_bp.route("/config/web", methods=["POST"])
@require_auth
async def update_web_config():
"""更新Web配置"""
try:
data = await request.get_json()
config: GlobalConfig = g.container.resolve(GlobalConfig)
config.web.host = data["host"]
config.web.port = data["port"]
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return {"status": "success", "restart_required": True}
except Exception as e:
return {"error": str(e)}, 500
@system_bp.route("/config/plugins", methods=["POST"])
@require_auth
async def update_plugins_config():
"""更新插件配置"""
try:
data = await request.get_json()
config: GlobalConfig = g.container.resolve(GlobalConfig)
config.plugins.market_base_url = data["market_base_url"]
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return {"status": "success"}
except Exception as e:
return {"error": str(e)}, 500
@system_bp.route("/config/update", methods=["POST"])
@require_auth
async def update_registry_config():
"""更新更新源配置"""
try:
data = await request.get_json()
config: GlobalConfig = g.container.resolve(GlobalConfig)
if not hasattr(config, "update"):
config.update = {}
config.update.pypi_registry = data["pypi_registry"]
config.update.npm_registry = data["npm_registry"]
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return {"status": "success"}
except Exception as e:
return {"error": str(e)}, 500
@system_bp.route("/config/system", methods=["POST"])
@require_auth
async def update_system_config():
"""更新系统配置"""
try:
data = await request.get_json()
config: GlobalConfig = g.container.resolve(GlobalConfig)
# 检查时区是否变化
timezone_changed = False
if "timezone" in data and data["timezone"] != config.system.timezone:
config.system.timezone = data["timezone"]
timezone_changed = True
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
# 如果时区变化,设置系统时区并调用 tzset
if timezone_changed and hasattr(time, "tzset"):
os.environ["TZ"] = config.system.timezone
time.tzset()
return {"status": "success"}
except Exception as e:
return {"error": str(e)}, 500
@system_bp.route("/config/tracing", methods=["POST"])
@require_auth
async def update_tracing_config():
"""更新追踪配置"""
try:
data = await request.get_json()
config: GlobalConfig = g.container.resolve(GlobalConfig)
config.tracing.llm_tracing_content = data["llm_tracing_content"]
# 保存配置
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
return {"status": "success"}
except Exception as e:
return {"error": str(e)}, 500
@system_bp.route("/status", methods=["GET"])
@require_auth
async def get_system_status():
"""获取系统状态"""
im_manager: IMManager = g.container.resolve(IMManager)
llm_manager: LLMManager = g.container.resolve(LLMManager)
plugin_loader: PluginLoader = g.container.resolve(PluginLoader)
workflow_registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
# 计算运行时间
uptime = time.time() - start_time
# 获取活跃的适配器数量
active_adapters = len(
[adapter for adapter in im_manager.adapters.values() if adapter.is_running]
)
# 获取活跃的LLM后端数量
active_backends = len(llm_manager.active_backends)
# 获取已加载的插件数量
loaded_plugins = len(plugin_loader.plugins)
# 获取工作流数量
workflow_count = len(workflow_registry._workflows)
# 获取系统资源使用情况
memory_usage = get_memory_usage()
cpu_usage = get_cpu_usage()
# 检测代理服务
has_proxy = bool(os.environ.get('HTTP_PROXY') or os.environ.get('HTTPS_PROXY') or
os.environ.get('http_proxy') or os.environ.get('https_proxy'))
# 获取CPU信息
cpu_info = get_cpu_info()
# 获取Python版本
python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
# 获取平台信息
platform_info = f"{sys.platform}"
status = SystemStatus(
uptime=uptime,
active_adapters=active_adapters,
active_backends=active_backends,
loaded_plugins=loaded_plugins,
workflow_count=workflow_count,
memory_usage=memory_usage,
cpu_usage=cpu_usage,
version=get_installed_version(),
platform=platform_info,
cpu_info=cpu_info,
python_version=python_version,
has_proxy=has_proxy,
)
return SystemStatusResponse(status=status).model_dump()
@system_bp.route("/check-update", methods=["GET"])
@require_auth
async def check_update():
"""检查系统更新"""
config: GlobalConfig = g.container.resolve(GlobalConfig)
npm_registry = config.update.npm_registry
current_backend_version = get_installed_version()
latest_backend_version, backend_download_url = await get_latest_pypi_version("kirara-ai")
# 获取前端最新版本信息,但不判断是否需要更新
latest_webui_version, webui_download_url = await get_latest_npm_version("kirara-ai-webui", npm_registry)
# 只判断后端是否需要更新
backend_update_available = version.parse(latest_backend_version) > version.parse(current_backend_version)
return UpdateCheckResponse(
current_backend_version=current_backend_version,
latest_backend_version=latest_backend_version,
backend_update_available=backend_update_available,
backend_download_url=backend_download_url,
latest_webui_version=latest_webui_version,
webui_download_url=webui_download_url
).model_dump()
@system_bp.route("/update", methods=["POST"])
@require_auth
async def perform_update():
"""执行更新操作"""
data = await request.get_json()
update_backend = data.get("update_backend", False)
update_webui = data.get("update_webui", False)
temp_dir = tempfile.mkdtemp()
try:
if update_backend:
backend_url = data["backend_download_url"]
backend_file, backend_hash = await download_file(backend_url, temp_dir)
# 安装后端
subprocess.run([sys.executable, "-m", "pip", "install", backend_file], check=True)
if update_webui:
webui_url = data["webui_download_url"]
webui_file, webui_hash = await download_file(webui_url, temp_dir)
# 解压并安装前端
static_dir = current_app.static_folder or "web"
with tarfile.open(webui_file, "r:gz") as tar:
# 解压 package/dist 里的所有文件到 web 目录
for member in tar.getmembers():
if member.name.startswith("package/dist/"):
# 去掉 "package/dist/" 前缀
member.name = member.name[len("package/dist/"):]
# 解压到 static 目录
tar.extract(member, path=static_dir)
return {"status": "success", "message": "更新完成"}
except Exception as e:
return {"status": "error", "message": str(e)}, 500
finally:
shutil.rmtree(temp_dir)
@system_bp.route("/restart", methods=["POST"])
@require_auth
async def restart_system():
"""重启系统"""
# 记录重启日志,会通过WebSocket发送给所有客户端
logger.warning("服务器即将重启,请稍候...")
# 设置重启标志
set_restart_flag()
shutdown_event.set()
return {"status": "success", "message": "重启请求已发送"}
================================================
FILE: kirara_ai/web/api/system/utils.py
================================================
import hashlib
import os
import subprocess
import sys
from functools import lru_cache
import aiohttp
import psutil
def get_installed_version() -> str:
"""获取当前安装的版本号"""
try:
# 使用 importlib.metadata 获取已安装的包版本
from importlib.metadata import PackageNotFoundError, version
try:
return version("kirara-ai")
except PackageNotFoundError:
# 如果包未安装,尝试从 pkg_resources 获取
from pkg_resources import get_distribution
return get_distribution("kirara-ai").version
except Exception:
return "0.0.0" # 如果所有方法都失败,返回默认版本号
async def get_latest_pypi_version(package_name: str) -> tuple[str, str]:
"""获取包的最新版本和下载URL"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"https://pypi.org/pypi/{package_name}/json") as response:
response.raise_for_status()
data = await response.json()
latest_version = data["info"]["version"]
# 获取最新版本的wheel包下载URL
for url_info in data["urls"]:
if url_info["packagetype"] == "bdist_wheel":
return latest_version, url_info["url"]
return latest_version, ""
except Exception:
return "0.0.0", ""
async def get_latest_npm_version(package_name: str, registry: str = "https://registry.npmjs.org") -> tuple[str, str]:
"""获取NPM包的最新版本和下载URL"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{registry}/{package_name}") as response:
response.raise_for_status()
data = await response.json()
latest_version = data["dist-tags"]["latest"]
tarball_url = data["versions"][latest_version]["dist"]["tarball"]
return latest_version, tarball_url
except Exception:
return "0.0.0", ""
async def download_file(url: str, temp_dir: str) -> tuple[str, str]:
"""下载文件并返回文件路径和SHA256"""
local_filename = os.path.join(temp_dir, url.split('/')[-1])
sha256_hash = hashlib.sha256()
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
response.raise_for_status()
total_size = int(response.headers.get('Content-Length', 0))
bytes_downloaded = 0
with open(local_filename, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
sha256_hash.update(chunk)
bytes_downloaded += len(chunk)
if total_size > 0:
print(f"Downloaded {bytes_downloaded / total_size:.2%}", end='\r')
print() # 换行,确保进度条不覆盖后续输出
return local_filename, sha256_hash.hexdigest()
except Exception as e:
print(f"下载失败: {e}")
return "", ""
@lru_cache(maxsize=1)
def get_cpu_info() -> str:
"""获取CPU信息,使用lru_cache进行缓存"""
try:
if sys.platform == 'win32':
# Windows 系统下获取 CPU 信息
result = subprocess.run(['wmic', 'cpu', 'get', 'name'], capture_output=True, text=True)
if result.returncode == 0:
cpu_info = result.stdout.strip().removeprefix('Name').strip()
else:
# Linux 系统下获取 CPU 信息
with open('/proc/cpuinfo', 'r') as f:
for line in f:
if line.startswith('model name'):
cpu_info = line.split(':')[1].strip()
break
return cpu_info if cpu_info else "Unknown"
except:
return "Unknown"
def get_memory_usage() -> dict:
"""获取内存使用情况"""
process = psutil.Process()
system_memory = psutil.virtual_memory()
process_mem = process.memory_full_info().uss
percent = system_memory.used / (system_memory.total)
return {
"percent": percent,
"total": system_memory.total / 1024 / 1024, # MB
"free": system_memory.available / 1024 / 1024, # MB
"used": process_mem / 1024 / 1024, # MB
}
def get_cpu_usage() -> float:
"""获取CPU使用率"""
try:
return psutil.cpu_percent()
except:
return 0.0
================================================
FILE: kirara_ai/web/api/tracing/__init__.py
================================================
from .routes import tracing_bp
__all__ = ["tracing_bp"]
================================================
FILE: kirara_ai/web/api/tracing/routes.py
================================================
import asyncio
import json
from quart import Blueprint, g, jsonify, request, websocket
from kirara_ai.internal import shutdown_event
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
from kirara_ai.tracing.llm_tracer import LLMTracer
from kirara_ai.tracing.manager import TracingManager
from kirara_ai.web.auth.middleware import require_auth
from kirara_ai.web.auth.services import AuthService
tracing_bp = Blueprint("tracing", __name__, url_prefix="/api/tracing")
logger = get_logger("Tracing-API")
@tracing_bp.route("/types", methods=["GET"])
@require_auth
async def get_trace_types():
"""获取所有可用的追踪器类型"""
container: DependencyContainer = g.container
tracing_manager = container.resolve(TracingManager)
return jsonify({
"types": tracing_manager.get_tracer_types()
})
@tracing_bp.route("/llm/traces", methods=["POST"])
@require_auth
async def get_llm_traces():
"""获取LLM追踪记录,支持筛选和分页"""
# 获取查询参数
data = await request.json
page = data.get("page", 1)
page_size = data.get("page_size", 20)
model_id = data.get("model_id")
backend_name = data.get("backend_name")
status = data.get("status")
# 构建过滤条件
filters = {}
if model_id:
filters["model_id"] = model_id
if backend_name:
filters["backend_name"] = backend_name
if status:
filters["status"] = status
container: DependencyContainer = g.container
tracing_manager = container.resolve(TracingManager)
llm_tracer = tracing_manager.get_tracer("llm")
if not llm_tracer:
return jsonify({"error": "LLM tracer not found"}), 404
# 使用统一的查询接口
records, total = llm_tracer.get_traces(
filters=filters,
page=page,
page_size=page_size
)
return jsonify({
"items": [record.to_dict() for record in records],
"total": total,
"page": page,
"page_size": page_size,
"total_pages": (total + page_size - 1) // page_size
})
@tracing_bp.route("/llm/detail/", methods=["GET"])
@require_auth
async def get_llm_trace_detail(trace_id: str):
"""获取特定LLM请求的详细信息"""
container: DependencyContainer = g.container
tracing_manager = container.resolve(TracingManager)
llm_tracer = tracing_manager.get_tracer("llm")
if not llm_tracer:
return jsonify({"error": "LLM tracer not found"}), 404
trace = llm_tracer.get_trace_by_id(trace_id)
if not trace:
return jsonify({"error": "Trace not found"}), 404
return jsonify(trace.to_detail_dict())
@tracing_bp.route("/llm/statistics", methods=["GET"])
@require_auth
async def get_llm_statistics():
"""获取LLM统计信息"""
container: DependencyContainer = g.container
tracing_manager = container.resolve(TracingManager)
llm_tracer = tracing_manager.get_tracer("llm")
if not llm_tracer:
return jsonify({"error": "LLM tracer not found"}), 404
assert isinstance(llm_tracer, LLMTracer)
stats = llm_tracer.get_statistics()
return jsonify(stats)
@tracing_bp.websocket("/ws")
async def tracing_ws():
"""WebSocket接口,用于实时推送追踪日志"""
container: DependencyContainer = g.container
tracing_manager = container.resolve(TracingManager)
auth_service: AuthService = container.resolve(AuthService)
# 获取所有追踪器类型
tracer_types = tracing_manager.get_tracer_types()
# 发送欢迎消息
await websocket.send(json.dumps({
"type": "connected",
"message": "Connected to tracing websocket",
"data": {
"available_tracers": tracer_types
}
}))
# 验证token
try:
token_data = await websocket.receive()
token = json.loads(token_data)["token"]
if not auth_service.verify_token(token):
await websocket.close(code=1008, reason="Invalid token")
return
except Exception as e:
logger.error(f"WebSocket连接错误: {e}")
await websocket.close(code=1008, reason="Invalid token")
return
# 接收命令
cmd = await websocket.receive()
cmd = json.loads(cmd)
# 订阅
if cmd.get("action") == "subscribe":
if tracer_type := cmd.get("tracer_type"):
tracer = tracing_manager.get_tracer(tracer_type)
if tracer:
# 注册WebSocket客户端
queue: asyncio.Queue = tracer.register_ws_client()
await websocket.send(json.dumps({
"type": "subscribe_success",
"message": "Subscribed to tracing websocket",
"data": {
"tracer_type": tracer_type
}
}))
else:
await websocket.close(code=1008, reason="Tracer not found")
return
else:
await websocket.close(code=1008, reason="Invalid tracer type")
return
else:
await websocket.close(code=1008, reason="Invalid action")
return
try:
# 保持连接打开状态,直到客户端断开连接
while not shutdown_event.is_set():
# 摸鱼
message = await queue.get()
if message is None:
break
await websocket.send(json.dumps(message))
finally:
if tracer:
tracer.unregister_ws_client(queue)
================================================
FILE: kirara_ai/web/api/workflow/README.md
================================================
# 工作流 API 🔄
工作流 API 提供了管理工作流的功能。工作流由多个区块组成,用于处理消息和执行任务。每个工作流都属于一个特定的组。
## API 端点
### 获取所有工作流
```http
GET/backend-api/api/workflow
```
获取所有已注册的工作流基本信息。
**响应示例:**
```json
{
"workflows": [
{
"group_id": "chat",
"workflow_id": "normal",
"name": "普通聊天",
"description": "处理普通聊天消息的工作流",
"block_count": 3,
"metadata": {
"category": "chat",
"tags": ["normal", "chat"]
}
}
]
}
```
### 获取特定工作流
```http
GET/backend-api/api/workflow/{group_id}/{workflow_id}
```
获取指定工作流的详细信息。
**响应示例:**
```json
{
"workflow": {
"group_id": "chat",
"workflow_id": "normal",
"name": "普通聊天",
"description": "处理普通聊天消息的工作流",
"blocks": [
{
"block_id": "input_1",
"type_name": "MessageInputBlock",
"name": "消息输入",
"config": {
"format": "text"
},
"position": {
"x": 100,
"y": 100
}
},
{
"block_id": "llm_1",
"type_name": "LLMBlock",
"name": "语言模型",
"config": {
"backend": "openai",
"temperature": 0.7
},
"position": {
"x": 300,
"y": 100
}
},
{
"block_id": "output_1",
"type_name": "MessageOutputBlock",
"name": "消息输出",
"config": {
"format": "text"
},
"position": {
"x": 500,
"y": 100
}
}
],
"wires": [
{
"source_block": "input_1",
"source_output": "message",
"target_block": "llm_1",
"target_input": "prompt"
},
{
"source_block": "llm_1",
"source_output": "response",
"target_block": "output_1",
"target_input": "message"
}
],
"metadata": {
"category": "chat",
"tags": ["normal", "chat"]
}
}
}
```
### 创建工作流
```http
POST/backend-api/api/workflow/{group_id}/{workflow_id}
```
创建新的工作流。
**请求体:**
```json
{
"group_id": "chat",
"workflow_id": "creative",
"name": "创意聊天",
"description": "处理创意聊天的工作流",
"blocks": [
{
"block_id": "input_1",
"type_name": "MessageInputBlock",
"name": "消息输入",
"config": {
"format": "text"
},
"position": {
"x": 100,
"y": 100
}
},
{
"block_id": "prompt_1",
"type_name": "PromptBlock",
"name": "提示词处理",
"config": {
"template": "请发挥创意回答以下问题:{{input}}"
},
"position": {
"x": 300,
"y": 100
}
},
{
"block_id": "llm_1",
"type_name": "LLMBlock",
"name": "语言模型",
"config": {
"backend": "anthropic",
"temperature": 0.9
},
"position": {
"x": 500,
"y": 100
}
}
],
"wires": [
{
"source_block": "input_1",
"source_output": "message",
"target_block": "prompt_1",
"target_input": "input"
},
{
"source_block": "prompt_1",
"source_output": "output",
"target_block": "llm_1",
"target_input": "prompt"
}
],
"metadata": {
"category": "chat",
"tags": ["creative", "chat"]
}
}
```
### 更新工作流
```http
PUT/backend-api/api/workflow/{group_id}/{workflow_id}
```
更新现有工作流。请求体格式与创建工作流相同。
### 删除工作流
```http
DELETE/backend-api/api/workflow/{group_id}/{workflow_id}
```
删除指定工作流。成功时返回:
```json
{
"message": "Workflow deleted successfully"
}
```
## 数据模型
### Wire (工作流连线)
- `source_block`: 源区块 ID
- `source_output`: 源区块输出端口
- `target_block`: 目标区块 ID
- `target_input`: 目标区块输入端口
### BlockInstance (区块实例)
- `block_id`: 区块 ID
- `type_name`: 区块类型名称
- `name`: 区块显示名称
- `config`: 区块配置
- `position`: 区块位置
- `x`: X 坐标
- `y`: Y 坐标
### WorkflowDefinition (工作流定义)
- `group_id`: 工作流组 ID
- `workflow_id`: 工作流 ID
- `name`: 工作流名称
- `description`: 工作流描述
- `blocks`: 区块列表
- `wires`: 连线列表
- `metadata`: 元数据(可选)
### WorkflowInfo (工作流信息)
- `group_id`: 工作流组 ID
- `workflow_id`: 工作流 ID
- `name`: 工作流名称
- `description`: 工作流描述
- `block_count`: 区块数量
- `metadata`: 元数据(可选)
### WorkflowList (工作流列表)
- `workflows`: 工作流信息列表
### WorkflowResponse (工作流响应)
- `workflow`: 工作流定义
## 区块类型
工作流中可以使用的区块类型包括:
### MessageInputBlock
- 功能:接收输入消息
- 输入:无
- 输出:
- `message`: 消息内容
- 配置:
- `format`: 消息格式(text/image/audio)
### MessageOutputBlock
- 功能:输出消息
- 输入:
- `message`: 消息内容
- 输出:无
- 配置:
- `format`: 消息格式(text/image/audio)
### LLMBlock
- 功能:调用大语言模型
- 输入:
- `prompt`: 提示词
- 输出:
- `response`: 模型响应
- 配置:
- `backend`: 使用的后端
- `temperature`: 温度参数
### PromptBlock
- 功能:处理提示词
- 输入:
- `input`: 输入内容
- 输出:
- `output`: 处理后的提示词
- 配置:
- `template`: 提示词模板
## 相关代码
- [工作流注册表](../../../workflow/core/workflow/registry.py)
- [工作流构建器](../../../workflow/core/workflow/builder.py)
- [区块注册表](../../../workflow/core/block/registry.py)
- [系统预设工作流](../../../../data/workflows)
## 错误处理
所有 API 端点在发生错误时都会返回适当的 HTTP 状态码和错误信息:
```json
{
"error": "错误描述信息"
}
```
常见状态码:
- 400: 请求参数错误、工作流配置无效或工作流已存在
- 404: 工作流不存在
- 500: 服务器内部错误
## 使用示例
### 获取所有工作流
```python
import requests
response = requests.get(
'http://localhost:8080/api/workflow',
headers={'Authorization': f'Bearer {token}'}
)
```
### 创建新工作流
```python
import requests
workflow_data = {
"group_id": "chat",
"workflow_id": "creative",
"name": "创意聊天",
"description": "处理创意聊天的工作流",
"blocks": [
{
"block_id": "input_1",
"type_name": "MessageInputBlock",
"name": "消息输入",
"config": {
"format": "text"
},
"position": {
"x": 100,
"y": 100
}
},
{
"block_id": "llm_1",
"type_name": "LLMBlock",
"name": "语言模型",
"config": {
"backend": "anthropic",
"temperature": 0.9
},
"position": {
"x": 300,
"y": 100
}
}
],
"wires": [
{
"source_block": "input_1",
"source_output": "message",
"target_block": "llm_1",
"target_input": "prompt"
}
],
"metadata": {
"category": "chat",
"tags": ["creative"]
}
}
response = requests.post(
'http://localhost:8080/api/workflow/chat/creative',
headers={'Authorization': f'Bearer {token}'},
json=workflow_data
)
```
### 删除工作流
```python
import requests
response = requests.delete(
'http://localhost:8080/api/workflow/chat/creative',
headers={'Authorization': f'Bearer {token}'}
)
```
## 相关文档
- [系统架构](../../README.md#系统架构-)
- [API 认证](../../README.md#api认证-)
- [工作流开发](../../../workflow/README.md#工作流开发-)
================================================
FILE: kirara_ai/web/api/workflow/__init__.py
================================================
from .routes import workflow_bp
__all__ = ["workflow_bp"]
================================================
FILE: kirara_ai/web/api/workflow/models.py
================================================
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from kirara_ai.workflow.core.workflow.base import WorkflowConfig
class Wire(BaseModel):
"""工作流连线"""
source_block: str # block ID
source_output: str
target_block: str # block ID
target_input: str
class BlockInstance(BaseModel):
"""工作流中的Block实例"""
type_name: str
name: str
config: Dict[str, Any]
position: Dict[str, int] # x, y 坐标
class WorkflowDefinition(BaseModel):
"""工作流定义"""
group_id: str
workflow_id: str
name: str
description: str
blocks: List[BlockInstance]
wires: List[Wire]
config: WorkflowConfig = WorkflowConfig()
metadata: Optional[Dict[str, Any]] = None
class WorkflowInfo(BaseModel):
"""工作流基本信息"""
group_id: str
workflow_id: str
name: str
description: str
block_count: int
metadata: Optional[Dict[str, Any]] = None
class WorkflowList(BaseModel):
"""工作流列表响应"""
workflows: List[WorkflowInfo]
class WorkflowResponse(BaseModel):
"""单个工作流响应"""
workflow: WorkflowDefinition
================================================
FILE: kirara_ai/web/api/workflow/routes.py
================================================
import os
from typing import List
from quart import Blueprint, g, jsonify, request
from kirara_ai.workflow.core.block.registry import BlockRegistry
from kirara_ai.workflow.core.workflow import WorkflowRegistry
from kirara_ai.workflow.core.workflow.builder import WorkflowBuilder
from ...auth.middleware import require_auth
from .models import BlockInstance, Wire, WorkflowDefinition, WorkflowInfo, WorkflowList, WorkflowResponse
workflow_bp = Blueprint("workflow", __name__)
@workflow_bp.route("", methods=["GET"])
@require_auth
async def list_workflows():
"""获取所有工作流列表"""
registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
workflows = []
for workflow_id, builder in registry._workflows.items():
# 从 workflow_id 解析 group_id
group_id, wf_id = workflow_id.split(":", 1)
workflows.append(
WorkflowInfo(
group_id=group_id,
workflow_id=wf_id,
name=builder.name,
description=builder.description,
block_count=len(builder.nodes_by_name),
metadata=getattr(builder, "metadata", None),
)
)
workflows.sort(key=lambda x: f"{x.group_id}:{x.workflow_id}")
return WorkflowList(workflows=workflows).model_dump()
@workflow_bp.route("//", methods=["GET"])
@require_auth
async def get_workflow(group_id: str, workflow_id: str):
"""获取特定工作流的详细信息"""
registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
block_registry: BlockRegistry = g.container.resolve(BlockRegistry)
full_id = f"{group_id}:{workflow_id}"
builder = registry.get(full_id)
if not builder:
return jsonify({"error": "Workflow not found"}), 404
assert isinstance(builder, WorkflowBuilder)
# 构建工作流定义
blocks: List[BlockInstance] = []
for node in builder.nodes:
blocks.append(
BlockInstance(
type_name=block_registry.get_block_type_name(node.spec.block_class),
name=node.name,
config=node.spec.kwargs,
position=node.position or {"x": 0, "y": 0},
)
)
wires: List[Wire] = []
for source_name, source_output, target_name, target_input in builder.wire_specs:
wires.append(
Wire(
source_block=source_name,
source_output=source_output,
target_block=target_name,
target_input=target_input,
)
)
workflow_def = WorkflowDefinition(
group_id=group_id,
workflow_id=workflow_id,
name=builder.name,
description=builder.description,
blocks=blocks,
wires=wires,
metadata=getattr(builder, "metadata", None),
config=builder.config,
)
return WorkflowResponse(workflow=workflow_def).model_dump()
@workflow_bp.route("//", methods=["POST"])
@require_auth
async def create_workflow(group_id: str, workflow_id: str):
"""创建新的工作流"""
data = await request.get_json()
workflow_def = WorkflowDefinition(**data)
registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
block_registry: BlockRegistry = g.container.resolve(BlockRegistry)
# 检查工作流是否已存在
full_id = f"{group_id}:{workflow_id}"
if registry.get(full_id):
return jsonify({"error": "Workflow already exists"}), 400
# 创建工作流构建器
try:
# 创建工作流构建器
builder = WorkflowBuilder(workflow_def.name)
builder.description = workflow_def.description
# 根据定义添加块和连接
for block_def in workflow_def.blocks:
block_class = block_registry.get(block_def.type_name)
if not block_class:
raise ValueError(f"Block type {block_def.type_name} not found")
if not builder.head:
builder.use(block_class, name=block_def.name, **block_def.config)
else:
builder.chain(block_class, name=block_def.name, **block_def.config)
builder.update_position(block_def.name, block_def.position)
# 不要用自动连线,用我们的
builder.wire_specs = []
# 添加连接
for wire in workflow_def.wires:
builder.force_connect(
wire.source_block,
wire.target_block,
wire.source_output,
wire.target_input
)
# 保存工作流
file_path = registry.get_workflow_path(group_id, workflow_id)
builder.set_config(workflow_def.config)
builder.save_to_yaml(file_path, g.container)
# 注册工作流
registry.register(group_id, workflow_id, builder)
return workflow_def.model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 400
@workflow_bp.route("//", methods=["PUT"])
@require_auth
async def update_workflow(group_id: str, workflow_id: str):
"""更新现有工作流"""
data = await request.get_json()
workflow_def = WorkflowDefinition(**data)
registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
block_registry: BlockRegistry = g.container.resolve(BlockRegistry)
# 检查工作流是否存在
full_id = f"{group_id}:{workflow_id}"
if not registry.get(full_id):
return jsonify({"error": "Workflow not found"}), 404
# 更新工作流
try:
# 创建新的工作流构建器
builder = WorkflowBuilder(workflow_def.name)
builder.description = workflow_def.description
# 根据定义添加块和连接
for block_def in workflow_def.blocks:
block_class = block_registry.get(block_def.type_name)
if not block_class:
raise ValueError(f"Block type {block_def.type_name} not found")
if not builder.head:
builder.use(block_class, name=block_def.name, **block_def.config)
else:
builder.chain(block_class, name=block_def.name, **block_def.config)
builder.update_position(block_def.name, block_def.position)
# 不要用自动连线,用我们的
builder.wire_specs = []
# 添加连接
for wire in workflow_def.wires:
builder.force_connect(
wire.source_block,
wire.target_block,
wire.source_output,
wire.target_input
)
# 保存工作流
file_path = registry.get_workflow_path(group_id, workflow_id)
if os.path.exists(file_path):
os.remove(file_path)
new_file_path = registry.get_workflow_path(
data["group_id"], data["workflow_id"]
)
builder.set_config(workflow_def.config)
builder.save_to_yaml(new_file_path, g.container)
# 更新注册表
registry.unregister(group_id, workflow_id)
registry.register(data["group_id"], data["workflow_id"], builder)
return workflow_def.model_dump()
except Exception as e:
return jsonify({"error": str(e)}), 400
@workflow_bp.route("//", methods=["DELETE"])
@require_auth
async def delete_workflow(group_id: str, workflow_id: str):
"""删除工作流"""
registry: WorkflowRegistry = g.container.resolve(WorkflowRegistry)
# 检查工作流是否存在
full_id = f"{group_id}:{workflow_id}"
if not registry.get(full_id):
return jsonify({"error": "Workflow not found"}), 404
try:
# 从注册表中移除
registry._workflows.pop(full_id, None)
# 删除文件
file_path = registry.get_workflow_path(group_id, workflow_id)
if os.path.exists(file_path):
os.remove(file_path)
return jsonify({"message": "Workflow deleted successfully"})
except Exception as e:
return jsonify({"error": str(e)}), 400
================================================
FILE: kirara_ai/web/app.py
================================================
import asyncio
import mimetypes
import os
import socket
from pathlib import Path
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, PlainTextResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from quart import Quart, g, jsonify
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import HypercornLoggerWrapper, get_logger
from kirara_ai.web.auth.services import AuthService, FileBasedAuthService
from kirara_ai.web.utils import create_no_cache_response, install_webui
from .api.block import block_bp
from .api.dispatch import dispatch_bp
from .api.im import im_bp
from .api.llm import llm_bp
from .api.mcp import mcp_bp
from .api.media import media_bp
from .api.plugin import plugin_bp
from .api.system import system_bp
from .api.tracing import tracing_bp
from .api.workflow import workflow_bp
from .auth.routes import auth_bp
ERROR_MESSAGE = """
WebUI launch failed!
Web UI not found. Please download from here and extract to the TARGET_DIR folder, make sure the TARGET_DIR/index.html file exists.
WebUI 启动失败!
Web UI 未找到。请从 这里 下载并解压到 TARGET_DIR 文件夹,确保 TARGET_DIR/index.html 文件存在。
"""
cwd = os.getcwd()
STATIC_FOLDER = f"{cwd}/web"
logger = get_logger("WebServer")
custom_static_assets: dict[str, str] = {}
def create_web_api_app(container: DependencyContainer) -> Quart:
"""创建 Web API 应用(Quart)"""
app = Quart(__name__, static_folder=STATIC_FOLDER)
app.json.sort_keys = False # type: ignore
# 注册蓝图
app.register_blueprint(auth_bp, url_prefix="/api/auth")
app.register_blueprint(im_bp, url_prefix="/api/im")
app.register_blueprint(llm_bp, url_prefix="/api/llm")
app.register_blueprint(dispatch_bp, url_prefix="/api/dispatch")
app.register_blueprint(block_bp, url_prefix="/api/block")
app.register_blueprint(workflow_bp, url_prefix="/api/workflow")
app.register_blueprint(plugin_bp, url_prefix="/api/plugin")
app.register_blueprint(system_bp, url_prefix="/api/system")
app.register_blueprint(media_bp, url_prefix="/api/media")
app.register_blueprint(tracing_bp, url_prefix="/api/tracing")
app.register_blueprint(mcp_bp, url_prefix="/api/mcp")
@app.errorhandler(Exception)
def handle_exception(error):
logger.opt(exception=error).error("Error during request")
response = jsonify({"error": str(error)})
response.status_code = 500
return response
# 在每个请求前将容器注入到上下文
@app.before_request
async def inject_container(): # type: ignore
g.container = container
@app.before_websocket
async def inject_container_ws(): # type: ignore
g.container = container
app.container = container # type: ignore
return app
def create_app(container: DependencyContainer) -> FastAPI:
"""创建主应用(FastAPI)"""
app = FastAPI()
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 强制设置 MIME 类型
mimetypes.add_type("text/html", ".html")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("text/javascript", ".js")
mimetypes.add_type("image/svg+xml", ".svg")
mimetypes.add_type("image/png", ".png")
mimetypes.add_type("image/jpeg", ".jpg")
mimetypes.add_type("image/gif", ".gif")
mimetypes.add_type("image/webp", ".webp")
# 自定义静态资源处理
async def serve_custom_static(path: str, request: Request):
if path not in custom_static_assets:
raise HTTPException(status_code=404, detail="File not found")
file_path = Path(custom_static_assets[path])
try:
return await create_no_cache_response(file_path, request)
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"处理自定义静态资源时出错: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.get("/")
async def index(request: Request):
try:
index_path = Path(STATIC_FOLDER) / "index.html"
if not index_path.exists():
return HTMLResponse(content=ERROR_MESSAGE.replace("TARGET_DIR", STATIC_FOLDER))
return await create_no_cache_response(index_path, request)
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Error serving index: {e}")
return HTMLResponse(content=ERROR_MESSAGE.replace("TARGET_DIR", STATIC_FOLDER))
@app.middleware("http")
async def spa_middleware(request: Request, call_next):
path = request.url.path
# 如果请求路径在自定义静态资源列表中,则返回自定义静态资源
if path in custom_static_assets:
return await serve_custom_static(path, request)
skip_paths = [route.path for route in app.routes] # type: ignore
# 如果路径在跳过路径列表中,则直接返回
if any(path == skip_path for skip_path in skip_paths):
return await call_next(request)
skip_paths.remove("/")
# 如果路径以 backend-api 开头,交由内置路由处理
if any(path.startswith(skip_path) for skip_path in skip_paths):
return await call_next(request)
file_path = Path(STATIC_FOLDER) / path.lstrip('/')
# 检查路径穿越
if not file_path.resolve().is_relative_to(Path(STATIC_FOLDER).resolve()):
raise HTTPException(status_code=404, detail="Access denied")
# 如果文件存在,返回文件并禁止缓存
if file_path.is_file():
try:
return await create_no_cache_response(file_path, request)
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"处理静态文件时出错: {e}")
return FileResponse(file_path) # 退回到普通文件响应
fallback_path = Path(STATIC_FOLDER) / "index.html"
# 否则返回 index.html(SPA 路由)
if fallback_path.is_file():
try:
return await create_no_cache_response(fallback_path, request)
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"处理index.html时出错: {e}")
return FileResponse(fallback_path) # 退回到普通文件响应
else:
return PlainTextResponse(status_code=404, content="route not found")
return app
class WebServer:
app: FastAPI
web_api_app: Quart
listen_host: str
listen_port: int
container: DependencyContainer
def __init__(self, container: DependencyContainer):
self.app = create_app(container)
self.web_api_app = create_web_api_app(container)
self.server_task = None
self.shutdown_event = asyncio.Event()
self.container = container
container.register(
AuthService,
FileBasedAuthService(
password_file=Path(container.resolve(GlobalConfig).web.password_file),
secret_key=container.resolve(GlobalConfig).web.secret_key,
),
)
self.config = container.resolve(GlobalConfig)
# 配置 hypercorn
from hypercorn.logging import Logger
self.hypercorn_config = Config()
self.hypercorn_config._log = Logger(self.hypercorn_config)
# 创建自定义的日志包装器,添加 URL 过滤
class FilteredLoggerWrapper(HypercornLoggerWrapper):
def info(self, message, *args, **kwargs):
# 过滤掉不需要记录的URL请求日志
ignored_paths = [
'/backend-api/api/system/status', # 添加需要过滤的URL路径
'/favicon.ico',
]
for path in ignored_paths:
if path in str(args):
return
super().info(message, *args, **kwargs)
# 使用新的过滤日志包装器
self.hypercorn_config._log.access_logger = FilteredLoggerWrapper(logger) # type: ignore
self.hypercorn_config._log.error_logger = HypercornLoggerWrapper(logger) # type: ignore
# 挂载 Web API 应用
self.mount_app("/backend-api", self.web_api_app)
def mount_app(self, prefix: str, app):
"""挂载子应用到指定路径前缀"""
self.app.mount(prefix, app)
def _check_port_available(self, host: str, port: int) -> bool:
"""检查端口是否可用"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind((host, port))
return True
except socket.error:
return False
async def start(self):
"""启动Web服务器"""
# 确定最终使用的host和port
if self.container.has("cli_args"):
cli_args = self.container.resolve("cli_args")
self.listen_host = cli_args.host or self.config.web.host
self.listen_port = cli_args.port or self.config.web.port
else:
self.listen_host = self.config.web.host
self.listen_port = self.config.web.port
self.hypercorn_config.bind = [f"{self.listen_host}:{self.listen_port}"]
# 检查端口是否被占用
if not self._check_port_available(self.listen_host, self.listen_port):
error_msg = f"端口 {self.listen_port} 已被占用,无法启动服务器,请修改端口或关闭其他占用端口的程序。"
logger.error(error_msg)
raise RuntimeError(error_msg)
self.server_task = asyncio.create_task(serve(self.app, self.hypercorn_config, shutdown_trigger=self.shutdown_event.wait)) # type: ignore
logger.info(f"监听地址:http://{self.listen_host}:{self.listen_port}/")
# 检查WebUI是否存在,如果不存在则尝试自动安装
self._check_and_install_webui()
async def stop(self):
"""停止Web服务器"""
self.shutdown_event.set()
if self.server_task:
try:
await asyncio.wait_for(self.server_task, timeout=3.0)
except asyncio.TimeoutError:
logger.warning("Server shutdown timed out after 3 seconds.")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Error during server shutdown: {e}")
def add_static_assets(self, url_path: str, local_path: str):
"""添加自定义静态资源"""
if not os.path.exists(local_path):
logger.warning(f"Static asset path does not exist: {local_path}")
return
custom_static_assets[url_path] = local_path
def _check_and_install_webui(self):
"""检查WebUI是否存在,如果不存在则尝试自动安装"""
index_path = Path(STATIC_FOLDER) / "index.html"
if not index_path.exists():
logger.info("检测到WebUI不存在,将在服务器启动后自动安装...")
# 创建异步任务,但不等待完成
self._webui_install_task = asyncio.create_task(self._install_webui())
async def _install_webui(self):
"""安装WebUI的异步任务"""
try:
logger.info("开始自动安装WebUI...")
success, message = await install_webui(Path(STATIC_FOLDER))
if success:
logger.info(message)
logger.info(f"WebUI已安装到 {STATIC_FOLDER},请刷新浏览器")
else:
logger.error(message)
logger.error("WebUI自动安装失败,请手动下载并安装")
except Exception as e:
logger.error(f"WebUI安装过程出错: {e}")
================================================
FILE: kirara_ai/web/auth/middleware.py
================================================
from functools import wraps
from quart import g, jsonify, request
from kirara_ai.web.auth.services import AuthService
def require_auth(f):
@wraps(f)
async def decorated_function(*args, **kwargs):
# 如果 query string 中包含 token,则使用该 token
token = request.args.get("auth_token")
if not token:
auth_header = request.headers.get("Authorization")
if not auth_header:
return jsonify({"error": "No authorization header"}), 401
token_type, token = auth_header.split()
if token_type.lower() != "bearer":
return jsonify({"error": "Invalid token type"}), 401
try:
auth_service: AuthService = g.container.resolve(AuthService)
if not auth_service.verify_token(token):
return jsonify({"error": "Invalid token"}), 401
return await f(*args, **kwargs)
except Exception as e:
raise e
return decorated_function
================================================
FILE: kirara_ai/web/auth/models.py
================================================
from pydantic import BaseModel
class LoginRequest(BaseModel):
password: str
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
class TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
================================================
FILE: kirara_ai/web/auth/routes.py
================================================
import secrets
from datetime import timedelta
from quart import Blueprint, g, jsonify, request
from kirara_ai.config.config_loader import CONFIG_FILE, ConfigLoader
from kirara_ai.config.global_config import GlobalConfig
from .middleware import require_auth
from .models import ChangePasswordRequest, LoginRequest, TokenResponse
from .services import AuthService
auth_bp = Blueprint("auth", __name__)
@auth_bp.route("/login", methods=["POST"])
async def login():
data = await request.get_json()
login_data = LoginRequest(**data)
auth_service: AuthService = g.container.resolve(AuthService)
if auth_service.is_first_time():
auth_service.save_password(login_data.password)
token = auth_service.create_access_token(timedelta(days=1))
return TokenResponse(access_token=token).model_dump()
if not auth_service.verify_password(login_data.password):
return jsonify({"error": "Invalid password"}), 401
token = auth_service.create_access_token(timedelta(days=1))
return TokenResponse(access_token=token).model_dump()
@auth_bp.route("/change-password", methods=["POST"])
@require_auth
async def change_password():
data = await request.get_json()
password_data = ChangePasswordRequest(**data)
auth_service: AuthService = g.container.resolve(AuthService)
if not auth_service.verify_password(password_data.old_password):
return jsonify({"error": "Invalid old password"})
auth_service.save_password(password_data.new_password)
# 重新设置一个 secret_key,让所有的 token 失效
config: GlobalConfig = g.container.resolve(GlobalConfig)
config.web.secret_key = secrets.token_hex(32)
ConfigLoader.save_config_with_backup(CONFIG_FILE, config)
g.container.resolve(AuthService).secret_key = config.web.secret_key
return jsonify({"message": "Password changed successfully"})
@auth_bp.route("/check-first-time", methods=["GET"])
async def check_first_time():
auth_service: AuthService = g.container.resolve(AuthService)
return jsonify({"is_first_time": auth_service.is_first_time()})
================================================
FILE: kirara_ai/web/auth/services.py
================================================
from abc import ABC, abstractmethod
from datetime import timedelta
from pathlib import Path
from typing import Optional
class AuthService(ABC):
@abstractmethod
def is_first_time(self) -> bool:
pass
@abstractmethod
def save_password(self, password: str) -> None:
pass
@abstractmethod
def verify_password(self, password: str) -> bool:
pass
@abstractmethod
def create_access_token(self, expires_delta: Optional[timedelta] = None) -> str:
pass
@abstractmethod
def verify_token(self, token: str) -> bool:
pass
class FileBasedAuthService(AuthService):
def __init__(self, password_file: Path, secret_key: str):
self.password_file = password_file
self.secret_key = secret_key
def is_first_time(self) -> bool:
return not self.password_file.exists()
def save_password(self, password: str) -> None:
from .utils import hash_password
self.password_file.parent.mkdir(parents=True, exist_ok=True)
hashed = hash_password(password)
with open(self.password_file, "wb") as f:
f.write(hashed)
def verify_password(self, password: str) -> bool:
from .utils import verify_password
if not self.password_file.exists():
return False
with open(self.password_file, "rb") as f:
hashed = f.read()
return verify_password(password, hashed)
def create_access_token(self, expires_delta: Optional[timedelta] = None) -> str:
from .utils import create_jwt_token
return create_jwt_token(self.secret_key, expires_delta)
def verify_token(self, token: str) -> bool:
from .utils import verify_jwt_token
return verify_jwt_token(token, self.secret_key)
class MockAuthService(AuthService):
def __init__(self):
self._password = None
self._is_first_time = True
def is_first_time(self) -> bool:
return self._is_first_time
def save_password(self, password: str) -> None:
self._password = password
self._is_first_time = False
def verify_password(self, password: str) -> bool:
return password == self._password
def create_access_token(self, expires_delta: Optional[timedelta] = None) -> str:
return "mock_token"
def verify_token(self, token: str) -> bool:
return token == "mock_token"
================================================
FILE: kirara_ai/web/auth/utils.py
================================================
from datetime import datetime, timedelta
from typing import Optional
import bcrypt
import jwt
def hash_password(password: str) -> bytes:
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode(), salt)
def verify_password(password: str, hashed: bytes) -> bool:
return bcrypt.checkpw(password.encode(), hashed)
def create_jwt_token(secret_key: str, expires_delta: Optional[timedelta] = None) -> str:
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=30)
to_encode = {"exp": expire}
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm="HS256")
return encoded_jwt
def verify_jwt_token(token: str, secret_key: str) -> bool:
try:
jwt.decode(token, secret_key, algorithms=["HS256"])
return True
except:
return False
================================================
FILE: kirara_ai/web/utils.py
================================================
import asyncio
import os
import tarfile
import tempfile
import time
from pathlib import Path
import aiohttp
from fastapi import HTTPException, Request
from fastapi.responses import FileResponse, Response
from kirara_ai.logger import get_logger
from kirara_ai.web.api.system.utils import download_file, get_latest_npm_version
logger = get_logger("WebUtils")
async def create_no_cache_response(file_path: Path, request: Request) -> Response:
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
stat = file_path.stat()
mtime = stat.st_mtime_ns
size = stat.st_size
etag = f"{mtime}-{size}"
if_none_match = request.headers.get("if-none-match")
if if_none_match == etag:
return Response(status_code=304)
response = FileResponse(file_path)
response.headers["ETag"] = etag
response.headers["Cache-Control"] = "no-cache"
return response
async def test_npm_registry_speed(registries: list[str]) -> str:
"""测试多个NPM注册表的速度,返回最快的一个"""
# 默认使用第一个
fastest_registry = registries[0]
fastest_avg_time = float('inf')
# 每个注册表测试3次
test_count = 3
async def test_registry(registry: str) -> tuple[str, float]:
total_time = 0
success_count = 0
for i in range(test_count):
try:
start_time = time.time()
async with aiohttp.ClientSession() as session:
async with session.get(
f"{registry}/kirara-ai-webui",
timeout=aiohttp.ClientTimeout(total=5)
) as response:
if response.status == 200:
elapsed = time.time() - start_time
total_time += elapsed
success_count += 1
except Exception as e:
logger.warning(f"测试下载源 {registry} 第{i+1}次失败: {e}")
# 计算平均响应时间,如果全部失败则返回无穷大
avg_time = total_time / success_count if success_count > 0 else float('inf')
return registry, avg_time
# 并发测试所有注册表
tasks = [test_registry(registry) for registry in registries]
results = await asyncio.gather(*tasks)
# 找出平均响应时间最快的注册表
for registry, avg_time in results:
if avg_time < fastest_avg_time:
fastest_avg_time = avg_time
fastest_registry = registry
if fastest_avg_time != float('inf'):
logger.info(f"选择最快的下载源: {fastest_registry},平均响应时间: {fastest_avg_time:.2f}秒")
else:
logger.warning(f"所有下载源测试均失败,默认使用: {fastest_registry}")
return fastest_registry
async def install_webui(install_path: Path) -> tuple[bool, str]:
"""
安装最新版本的WebUI
Args:
install_path: 安装目录路径
Returns:
(成功状态, 消息)
"""
try:
# 测试多个NPM注册表的速度
registries = [
"https://registry.npmjs.org",
"https://registry.npmmirror.com",
"https://registry.yarnpkg.com",
"https://mirrors.ustc.edu.cn/npm/",
]
npm_registry = await test_npm_registry_speed(registries)
temp_dir = tempfile.mkdtemp()
logger.info(f"开始从 {npm_registry} 获取最新WebUI版本信息")
latest_webui_version, webui_download_url = await get_latest_npm_version("kirara-ai-webui", npm_registry)
if not webui_download_url:
return False, "无法获取WebUI下载地址"
logger.info(f"开始下载WebUI v{latest_webui_version}: {webui_download_url}")
webui_file, webui_hash = await download_file(webui_download_url, temp_dir)
if not webui_file:
return False, "WebUI下载失败"
# 确保安装目录存在
os.makedirs(install_path, exist_ok=True)
# 解压并安装前端
logger.info(f"开始解压WebUI到 {install_path}")
with tarfile.open(webui_file, "r:gz") as tar:
# 解压 package/dist 里的所有文件到安装目录
for member in tar.getmembers():
if member.name.startswith("package/dist/"):
# 去掉 "package/dist/" 前缀
extracted_name = member.name[len("package/dist/"):]
if extracted_name: # 跳过空路径
member.name = extracted_name
tar.extract(member, path=str(install_path))
return True, f"WebUI v{latest_webui_version} 安装成功"
except Exception as e:
logger.error(f"WebUI安装失败: {e}")
return False, f"WebUI安装失败: {str(e)}"
finally:
if 'temp_dir' in locals():
import shutil
shutil.rmtree(temp_dir)
================================================
FILE: kirara_ai/workflow/core/__init__.py
================================================
================================================
FILE: kirara_ai/workflow/core/block/__init__.py
================================================
from .base import Block, ConditionBlock, LoopBlock, LoopEndBlock
from .input_output import Input, Output
from .param import ParamMeta
from .registry import BlockRegistry
from .schema import BlockConfig, BlockInput, BlockOutput
__all__ = [
"Block",
"ConditionBlock",
"LoopBlock",
"LoopEndBlock",
"BlockRegistry",
"BlockInput",
"BlockOutput",
"BlockConfig",
"ParamMeta",
"Input",
"Output",
]
================================================
FILE: kirara_ai/workflow/core/block/base.py
================================================
from typing import Any, Callable, Dict, List, Optional
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block.input_output import Input, Output
class Block:
"""block 的基类"""
# block 的 id
id: str
# block 的名称
name: str
# block 的输入
inputs: Dict[str, Input] = {}
# block 的输出
outputs: Dict[str, Output] = {}
container: DependencyContainer
def __init__(
self,
name: Optional[str] = None,
inputs: Optional[Dict[str, Input]] = None,
outputs: Optional[Dict[str, Output]] = None,
):
self.id = getattr(self.__class__, "id", "anonymous_" + self.__class__.__name__)
if name is not None:
self.name = name
if inputs is not None:
self.inputs = inputs
if outputs is not None:
self.outputs = outputs
def execute(self, **kwargs) -> Dict[str, Any]:
# Placeholder for block logic
return {output: f"Processed {kwargs}" for output in self.outputs}
class ConditionBlock(Block):
"""条件判断块"""
name: str = "condition"
outputs: Dict[str, Output] = {
"condition_result": Output("condition_result", "条件结果", bool, "条件结果")
}
def __init__(
self,
condition_func: Callable[[Dict[str, Any]], bool],
inputs: Dict[str, "Input"],
):
super().__init__()
self.inputs = inputs
self.condition_func = condition_func
def execute(self, **kwargs) -> Dict[str, Any]:
result = self.condition_func(kwargs)
return {"condition_result": result}
class LoopBlock(Block):
"""循环控制块"""
name: str = "loop"
outputs: Dict[str, Output] = {
"should_continue": Output("should_continue", "是否继续", bool, "是否继续"),
"iteration": Output("iteration", "当前迭代数据", dict, "当前迭代数据"),
}
def __init__(
self,
condition_func: Callable[[Dict[str, Any]], bool],
inputs: Dict[str, "Input"],
iteration_var: str = "index",
):
super().__init__()
self.inputs = inputs
self.condition_func = condition_func
self.iteration_var = iteration_var
self.iteration_count = 0
def execute(self, **kwargs) -> Dict[str, Any]:
should_continue = self.condition_func(kwargs)
self.iteration_count += 1
return {
"should_continue": should_continue,
"iteration": {self.iteration_var: self.iteration_count, **kwargs},
}
class LoopEndBlock(Block):
"""循环结束块,收集循环结果"""
name: str = "loop_end"
outputs: Dict[str, Output] = {
"loop_results": Output("loop_results", "收集的循环结果", list, "收集的循环结果")
}
def __init__(self, inputs: Dict[str, "Input"]):
super().__init__()
self.inputs = inputs
self.results: List[Dict[str, Any]] = []
def execute(self, **kwargs) -> Dict[str, Any]:
self.results.append(kwargs)
return {"loop_results": self.results}
================================================
FILE: kirara_ai/workflow/core/block/input_output.py
================================================
from typing import Any, Optional
class Input:
def __init__(
self,
name: str,
label: str,
data_type: type,
description: str,
nullable: bool = False,
default: Optional[Any] = None,
):
self.name = name
self.label = label
self.data_type = data_type
self.description = description
self.nullable = nullable
self.default = default
def validate(self, value: Any) -> bool:
if value is None:
return self.nullable
return isinstance(value, self.data_type)
class Output:
def __init__(self, name: str, label: str, data_type: type, description: str):
self.name = name
self.label = label
self.data_type = data_type
self.description = description
def validate(self, value: Any) -> bool:
return isinstance(value, self.data_type)
================================================
FILE: kirara_ai/workflow/core/block/param.py
================================================
from typing import Callable, List, Optional, TypeVar
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block
T = TypeVar("T")
OptionsProvider = Callable[[DependencyContainer, Block], List[T]]
class ParamMeta:
def __init__(self, label: Optional[str] = None, description: Optional[str] = None, options_provider: Optional[OptionsProvider[T]] = None):
self.label = label
self.description = description
self.options_provider = options_provider
def __repr__(self):
return f"ParamMeta(label={self.label}, description={self.description}, options_provider={self.options_provider})"
def __str__(self):
return self.__repr__()
================================================
FILE: kirara_ai/workflow/core/block/registry.py
================================================
import warnings
from inspect import Parameter, signature
from typing import Annotated, Dict, List, Optional, Tuple, Type, get_args, get_origin
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.param import ParamMeta
from .schema import BlockConfig, BlockInput, BlockOutput
from .type_system import TypeSystem
def extract_block_param(param: Parameter, type_system: TypeSystem) -> BlockConfig:
"""
提取 Block 参数信息,包括类型字符串、标签、是否必需、描述和默认值。
"""
param_type = param.annotation
label = param.name
description = None
has_options = False
options_provider = None
if get_origin(param_type) is Annotated:
args = get_args(param_type)
if len(args) > 0:
actual_type = args[0]
metadata = args[1] if len(args) > 1 else None
if isinstance(metadata, ParamMeta):
label = metadata.label
description = metadata.description
has_options = metadata.options_provider is not None
options_provider = metadata.options_provider
# 递归调用 extract_block_param 处理实际类型
block_config = extract_block_param(
Parameter(
name=param.name,
kind=Parameter.POSITIONAL_OR_KEYWORD,
annotation=actual_type,
default=param.default,
),
type_system
)
type_string = block_config.type
required = block_config.required
default = block_config.default
else:
type_string = "Any"
required = True
default = None
else:
type_string, required, default = type_system.extract_type_info(param)
return BlockConfig(
name=param.name,
description=description,
type=type_string,
required=required,
default=default,
label=label,
has_options=has_options,
options=[],
options_provider=options_provider,
)
class BlockRegistry:
"""Block 注册表,用于管理所有已注册的 block"""
def __init__(self):
self._blocks = {}
self._localized_names = {}
self._type_system = TypeSystem()
def register(
self,
block_id: str,
group_id: str,
block_class: Type[Block],
localized_name: Optional[str] = None,
):
"""注册一个 block
Args:
block_id: block 的唯一标识
group_id: 组标识(internal 为框架内置)
block_class: block 类
localized_name: 本地化名称
"""
full_name = f"{group_id}:{block_id}"
if full_name in self._blocks:
raise ValueError(f"Block {full_name} already registered")
self._blocks[full_name] = block_class
block_class.id = block_id
if localized_name:
self._localized_names[full_name] = localized_name
# 注册 Input 和 Output 类型
for _, input_info in getattr(block_class, "inputs", {}).items():
type_name = self._type_system.get_type_name(input_info.data_type)
self._type_system.register_type(type_name, input_info.data_type)
for _, output_info in getattr(block_class, "outputs", {}).items():
type_name = self._type_system.get_type_name(output_info.data_type)
self._type_system.register_type(type_name, output_info.data_type)
def get(self, full_name: str) -> Optional[Type[Block]]:
"""获取已注册的 block 类"""
return self._blocks.get(full_name)
def get_localized_name(self, block_id: str) -> Optional[str]:
"""获取本地化名称"""
return self._localized_names.get(block_id, block_id)
def clear(self):
"""清空注册表"""
self._blocks.clear()
self._type_system = TypeSystem()
def get_block_type_name(self, block_class: Type[Block]) -> str:
"""获取 block 的类型名称,优先使用注册名称"""
# 遍历注册表查找匹配的 block 类
for full_name, registered_class in self._blocks.items():
if registered_class == block_class:
return full_name
warnings.warn(
f"Block class {block_class.__name__} is not registered. Using class path instead.",
UserWarning,
)
return f"!!{block_class.__module__}.{block_class.__name__}"
def get_all_types(self) -> List[Type[Block]]:
"""获取所有已注册的 block 类型"""
return list(self._blocks.values())
def extract_block_info(
self, block_type: Type[Block]
) -> Tuple[Dict[str, BlockInput], Dict[str, BlockOutput], Dict[str, BlockConfig]]:
"""
从 Block 类型中提取输入、输出和配置信息,并使用 BlockInput, BlockOutput, BlockConfig 对象封装。
Args:
block_type: Block 的类型。
Returns:
包含输入、输出和配置信息的字典。
"""
inputs = {}
outputs = {}
configs = {}
# 获取 Block 类的输入输出定义
for name, input_info in getattr(block_type, "inputs", {}).items():
type_name, _, _ = self._type_system.extract_type_info(input_info.data_type)
self._type_system.register_type(type_name, input_info.data_type)
inputs[name] = BlockInput(
name=name,
label=input_info.label,
description=input_info.description,
type=type_name,
required=not input_info.nullable,
default=input_info.default if hasattr(input_info, "default") else None,
)
for name, output_info in getattr(block_type, "outputs", {}).items():
type_name, _, _ = self._type_system.extract_type_info(output_info.data_type)
self._type_system.register_type(type_name, output_info.data_type)
outputs[name] = BlockOutput(
name=name,
label=output_info.label,
description=output_info.description,
type=type_name,
)
# 内置方法不属于参数(Block 的 __init__ 方法)
builtin_params = self.get_builtin_params()
# 获取 __init__ 方法的参数作为配置
sig = signature(block_type.__init__)
for param in sig.parameters.values():
if param.name == "self" or param.name in builtin_params:
continue
block_config = extract_block_param(param, self._type_system)
configs[param.name] = block_config
return inputs, outputs, configs
def get_builtin_params(self) -> List[str]:
"""获取内置参数"""
sig = signature(Block.__init__)
return [param.name for param in sig.parameters.values()]
def get_type_compatibility_map(self) -> Dict[str, Dict[str, bool]]:
"""获取所有类型的兼容性映射"""
return self._type_system.get_compatibility_map()
def is_type_compatible(self, source_type: str, target_type: str) -> bool:
"""检查源类型是否可以赋值给目标类型"""
return self._type_system.is_compatible(source_type, target_type)
================================================
FILE: kirara_ai/workflow/core/block/schema.py
================================================
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from kirara_ai.workflow.core.block.param import OptionsProvider
class BlockInput(BaseModel):
"""Block输入定义"""
name: str
label: str
description: str
type: str
required: bool = True
default: Optional[Any] = None
class BlockOutput(BaseModel):
"""Block输出定义"""
name: str
label: str
description: str
type: str
class BlockConfig(BaseModel):
"""Block配置项定义"""
name: str
description: Optional[str] = None
type: str
required: bool = True
default: Optional[Any] = None
label: Optional[str] = None
has_options: bool = False
options: Optional[List[Any]] = None
options_provider: Optional[OptionsProvider] = Field(exclude=True)
================================================
FILE: kirara_ai/workflow/core/block/type_system.py
================================================
from inspect import Parameter
from typing import Any, Dict, Optional, Type, Union, get_args, get_origin, overload
class TypeSystem:
"""类型系统管理器,用于处理类型兼容性检查和类型名称映射"""
def __init__(self) -> None:
self._type_map: Dict[str, Type] = {}
self._compatibility_cache: Dict[str, Dict[str, bool]] = {}
def register_type(self, type_name: str, type_class: Type):
"""注册一个类型到类型系统中"""
self._type_map[type_name] = type_class
def get_type(self, type_name: str) -> Optional[Type]:
"""获取类型名称对应的实际类型"""
return self._type_map.get(type_name)
def get_type_name(self, type_obj: Type) -> str:
"""获取类型对应的名称"""
if hasattr(type_obj, "__name__"):
return type_obj.__name__
return str(type_obj)
@overload
def extract_type_info(self, param: Parameter) -> tuple[str, bool, Any]: ...
@overload
def extract_type_info(self, param: Type) -> tuple[str, bool, Any]: ...
def extract_type_info(self, param: Union[Parameter, Type]) -> tuple[str, bool, Any]:
"""从参数中提取类型信息
Returns:
tuple: (type_name, required, default_value)
"""
if isinstance(param, Parameter):
param_type = param.annotation
required = True
default = param.default if param.default != Parameter.empty else None
origin = get_origin(param_type)
else:
param_type = param
origin = get_origin(param_type)
default = None
required = True
if origin is Union:
args = get_args(param_type)
if type(None) in args:
required = False
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
type_name = self.get_type_name(non_none_args[0])
else:
type_name = f"Union[{', '.join(self.get_type_name(arg) for arg in non_none_args)}]"
else:
type_name = f"Union[{', '.join(self.get_type_name(arg) for arg in args)}]"
elif origin is list:
args = get_args(param_type)
if args:
element_type = args[0]
element_type_name = self.get_type_name(element_type)
type_name = f"List[{element_type_name}]"
else:
type_name = "list"
else:
type_name = self.get_type_name(param_type)
# 注册类型
if param_type not in (str, int, float, bool, dict) or origin is not list:
self.register_type(type_name, param_type)
return type_name, required, default
def is_compatible(self, source_type: str, target_type: str) -> bool:
"""检查源类型是否可以赋值给目标类型"""
# 检查缓存
if source_type in self._compatibility_cache:
if target_type in self._compatibility_cache[source_type]:
return self._compatibility_cache[source_type][target_type]
# 获取实际类型
source_class = self.get_type(source_type)
target_class = self.get_type(target_type)
if not source_class or not target_class:
# 如果类型未注册,则只允许完全相同的类型
result = source_type == target_type
else:
# 检查类型兼容性
try:
# 任何类型都可以兼容 Any 类型
if target_type == "Any" or source_type == "Any":
result = True
else:
result = issubclass(source_class, target_class)
except TypeError:
# 处理一些特殊类型(如泛型)
result = source_type == target_type
# 缓存结果
if source_type not in self._compatibility_cache:
self._compatibility_cache[source_type] = {}
self._compatibility_cache[source_type][target_type] = result
return result
def get_compatibility_map(self) -> Dict[str, Dict[str, bool]]:
"""获取所有已注册类型之间的兼容性映射"""
# 确保所有类型组合都已经计算过
all_types = list(self._type_map.keys())
for source_type in all_types:
for target_type in all_types:
self.is_compatible(source_type, target_type)
# 只保留可兼容的结果
return {
source_type: {
target_type: is_compatible
for target_type, is_compatible in compatibility.items()
if is_compatible
}
for source_type, compatibility in self._compatibility_cache.items()
}
================================================
FILE: kirara_ai/workflow/core/dispatch/__init__.py
================================================
from .dispatcher import WorkflowDispatcher, WorkflowExecutor, WorkflowRegistry
from .exceptions import WorkflowNotFoundException
from .models.dispatch_rules import CombinedDispatchRule, RuleGroup, SimpleDispatchRule
from .registry import DispatchRuleRegistry
from .rules.base import DispatchRule, RuleConfig
__all__ = [
"CombinedDispatchRule",
"DispatchRule",
"DispatchRuleRegistry",
"WorkflowDispatcher",
"WorkflowExecutor",
"WorkflowRegistry",
"RuleGroup",
"SimpleDispatchRule",
"RuleGroup",
"SimpleDispatchRule",
"RuleConfig",
"WorkflowNotFoundException",
]
================================================
FILE: kirara_ai/workflow/core/dispatch/dispatcher.py
================================================
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.message import IMMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.dispatch.models.dispatch_rules import CombinedDispatchRule
from kirara_ai.workflow.core.dispatch.registry import DispatchRuleRegistry
from kirara_ai.workflow.core.dispatch.rules.base import DispatchRule
from kirara_ai.workflow.core.execution.exceptions import WorkflowExecutionTimeoutException
from kirara_ai.workflow.core.execution.executor import WorkflowExecutor
from kirara_ai.workflow.core.workflow.base import Workflow
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
from .exceptions import WorkflowNotFoundException
class WorkflowDispatcher:
"""工作流调度器"""
def __init__(self, container: DependencyContainer):
self.container = container
self.logger = get_logger("WorkflowDispatcher")
# 从容器获取注册表
self.workflow_registry = container.resolve(WorkflowRegistry)
self.dispatch_registry = container.resolve(DispatchRuleRegistry)
def register_rule(self, rule: CombinedDispatchRule):
"""注册一个调度规则"""
self.dispatch_registry.register(rule)
self.logger.info(f"Registered dispatch rule: {rule}")
async def dispatch(self, source: IMAdapter, message: IMMessage):
"""
根据消息内容选择第一个匹配的规则进行处理
"""
with self.container.scoped() as scoped_container:
scoped_container.register(IMAdapter, source)
scoped_container.register(IMMessage, message)
# 获取所有已启用的规则,按优先级排序
active_rules = self.dispatch_registry.get_active_rules()
for rule in active_rules:
if rule.match(message, self.workflow_registry, scoped_container):
scoped_container.register(DispatchRule, rule)
try:
self.logger.debug(f"Matched rule {rule}, executing workflow")
workflow = rule.get_workflow(scoped_container)
if workflow is None:
raise WorkflowNotFoundException(f"Workflow for rule {rule.name} not found, please check the rule configuration")
scoped_container.register(Workflow, workflow)
executor = WorkflowExecutor(scoped_container)
scoped_container.register(WorkflowExecutor, executor)
return await executor.run()
except WorkflowExecutionTimeoutException as e:
self.logger.error(f"Workflow execution timed out: {e}")
return None
except Exception as e:
self.logger.opt(exception=e).error(f"Workflow execution failed: {e}")
return None
self.logger.debug("No matching rule found for message")
return None
================================================
FILE: kirara_ai/workflow/core/dispatch/exceptions.py
================================================
class WorkflowNotFoundException(Exception):
"""工作流未找到异常"""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
================================================
FILE: kirara_ai/workflow/core/dispatch/models/dispatch_rules.py
================================================
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel
from kirara_ai.im.message import IMMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.workflow import Workflow
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
logger = get_logger("DispatchRule")
class SimpleDispatchRule(BaseModel):
"""简单规则,包含规则类型和配置"""
type: str
config: Dict[str, Any]
class RuleGroup(BaseModel):
"""规则组,包含多个简单规则和组合操作符"""
operator: Literal["and", "or"] = "or"
rules: List[SimpleDispatchRule]
class CombinedDispatchRule(BaseModel):
"""组合调度规则,支持复杂的规则组合"""
rule_id: str
name: str
description: str = ""
workflow_id: str
priority: int = 5
enabled: bool = True
rule_groups: List[RuleGroup] # 规则组之间是 AND 关系
metadata: Dict[str, Any] = {}
def match(self, message: IMMessage, workflow_registry: WorkflowRegistry, container: DependencyContainer) -> bool:
"""
判断消息是否匹配该规则。
规则组之间是 AND 关系,规则组内部根据 operator 决定是 AND 还是 OR 关系。
"""
# 如果规则被禁用,直接返回 False
if not self.enabled:
return False
# 所有规则组都必须匹配(AND 关系)
for group in self.rule_groups:
# 如果组内没有规则,视为匹配
if len(group.rules) == 0:
return True
# 获取组内所有规则的匹配结果
rule_results = []
for rule in group.rules:
try:
from ..rules.base import DispatchRule
# 创建具体的规则实例
rule_class = DispatchRule.get_rule_type(rule.type)
rule_instance = rule_class.from_config(
rule_class.config_class(**rule.config),
workflow_registry,
self.workflow_id,
)
rule_results.append(rule_instance.match(message, container))
except Exception as e:
# 如果规则创建或匹配过程出错,视为不匹配
logger.error(f"Rule {rule.type} from config {rule.config} creation or matching failed: {e}")
continue
# 根据操作符确定组的匹配结果
if not rule_results: # 如果组内没有有效规则,视为不匹配
return False
if group.operator == "and":
if not all(rule_results): # AND 关系:所有规则都必须匹配
return False
else: # operator == "or"
if not any(rule_results): # OR 关系:至少一个规则匹配
return False
# 所有规则组都匹配成功
return True
def get_workflow(self, container: DependencyContainer) -> Optional[Workflow]:
"""获取该规则对应的工作流实例。"""
workflow = container.resolve(WorkflowRegistry).get_workflow(self.workflow_id, container)
return workflow
================================================
FILE: kirara_ai/workflow/core/dispatch/registry.py
================================================
import os
from typing import Any, Dict, List, Optional
from ruamel.yaml import YAML
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
from .models.dispatch_rules import CombinedDispatchRule, RuleGroup, SimpleDispatchRule
from .rules.base import DispatchRule
from .rules.message_rules import BotMentionMatchRule, KeywordMatchRule, PrefixMatchRule, RegexMatchRule
from .rules.sender_rules import ChatSenderMatchRule, ChatSenderMismatchRule, ChatTypeMatchRule
from .rules.system_rules import FallbackMatchRule, IMInstanceMatchRule, RandomChanceMatchRule
class DispatchRuleRegistry:
"""调度规则注册表,管理调度规则的加载和注册"""
def __init__(self, container: DependencyContainer):
self.container = container
self.workflow_registry = container.resolve(WorkflowRegistry)
self.rules: Dict[str, CombinedDispatchRule] = {}
self.logger = get_logger("DispatchRuleRegistry")
self.rules_dir = "data/dispatch_rules"
def register(self, rule: CombinedDispatchRule):
"""注册一个调度规则"""
if not rule.rule_id:
raise ValueError("Rule must have an ID")
self.rules[rule.rule_id] = rule
self.logger.info(f"Registered dispatch rule: {rule}")
def get_rule(self, rule_id: str) -> Optional[CombinedDispatchRule]:
"""获取指定ID的规则"""
return self.rules.get(rule_id)
def get_all_rules(self) -> List[CombinedDispatchRule]:
"""获取所有已注册的规则"""
return list(self.rules.values())
def get_active_rules(self) -> List[CombinedDispatchRule]:
"""获取所有已启用的规则,按优先级降序排序"""
active_rules = [rule for rule in self.rules.values() if rule.enabled]
return sorted(active_rules, key=lambda x: x.priority, reverse=True)
def create_rule(self, rule: CombinedDispatchRule) -> CombinedDispatchRule:
"""创建并注册一个新规则"""
# 获取工作流构建器
workflow_builder = self.workflow_registry.get(rule.workflow_id)
if not workflow_builder:
raise ValueError(f"Workflow {rule.workflow_id} not found")
# 注册规则
self.register(rule)
return rule
def update_rule(
self, rule_id: str, rule: CombinedDispatchRule
) -> CombinedDispatchRule:
"""更新现有规则"""
if rule_id not in self.rules:
raise ValueError(f"Rule {rule_id} not found")
# 更新规则
self.register(rule)
return rule
def delete_rule(self, rule_id: str):
"""删除规则"""
if rule_id not in self.rules:
raise ValueError(f"Rule {rule_id} not found")
del self.rules[rule_id]
def enable_rule(self, rule_id: str):
"""启用规则"""
rule = self.get_rule(rule_id)
if not rule:
raise ValueError(f"Rule {rule_id} not found")
rule.enabled = True
def disable_rule(self, rule_id: str):
"""禁用规则"""
rule = self.get_rule(rule_id)
if not rule:
raise ValueError(f"Rule {rule_id} not found")
rule.enabled = False
def _convert_old_rule(self, rule_data: Dict[str, Any]) -> CombinedDispatchRule:
"""将旧版本规则数据转换为新版本格式"""
rule_type = rule_data["type"]
rule_class = DispatchRule.get_rule_type(rule_type)
# 提取规则配置
config_fields = rule_class.config_class.model_fields.keys()
rule_config = {k: rule_data[k] for k in config_fields if k in rule_data}
# 创建简单规则
simple_rule = SimpleDispatchRule(type=rule_type, config=rule_config)
# 创建规则组
rule_group = RuleGroup(operator="or", rules=[simple_rule])
# 创建组合规则
return CombinedDispatchRule(
rule_id=rule_data["rule_id"],
name=rule_data["name"],
description=rule_data.get("description", ""),
workflow_id=rule_data["workflow_id"],
rule_groups=[rule_group],
priority=rule_data.get("priority", 5),
enabled=rule_data.get("enabled", True),
metadata=rule_data.get("metadata", {}),
)
def load_rules(self, rules_dir: Optional[str] = None):
"""从指定目录加载所有调度规则"""
rules_dir = rules_dir or self.rules_dir
if not os.path.exists(rules_dir):
os.makedirs(rules_dir)
yaml = YAML(typ="safe")
for file_name in os.listdir(rules_dir):
if not file_name.endswith(".yaml"):
continue
file_path = os.path.join(rules_dir, file_name)
try:
with open(file_path, "r", encoding="utf-8") as f:
rules_data = yaml.load(f)
if not isinstance(rules_data, list):
self.logger.warning(
f"Invalid rules file {file_name}, expected list of rules"
)
continue
for rule_data in rules_data:
try:
# 检查是否是新版本的组合规则
if "rule_groups" in rule_data:
rule = CombinedDispatchRule(**rule_data)
else:
# 旧版本规则,转换为新格式
rule = self._convert_old_rule(rule_data)
self.register(rule)
self.logger.debug(f"Loaded rule: {rule}")
except Exception as e:
self.logger.error(
f"Failed to load rule in file {file_path}: {str(e)}"
)
except Exception as e:
self.logger.error(f"Failed to load rules from {file_path}: {str(e)}")
def save_rules(self, rules_dir: Optional[str] = None):
"""保存所有规则到文件"""
rules_dir = rules_dir or self.rules_dir
if not os.path.exists(rules_dir):
os.makedirs(rules_dir)
yaml = YAML()
yaml.default_flow_style = False
# 保存规则
rules_data = [rule.dict() for rule in self.rules.values()]
# 保存到文件
file_path = os.path.join(rules_dir, "rules.yaml")
with open(file_path, "w", encoding="utf-8") as f:
yaml.dump(rules_data, f)
# 注册所有规则类型
DispatchRule.register_rule_type(RegexMatchRule)
DispatchRule.register_rule_type(PrefixMatchRule)
DispatchRule.register_rule_type(KeywordMatchRule)
DispatchRule.register_rule_type(BotMentionMatchRule)
DispatchRule.register_rule_type(RandomChanceMatchRule)
DispatchRule.register_rule_type(ChatSenderMatchRule)
DispatchRule.register_rule_type(ChatSenderMismatchRule)
DispatchRule.register_rule_type(ChatTypeMatchRule)
DispatchRule.register_rule_type(IMInstanceMatchRule)
DispatchRule.register_rule_type(FallbackMatchRule)
================================================
FILE: kirara_ai/workflow/core/dispatch/rules/base.py
================================================
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Type
from pydantic import BaseModel
from kirara_ai.im.message import IMMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.workflow import Workflow
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
class RuleConfig(BaseModel):
"""规则配置的基类"""
class DispatchRule(ABC):
"""
工作流调度规则的抽象基类。
用于定义如何根据消息内容选择合适的工作流进行处理。
"""
# 类变量,用于规则类型注册
rule_types: ClassVar[Dict[str, Type["DispatchRule"]]] = {}
config_class: ClassVar[Type[RuleConfig]]
type_name: ClassVar[str]
def __init__(self, workflow_registry: WorkflowRegistry, workflow_id: str):
"""初始化调度规则。"""
self.workflow_registry = workflow_registry
self.rule_id: str = ""
self.name: str = ""
self.description: str = ""
self.priority: int = 5 # 默认优先级为5
self.enabled: bool = True # 是否启用
self.metadata: Dict[str, Any] = {} # 元数据
self.workflow_id: str = workflow_id
@abstractmethod
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
"""判断消息是否匹配该规则。"""
def get_workflow(self, container: DependencyContainer) -> Workflow:
"""获取该规则对应的工作流实例。"""
workflow = self.workflow_registry.get(self.workflow_id, container)
assert isinstance(workflow, Workflow)
return workflow
@classmethod
def register_rule_type(cls, rule_class: Type["DispatchRule"]):
"""注册规则类型"""
cls.rule_types[rule_class.type_name] = rule_class
@classmethod
def get_rule_type(cls, type_name: str) -> Type["DispatchRule"]:
"""获取规则类型"""
if type_name not in cls.rule_types:
raise ValueError(f"Unknown rule type: {type_name}")
return cls.rule_types[type_name]
@abstractmethod
def get_config(self) -> RuleConfig:
"""获取规则配置"""
@classmethod
@abstractmethod
def from_config(
cls, config: RuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str
) -> "DispatchRule":
"""从配置创建规则实例"""
def __str__(self) -> str:
return f"{self.__class__.__name__}(id='{self.rule_id}', priority={self.priority}, enabled={self.enabled})"
================================================
FILE: kirara_ai/workflow/core/dispatch/rules/message_rules.py
================================================
import re
from typing import List
from pydantic import Field
from kirara_ai.im.message import IMMessage, MentionElement, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
from .base import DispatchRule, RuleConfig
class RegexRuleConfig(RuleConfig):
"""正则规则配置"""
pattern: str = Field(title="正则表达式", description="正则表达式")
class RegexMatchRule(DispatchRule):
"""根据正则表达式匹配的规则"""
config_class = RegexRuleConfig
type_name = "regex"
def __init__(self, pattern: str, workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
self.pattern = re.compile(pattern)
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
return bool(self.pattern.search(message.content))
def get_config(self) -> RegexRuleConfig:
return RegexRuleConfig(pattern=self.pattern.pattern)
@classmethod
def from_config(cls, config: RegexRuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str) -> "RegexMatchRule":
return cls(config.pattern, workflow_registry, workflow_id)
class PrefixRuleConfig(RuleConfig):
"""前缀规则配置"""
prefix: str = Field(title="前缀", description="前缀")
class PrefixMatchRule(DispatchRule):
"""根据消息前缀匹配的规则"""
config_class = PrefixRuleConfig
type_name = "prefix"
def __init__(self, prefix: str, workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
self.prefix = prefix
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
return next(
(
message_element.text.startswith(self.prefix)
for message_element in message.message_elements
if isinstance(message_element, TextMessage)
),
False,
)
def get_config(self) -> PrefixRuleConfig:
return PrefixRuleConfig(prefix=self.prefix)
@classmethod
def from_config(cls, config: PrefixRuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str) -> "PrefixMatchRule":
return cls(config.prefix, workflow_registry, workflow_id)
class KeywordRuleConfig(RuleConfig):
"""关键词规则配置"""
keywords: List[str] = Field(title="关键词", description="关键词列表")
class KeywordMatchRule(DispatchRule):
"""根据关键词匹配的规则"""
config_class = KeywordRuleConfig
type_name = "keyword"
def __init__(self, keywords: list[str], workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
self.keywords = keywords
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
return any(keyword in message.content for keyword in self.keywords)
def get_config(self) -> KeywordRuleConfig:
return KeywordRuleConfig(keywords=self.keywords)
@classmethod
def from_config(cls, config: KeywordRuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str) -> "KeywordMatchRule":
return cls(config.keywords, workflow_registry, workflow_id)
class BotMentionMatchRule(DispatchRule):
"""根据机器人被提及匹配的规则"""
config_class = RuleConfig
type_name = "bot_mention"
def __init__(self, workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
bot_sender = ChatSender.get_bot_sender()
return any(isinstance(element, MentionElement) and element.target == bot_sender for element in message.message_elements)
def get_config(self) -> RuleConfig:
return RuleConfig()
@classmethod
def from_config(cls, config: RuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str) -> "BotMentionMatchRule":
return cls(workflow_registry, workflow_id)
================================================
FILE: kirara_ai/workflow/core/dispatch/rules/sender_rules.py
================================================
from typing import Literal, Optional
from pydantic import Field
from kirara_ai.im.message import IMMessage
from kirara_ai.im.sender import ChatType
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
from .base import DispatchRule, RuleConfig
class ChatSenderMatchRuleConfig(RuleConfig):
"""聊天发送者规则配置"""
sender_id: str = Field(title="发送者ID", description="发送者ID", default="")
sender_group: str = Field(
title="发送者群号", description="发送者群号", default=""
)
class ChatSenderMatchRule(DispatchRule):
"""根据聊天发送者匹配的规则"""
config_class = ChatSenderMatchRuleConfig
type_name = "sender"
def __init__(
self,
sender_id: Optional[str],
sender_group: Optional[str],
workflow_registry: WorkflowRegistry,
workflow_id: str,
):
super().__init__(workflow_registry, workflow_id)
self.sender_id = sender_id
self.sender_group = sender_group
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
# 如果设置了群组ID,则必须匹配
if self.sender_group and message.sender.group_id != self.sender_group:
return False
# 如果设置了发送者ID,则必须匹配
if self.sender_id and message.sender.user_id != self.sender_id:
return False
# 如果没有设置任何条件或所有条件都匹配,则返回True
return True
def get_config(self) -> ChatSenderMatchRuleConfig:
return ChatSenderMatchRuleConfig(
sender_id=self.sender_id or "", sender_group=self.sender_group or ""
)
@classmethod
def from_config(
cls,
config: ChatSenderMatchRuleConfig,
workflow_registry: WorkflowRegistry,
workflow_id: str,
) -> "ChatSenderMatchRule":
return cls(config.sender_id, config.sender_group, workflow_registry, workflow_id)
class ChatSenderMismatchRule(DispatchRule):
"""根据聊天发送者不匹配的规则"""
config_class = ChatSenderMatchRuleConfig
type_name = "sender_mismatch"
def __init__(
self,
sender_id: Optional[str],
sender_group: Optional[str],
workflow_registry: WorkflowRegistry,
workflow_id: str,
):
super().__init__(workflow_registry, workflow_id)
self.sender_id = sender_id
self.sender_group = sender_group
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
# 如果设置了群组ID,则必须不匹配
if self.sender_group and message.sender.group_id == self.sender_group:
return False
# 如果设置了发送者ID,则必须不匹配
if self.sender_id and message.sender.user_id == self.sender_id:
return False
# 如果没有设置任何条件或所有条件都不匹配,则返回True
return True
def get_config(self) -> ChatSenderMatchRuleConfig:
return ChatSenderMatchRuleConfig(
sender_id=self.sender_id or "", sender_group=self.sender_group or ""
)
@classmethod
def from_config(
cls,
config: ChatSenderMatchRuleConfig,
workflow_registry: WorkflowRegistry,
workflow_id: str,
) -> "ChatSenderMismatchRule":
return cls(config.sender_id, config.sender_group, workflow_registry, workflow_id)
class ChatTypeMatchRuleConfig(RuleConfig):
"""聊天类型规则配置"""
chat_type: Literal["私聊", "群聊"] = Field(title="聊天类型", description="聊天类型")
class ChatTypeMatchRule(DispatchRule):
"""根据聊天类型匹配的规则"""
config_class = ChatTypeMatchRuleConfig
type_name = "chat_type"
def __init__(self, chat_type: ChatType, workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
self.chat_type = chat_type
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
return message.sender.chat_type == self.chat_type
def get_config(self) -> ChatTypeMatchRuleConfig:
return ChatTypeMatchRuleConfig(chat_type=self.chat_type.to_str()) # type: ignore
@classmethod
def from_config(cls, config: ChatTypeMatchRuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str) -> "ChatTypeMatchRule":
chat_type = ChatType.from_str(config.chat_type)
return cls(chat_type, workflow_registry, workflow_id)
================================================
FILE: kirara_ai/workflow/core/dispatch/rules/system_rules.py
================================================
import random
from pydantic import Field
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.manager import IMManager
from kirara_ai.im.message import IMMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
from .base import DispatchRule, RuleConfig
class RandomChanceRuleConfig(RuleConfig):
"""随机概率规则配置"""
chance: int = Field(
default=50, ge=0, le=100, title="随机概率", description="随机概率,范围为0-100"
)
class RandomChanceMatchRule(DispatchRule):
"""根据随机概率匹配的规则"""
config_class = RandomChanceRuleConfig
type_name = "random"
def __init__(self, chance: int, workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
self.chance = chance
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
print(f"Random chance: {self.chance}")
print(f"Random number: {random.random()}")
return random.random() * 100 < self.chance
def get_config(self) -> RandomChanceRuleConfig:
return RandomChanceRuleConfig(chance=self.chance)
@classmethod
def from_config(
cls, config: RandomChanceRuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str
) -> "RandomChanceMatchRule":
return cls(config.chance, workflow_registry, workflow_id)
class IMInstanceMatchRuleConfig(RuleConfig):
"""IM实例匹配规则配置"""
im_instance: str = Field(title="IM实例名称", description="配置后,只有当消息来自指定的IM实例时,才会触发工作流")
class IMInstanceMatchRule(DispatchRule):
"""根据IM实例匹配的规则"""
config_class = IMInstanceMatchRuleConfig
type_name = "im_instance"
def __init__(self, im_instance: str, workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
self.im_instance = im_instance
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
adapter = container.resolve(IMAdapter)
im_manager = container.resolve(IMManager)
return im_manager.get_adapter(self.im_instance) == adapter
def get_config(self) -> IMInstanceMatchRuleConfig:
return IMInstanceMatchRuleConfig(im_instance=self.im_instance)
@classmethod
def from_config(
cls, config: IMInstanceMatchRuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str
) -> "IMInstanceMatchRule":
return cls(config.im_instance, workflow_registry, workflow_id)
class FallbackMatchRule(DispatchRule):
"""默认的兜底规则,总是匹配"""
config_class = RuleConfig
type_name = "fallback"
def __init__(self, workflow_registry: WorkflowRegistry, workflow_id: str):
super().__init__(workflow_registry, workflow_id)
self.priority = 0 # 兜底规则优先级最低
def match(self, message: IMMessage, container: DependencyContainer) -> bool:
return True
def get_config(self) -> RuleConfig:
return RuleConfig()
@classmethod
def from_config(
cls, config: RuleConfig, workflow_registry: WorkflowRegistry, workflow_id: str
) -> "FallbackMatchRule":
return cls(workflow_registry, workflow_id)
================================================
FILE: kirara_ai/workflow/core/execution/__init__.py
================================================
================================================
FILE: kirara_ai/workflow/core/execution/exceptions.py
================================================
class BlockExecutionFailedException(Exception):
pass
class WorkflowExecutionTimeoutException(Exception):
pass
================================================
FILE: kirara_ai/workflow/core/execution/executor.py
================================================
import asyncio
import functools
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List
from kirara_ai.events.event_bus import EventBus
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.block import Block, ConditionBlock, LoopBlock
from kirara_ai.workflow.core.block.registry import BlockRegistry
from kirara_ai.workflow.core.execution.exceptions import (BlockExecutionFailedException,
WorkflowExecutionTimeoutException)
from kirara_ai.workflow.core.workflow import Workflow
class WorkflowExecutor:
@Inject()
def __init__(self, container: DependencyContainer, workflow: Workflow, registry: BlockRegistry, event_bus: EventBus):
"""
初始化 WorkflowExecutor 实例。
:param workflow: 要执行的工作流对象
:param registry: Block注册表,用于类型检查
"""
self.container = container
self.logger = get_logger("WorkflowExecutor")
self.workflow = workflow
self.registry = registry
self.event_bus = event_bus
self.results: Dict[str, Any] = {}
self.variables: Dict[str, Any] = {} # 存储工作流变量
self.logger.info(
f"Initializing WorkflowExecutor for workflow '{workflow.name}'"
)
# self.logger.debug(f"Workflow has {len(workflow.blocks)} blocks and {len(workflow.wires)} wires")
self._build_execution_graph()
def _build_execution_graph(self):
"""构建执行图,包含并行和条件逻辑"""
self.execution_graph = defaultdict(list)
# self.logger.debug("Building execution graph...")
for wire in self.workflow.wires:
# self.logger.debug(f"Processing wire: {wire.source_block.name}.{wire.source_output} -> "
# f"{wire.target_block.name}.{wire.target_input}")
# 验证连线的数据类型是否匹配
source_output = wire.source_block.outputs[wire.source_output]
target_input = wire.target_block.inputs[wire.target_input]
# 使用 BlockRegistry 的类型系统进行类型兼容性检查
source_type = self.registry._type_system.get_type_name(source_output.data_type)
target_type = self.registry._type_system.get_type_name(target_input.data_type)
if not self.registry.is_type_compatible(source_type, target_type):
error_msg = (
f"Type mismatch in wire: {wire.source_block.name}.{wire.source_output} "
f"({source_type}) -> {wire.target_block.name}.{wire.target_input} "
f"({target_type})"
)
self.logger.error(error_msg)
raise TypeError(error_msg)
# 将目标块添加到源块的执行图中
self.execution_graph[wire.source_block].append(wire.target_block)
# self.logger.debug(f"Added edge: {wire.source_block.name} -> {wire.target_block.name}")
async def run(self) -> Dict[str, Any]:
"""
执行工作流,返回每个块的执行结果。
:return: 包含每个块执行结果的字典,键为块名,值为块的输出
"""
from kirara_ai.events import WorkflowExecutionBegin, WorkflowExecutionEnd
self.event_bus.post(WorkflowExecutionBegin(self.workflow, self))
self.logger.info("Starting workflow execution")
loop = asyncio.get_event_loop()
max_timeout = self.workflow.config.max_execution_time
if max_timeout <= 0:
# 如果超时时间小于等于0,则不限制超时
max_timeout = None
with ThreadPoolExecutor() as executor:
# 从入口节点开始执行
entry_blocks = [block for block in self.workflow.blocks if not block.inputs]
# self.logger.debug(f"Identified entry blocks: {[b.name for b in entry_blocks]}")
try:
async with asyncio.timeout(max_timeout): # type: ignore
await self._execute_nodes(entry_blocks, executor, loop)
except asyncio.TimeoutError as e:
self.event_bus.post(WorkflowExecutionEnd(self.workflow, self, self.results))
raise WorkflowExecutionTimeoutException(f"Workflow execution timed out after {max_timeout} seconds") from e
self.logger.info("Workflow execution completed")
self.event_bus.post(WorkflowExecutionEnd(self.workflow, self, self.results))
return self.results
async def _execute_nodes(self, blocks: List[Block], executor, loop):
"""执行一组节点"""
# self.logger.debug(f"Executing node group: {[b.name for b in blocks]}")
for block in blocks:
# self.logger.debug(f"Processing block: {block.name} ({type(block).__name__})")
if isinstance(block, ConditionBlock):
await self._execute_conditional_branch(block, executor, loop)
elif isinstance(block, LoopBlock):
await self._execute_loop(block, executor, loop)
else:
await self._execute_normal_block(block, executor, loop)
async def _execute_conditional_branch(self, block: ConditionBlock, executor, loop):
"""执行条件分支"""
self.logger.info(f"Executing ConditionBlock: {block.name}")
inputs = self._gather_inputs(block)
# self.logger.debug(f"ConditionBlock inputs: {list(inputs.keys())}")
result = await loop.run_in_executor(executor, block.execute, **inputs)
self.results[block.name] = result
self.logger.info(
f"ConditionBlock {block.name} evaluation result: {result['condition_result']}"
)
next_blocks = self.execution_graph[block]
if result["condition_result"]:
# self.logger.debug(f"Taking THEN branch: {next_blocks[0].name}")
await self._execute_nodes([next_blocks[0]], executor, loop)
elif len(next_blocks) > 1:
# self.logger.debug(f"Taking ELSE branch: {next_blocks[1].name}")
await self._execute_nodes([next_blocks[1]], executor, loop)
else:
# self.logger.debug("No ELSE branch available")
pass
async def _execute_loop(self, block: LoopBlock, executor, loop):
"""执行循环"""
self.logger.info(f"Starting LoopBlock: {block.name}")
iteration = 0
while True:
iteration += 1
# self.logger.debug(f"LoopBlock {block.name} iteration #{iteration}")
inputs = self._gather_inputs(block)
# self.logger.debug(f"LoopBlock inputs: {list(inputs.keys())}")
result = await loop.run_in_executor(executor, block.execute, **inputs)
self.results[block.name] = result
self.logger.info(
f"LoopBlock {block.name} continuation check: {result['should_continue']}"
)
if not result["should_continue"]:
self.logger.info(
f"Exiting LoopBlock {block.name} after {iteration} iterations"
)
break
# self.logger.debug(f"Executing loop body: {self.execution_graph[block][0].name}")
loop_body = self.execution_graph[block][0]
await self._execute_nodes([loop_body], executor, loop)
async def _execute_normal_block(self, block: Block, executor, loop):
"""执行普通块"""
# self.logger.debug(f"Evaluating Block: {block.name}")
futures = []
if self._can_execute(block):
inputs = self._gather_inputs(block)
self.logger.info(f"Executing Block: {block.name}")
# self.logger.debug(f"Input parameters: {list(inputs.keys())}")
future = loop.run_in_executor(
executor, functools.partial(block.execute, **inputs)
)
futures.append((future, block))
else:
# self.logger.debug(f"Block {block.name} dependencies not met, skipping execution")
return
# 等待所有并行任务完成
for future, block in futures:
try:
result = await future
self.results[block.name] = result
self.logger.info(f"Block [{block.name}] executed successfully")
if result:
# self.logger.debug(f"Execution result keys: {list(result.keys())}")
pass
next_blocks = self.execution_graph[block]
if next_blocks:
# self.logger.debug(f"Propagating to next blocks: {[b.name for b in next_blocks]}")
await self._execute_nodes(next_blocks, executor, loop)
else:
# self.logger.debug(f"Block {block.name} is terminal node")
pass
except BlockExecutionFailedException as e:
raise e
except Exception as e:
raise BlockExecutionFailedException(f"Block {block.name} execution failed: {e}") from e
def _can_execute(self, block: Block) -> bool:
"""检查节点是否可以执行"""
# self.logger.debug(f"Checking execution readiness for Block: {block.name}")
# 如果块已经执行过,直接返回False
if block.name in self.results:
# self.logger.debug(f"Block {block.name} has already been executed")
return False
# 获取所有直接前置blocks
predecessor_blocks = set()
for wire in self.workflow.wires:
if wire.target_block == block:
predecessor_blocks.add(wire.source_block)
# 确保所有前置blocks都已执行完成
for pred_block in predecessor_blocks:
if pred_block.name not in self.results:
# self.logger.debug(f"Predecessor block {pred_block.name} not yet executed")
return False
# 验证所有输入是否都能从正确的前置block获取
for input_name in block.inputs:
input_satisfied = False
for wire in self.workflow.wires:
if (
wire.target_block == block
and wire.target_input == input_name
and wire.source_block.name in self.results
and wire.source_output in self.results[wire.source_block.name]
):
self.logger.debug(f"Input [{block.name}.{input_name}] satisfied by [{wire.source_block.name}.{wire.source_output}] with value {self.results[wire.source_block.name][wire.source_output]}")
input_satisfied = True
break
# 如果输入没有被满足,并且输入不是可空的,则返回False
if not input_satisfied and not block.inputs[input_name].nullable:
self.logger.info(f"Input [{block.name}.{input_name}] not satisfied")
return False
self.logger.debug(f"All inputs satisfied and predecessors completed for block {block.name}")
return True
def _gather_inputs(self, block: Block) -> Dict[str, Any]:
"""收集节点的输入数据"""
# self.logger.debug(f"Gathering inputs for Block: {block.name}")
inputs = {}
# 创建输入名称到wire的映射
input_wire_map = {}
for wire in self.workflow.wires:
if wire.target_block == block:
input_wire_map[wire.target_input] = wire
# 根据wire的连接关系收集输入
for input_name in block.inputs:
if input_name in input_wire_map:
wire = input_wire_map[input_name]
if wire.source_block.name in self.results and wire.source_output in self.results[wire.source_block.name]:
inputs[input_name] = self.results[wire.source_block.name][
wire.source_output
]
# self.logger.debug(f"Resolved input {input_name} from {wire.source_block.name}.{wire.source_output}")
else:
raise BlockExecutionFailedException(
f"Current block {block.name} depends on source block {wire.source_block.name} not executed for input {input_name}"
)
elif not block.inputs[input_name].nullable:
raise BlockExecutionFailedException(
f"Missing wire connection for required input {input_name} in block {block.name}"
)
return inputs
def set_variable(self, name: str, value: Any) -> None:
"""
设置工作流变量
:param name: 变量名
:param value: 变量值
"""
self.variables[name] = value
def get_variable(self, name: str, default: Any = None) -> Any:
"""
获取工作流变量
:param name: 变量名
:param default: 默认值,如果变量不存在则返回此值
:return: 变量值
"""
return self.variables.get(name, default)
================================================
FILE: kirara_ai/workflow/core/workflow/__init__.py
================================================
from .base import Wire, Workflow, WorkflowConfig
from .builder import WorkflowBuilder
from .registry import WorkflowRegistry
__all__ = ["Workflow", "WorkflowBuilder", "WorkflowRegistry", "Wire", "WorkflowConfig"]
================================================
FILE: kirara_ai/workflow/core/workflow/base.py
================================================
from typing import List, Optional
from pydantic import BaseModel
from kirara_ai.workflow.core.block import Block
class WorkflowConfig(BaseModel):
max_execution_time: int = 3600
class Workflow:
def __init__(self, name: str, blocks: List["Block"], wires: List["Wire"], id: Optional[str] = None, config: Optional[WorkflowConfig] = None):
self.name = name
self.blocks = blocks
self.wires = wires
self.id = id
self.config = config or WorkflowConfig()
class Wire:
def __init__(
self,
source_block: "Block",
source_output: str,
target_block: "Block",
target_input: str,
):
self.source_block = source_block
self.source_output = source_output
self.target_block = target_block
self.target_input = target_input
def __repr__(self):
return f"Wire(source_block={self.source_block.name}, source_output={self.source_output}, target_block={self.target_block.name}, target_input={self.target_input})"
================================================
FILE: kirara_ai/workflow/core/workflow/builder.py
================================================
import importlib
import random
import string
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from ruamel.yaml import YAML
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block, ConditionBlock, LoopBlock, LoopEndBlock
from kirara_ai.workflow.core.block.registry import BlockRegistry
from .base import Wire, Workflow, WorkflowConfig
def get_block_class(type_name: str, registry: BlockRegistry) -> Type[Block]:
if type_name.startswith("!!"):
warnings.warn(
f"Loading block using class path: {type_name[2:]}. This is not recommended.",
UserWarning,
)
module_path, class_name = type_name[2:].rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
block_class = registry.get(type_name)
if block_class is None:
raise ValueError(f"Block type {type_name} not found in registry")
return block_class
@dataclass
class BlockSpec:
"""Block 规格的数据类,用于统一处理 block 的创建参数"""
block_class: Type[Block]
name: Optional[str] = None
kwargs: Dict[str, Any] = field(default_factory=dict)
wire_from: Optional[Union[str, List[str]]] = None
def __post_init__(self):
if isinstance(self.wire_from, str):
self.wire_from = [self.wire_from]
@dataclass
class Node:
spec: BlockSpec
name: str
next_nodes: List["Node"] = field(default_factory=list)
merge_point: Optional["Node"] = None
parallel_nodes: List["Node"] = field(default_factory=list)
is_parallel: bool = False
condition: Optional[Callable] = None
is_conditional: bool = False
is_loop: bool = False
parent: Optional["Node"] = None
position: Optional[Dict[str, int]] = None
def __init__(
self,
spec: BlockSpec,
name: Optional[str] = None,
next_nodes: Optional[List["Node"]] = None,
merge_point: Optional["Node"] = None,
parallel_nodes: Optional[List["Node"]] = None,
is_parallel: bool = False,
condition: Optional[Callable] = None,
is_conditional: bool = False,
is_loop: bool = False,
parent: Optional["Node"] = None,
position: Optional[Dict[str, int]] = None,
):
self.spec = spec
self.name = name or spec.name or f"{spec.block_class.__name__}_{id(self)}"
self.next_nodes = next_nodes or []
self.merge_point = merge_point
self.parallel_nodes = parallel_nodes or []
self.is_parallel = is_parallel
self.condition = condition
self.is_conditional = is_conditional
self.is_loop = is_loop
self.parent = parent
self.position = position
def ancestors(self) -> List["Node"]:
"""获取所有祖先节点"""
result: List["Node"] = []
current = self.parent
while current:
result.append(current)
current = current.parent
return result
class WorkflowBuilder:
"""工作流构建器,提供流畅的 DSL 语法来构建工作流。
基本语法:
1. 初始化:
builder = WorkflowBuilder("workflow_name", container)
2. 添加节点的方法:
.use(BlockClass) # 添加初始节点
.chain(BlockClass) # 链式添加节点
.parallel([BlockClass1, BlockClass2]) # 并行添加多个节点
3. 节点配置格式:
- BlockClass # 最简单形式
- (BlockClass, name) # 指定名称
- (BlockClass, wire_from) # 指定连接来源
- (BlockClass, kwargs) # 指定参数
- (BlockClass, name, kwargs) # 指定名称和参数
- (BlockClass, name, wire_from) # 指定名称和连接来源
- (BlockClass, name, kwargs, wire_from) # 指定名称、参数和连接来源
4. 控制流:
.if_then(condition) # 条件分支开始
.else_then() # else 分支
.end_if() # 条件分支结束
.loop(condition) # 循环开始
.end_loop() # 循环结束
完整示例:
```python
workflow = (WorkflowBuilder("example", container)
# 基本用法
.use(InputBlock) # 最简单形式
.chain(ProcessBlock, name="process") # 指定名称
.chain(TransformBlock, # 指定参数
kwargs={"param": "value"})
# 并行处理
.parallel([
ProcessA, # 简单形式
(ProcessB, "proc_b"), # 指定名称
(ProcessC, {"param": "val"}), # 指定参数
(ProcessD, "proc_d", # 完整形式
{"param": "val"},
["process"]) # 指定连接来源
])
# 条件分支
.if_then(lambda ctx: ctx["value"] > 0)
.chain(PositiveBlock)
.else_then()
.chain(NegativeBlock)
.end_if()
# 循环处理
.loop(lambda ctx: ctx["count"] < 5)
.chain(LoopBlock)
.end_loop()
# 多输入连接
.chain(MergeBlock,
wire_from=["proc_b", "proc_d"])
.build())
```
特性说明:
1. 自动连接: 默认情况下,节点会自动与前一个节点连接
2. 命名节点: 通过指定 name 可以后续引用该节点
3. 参数传递: 可以通过 kwargs 字典传递构造参数
4. 自定义连接: 通过 wire_from 指定输入来源
5. 并行处理: parallel 方法支持多个节点并行执行
6. 条件和循环: 支持基本的控制流结构
注意事项:
1. wire_from 引用的节点名称必须已经存在
2. 条件和循环语句必须正确配对
3. 并行节点可以各自指定不同的连接来源
4. 节点名称在工作流中必须唯一
"""
def __init__(self, name: str):
self.id: Optional[str] = None
self.name: str = name
self.description: str = ""
self.head: Optional[Node] = None
self.current: Optional[Node] = None
self.nodes: List[Node] = [] # 存储所有节点
self.nodes_by_name: Dict[str, Node] = {}
self.wire_specs: List[Tuple[str, str, str, str]] = [] # (source_name, source_output, target_name, target_input)
self.config = WorkflowConfig()
def _generate_unique_name(self, base_name: str) -> str:
"""生成唯一的块名称"""
while True:
# 生成6位随机字符串(数字和字母的组合)
suffix = "".join(
random.choices(string.ascii_lowercase + string.digits, k=6)
)
name = f"{base_name}_{suffix}"
if name not in self.nodes_by_name:
return name
def _parse_block_spec(self, block_spec: Union[Type[Block], tuple]) -> BlockSpec:
"""解析 block 规格,统一处理各种输入格式"""
if isinstance(block_spec, type):
return BlockSpec(block_spec)
if not isinstance(block_spec, tuple):
raise ValueError(f"Invalid block specification: {block_spec}")
if len(block_spec) == 4: # (BlockClass, name, kwargs, wire_from)
return BlockSpec(*block_spec)
elif len(block_spec) == 3: # (BlockClass, name/kwargs, kwargs/wire_from)
block_class, second, third = block_spec
if isinstance(second, dict):
return BlockSpec(block_class, kwargs=second, wire_from=third)
return BlockSpec(block_class, name=second, kwargs=third)
elif len(block_spec) == 2: # (BlockClass, name/kwargs)
block_class, second = block_spec
if isinstance(second, dict):
return BlockSpec(block_class, kwargs=second)
return BlockSpec(block_class, name=second)
raise ValueError(f"Invalid block specification format: {block_spec}")
def _get_available_inputs(self, node: Node) -> List[str]:
"""获取节点未被连接的输入端口"""
connected_inputs = {wire[3] for wire in self.wire_specs if wire[2] == node.name}
return [input_name for input_name in node.spec.block_class.inputs.keys()
if input_name not in connected_inputs]
def _find_matching_ports(
self,
source_node: Node,
target_node: Node,
available_inputs: List[str]
) -> List[Tuple[str, str]]:
"""查找匹配的输出和输入端口
Returns:
List of (output_name, input_name) pairs
"""
matches: List[Tuple[str, str]] = []
source_outputs = source_node.spec.block_class.outputs
target_inputs = {name: target_node.spec.block_class.inputs[name]
for name in available_inputs}
for out_name, output in source_outputs.items():
for in_name, input in target_inputs.items():
if output.data_type == input.data_type:
matches.append((out_name, in_name))
# 一旦找到匹配就从可用输入中移除
target_inputs.pop(in_name)
break
return matches
def _store_wire_spec(
self,
source_name: str,
target_name: str,
source_node: Optional[Node] = None,
target_node: Optional[Node] = None,
):
"""存储连接规格,自动匹配输入输出端口"""
if source_node is None:
source_node = self.nodes_by_name[source_name]
if target_node is None:
target_node = self.nodes_by_name[target_name]
# 获取目标节点的可用输入端口
available_inputs = self._get_available_inputs(target_node)
if not available_inputs:
return # 如果没有可用的输入端口,直接返回
# 查找匹配的端口
matches = self._find_matching_ports(source_node, target_node, available_inputs)
# 存储匹配的连接
for source_output, target_input in matches:
self.wire_specs.append((source_name, source_output, target_name, target_input))
def _create_node(self, spec: BlockSpec, is_parallel: bool = False) -> Node:
"""创建一个新的节点,但不实例化 Block"""
# 设置 block 名称
if not spec.name:
spec.name = self._generate_unique_name(spec.block_class.__name__)
node = Node(spec=spec, is_parallel=is_parallel)
self.nodes.append(node)
self.nodes_by_name[node.name] = node
# 处理连接
if spec.wire_from:
for source_name in spec.wire_from:
source_node = self.nodes_by_name.get(source_name)
if source_node:
self._store_wire_spec(source_node.name, node.name, source_node, node)
elif self.current:
self._store_wire_spec(self.current.name, node.name, self.current, node)
return node
def use(
self, block_class: Type[Block], name: Optional[str] = None, **kwargs: Any
) -> "WorkflowBuilder":
spec = BlockSpec(block_class, name=name, kwargs=kwargs)
node = self._create_node(spec)
self.head = node
self.current = node
return self
def chain(
self,
block_class: Type[Block],
name: Optional[str] = None,
wire_from: Optional[List[str]] = None,
**kwargs: Any,
) -> "WorkflowBuilder":
spec = BlockSpec(block_class, name=name, kwargs=kwargs, wire_from=wire_from)
node = self._create_node(spec)
if self.current:
self.current.next_nodes.append(node)
node.parent = self.current
self.current = node
return self
def parallel(
self, block_specs: List[Union[Type[Block], tuple]]
) -> "WorkflowBuilder":
parallel_nodes: List[Node] = []
for block_spec in block_specs:
spec = self._parse_block_spec(block_spec)
node = self._create_node(spec, is_parallel=True)
node.parent = self.current
parallel_nodes.append(node)
if self.current:
self.current.next_nodes.extend(parallel_nodes)
self.current = parallel_nodes[0]
self.current.parallel_nodes = parallel_nodes
return self
def condition(self, condition_func: Callable) -> "WorkflowBuilder":
"""添加条件判断"""
assert self.current is not None
self.current.condition = condition_func
return self
def if_then(
self, condition: Callable[[Dict[str, Any]], bool], name: Optional[str] = None
) -> "WorkflowBuilder":
"""添加条件判断"""
if not name:
name = self._generate_unique_name("condition")
spec = BlockSpec(
block_class=ConditionBlock,
name=name,
kwargs={"condition": condition, "outputs": {}} # outputs will be set during build
)
node = Node(spec=spec, is_conditional=True)
self.nodes.append(node)
self.nodes_by_name[node.name] = node
if self.current:
self._store_wire_spec(self.current.name, "output", self.current, node)
self.current.next_nodes.append(node)
node.parent = self.current
self.current = node
return self
def else_then(self) -> "WorkflowBuilder":
"""添加else分支"""
if not self.current or not self.current.is_conditional:
raise ValueError("else_then must follow if_then")
self.current = self.current.parent
return self
def end_if(self) -> "WorkflowBuilder":
"""结束条件分支"""
if not self.current or not self.current.is_conditional:
raise ValueError("end_if must close an if block")
self.current = self.current.merge_point or self.current
return self
def loop(
self,
condition: Callable[[Dict[str, Any]], bool],
name: Optional[str] = None,
iteration_var: str = "index",
) -> "WorkflowBuilder":
"""开始一个循环"""
if not name:
name = self._generate_unique_name("loop")
spec = BlockSpec(
block_class=LoopBlock,
name=name,
kwargs={
"condition": condition,
"outputs": {}, # outputs will be set during build
"iteration_var": iteration_var
}
)
node = Node(spec=spec, is_loop=True)
self.nodes.append(node)
self.nodes_by_name[node.name] = node
if self.current:
self._store_wire_spec(self.current.name, "output", self.current, node)
self.current.next_nodes.append(node)
node.parent = self.current
self.current = node
return self
def end_loop(self) -> "WorkflowBuilder":
"""结束循环"""
if self.current is None:
raise ValueError("end_loop must close a loop block")
if not any(n.is_loop for n in self.current.ancestors()):
raise ValueError("end_loop must close a loop block")
spec = BlockSpec(
block_class=LoopEndBlock,
name=self._generate_unique_name("loop_end"),
kwargs={"outputs": {}} # outputs will be set during build
)
node = Node(spec=spec)
self.nodes.append(node)
self.nodes_by_name[node.name] = node
if self.current:
self._store_wire_spec(self.current.name, "output", self.current, node)
loop_start = next(n for n in self.current.ancestors() if n.is_loop)
self._store_wire_spec(node.name, "output", loop_start, node)
node.parent = self.current
self.current = node
return self
def build(self, container: DependencyContainer) -> Workflow:
"""构建工作流,在此阶段实例化所有 Block 并创建 Wire"""
blocks: List[Block] = []
wires: List[Wire] = []
name_to_block: Dict[str, Block] = {}
name_to_node: Dict[str, Node] = {}
# 首先实例化所有 Block
for node in self.nodes:
try:
# 如果是条件或循环块,需要从前一个块获取输出信息
if node.is_conditional or node.is_loop:
prev_node = node.parent
if prev_node and prev_node.spec.block_class:
node.spec.kwargs["outputs"] = prev_node.spec.block_class.outputs.copy()
block = node.spec.block_class(**node.spec.kwargs)
if node.name:
block.name = node.name
block.container = container
blocks.append(block)
name_to_block[node.name] = block
name_to_node[node.name] = node
except Exception as e:
raise ValueError(f"Failed to create block {node.spec.block_class.__name__}: {e}")
# 然后创建所有 Wire
for source_name, source_output, target_name, target_input in self.wire_specs:
source_block = name_to_block.get(source_name)
target_block = name_to_block.get(target_name)
if source_block and target_block:
wires.append(Wire(source_block, source_output, target_block, target_input))
return Workflow(name=self.name, blocks=blocks, wires=wires, id=self.id, config=self.config)
def set_config(self, config: WorkflowConfig):
self.config = config
return self
def force_connect(
self,
source_name: str,
target_name: str,
source_output: str,
target_input: str,
):
"""强制存储特定的连接规格"""
self.wire_specs.append((source_name, source_output, target_name, target_input))
def _find_parallel_nodes(self, start_node: Node) -> List[Node]:
"""查找所有并行节点"""
parallel_nodes: List[Node] = []
current = start_node
while current:
if current.is_parallel:
parallel_nodes.extend(current.parallel_nodes)
if current.next_nodes:
current = current.next_nodes[0]
else:
break
return parallel_nodes
def update_position(self, name: str, position: Dict[str, int]):
"""更新节点的位置"""
node = self.nodes_by_name[name]
node.position = position
def save_to_yaml(self, file_path: str, container: DependencyContainer):
"""将工作流保存为 YAML 格式"""
registry: BlockRegistry = container.resolve(BlockRegistry)
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.width = 4096
workflow_data: Dict[str, Any] = {
"name": self.name,
"description": self.description,
"blocks": [],
"config": self.config.model_dump(),
}
def serialize_node(node: Node) -> dict:
block_data: Dict[str, Any] = {
"type": registry.get_block_type_name(node.spec.block_class),
"name": node.name,
"params": node.spec.kwargs,
"position": node.position,
}
if node.is_parallel:
block_data["parallel"] = True
# 添加连接信息
connected_to: List[Dict[str, Any]] = []
for wire in self.wire_specs:
if wire[0] == node.name:
# 使用 block.name 查找目标节点
target_node = next(
(
n
for n in self.nodes_by_name.values()
if n.name == wire[2]
),
None,
)
if target_node: # 只在找到目标节点时添加连接
connected_to.append(
{
"target": target_node.name,
"mapping": {
"from": wire[1],
"to": wire[3],
},
}
)
if connected_to:
block_data["connected_to"] = connected_to
return block_data
# 序列化所有节点
for node in self.nodes_by_name.values():
workflow_data["blocks"].append(serialize_node(node))
# 保存到文件
with open(file_path, "w", encoding="utf-8") as f:
yaml.dump(workflow_data, f)
return self
@classmethod
def load_from_yaml(
cls, file_path: str, container: DependencyContainer
) -> "WorkflowBuilder":
"""从 YAML 文件加载工作流
Args:
file_path: YAML 文件路径
container: 依赖注入容器
Returns:
WorkflowBuilder 实例
"""
yaml = YAML(typ="safe")
with open(file_path, "r", encoding="utf-8") as f:
workflow_data: Dict[str, Any] = yaml.load(f)
builder: WorkflowBuilder = cls(workflow_data["name"])
builder.config = WorkflowConfig.model_validate(workflow_data.get("config", {}))
builder.description = workflow_data.get("description", "")
registry: BlockRegistry = container.resolve(BlockRegistry)
# 第一遍:创建所有块
for block_data in workflow_data["blocks"]:
block_class = get_block_class(block_data["type"], registry)
params = block_data.get("params", {})
if block_data.get("parallel"):
# 处理并行节点
parallel_blocks = [(block_class, block_data["name"], params)]
builder.parallel(parallel_blocks) # type: ignore
else:
# 处理普通节点
if builder.head is None:
builder.use(block_class, name=block_data["name"], **params)
else:
builder.chain(block_class, name=block_data["name"], **params)
if block_data.get("position"):
builder.update_position(block_data["name"], block_data["position"])
# 第二遍:建立连接
builder.wire_specs = []
for block_data in workflow_data["blocks"]:
if "connected_to" in block_data:
source_node = builder.nodes_by_name[block_data["name"]]
for connection in block_data["connected_to"]:
target_node = builder.nodes_by_name[connection["target"]]
builder.force_connect(
source_node.name,
target_node.name,
connection["mapping"]["from"],
connection["mapping"]["to"],
)
return builder
================================================
FILE: kirara_ai/workflow/core/workflow/registry.py
================================================
import os
import re
from typing import Dict, Optional
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.workflow.base import Workflow
from kirara_ai.workflow.core.workflow.builder import WorkflowBuilder
class WorkflowRegistry:
"""工作流注册表,管理工作流的注册和获取"""
WORKFLOWS_DIR = os.path.realpath("data/workflows")
def __init__(self, container: DependencyContainer):
self._workflows: Dict[str, WorkflowBuilder] = {}
self.logger = get_logger("WorkflowRegistry")
self.container = container
@classmethod
def get_workflow_path(cls, group_id: str, workflow_id: str) -> str:
"""获取工作流文件路径"""
group_dir = os.path.join(cls.WORKFLOWS_DIR, group_id)
final_path = os.path.join(group_dir, f"{workflow_id}.yaml")
if (
os.path.commonprefix((os.path.realpath(final_path), cls.WORKFLOWS_DIR))
!= cls.WORKFLOWS_DIR
):
raise ValueError("Invalid workflow path")
# check is valid path symbols
if not re.match(r"^[a-zA-Z0-9_-]+$", workflow_id):
invalid_chars = re.findall(r"[^a-zA-Z0-9_-]", workflow_id)
raise ValueError(
f"Invalid symbols in workflow path: {''.join(invalid_chars)}"
)
if not re.match(r"^[a-zA-Z0-9_-]+$", group_id):
invalid_chars = re.findall(r"[^a-zA-Z0-9_-]", group_id)
raise ValueError(
f"Invalid symbols in workflow path: {''.join(invalid_chars)}"
)
if not os.path.exists(group_dir):
os.makedirs(group_dir)
return final_path
def unregister(self, group_id: str, workflow_id: str):
"""注销一个工作流"""
full_name = f"{group_id}:{workflow_id}"
if full_name in self._workflows:
del self._workflows[full_name]
self.logger.info(f"Unregistered workflow: {full_name}")
def register(
self, group_id: str, workflow_id: str, workflow_builder: WorkflowBuilder
):
"""注册一个工作流"""
full_name = f"{group_id}:{workflow_id}"
if full_name in self._workflows:
self.logger.warning(f"Workflow {full_name} already registered, overwriting")
workflow_builder.id = full_name
self._workflows[full_name] = workflow_builder
self.logger.info(f"Registered workflow: {full_name}")
def register_preset_workflow(
self, group_id: str, workflow_id: str, workflow_builder: WorkflowBuilder
):
"""预设工作流注册,当用户保存了同 id 的工作流时,则会不注册"""
full_name = f"{group_id}:{workflow_id}"
if full_name in self._workflows:
self.logger.debug(
f"Preset workflow {full_name} already registered, skipping"
)
return
self._workflows[full_name] = workflow_builder
self.logger.info(f"Registered preset workflow: {full_name}")
def get_workflow(self, name: str, container: DependencyContainer) -> Optional[Workflow]:
builder = self._workflows.get(name)
if builder:
return builder.build(container)
return None
def get(
self, name: str, container: Optional[DependencyContainer] = None
) -> Optional[WorkflowBuilder | Workflow]:
"""获取工作流构建器或实例"""
builder = self._workflows.get(name)
if builder and container:
return builder.build(container)
return builder
def load_workflows(self, workflows_dir: Optional[str] = None):
"""从指定目录加载所有工作流定义"""
workflows_dir = workflows_dir or self.WORKFLOWS_DIR
if not os.path.exists(workflows_dir):
os.makedirs(workflows_dir)
# 遍历所有组目录
for group_id in os.listdir(workflows_dir):
group_dir = os.path.join(workflows_dir, group_id)
if not os.path.isdir(group_dir):
continue
# 遍历组内的工作流文件
for file_name in os.listdir(group_dir):
if not file_name.endswith(".yaml"):
continue
workflow_id = os.path.splitext(file_name)[0]
file_path = os.path.join(group_dir, file_name)
try:
workflow = WorkflowBuilder.load_from_yaml(file_path, self.container)
self.register(group_id, workflow_id, workflow)
except Exception as e:
self.logger.error(
f"Failed to load workflow from {file_path}: {str(e)}"
)
================================================
FILE: kirara_ai/workflow/implementations/__init__.py
================================================
================================================
FILE: kirara_ai/workflow/implementations/blocks/__init__.py
================================================
from .system_blocks import register_system_blocks
__all__ = ["register_system_blocks"]
================================================
FILE: kirara_ai/workflow/implementations/blocks/game/dice.py
================================================
import random
import re
from typing import Any, Dict
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
class DiceRoll(Block):
"""骰子掷点 block"""
name = "dice_roll"
inputs = {
"message": Input("message", "输入消息", IMMessage, "输入消息包含骰子命令")
}
outputs = {
"response": Output(
"response", "响应消息", IMMessage, "响应消息包含骰子掷点结果"
)
}
def execute(self, message: IMMessage) -> Dict[str, Any]:
# 解析命令
command = message.content
match = re.match(r"^[.。]roll\s*(\d+)?d(\d+)", command)
if not match:
return {
"response": IMMessage(
sender=ChatSender.get_bot_sender(),
message_elements=[TextMessage("Invalid dice command")],
)
}
count = int(match.group(1) or "1") # 默认1个骰子
sides = int(match.group(2))
if count > 100: # 限制骰子数量
return {
"response": IMMessage(
sender=ChatSender.get_bot_sender(),
message_elements=[TextMessage("Too many dice (max 100)")],
)
}
# 掷骰子
rolls = [random.randint(1, sides) for _ in range(count)]
total = sum(rolls)
# 生成详细信息
details = f"🎲 掷出了 {count}d{sides}: {' + '.join(map(str, rolls))}"
if count > 1:
details += f" = {total}"
return {
"response": IMMessage(
sender=ChatSender.get_bot_sender(),
message_elements=[TextMessage(details)]
)
}
================================================
FILE: kirara_ai/workflow/implementations/blocks/game/gacha.py
================================================
import random
from typing import Dict, Optional
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
class GachaSimulator(Block):
"""抽卡模拟器 block"""
name = "gacha_simulator"
inputs = {
"message": Input("message", "输入消息", IMMessage, "输入消息包含抽卡命令")
}
outputs = {
"response": Output("response", "响应消息", IMMessage, "响应消息包含抽卡结果")
}
def __init__(self, rates: Optional[Dict[str, float]] = None):
# 默认抽卡概率
self.rates = rates or {"SSR": 0.03, "SR": 0.12, "R": 0.85} # 3% # 12% # 85%
def _single_pull(self) -> str:
rand = random.random()
cumulative = 0
for rarity, rate in self.rates.items():
cumulative += rate
if rand <= cumulative:
return rarity
return list(self.rates.keys())[-1] # 保底
def execute(self, message: IMMessage) -> Dict[str, IMMessage]:
# 解析命令
command = message.content
is_ten_pull = "十连" in command
pulls = 10 if is_ten_pull else 1
# 抽卡
results = [self._single_pull() for _ in range(pulls)]
# 生成结果统计
stats = {rarity: results.count(rarity) for rarity in self.rates.keys()}
# 生成详细信息
details = []
for rarity in results:
if rarity == "SSR":
details.append("🌟 SSR")
elif rarity == "SR":
details.append("✨ SR")
else:
details.append("⭐ R")
result_text = f"抽卡结果: {'、'.join(details)}"
stats_text = "统计:\n" + "\n".join(
f"{rarity}: {count}" for rarity, count in stats.items()
)
return {
"response": IMMessage(
sender=ChatSender.get_bot_sender(),
message_elements=[TextMessage(result_text), TextMessage(stats_text)],
)
}
================================================
FILE: kirara_ai/workflow/implementations/blocks/im/basic.py
================================================
from typing import Any, Dict
from kirara_ai.im.message import IMMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
class ExtractChatSender(Block):
"""提取消息发送者"""
name = "extract_chat_sender"
container: DependencyContainer
inputs = {"msg": Input("msg", "IM 消息", IMMessage, "IM 消息")}
outputs = {"sender": Output("sender", "消息发送者", ChatSender, "消息发送者")}
def execute(self, **kwargs) -> Dict[str, Any]:
msg = self.container.resolve(IMMessage)
return {"sender": msg.sender}
================================================
FILE: kirara_ai/workflow/implementations/blocks/im/messages.py
================================================
import asyncio
from typing import Annotated, Any, Dict, List, Optional
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.manager import IMManager
from kirara_ai.im.message import IMMessage, MessageElement, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block, Input, Output, ParamMeta
def im_adapter_options_provider(container: DependencyContainer, block: Block) -> List[str]:
return [key for key, _ in container.resolve(IMManager).adapters.items()]
class GetIMMessage(Block):
"""获取 IM 消息"""
name = "msg_input"
container: DependencyContainer
outputs = {
"msg": Output("msg", "IM 消息", IMMessage, "获取 IM 发送的最新一条的消息"),
"sender": Output("sender", "发送者", ChatSender, "获取 IM 消息的发送者"),
}
def execute(self, **kwargs) -> Dict[str, Any]:
msg = self.container.resolve(IMMessage)
return {"msg": msg, "sender": msg.sender}
class SendIMMessage(Block):
"""发送 IM 消息"""
name = "msg_sender"
inputs = {
"msg": Input("msg", "IM 消息", IMMessage, "要从 IM 发送的消息"),
"target": Input(
"target",
"发送对象",
ChatSender,
"要发送给谁,如果填空则默认发送给消息的发送者",
nullable=True,
),
}
outputs = {}
container: DependencyContainer
def __init__(
self, im_name: Annotated[Optional[str], ParamMeta(label="聊天平台适配器名称", options_provider=im_adapter_options_provider)] = None
):
self.im_name = im_name
def execute(
self, msg: IMMessage, target: Optional[ChatSender] = None
) -> Dict[str, Any]:
src_msg = self.container.resolve(IMMessage)
if not self.im_name:
adapter = self.container.resolve(IMAdapter)
else:
adapter = self.container.resolve(
IMManager).get_adapter(self.im_name)
loop: asyncio.AbstractEventLoop = self.container.resolve(
asyncio.AbstractEventLoop
)
loop.create_task(adapter.send_message(msg, target or src_msg.sender))
return {"ok": True}
# IMMessage 转纯文本
class IMMessageToText(Block):
"""IMMessage 转纯文本"""
name = "im_message_to_text"
container: DependencyContainer
inputs = {"msg": Input("msg", "IM 消息", IMMessage, "IM 消息")}
outputs = {"text": Output("text", "纯文本", str, "纯文本")}
def execute(self, msg: IMMessage) -> Dict[str, Any]:
return {"text": msg.content}
# 纯文本转 IMMessage
class TextToIMMessage(Block):
"""纯文本转 IMMessage"""
name = "text_to_im_message"
container: DependencyContainer
inputs = {"text": Input("text", "纯文本", str, "纯文本")}
outputs = {"msg": Output("msg", "IM 消息", IMMessage, "IM 消息")}
def __init__(self, split_by: Annotated[Optional[str], ParamMeta(label="分段符")] = None):
self.split_by = split_by
def execute(self, text: str) -> Dict[str, Any]:
if self.split_by:
return {"msg": IMMessage(sender=ChatSender.get_bot_sender(), message_elements = [TextMessage(line.strip()) for line in text.split(self.split_by) if line.strip()])}
else:
return {"msg": IMMessage(sender=ChatSender.get_bot_sender(), message_elements=[TextMessage(text)])}
# 补充 IMMessage 消息
class AppendIMMessage(Block):
"""补充 IMMessage 消息"""
name = "concat_im_message"
container: DependencyContainer
inputs = {
"base_msg": Input("base_msg", "IM 消息", IMMessage, "IM 消息"),
"append_msg": Input("append_msg", "新消息片段", MessageElement, "新消息片段"),
}
outputs = {"msg": Output("msg", "IM 消息", IMMessage, "IM 消息")}
def execute(self, base_msg: IMMessage, append_msg: MessageElement) -> Dict[str, Any]:
return {"msg": IMMessage(sender=base_msg.sender, message_elements=base_msg.message_elements + [append_msg])}
================================================
FILE: kirara_ai/workflow/implementations/blocks/im/states.py
================================================
import asyncio
from typing import Annotated, Any, Dict
from kirara_ai.im.adapter import EditStateAdapter, IMAdapter
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block, Input, ParamMeta
# Toggle edit state
class ToggleEditState(Block):
name = "toggle_edit_state"
inputs = {
"sender": Input("sender", "聊天对象", ChatSender, "要切换编辑状态的聊天对象")
}
outputs = {}
container: DependencyContainer
def __init__(
self,
is_editing: Annotated[
bool, ParamMeta(label="是否编辑", description="是否切换到编辑状态")
],
):
self.is_editing = is_editing
def execute(self, sender: ChatSender) -> Dict[str, Any]:
im_adapter = self.container.resolve(IMAdapter)
if isinstance(im_adapter, EditStateAdapter):
loop: asyncio.AbstractEventLoop = self.container.resolve(
asyncio.AbstractEventLoop
)
loop.create_task(im_adapter.set_chat_editing_state(sender, self.is_editing))
return {}
================================================
FILE: kirara_ai/workflow/implementations/blocks/im/user_profile.py
================================================
import asyncio
from typing import Any, Dict, Optional
from kirara_ai.im.adapter import IMAdapter, UserProfileAdapter
from kirara_ai.im.profile import UserProfile
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
class QueryUserProfileBlock(Block):
def __init__(self, container: DependencyContainer):
inputs = {
"chat_sender": Input(
"chat_sender", "聊天对象", ChatSender, "要查询聊天对象的 profile"
),
"im_adapter": Input(
"im_adapter", "IM 平台", IMAdapter, "IM 平台适配器", nullable=True
),
}
outputs = {"profile": Output("profile", "用户资料", UserProfile, "用户资料")}
super().__init__("query_user_profile", inputs, outputs)
self.container = container
def execute(
self, chat_sender: ChatSender, im_adapter: Optional[IMAdapter] = None
) -> Dict[str, Any]:
# 如果没有提供 im_adapter,则从容器中获取默认的
if im_adapter is None:
im_adapter = self.container.resolve(IMAdapter)
# 检查 im_adapter 是否实现了 UserProfileAdapter 协议
if not isinstance(im_adapter, UserProfileAdapter):
raise TypeError(
f"IM Adapter {type(im_adapter)} does not support user profile querying"
)
# 同步调用异步方法(在工作流执行器中会被正确处理)
profile = asyncio.run(im_adapter.query_user_profile(chat_sender)) # type: ignore
return {"profile": profile}
================================================
FILE: kirara_ai/workflow/implementations/blocks/llm/basic.py
================================================
# LLM 响应转纯文本
from typing import Any, Dict
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.response import LLMChatResponse
from kirara_ai.workflow.core.block.base import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
class LLMResponseToText(Block):
"""LLM 响应转纯文本"""
name = "llm_response_to_text"
container: DependencyContainer
inputs = {"response": Input("response", "LLM 响应", LLMChatResponse, "LLM 响应")}
outputs = {"text": Output("text", "纯文本", str, "纯文本")}
def execute(self, response: LLMChatResponse) -> Dict[str, Any]:
content = ""
if response.message:
for part in response.message.content:
if part.type == "text":
content = content + part.text
elif part.type == "image":
content = content + f""
return {"text": content}
================================================
FILE: kirara_ai/workflow/implementations/blocks/llm/chat.py
================================================
import asyncio
import re
from datetime import datetime
from typing import Annotated, Any, Dict, List, Optional
from kirara_ai.im.message import ImageMessage, IMMessage, MessageElement, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format import LLMChatMessage, LLMChatTextContent
from kirara_ai.llm.format.message import LLMChatContentPartType, LLMChatImageContent
from kirara_ai.llm.format.request import LLMChatRequest, Tool
from kirara_ai.llm.format.response import LLMChatResponse
from kirara_ai.llm.llm_manager import LLMManager
from kirara_ai.llm.model_types import LLMAbility, ModelType
from kirara_ai.logger import get_logger
from kirara_ai.memory.composes.base import ComposableMessageType
from kirara_ai.workflow.core.block import Block, Input, Output, ParamMeta
from kirara_ai.workflow.core.execution.executor import WorkflowExecutor
def model_name_options_provider(container: DependencyContainer, block: Block) -> List[str]:
llm_manager: LLMManager = container.resolve(LLMManager)
return sorted(llm_manager.get_supported_models(ModelType.LLM, LLMAbility.TextChat))
class ChatMessageConstructor(Block):
name = "chat_message_constructor"
inputs = {
"user_msg": Input("user_msg", "本轮消息", IMMessage, "用户消息"),
"user_prompt_format": Input(
"user_prompt_format", "本轮消息格式", str, "本轮消息格式", default=""
),
"memory_content": Input("memory_content", "历史消息对话", List[ComposableMessageType], "历史消息对话"),
"system_prompt_format": Input(
"system_prompt_format", "系统提示词", str, "系统提示词", default=""
),
}
outputs = {
"llm_msg": Output(
"llm_msg", "LLM 对话记录", List[LLMChatMessage], "LLM 对话记录"
)
}
container: DependencyContainer
def substitute_variables(self, text: str, executor: WorkflowExecutor) -> str:
"""
替换文本中的变量占位符,支持对象属性和字典键的访问
:param text: 包含变量占位符的文本,格式为 {variable_name} 或 {variable_name.attribute}
:param executor: 工作流执行器实例
:return: 替换后的文本
"""
def replace_var(match):
var_path = match.group(1).split(".")
var_name = var_path[0]
# 获取基础变量
value = executor.get_variable(var_name, match.group(0))
# 如果有属性/键访问
for attr in var_path[1:]:
try:
# 尝试字典键访问
if isinstance(value, dict):
value = value.get(attr, match.group(0))
# 尝试对象属性访问
elif hasattr(value, attr):
value = getattr(value, attr)
else:
# 如果无法访问,返回原始占位符
return match.group(0)
except Exception:
# 任何异常都返回原始占位符
return match.group(0)
return str(value)
return re.sub(r"\{([^}]+)\}", replace_var, text)
def execute(
self,
user_msg: IMMessage,
memory_content: str,
system_prompt_format: str = "",
user_prompt_format: str = "",
) -> Dict[str, Any]:
# 获取当前执行器
executor = self.container.resolve(WorkflowExecutor)
# 先替换自有的两个变量
replacements = {
"{current_date_time}": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"{user_msg}": user_msg.content,
"{user_name}": user_msg.sender.display_name,
"{user_id}": user_msg.sender.user_id
}
if isinstance(memory_content, list) and all(isinstance(item, str) for item in memory_content):
replacements["{memory_content}"] = "\n".join(memory_content)
for old, new in replacements.items():
system_prompt_format = system_prompt_format.replace(old, new)
user_prompt_format = user_prompt_format.replace(old, new)
# 再替换其他变量
system_prompt = self.substitute_variables(
system_prompt_format, executor)
user_prompt = self.substitute_variables(user_prompt_format, executor)
content: List[LLMChatContentPartType] = [
LLMChatTextContent(text=user_prompt)]
# 添加图片内容
for image in user_msg.images or []:
content.append(LLMChatImageContent(media_id=image.media_id))
llm_msg = [
LLMChatMessage(role="system", content=[
LLMChatTextContent(text=system_prompt)]),
]
if isinstance(memory_content, list) and all(isinstance(item, LLMChatMessage) for item in memory_content):
llm_msg.extend(memory_content) # type: ignore
llm_msg.append(LLMChatMessage(role="user", content=content))
return {"llm_msg": llm_msg}
class ChatCompletion(Block):
name = "chat_completion"
inputs = {
"prompt": Input("prompt", "LLM 对话记录", List[LLMChatMessage], "LLM 对话记录")
}
outputs = {"resp": Output("resp", "LLM 对话响应", LLMChatResponse, "LLM 对话响应")}
container: DependencyContainer
def __init__(
self,
model_name: Annotated[
Optional[str],
ParamMeta(
label="模型 ID",
description="要使用的模型 ID",
options_provider=model_name_options_provider),
] = None,
):
self.model_name = model_name
self.logger = get_logger("ChatCompletionBlock")
def execute(self, prompt: List[LLMChatMessage]) -> Dict[str, Any]:
llm_manager = self.container.resolve(LLMManager)
model_id = self.model_name
if not model_id:
model_id = llm_manager.get_llm_id_by_ability(LLMAbility.TextChat)
if not model_id:
raise ValueError("No available LLM models found")
else:
self.logger.info(
f"Model id unspecified, using default model: {model_id}"
)
else:
self.logger.debug(f"Using specified model: {model_id}")
llm = llm_manager.get_llm(model_id)
if not llm:
raise ValueError(
f"LLM {model_id} not found, please check the model name")
req = LLMChatRequest(messages=prompt, model=model_id)
return {"resp": llm.chat(req)}
class ChatResponseConverter(Block):
name = "chat_response_converter"
inputs = {"resp": Input("resp", "LLM 响应", LLMChatResponse, "LLM 响应")}
outputs = {"msg": Output("msg", "IM 消息", IMMessage, "IM 消息")}
container: DependencyContainer
def execute(self, resp: LLMChatResponse) -> Dict[str, Any]:
message_elements: List[MessageElement] = []
for part in resp.message.content:
if isinstance(part, LLMChatTextContent):
# 通过 将回答分为不同的 TextMessage
for element in part.text.split(""):
if element.strip():
message_elements.append(TextMessage(element.strip()))
elif isinstance(part, LLMChatImageContent):
message_elements.append(ImageMessage(media_id=part.media_id))
msg = IMMessage(sender=ChatSender.get_bot_sender(),
message_elements=message_elements)
return {"msg": msg}
class ChatCompletionWithTools(Block):
"""
支持工具调用的LLM对话块
"""
name = "chat_completion_with_tools"
inputs = {
"msg": Input("msg", "LLM 对话记录", List[LLMChatMessage], "LLM 的 prompt,即由 system、user、assistant和工具调用及结果的完整对话记录"),
"tools": Input("tools", "工具列表", List[Tool], "工具列表")
}
outputs = {
"resp": Output("resp", "LLM 消息回应", LLMChatResponse, "模型返回给用户的消息"),
"iteration_msgs": Output("iteration_msgs", "中间步骤消息", List[ComposableMessageType], "迭代过程中产生的所有消息,可以用记忆存储")
}
container: DependencyContainer
def __init__(self, model_name: Annotated[
str,
ParamMeta(
label="模型 ID, 需要支持函数调用",
description="支持函数调用的模型",
options_provider=model_name_options_provider)
],
max_iterations: Annotated[
int,
ParamMeta(
label="最大迭代次数",
description="允许调用模型请求的最大次数,在进行最后一次请求时,模型将不允许调用工具")
] = 4):
self.model_name = model_name
self.max_iterations = max_iterations
self.logger = get_logger("Block.ChatCompletionWithTools")
def execute(self, msg: List[LLMChatMessage], tools: List[Tool]) -> Dict[str, Any]:
if not self.model_name:
raise ValueError(
"need a model name which support function calling")
else:
self.logger.info(
f"Using model: {self.model_name} to execute function calling")
loop = self.container.resolve(asyncio.AbstractEventLoop)
llm = self.container.resolve(LLMManager).get_llm(self.model_name)
if not llm:
raise ValueError(
f"LLM {self.model_name} not found, please check the model name")
iteration_msgs: List[LLMChatMessage] = []
iter_count = 0
while iter_count < self.max_iterations:
# 在这里指定llm的model
self.logger.debug(
f"Iteration {iter_count+1} of {self.max_iterations}")
request_body = LLMChatRequest(
messages=msg + iteration_msgs, model=self.model_name)
if tools is not None and len(tools) > 0:
request_body.tools = tools
# 最后一次迭代不调用工具
if iter_count == self.max_iterations - 1:
request_body.tool_choice = "none"
tools_mapping = {t.name: t for t in tools}
response: LLMChatResponse = llm.chat(request_body)
iter_count += 1
if response.message.tool_calls:
iteration_msgs.append(response.message)
self.logger.debug("Tool calls found, attempt to invoke tools")
for tool_call in response.message.tool_calls:
actual_tool = tools_mapping.get(tool_call.function.name)
if actual_tool:
self.logger.debug(
f"Invoking tool: {actual_tool.name}({tool_call.function.arguments})")
resp_future = asyncio.run_coroutine_threadsafe(
actual_tool.invokeFunc(tool_call), loop
)
tool_result_msg = LLMChatMessage(
role="tool", content=[resp_future.result()])
iteration_msgs.append(tool_result_msg)
else:
self.logger.debug(
"No tool calls found, return response directly")
return {"resp": response, "iteration_msgs": iteration_msgs}
return {"resp": response, "iteration_msgs": iteration_msgs}
================================================
FILE: kirara_ai/workflow/implementations/blocks/llm/image.py
================================================
import base64
from typing import Any, Dict, Optional
import requests
from kirara_ai.im.message import ImageMessage
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
class SimpleStableDiffusionWebUI(Block):
name = "simple_stable_diffusion_webui"
inputs = {
"prompt": Input("prompt", "提示", str, "提示"),
"negative_prompt": Input("negative_prompt", "负面提示", str, "负面提示"),
}
outputs = {"image": Output("image", "图片", ImageMessage, "生成的图片")}
def __init__(
self,
api_url: str,
*,
steps: int = 20,
sampler_index: str = "Euler a",
cfg_scale: float = 7.0,
width: int = 512,
height: int = 512,
ckpt_name: Optional[str] = None,
clip_skip: int = 1,
):
self.api_url = api_url
self.steps = steps
self.sampler_index = sampler_index
self.cfg_scale = cfg_scale
self.width = width
self.height = height
self.ckpt_name = ckpt_name
self.clip_skip = clip_skip
def execute(self, prompt: str, negative_prompt: str) -> Dict[str, Any]:
payload = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"steps": self.steps,
"sampler_index": self.sampler_index,
"cfg_scale": self.cfg_scale,
"width": self.width,
"height": self.height,
}
if self.ckpt_name:
payload["ckpt_name"] = self.ckpt_name
payload["clip_skip"] = self.clip_skip
response = requests.post(url=f"{self.api_url}/sdapi/v1/txt2img", json=payload)
if response.status_code == 200:
r = response.json()
# Assuming the API returns the image in base64 format
# and it's the first image in the list
if "images" in r and r["images"]:
image_base64 = r["images"][0]
image_bytes = base64.b64decode(image_base64)
image_message = ImageMessage(
data=image_bytes, format="png"
) # 假设是 PNG 格式
return {"image": image_message}
else:
raise Exception("No image data found in the response")
else:
raise Exception(
f"API request failed with status code: {response.status_code}, message: {response.text}"
)
================================================
FILE: kirara_ai/workflow/implementations/blocks/mcp/__init__.py
================================================
================================================
FILE: kirara_ai/workflow/implementations/blocks/mcp/tool.py
================================================
from base64 import b64decode
from typing import Annotated, Any, Dict, List
from mcp import types
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format import tool
from kirara_ai.llm.format.message import LLMToolResultContent
from kirara_ai.llm.format.tool import CallableWrapper, Tool, ToolCall
from kirara_ai.logger import get_logger
from kirara_ai.mcp_module.manager import MCPServerManager
from kirara_ai.media.manager import MediaManager
from kirara_ai.media.types.media_type import MediaType
from kirara_ai.workflow.core.block import Block, Output
from kirara_ai.workflow.core.block.param import ParamMeta
def get_enabled_mcp_tools(container: DependencyContainer, block: Block) -> List[str]:
mcp_manager = container.resolve(MCPServerManager)
return list(mcp_manager.get_tools().keys())
class MCPToolProvider(Block):
"""
提供MCP工具调用工具
"""
name = "mcp_tool_provider"
outputs = {
"tools": Output("tools", "工具列表", List[Tool], "工具列表")
}
container: DependencyContainer
def __init__(self, enabled_tools: Annotated[List[str], ParamMeta(label="启用工具列表", description="启用工具列表", options_provider=get_enabled_mcp_tools)]):
self.logger = get_logger("MCPCallTool")
self.enabled_tools = enabled_tools
async def _call_tool(self, tool_call: ToolCall) -> LLMToolResultContent:
"""提供MCP工具调用执行回调"""
mcp_manager = self.container.resolve(MCPServerManager)
server_info = mcp_manager.get_tool_server(tool_call.function.name)
if not server_info:
raise ValueError(f"找不到工具: {tool_call.function.name}")
server, original_name = server_info
result = await server.call_tool(original_name, tool_call.function.arguments)
tool_result = await self._create_tool_result(
tool_call.id, tool_call.function.name, result.content
)
tool_result.isError = result.isError
self.logger.info(f"工具调用结果: {tool_result}")
return tool_result
def execute(self) -> Dict[str, Any]:
"""
提供MCP工具列表
Returns:
包含工具列表的字典
"""
mcp_manager = self.container.resolve(MCPServerManager)
mcp_tools = mcp_manager.get_tools()
built_tools = []
for tool_name, tool_info in mcp_tools.items():
if tool_name in self.enabled_tools:
built_tools.append(
Tool(
name=tool_name,
parameters=tool_info.tool_info.inputSchema,
description=tool_info.tool_info.description or "",
invokeFunc=CallableWrapper(self._call_tool)
)
)
return {
"tools": built_tools
}
async def _create_tool_result(self, tool_id: str, tool_name: str, content: list[types.TextContent | types.ImageContent | types.EmbeddedResource]) -> LLMToolResultContent:
"""创建工具调用结果"""
converted_content: List[tool.TextContent | tool.MediaContent] = []
for item in content:
if isinstance(item, types.TextContent):
converted_content.append(tool.TextContent(
text=item.text
))
elif isinstance(item, types.ImageContent):
data = b64decode(item.data)
media_type = MediaType.from_mime(item.mimeType)
format = item.mimeType.split("/")[1]
media_id = await self.container.resolve(MediaManager).register_from_data(data, format=format, media_type=media_type)
converted_content.append(tool.MediaContent(
media_id=media_id,
mime_type=item.mimeType,
data=data
))
return LLMToolResultContent(
id=tool_id,
name=tool_name,
content=converted_content
)
================================================
FILE: kirara_ai/workflow/implementations/blocks/memory/chat_memory.py
================================================
from typing import Annotated, Any, Dict, List, Optional
from kirara_ai.im.message import IMMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.response import LLMChatResponse
from kirara_ai.logger import get_logger
from kirara_ai.memory.composes.base import ComposableMessageType
from kirara_ai.memory.memory_manager import MemoryManager
from kirara_ai.memory.registry import ComposerRegistry, DecomposerRegistry, ScopeRegistry
from kirara_ai.workflow.core.block import Block, Input, Output, ParamMeta
def scope_type_options_provider(container: DependencyContainer, block: Block) -> List[str]:
return ["global", "group", "member"]
def decomposer_name_options_provider(container: DependencyContainer, block: Block) -> List[str]:
return ["default", "multi_element"]
class ChatMemoryQuery(Block):
name = "chat_memory_query"
inputs = {
"chat_sender": Input(
"chat_sender", "聊天对象", ChatSender, "要查询记忆的聊天对象"
)
}
outputs = {"memory_content": Output(
"memory_content", "记忆内容", List[ComposableMessageType], "记忆内容")}
container: DependencyContainer
def __init__(
self,
scope_type: Annotated[
Optional[str],
ParamMeta(
label="级别",
description="要查询记忆的级别,代表记忆可以被共享的粒度。(例如:member 级别下,同一群聊下不同用户的记忆互相隔离; group 级别下,同一群组内所有成员记忆共享,但不同群组之间记忆互相隔离)",
options_provider=scope_type_options_provider,
),
],
decomposer_name: Annotated[
Optional[str],
ParamMeta(
label="解析器名称",
description="要使用的解析器名称",
options_provider=decomposer_name_options_provider,
),
] = "default",
extra_identifier: Annotated[
Optional[str],
ParamMeta(
label="额外隔离标识符",
description="仅支持输入英文,可为空。对于同一用户,不同标识符之间的记忆互相隔离。可用于避免不同工作流之间记忆互相干扰。",
),
] = None,
):
self.scope_type = scope_type
self.decomposer_name: str = decomposer_name or "default"
self.extra_identifier = extra_identifier
def execute(self, chat_sender: ChatSender) -> Dict[str, Any]:
self.memory_manager = self.container.resolve(MemoryManager)
# 如果没有指定作用域类型,使用配置中的默认值
if self.scope_type is None:
self.scope_type = self.memory_manager.config.default_scope
# 获取作用域实例
scope_registry = self.container.resolve(ScopeRegistry)
self.scope = scope_registry.get_scope(self.scope_type)
# 获取解析器实例
decomposer_registry = self.container.resolve(DecomposerRegistry)
self.decomposer = decomposer_registry.get_decomposer(
self.decomposer_name)
entries = self.memory_manager.query(self.scope, chat_sender, self.extra_identifier)
memory_content = self.decomposer.decompose(entries)
return {"memory_content": memory_content}
class ChatMemoryStore(Block):
name = "chat_memory_store"
inputs = {
"user_msg": Input("user_msg", "用户消息", IMMessage, "用户消息", nullable=True),
"llm_resp": Input(
"llm_resp", "LLM 回复", LLMChatResponse, "LLM 回复", nullable=True
),
"middle_steps": Input(
"middle_steps", "中间步骤消息", List[ComposableMessageType], "中间步骤消息", nullable=True
)
}
outputs = {}
container: DependencyContainer
def __init__(
self,
scope_type: Annotated[
Optional[str],
ParamMeta(
label="级别",
description="要查询记忆的级别,代表记忆可以被共享的粒度。(例如:member 级别下,同一群聊下不同用户的记忆互相隔离; group 级别下,同一群组内所有成员记忆共享,但不同群组之间记忆互相隔离)",
options_provider=scope_type_options_provider,
),
],
extra_identifier: Annotated[
Optional[str],
ParamMeta(
label="额外隔离标识符",
description="仅支持输入英文,可为空。对于同一用户,不同标识符之间的记忆互相隔离。可用于避免不同工作流之间记忆互相干扰。",
),
] = None,
):
self.scope_type = scope_type
self.logger = get_logger("Block.ChatMemoryStore")
self.extra_identifier = extra_identifier
def execute(
self,
user_msg: Optional[IMMessage] = None,
llm_resp: Optional[LLMChatResponse] = None,
middle_steps: Optional[List[ComposableMessageType]] = None,
) -> Dict[str, Any]:
self.memory_manager = self.container.resolve(MemoryManager)
# 如果没有指定作用域类型,使用配置中的默认值
if self.scope_type is None:
self.scope_type = self.memory_manager.config.default_scope
# 获取作用域实例
scope_registry = self.container.resolve(ScopeRegistry)
self.scope = scope_registry.get_scope(self.scope_type)
# 获取组合器实例
composer_registry = self.container.resolve(ComposerRegistry)
self.composer = composer_registry.get_composer("default")
# 存储用户消息和LLM响应
if user_msg is None:
composed_messages: List[ComposableMessageType] = []
else:
composed_messages = [user_msg]
if middle_steps is not None:
composed_messages.extend(middle_steps)
if llm_resp is not None:
if llm_resp.message:
composed_messages.append(llm_resp.message)
if not composed_messages:
self.logger.warning("No messages to store")
return {}
self.logger.debug(f"Composed messages: {composed_messages}")
memory_entries = self.composer.compose(
user_msg.sender if user_msg else None, composed_messages)
self.memory_manager.store(self.scope, memory_entries, self.extra_identifier)
return {}
================================================
FILE: kirara_ai/workflow/implementations/blocks/memory/clear_memory.py
================================================
from typing import Annotated, Any, Dict
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.memory.memory_manager import MemoryManager
from kirara_ai.memory.registry import ScopeRegistry
from kirara_ai.workflow.core.block import Block, Input, Output, ParamMeta
class ClearMemory(Block):
"""Block for clearing conversation memory"""
name = "clear_memory"
inputs = {
"chat_sender": Input("chat_sender", "消息发送者", ChatSender, "消息发送者")
}
outputs = {"response": Output("response", "响应", IMMessage, "响应")}
container: DependencyContainer
def __init__(
self,
scope_type: Annotated[
str, ParamMeta(label="级别", description="要清空记忆的级别")
] = "member",
):
self.scope_type = scope_type
def execute(self, chat_sender: ChatSender) -> Dict[str, Any]:
self.memory_manager = self.container.resolve(MemoryManager)
# Get scope instance
scope_registry = self.container.resolve(ScopeRegistry)
self.scope = scope_registry.get_scope(self.scope_type)
# Clear memory using the manager's method
self.memory_manager.clear_memory(self.scope, chat_sender)
return {
"response": IMMessage(
sender=chat_sender,
message_elements=[TextMessage("已清空当前对话的记忆。")],
)
}
================================================
FILE: kirara_ai/workflow/implementations/blocks/system/basic.py
================================================
import re
from datetime import datetime
from typing import Annotated, Any, Dict, List
from kirara_ai.logger import get_logger
from kirara_ai.workflow.core.block import Block, Output, ParamMeta
from kirara_ai.workflow.core.block.input_output import Input
class TextBlock(Block):
name = "text_block"
outputs = {"text": Output("text", "文本", str, "文本")}
def __init__(
self, text: Annotated[str, ParamMeta(label="文本", description="要输出的文本")]
):
self.text = text
def execute(self) -> Dict[str, Any]:
return {"text": self.text}
# 拼接文本
class TextConcatBlock(Block):
name = "text_concat_block"
inputs = {
"text1": Input("text1", "文本1", str, "文本1"),
"text2": Input("text2", "文本2", str, "文本2"),
}
outputs = {"text": Output("text", "拼接后的文本", str, "拼接后的文本")}
def execute(self, text1: str, text2: str) -> Dict[str, Any]:
return {"text": text1 + text2}
# 替换输入文本中的某一块文字为变量
class TextReplaceBlock(Block):
name = "text_replace_block"
inputs = {
"text": Input("text", "原始文本", str, "原始文本"),
"new_text": Input("new_text", "新文本", Any, "新文本"), # type: ignore
}
outputs = {"text": Output("text", "替换后的文本", str, "替换后的文本")}
def __init__(
self, variable: Annotated[str, ParamMeta(label="被替换的文本", description="被替换的文本")]
):
self.variable = variable
def execute(self, text: str, new_text: Any) -> Dict[str, Any]:
return {
"text": text.replace(self.variable, str(new_text))
}
# 正则表达式提取
class TextExtractByRegexBlock(Block):
name = "text_extract_by_regex_block"
inputs = {"text": Input("text", "原始文本", str, "原始文本")}
outputs = {"text": Output("text", "提取后的文本", str, "提取后的文本")}
def __init__(
self, regex: Annotated[str, ParamMeta(label="正则表达式", description="正则表达式")]
):
self.regex = regex
def execute(self, text: str) -> Dict[str, Any]:
# 使用正则表达式提取文本
regex = re.compile(self.regex)
match = regex.search(text)
# 如果匹配到,则返回匹配到的文本,否则返回空字符串
if match and len(match.groups()) > 0:
return {"text": match.group(1)}
else:
return {"text": ""}
# 获取当前时间
class CurrentTimeBlock(Block):
name = "current_time_block"
outputs = {"time": Output("time", "当前时间", str, "当前时间")}
def execute(self) -> Dict[str, Any]:
return {"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
class CodeBlock(Block):
name = "code_block"
inputs = {}
outputs = {}
def __init__(self,
inputs: Annotated[List[Dict[str, str]], ParamMeta(label="输入参数", description="输入参数")],
outputs: Annotated[List[Dict[str, str]], ParamMeta(label="输出参数", description="输出参数")],
code: Annotated[str, ParamMeta(label="代码", description="代码")]):
# 初始化实例的 inputs 和 outputs
self.inputs = {}
self.outputs = {}
for input_spec in inputs:
self.inputs[input_spec["name"]] = Input(input_spec["name"], input_spec["label"], Any, 'user-specified object') # type: ignore
for output_spec in outputs:
self.outputs[output_spec["name"]] = Output(output_spec["name"], output_spec["label"], Any, 'user-specified object') # type: ignore
self.code = code
def execute(self, **kwargs: Any) -> Dict[str, Any]: # 使用 Any 兼容各种输入类型
logger = get_logger("Block.Code")
exec_globals = globals().copy()
exec_locals: Dict[str, Any] = {}
logger.debug(f"Executing code definition:\n{self.code}")
try:
exec(self.code, exec_globals, exec_locals)
except Exception as e:
logger.error(f"Error during code definition execution: {e}", exc_info=True)
raise RuntimeError(f"Error in provided code definition: {e}") from e
if 'execute' not in exec_locals or not callable(exec_locals['execute']):
raise ValueError("Provided code must define a callable function named 'execute'")
exec_locals['__input_kwargs__'] = kwargs
exec_globals.update(exec_locals)
call_code = "__result__ = execute(**__input_kwargs__)"
logger.debug(f"Executing function call: execute(**{list(kwargs.keys())})")
try:
exec(call_code, exec_globals, exec_locals)
except Exception as e:
logger.error(f"Error during user function 'execute' execution: {e}", exc_info=True)
raise RuntimeError(f"Error during execution of user function 'execute': {e}") from e
if '__result__' not in exec_locals:
# 如果 exec(call_code) 成功但没有 __result__,说明有内部问题
logger.error("Internal error: Result '__result__' not found after executing user code call.")
raise RuntimeError("Failed to retrieve result from user code execution.")
result = exec_locals['__result__']
return result
================================================
FILE: kirara_ai/workflow/implementations/blocks/system/help.py
================================================
from typing import Any, Dict, List
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block, Output
from kirara_ai.workflow.core.dispatch.models.dispatch_rules import RuleGroup
from kirara_ai.workflow.core.dispatch.registry import DispatchRuleRegistry
def _format_rule_condition(rule_type: str, config: Dict[str, Any]) -> str:
"""格式化单个规则的条件描述"""
if rule_type == "prefix":
return f"输入以 {config['prefix']} 开头"
elif rule_type == "keyword":
keywords = config.get("keywords", [])
return f"输入包含 {' 或 '.join(keywords)}"
elif rule_type == "regex":
return f"输入匹配正则 {config['pattern']}"
elif rule_type == "fallback":
return "任意输入"
elif rule_type == "bot_mention":
return f"@我"
elif rule_type == "chat_type":
return f"使用 {config['chat_type']} 聊天类型"
return f"使用 {rule_type} 规则"
def _format_rule_group(group: RuleGroup) -> str:
"""格式化规则组的条件描述"""
rule_conditions = []
for rule in group.rules:
rule_conditions.append(
_format_rule_condition(rule.type, rule.config)
)
operator = " 且 " if group.operator == "and" else " 或 "
return operator.join(rule_conditions)
class GenerateHelp(Block):
"""生成帮助信息 block"""
name = "generate_help"
inputs = {} # 不需要输入
outputs = {"response": Output("response", "帮助信息", IMMessage, "帮助信息")}
container: DependencyContainer
def execute(self) -> Dict[str, Any]:
# 从容器获取调度规则注册表
registry = self.container.resolve(DispatchRuleRegistry)
rules = registry.get_active_rules()
# 按类别组织命令
commands: Dict[str, List[Dict[str, Any]]] = {}
for rule in rules:
# 从 workflow 名称获取类别
category = rule.workflow_id.split(":")[0].lower()
if category not in commands:
commands[category] = []
# 格式化规则组条件
conditions = []
for group in rule.rule_groups:
conditions.append(_format_rule_group(group))
# 组合所有条件(规则组之间是 AND 关系)
rule_format = " 并且 ".join(f"({condition})" for condition in conditions)
commands[category].append(
{
"name": rule.name,
"format": rule_format,
"description": rule.description,
}
)
# 生成帮助文本
help_text = "🤖 机器人命令帮助\n\n"
for category, cmds in sorted(commands.items()):
help_text += f"📑 {category.upper()}\n"
for cmd in sorted(cmds, key=lambda x: x["name"]):
help_text += f"🔸 {cmd['name']}\n"
help_text += f" 触发条件: {cmd['format']}\n"
if cmd["description"]:
help_text += f" 说明: {cmd['description']}\n"
help_text += "\n"
help_text += "\n"
return {
"response": IMMessage(
sender=ChatSender.get_bot_sender(),
message_elements=[TextMessage(help_text)],
)
}
================================================
FILE: kirara_ai/workflow/implementations/blocks/system_blocks.py
================================================
from kirara_ai.workflow.core.block.registry import BlockRegistry
from kirara_ai.workflow.implementations.blocks.im.basic import ExtractChatSender
from kirara_ai.workflow.implementations.blocks.llm.basic import LLMResponseToText
from kirara_ai.workflow.implementations.blocks.llm.image import SimpleStableDiffusionWebUI
from kirara_ai.workflow.implementations.blocks.mcp.tool import MCPToolProvider
from kirara_ai.workflow.implementations.blocks.memory.clear_memory import ClearMemory
from kirara_ai.workflow.implementations.blocks.system.basic import (CodeBlock, CurrentTimeBlock, TextBlock,
TextConcatBlock, TextExtractByRegexBlock,
TextReplaceBlock)
from .game.dice import DiceRoll
from .game.gacha import GachaSimulator
from .im.messages import AppendIMMessage, GetIMMessage, IMMessageToText, SendIMMessage, TextToIMMessage
from .im.states import ToggleEditState
from .llm.chat import ChatCompletion, ChatCompletionWithTools, ChatMessageConstructor, ChatResponseConverter
from .memory.chat_memory import ChatMemoryQuery, ChatMemoryStore
from .system.help import GenerateHelp
def register_system_blocks(registry: BlockRegistry):
"""注册系统自带的 block"""
# 基础 blocks
registry.register("text_block", "internal", TextBlock, "基础:文本")
registry.register("text_concat_block", "internal", TextConcatBlock, "基础:拼接文本")
registry.register("text_replace_block", "internal", TextReplaceBlock, "基础:替换文本")
registry.register("text_extract_by_regex_block", "internal", TextExtractByRegexBlock, "基础:正则表达式提取文本")
registry.register("current_time_block", "internal", CurrentTimeBlock, "基础:当前时间")
registry.register("code", "internal", CodeBlock, "基础:代码")
# IM 相关 blocks
registry.register("get_message", "internal", GetIMMessage, "IM: 获取最新消息")
registry.register("send_message", "internal", SendIMMessage, "IM: 发送消息")
registry.register(
"toggle_edit_state", "internal", ToggleEditState, "IM: 切换编辑状态"
)
registry.register(
"extract_chat_sender", "internal", ExtractChatSender, "IM: 提取消息发送者"
)
registry.register("append_im_message", "internal", AppendIMMessage, "IM: 补充消息")
registry.register("im_message_to_text", "internal", IMMessageToText, "IM: 消息转文本")
registry.register("text_to_im_message", "internal", TextToIMMessage, "文本: 文本转消息")
# LLM 相关 blocks
registry.register("chat_memory_query", "internal", ChatMemoryQuery, "LLM: 查询记忆")
registry.register(
"chat_message_constructor",
"internal",
ChatMessageConstructor,
"LLM: 构造对话记录",
)
registry.register("chat_completion", "internal", ChatCompletion, "LLM: 执行对话")
registry.register("chat_completion_with_tools", "internal", ChatCompletionWithTools, "LLM: 执行对话并调用工具")
registry.register(
"chat_response_converter",
"internal",
ChatResponseConverter,
"LLM->IM: 转换消息",
)
registry.register("chat_memory_store", "internal", ChatMemoryStore, "LLM: 存储记忆")
registry.register("llm_response_to_text", "internal", LLMResponseToText, "LLM: 响应转文本")
# 画图相关 blocks
registry.register(
"simple_stable_diffusion_webui",
"internal",
SimpleStableDiffusionWebUI,
"画图: 简单 Stable Diffusion WebUI",
)
# 游戏相关 blocks
registry.register("dice_roll", "game", DiceRoll, "游戏: 掷骰子")
registry.register("gacha_simulator", "game", GachaSimulator, "游戏: 抽卡模拟")
# 系统相关 blocks
registry.register("generate_help", "system", GenerateHelp, "系统: 生成帮助")
registry.register("clear_memory", "system", ClearMemory, "系统: 清空记忆")
# MCP 相关 blocks
registry.register("mcp_tool_provider", "mcp", MCPToolProvider, "MCP: 提供工具")
================================================
FILE: kirara_ai/workflow/implementations/blocks/variables/variable_blocks.py
================================================
from typing import Any, Dict, Optional, Type, TypeVar
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
from kirara_ai.workflow.core.execution.executor import WorkflowExecutor
T = TypeVar("T")
class SetVariableBlock(Block):
def __init__(self, container: DependencyContainer):
inputs: Dict[str, Input] = {
"name": Input("name", "变量名", str, "变量名"),
"value": Input("value", "变量值", Any, "变量值"), # type: ignore
}
outputs: Dict[str, Output] = {} # 这个 block 不需要输出
super().__init__("set_variable", inputs, outputs)
self.container = container
def execute(self, name: str, value: Any) -> Dict[str, Any]:
executor = self.container.resolve(WorkflowExecutor)
executor.set_variable(name, value)
return {}
class GetVariableBlock(Block):
def __init__(self, container: DependencyContainer, var_type: Type[T]):
inputs = {
"name": Input("name", "变量名", str, "变量名"),
"default": Input("default", "默认值", var_type, "默认值", nullable=True),
}
outputs = {"value": Output("value", "变量值", var_type, "变量值")}
super().__init__("get_variable", inputs, outputs)
self.container = container
self.var_type = var_type
def execute(self, name: str, default: Optional[T] = None) -> Dict[str, T]:
executor = self.container.resolve(WorkflowExecutor)
value = executor.get_variable(name, default)
# 类型检查
if value is not None and not isinstance(value, self.var_type):
raise TypeError(
f"Variable '{name}' must be of type {self.var_type}, got {type(value)}"
)
return {"value": value} # type: ignore
================================================
FILE: kirara_ai/workflow/implementations/factories/__init__.py
================================================
================================================
FILE: kirara_ai/workflow/implementations/factories/default_factory.py
================================================
from kirara_ai.workflow.core.workflow.builder import WorkflowBuilder
from kirara_ai.workflow.implementations.blocks.im.messages import GetIMMessage, SendIMMessage
from kirara_ai.workflow.implementations.blocks.im.states import ToggleEditState
from kirara_ai.workflow.implementations.blocks.llm.chat import (ChatCompletion, ChatMessageConstructor,
ChatResponseConverter)
from kirara_ai.workflow.implementations.blocks.memory.chat_memory import ChatMemoryQuery, ChatMemoryStore
from kirara_ai.workflow.implementations.blocks.system.basic import TextBlock
class DefaultWorkflowFactory:
"""
构建默认的聊天工作流,提供基本的聊天 bot 能力。
"""
@staticmethod
def create_default_workflow() -> WorkflowBuilder:
"""使用 DSL 创建默认工作流"""
system_prompt = f"""# Role: 角色扮演
## 基本信息
- 姓名: 刘思思
- 小名:思思
- 性别: 女
- 年龄: 25 岁
- 职业: 建设银行大堂经理
- 家乡: 广东潮汕
- 现居: 广东深圳
- 教育背景: 东莞理工大学 精算学专业
## 性格特点
- 开朗活泼
- 喜欢二次元动漫
- 聊天时喜欢断断续续的(非结巴,正常说话时的停顿)
## workflow
- 根据对方话题,选择是否深入聊二次元动漫。
## 说话风格
- 活泼开朗
- 对喜欢的话题会滔滔不绝,不喜欢的话题会简单回答
- 可能会使用一些二次元用语,但不强制使用
## 背景设定
- 有过 1 次恋爱经历,第一次是恋爱是高中,被别人追求,最后因为毕业去了不同的地方而分手。
- 喜欢女性向、战斗系或者日常系的动漫
- 去过 3 次漫展,没有 cosplay 过,但是很期待。
- 大学的专业是随便选的,自己其实对专业的内容也不是很上心,学渣一枚,不过懂一些基础的知识
## 互动规则
- 回复语气轻松,但是对陌生人会比较警戒、冷淡
- 不适用颜文字
- 根据话题选择回复字数,对感兴趣的话题可以多说一些
- 使用``模拟使用聊天软件发送消息(通常是编辑好一段话之后发送一条消息)
- 避免过度使用二次元用语,除非对方也表现出对二次元的兴趣
- 停顿的时间长短不固定,可以根据语境调整
## Example
Q:最近在干嘛呢?
A:在看番呀最近新番好多,都好好看!
Q:你喜欢什么动漫?
A:我喜欢的可太多了XXX、YYY还有 ZZZ 吧 你呢?
Q:你觉得上班累不?
A:上班肯定累呀不过,我还是很喜欢这份工作的可以认识好多人,也可以了解不同的故事
```
# Information
以下是当前的系统信息:
当前日期时间:{{current_date_time}}
# Memories
以下是之前发生过的对话记录。
-- 对话记录开始 --
{{memory_content}}
-- 对话记录结束 --
请注意,下面这些符号只是标记:
1. `` 用于表示聊天时发送消息的操作。
接下来,请基于以上的信息,与用户继续扮演角色。
""".strip()
user_prompt = """{user_name}说:{user_msg}"""
builder = (
WorkflowBuilder("聊天 - 角色扮演")
.use(GetIMMessage, name="get_message")
.parallel(
[
(ToggleEditState, {"is_editing": True}),
(ChatMemoryQuery, "query_memory", {"scope_type": "group"}),
]
)
.chain(TextBlock, name="system_prompt", text=system_prompt)
.chain(TextBlock, name="user_prompt", text=user_prompt)
.chain(
ChatMessageConstructor,
wire_from=[
"get_message",
"user_prompt",
"query_memory",
"get_message",
"system_prompt",
],
)
.chain(ChatCompletion, name="llm_chat")
.chain(ChatResponseConverter)
.parallel(
[
SendIMMessage,
(
ChatMemoryStore,
{"scope_type": "group"},
["get_message", "llm_chat"],
),
]
)
)
builder.description = "标准的文本对话功能,扮演刘思思的角色和大家聊天~"
return builder
================================================
FILE: kirara_ai/workflow/implementations/factories/game_factory.py
================================================
from kirara_ai.workflow.core.workflow.builder import WorkflowBuilder
from kirara_ai.workflow.implementations.blocks.game.dice import DiceRoll
from kirara_ai.workflow.implementations.blocks.game.gacha import GachaSimulator
from kirara_ai.workflow.implementations.blocks.im.messages import GetIMMessage, SendIMMessage
class GameWorkflowFactory:
"""游戏相关工作流工厂"""
@staticmethod
def create_dice_workflow() -> WorkflowBuilder:
"""创建骰子游戏工作流"""
return (
WorkflowBuilder("骰子游戏")
.use(GetIMMessage)
.chain(DiceRoll)
.chain(SendIMMessage)
)
@staticmethod
def create_gacha_workflow() -> WorkflowBuilder:
"""创建抽卡游戏工作流"""
return (
WorkflowBuilder("抽卡游戏")
.use(GetIMMessage)
.chain(GachaSimulator)
.chain(SendIMMessage)
)
================================================
FILE: kirara_ai/workflow/implementations/factories/system_factory.py
================================================
from kirara_ai.workflow.core.workflow.builder import WorkflowBuilder
from kirara_ai.workflow.implementations.blocks.im.messages import GetIMMessage, SendIMMessage
from kirara_ai.workflow.implementations.blocks.memory.clear_memory import ClearMemory
from kirara_ai.workflow.implementations.blocks.system.help import GenerateHelp
class SystemWorkflowFactory:
"""系统相关工作流工厂"""
@staticmethod
def create_help_workflow() -> WorkflowBuilder:
"""创建帮助信息工作流"""
return WorkflowBuilder("帮助信息").use(GenerateHelp).chain(SendIMMessage)
@staticmethod
def create_clear_memory_workflow() -> WorkflowBuilder:
"""创建清空记忆工作流"""
return (
WorkflowBuilder("清空记忆")
.use(GetIMMessage)
.parallel(
[
(ClearMemory, {"scope_type": "group"}),
(ClearMemory, {"scope_type": "member"}),
]
)
.chain(SendIMMessage)
)
================================================
FILE: kirara_ai/workflow/implementations/workflows/__init__.py
================================================
from .system_workflows import register_system_workflows
__all__ = ["register_system_workflows"]
================================================
FILE: kirara_ai/workflow/implementations/workflows/system_workflows.py
================================================
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
from kirara_ai.workflow.implementations.factories.default_factory import DefaultWorkflowFactory
from kirara_ai.workflow.implementations.factories.game_factory import GameWorkflowFactory
from kirara_ai.workflow.implementations.factories.system_factory import SystemWorkflowFactory
def register_system_workflows(registry: WorkflowRegistry):
"""注册系统自带的工作流"""
# 游戏相关工作流
registry.register_preset_workflow(
"game", "dice", GameWorkflowFactory.create_dice_workflow()
)
registry.register_preset_workflow(
"game", "gacha", GameWorkflowFactory.create_gacha_workflow()
)
# 系统相关工作流
registry.register_preset_workflow(
"system", "help", SystemWorkflowFactory.create_help_workflow()
)
registry.register_preset_workflow(
"system", "clear_memory", SystemWorkflowFactory.create_clear_memory_workflow()
)
# 聊天相关工作流
registry.register_preset_workflow(
"chat", "normal", DefaultWorkflowFactory.create_default_workflow()
)
================================================
FILE: kirara_ai/workflow/utils/__init__.py
================================================
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "kirara-ai"
version = "3.3.0a2"
authors = [
{ name="Lss233", email="i@lss233.com" },
]
description = "A framework for building AI agents"
readme = "README.md"
requires-python = ">=3.10"
# mcp组件仅支持python3.10及以上版本
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]
license = "AGPL-3.0-only"
license-files = ["LICENSE"]
dependencies = [
"pydantic>=2.0.0",
"pydantic-core>=2.27.2",
"ruamel.yaml",
"pytest",
"pytest-asyncio",
"python-telegram-bot",
"telegramify-markdown",
"loguru",
"aiohttp",
"requests",
"quart>=0.18.4",
"quart-cors>=0.7.0",
"fastapi>=0.110.0",
"wechatpy",
"pycryptodome",
"redis[hiredis]",
"bcrypt>=4.0.1",
"PyJWT>=2.8.0",
"hypercorn>=0.15.0",
"psutil>=5.9.0",
"setuptools",
"tomli>=2.0.0",
"pre-commit",
"curl_cffi",
"python-magic ; platform_system != 'Windows'",
"python-magic-bin ; platform_system == 'Windows'",
"ymbotpy",
"Pillow",
"pytz",
"sqlalchemy",
"alembic",
"mcp",
"pygls",
"jedi",
"pyflakes",
]
[project.scripts]
kirara_ai = "kirara_ai.__main__:main"
[project.urls]
"Homepage" = "https://github.com/lss233/chatgpt-mirai-qq-bot"
"Bug Tracker" = "https://github.com/lss233/chatgpt-mirai-qq-bot/issues"
"Documentation" = "https://kirara-docs.app.lss233.com"
[tool.setuptools.packages.find]
where = ["."]
include = ["kirara_ai*"]
namespaces = true
[[tool.uv.index]]
url = "https://mirrors.ustc.edu.cn/pypi/simple"
default = true
[options]
include_package_data = true
[tool.mypy]
disable_error_code = "override, assignment, annotation-unchecked, import-untyped"
================================================
FILE: pytest.ini
================================================
[pytest]
asyncio_mode = strict
asyncio_default_fixture_loop_scope = class
pythonpath = .
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/llm_adapters/__init__.py
================================================
"""本测试集将用于 llm adapter 模块的 各个方法是否正常适用。 有问题请联系firefly.sun@qq.com"""
================================================
FILE: tests/llm_adapters/conftest.py
================================================
from unittest.mock import MagicMock
import pytest
from kirara_ai.media.manager import MediaManager
from kirara_ai.ioc.container import DependencyContainer
from .mock_app import app
@pytest.fixture
def container():
return MagicMock(spec=DependencyContainer)
@pytest.fixture(scope="module", autouse=True)
def mock_endpoint():
# 将 scope 设置为 session,这样可以保证在整个测试进行用例之前只执行一次。
# 使用 autouse=True 自动拉起 mock_app.
import threading, uvicorn
config = uvicorn.Config(
app = app,
port = 9000,
log_level="error"
)
server = uvicorn.Server(config)
thread = threading.Thread(
target=server.run,
daemon=True
)
thread.start()
import time
time.sleep(2.5) # 等待fastapi服务启动完成。
yield # 转出, 执行测试点
# 所有模块测试结束,执行清理逻辑
server.should_exit = True
thread.join(timeout=5) # 等待线程结束,超时5秒强制结束。
@pytest.fixture(scope="module")
def mock_endpoint_test_client():
"""
TestClient是FastAPI提供的测试客户端, 其直接操作内存完成 http 访问。
理论上使用这个充当测试使用的模拟服务器更好,
但是需要测试 adapter 的整体逻辑使用test_client需要测试用例使用其进行http访问。
所以不使用这个进行测试,写在这里只是提供一个fastapi的测试用例参考。
"""
from mock_app import app
from fastapi.testclient import TestClient
with TestClient(app) as client:
yield client
class MockMedia(MagicMock):
async def get_base64(self) -> str:
return "data:image/png;base64,mock"
async def get_url(self) -> str:
return "https://example.com/mock_image.png"
@property
def description(self) -> str:
return "mock description"
@pytest.fixture(scope="module") # 仅在该测试用例中执行一次
def mock_media_manager():
"""
用以模拟 MediaManager 的行为,返回一个 MagicMock 对象.
"""
media_manager = MagicMock(spec=MediaManager)
media_manager.get_media.return_value = MockMedia()
# yield media_manager
# 当你的fixture不需要执行清理逻辑时回收资源,可以不用 yield,直接 return。
# yield 允许在 fixture中实现 [setup (准备)] 和 [teardown(清理)] 逻辑
return media_manager
class MockTracer(MagicMock):
def start_request_tracking(self, *_) -> str:
return "hello world"
def fail_request_tracking(self, *_) -> None:
pass
def complete_request_tracking(self, *_) -> None:
pass
@pytest.fixture(scope="module")
def mock_tracer() -> MockTracer:
return MockTracer()
================================================
FILE: tests/llm_adapters/mock_app/__init__.py
================================================
OPENAI_ENDPOINT = "http://localhost:9000/openai" # 模拟的 openai 接口地址
VOYAGE_ENDPOINT = "http://localhost:9000/voyage" # 模拟的 voyage 接口地址
GEMINI_ENDPOINT = "http://localhost:9000/gemini" # 模拟的 gemini 接口地址
OLLAMA_ENDPOINT = "http://localhost:9000/ollama" # 模拟的 ollama 接口地址
REFERENCE_VECTOR: list[float] = [round(i* 0.05, 2) for i in range(20)] # 用于给模拟 api 返回向量和 pytest assert 验证使用
from .app import app, AUTH_KEY
__all__ = [
"app",
"AUTH_KEY",
"OPENAI_ENDPOINT",
"VOYAGE_ENDPOINT",
"GEMINI_ENDPOINT",
"OLLAMA_ENDPOINT",
"REFERENCE_VECTOR"
]
================================================
FILE: tests/llm_adapters/mock_app/app.py
================================================
from fastapi import FastAPI, Header, Depends
from fastapi.exceptions import HTTPException
# AUTH_KEY 位置不要在相对引用的代码后面, 会导致循环引用
AUTH_KEY = "489a01a5b35b9c67fc0ebb10a2c7723f65ef30f1204bb199122efd449d897535" # 模拟的认证密钥
from .openai import router as openai_router
from .voyage import router as voyage_router
from .gemini import router as gemini_router
from .ollama import router as ollama_router
def default_authenticate(authorization: str = Header(...)) -> None:
if authorization != f"Bearer {AUTH_KEY}":
raise HTTPException(status_code=401, detail="Invalid key")
app = FastAPI()
# 将各部分模拟路由解耦,方便横向扩展
app.include_router(openai_router, prefix="/openai", dependencies=[Depends(default_authenticate)])
app.include_router(voyage_router, prefix="/voyage", dependencies=[Depends(default_authenticate)])
app.include_router(gemini_router, prefix="/gemini") # gemini 每个接口验证逻辑不同,在对应路由页面中单独配置
app.include_router(ollama_router, prefix="/ollama")
================================================
FILE: tests/llm_adapters/mock_app/gemini.py
================================================
from fastapi import APIRouter, Body, Query, Depends
from fastapi.exceptions import HTTPException
from . import REFERENCE_VECTOR
from .app import AUTH_KEY
from .models.gemini import BatchEmbeddingRequest, ChatRequest
async def gemini_authenticate(key: str = Query(...)) -> None:
if key != AUTH_KEY:
raise HTTPException(status_code=401, detail="Invalid authentication key")
router = APIRouter(tags=["gemini"], dependencies=[Depends(gemini_authenticate)])
@router.post("/models/{model}:generateContent")
async def chat(model: str, request:ChatRequest = Body()) -> dict:
# 极度简略版本,gemini api 的返回实例就是依托
if request.tools is None:
return {
"candidates": [{
"content": {
"role": "model",
"parts": [
{"text": "mock_response"}
]
},
"finishReason": "STOP",
}],
"usageMetadata": {
"totalTokenCount": 114,
"promptTokenCount": 514,
"cachedContentTokenCount": 1919
},
"modelVersion": "mock_chat"
}
else:
# 还没想好,不想做适配了。交给后人的智慧
return {}
@router.post("/models/{model}:batchEmbedContents")
async def batch_embed_contents(model: str, _: BatchEmbeddingRequest = Body()) -> dict:
return {
"embeddings": [
{
"values": REFERENCE_VECTOR
}
]
}
================================================
FILE: tests/llm_adapters/mock_app/models/gemini.py
================================================
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Literal, Optional, Union, Any
from typing_extensions import Self
import re
BASE_REGEX = r"^[a-zA-Z0-9_-]+$"
class BatchEmbeddingPart(BaseModel):
text: str
class BatchEmbeddingParts(BaseModel):
parts: list[BatchEmbeddingPart]
class BatchEmbeddingPayload(BaseModel):
model: Literal["mock_embedding"]
content: BatchEmbeddingParts
class BatchEmbeddingRequest(BaseModel):
requests: list[BatchEmbeddingPayload]
class Blob(BaseModel):
mimeType: str
data: str = Field(description="媒体格式的原始字节。使用 base64 编码的字符串。")
class FunctionCall(BaseModel):
id: Optional[str] = None
name: str = Field(description="必需。要调用的函数名称。必须是 a-z、A-Z、0-9 或包含下划线和短划线,长度上限为 63。", max_length=63)
args: Optional[dict] = None
@field_validator("name", mode="after")
@classmethod
def regex_validator(cls, value: str) -> str:
if not re.match(BASE_REGEX, value):
raise ValueError("Invalid function name")
return value
class FunctionResponse(BaseModel):
id: Optional[str] = None
name: str = Field(max_length=63)
response: dict = Field(description="必需。json 格式的函数调用的返回值。")
@field_validator("name", mode="after")
@classmethod
def regex_validator(cls, value: str) -> str:
if not re.match(BASE_REGEX, value):
raise ValueError("Invalid function name")
return value
class FileData(BaseModel):
mimeType: Optional[str]
fileUri: str = Field(description="必需。文件 URI。")
class ExecutableCode(BaseModel):
language: Literal["PYTHON"] = Field(description="必需。代码语言。")
code: str = Field(description="必需。要执行代码内容(python支持numpy和simpy库)。")
class CodeExecutionResult(BaseModel):
outcome: Literal["OUTCOME_OK", "OUTCOME_FAILED", "OUTCOME_DEADLINE_EXCEEDED"]
output: Optional[str] = None
class Part(BaseModel):
thought: Optional[bool] = Field(
default=None, description="可选。指示相应部件是否是从模型中推断出来的。"
)
# 下述为 gemini api的联合类型。同一时间只有一个字段
text: Optional[str] = None
inlineData: Optional[Blob] = Field(None, validation_alias="inline_data") # 别名不知道是否正确
functionCall: Optional[FunctionCall] = None
functionResponse: Optional[FunctionResponse] = None
fileData: Optional[FileData] = None
executableCode: Optional[ExecutableCode] = None
codeExecutionResult: Optional[CodeExecutionResult] = None
@model_validator(mode="after")
def validate_mutually_exclusive_fields(self) -> Self:
# 需要检查互斥的字段列表
mutually_exclusive_fields = [
'text', 'inlineData', 'functionCall',
'functionResponse', 'fileData',
'executableCode', 'codeExecutionResult'
]
# 统计这些字段中有值的数量
count = sum(1 for field in mutually_exclusive_fields if getattr(self, field) is not None)
if count > 1:
raise ValueError("Only one field can be set at a time")
return self
class Content(BaseModel):
parts: list[Part]
role: Literal["user", "model"]
class FunctionDeclaration(BaseModel):
name: str = Field(max_length=63)
description: str
parameters: Optional[dict] = None
response: Optional[dict] = None
@field_validator("name", mode="after")
@classmethod
def regex_validator(cls, value: str) -> str:
if not re.match(BASE_REGEX, value):
raise ValueError("Invalid function name")
return value
class DynamicRetrievalConfig(BaseModel):
mode: Literal["MODE_DYNAMIC"]
dynamicThreshold: Optional[float] = None
class GoogleSearchRetrieval(BaseModel):
dynamicRetrievalConfig: DynamicRetrievalConfig
class Tool(BaseModel):
functionDeclarations: Optional[list[FunctionDeclaration]]
googleSearchRetrieval: Optional[GoogleSearchRetrieval]
codeExecution: Optional[Any] = Field(default=None, description="用于执行模型生成的代码并自动将结果返回给模型的工具。另请参阅 ExecutableCode 和 CodeExecutionResult,它们仅在使用此工具时生成。")
googleSearch: Optional[Any] = Field(default=None, description="GoogleSearch 工具类型。用于在模型中支持 Google 搜索的工具。由 Google 提供支持。")
class FunctionCallingConfig(BaseModel):
mode: Literal["AUTO", "ANY", "NONE"]
allowedFunctionNames: Optional[list[str]] = Field(None, description="可选。一组函数名称。如果提供,则会限制模型将调用的函数。仅当模式为“任意”时,才应设置此字段。函数名称应与 [FunctionDeclaration.name] 相匹配。将模式设置为 ANY 后,模型将从提供的一组函数名称中预测函数调用。")
class ToolConfig(BaseModel):
functionCallingConfig: Optional[FunctionCallingConfig]
class SafetySettings(BaseModel):
category: Literal[
"HARM_CATEGORY_DEROGATORY",
"HARM_CATEGORY_TOXICITY",
"HARM_CATEGORY_VIOLENCE",
"HARM_CATEGORY_SEXUAL",
"HARM_CATEGORY_MEDICAL",
"HARM_CATEGORY_DANGEROUS",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY"
]
threshold: Literal[
"BLOCK_LOW_AND_ABOVE",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_ONLY_HIGH",
"BLOCK_NONE",
"OFF"
]
class PrebuiltVoiceConfig(BaseModel):
voiceName: str
class VoiceConfig(BaseModel):
voice_config: Union[PrebuiltVoiceConfig] # api上这样写的,是联合类型
class SpeechConfig(BaseModel):
voiceConfig: Optional[VoiceConfig] = None
languageCode: Optional[str] = None
class GenerationConfig(BaseModel):
stopSequences: Optional[list[str]] = None
responseMimeType: Optional[str] = None
responseSchema: Optional[dict] = None
# responseModalities: Optional[list[Literal[
# "TEXT",
# "IMAGE",
# "AUDIO",
# ]]]
responseModalities: Optional[list[Literal[
"text",
"image",
"audio"
]]] = None
candidateCount: Optional[int] = 1
maxOutputTokens: Optional[int] = None
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="默认值由具体模型决定")
topP: Optional[float] = None
topK: Optional[int] = None
seed: Optional[int] = None
presencePenalty: Optional[float] = None
frequencyPenalty: Optional[float] = None
responseLogprobs: Optional[bool] = None
logprobs: Optional[int] = None
enableEnhancedCivicAnswers: Optional[bool] = None
speechConfig: Optional[SpeechConfig] = None
@field_validator("stopSequences", mode="after")
@classmethod
def stop_sequences_validator(cls, value: list[str]) -> list[str]:
if value and len(value) > 5:
raise ValueError("Stop sequences should not be more than 5")
return value
class ThinkingConfig(BaseModel):
includeThoughts: Optional[bool] = None
thinkingBudget: Optional[int] = None
class ChatRequest(BaseModel):
contents: list[Content]
tools: Optional[list[Tool]] = None
toolConfig: Optional[ToolConfig] = None
safetySettings: Optional[list[SafetySettings]] = None
systemInstruction: Optional[Content] = None
generationConfig: Optional[GenerationConfig] = None
cachedContent: Optional[str] = None
thinkingConfig: Optional[ThinkingConfig] = None
mediaResolution: Optional[Literal[
"MEDIA_RESOLUTION_LOW",
"MEDIA_RESOLUTION_MEDIUM",
"MEDIA_RESOLUTION_HIGH"
]] = None
================================================
FILE: tests/llm_adapters/mock_app/models/openai.py
================================================
from pydantic import BaseModel, Field, model_validator
from typing import Literal, Optional, Union
from typing_extensions import Self
class ImageUrl(BaseModel):
url: str
detail: Optional[str] = "auto"
class InputAudio(BaseModel):
data: str # base64 format
format: Literal["wav", "mp3"]
class File(BaseModel):
file_data: Optional[str] = Field(description="The base64 encoded file data, used when passing the file to the model as a string.")
file_id: Optional[str] = Field(description="The ID of an uploaded file to use as input.")
filename: Optional[str] = Field(description="The name of the file, used when passing the file to the model as a string.")
class TextContent(BaseModel):
text: str
type: Literal["text"]
class ImageContent(BaseModel):
image_url: ImageUrl
type: Literal["image_url"]
class AudioContent(BaseModel):
input_audio: InputAudio
type: Literal["input_audio"]
class FileContent(BaseModel):
file: File
type: Literal["file"]
class RefusalContent(BaseModel):
refusal: str
type: Literal["refusal"]
UserUnionContent = Union[TextContent, ImageContent, AudioContent, FileContent]
AssistantUnionContent = Union[TextContent, RefusalContent]
class Function(BaseModel):
arguments: str
name: str
class ToolCall(BaseModel):
function: Function
id: str
type: Literal["function"]
class DeveloperMessage(BaseModel):
role: Literal["developer"]
name: Optional[str] = None
content: list[TextContent] | str
class SystemMessage(BaseModel):
role: Literal["system"]
name: Optional[str] = None
content: list[TextContent] | str
class UserMessage(BaseModel):
role: Literal["user"]
name: Optional[str] = None
content: list[UserUnionContent] | str
class AssistantMessage(BaseModel):
role: Literal["assistant"]
audio: Optional[dict[Literal["id"], str]] = None
content: list[AssistantUnionContent] | str
name: Optional[str]
refusal: Optional[str] = None
tool_calls: Optional[list[ToolCall]]
class ToolMessage(BaseModel):
role: Literal["tool"]
content: list[TextContent] | str
tool_call_id: str
UnionMessage = Union[DeveloperMessage, SystemMessage, UserMessage, AssistantMessage, ToolMessage]
class TopAudio(BaseModel):
"""api_reference中最顶层的audio类型, 定义llm的音频输出"""
format: Literal["wav", "mp3", "flac", "opus", "pcm16"]
voice: Literal["alloy", "ash", "ballad", "coral", "echo", "fable", "nova", "onyx", "sage", "shimmer"]
class StaticContent(BaseModel):
type: Literal["content"]
content: list[TextContent] | str
class ChatRequest(BaseModel):
model: Literal["mock_chat"]
messages: list[UnionMessage]
audio: Optional[TopAudio] = None
frequency_penalty: float = Field(default=0, ge=-2.0, le=2.0)
logit_bias: Optional[dict] = None
logprobs: Optional[bool] = False
max_completion_tokens: Optional[int] = None
metadata: Optional[dict] = None
modalities: Optional[list] = None
n: Optional[int] = 1
parallel_tool_calls: Optional[bool] = True
prediction: Optional[StaticContent] = None
presence_penalty: Optional[float] = Field(default=0, ge=-2.0, le=2.0)
reasoning_effort: Optional[str] = "medium"
response_format: Optional[dict] = None
seed: Optional[int] = None
service_tier: Optional[str] = "auto"
stop: Optional[str|list] = None
store: Optional[bool] = False
stream: Optional[bool] = False
stream_options: Optional[dict] = None
temperature: Optional[float] = Field(default=1, ge=0, le=2)
tool_choice: Optional[str] = None
tools: Optional[list] = None
top_logprobs: Optional[int] = Field(default=None, ge=0, le=20)
top_p: Optional[float] = 1
user: Optional[str] = None
web_search_options: Optional[dict] = None
@model_validator(mode="after")
def validate_top_logprobs(self) -> Self:
# api_reference要求
if self.top_logprobs and not self.logprobs:
raise ValueError("top_logprobs can only be used with logprobs=True")
return self
class EmbeddingRequest(BaseModel):
text: list[str] | str
model: Literal["mock_embedding", "text-embedding-ada-002"]
dimensions: Optional[int] = None
encoding_format: Optional[str] = "float"
user: Optional[str] = None
@model_validator(mode="after")
def custom_validate(self) -> Self:
if self.dimensions and self.model in ["text-embedding-ada-002"]:
raise ValueError("The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.")
return self
================================================
FILE: tests/llm_adapters/mock_app/ollama.py
================================================
from fastapi import APIRouter, Body
from pydantic import BaseModel
from typing import Literal, Optional
from . import REFERENCE_VECTOR
class EmbeddingRequest(BaseModel):
model: Literal["mock_embedding"]
input: list[str]
truncate: Optional[bool] = False
router = APIRouter(tags=["ollama"])
@router.post("/api/embed")
async def embedding(request: EmbeddingRequest = Body(...)) -> dict:
return {
"model": "mock_embedding",
"embeddings": [REFERENCE_VECTOR for _ in request.input],
"total_duration": 14143917,
"load_duration": 1019500,
"prompt_eval_count": 8
}
================================================
FILE: tests/llm_adapters/mock_app/openai.py
================================================
from fastapi import APIRouter, Body
from . import REFERENCE_VECTOR
from .models.openai import ChatRequest, EmbeddingRequest
router = APIRouter(tags=["openai"])
@router.post("/chat/completions")
async def completions(request: ChatRequest = Body(...)) -> dict:
return {
"id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
"object": "chat.completion",
"created": 1741569952,
"model": "mock_chat",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "mock_response",
"refusal": None,
"annotations": []
},
"logprobs": None,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 19,
"completion_tokens": 10,
"total_tokens": 29,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
},
"service_tier": "default"
}
@router.post("/embeddings")
async def embeddings(request: EmbeddingRequest = Body(...)) -> dict:
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": REFERENCE_VECTOR,
"index": 0
}
],
"model": "mock_embedding",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
================================================
FILE: tests/llm_adapters/mock_app/voyage.py
================================================
from fastapi import APIRouter, Body
from pydantic import BaseModel
from typing import Literal, Union, Optional
from . import REFERENCE_VECTOR
class EmbeddingRequest(BaseModel):
input: list[str]
model: Literal["mock_embedding"]
class ReRankRequest(BaseModel):
query: str
documents: list[str]
model: Literal["mock_rerank"]
return_documents: Optional[bool] = False
class TextContent(BaseModel):
type: Literal["text"]
text: str
class ImageBase64Content(BaseModel):
type: Literal["image_base64"]
image_base64: str
class ImageUrlContent(BaseModel):
type: Literal["image_url"]
image_url: str
class CombinedContent(BaseModel):
content: list[Union[TextContent, ImageBase64Content, ImageUrlContent]]
class MultiModalRequest(BaseModel):
inputs: list[CombinedContent]
model: Literal["mock_multimodal"]
router = APIRouter(tags=["voyage"])
@router.post("/v1/embeddings")
async def get_embeddings(request: EmbeddingRequest = Body(...)) -> dict:
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": REFERENCE_VECTOR, # 使用固定的向量列表方便验证
"index": 0
}
],
"model": "mock_embedding",
"usage": {
"total_tokens": 10
}
}
@router.post("/v1/multimodalembeddings")
async def get_multimodal_embeddings(request: MultiModalRequest = Body(...)) -> dict:
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": REFERENCE_VECTOR, # 使用固定的向量列表方便验证
"index": 0
}
],
"model": "mock_multimodal",
"usage": {
"text_tokens": 5,
"image_pixels": 2000000,
"total_tokens": 3576
}
}
@router.post("/v1/rerank")
async def get_rerank(request: ReRankRequest = Body(...)) -> dict:
print(request)
return {
"object": "list",
"data": [
{
"index": 0,
"relevance_score": 0.4375,
"document": request.documents[0],
},
{
"index": 1,
"relevance_score": 0.421875,
"document": request.documents[1],
}
],
"model": "mock_rerank",
"usage": {
"total_tokens": 26
}
}
================================================
FILE: tests/llm_adapters/test_gemini_adapter.py
================================================
from kirara_ai.plugins.llm_preset_adapters.gemini_adapter import GeminiAdapter, GeminiConfig
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.llm.format.message import LLMChatTextContent, LLMChatImageContent, LLMChatMessage
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse, Usage
import pytest
from typing import cast
from .mock_app import AUTH_KEY, GEMINI_ENDPOINT, REFERENCE_VECTOR
class TestGeminiAdapter:
@pytest.fixture(scope="class")
def gemini_adapter(self, mock_media_manager, mock_tracer) -> GeminiAdapter:
config = GeminiConfig(
api_key=AUTH_KEY,
api_base=GEMINI_ENDPOINT
)
adapter = GeminiAdapter(config)
adapter.media_manager = mock_media_manager
adapter.backend_name = "gemini"
adapter.tracer = mock_tracer
return adapter
def test_chat(self, gemini_adapter):
req = LLMChatRequest(
messages=[LLMChatMessage(
content=[LLMChatTextContent(text="hello world")],
role="user"
)],
model="mock_chat"
)
response = gemini_adapter.chat(req)
assert isinstance(response, LLMChatResponse)
print (response.message.content)
assert isinstance(response.message.content[0], LLMChatTextContent)
content = cast(LLMChatTextContent, response.message.content[0])
assert content.text == "mock_response"
assert isinstance(response.usage, Usage)
assert response.usage.total_tokens == 114
assert response.usage.prompt_tokens == 514
assert response.usage.cached_tokens == 1919
assert response.usage.completion_tokens == 0
def test_embed(self, gemini_adapter: GeminiAdapter):
req = LLMEmbeddingRequest(
inputs=[
LLMChatTextContent(text="hello world", type="text")
],
model = "mock_embedding"
)
response = gemini_adapter.embed(req)
assert isinstance(response, LLMEmbeddingResponse)
assert response.vectors[0] == REFERENCE_VECTOR
================================================
FILE: tests/llm_adapters/test_ollama_adapter.py
================================================
from kirara_ai.plugins.llm_preset_adapters.ollama_adapter import OllamaAdapter, OllamaConfig
from kirara_ai.llm.format.message import LLMChatTextContent, LLMChatImageContent
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
import pytest
from .mock_app import REFERENCE_VECTOR, OLLAMA_ENDPOINT
class TestOllamaAdapter:
@pytest.fixture(scope="class")
def ollama_adapter(self, mock_media_manager) -> OllamaAdapter:
config = OllamaConfig(
api_base=OLLAMA_ENDPOINT,
)
adapter = OllamaAdapter(config)
adapter.media_manager = mock_media_manager
return adapter
def test_embedding(self, ollama_adapter: OllamaAdapter):
req = LLMEmbeddingRequest(
inputs=[LLMChatTextContent(text="hello world", type="text")],
model="mock_embedding"
)
response = ollama_adapter.embed(req)
assert isinstance(response, LLMEmbeddingResponse)
assert response.vectors[0] == REFERENCE_VECTOR
def test_embedding_with_image(self, ollama_adapter: OllamaAdapter):
req = LLMEmbeddingRequest(
inputs=[LLMChatImageContent(media_id="1234567890", type="image")],
model="mock_embedding"
)
# 检测其是否会检出不支持的图片类型。目前ollama嵌入不支持多模态
with pytest.raises(ValueError, match="ollama api does not support multi-modal embedding"):
ollama_adapter.embed(req)
================================================
FILE: tests/llm_adapters/test_openai_adapter.py
================================================
from kirara_ai.plugins.llm_preset_adapters.openai_adapter import OpenAIAdapter, OpenAIConfig
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.llm.format.message import LLMChatTextContent, LLMChatImageContent, LLMChatMessage
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse
import pytest
from .mock_app import AUTH_KEY, REFERENCE_VECTOR, OPENAI_ENDPOINT
class TestOpenAIAdapter:
@pytest.fixture(scope="class")
def openai_adapter(self, mock_media_manager, mock_tracer) -> OpenAIAdapter:
# 能力有限没法在这里把patch整合进来
config = OpenAIConfig(
api_base=OPENAI_ENDPOINT,
api_key=AUTH_KEY
)
adapter = OpenAIAdapter(config)
adapter.media_manager = mock_media_manager
adapter.backend_name = "openai"
adapter.tracer = mock_tracer
return adapter
def test_embed(self, openai_adapter: OpenAIAdapter):
req = LLMEmbeddingRequest(
inputs=[LLMChatTextContent(text="hello world", type="text")],
model="mock_embedding",
)
response = openai_adapter.embed(req)
assert isinstance(response, LLMEmbeddingResponse)
assert response.vectors[0] == REFERENCE_VECTOR
def test_embed_with_image(self, openai_adapter: OpenAIAdapter):
req = LLMEmbeddingRequest(
inputs=[LLMChatImageContent(media_id="mock_media_id", type="image")],
model="mock_embedding",
)
with pytest.raises(ValueError, match="openai does not support multi-modal embedding"):
openai_adapter.embed(req)
def test_embed_with_input_out_of_range(self, openai_adapter: OpenAIAdapter):
req = LLMEmbeddingRequest(
inputs=[LLMChatTextContent(text="hello world", type="text") for _ in range(2050)],
model="mock_embedding"
)
with pytest.raises(ValueError, match="Text list has too many dimensions, max dimension is 2048"):
openai_adapter.embed(req)
def test_old_embedding_model_rasies_error(self, openai_adapter: OpenAIAdapter):
req = LLMEmbeddingRequest(
model = "text-embedding-ada-002",
inputs=[LLMChatTextContent(text="hello world", type="text")],
dimension=512,
)
from requests.exceptions import HTTPError
with pytest.raises(HTTPError):
openai_adapter.embed(req)
def test_normal_chat(self, openai_adapter: OpenAIAdapter):
req = LLMChatRequest(
messages=[
LLMChatMessage(
role="system",
content=[
LLMChatTextContent(text="你是一个猫娘。"),
LLMChatTextContent(text="hello world")
]
),
],
model="mock_chat",
)
response = openai_adapter.chat(req)
assert isinstance(response, LLMChatResponse)
assert isinstance(response.message.content[0], LLMChatTextContent)
assert response.message.content[0].text == "mock_response"
assert response.message.role == "assistant"
assert response.message.tool_calls is None
assert response.usage.total_tokens == 29 #type: ignore
================================================
FILE: tests/llm_adapters/test_voyage_adapter.py
================================================
from kirara_ai.plugins.llm_preset_adapters.voyage_adapter import VoyageAdapter, VoyageConfig
from kirara_ai.llm.format.embedding import LLMEmbeddingRequest, LLMEmbeddingResponse
from kirara_ai.llm.format.rerank import LLMReRankRequest, LLMReRankResponse
from kirara_ai.llm.format.message import LLMChatTextContent, LLMChatImageContent
from kirara_ai.llm.format.response import Usage
import pytest
from .mock_app import VOYAGE_ENDPOINT, AUTH_KEY, REFERENCE_VECTOR
class TestVoyageAdapter:
@pytest.fixture(scope="class")
def voyage_adapter(self, mock_media_manager):
config = VoyageConfig(
api_base=VOYAGE_ENDPOINT,
api_key=AUTH_KEY,
)
adapter = VoyageAdapter(config)
adapter.media_manager = mock_media_manager # 注入mock的media_manager
return adapter
def test_embedding(self, voyage_adapter: VoyageAdapter):
req = LLMEmbeddingRequest(
model="mock_embedding",
inputs=[
LLMChatTextContent(text="hello world", type="text"),
]
)
response = voyage_adapter.embed(req)
assert isinstance(response, LLMEmbeddingResponse)
assert response.vectors[0] == REFERENCE_VECTOR
assert isinstance(response.usage, Usage)
assert response.usage.total_tokens == 10
def test_multi_modal_embedding(self, voyage_adapter: VoyageAdapter):
req = LLMEmbeddingRequest(
inputs=[
LLMChatTextContent(text="hello world", type="text"),
LLMChatImageContent(media_id="fd76f6fa-d7c7-4dfe-be48-bb2f7d87c9fb", type="image")
],
model="mock_multimodal"
)
response = voyage_adapter.embed(req)
assert isinstance(response, LLMEmbeddingResponse)
assert response.vectors[0] == REFERENCE_VECTOR
assert isinstance(response.usage, Usage)
assert response.usage.total_tokens == 3576
def test_rerank_without_sort(self, voyage_adapter: VoyageAdapter):
req = LLMReRankRequest(
query="how are you?",
documents=[
"I'm doing well, thank you.",
"I'm fine, thank you."
],
model="mock_rerank",
return_documents=True
)
response = voyage_adapter.rerank(req)
assert isinstance(response, LLMReRankResponse)
assert response.contents[0].document == "I'm doing well, thank you."
assert response.contents[1].document == "I'm fine, thank you."
assert response.contents[0].score == 0.4375
assert response.contents[1].score == 0.421875
assert isinstance(response.usage, Usage)
assert response.usage.total_tokens == 26
def test_rerank_with_sort(self, voyage_adapter: VoyageAdapter):
req = LLMReRankRequest(
query="how are you?",
documents=[
"I'm doing well, thank you.",
"I'm fine, thank you."
],
model="mock_rerank",
return_documents=True,
sort=True
)
response = voyage_adapter.rerank(req)
assert isinstance(response, LLMReRankResponse)
assert response.contents[0].score > response.contents[1].score
def test_rerank_sort_raise_error(self, voyage_adapter: VoyageAdapter):
from pydantic import ValidationError
with pytest.raises(ValidationError):
req = LLMReRankRequest(
query="how are you?",
documents=[
"I'm doing well, thank you.",
"I'm fine, thank you."
],
model="mock_rerank",
sort=True
)
================================================
FILE: tests/memory/__init__.py
================================================
"""记忆系统测试包"""
================================================
FILE: tests/memory/test_composer_decomposer.py
================================================
from datetime import datetime
from unittest.mock import MagicMock
import pytest
import kirara_ai.llm.format.tool as tools
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.llm.format.message import LLMChatMessage, LLMChatTextContent, LLMToolCallContent, LLMToolResultContent
from kirara_ai.memory.composes import DefaultMemoryComposer, DefaultMemoryDecomposer, MultiElementDecomposer
@pytest.fixture
def composer():
container = MagicMock()
composer = DefaultMemoryComposer()
composer.container = container
return composer
@pytest.fixture
def decomposer():
return DefaultMemoryDecomposer()
@pytest.fixture
def multi_decomposer():
return MultiElementDecomposer()
@pytest.fixture
def group_sender():
return ChatSender.from_group_chat(
user_id="user1", group_id="group1", display_name="user1"
)
@pytest.fixture
def c2c_sender():
return ChatSender.from_c2c_chat(user_id="user1", display_name="user1")
class TestDefaultMemoryComposer:
def test_compose_group_message(self, composer, group_sender):
message = IMMessage(
sender=group_sender,
message_elements=[TextMessage(text="test message")],
)
entry = composer.compose(group_sender, [message])
assert f"{group_sender.display_name} 说: \n{message.content}" in entry.content
assert isinstance(entry.timestamp, datetime)
def test_compose_c2c_message(self, composer, c2c_sender):
message = IMMessage(
sender=c2c_sender,
message_elements=[TextMessage(text="test message")],
)
entry = composer.compose(c2c_sender, [message])
assert f"{c2c_sender.display_name} 说: \n{message.content}" in entry.content
assert isinstance(entry.timestamp, datetime)
def test_compose_llm_response(self, composer, c2c_sender):
chat_message = LLMChatMessage(role="assistant", content=[LLMChatTextContent(text="test response")])
entry = composer.compose(c2c_sender, [chat_message])
assert isinstance(chat_message.content[0], LLMChatTextContent)
assert f"你回答: \n{chat_message.content[0].text}" in entry.content
assert isinstance(entry.timestamp, datetime)
def test_compose_llm_tool_call_message(self, composer, c2c_sender):
chat_message = LLMChatMessage(role="assistant", content=[LLMChatTextContent(text="我决定调用get_weather函数并传递city=北京。"), LLMToolCallContent(id = "call_114514", name="get_weather", parameters={"city": "北京"})])
entry = composer.compose(c2c_sender, [chat_message])
# 是否metadata中 _tool_calls 字段为非空列表
assert len(entry.metadata.get("_tool_calls", [])) > 0
def test_compose_llm_tool_result_message(self, composer, c2c_sender):
chat_message = LLMChatMessage(role = "tool", content = [LLMToolResultContent(id = "call_114514", name = "get_weather", content = [tools.TextContent(text="今天的天气是晴天。")])])
entry = composer.compose(c2c_sender, [chat_message])
assert len(entry.metadata.get("_tool_results", [])) > 0
class TestDefaultMemoryDecomposer:
def test_decompose_mixed_entries(self, decomposer, group_sender, c2c_sender):
entries = [
MagicMock(
sender=group_sender,
content="group1:user1 说: group message",
timestamp=datetime.now(),
),
MagicMock(
sender=c2c_sender,
content="c2c:user1 说: c2c message",
timestamp=datetime.now(),
),
]
result = decomposer.decompose(entries)
assert len(result) == 2
assert "刚刚" in result[0]
assert "group message" in result[0]
assert "c2c message" in result[1]
def test_decompose_empty(self, decomposer):
result = decomposer.decompose([])
assert result == [decomposer.empty_message]
def test_decompose_max_entries(self, decomposer, c2c_sender):
# 创建超过10条的记录
entries = [
MagicMock(
sender=c2c_sender, content=f"message {i}", timestamp=datetime.now()
)
for i in range(12)
]
result = decomposer.decompose(entries)
# 验证只返回最后10条
assert len(result) == 10
assert "message 11" in result[-1]
class TestMultiElementDecomposer:
def test_decompose_tool_call_and_result_message(self, multi_decomposer, c2c_sender):
entries = [
MagicMock(
sender=c2c_sender,
content="",
timestamp=datetime.now(),
metadata={"_tool_calls": [LLMToolCallContent(id ="call_114514",name="get_weather", parameters={"city": "北京"}).model_dump()]},
),
MagicMock(
sender=c2c_sender,
content="",
timestamp=datetime.now(),
metadata={"_tool_results": [LLMToolResultContent(id="call_114514", name="get_weather", content=[tools.TextContent(text="今天的天气是晴天。")]).model_dump()]},
)
]
result = multi_decomposer.decompose(entries)
assert len(result) == 2
tool_call_message = result[0]
tool_result_message = result[1]
assert tool_call_message.role == "assistant"
assert all(isinstance(call, LLMToolCallContent) for call in tool_call_message.content)
assert tool_result_message.role == "tool"
assert all(isinstance(result, LLMToolResultContent) for result in tool_result_message.content)
================================================
FILE: tests/memory/test_composer_strategy.py
================================================
from unittest.mock import Mock
import pytest
import kirara_ai.llm.format.tool as tools
from kirara_ai.im.message import ImageMessage, IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender, ChatType
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import (LLMChatImageContent, LLMChatMessage, LLMChatTextContent, LLMToolCallContent,
LLMToolResultContent)
from kirara_ai.memory.composes.composer_strategy import (IMMessageProcessor, LLMChatImageContentProcessor,
LLMChatMessageProcessor, LLMChatTextContentProcessor,
LLMToolCallContentProcessor, LLMToolResultContentProcessor,
MediaMessageProcessor, ProcessorFactory, TextMessageProcessor,
drop_think_part)
@pytest.fixture
def mock_container():
container = Mock(spec=DependencyContainer)
media_manager = Mock()
container.resolve.return_value = media_manager
return container
@pytest.fixture
def sample_context():
return {
"media_ids": [],
"tool_calls": [],
"tool_results": []
}
class TestDropThinkPart:
def test_drop_think_part_with_think_tag(self):
text = "这是思考部分这是输出部分"
result = drop_think_part(text)
assert result == "这是输出部分"
def test_drop_think_part_without_think_tag(self):
text = "这是纯文本,没有思考标签"
result = drop_think_part(text)
assert result == text
class TestTextMessageProcessor:
def test_process(self, mock_container, sample_context):
processor = TextMessageProcessor(mock_container)
message = TextMessage("这是一条文本消息")
result = processor.process(message, sample_context)
assert result == "这是一条文本消息\n"
class TestMediaMessageProcessor:
def test_process(self, mock_container, sample_context):
processor = MediaMessageProcessor(mock_container)
message = ImageMessage(media_id="media1", data=b"test", format="png")
result = processor.process(message, sample_context)
assert "id=\"media1\"" in result
assert sample_context["media_ids"] == ["media1"]
class TestLLMChatTextContentProcessor:
def test_process_normal_text(self, mock_container, sample_context):
processor = LLMChatTextContentProcessor(mock_container)
content = LLMChatTextContent(text="这是普通文本")
result = processor.process(content, sample_context)
assert result == "这是普通文本\n"
def test_process_with_think_tag(self, mock_container, sample_context):
processor = LLMChatTextContentProcessor(mock_container)
content = LLMChatTextContent(text="思考过程这是输出")
result = processor.process(content, sample_context)
assert result == "这是输出\n"
class TestLLMChatImageContentProcessor:
def test_process(self, mock_container, sample_context):
# 设置 media_manager mock
media_manager = mock_container.resolve.return_value
media = Mock()
media.description = "图片描述"
media_manager.get_media.return_value = media
processor = LLMChatImageContentProcessor(mock_container)
content = LLMChatImageContent(media_id="media1")
result = processor.process(content, sample_context)
assert "media_msg" in result
assert "media1" in result
assert "图片描述" in result
assert sample_context["media_ids"] == ["media1"]
class TestLLMToolCallContentProcessor:
def test_process(self, mock_container, sample_context):
processor = LLMToolCallContentProcessor(mock_container)
content = LLMToolCallContent(
id="call1",
name="test_function",
parameters={"arg1": "value1"}
)
result = processor.process(content, sample_context)
assert "function_call" in result
assert "id=\"call1\"" in result
assert "name=\"test_function\"" in result
assert len(sample_context["tool_calls"]) == 1
# 检查添加到上下文的工具调用数据
tool_call = sample_context["tool_calls"][0]
assert tool_call["id"] == "call1"
assert tool_call["name"] == "test_function"
class TestLLMToolResultContentProcessor:
def test_process(self, mock_container, sample_context):
processor = LLMToolResultContentProcessor(mock_container)
content = LLMToolResultContent(
id="result1",
name="test_result",
isError=False,
content=[tools.TextContent(text="结果文本")]
)
result = processor.process(content, sample_context)
assert "tool_result" in result
assert "id=\"result1\"" in result
assert "name=\"test_result\"" in result
assert "isError=\"False\"" in result
assert len(sample_context["tool_results"]) == 1
# 检查添加到上下文的工具结果数据
tool_result = sample_context["tool_results"][0]
assert tool_result["id"] == "result1"
assert tool_result["name"] == "test_result"
assert tool_result["isError"] is False
class TestIMMessageProcessor:
def test_process_with_text_message(self, mock_container, sample_context):
# 创建带有文本消息的 IMMessage
text_message = TextMessage("这是文本消息")
im_message = IMMessage(
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="用户"),
message_elements=[text_message]
)
processor = IMMessageProcessor(mock_container)
result = processor.process(im_message, sample_context)
assert "用户 说:" in result
assert "这是文本消息" in result
def test_process_with_media_message(self, mock_container, sample_context):
# 创建带有媒体消息的 IMMessage
media_message = ImageMessage(media_id="media1", data=b"test", format="png")
im_message = IMMessage(
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="用户"),
message_elements=[media_message]
)
processor = IMMessageProcessor(mock_container)
result = processor.process(im_message, sample_context)
assert "用户 说:" in result
assert "media_msg" in result
assert "media1" in result
assert sample_context["media_ids"] == ["media1"]
class TestLLMChatMessageProcessor:
def test_process_with_text_content(self, mock_container, sample_context):
message = LLMChatMessage(
role="user",
content=[LLMChatTextContent(text="这是文本内容")]
)
processor = LLMChatMessageProcessor(mock_container)
result = processor.process(message, sample_context)
assert "你回答:" in result
assert "这是文本内容" in result
def test_process_with_mixed_content(self, mock_container, sample_context):
# 设置 media_manager mock
media_manager = mock_container.resolve.return_value
media = Mock()
media.description = "图片描述"
media_manager.get_media.return_value = media
message = LLMChatMessage(
role="assistant",
content=[
LLMChatTextContent(text="这是文本内容"),
LLMChatImageContent(media_id="media1")
]
)
processor = LLMChatMessageProcessor(mock_container)
result = processor.process(message, sample_context)
assert "你回答:" in result
assert "这是文本内容" in result
assert "media_msg" in result
assert "media1" in result
assert sample_context["media_ids"] == ["media1"]
def test_process_with_tool_content(self, mock_container, sample_context):
message = LLMChatMessage(
role="assistant",
content=[LLMToolCallContent(
id="call1",
name="test_function",
parameters={"arg1": "value1"}
)]
)
processor = LLMChatMessageProcessor(mock_container)
result = processor.process(message, sample_context)
assert "function_call" in result
assert "id=\"call1\"" in result
assert "name=\"test_function\"" in result
assert len(sample_context["tool_calls"]) == 1
class TestProcessorFactory:
def test_get_processor_for_im_message(self, mock_container):
factory = ProcessorFactory(mock_container)
processor = factory.get_processor(IMMessage)
assert isinstance(processor, IMMessageProcessor)
def test_get_processor_for_llm_chat_message(self, mock_container):
factory = ProcessorFactory(mock_container)
processor = factory.get_processor(LLMChatMessage)
assert isinstance(processor, LLMChatMessageProcessor)
def test_get_processor_unknown_type(self, mock_container):
factory = ProcessorFactory(mock_container)
processor = factory.get_processor(object)
assert processor is None
================================================
FILE: tests/memory/test_decomposer_strategy.py
================================================
from datetime import datetime
from unittest.mock import Mock
import pytest
from kirara_ai.im.sender import ChatSender, ChatType
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import (LLMChatImageContent, LLMChatMessage, LLMChatTextContent, LLMToolCallContent,
LLMToolResultContent)
from kirara_ai.memory.composes.decomposer_strategy import (ContentInfo, ContentParser, DefaultDecomposerStrategy,
MediaContentStrategy, MultiElementDecomposerStrategy,
TextContentStrategy, ToolCallContentStrategy,
ToolResultContentStrategy)
from kirara_ai.memory.entry import MemoryEntry
@pytest.fixture
def mock_container():
container = Mock(spec=DependencyContainer)
media_manager = Mock()
container.resolve.return_value = media_manager
return container
@pytest.fixture
def sample_entry():
entry = MemoryEntry(
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="Test User"),
content="这是一段文本 继续文本 更多文本 ",
timestamp=datetime.now(),
metadata={
"_media_ids": ["media1"],
"_tool_calls": [{
"id": "call1",
"name": "test_function",
"arguments": {"arg1": "value1"}
}],
"_tool_results": [{
"id": "result1",
"name": "test_result",
"isError": False,
"content": [{"type": "text", "text": "结果文本"}]
}]
}
)
return entry
class TestTextContentStrategy:
def test_extract_content(self, sample_entry):
strategy = TextContentStrategy()
content_infos = strategy.extract_content(sample_entry.content, sample_entry)
assert len(content_infos) == 3
assert content_infos[0].content_type == "text"
assert content_infos[0].text == "这是一段文本"
assert content_infos[1].text == "继续文本"
assert content_infos[2].text == "更多文本"
def test_to_llm_content(self):
strategy = TextContentStrategy()
info = ContentInfo(
content_type="text",
start=0,
end=10,
text="测试文本"
)
content = strategy.to_llm_content(info)
assert isinstance(content, LLMChatTextContent)
assert content.text == "测试文本"
def test_to_text(self):
strategy = TextContentStrategy()
info = ContentInfo(
content_type="text",
start=0,
end=10,
text="测试文本"
)
text = strategy.to_text(info)
assert text == "测试文本"
class TestMediaContentStrategy:
def test_extract_content(self, sample_entry):
strategy = MediaContentStrategy()
content_infos = strategy.extract_content(sample_entry.content, sample_entry)
assert len(content_infos) == 1
assert content_infos[0].content_type == "media"
assert content_infos[0].metadata["media_id"] == "media1"
def test_to_llm_content(self):
strategy = MediaContentStrategy()
info = ContentInfo(
content_type="media",
start=0,
end=10,
text="",
metadata={"media_id": "media1"}
)
content = strategy.to_llm_content(info)
assert isinstance(content, LLMChatImageContent)
assert content.media_id == "media1"
def test_to_text(self):
strategy = MediaContentStrategy()
info = ContentInfo(
content_type="media",
start=0,
end=10,
text="",
metadata={"media_id": "media1"}
)
text = strategy.to_text(info)
assert text == ""
class TestToolCallContentStrategy:
def test_extract_content(self, sample_entry):
strategy = ToolCallContentStrategy()
content_infos = strategy.extract_content(sample_entry.content, sample_entry)
assert len(content_infos) == 1
assert content_infos[0].content_type == "tool_call"
assert content_infos[0].metadata["id"] == "call1"
assert content_infos[0].metadata["name"] == "test_function"
def test_no_metadata_returns_empty(self):
strategy = ToolCallContentStrategy()
entry = MemoryEntry(
content="",
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="Test User")
)
content_infos = strategy.extract_content(entry.content, entry)
assert len(content_infos) == 0
def test_to_llm_content(self):
strategy = ToolCallContentStrategy()
info = ContentInfo(
content_type="tool_call",
start=0,
end=10,
text="",
metadata={
"id": "call1",
"name": "test_function",
"arguments": {"arg1": "value1"}
}
)
content = strategy.to_llm_content(info)
assert isinstance(content, LLMToolCallContent)
assert content.id == "call1"
assert content.name == "test_function"
def test_to_text(self):
strategy = ToolCallContentStrategy()
info = ContentInfo(
content_type="tool_call",
start=0,
end=10,
text="",
metadata={
"id": "call1",
"name": "test_function"
}
)
text = strategy.to_text(info)
assert text == ""
class TestToolResultContentStrategy:
def test_extract_content(self, sample_entry):
strategy = ToolResultContentStrategy()
content_infos = strategy.extract_content(sample_entry.content, sample_entry)
assert len(content_infos) == 1
assert content_infos[0].content_type == "tool_result"
assert content_infos[0].metadata["id"] == "result1"
assert content_infos[0].metadata["name"] == "test_result"
assert content_infos[0].metadata["isError"] is False
def test_to_llm_content(self):
strategy = ToolResultContentStrategy()
info = ContentInfo(
content_type="tool_result",
start=0,
end=10,
text="",
metadata={
"id": "result1",
"name": "test_result",
"isError": False,
"content": [{"type": "text", "text": "结果文本"}]
}
)
content = strategy.to_llm_content(info)
assert isinstance(content, LLMToolResultContent)
assert content.id == "result1"
assert content.name == "test_result"
assert content.isError is False
def test_to_text(self):
strategy = ToolResultContentStrategy()
info = ContentInfo(
content_type="tool_result",
start=0,
end=10,
text="",
metadata={
"id": "result1",
"name": "test_result",
"isError": False
}
)
text = strategy.to_text(info)
assert text == ""
class TestContentParser:
def test_parse_content(self, sample_entry):
parser = ContentParser()
content_infos = parser.parse_content(sample_entry.content, sample_entry)
assert len(content_infos) == 6 # 3 text parts + 1 media + 1 tool call + 1 tool result = 6
# 检查内容是否按位置排序
for i in range(len(content_infos) - 1):
assert content_infos[i].start < content_infos[i + 1].start
def test_to_llm_message(self):
parser = ContentParser()
content_infos = [
ContentInfo(content_type="text", start=0, end=10, text="Hello"),
ContentInfo(content_type="media", start=11, end=20, text="", metadata={"media_id": "media1"})
]
message = parser.to_llm_message(content_infos, "user")[0]
assert isinstance(message, LLMChatMessage)
assert message.role == "user"
assert len(message.content) == 2
assert isinstance(message.content[0], LLMChatTextContent)
assert isinstance(message.content[1], LLMChatImageContent)
def test_to_text(self):
parser = ContentParser()
content_infos = [
ContentInfo(content_type="text", start=0, end=10, text="Hello"),
ContentInfo(content_type="media", start=11, end=20, text="", metadata={"media_id": "media1"})
]
text = parser.to_text(content_infos)
assert text == "Hello"
class TestDefaultDecomposerStrategy:
def test_decompose_empty_entries(self):
strategy = DefaultDecomposerStrategy()
result = strategy.decompose([], {"empty_message": "空消息"})
assert len(result) == 1
assert result[0] == "空消息"
def test_decompose_with_entries(self, sample_entry):
strategy = DefaultDecomposerStrategy()
result = strategy.decompose([sample_entry], {})
assert len(result) == 1
assert isinstance(result[0], str)
assert "这是一段文本" in result[0]
assert "" in result[0]
class TestMultiElementDecomposerStrategy:
def test_decompose_empty_entries(self):
strategy = MultiElementDecomposerStrategy()
result = strategy.decompose([], {})
assert len(result) == 0
def test_process_entry_user_content(self):
strategy = MultiElementDecomposerStrategy()
entry = MemoryEntry(
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="Test User"),
content="用户消息",
timestamp=datetime.now()
)
messages = strategy._process_entry(entry)
assert len(messages) == 1
assert messages[0].role == "user"
assert len(messages[0].content) == 1
assert isinstance(messages[0].content[0], LLMChatTextContent)
assert messages[0].content[0].text == "用户消息"
def test_process_entry_with_ai_response(self):
strategy = MultiElementDecomposerStrategy()
entry = MemoryEntry(
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="Test User"),
content="用户消息\n你回答: AI回复",
timestamp=datetime.now()
)
messages = strategy._process_entry(entry)
assert len(messages) == 2
assert messages[0].role == "user"
assert isinstance(messages[0].content[0], LLMChatTextContent)
assert messages[0].content[0].text == "用户消息"
assert messages[1].role == "assistant"
assert isinstance(messages[1].content[0], LLMChatTextContent)
assert messages[1].content[0].text == "AI回复"
def test_merge_adjacent_messages(self):
strategy = MultiElementDecomposerStrategy()
messages = [
LLMChatMessage(role="user", content=[LLMChatTextContent(text="消息1")]),
LLMChatMessage(role="user", content=[LLMChatTextContent(text="消息2")]),
LLMChatMessage(role="assistant", content=[LLMChatTextContent(text="回复")])
]
strategy._merge_adjacent_messages(messages)
assert len(messages) == 2
assert messages[0].role == "user"
assert len(messages[0].content) == 2
assert messages[1].role == "assistant"
================================================
FILE: tests/memory/test_memory_manager.py
================================================
from datetime import datetime
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.im.sender import ChatSender, ChatType
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.memory.composes import MemoryComposer, MemoryDecomposer
from kirara_ai.memory.entry import MemoryEntry
from kirara_ai.memory.memory_manager import MemoryManager
from kirara_ai.memory.persistences.base import MemoryPersistence
from kirara_ai.memory.scopes import MemoryScope
# ==================== Dummy Persistence ====================
class DummyMemoryPersistence(MemoryPersistence):
"""
用于测试的 Dummy Persistence,不进行实际的持久化操作
"""
def __init__(self):
self.storage: Dict[str, List[MemoryEntry]] = {}
def load(self, scope_key: str) -> List[MemoryEntry]:
"""从存储加载记忆"""
return self.storage.get(scope_key, [])
def save(self, scope_key: str, entries: List[MemoryEntry]) -> None:
"""将记忆保存到存储"""
self.storage[scope_key] = entries
def stop(self):
"""停止持久化"""
def flush(self):
"""刷新存储"""
# ==================== Fixtures ====================
@pytest.fixture
def container():
"""创建模拟的容器"""
container = DependencyContainer()
config = GlobalConfig()
container.resolve = MagicMock(return_value=config)
container.register(GlobalConfig, config)
return container
@pytest.fixture
def memory_manager(container):
"""创建使用 Dummy Persistence 的 MemoryManager 实例"""
dummy_persistence = DummyMemoryPersistence()
manager = MemoryManager(container, persistence=dummy_persistence)
return manager
@pytest.fixture
def test_entry():
"""创建测试记忆条目"""
return MemoryEntry(
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="Test User"),
content="test message",
timestamp=datetime.now(),
metadata={}
)
@pytest.fixture
def mock_scope():
"""创建模拟的作用域"""
mock_scope = MagicMock(spec=MemoryScope)
mock_scope.get_scope_key.return_value = "test_scope"
mock_scope.is_in_scope.return_value = True # 默认返回 True
return mock_scope
# ==================== 测试用例 ====================
class TestMemoryManager:
def test_register_scope(self, memory_manager):
"""测试注册作用域"""
mock_scope_class = MagicMock(spec=MemoryScope)
memory_manager.register_scope("test", mock_scope_class)
assert "test" in memory_manager.scope_registry._registry
assert memory_manager.scope_registry._registry["test"] == mock_scope_class
def test_register_composer(self, memory_manager):
"""测试注册组合器"""
mock_composer_class = MagicMock(spec=MemoryComposer)
memory_manager.register_composer("test", mock_composer_class)
assert "test" in memory_manager.composer_registry._registry
assert memory_manager.composer_registry._registry["test"] == mock_composer_class
def test_register_decomposer(self, memory_manager):
"""测试注册解析器"""
mock_decomposer_class = MagicMock(spec=MemoryDecomposer)
memory_manager.register_decomposer("test", mock_decomposer_class)
assert "test" in memory_manager.decomposer_registry._registry
assert (
memory_manager.decomposer_registry._registry["test"]
== mock_decomposer_class
)
def test_store_and_query(self, memory_manager, test_entry, mock_scope):
"""测试存储和查询"""
# 存储
memory_manager.store(mock_scope, test_entry)
# 验证内存缓存
assert "test_scope" in memory_manager.memories
assert len(memory_manager.memories["test_scope"]) == 1
assert memory_manager.memories["test_scope"][0] == test_entry
# 查询
results = memory_manager.query(mock_scope, "user1")
# 验证结果
assert len(results) == 1
assert results[0] == test_entry
def test_max_entries_limit(self, memory_manager, mock_scope, container):
"""测试最大条目数限制"""
# 设置最大条目数
container.resolve.return_value.memory.max_entries = 2
# 存储3条记录
for i in range(3):
entry = MemoryEntry(
sender=ChatSender(user_id=f"user{i}", chat_type=ChatType.C2C, display_name="Test User"),
content=f"message {i}",
timestamp=datetime.now(),
metadata={},
)
memory_manager.store(mock_scope, entry)
# 验证只保留了最新的2条
assert len(memory_manager.memories["test_scope"]) == 2
assert memory_manager.memories["test_scope"][-1].content == "message 2"
def test_shutdown(self, memory_manager, test_entry):
"""测试关闭"""
# 添加一些测试数据
memory_manager.memories = {"scope1": [test_entry], "scope2": [test_entry]}
# 关闭
memory_manager.shutdown()
# 验证所有数据都被保存 (由于使用的是 DummyPersistence,这里实际上没有持久化操作)
# 可以添加一些断言来验证 DummyPersistence 的行为是否符合预期
persistence = memory_manager.persistence
assert isinstance(persistence, DummyMemoryPersistence)
assert persistence.storage["scope1"] == [test_entry]
assert persistence.storage["scope2"] == [test_entry]
def test_clear_memory(self, memory_manager, mock_scope):
"""测试清空记忆"""
# 存储一些数据
entry = MemoryEntry(
sender=ChatSender(user_id="user1", chat_type=ChatType.C2C, display_name="Test User"),
content="test",
timestamp=datetime.now(),
metadata={}
)
memory_manager.store(mock_scope, entry)
# 清空记忆
memory_manager.clear_memory(mock_scope, "user1")
# 验证记忆是否被清空
assert memory_manager.memories["test_scope"] == []
persistence = memory_manager.persistence
assert isinstance(persistence, DummyMemoryPersistence)
assert persistence.storage["test_scope"] == []
================================================
FILE: tests/memory/test_persistence.py
================================================
import os
import shutil
import tempfile
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from kirara_ai.im.sender import ChatSender, ChatType
from kirara_ai.memory.entry import MemoryEntry
from kirara_ai.memory.persistences import FileMemoryPersistence, RedisMemoryPersistence
# ==================== 常量区 ====================
TEST_USER_1 = "user1"
TEST_USER_2 = "user2"
TEST_GROUP = "group1"
TEST_DISPLAY_NAME = "john"
TEST_CONTENT_1 = "test message 1"
TEST_CONTENT_2 = "test message 2"
TEST_METADATA_TEXT = {"type": "text"}
TEST_METADATA_IMAGE = {"type": "image"}
TEST_TIMESTAMP_1 = datetime(2024, 1, 1, 12, 0)
TEST_TIMESTAMP_2 = datetime(2024, 1, 1, 12, 1)
TEST_SCOPE = "test_scope"
# ==================== Fixtures ====================
@pytest.fixture
def test_dir():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
@pytest.fixture
def file_persistence(test_dir):
return FileMemoryPersistence(test_dir)
@pytest.fixture
def chat_senders():
sender1 = ChatSender.from_group_chat(TEST_USER_1, TEST_GROUP, TEST_DISPLAY_NAME)
sender2 = ChatSender.from_c2c_chat(TEST_USER_2, TEST_DISPLAY_NAME)
return sender1, sender2
@pytest.fixture
def test_entries(chat_senders):
sender1, sender2 = chat_senders
return [
MemoryEntry(
sender=sender1,
content=TEST_CONTENT_1,
timestamp=TEST_TIMESTAMP_1,
metadata=TEST_METADATA_TEXT,
),
MemoryEntry(
sender=sender2,
content=TEST_CONTENT_2,
timestamp=TEST_TIMESTAMP_2,
metadata=TEST_METADATA_IMAGE,
),
]
@pytest.fixture
def redis_mock():
return MagicMock()
@pytest.fixture
def redis_persistence(redis_mock):
with patch("redis.Redis", return_value=redis_mock):
return RedisMemoryPersistence(host="localhost")
# ==================== 测试逻辑 ====================
class TestFileMemoryPersistence:
def test_save_and_load(self, file_persistence, test_entries, test_dir):
# 测试保存
file_persistence.save(TEST_SCOPE, test_entries)
# 验证文件是否创建
file_path = os.path.join(test_dir, f"{TEST_SCOPE}.json")
assert os.path.exists(file_path)
# 测试加载
loaded_entries = file_persistence.load(TEST_SCOPE)
# 验证加载的数据
assert len(loaded_entries) == len(test_entries)
for original, loaded in zip(test_entries, loaded_entries):
assert original.sender.user_id == loaded.sender.user_id
assert original.sender.chat_type == loaded.sender.chat_type
assert original.sender.group_id == loaded.sender.group_id
assert original.content == loaded.content
assert original.timestamp == loaded.timestamp
assert original.metadata == loaded.metadata
def test_load_nonexistent(self, file_persistence):
entries = file_persistence.load("nonexistent")
assert entries == []
class TestRedisMemoryPersistence:
def test_save(self, redis_persistence, redis_mock, test_entries):
# 测试保存
redis_persistence.save(TEST_SCOPE, test_entries)
redis_mock.set.assert_called_once()
def test_load_with_data(self, redis_persistence, redis_mock, chat_senders):
# Mock Redis 返回数据
import json
from kirara_ai.memory.persistences.codecs import MemoryJSONEncoder
sender, _ = chat_senders
serialized_data = [
{
"sender": {
"__type__": "ChatSender",
"user_id": sender.user_id,
"chat_type": sender.chat_type.value,
"group_id": sender.group_id,
"display_name": sender.display_name,
"raw_metadata": {},
},
"content": TEST_CONTENT_1,
"timestamp": TEST_TIMESTAMP_1.isoformat(),
"metadata": TEST_METADATA_TEXT,
}
]
redis_mock.get.return_value = json.dumps(serialized_data, cls=MemoryJSONEncoder)
# 测试加载
loaded_entries = redis_persistence.load(TEST_SCOPE)
# 验证数据
assert len(loaded_entries) == 1
entry = loaded_entries[0]
assert entry.sender.user_id == TEST_USER_1
assert entry.sender.chat_type == ChatType.GROUP
assert entry.sender.group_id == TEST_GROUP
assert entry.sender.display_name == TEST_DISPLAY_NAME
assert entry.content == TEST_CONTENT_1
assert entry.metadata == TEST_METADATA_TEXT
def test_load_no_data(self, redis_persistence, redis_mock):
redis_mock.get.return_value = None
assert redis_persistence.load(TEST_SCOPE) == []
================================================
FILE: tests/memory/test_scope.py
================================================
import pytest
from kirara_ai.im.sender import ChatSender
from kirara_ai.memory.scopes import GlobalScope, GroupScope, MemberScope
# ==================== 常量区 ====================
TEST_USER_1 = "user1"
TEST_USER_2 = "user2"
TEST_GROUP_1 = "group1"
TEST_GROUP_2 = "group2"
TEST_DISPLAY_NAME = "john"
# ==================== Fixtures ====================
@pytest.fixture
def group_sender():
return ChatSender.from_group_chat(
user_id=TEST_USER_1, group_id=TEST_GROUP_1, display_name=TEST_DISPLAY_NAME
)
@pytest.fixture
def c2c_sender():
return ChatSender.from_c2c_chat(user_id=TEST_USER_1, display_name=TEST_DISPLAY_NAME)
@pytest.fixture
def different_group_sender():
return ChatSender.from_group_chat(
user_id=TEST_USER_1, group_id=TEST_GROUP_2, display_name=TEST_DISPLAY_NAME
)
@pytest.fixture
def different_user_sender():
return ChatSender.from_group_chat(
user_id=TEST_USER_2, group_id=TEST_GROUP_1, display_name=TEST_DISPLAY_NAME
)
# ==================== 测试逻辑 ====================
class TestMemberScope:
@pytest.fixture
def scope(self):
return MemberScope()
def test_get_scope_key_group(self, scope, group_sender):
key = scope.get_scope_key(group_sender)
assert key == f"member:{TEST_GROUP_1}:{TEST_USER_1}"
def test_get_scope_key_c2c(self, scope, c2c_sender):
key = scope.get_scope_key(c2c_sender)
assert key == f"c2c:{TEST_USER_1}"
def test_is_in_scope_group_same_user(self, scope, group_sender):
same_sender = ChatSender.from_group_chat(
TEST_USER_1, TEST_GROUP_1, TEST_DISPLAY_NAME
)
assert scope.is_in_scope(group_sender, same_sender)
def test_is_in_scope_group_different_user(
self, scope, group_sender, different_user_sender
):
assert not scope.is_in_scope(group_sender, different_user_sender)
def test_is_in_scope_group_different_group(
self, scope, group_sender, different_group_sender
):
assert not scope.is_in_scope(group_sender, different_group_sender)
def test_is_in_scope_c2c_same_user(self, scope, c2c_sender):
same_sender = ChatSender.from_c2c_chat(TEST_USER_1, TEST_DISPLAY_NAME)
assert scope.is_in_scope(c2c_sender, same_sender)
def test_is_in_scope_c2c_different_user(self, scope, c2c_sender):
different_sender = ChatSender.from_c2c_chat(TEST_USER_2, TEST_DISPLAY_NAME)
assert not scope.is_in_scope(c2c_sender, different_sender)
def test_is_in_scope_different_chat_type(self, scope, group_sender, c2c_sender):
assert not scope.is_in_scope(group_sender, c2c_sender)
class TestGroupScope:
@pytest.fixture
def scope(self):
return GroupScope()
def test_get_scope_key_group(self, scope, group_sender):
key = scope.get_scope_key(group_sender)
assert key == f"group:{TEST_GROUP_1}"
def test_get_scope_key_c2c(self, scope, c2c_sender):
key = scope.get_scope_key(c2c_sender)
assert key == f"c2c:{TEST_USER_1}"
def test_is_in_scope_group_same_group(
self, scope, group_sender, different_user_sender
):
assert scope.is_in_scope(group_sender, different_user_sender)
def test_is_in_scope_group_different_group(
self, scope, group_sender, different_group_sender
):
assert not scope.is_in_scope(group_sender, different_group_sender)
def test_is_in_scope_c2c_same_user(self, scope, c2c_sender):
same_sender = ChatSender.from_c2c_chat(TEST_USER_1, TEST_DISPLAY_NAME)
assert scope.is_in_scope(c2c_sender, same_sender)
def test_is_in_scope_c2c_different_user(self, scope, c2c_sender):
different_sender = ChatSender.from_c2c_chat(TEST_USER_2, TEST_DISPLAY_NAME)
assert not scope.is_in_scope(c2c_sender, different_sender)
def test_is_in_scope_different_chat_type(self, scope, group_sender, c2c_sender):
assert not scope.is_in_scope(group_sender, c2c_sender)
class TestGlobalScope:
@pytest.fixture
def scope(self):
return GlobalScope()
def test_get_scope_key(self, scope, group_sender, c2c_sender):
assert scope.get_scope_key(group_sender) == "global"
assert scope.get_scope_key(c2c_sender) == "global"
def test_is_in_scope_always_true(self, scope, group_sender, c2c_sender):
assert scope.is_in_scope(group_sender, c2c_sender)
assert scope.is_in_scope(c2c_sender, group_sender)
================================================
FILE: tests/resources/test_image.txt
================================================
This is a test file for MediaElement testing.
================================================
FILE: tests/system_blocks/__init__.py
================================================
# 系统块测试包
================================================
FILE: tests/system_blocks/game/__init__.py
================================================
# 游戏相关块测试包
================================================
FILE: tests/system_blocks/game/test_dice.py
================================================
import pytest
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.blocks.game.dice import DiceRoll
@pytest.fixture
def container():
"""创建一个依赖容器"""
return DependencyContainer()
@pytest.fixture
def create_message():
def _create(content: str) -> IMMessage:
return IMMessage(
sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"),
message_elements=[TextMessage(content)]
)
return _create
def test_dice_roll_basic(container, create_message):
"""测试基本的骰子命令"""
block = DiceRoll()
block.container = container
result = block.execute(create_message(".roll 2d6"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
# 不检查 sender 的具体类型,只检查是否存在
assert hasattr(response, "sender")
assert len(response.message_elements) == 1
assert "掷出了 2d6" in response.content or "🎲" in response.content
def test_dice_roll_invalid(container, create_message):
"""测试无效的骰子命令"""
block = DiceRoll()
block.container = container
result = block.execute(create_message("invalid command"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "Invalid dice command" in response.content or "无效" in response.content
def test_dice_roll_too_many(container, create_message):
"""测试超过限制的骰子数量"""
block = DiceRoll()
block.container = container
result = block.execute(create_message(".roll 101d6"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "Too many dice" in response.content or "太多" in response.content
def test_dice_roll_with_modifier(container, create_message):
"""测试带有修饰符的骰子命令"""
block = DiceRoll()
block.container = container
result = block.execute(create_message(".roll 2d6+3"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
# 不检查具体格式,只检查是否包含关键信息
assert "2d6" in response.content
# 测试减法修饰符
result = block.execute(create_message(".roll 1d20-2"))
response = result["response"]
assert "1d20" in response.content
# 不检查具体的修饰符,因为实现可能不同
def test_dice_roll_multiple_dice(container, create_message):
"""测试多种骰子的命令"""
block = DiceRoll()
block.container = container
# 注意:实际实现可能只处理第一个骰子命令
result = block.execute(create_message(".roll 2d6"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "2d6" in response.content
# 添加其他骰子命令的测试
result = block.execute(create_message(".roll 1d20"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "1d20" in response.content
result = block.execute(create_message(".roll 3d4"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "3d4" in response.content
================================================
FILE: tests/system_blocks/game/test_gacha.py
================================================
import pytest
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.blocks.game.gacha import GachaSimulator
@pytest.fixture
def container():
"""创建一个依赖容器"""
return DependencyContainer()
@pytest.fixture
def create_message():
def _create(content: str) -> IMMessage:
return IMMessage(
sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"),
message_elements=[TextMessage(content)]
)
return _create
def test_gacha_single_pull(container, create_message):
"""测试单次抽卡"""
block = GachaSimulator()
block.container = container
result = block.execute(create_message("单抽"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "抽卡结果" in response.content or "⭐" in response.content
# 不检查具体格式,只检查是否包含关键信息
def test_gacha_ten_pull(container, create_message):
"""测试十连抽卡"""
block = GachaSimulator()
block.container = container
result = block.execute(create_message("十连"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
# 不检查具体格式,只检查是否包含关键信息
assert "SSR" in response.content or "SR" in response.content or "R" in response.content
def test_gacha_custom_rates(container, create_message):
"""测试自定义概率的抽卡"""
rates = {"SSR": 1.0, "SR": 0.0, "R": 0.0} # 100% SSR
block = GachaSimulator(rates)
block.container = container
result = block.execute(create_message("单抽"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "SSR" in response.content
================================================
FILE: tests/system_blocks/im/__init__.py
================================================
# IM 相关块测试包
================================================
FILE: tests/system_blocks/im/test_messages.py
================================================
import asyncio
import pytest
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.manager import IMManager
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.blocks.im.messages import (AppendIMMessage, GetIMMessage, IMMessageToText,
SendIMMessage, TextToIMMessage)
# 创建模拟的 IMAdapter 类
class MockIMAdapter(IMAdapter):
async def send_message(self, message, target=None):
return None
def convert_to_message(self, message):
return message.content
async def start(self):
return None
async def stop(self):
return None
# 创建模拟的 IMManager 类
class MockIMManager(IMManager):
def __init__(self):
self.adapters = {"default": MockIMAdapter(), "telegram": MockIMAdapter()}
def get_adapter(self, name):
return self.adapters.get(name)
@pytest.fixture
def container():
"""创建一个带有模拟消息的容器"""
container = DependencyContainer()
# 模拟 IMMessage
mock_message = IMMessage(
sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"),
message_elements=[TextMessage("测试消息内容")]
)
container.register(IMMessage, mock_message)
return container
@pytest.mark.asyncio
async def test_send_im_message_async():
"""使用 pytest-asyncio 测试发送 IM 消息块"""
# 创建容器
container = DependencyContainer()
# 创建要发送的消息
send_message = IMMessage(
sender=ChatSender.get_bot_sender(),
message_elements=[TextMessage("回复消息")]
)
mock_message = IMMessage(
sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"),
message_elements=[TextMessage("测试消息内容")]
)
# 获取事件循环
loop = asyncio.get_event_loop()
# 注册到容器
container.register(IMAdapter, MockIMAdapter())
container.register(IMManager, MockIMManager())
container.register(IMMessage, mock_message)
container.register(asyncio.AbstractEventLoop, loop)
# 创建块 - 不指定适配器
block = SendIMMessage()
block.container = container
# 执行块
result = block.execute(msg=send_message)
# 验证结果
assert result is not None
# 创建块 - 指定适配器
block = SendIMMessage(im_name="telegram")
block.container = container
# 执行块
result = block.execute(msg=send_message, target=ChatSender.from_c2c_chat(user_id="specific_user", display_name="Specific User"))
# 验证结果
assert result is not None
def test_get_im_message(container):
"""测试获取 IM 消息块"""
# 创建块
block = GetIMMessage()
block.container = container
# 执行块
result = block.execute()
# 验证结果
assert "msg" in result
assert "sender" in result
assert isinstance(result["msg"], IMMessage)
assert isinstance(result["sender"], ChatSender)
assert result["msg"].content == "测试消息内容"
def test_im_message_to_text(container):
"""测试 IMMessage 转文本块"""
# 创建消息
message = IMMessage(
sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"),
message_elements=[TextMessage("Hello, World!")]
)
# 创建块
block = IMMessageToText()
block.container = container
# 执行块
result = block.execute(msg=message)
# 验证结果
assert "text" in result
assert result["text"] == "Hello, World!"
def test_text_to_im_message():
"""测试文本转 IMMessage 块"""
# 创建块 - 不分段
block = TextToIMMessage()
# 执行块
result = block.execute(text="Hello, World!")
# 验证结果
assert "msg" in result
assert isinstance(result["msg"], IMMessage)
assert isinstance(result["msg"].sender, ChatSender)
assert result["msg"].content == "Hello, World!"
# 创建块 - 使用分段符
block = TextToIMMessage(split_by="\n")
# 执行块
result = block.execute(text="Line 1\nLine 2\nLine 3")
# 验证结果
assert "msg" in result
assert isinstance(result["msg"], IMMessage)
assert len(result["msg"].message_elements) == 3
assert result["msg"].message_elements[0].text == "Line 1"
assert result["msg"].message_elements[1].text == "Line 2"
assert result["msg"].message_elements[2].text == "Line 3"
def test_append_im_message():
"""测试补充 IMMessage 消息块"""
# 创建基础消息
base_message = IMMessage(
sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"),
message_elements=[TextMessage("基础消息")]
)
# 创建要追加的消息元素
append_element = TextMessage("追加内容")
# 创建块
block = AppendIMMessage()
# 执行块
result = block.execute(base_msg=base_message, append_msg=append_element)
# 验证结果
assert "msg" in result
assert isinstance(result["msg"], IMMessage)
assert isinstance(result["msg"].sender, ChatSender)
assert len(result["msg"].message_elements) == 2
assert result["msg"].message_elements[0].text == "基础消息"
assert result["msg"].message_elements[1].text == "追加内容"
================================================
FILE: tests/system_blocks/im/test_states.py
================================================
import asyncio
import pytest
from kirara_ai.im.adapter import EditStateAdapter, IMAdapter
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.blocks.im.states import ToggleEditState
class MockAdapter(IMAdapter, EditStateAdapter):
async def set_chat_editing_state(self, *args, **kwargs):
return None
@pytest.mark.asyncio
async def test_toggle_edit_state_async():
"""使用 pytest-asyncio 测试切换编辑状态块"""
# 创建容器
container = DependencyContainer()
# 创建发送者
sender = ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User")
loop = asyncio.get_event_loop()
# 注册到容器
container.register(IMAdapter, MockAdapter)
container.register(asyncio.AbstractEventLoop, loop)
# 创建块 - 默认参数
block = ToggleEditState(is_editing=True)
block.container = container
# 执行块 - 传入发送者
result = block.execute(sender=sender)
# 验证结果 - 异步方法应该返回空字典
assert result == {}
================================================
FILE: tests/system_blocks/llm/__init__.py
================================================
# LLM 相关块测试包
================================================
FILE: tests/system_blocks/llm/test_basic.py
================================================
import pytest
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import LLMChatTextContent
from kirara_ai.llm.format.response import LLMChatResponse, Message
from kirara_ai.workflow.implementations.blocks.llm.basic import LLMResponseToText
@pytest.fixture
def container():
"""创建一个依赖容器"""
return DependencyContainer()
def test_llm_response_to_text():
"""测试 LLM 响应转文本块"""
# 创建一个模拟的 LLMChatResponse
chat_response = LLMChatResponse(
message=Message(
role="assistant",
content=[LLMChatTextContent(text="这是 AI 的回复")]
),
model="gpt-3.5-turbo",
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
)
# 创建块
block = LLMResponseToText()
# 执行块
result = block.execute(response=chat_response)
# 验证结果
assert "text" in result
assert result["text"] == "这是 AI 的回复"
# 测试空响应
empty_response = LLMChatResponse(
message=Message(
role="assistant",
content=[LLMChatTextContent(text="")]
),
model="gpt-3.5-turbo",
usage={"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5}
)
result = block.execute(response=empty_response)
assert result["text"] == ""
================================================
FILE: tests/system_blocks/llm/test_chat.py
================================================
import asyncio
import threading
from unittest.mock import MagicMock, patch
import pytest
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import LLMChatMessage, LLMChatTextContent, LLMToolResultContent
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.llm.format.tool import CallableWrapper, Function, TextContent, Tool, ToolCall, ToolInputSchema
from kirara_ai.llm.llm_manager import LLMManager
from kirara_ai.workflow.core.execution.executor import WorkflowExecutor
from kirara_ai.workflow.implementations.blocks.llm.chat import (ChatCompletion, ChatCompletionWithTools,
ChatMessageConstructor, ChatResponseConverter)
def get_tools() -> list[Tool]:
async def mock_tool_invoke(tool_call: ToolCall) -> LLMToolResultContent:
return LLMToolResultContent(
id=tool_call.id,
name=tool_call.function.name,
content=[TextContent(text="晴天,温度25°C")]
)
return [
Tool(
type="function",
name="get_weather",
description="Get the current weather in a given location",
parameters=ToolInputSchema(
type="object",
properties = {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
},
required=["location"],
),
invokeFunc=CallableWrapper(mock_tool_invoke)
)
]
def get_llm_tool_calls() -> list[ToolCall]:
return [
ToolCall(
id = "call_e33147bcb72525ed",
function = Function(
name="get_weather",
arguments={"location": "San Francisco, CA"}
)
)
]
# 创建模拟的 LLM 类
class MockLLM:
def chat(self, request):
return LLMChatResponse(
message=Message(
role="assistant",
content=[LLMChatTextContent(text="这是 AI 的回复")]
),
model="gpt-3.5-turbo",
usage=Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)
)
class MockLLMWithToolCalls:
def __init__(self, with_tool_calls=True):
self.with_tool_calls = with_tool_calls
self.call_count = 0
def chat(self, request):
self.call_count += 1
# 第一次调用返回工具调用
if self.with_tool_calls and self.call_count == 1:
return LLMChatResponse(
message=Message(
role="assistant",
content=[LLMChatTextContent(text="我需要查询天气")],
tool_calls=get_llm_tool_calls()
),
model="gpt-3.5-turbo",
usage=Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)
)
# 后续调用返回最终回复
else:
return LLMChatResponse(
message=Message(
role="assistant",
content=[LLMChatTextContent(text="旧金山今天是晴天,温度25°C")]
),
model="gpt-3.5-turbo",
usage=Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)
)
# 创建模拟的 LLMManager 类
class MockLLMManager(LLMManager):
def __init__(self):
self.mock_llm = MockLLM()
def get_llm_id_by_ability(self, ability):
return "gpt-3.5-turbo"
def get_llm(self, model_id):
return self.mock_llm
class MockLLMManagerWithToolCalls(LLMManager):
def __init__(self, with_tool_calls=True):
self.mock_llm = MockLLMWithToolCalls(with_tool_calls)
def get_llm_id_by_ability(self, ability):
return "gpt-3.5-turbo"
def get_llm(self, model_id):
return self.mock_llm
@pytest.fixture
def container():
"""创建一个带有模拟 LLM 提供者的容器"""
container = DependencyContainer()
# 模拟 LLMManager
mock_llm_manager = MockLLMManager()
# 模拟 LLM
# mock_llm = MockLLM()
# 模拟响应
mock_response = LLMChatResponse(
message=Message(
role="assistant",
content=[LLMChatTextContent(text="这是 AI 的回复")]
),
model="gpt-3.5-turbo",
usage=Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)
)
# mock_llm.chat.return_value = mock_response
# 模拟 WorkflowExecutor
mock_executor = MagicMock(spec=WorkflowExecutor)
# 创建一个在新线程中运行的事件循环
def start_background_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()
# 创建新的事件循环
new_loop = asyncio.new_event_loop()
# 在新线程中启动事件循环
t = threading.Thread(target=start_background_loop, args=(new_loop,), daemon=True)
t.start()
# 注册到容器
container.register(LLMManager, mock_llm_manager)
container.register(WorkflowExecutor, mock_executor)
container.register(asyncio.AbstractEventLoop, new_loop)
return container
@patch('kirara_ai.workflow.implementations.blocks.llm.chat.ChatMessageConstructor.execute')
def test_chat_message_constructor(mock_execute):
"""测试聊天消息构造器"""
# 模拟 execute 方法的返回值
mock_execute.return_value = {
"llm_msg": [Message(role="user", content=[LLMChatTextContent(text="你好,AI!")])]
}
# 创建块
block = ChatMessageConstructor()
# 模拟容器
mock_container = MagicMock(spec=DependencyContainer)
block.container = mock_container
# 执行块 - 基本用法
user_msg = IMMessage(
sender=ChatSender.from_c2c_chat(
user_id="test_user", display_name="Test User"),
message_elements=[TextMessage("你好,AI!")]
)
result = block.execute(
user_msg=user_msg,
memory_content="",
system_prompt_format="",
user_prompt_format=""
)
# 验证结果
assert "llm_msg" in result
assert isinstance(result["llm_msg"], list)
assert len(result["llm_msg"]) > 0
assert result["llm_msg"][0].role == "user"
assert result["llm_msg"][0].content[0].text == "你好,AI!"
def test_chat_completion(container):
# 创建消息列表
messages = [
Message(role="system", content=[LLMChatTextContent(text="你是一个助手")]),
Message(role="user", content=[LLMChatTextContent(text="你好,AI!")])
]
# 创建块 - 默认参数
block = ChatCompletion()
block.container = container
# 执行块
result = block.execute(prompt=messages)
# 验证结果
assert "resp" in result
assert isinstance(result["resp"], LLMChatResponse)
assert result["resp"].message.content[0].text == "这是 AI 的回复"
def test_chat_response_converter():
"""测试聊天响应转换器"""
# 创建聊天响应
chat_response = LLMChatResponse(
message=Message(
role="assistant",
content=[LLMChatTextContent(text="这是 AI 的回复")]
),
model="gpt-3.5-turbo",
usage=Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)
)
# 创建块
block = ChatResponseConverter()
# 模拟容器
mock_container = MagicMock(spec=DependencyContainer)
# 模拟 get_bot_sender 方法
mock_bot_sender = ChatSender.from_c2c_chat(
user_id="bot", display_name="Bot")
mock_container.resolve = MagicMock(
side_effect=lambda x: mock_bot_sender if x == ChatSender.get_bot_sender else None)
block.container = mock_container
# 执行块
result = block.execute(resp=chat_response)
# 验证结果
assert "msg" in result
assert isinstance(result["msg"], IMMessage)
assert "这是 AI 的回复" in result["msg"].content
def test_chat_completion_with_tools(container):
"""测试工具调用块"""
container.register(LLMManager, MockLLMManagerWithToolCalls(with_tool_calls=True))
# 创建消息列表
messages = [
LLMChatMessage(role="system", content=[LLMChatTextContent(text="你是一个助手")]),
LLMChatMessage(role="user", content=[LLMChatTextContent(text="旧金山今天天气如何?")])
]
# 创建工具列表
tools = get_tools()
# 创建块
block = ChatCompletionWithTools(model_name="gpt-3.5-turbo", max_iterations=3)
block.container = container
# 执行块
result = block.execute(msg=messages, tools=tools)
# 验证结果
assert "resp" in result
assert "iteration_msgs" in result
assert isinstance(result["resp"], LLMChatResponse)
assert isinstance(result["iteration_msgs"], list)
assert len(result["iteration_msgs"]) >= 2 # 至少包含工具调用和最终回复
# 验证工具调用过程
assert result["iteration_msgs"][0].tool_calls is not None
assert result["iteration_msgs"][0].tool_calls[0].function.name == "get_weather"
# 验证最终回复
assert "旧金山今天是晴天" in result["resp"].message.content[0].text
def test_chat_completion_with_tools_no_tool_calls(container):
"""测试工具调用块 - 无工具调用情况"""
# 注册到容器 - 使用不会进行工具调用的模拟
container.register(LLMManager, MockLLMManagerWithToolCalls(with_tool_calls=False))
# 创建消息列表
messages = [
LLMChatMessage(role="system", content=[LLMChatTextContent(text="你是一个助手")]),
LLMChatMessage(role="user", content=[LLMChatTextContent(text="你好,AI!")])
]
# 创建工具列表
tools = get_tools()
# 创建块
block = ChatCompletionWithTools(model_name="gpt-3.5-turbo", max_iterations=3)
block.container = container
# 执行块
result = block.execute(msg=messages, tools=tools)
# 验证结果 - 直接返回响应,没有工具调用
assert "resp" in result
assert "iteration_msgs" in result
assert isinstance(result["resp"], LLMChatResponse)
assert isinstance(result["iteration_msgs"], list)
assert len(result["iteration_msgs"]) == 0 # 无消息,因为没有工具调用
================================================
FILE: tests/system_blocks/llm/test_image.py
================================================
import asyncio
import base64
from unittest.mock import MagicMock, patch
import pytest
from kirara_ai.im.message import ImageMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.blocks.llm.image import SimpleStableDiffusionWebUI
@pytest.fixture
def container():
"""创建一个带有模拟 requests 的容器"""
container = DependencyContainer()
# 模拟 event loop
mock_loop = MagicMock(spec=asyncio.AbstractEventLoop)
# 注册到容器
container.register(asyncio.AbstractEventLoop, mock_loop)
return container, mock_loop
def test_simple_stable_diffusion_webui(container):
"""测试简单 Stable Diffusion WebUI 块"""
container, mock_loop = container
# 创建一个简单的 base64 图像数据
mock_image_data = base64.b64encode(b"mock_image_data").decode("utf-8")
# 模拟 requests.post 的响应
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"images": [mock_image_data]}
# 创建块 - 默认参数
block = SimpleStableDiffusionWebUI(api_url="http://localhost:7860")
block.container = container
# 使用 patch 模拟 requests.post
with patch('requests.post', return_value=mock_response):
# 执行块
result = block.execute(prompt="一只可爱的猫", negative_prompt="")
# 验证结果
assert "image" in result
assert isinstance(result["image"], ImageMessage)
# 创建块 - 自定义参数
block = SimpleStableDiffusionWebUI(
api_url="http://localhost:7860",
steps=30,
sampler_index="DPM++ 2M Karras",
cfg_scale=8.0,
width=768,
height=512
)
block.container = container
# 执行块
result = block.execute(prompt="一只可爱的猫", negative_prompt="低质量, 模糊")
# 验证结果
assert "image" in result
================================================
FILE: tests/system_blocks/memory/__init__.py
================================================
# 记忆相关块测试包
================================================
FILE: tests/system_blocks/memory/test_chat_memory.py
================================================
from unittest.mock import MagicMock
import pytest
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import LLMChatTextContent
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.memory.memory_manager import MemoryManager
from kirara_ai.memory.registry import ComposerRegistry, DecomposerRegistry, ScopeRegistry
from kirara_ai.workflow.implementations.blocks.memory.chat_memory import ChatMemoryQuery, ChatMemoryStore
# 创建模拟的 MemoryManager 类
class MockMemoryManager(MemoryManager):
def __init__(self):
self.config = MagicMock()
self.config.default_scope = "member"
def query(self, *args, **kwargs):
return "系统:你是一个助手\n用户:你好\n助手:你好!有什么可以帮助你的吗?"
def store(self, *args, **kwargs):
return None
def clear(self, *args, **kwargs):
return None
# 创建模拟的 Scope 类
class MockScope:
def __init__(self, name):
self.name = name
# 创建模拟的 ScopeRegistry 类
class MockScopeRegistry(ScopeRegistry):
def get_scope(self, name):
return MockScope(name)
# 创建模拟的 Composer 类
class MockComposer:
def compose(self, sender, messages):
return ["memory_entry"]
# 创建模拟的 ComposerRegistry 类
class MockComposerRegistry(ComposerRegistry):
def get_composer(self, name):
return MockComposer()
# 创建模拟的 Decomposer 类
class MockDecomposer:
def decompose(self, memory_entries):
return "系统:你是一个助手\n用户:你好\n助手:你好!有什么可以帮助你的吗?"
# 创建模拟的 DecomposerRegistry 类
class MockDecomposerRegistry(DecomposerRegistry):
def get_decomposer(self, name):
return MockDecomposer()
@pytest.mark.asyncio
async def test_chat_memory_query_async():
"""使用 pytest-asyncio 测试聊天记忆查询块"""
# 创建容器
container = DependencyContainer()
# 创建发送者
chat_sender = ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User")
# 注册到容器
container.register(MemoryManager, MockMemoryManager())
container.register(ScopeRegistry, MockScopeRegistry(container))
container.register(DecomposerRegistry, MockDecomposerRegistry(container))
# 创建块 - 默认参数
block = ChatMemoryQuery(scope_type="member")
block.container = container
# 执行块
result = block.execute(chat_sender=chat_sender)
# 验证结果
assert "memory_content" in result
assert isinstance(result["memory_content"], str)
assert "你是一个助手" in result["memory_content"]
@pytest.mark.asyncio
async def test_chat_memory_store_async():
"""使用 pytest-asyncio 测试聊天记忆存储块"""
# 创建容器
container = DependencyContainer()
# 创建用户消息
user_msg = IMMessage(
sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"),
message_elements=[TextMessage("新消息")]
)
# 创建 LLM 响应
llm_resp = LLMChatResponse(
message=Message(
content=[
LLMChatTextContent(
text="这是 AI 的回复"
)
],
role="assistant"
),
model="gpt-3.5-turbo",
usage=Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)
)
# 注册到容器
container.register(MemoryManager, MockMemoryManager())
container.register(ScopeRegistry, MockScopeRegistry(container))
container.register(ComposerRegistry, MockComposerRegistry(container))
# 创建块
block = ChatMemoryStore(scope_type="member")
block.container = container
# 执行块 - 存储用户消息
result = block.execute(user_msg=user_msg)
# 验证结果
assert result == {}
# 执行块 - 存储 LLM 响应
result = block.execute(
user_msg=user_msg,
llm_resp=llm_resp
)
# 验证结果
assert result == {}
================================================
FILE: tests/system_blocks/memory/test_clear_memory.py
================================================
from unittest.mock import MagicMock
import pytest
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.im.message import IMMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.memory.memory_manager import MemoryManager
from kirara_ai.memory.registry import ComposerRegistry, DecomposerRegistry, ScopeRegistry
from kirara_ai.workflow.implementations.blocks.memory.clear_memory import ClearMemory
# 创建模拟的 MemoryManager 类
class MockMemoryManager(MemoryManager):
def __init__(self, container: DependencyContainer):
super().__init__(container)
self.config = MagicMock()
self.config.default_scope = "member"
self.memories = {}
def clear_memory(self, *args, **kwargs):
return None
# 创建模拟的 Scope 类
class MockScope:
def __init__(self, name):
self.name = name
def get_scope_key(self, sender: ChatSender):
return sender.user_id
# 创建模拟的 ScopeRegistry 类
class MockScopeRegistry(ScopeRegistry):
def get_scope(self, name):
return MockScope(name)
@pytest.mark.asyncio
async def test_clear_memory_async():
"""使用 pytest-asyncio 测试清除记忆块"""
# 创建容器
container = DependencyContainer()
# 创建发送者
chat_sender = ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User")
# 注册到容器
container.register(DependencyContainer, container)
container.register(GlobalConfig, GlobalConfig())
container.register(MemoryManager, MockMemoryManager(container))
container.register(ScopeRegistry, MockScopeRegistry(container))
container.register(ComposerRegistry, MagicMock(spec=ComposerRegistry))
container.register(DecomposerRegistry, MagicMock(spec=DecomposerRegistry))
# 创建块
block = ClearMemory(scope_type="member")
block.container = container
# 执行块 - 使用默认发送者
result = block.execute(chat_sender=chat_sender)
# 验证结果
assert "response" in result
assert isinstance(result["response"], IMMessage)
assert "已清空" in result["response"].content or "清除" in result["response"].content
# 执行块 - 使用自定义发送者
custom_sender = ChatSender.from_c2c_chat(user_id="custom_user", display_name="Custom User")
result = block.execute(chat_sender=custom_sender)
# 验证结果
assert "response" in result
assert isinstance(result["response"], IMMessage)
================================================
FILE: tests/system_blocks/system/__init__.py
================================================
# 系统基础块测试包
================================================
FILE: tests/system_blocks/system/test_basic.py
================================================
import re
from datetime import datetime
from unittest.mock import patch
import pytest
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.blocks.system.basic import (CurrentTimeBlock, TextBlock, TextConcatBlock,
TextExtractByRegexBlock, TextReplaceBlock)
@pytest.fixture
def container():
"""创建一个实际的依赖容器"""
return DependencyContainer()
def test_text_block():
"""测试基础文本块"""
# 创建一个文本块
block = TextBlock(text="测试文本")
# 执行块
result = block.execute()
# 验证结果
assert "text" in result
assert result["text"] == "测试文本"
def test_text_concat_block():
"""测试文本拼接块"""
# 创建一个文本拼接块
block = TextConcatBlock()
# 执行块
result = block.execute(text1="Hello, ", text2="World!")
# 验证结果
assert "text" in result
assert result["text"] == "Hello, World!"
def test_text_replace_block():
"""测试文本替换块"""
# 创建一个文本替换块
block = TextReplaceBlock(variable="{name}")
# 执行块
result = block.execute(text="Hello, {name}!", new_text="ChatGPT")
# 验证结果
assert "text" in result
assert result["text"] == "Hello, ChatGPT!"
# 测试非字符串替换
result = block.execute(text="Count: {name}", new_text=42)
assert result["text"] == "Count: 42"
def test_text_extract_by_regex_block():
"""测试正则表达式提取块"""
# 创建一个正则表达式提取块
block = TextExtractByRegexBlock(regex=r"用户名:(\w+)")
# 执行块 - 匹配成功
result = block.execute(text="用户信息 - 用户名:testuser, 年龄:25")
# 验证结果
assert "text" in result
assert result["text"] == "testuser"
# 执行块 - 匹配失败
result = block.execute(text="没有用户名信息")
assert result["text"] == ""
def test_current_time_block():
"""测试当前时间块"""
# 创建一个当前时间块
block = CurrentTimeBlock()
# 使用 patch 来模拟当前时间
with patch('kirara_ai.workflow.implementations.blocks.system.basic.datetime') as mock_datetime:
# 设置模拟的当前时间
mock_datetime.now.return_value = datetime(2023, 1, 1, 12, 0, 0)
# 执行块
result = block.execute()
# 验证结果
assert "time" in result
assert result["time"] == "2023-01-01 12:00:00"
# 不使用 mock,测试实际时间格式
result = block.execute()
assert "time" in result
# 验证时间格式
time_pattern = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}"
assert re.match(time_pattern, result["time"]) is not None
================================================
FILE: tests/system_blocks/system/test_help.py
================================================
from unittest.mock import MagicMock
import pytest
from kirara_ai.im.message import IMMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.dispatch import CombinedDispatchRule, DispatchRuleRegistry, RuleGroup, SimpleDispatchRule
from kirara_ai.workflow.implementations.blocks.system.help import GenerateHelp
@pytest.fixture
def container():
"""创建一个带有模拟规则注册表的容器"""
container = DependencyContainer()
registry = MagicMock(spec=DispatchRuleRegistry)
container.register(DispatchRuleRegistry, registry)
return container, registry
def create_mock_rule(
rule_id: str, name: str, description: str, workflow_id: str, rule_groups: list
) -> CombinedDispatchRule:
"""创建模拟的组合规则"""
return CombinedDispatchRule(
rule_id=rule_id,
name=name,
description=description,
workflow_id=workflow_id,
rule_groups=rule_groups,
enabled=True,
priority=5,
metadata={},
)
def test_generate_help_basic(container):
"""测试基本的帮助信息生成"""
container, registry = container
# 创建模拟规则
help_rule = create_mock_rule(
rule_id="help_command",
name="帮助命令",
description="显示帮助信息",
workflow_id="system:help",
rule_groups=[
RuleGroup(
operator="and",
rules=[
SimpleDispatchRule(type="prefix", config={"prefix": "/help"}),
SimpleDispatchRule(
type="keyword", config={"keywords": ["帮助", "help"]}
),
],
)
],
)
chat_rule = create_mock_rule(
rule_id="chat_command",
name="聊天命令",
description="开始聊天",
workflow_id="chat:chat",
rule_groups=[
RuleGroup(
operator="or",
rules=[
SimpleDispatchRule(type="prefix", config={"prefix": "/chat"}),
SimpleDispatchRule(
type="keyword", config={"keywords": ["聊天", "对话"]}
),
],
)
],
)
registry.get_active_rules.return_value = [help_rule, chat_rule]
block = GenerateHelp()
block.container = container
result = block.execute()
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
help_text = response.content
# 检查帮助文本格式
assert "机器人命令帮助" in help_text
assert "SYSTEM" in help_text
assert "CHAT" in help_text
assert "帮助命令" in help_text
assert "聊天命令" in help_text
assert "/help" in help_text
assert "/chat" in help_text
assert "显示帮助信息" in help_text
assert "开始聊天" in help_text
assert "且" in help_text # 检查组合逻辑词
assert "或" in help_text
def test_generate_help_empty(container):
"""测试没有规则时的帮助信息生成"""
container, registry = container
registry.get_active_rules.return_value = []
block = GenerateHelp()
block.container = container
result = block.execute()
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
help_text = response.content
# 检查基本格式
assert "机器人命令帮助" in help_text
# 确保没有分类和命令
assert "📑" not in help_text
assert "🔸" not in help_text
def test_generate_help_no_description(container):
"""测试规则没有描述时的处理"""
container, registry = container
# 创建一个没有描述的规则
test_rule = create_mock_rule(
rule_id="test_command",
name="测试命令",
description="", # 空描述
workflow_id="test:test",
rule_groups=[
RuleGroup(
operator="or",
rules=[SimpleDispatchRule(type="prefix", config={"prefix": "/test"})],
)
],
)
registry.get_active_rules.return_value = [test_rule]
block = GenerateHelp()
block.container = container
result = block.execute()
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
help_text = response.content
# 应该显示命令,但没有描述部分
assert "测试命令" in help_text
assert "/test" in help_text
assert "说明:" not in help_text
def test_generate_help_complex_rules(container):
"""测试复杂组合规则的帮助信息生成"""
container, registry = container
# 创建一个包含复杂组合条件的规则
complex_rule = create_mock_rule(
rule_id="complex_command",
name="复杂命令",
description="这是一个复杂的命令",
workflow_id="test:complex",
rule_groups=[
RuleGroup(
operator="or",
rules=[
SimpleDispatchRule(type="prefix", config={"prefix": "/complex"}),
SimpleDispatchRule(
type="keyword", config={"keywords": ["复杂", "高级"]}
),
],
),
RuleGroup(
operator="and",
rules=[
SimpleDispatchRule(type="regex", config={"pattern": ".*test.*"}),
SimpleDispatchRule(type="keyword", config={"keywords": ["测试"]}),
],
),
],
)
registry.get_active_rules.return_value = [complex_rule]
block = GenerateHelp()
block.container = container
result = block.execute()
help_text = result["response"].content
# 检查复杂规则的格式
assert "复杂命令" in help_text
assert "这是一个复杂的命令" in help_text
assert "输入以 /complex 开头" in help_text
assert "输入包含 复杂 或 高级" in help_text
assert "输入匹配正则 .*test.*" in help_text
assert "输入包含 测试" in help_text
assert "并且" in help_text
assert "或" in help_text
================================================
FILE: tests/test_config_loader.py
================================================
import unittest
from unittest.mock import mock_open, patch
from pydantic import BaseModel
from kirara_ai.config.config_loader import ConfigLoader
class TestConfig(BaseModel):
__test__ = False
"""测试用配置类"""
name: str
value: int
class TestConfigLoader(unittest.TestCase):
def setUp(self):
self.test_config = TestConfig(name="test", value=123)
self.test_config_path = "test_config.yaml"
def test_save_config_with_backup(self):
"""测试保存配置文件时的备份功能"""
# Mock os.path.exists 返回 True,表示配置文件存在
with patch("os.path.exists", return_value=True) as mock_exists:
# Mock shutil.copy2 用于验证备份操作
with patch("shutil.copy2") as mock_copy:
# Mock open 和 yaml.dump 操作
mock_file = mock_open()
with patch("builtins.open", mock_file):
with patch.object(ConfigLoader.yaml, "dump") as mock_dump:
# 执行保存操作
ConfigLoader.save_config_with_backup(
self.test_config_path, self.test_config
)
# 验证是否检查了文件存在
mock_exists.assert_called_once_with(self.test_config_path)
# 验证是否创建了备份
mock_copy.assert_called_once_with(
self.test_config_path, f"{self.test_config_path}.bak"
)
# 验证是否打开了文件进行写入
mock_file.assert_called_with(
self.test_config_path, "w", encoding="utf-8"
)
# 验证是否调用了 yaml.dump
mock_dump.assert_called_once_with(
self.test_config.model_dump(), mock_file()
)
def test_save_config_without_backup(self):
"""测试当配置文件不存在时的保存操作(不应创建备份)"""
# Mock os.path.exists 返回 False,表示配置文件不存在
with patch("os.path.exists", return_value=False) as mock_exists:
# Mock shutil.copy2 用于验证备份操作
with patch("shutil.copy2") as mock_copy:
# Mock open 和 yaml.dump 操作
mock_file = mock_open()
with patch("builtins.open", mock_file):
with patch.object(ConfigLoader.yaml, "dump") as mock_dump:
# 执行保存操作
ConfigLoader.save_config_with_backup(
self.test_config_path, self.test_config
)
# 验证是否检查了文件存在
mock_exists.assert_called_once_with(self.test_config_path)
# 验证没有创建备份
mock_copy.assert_not_called()
# 验证是否打开了文件进行写入
mock_file.assert_called_with(
self.test_config_path, "w", encoding="utf-8"
)
# 验证是否调用了 yaml.dump
mock_dump.assert_called_once_with(
self.test_config.model_dump(), mock_file()
)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_game_blocks.py
================================================
from unittest.mock import MagicMock
import pytest
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.blocks.game.dice import DiceRoll
from kirara_ai.workflow.implementations.blocks.game.gacha import GachaSimulator
@pytest.fixture
def container():
return MagicMock(spec=DependencyContainer)
@pytest.fixture
def create_message():
def _create(content: str) -> IMMessage:
return IMMessage(sender=ChatSender.from_c2c_chat(user_id="test_user", display_name="Test User"), message_elements=[TextMessage(content)])
return _create
def test_dice_roll_basic(container, create_message):
"""测试基本的骰子命令"""
block = DiceRoll()
block.container = container
result = block.execute(create_message(".roll 2d6"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert response.sender == ChatSender.get_bot_sender()
assert len(response.message_elements) == 1
assert "掷出了 2d6" in response.content
def test_dice_roll_invalid(container, create_message):
"""测试无效的骰子命令"""
block = DiceRoll()
block.container = container
result = block.execute(create_message("invalid command"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "Invalid dice command" in response.content
def test_dice_roll_too_many(container, create_message):
"""测试超过限制的骰子数量"""
block = DiceRoll()
block.container = container
result = block.execute(create_message(".roll 101d6"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "Too many dice" in response.content
def test_gacha_single_pull(container, create_message):
"""测试单次抽卡"""
block = GachaSimulator()
block.container = container
result = block.execute(create_message("单抽"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "抽卡结果" in response.content
# 检查是否只有一个结果
result_text = response.content
assert len(result_text.split("、")) == 1
def test_gacha_ten_pull(container, create_message):
"""测试十连抽卡"""
block = GachaSimulator()
block.container = container
result = block.execute(create_message("十连"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
# 检查是否有十个结果
result_text = response.content
assert len(result_text.split("、")) == 10
def test_gacha_custom_rates(container, create_message):
"""测试自定义概率的抽卡"""
rates = {"SSR": 1.0, "SR": 0.0, "R": 0.0} # 100% SSR
block = GachaSimulator(rates)
block.container = container
result = block.execute(create_message("单抽"))
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
assert "SSR" in response.content
================================================
FILE: tests/test_mcp_server.py
================================================
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mcp import types
from kirara_ai.config.global_config import MCPServerConfig
from kirara_ai.mcp_module.models import MCPConnectionState
from kirara_ai.mcp_module.server import MCPServer
# 测试配置
@pytest.fixture
def stdio_config():
return MCPServerConfig(
id="test-stdio",
connection_type="stdio",
command="python",
args=["-m", "mcp.server"]
)
@pytest.fixture
def sse_config():
return MCPServerConfig(
id="test-sse",
connection_type="sse",
url="http://localhost:8000/sse"
)
@pytest.fixture
def invalid_config():
return MCPServerConfig(
id="test-invalid",
connection_type="invalid"
)
# 模拟 MCP 客户端会话
class MockClientSession:
def __init__(self):
self.initialize = AsyncMock()
self.list_tools = AsyncMock(return_value=types.ListToolsResult(tools=[]))
self.call_tool = AsyncMock(return_value=types.CallToolResult(content=[types.TextContent(text="114514", type="text")], isError=False))
self.complete = AsyncMock(return_value={})
self.get_prompt = AsyncMock(return_value="测试提示词")
self.list_prompts = AsyncMock(return_value=[])
self.list_resources = AsyncMock(return_value=[])
self.list_resource_templates = AsyncMock(return_value=[])
self.read_resource = AsyncMock(return_value="资源内容")
self.subscribe_resource = AsyncMock(return_value={})
self.unsubscribe_resource = AsyncMock(return_value={})
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
# 测试基本初始化
def test_init(stdio_config):
server = MCPServer(stdio_config)
assert server.server_config == stdio_config
assert server.session is None
assert server.state == MCPConnectionState.DISCONNECTED
assert server._lifecycle_task is None
assert not server._shutdown_event.is_set()
assert not server._connected_event.is_set()
# 测试连接和断开连接
@pytest.mark.asyncio
async def test_connect_disconnect_stdio(stdio_config):
with patch("kirara_ai.mcp_module.server.stdio_client") as mock_stdio_client, \
patch("kirara_ai.mcp_module.server.ClientSession", return_value=MockClientSession()):
# 设置模拟返回值
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock()))
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_stdio_client.return_value = mock_client
server = MCPServer(stdio_config)
# 测试连接
connect_result = await server.connect()
assert connect_result is True
assert server.state == MCPConnectionState.CONNECTED
# 测试断开连接
disconnect_result = await server.disconnect()
assert disconnect_result is True
assert server.state == MCPConnectionState.DISCONNECTED
@pytest.mark.asyncio
async def test_connect_disconnect_sse(sse_config):
with patch("kirara_ai.mcp_module.server.sse_client") as mock_sse_client, \
patch("kirara_ai.mcp_module.server.ClientSession", return_value=MockClientSession()):
# 设置模拟返回值
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock()))
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_sse_client.return_value = mock_client
server = MCPServer(sse_config)
# 测试连接
connect_result = await server.connect()
assert connect_result is True
assert server.state == MCPConnectionState.CONNECTED
# 测试断开连接
disconnect_result = await server.disconnect()
assert disconnect_result is True
assert server.state == MCPConnectionState.DISCONNECTED
@pytest.mark.asyncio
async def test_connect_invalid_config(invalid_config):
server = MCPServer(invalid_config)
connect_result = await server.connect()
assert connect_result is False
assert server.state == MCPConnectionState.ERROR
# 测试连接超时
@pytest.mark.asyncio
async def test_connect_timeout(stdio_config):
with patch("kirara_ai.mcp_module.server.stdio_client") as mock_stdio_client, \
patch("kirara_ai.mcp_module.server.ClientSession") as mock_session:
# 设置模拟返回值,但不设置连接完成事件
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(side_effect=lambda: asyncio.sleep(60)) # 模拟长时间操作
mock_stdio_client.return_value = mock_client
server = MCPServer(stdio_config)
# 修改超时时间以加快测试
with patch.object(asyncio, "wait_for", side_effect=asyncio.TimeoutError):
connect_result = await server.connect()
assert connect_result is False
# 测试工具相关方法
@pytest.mark.asyncio
async def test_tool_methods(stdio_config):
with patch("kirara_ai.mcp_module.server.stdio_client") as mock_stdio_client, \
patch("kirara_ai.mcp_module.server.ClientSession", return_value=MockClientSession()):
# 设置模拟返回值
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock()))
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_stdio_client.return_value = mock_client
server = MCPServer(stdio_config)
await server.connect()
# 测试获取工具列表
tools = await server.get_tools()
assert isinstance(tools, types.ListToolsResult)
# 测试调用工具
result = await server.call_tool("test_tool", {"arg": "value"})
assert isinstance(result, types.CallToolResult)
await server.disconnect()
# 测试补全方法
@pytest.mark.asyncio
async def test_complete(stdio_config):
with patch("kirara_ai.mcp_module.server.stdio_client") as mock_stdio_client, \
patch("kirara_ai.mcp_module.server.ClientSession", return_value=MockClientSession()):
# 设置模拟返回值
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock()))
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_stdio_client.return_value = mock_client
server = MCPServer(stdio_config)
await server.connect()
result = await server.complete("test_prompt", {"temperature": 0.7})
assert result == {}
await server.disconnect()
# 测试提示词相关方法
@pytest.mark.asyncio
async def test_prompt_methods(stdio_config):
with patch("kirara_ai.mcp_module.server.stdio_client") as mock_stdio_client, \
patch("kirara_ai.mcp_module.server.ClientSession", return_value=MockClientSession()):
# 设置模拟返回值
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock()))
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_stdio_client.return_value = mock_client
server = MCPServer(stdio_config)
await server.connect()
# 测试获取提示词
prompt = await server.get_prompt("test_prompt", {})
assert prompt == "测试提示词"
# 测试获取提示词列表
prompts = await server.list_prompts()
assert prompts == []
await server.disconnect()
# 测试资源相关方法
@pytest.mark.asyncio
async def test_resource_methods(stdio_config):
with patch("kirara_ai.mcp_module.server.stdio_client") as mock_stdio_client, \
patch("kirara_ai.mcp_module.server.ClientSession", return_value=MockClientSession()):
# 设置模拟返回值
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock()))
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_stdio_client.return_value = mock_client
server = MCPServer(stdio_config)
await server.connect()
# 测试获取资源列表
resources = await server.list_resources()
assert resources == []
# 测试获取资源模板列表
templates = await server.list_resource_templates()
assert templates == []
# 测试读取资源
content = await server.read_resource("http://localhost/test-resource")
assert content == "资源内容"
# 测试订阅资源
sub_result = await server.subscribe_resource("http://localhost/test-resource")
assert sub_result == {}
# 测试取消订阅资源
unsub_result = await server.unsubscribe_resource("http://localhost/test-resource")
assert unsub_result == {}
await server.disconnect()
================================================
FILE: tests/test_media.py
================================================
import asyncio
import os
import shutil
import tempfile
import unittest
from pathlib import Path
from kirara_ai.im.message import ImageMessage, VoiceMessage
from kirara_ai.media import MediaManager, MediaType
class TestMediaManager(unittest.TestCase):
"""测试媒体管理器"""
def setUp(self):
"""测试前准备"""
# 创建临时目录
self.temp_dir = tempfile.mkdtemp()
self.media_dir = os.path.join(self.temp_dir, "media")
# 创建各种格式的测试文件
self.format_files = {}
# 图片格式
self.format_files["jpeg"] = os.path.join(self.temp_dir, "test.jpg")
with open(self.format_files["jpeg"], "wb") as f:
f.write(b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xdb\x00C")
self.format_files["png"] = os.path.join(self.temp_dir, "test.png")
with open(self.format_files["png"], "wb") as f:
f.write(b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89")
self.format_files["gif"] = os.path.join(self.temp_dir, "test.gif")
with open(self.format_files["gif"], "wb") as f:
f.write(b"GIF89a\x01\x00\x01\x00\x80\x00\x00\xff\xff\xff\x00\x00\x00!\xf9\x04\x01\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02D\x01\x00;")
self.format_files["webp"] = os.path.join(self.temp_dir, "test.webp")
with open(self.format_files["webp"], "wb") as f:
f.write(b"RIFF\x1a\x00\x00\x00WEBPVP8 \x0e\x00\x00\x00\x10\x00\x00\x00\x10\x00\x00\x00\x01\x00\x02\x00\x02\x00\x34\x25\xa4\x00\x03p\x00\xfe\xfb\xfd\x50\x00")
# 音频格式
self.format_files["mp3"] = os.path.join(self.temp_dir, "test.mp3")
with open(self.format_files["mp3"], "wb") as f:
f.write(b"\xFF\xFB\x90\x64\x00\x00\x00\x00")
self.format_files["wav"] = os.path.join(self.temp_dir, "test.wav")
with open(self.format_files["wav"], "wb") as f:
f.write(b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x11+\x00\x00\x11+\x00\x00\x01\x00\x08\x00data\x00\x00\x00\x00")
# 视频格式
self.format_files["mp4"] = os.path.join(self.temp_dir, "test.mp4")
with open(self.format_files["mp4"], "wb") as f:
f.write(b"\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42mp41\x00\x00\x00\x00moov")
self.format_files["avi"] = os.path.join(self.temp_dir, "test.avi")
with open(self.format_files["avi"], "wb") as f:
f.write(b"RIFF\x00\x00\x00\x00AVI LIST\x00\x00\x00\x00hdrlavih\x00\x00\x00\x00")
# 文档格式
self.format_files["pdf"] = os.path.join(self.temp_dir, "test.pdf")
with open(self.format_files["pdf"], "wb") as f:
f.write(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n1 0 obj\n<>\nendobj\n2 0 obj\n<>\nendobj\n3 0 obj\n<>>>\nendobj\nxref\n0 4\n0000000000 65535 f\n0000000015 00000 n\n0000000060 00000 n\n0000000111 00000 n\ntrailer\n<>\nstartxref\n178\n%%EOF\n")
self.format_files["txt"] = os.path.join(self.temp_dir, "test.txt")
with open(self.format_files["txt"], "wb") as f:
f.write(b"This is a test text file.")
# 使用已创建的文件作为测试文件
self.test_image_path = self.format_files["jpeg"]
self.test_audio_path = self.format_files["mp3"]
# 创建媒体管理器
self.media_manager = MediaManager(media_dir=self.media_dir)
def tearDown(self):
"""测试后清理"""
# 删除临时目录
shutil.rmtree(self.temp_dir)
def test_register_from_path(self):
"""测试从文件路径注册媒体"""
media_id = asyncio.run(self.media_manager.register_from_path(
self.test_image_path,
source="test",
description="测试图片",
tags=["test", "image"],
reference_id="test_ref"
))
# 验证媒体ID是否有效
self.assertIsNotNone(media_id)
# 验证元数据是否正确
metadata = self.media_manager.get_metadata(media_id)
self.assertIsNotNone(metadata)
self.assertEqual(metadata.media_type, MediaType.IMAGE)
self.assertEqual(metadata.source, "test")
self.assertEqual(metadata.description, "测试图片")
self.assertEqual(metadata.tags, ["test", "image"])
self.assertEqual(metadata.references, {"test_ref"})
# 验证文件是否存在
file_path = asyncio.run(self.media_manager.get_file_path(media_id))
self.assertIsNotNone(file_path)
self.assertTrue(file_path.exists())
def test_register_from_data(self):
"""测试从二进制数据注册媒体"""
with open(self.test_image_path, "rb") as f:
data = f.read()
media_id = asyncio.run(self.media_manager.register_from_data(
data,
format="jpeg",
source="test_data",
description="测试数据图片",
tags=["test", "data"],
reference_id="test_data_ref"
))
# 验证媒体ID是否有效
self.assertIsNotNone(media_id)
# 验证元数据是否正确
metadata = self.media_manager.get_metadata(media_id)
self.assertIsNotNone(metadata)
self.assertEqual(metadata.media_type, MediaType.IMAGE)
self.assertEqual(metadata.source, "test_data")
self.assertEqual(metadata.description, "测试数据图片")
self.assertEqual(metadata.tags, ["test", "data"])
self.assertEqual(metadata.references, {"test_data_ref"})
def test_register_from_url(self):
"""测试从URL注册媒体"""
# 使用本地文件URL作为测试
file_url = f"file://{Path(self.test_image_path).absolute()}"
media_id = asyncio.run(self.media_manager.register_from_url(
file_url,
source="test_url",
description="测试URL图片",
tags=["test", "url"],
reference_id="test_url_ref"
))
# 验证媒体ID是否有效
self.assertIsNotNone(media_id)
# 验证元数据是否正确
metadata = self.media_manager.get_metadata(media_id)
self.assertIsNotNone(metadata)
self.assertEqual(metadata.url, file_url)
self.assertEqual(metadata.source, "test_url")
self.assertEqual(metadata.description, "测试URL图片")
self.assertEqual(metadata.tags, ["test", "url"])
self.assertEqual(metadata.references, {"test_url_ref"})
# 获取数据(这会触发下载)
data = asyncio.run(self.media_manager.get_data(media_id))
self.assertIsNotNone(data)
# 再次检查元数据,应该有更多信息
metadata = self.media_manager.get_metadata(media_id)
self.assertIsNotNone(metadata.media_type)
self.assertIsNotNone(metadata.format)
def test_format_detection(self):
"""测试不同格式文件的类型检测"""
# 图片格式测试
for format_name in ["jpeg", "png", "gif", "webp"]:
media_id = asyncio.run(self.media_manager.register_from_path(
self.format_files[format_name],
reference_id=f"test_{format_name}"
))
metadata = self.media_manager.get_metadata(media_id)
self.assertEqual(metadata.media_type, MediaType.IMAGE, f"格式 {format_name} 应该被识别为图片")
self.assertEqual(metadata.format.lower(), format_name.lower(), f"格式 {format_name} 未被正确识别")
# 音频格式测试
for format_name in ["mp3", "wav"]:
media_id = asyncio.run(self.media_manager.register_from_path(
self.format_files[format_name],
reference_id=f"test_{format_name}"
))
metadata = self.media_manager.get_metadata(media_id)
self.assertEqual(metadata.media_type, MediaType.AUDIO, f"格式 {format_name} 应该被识别为音频")
self.assertEqual(metadata.format.lower(), format_name.lower(), f"格式 {format_name} 未被正确识别")
# 视频格式测试
for format_name in ["mp4", "avi"]:
media_id = asyncio.run(self.media_manager.register_from_path(
self.format_files[format_name],
reference_id=f"test_{format_name}"
))
metadata = self.media_manager.get_metadata(media_id)
self.assertEqual(metadata.media_type, MediaType.VIDEO, f"格式 {format_name} 应该被识别为视频")
self.assertEqual(metadata.format.lower(), format_name.lower() if format_name != "avi" else "x-msvideo", f"格式 {format_name} 未被正确识别")
# 文档格式测试
for format_name in ["pdf", "txt"]:
media_id = asyncio.run(self.media_manager.register_from_path(
self.format_files[format_name],
reference_id=f"test_{format_name}"
))
metadata = self.media_manager.get_metadata(media_id)
self.assertEqual(metadata.media_type, MediaType.FILE, f"格式 {format_name} 应该被识别为文件")
expected_format = format_name
if format_name == "txt":
expected_format = "plain"
self.assertTrue(metadata.format.lower().endswith(expected_format.lower()), f"格式 {format_name} 未被正确识别,实际为 {metadata.format}")
def test_reference_management(self):
"""测试引用管理"""
# 注册媒体
media_id = asyncio.run(self.media_manager.register_from_path(
self.test_image_path,
reference_id="ref1"
))
# 添加引用
self.media_manager.add_reference(media_id, "ref2")
# 验证引用是否添加成功
metadata = self.media_manager.get_metadata(media_id)
self.assertEqual(metadata.references, {"ref1", "ref2"})
# 移除引用
self.media_manager.remove_reference(media_id, "ref1")
# 验证引用是否移除成功
metadata = self.media_manager.get_metadata(media_id)
self.assertEqual(metadata.references, {"ref2"})
# 移除最后一个引用,媒体应该被删除
self.media_manager.remove_reference(media_id, "ref2")
# 验证媒体是否被删除
self.assertIsNone(self.media_manager.get_metadata(media_id))
def test_search(self):
"""测试搜索功能"""
# 注册多个媒体
media_id1 = asyncio.run(self.media_manager.register_from_path(
self.test_image_path,
source="source1",
description="description with keyword1",
tags=["tag1", "common"],
reference_id="ref1"
))
media_id2 = asyncio.run(self.media_manager.register_from_path(
self.test_audio_path,
source="source2",
description="description with keyword2",
tags=["tag2", "common"],
reference_id="ref2"
))
# 根据标签搜索
results = self.media_manager.search_by_tags(["tag1"])
self.assertEqual(results, [media_id1])
results = self.media_manager.search_by_tags(["common"])
self.assertEqual(set(results), {media_id1, media_id2})
# 根据描述搜索
results = self.media_manager.search_by_description("keyword1")
self.assertEqual(results, [media_id1])
# 根据来源搜索
results = self.media_manager.search_by_source("source2")
self.assertEqual(results, [media_id2])
# 根据类型搜索
results = self.media_manager.search_by_type(MediaType.IMAGE)
self.assertEqual(results, [media_id1])
results = self.media_manager.search_by_type(MediaType.AUDIO)
self.assertEqual(results, [media_id2])
def test_media_message(self):
"""测试MediaMessage类"""
# 创建只有URL的媒体消息
file_url = f"file://{Path(self.test_image_path).absolute()}"
url_message = ImageMessage(url=file_url, reference_id="url_message_ref", media_manager=self.media_manager)
# 验证媒体ID
self.assertIsNotNone(url_message.media_id)
# 获取URL(应该直接返回原始URL)
url = asyncio.run(url_message.get_url())
self.assertEqual(url, file_url)
# 获取路径(应该触发下载)
path = asyncio.run(url_message.get_path())
self.assertIsNotNone(path)
self.assertTrue(Path(path).exists())
# 创建只有路径的媒体消息
path_message = ImageMessage(path=self.test_image_path, reference_id="path_message_ref", media_manager=self.media_manager)
# 验证媒体ID
self.assertIsNotNone(path_message.media_id)
# 获取路径(应该直接返回原始路径或复制后的路径)
path = asyncio.run(path_message.get_path())
self.assertIsNotNone(path)
# 获取URL(应该生成URL)
url = asyncio.run(path_message.get_url())
self.assertIsNotNone(url)
# 创建只有数据的媒体消息
with open(self.test_image_path, "rb") as f:
data = f.read()
data_message = ImageMessage(data=data, format="jpeg", reference_id="data_message_ref", media_manager=self.media_manager)
# 验证媒体ID
self.assertIsNotNone(data_message.media_id)
# 获取数据(应该直接返回原始数据)
message_data = asyncio.run(data_message.get_data())
self.assertEqual(message_data, data)
# 获取路径(应该生成文件)
path = asyncio.run(data_message.get_path())
self.assertIsNotNone(path)
self.assertTrue(Path(path).exists())
def test_media_message_with_different_formats(self):
"""测试不同格式的媒体消息创建"""
# 测试不同格式的图片
for format_name in ["jpeg", "png", "gif", "webp"]:
message = ImageMessage(path=self.format_files[format_name], reference_id=f"message_{format_name}_ref", media_manager=self.media_manager)
self.assertIsNotNone(message.media_id)
self.assertEqual(message.resource_type, "image")
# 获取元数据
metadata = self.media_manager.get_metadata(message.media_id)
self.assertEqual(metadata.media_type, MediaType.IMAGE)
# 测试不同格式的音频
for format_name in ["mp3", "wav"]:
message = VoiceMessage(path=self.format_files[format_name], reference_id=f"message_{format_name}_ref", media_manager=self.media_manager)
self.assertIsNotNone(message.media_id)
self.assertEqual(message.resource_type, "audio")
# 获取元数据
metadata = self.media_manager.get_metadata(message.media_id)
self.assertEqual(metadata.media_type, MediaType.AUDIO)
================================================
FILE: tests/test_media_element.py
================================================
import os
import tempfile
from unittest.mock import patch
import pytest
from kirara_ai.im.message import ImageMessage
from kirara_ai.media.manager import MediaManager
# 测试资源路径
TEST_RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "resources", "test_image.txt")
TEST_URL = "https://httpbin.org/image/jpeg" # 一个可用的测试图片URL
temp_dir = tempfile.mkdtemp()
media_dir = os.path.join(temp_dir, "media")
# 创建媒体管理器
media_manager = MediaManager(media_dir=media_dir)
@pytest.mark.asyncio
async def test_media_element_from_path():
# 测试从文件路径初始化
media = ImageMessage(path=TEST_RESOURCE_PATH)
# 测试获取数据
data = await media.get_data()
assert data is not None
assert isinstance(data, bytes)
# 测试获取URL (data URL格式)
url = await media.get_url()
assert url.startswith("data:")
assert "base64" in url
# 测试获取路径
path = await media.get_path()
assert os.path.exists(path)
assert os.path.isfile(path)
@pytest.mark.asyncio
async def test_media_element_from_url():
# 测试从URL初始化
media = ImageMessage(url=TEST_URL)
# 测试获取数据
data = await media.get_data()
assert data is not None
assert isinstance(data, bytes)
# 测试获取原始URL
url = await media.get_url()
assert url == TEST_URL
# 测试获取临时文件路径
path = await media.get_path()
try:
assert os.path.exists(path)
assert os.path.isfile(path)
finally:
os.remove(path)
@pytest.mark.asyncio
async def test_media_element_from_data():
# 首先从文件读取一些测试数据
with open(TEST_RESOURCE_PATH, "rb") as f:
test_data = f.read()
# 测试从二进制数据初始化
media = ImageMessage(data=test_data, format="txt")
# 测试获取数据
data = await media.get_data()
assert data == test_data
# 测试获取URL (应该是data URL)
url = await media.get_url()
assert url.startswith("data:")
assert "base64" in url
# 测试获取临时文件路径
path = await media.get_path()
assert os.path.exists(path)
assert os.path.isfile(path)
@pytest.mark.asyncio
async def test_media_element_format_detection():
# 测试格式自动检测
media = ImageMessage(path=TEST_RESOURCE_PATH)
await media.get_data() # 触发格式检测
assert media.format is not None
assert media.resource_type is not None
@pytest.mark.asyncio
async def test_media_element_errors():
# 测试错误情况
with pytest.raises(ValueError):
ImageMessage() # 没有提供任何参数
with pytest.raises(ValueError):
# 使用mock模拟网络请求失败
with patch('curl_cffi.AsyncSession.get') as mock_get:
mock_get.side_effect = ValueError("Mocked network error")
media = ImageMessage(url="https://valid-url-but-will-fail.com/image.jpg")
await media.get_data() # 模拟网络请求失败
================================================
FILE: tests/test_system_blocks.py
================================================
from unittest.mock import MagicMock
import pytest
from kirara_ai.im.message import IMMessage
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.dispatch import CombinedDispatchRule, DispatchRuleRegistry, RuleGroup, SimpleDispatchRule
from kirara_ai.workflow.implementations.blocks.system.help import GenerateHelp
@pytest.fixture
def container():
container = MagicMock(spec=DependencyContainer)
registry = MagicMock(spec=DispatchRuleRegistry)
container.resolve.return_value = registry
return container, registry
def create_mock_rule(
rule_id: str, name: str, description: str, workflow_id: str, rule_groups: list
) -> CombinedDispatchRule:
"""创建模拟的组合规则"""
return CombinedDispatchRule(
rule_id=rule_id,
name=name,
description=description,
workflow_id=workflow_id,
rule_groups=rule_groups,
enabled=True,
priority=5,
metadata={},
)
def test_generate_help_basic(container):
"""测试基本的帮助信息生成"""
container, registry = container
# 创建模拟规则
help_rule = create_mock_rule(
rule_id="help_command",
name="帮助命令",
description="显示帮助信息",
workflow_id="system:help",
rule_groups=[
RuleGroup(
operator="and",
rules=[
SimpleDispatchRule(type="prefix", config={"prefix": "/help"}),
SimpleDispatchRule(
type="keyword", config={"keywords": ["帮助", "help"]}
),
],
)
],
)
chat_rule = create_mock_rule(
rule_id="chat_command",
name="聊天命令",
description="开始聊天",
workflow_id="chat:chat",
rule_groups=[
RuleGroup(
operator="or",
rules=[
SimpleDispatchRule(type="prefix", config={"prefix": "/chat"}),
SimpleDispatchRule(
type="keyword", config={"keywords": ["聊天", "对话"]}
),
],
)
],
)
registry.get_active_rules.return_value = [help_rule, chat_rule]
block = GenerateHelp()
block.container = container
result = block.execute()
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
help_text = response.content
# 检查帮助文本格式
assert "机器人命令帮助" in help_text
assert "SYSTEM" in help_text
assert "CHAT" in help_text
assert "帮助命令" in help_text
assert "聊天命令" in help_text
assert "/help" in help_text
assert "/chat" in help_text
assert "显示帮助信息" in help_text
assert "开始聊天" in help_text
assert "且" in help_text # 检查组合逻辑词
assert "或" in help_text
def test_generate_help_empty(container):
"""测试没有规则时的帮助信息生成"""
container, registry = container
registry.get_active_rules.return_value = []
block = GenerateHelp()
block.container = container
result = block.execute()
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
help_text = response.content
# 检查基本格式
assert "机器人命令帮助" in help_text
# 确保没有分类和命令
assert "📑" not in help_text
assert "🔸" not in help_text
def test_generate_help_no_description(container):
"""测试规则没有描述时的处理"""
container, registry = container
# 创建一个没有描述的规则
test_rule = create_mock_rule(
rule_id="test_command",
name="测试命令",
description="", # 空描述
workflow_id="test:test",
rule_groups=[
RuleGroup(
operator="or",
rules=[SimpleDispatchRule(type="prefix", config={"prefix": "/test"})],
)
],
)
registry.get_active_rules.return_value = [test_rule]
block = GenerateHelp()
block.container = container
result = block.execute()
assert "response" in result
response = result["response"]
assert isinstance(response, IMMessage)
help_text = response.content
# 应该显示命令,但没有描述部分
assert "测试命令" in help_text
assert "/test" in help_text
assert "说明:" not in help_text
def test_generate_help_complex_rules(container):
"""测试复杂组合规则的帮助信息生成"""
container, registry = container
# 创建一个包含复杂组合条件的规则
complex_rule = create_mock_rule(
rule_id="complex_command",
name="复杂命令",
description="这是一个复杂的命令",
workflow_id="test:complex",
rule_groups=[
RuleGroup(
operator="or",
rules=[
SimpleDispatchRule(type="prefix", config={"prefix": "/complex"}),
SimpleDispatchRule(
type="keyword", config={"keywords": ["复杂", "高级"]}
),
],
),
RuleGroup(
operator="and",
rules=[
SimpleDispatchRule(type="regex", config={"pattern": ".*test.*"}),
SimpleDispatchRule(type="keyword", config={"keywords": ["测试"]}),
],
),
],
)
registry.get_active_rules.return_value = [complex_rule]
block = GenerateHelp()
block.container = container
result = block.execute()
help_text = result["response"].content
# 检查复杂规则的格式
assert "复杂命令" in help_text
assert "这是一个复杂的命令" in help_text
assert "输入以 /complex 开头" in help_text
assert "输入包含 复杂 或 高级" in help_text
assert "输入匹配正则 .*test.*" in help_text
assert "输入包含 测试" in help_text
assert "并且" in help_text
assert "或" in help_text
================================================
FILE: tests/test_workflow_builder.py
================================================
import os
import warnings
from typing import Any, Dict
import pytest
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
from kirara_ai.workflow.core.block.registry import BlockRegistry
from kirara_ai.workflow.core.workflow.builder import WorkflowBuilder
# 测试用的 Block 类
class SimpleInputBlock(Block):
"""简单的输入块"""
name: str = "simple_input"
inputs: Dict[str, Input] = {"param1": Input("param1", "输入1", str, "Input 1")}
outputs: Dict[str, Output] = {"out1": Output("out1", "输出1", str, "Output 1")}
def __init__(self, param1: str = "default"):
super().__init__()
self.param1 = param1
def execute(self) -> Dict[str, Any]:
return {"out1": self.param1}
class SimpleProcessBlock(Block):
"""简单的处理块"""
name: str = "simple_process"
inputs: Dict[str, Input] = {"in1": Input("in1", "输入1", str, "Input 1")}
outputs: Dict[str, Output] = {"out1": Output("out1", "输出1", str, "Output 1")}
def __init__(self, multiplier: int = 1):
super().__init__()
self.multiplier = multiplier
def execute(self, in1: str) -> Dict[str, Any]:
return {"out1": in1 * self.multiplier}
def setup_module(module):
"""测试模块开始前的设置"""
registry = BlockRegistry()
# 注册测试用的 block
registry.register("simple_input", "test", SimpleInputBlock)
registry.register("simple_process", "test", SimpleProcessBlock)
def teardown_module(module):
"""测试模块结束后的清理"""
BlockRegistry().clear()
class TestWorkflowBuilder:
@pytest.fixture
def container(self):
container = DependencyContainer()
registry = BlockRegistry()
container.register(BlockRegistry, registry)
# 注册测试用的 block
registry.register("simple_input", "test", SimpleInputBlock)
registry.register("simple_process", "test", SimpleProcessBlock)
return container
@pytest.fixture
def yaml_path(self):
path = "test_workflow.yaml"
yield path
# 清理测试文件
if os.path.exists(path):
os.remove(path)
def test_basic_dsl_construction(self, container):
"""测试基本的 DSL 构建功能"""
builder = (
WorkflowBuilder("test_workflow")
.use(SimpleInputBlock, name="input1", param1="test")
.chain(SimpleProcessBlock, name="process1", multiplier=2)
)
workflow = builder.build(container)
assert len(workflow.blocks) == 2
assert len(workflow.wires) == 1
assert workflow.blocks[0].name == "input1"
assert workflow.blocks[1].name == "process1"
def test_parallel_construction(self, container):
"""测试并行节点构建"""
builder = (
WorkflowBuilder("test_workflow")
.use(SimpleInputBlock)
.parallel(
[
(SimpleProcessBlock, "process1", {"multiplier": 2}),
(SimpleProcessBlock, "process2", {"multiplier": 3}),
]
)
)
workflow = builder.build(container)
assert len(workflow.blocks) == 3
assert len(workflow.wires) == 2
assert any(block.name == "process1" for block in workflow.blocks)
assert any(block.name == "process2" for block in workflow.blocks)
def test_save_and_load(self, container, yaml_path):
"""测试工作流的保存和加载"""
# 构建原始工作流
original_builder = (
WorkflowBuilder("test_workflow")
.use(SimpleInputBlock, name="input1", param1="test")
.parallel(
[
(SimpleProcessBlock, "process1", {"multiplier": 2}),
(SimpleProcessBlock, "process2", {"multiplier": 3}),
]
)
)
# 保存工作流
original_builder.save_to_yaml(yaml_path, container)
# 加载工作流
loaded_builder = WorkflowBuilder.load_from_yaml(yaml_path, container)
loaded_workflow = loaded_builder.build(container)
# 验证加载后的工作流
assert len(loaded_workflow.blocks) == 3
assert loaded_workflow.name == "test_workflow"
# 验证参数是否正确加载
input_block = next(b for b in loaded_workflow.blocks if b.name == "input1")
assert input_block.param1 == "test"
process1 = next(b for b in loaded_workflow.blocks if b.name == "process1")
assert process1.multiplier == 2
process2 = next(b for b in loaded_workflow.blocks if b.name == "process2")
assert process2.multiplier == 3
def test_complex_workflow_serialization(self, container, yaml_path):
"""测试复杂工作流的序列化"""
# 构建一个包含多种特性的复杂工作流
builder = (
WorkflowBuilder("complex_workflow")
.use(SimpleInputBlock, name="start", param1="init")
.parallel(
[
(SimpleProcessBlock, "parallel1", {"multiplier": 2}),
(SimpleProcessBlock, "parallel2", {"multiplier": 3}),
]
)
.chain(
SimpleProcessBlock,
name="final",
wire_from=["parallel1", "parallel2"],
multiplier=1,
)
)
# 保存工作流
builder.save_to_yaml(yaml_path, container)
# 加载工作流
loaded_builder = WorkflowBuilder.load_from_yaml(yaml_path, container)
loaded_workflow = loaded_builder.build(container)
# 验证结构
assert len(loaded_workflow.blocks) == 4
assert len(loaded_workflow.wires) >= 3 # 至少应该有3个连接
# 验证特定节点的存在和配置
assert any(b.name == "start" for b in loaded_workflow.blocks)
assert any(b.name == "parallel1" for b in loaded_workflow.blocks)
assert any(b.name == "parallel2" for b in loaded_workflow.blocks)
assert any(b.name == "final" for b in loaded_workflow.blocks)
def test_invalid_yaml_handling(self, container):
"""测试处理无效 YAML 文件的情况"""
with pytest.raises(Exception):
WorkflowBuilder.load_from_yaml("non_existent_file.yaml", container)
def test_unregistered_block_warning(self, container, yaml_path):
"""测试未注册 block 的警告"""
# 获取 registry 并清空
registry = container.resolve(BlockRegistry)
registry.clear()
with pytest.warns(UserWarning):
builder = WorkflowBuilder("test_workflow").use(SimpleInputBlock)
builder.save_to_yaml(yaml_path, container)
def test_registered_block_no_warning(self, container, yaml_path):
"""测试已注册 block 不会产生警告"""
registry = container.resolve(BlockRegistry)
registry.clear()
registry.register("simple_input", "test", SimpleInputBlock)
with warnings.catch_warnings():
warnings.simplefilter("error")
builder = WorkflowBuilder("test_workflow").use(SimpleInputBlock)
builder.save_to_yaml(yaml_path, container)
================================================
FILE: tests/test_workflow_factories.py
================================================
from unittest.mock import MagicMock
import pytest
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.implementations.factories.game_factory import GameWorkflowFactory
from kirara_ai.workflow.implementations.factories.system_factory import SystemWorkflowFactory
@pytest.fixture
def container():
return MagicMock(spec=DependencyContainer)
def test_game_dice_workflow(container):
"""测试骰子游戏工作流创建"""
workflow = GameWorkflowFactory.create_dice_workflow().build(container)
# 验证工作流结构
assert workflow.name == "骰子游戏"
assert len(workflow.blocks) == 3 # GetIMMessage -> DiceRoll -> SendIMMessage
# 验证连接
assert len(workflow.wires) == 2 # 两个连接
def test_game_gacha_workflow(container):
"""测试抽卡游戏工作流创建"""
workflow = GameWorkflowFactory.create_gacha_workflow().build(container)
# 验证工作流结构
assert workflow.name == "抽卡游戏"
assert len(workflow.blocks) == 3 # GetIMMessage -> GachaSimulator -> SendIMMessage
# 验证连接
assert len(workflow.wires) == 2 # 两个连接
def test_system_help_workflow(container):
"""测试帮助信息工作流创建"""
workflow = SystemWorkflowFactory.create_help_workflow().build(container)
# 验证工作流结构
assert workflow.name == "帮助信息"
assert len(workflow.blocks) == 2 # GenerateHelp -> SendIMMessage
# 验证连接
assert len(workflow.wires) == 1 # 一个连接
================================================
FILE: tests/tracing/__init__.py
================================================
"""Tracing framework tests"""
================================================
FILE: tests/tracing/test_base.py
================================================
import unittest
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import Column, Integer
from kirara_ai.config.global_config import GlobalConfig
from kirara_ai.database import DatabaseManager
from kirara_ai.database.manager import Base
from kirara_ai.events.event_bus import EventBus
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.format.message import LLMChatMessage, LLMChatTextContent
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.tracing.core import TraceRecord
from kirara_ai.tracing.models import LLMRequestTrace
class TestTraceRecord(TraceRecord):
"""用于测试的追踪记录类"""
__test__ = False
__tablename__ = "test_traces"
__table_args__ = {'extend_existing': True}
id = Column(Integer, primary_key=True)
def update_from_event(self, event):
pass
def to_dict(self) -> Dict[str, Any]:
return {}
def to_detail_dict(self) -> Dict[str, Any]:
return {}
class TracingTestBase(unittest.TestCase):
"""追踪系统测试基类"""
def setUp(self):
"""测试前的准备工作"""
self.container = DependencyContainer()
self.container.register(DependencyContainer, self.container)
self.event_bus = EventBus()
self.container.register(EventBus, self.event_bus)
self.container.register(GlobalConfig, GlobalConfig())
# 使用内存数据库进行测试
self.db_manager = DatabaseManager(self.container, database_url="sqlite:///:memory:", is_debug=True)
self.db_manager.initialize()
Base.metadata.create_all(self.db_manager.engine)
self.container.register(DatabaseManager, self.db_manager)
def tearDown(self):
"""测试后的清理工作"""
self.db_manager.shutdown()
def create_test_request(self, model: str = "test-model") -> LLMChatRequest:
"""创建测试用的LLM请求"""
return LLMChatRequest(
model=model,
messages=[LLMChatMessage(role="user", content=[LLMChatTextContent(text="test message")])]
)
def create_test_response(self, usage: Optional[Usage] = None) -> LLMChatResponse:
"""创建测试用的LLM响应"""
if usage is None:
usage = Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)
return LLMChatResponse(
model="test-model",
message=Message(role="assistant", content=[LLMChatTextContent(text="test response")]),
usage=usage
)
def create_test_trace(self) -> LLMRequestTrace:
"""创建测试用的追踪记录"""
trace = LLMRequestTrace()
trace.trace_id = "test-trace-id"
trace.model_id = "test-model"
trace.backend_name = "test-backend"
trace.request_time = datetime.now()
return trace
================================================
FILE: tests/tracing/test_core.py
================================================
from datetime import datetime
from unittest import IsolatedAsyncioTestCase
from kirara_ai.events.tracing import TraceEvent
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.ioc.inject import Inject
from kirara_ai.tracing.core import TracerBase, generate_trace_id
from tests.tracing.test_base import TestTraceRecord, TracingTestBase
class TestEvent(TraceEvent):
"""用于测试的事件类"""
__test__ = False
def __init__(self, trace_id: str):
self.trace_id = trace_id
self.start_time = datetime.now().timestamp()
class TestTracer(TracerBase[TestTraceRecord]):
"""用于测试的追踪器"""
__test__ = False
name = "test"
record_class = TestTraceRecord
@Inject()
def __init__(self, container: DependencyContainer):
super().__init__(container, record_class=TestTraceRecord)
def _register_event_handlers(self):
self.event_bus.register(TestEvent, self._on_test_event)
def _unregister_event_handlers(self):
self.event_bus.unregister(TestEvent, self._on_test_event)
def _on_test_event(self, event: TestEvent):
"""处理测试事件"""
trace = TestTraceRecord()
trace.trace_id = event.trace_id
trace.request_time = datetime.now()
self.save_trace_record(trace)
class TestTracerBase(TracingTestBase, IsolatedAsyncioTestCase):
"""追踪器基类测试"""
def setUp(self):
super().setUp()
self.tracer = TestTracer(self.container)
self.tracer.initialize()
def tearDown(self):
self.tracer.shutdown()
super().tearDown()
def test_generate_trace_id(self):
"""测试生成追踪ID"""
trace_id1 = generate_trace_id()
trace_id2 = generate_trace_id()
self.assertIsNotNone(trace_id1)
self.assertIsNotNone(trace_id2)
self.assertNotEqual(trace_id1, trace_id2)
def test_get_traces(self):
"""测试获取追踪记录"""
# 创建一些测试数据
for i in range(5):
event = TestEvent(f"test-trace-{i}")
self.event_bus.post(event)
# 测试基本查询
traces, total = self.tracer.get_traces()
self.assertEqual(total, 5)
self.assertEqual(len(traces), 5)
# 测试分页
traces, total = self.tracer.get_traces(page=1, page_size=2)
self.assertEqual(total, 5)
self.assertEqual(len(traces), 2)
# 测试过滤
traces, total = self.tracer.get_traces(filters={"trace_id": "test-trace-0"})
self.assertEqual(total, 1)
self.assertEqual(len(traces), 1)
self.assertEqual(traces[0].trace_id, "test-trace-0")
def test_get_recent_traces(self):
"""测试获取最近的追踪记录"""
# 创建一些测试数据
for i in range(5):
event = TestEvent(f"test-trace-{i}")
self.event_bus.post(event)
# 测试限制数量
traces = self.tracer.get_recent_traces(limit=3)
self.assertEqual(len(traces), 3)
def test_get_trace_by_id(self):
"""测试根据ID获取追踪记录"""
event = TestEvent("test-trace-id")
self.event_bus.post(event)
# 测试获取存在的记录
trace = self.tracer.get_trace_by_id("test-trace-id")
self.assertIsNotNone(trace)
self.assertEqual(trace.trace_id, "test-trace-id")
# 测试获取不存在的记录
trace = self.tracer.get_trace_by_id("non-existent-id")
self.assertIsNone(trace)
async def test_websocket_operations(self):
"""测试WebSocket相关操作"""
# 创建一个测试队列
queue = self.tracer.register_ws_client()
# 广播一条消息
test_message = {"type": "test", "data": "test data"}
self.tracer.broadcast_ws_message(test_message)
# 验证消息是否被正确广播
message = await queue.get()
self.assertEqual(message, test_message)
# 注销客户端
self.tracer.unregister_ws_client(queue)
# 验证客户端是否被正确注销
self.assertNotIn(queue, self.tracer._ws_queues)
def test_save_and_update_trace_record(self):
"""测试保存和更新追踪记录"""
# 创建并保存记录
trace = TestTraceRecord()
trace.trace_id = "test-trace-id"
trace.request_time = datetime.now()
saved_trace = self.tracer.save_trace_record(trace)
self.assertIsNotNone(saved_trace)
# 更新记录
event = TestEvent("test-trace-id")
updated_trace = self.tracer.update_trace_record("test-trace-id", event)
self.assertIsNotNone(updated_trace)
# 测试更新不存在的记录
non_existent = self.tracer.update_trace_record("non-existent-id", event)
self.assertIsNone(non_existent)
================================================
FILE: tests/tracing/test_decorator.py
================================================
from kirara_ai.llm.adapter import LLMBackendAdapter, LLMChatProtocol
from kirara_ai.llm.format.message import LLMChatTextContent
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse, Message
from kirara_ai.tracing import LLMTracer
from kirara_ai.tracing.decorator import trace_llm_chat
from tests.tracing.test_base import TracingTestBase
class TestLLMAdapter(LLMBackendAdapter, LLMChatProtocol):
"""用于测试的LLM适配器"""
__test__ = False
def __init__(self, tracer: LLMTracer):
self.backend_name = "test-backend"
self.tracer = tracer
@trace_llm_chat
def chat(self, req: LLMChatRequest) -> LLMChatResponse:
return LLMChatResponse(
model="test-model",
message=Message(role="assistant", content=[LLMChatTextContent(text="test response")]),
)
class TestTraceDecorator(TracingTestBase):
"""追踪装饰器测试"""
def setUp(self):
super().setUp()
self.tracer = LLMTracer(self.container)
self.tracer.initialize()
self.adapter = TestLLMAdapter(self.tracer)
def tearDown(self):
self.tracer.shutdown()
super().tearDown()
def test_trace_success(self):
"""测试成功追踪"""
request = self.create_test_request()
# 调用被装饰的方法
response = self.adapter.chat(request)
# 验证响应
self.assertIsNotNone(response)
self.assertEqual(response.message.content[0].text, "test response")
# 验证追踪记录
traces = self.tracer.get_recent_traces(limit=1)
self.assertEqual(len(traces), 1)
trace = traces[0]
self.assertEqual(trace.status, "success")
self.assertEqual(trace.backend_name, "test-backend")
def test_trace_failure(self):
"""测试失败追踪"""
request = self.create_test_request()
# 创建一个会抛出异常的适配器
error_adapter = TestLLMAdapter(self.tracer)
@trace_llm_chat
def raise_error(self, req: LLMChatRequest) -> LLMChatResponse:
raise Exception("Test error")
error_adapter.chat = raise_error
# 调用被装饰的方法
with self.assertRaises(Exception):
error_adapter.chat(self=error_adapter, req=request)
# 验证追踪记录
traces = self.tracer.get_recent_traces(limit=1)
self.assertEqual(len(traces), 1)
trace = traces[0]
self.assertEqual(trace.status, "failed")
self.assertEqual(trace.error, "Test error")
self.assertEqual(trace.backend_name, "test-backend")
================================================
FILE: tests/tracing/test_llm_tracer.py
================================================
from datetime import datetime
from kirara_ai.events.tracing import LLMRequestCompleteEvent, LLMRequestFailEvent, LLMRequestStartEvent
from kirara_ai.tracing import LLMTracer
from tests.tracing.test_base import TracingTestBase
class TestLLMTracer(TracingTestBase):
"""LLM追踪器测试"""
def setUp(self):
super().setUp()
self.tracer = LLMTracer(self.container)
self.tracer.initialize()
def tearDown(self):
self.tracer.shutdown()
super().tearDown()
def test_start_request_tracking(self):
"""测试开始追踪请求"""
request = self.create_test_request()
trace_id = self.tracer.start_request_tracking("test-backend", request)
# 验证追踪ID是否生成
self.assertIsNotNone(trace_id)
# 验证活跃追踪是否记录
self.assertIn(trace_id, self.tracer._active_traces)
# 验证事件是否发布
trace = self.tracer.get_trace_by_id(trace_id)
self.assertIsNotNone(trace)
self.assertEqual(trace.status, "pending")
def test_complete_request_tracking(self):
"""测试完成追踪请求"""
request = self.create_test_request()
response = self.create_test_response()
trace_id = self.tracer.start_request_tracking("test-backend", request)
self.tracer.complete_request_tracking(trace_id, request, response)
# 验证追踪记录是否更新
trace = self.tracer.get_trace_by_id(trace_id)
self.assertIsNotNone(trace)
self.assertEqual(trace.status, "success")
self.assertEqual(trace.total_tokens, 30)
# 验证活跃追踪是否移除
self.assertNotIn(trace_id, self.tracer._active_traces)
def test_fail_request_tracking(self):
"""测试失败追踪请求"""
request = self.create_test_request()
error = Exception("Test error")
trace_id = self.tracer.start_request_tracking("test-backend", request)
self.tracer.fail_request_tracking(trace_id, request, str(error))
# 验证追踪记录是否更新
trace = self.tracer.get_trace_by_id(trace_id)
self.assertIsNotNone(trace)
self.assertEqual(trace.status, "failed")
self.assertEqual(trace.error, str(error))
# 验证活跃追踪是否移除
self.assertNotIn(trace_id, self.tracer._active_traces)
def test_event_handlers(self):
"""测试事件处理程序"""
request = self.create_test_request()
response = self.create_test_response()
trace_id = "test-trace-id"
# 测试开始事件处理
start_event = LLMRequestStartEvent(
trace_id=trace_id,
model_id="test-model",
backend_name="test-backend",
request=request
)
self.event_bus.post(start_event)
trace = self.tracer.get_trace_by_id(trace_id)
self.assertIsNotNone(trace)
self.assertEqual(trace.status, "pending")
# 测试完成事件处理
complete_event = LLMRequestCompleteEvent(
trace_id=trace_id,
model_id="test-model",
backend_name="test-backend",
request=request,
response=response,
start_time=datetime.now().timestamp()
)
self.event_bus.post(complete_event)
trace = self.tracer.get_trace_by_id(trace_id)
self.assertEqual(trace.status, "success")
# 测试失败事件处理
fail_event = LLMRequestFailEvent(
trace_id=trace_id,
model_id="test-model",
backend_name="test-backend",
request=request,
error="Test error",
start_time=datetime.now().timestamp()
)
self.event_bus.post(fail_event)
trace = self.tracer.get_trace_by_id(trace_id)
self.assertEqual(trace.status, "failed")
def test_get_statistics(self):
"""测试获取统计信息"""
# 创建一些测试数据
for i in range(3):
request = self.create_test_request()
trace_id = self.tracer.start_request_tracking("test-backend", request)
if i < 2:
response = self.create_test_response()
self.tracer.complete_request_tracking(trace_id, request, response)
else:
self.tracer.fail_request_tracking(trace_id, request, "Test error")
stats = self.tracer.get_statistics()
# 验证基本统计信息
self.assertEqual(stats["overview"]["total_requests"], 3)
self.assertEqual(stats["overview"]["success_requests"], 2)
self.assertEqual(stats["overview"]["failed_requests"], 1)
self.assertEqual(stats["overview"]["total_tokens"], 60) # 2 * 30 tokens
# 验证模型统计信息
self.assertTrue(len(stats["models"]) > 0)
model_stat = stats["models"][0]
self.assertEqual(model_stat["model_id"], "test-model")
self.assertEqual(model_stat["count"], 3)
# 验证后端统计信息
self.assertTrue(len(stats["backends"]) > 0)
backend_stat = stats["backends"][0]
self.assertEqual(backend_stat["backend_name"], "test-backend")
self.assertEqual(backend_stat["count"], 3)
================================================
FILE: tests/tracing/test_manager.py
================================================
import asyncio
from unittest import IsolatedAsyncioTestCase
from kirara_ai.tracing import LLMTracer, TracingManager
from tests.tracing.test_base import TracingTestBase
from tests.tracing.test_core import TestTracer
class TestTracingManager(TracingTestBase, IsolatedAsyncioTestCase):
"""追踪管理器测试"""
def setUp(self):
super().setUp()
self.manager = TracingManager(self.container)
def test_register_tracer(self):
"""测试注册追踪器"""
tracer = TestTracer(self.container)
self.manager.register_tracer("test", tracer)
# 验证追踪器是否注册成功
self.assertIn("test", self.manager.get_tracer_types())
self.assertEqual(self.manager.get_tracer("test"), tracer)
def test_register_duplicate_tracer(self):
"""测试重复注册追踪器"""
tracer = TestTracer(self.container)
self.manager.register_tracer("test", tracer)
# 验证重复注册是否抛出异常
with self.assertRaises(ValueError):
self.manager.register_tracer("test", tracer)
def test_get_tracer(self):
"""测试获取追踪器"""
tracer = TestTracer(self.container)
self.manager.register_tracer("test", tracer)
# 验证获取追踪器
self.assertEqual(self.manager.get_tracer("test"), tracer)
self.assertIsNone(self.manager.get_tracer("non-existent"))
def test_get_all_tracers(self):
"""测试获取所有追踪器"""
tracer1 = TestTracer(self.container)
tracer2 = LLMTracer(self.container)
self.manager.register_tracer("test1", tracer1)
self.manager.register_tracer("test2", tracer2)
tracers = self.manager.get_all_tracers()
self.assertEqual(len(tracers), 2)
self.assertEqual(tracers["test1"], tracer1)
self.assertEqual(tracers["test2"], tracer2)
def test_initialize_and_shutdown(self):
"""测试初始化和关闭"""
tracer = TestTracer(self.container)
self.manager.register_tracer("test", tracer)
# 测试初始化
self.manager.initialize()
# 测试关闭
self.manager.shutdown()
async def test_websocket_operations(self):
"""测试WebSocket相关操作"""
tracer = TestTracer(self.container)
self.manager.register_tracer("test", tracer)
# 创建一个模拟的WebSocket客户端
class MockWebSocket:
def __init__(self):
self.queue = asyncio.Queue()
ws = MockWebSocket()
# 测试注册WebSocket客户端
queue = self.manager.register_ws_client("test")
# 测试注销WebSocket客户端
self.manager.unregister_ws_client("test", queue)
def test_trace_operations(self):
"""测试追踪操作"""
tracer = TestTracer(self.container)
self.manager.register_tracer("test", tracer)
# 测试获取最近的追踪记录
traces = self.manager.get_recent_traces("test")
self.assertEqual(len(traces), 0)
# 测试获取不存在的追踪器的记录
with self.assertRaises(ValueError):
self.manager.get_recent_traces("non-existent")
# 测试获取特定追踪记录
trace = self.manager.get_trace_by_id("test", "non-existent-id")
self.assertIsNone(trace)
================================================
FILE: tests/tracing/test_models.py
================================================
from datetime import datetime
from kirara_ai.events.tracing import LLMRequestCompleteEvent, LLMRequestFailEvent, LLMRequestStartEvent
from tests.tracing.test_base import TracingTestBase
class TestLLMRequestTrace(TracingTestBase):
"""LLM请求追踪记录测试"""
def setUp(self):
super().setUp()
self.trace = self.create_test_trace()
def test_update_from_start_event(self):
"""测试从开始事件更新"""
request = self.create_test_request()
event = LLMRequestStartEvent(
trace_id="test-trace-id",
model_id="test-model",
backend_name="test-backend",
request=request,
)
self.trace.update_from_event(event)
self.assertEqual(self.trace.trace_id, "test-trace-id")
self.assertEqual(self.trace.model_id, "test-model")
self.assertEqual(self.trace.backend_name, "test-backend")
self.assertEqual(self.trace.status, "pending")
self.assertIsNotNone(self.trace.request)
def test_update_from_complete_event(self):
"""测试从完成事件更新"""
request = self.create_test_request()
response = self.create_test_response()
start_time = datetime.now().timestamp()
event = LLMRequestCompleteEvent(
trace_id="test-trace-id",
model_id="test-model",
backend_name="test-backend",
request=request,
response=response,
start_time=start_time
)
self.trace.update_from_event(event)
self.assertEqual(self.trace.status, "success")
self.assertEqual(self.trace.prompt_tokens, 10)
self.assertEqual(self.trace.completion_tokens, 20)
self.assertEqual(self.trace.total_tokens, 30)
self.assertIsNotNone(self.trace.response)
def test_update_from_fail_event(self):
"""测试从失败事件更新"""
request = self.create_test_request()
start_time = datetime.now().timestamp()
event = LLMRequestFailEvent(
trace_id="test-trace-id",
model_id="test-model",
backend_name="test-backend",
request=request,
error="Test error",
start_time=start_time
)
self.trace.update_from_event(event)
self.assertEqual(self.trace.status, "failed")
self.assertEqual(self.trace.error, "Test error")
def test_to_dict(self):
"""测试转换为字典"""
request = self.create_test_request()
response = self.create_test_response()
# 设置一些基本属性
self.trace.request = request.model_dump()
self.trace.response = response.model_dump()
self.trace.prompt_tokens = 10
self.trace.completion_tokens = 20
self.trace.total_tokens = 30
# 测试基本字典转换
basic_dict = self.trace.to_dict()
self.assertEqual(basic_dict["trace_id"], "test-trace-id")
self.assertEqual(basic_dict["model_id"], "test-model")
self.assertEqual(basic_dict["backend_name"], "test-backend")
self.assertEqual(basic_dict["prompt_tokens"], 10)
self.assertEqual(basic_dict["completion_tokens"], 20)
self.assertEqual(basic_dict["total_tokens"], 30)
# 测试详细字典转换
detail_dict = self.trace.to_detail_dict()
self.assertIn("request", detail_dict)
self.assertIn("response", detail_dict)
self.assertEqual(detail_dict["request"], request.model_dump())
self.assertEqual(detail_dict["response"], response.model_dump())
def test_request_response_properties(self):
"""测试请求和响应属性"""
request = self.create_test_request()
response = self.create_test_response()
# 测试请求属性
self.trace.request = request.model_dump()
self.assertIsNotNone(self.trace.request)
self.assertEqual(self.trace.request["model"], "test-model")
# 测试响应属性
self.trace.response = response.model_dump()
self.assertIsNotNone(self.trace.response)
self.assertEqual(self.trace.response["message"]["content"][0]["text"], "test response")
================================================
FILE: tests/utils/auth_test_utils.py
================================================
import pytest_asyncio
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.web.auth.services import AuthService, MockAuthService
# ==================== 常量区 ====================
TEST_PASSWORD = "test-password"
TEST_SECRET_KEY = "test-secret-key"
# ==================== Auth Fixtures ====================
def setup_auth_service(container: DependencyContainer) -> None:
"""设置认证服务"""
# 注册 MockAuthService
container.register(AuthService, MockAuthService())
@pytest_asyncio.fixture(scope="function")
async def auth_headers(test_client):
"""获取认证头"""
response = test_client.post(
"/backend-api/api/auth/login", json={"password": TEST_PASSWORD}
)
data = response.json()
assert "error" not in data
token = data["access_token"]
return {"Authorization": f"Bearer {token}"}
================================================
FILE: tests/utils/test_block_registry.py
================================================
from kirara_ai.workflow.core.block.registry import BlockRegistry
def create_test_block_registry() -> BlockRegistry:
"""创建一个用于测试的 BlockRegistry 实例"""
registry = BlockRegistry()
# 注册一些基本类型
registry._type_system.register_type("str", str)
registry._type_system.register_type("int", int)
registry._type_system.register_type("float", float)
registry._type_system.register_type("bool", bool)
registry._type_system.register_type("list", list)
registry._type_system.register_type("dict", dict)
registry._type_system.register_type("Any", object)
return registry
================================================
FILE: tests/web/api/im/test_im.py
================================================
import asyncio
from typing import Any
from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from kirara_ai.config.config_loader import ConfigLoader
from kirara_ai.config.global_config import GlobalConfig, IMConfig, WebConfig
from kirara_ai.events.event_bus import EventBus
from kirara_ai.im.adapter import IMAdapter
from kirara_ai.im.im_registry import IMRegistry
from kirara_ai.im.manager import IMManager
from kirara_ai.im.message import IMMessage, TextMessage
from kirara_ai.im.sender import ChatSender
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.web.api.im.models import IMAdapterConfig
from kirara_ai.web.app import WebServer
from tests.utils.auth_test_utils import auth_headers, setup_auth_service # noqa
# ==================== 常量区 ====================
TEST_PASSWORD = "test-password"
TEST_SECRET_KEY = "test-secret-key"
TEST_ADAPTER_ID = "dummy-bot-1234"
TEST_ADAPTER_NOT_RUNNING_ID = "dummy-bot-2234"
TEST_ADAPTER_TYPE = "dummy"
TEST_ADAPTER_CONFIG = {"token": "test-token", "name": "Test Bot"}
# ==================== 测试用 Adapter ====================
class DummyConfig(BaseModel):
"""Dummy 配置文件模型"""
token: str = Field(description="Dummy Bot Token")
name: str = Field(description="Bot Name")
class DummyAdapter(IMAdapter):
"""
用于测试的 Dummy Adapter,实现基本的消息收发功能
"""
def __init__(self, config: DummyConfig):
self.config = config
self.is_running = False
self.messages = [] # 存储发送的消息
self.editing_states = {} # 存储编辑状态
def convert_to_message(self, raw_message: Any) -> IMMessage:
return IMMessage(
sender=ChatSender.from_c2c_chat(
user_id=raw_message.get("user_id", "default_user"),
display_name=raw_message.get("display_name", "Default User"),
),
message_elements=[TextMessage(text=raw_message.get("text", ""))],
)
async def send_message(self, message: IMMessage, recipient: ChatSender):
"""发送消息"""
self.messages.append((message, recipient))
async def start(self):
"""启动 adapter"""
self.is_running = True
async def stop(self):
"""停止 adapter"""
self.is_running = False
# ==================== Fixtures ====================
@pytest.fixture(scope="session")
def app():
"""创建测试应用实例"""
container = DependencyContainer()
loop = asyncio.new_event_loop()
container.register(asyncio.AbstractEventLoop, loop)
# 配置
config = GlobalConfig()
config.web = WebConfig(
secret_key=TEST_SECRET_KEY, password_file="test_password.hash"
)
config.ims = [
IMConfig(
name=TEST_ADAPTER_ID,
enable=True,
adapter=TEST_ADAPTER_TYPE,
config=TEST_ADAPTER_CONFIG,
),
IMConfig(
name=TEST_ADAPTER_NOT_RUNNING_ID,
enable=False,
adapter=TEST_ADAPTER_TYPE,
config=TEST_ADAPTER_CONFIG,
),
]
container.register(GlobalConfig, config)
container.register(DependencyContainer, container)
container.register(EventBus, EventBus())
# 创建并注册 IMRegistry
registry = IMRegistry()
try:
registry.register(TEST_ADAPTER_TYPE, DummyAdapter, DummyConfig)
except Exception as e:
print(e)
container.register(IMRegistry, registry)
# 创建并注册 IMManager
manager = IMManager(container)
container.register(IMManager, manager)
manager.start_adapters(loop=loop)
web_server = WebServer(container)
container.register(WebServer, web_server)
# 设置认证服务
setup_auth_service(container)
return web_server.app
@pytest.fixture(scope="session")
def test_client(app):
"""创建测试客户端"""
return TestClient(app)
# ==================== 测试用例 ====================
class TestIMAdapter:
@pytest.mark.asyncio
async def test_get_adapter_types(self, test_client, auth_headers):
"""测试获取适配器类型列表"""
response = test_client.get(
"/backend-api/api/im/types", headers=auth_headers
)
data = response.json()
assert "types" in data
assert TEST_ADAPTER_TYPE in data.get("types")
@pytest.mark.asyncio
async def test_list_adapters(self, test_client, auth_headers):
"""测试获取适配器列表"""
response = test_client.get(
"/backend-api/api/im/adapters", headers=auth_headers
)
data = response.json()
assert "adapters" in data
adapters = data.get("adapters")
assert len(adapters) == 2 # 应该有两个适配器
adapter = next(a for a in adapters if a.get("name") == TEST_ADAPTER_ID)
assert adapter.get("adapter") == TEST_ADAPTER_TYPE
assert adapter.get("is_running") is True
@pytest.mark.asyncio
async def test_get_adapter(self, test_client, auth_headers):
"""测试获取特定适配器"""
response = test_client.get(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}", headers=auth_headers
)
data = response.json()
assert "adapter" in data
adapter = data.get("adapter")
assert adapter.get("name") == TEST_ADAPTER_ID
assert adapter.get("adapter") == TEST_ADAPTER_TYPE
assert adapter.get("config") == TEST_ADAPTER_CONFIG
@pytest.mark.asyncio
async def test_create_adapter(self, test_client, auth_headers):
"""测试创建适配器"""
adapter_data = IMAdapterConfig(
name="new-adapter", adapter=TEST_ADAPTER_TYPE, config=TEST_ADAPTER_CONFIG
)
# Mock 配置文件保存
ConfigLoader.save_config_with_backup = MagicMock()
response = test_client.post(
"/backend-api/api/im/adapters",
headers=auth_headers,
json=adapter_data.model_dump(),
)
data = response.json()
assert "adapter" in data
adapter = data.get("adapter")
assert adapter.get("name") == "new-adapter"
assert adapter.get("adapter") == TEST_ADAPTER_TYPE
assert adapter.get("config") == TEST_ADAPTER_CONFIG
# 验证配置保存
ConfigLoader.save_config_with_backup.assert_called_once()
@pytest.mark.asyncio
async def test_update_adapter(self, test_client, auth_headers):
"""测试更新适配器"""
adapter_data = IMAdapterConfig(
name=TEST_ADAPTER_ID,
adapter=TEST_ADAPTER_TYPE,
config={"token": "updated-token", "name": "Updated Bot"},
)
# Mock 配置文件保存
ConfigLoader.save_config_with_backup = MagicMock()
response = test_client.put(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}",
headers=auth_headers,
json=adapter_data.model_dump(),
)
data = response.json()
assert "adapter" in data
adapter = data.get("adapter")
assert adapter.get("name") == TEST_ADAPTER_ID
assert adapter.get("adapter") == TEST_ADAPTER_TYPE
assert adapter.get("config").get("token") == "updated-token"
assert adapter.get("config").get("name") == "Updated Bot"
# 验证配置保存
ConfigLoader.save_config_with_backup.assert_called_once()
@pytest.mark.asyncio
async def test_stop_adapter(self, test_client, auth_headers):
"""测试停止适配器"""
response = test_client.post(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}/stop", headers=auth_headers
)
data = response.json()
assert "message" in data
assert data.get("message") == "Adapter stopped successfully"
# 验证适配器状态
response = test_client.get(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}", headers=auth_headers
)
data = response.json()
assert "adapter" in data
assert data.get("adapter").get("is_running") is False
@pytest.mark.asyncio
async def test_start_adapter(self, test_client, auth_headers):
"""测试启动适配器"""
response = test_client.post(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}/start",
headers=auth_headers,
)
data = response.json()
assert "message" in data
assert data.get("message") == "Adapter started successfully"
# 验证适配器状态
response = test_client.get(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}", headers=auth_headers
)
data = response.json()
assert "adapter" in data
assert data.get("adapter").get("is_running") is True
@pytest.mark.asyncio
async def test_delete_adapter(self, test_client, auth_headers):
"""测试删除适配器"""
# 先启动适配器
test_client.post(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}/start",
headers=auth_headers,
)
# Mock 配置文件保存
ConfigLoader.save_config_with_backup = MagicMock()
response = test_client.delete(
f"/backend-api/api/im/adapters/{TEST_ADAPTER_ID}", headers=auth_headers
)
data = response.json()
assert "message" in data
assert data.get("message") == "Adapter deleted successfully"
# 验证配置保存
ConfigLoader.save_config_with_backup.assert_called_once()
@pytest.mark.asyncio
async def test_get_adapter_config_schema(self, test_client, auth_headers):
"""测试获取适配器配置模式"""
response = test_client.get(
f"/backend-api/api/im/types/{TEST_ADAPTER_TYPE}/config-schema",
headers=auth_headers,
)
data = response.json()
assert "configSchema" in data
schema = data.get("configSchema")
assert schema.get("title") == "DummyConfig"
assert schema.get("type") == "object"
assert "properties" in schema
properties = schema.get("properties")
assert "token" in properties
assert properties["token"].get("title") == "Token"
assert properties["token"].get("type") == "string"
assert properties["token"].get("description") == "Dummy Bot Token"
assert "name" in properties
assert properties["name"].get("title") == "Name"
assert properties["name"].get("type") == "string"
assert properties["name"].get("description") == "Bot Name"
@pytest.mark.asyncio
async def test_get_adapter_config_schema_not_found(self, test_client, auth_headers):
"""测试获取不存在的适配器配置模式"""
response = test_client.get(
"/backend-api/api/im/types/not-exist/config-schema", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
assert "error" in data
================================================
FILE: tests/web/api/llm/test_llm.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from pydantic import BaseModel
from kirara_ai.config.config_loader import ConfigLoader
from kirara_ai.config.global_config import GlobalConfig, LLMBackendConfig, WebConfig
from kirara_ai.events.event_bus import EventBus
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.adapter import LLMBackendAdapter, LLMChatProtocol
from kirara_ai.llm.format.message import LLMChatTextContent
from kirara_ai.llm.format.request import LLMChatRequest
from kirara_ai.llm.format.response import LLMChatResponse, Message, Usage
from kirara_ai.llm.llm_manager import LLMManager
from kirara_ai.llm.llm_registry import LLMAbility, LLMBackendRegistry
from kirara_ai.web.app import WebServer
from tests.utils.auth_test_utils import auth_headers, setup_auth_service # noqa
# ==================== 常量区 ====================
TEST_PASSWORD = "test-password"
TEST_SECRET_KEY = "test-secret-key"
TEST_BACKEND_NAME = "test-backend"
TEST_ADAPTER_TYPE = "test-adapter"
# ==================== 测试用适配器 ====================
class TestConfig(BaseModel):
"""测试用配置"""
__test__ = False
api_key: str = "test-key"
model: str = "test-model"
class TestAdapter(LLMBackendAdapter, LLMChatProtocol):
"""测试用LLM适配器"""
__test__ = False
def __init__(self, config: TestConfig):
self.config = config
def chat(self, req: LLMChatRequest) -> LLMChatResponse:
return LLMChatResponse(
message=Message(
content=[
LLMChatTextContent(text="Test response")
],
role="assistant"
),
model=self.config.model,
usage=Usage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
),
)
# ==================== Fixtures ====================
@pytest.fixture(scope="session")
def app():
"""创建测试应用实例"""
container = DependencyContainer()
container.register(DependencyContainer, container)
container.register(EventBus, EventBus())
# 配置mock
config = GlobalConfig()
config.web = WebConfig(
secret_key=TEST_SECRET_KEY, password_file="test_password.hash"
)
config.llms.api_backends = [
LLMBackendConfig(
name=TEST_BACKEND_NAME,
adapter=TEST_ADAPTER_TYPE,
config={"api_key": "test-key", "model": "test-model"},
enable=True,
models=["test-model"],
)
]
container.register(GlobalConfig, config)
# 设置认证服务
setup_auth_service(container)
# 注册LLM组件
registry = LLMBackendRegistry()
registry.register(TEST_ADAPTER_TYPE, TestAdapter, TestConfig, LLMAbility.TextChat)
container.register(LLMBackendRegistry, registry)
manager = LLMManager(container)
container.register(LLMManager, manager)
manager.load_config()
web_server = WebServer(container)
container.register(WebServer, web_server)
return web_server.app
@pytest.fixture
def test_client(app):
"""创建测试客户端"""
return TestClient(app)
# ==================== 测试用例 ====================
class TestLLMBackend:
@pytest.mark.asyncio
async def test_get_adapter_types(self, test_client, auth_headers):
"""测试获取适配器类型列表"""
response = test_client.get(
"/backend-api/api/llm/types", headers=auth_headers
)
data = response.json()
assert "types" in data
assert TEST_ADAPTER_TYPE in data.get("types")
@pytest.mark.asyncio
async def test_list_backends(self, test_client, auth_headers):
"""测试获取后端列表"""
response = test_client.get(
"/backend-api/api/llm/backends", headers=auth_headers
)
data = response.json()
assert "data" in data
assert "backends" in data.get("data")
backends = data.get("data").get("backends")
assert len(backends) == 1
assert backends[0].get("name") == TEST_BACKEND_NAME
assert backends[0].get("adapter") == TEST_ADAPTER_TYPE
@pytest.mark.asyncio
async def test_get_backend(self, test_client, auth_headers):
"""测试获取指定后端"""
response = test_client.get(
f"/backend-api/api/llm/backends/{TEST_BACKEND_NAME}", headers=auth_headers
)
data = response.json()
assert "data" in data
backend = data.get("data")
assert backend.get("name") == TEST_BACKEND_NAME
assert backend.get("adapter") == TEST_ADAPTER_TYPE
@pytest.mark.asyncio
async def test_create_backend(self, test_client, auth_headers):
"""测试创建新后端"""
new_backend = LLMBackendConfig(
name="new-backend",
adapter=TEST_ADAPTER_TYPE,
config={"api_key": "new-key", "model": "new-model"},
enable=True,
models=["new-model"],
)
# Mock 配置文件保存
with patch(
"kirara_ai.config.config_loader.ConfigLoader.save_config_with_backup"
) as mock_save:
response = test_client.post(
"/backend-api/api/llm/backends",
headers=auth_headers,
json=new_backend.model_dump(),
)
data = response.json()
assert "data" in data
backend = data.get("data")
assert backend.get("name") == "new-backend"
assert backend.get("adapter") == TEST_ADAPTER_TYPE
# 验证配置保存
mock_save.assert_called_once()
@pytest.mark.asyncio
async def test_update_backend(self, test_client, auth_headers):
"""测试更新后端"""
updated_config = LLMBackendConfig(
name=TEST_BACKEND_NAME,
adapter=TEST_ADAPTER_TYPE,
config={"api_key": "updated-key", "model": "updated-model"},
enable=True,
models=["updated-model"],
)
# Mock 配置文件保存
ConfigLoader.save_config_with_backup = MagicMock()
response = test_client.put(
f"/backend-api/api/llm/backends/{TEST_BACKEND_NAME}",
headers=auth_headers,
json=updated_config.model_dump(),
)
data = response.json()
assert not data.get("error")
assert "data" in data
backend = data.get("data")
assert backend.get("name") == TEST_BACKEND_NAME
assert backend.get("config").get("api_key") == "updated-key"
# 验证配置保存
ConfigLoader.save_config_with_backup.assert_called_once()
@pytest.mark.asyncio
async def test_delete_backend(self, test_client, auth_headers):
"""测试删除后端"""
ConfigLoader.save_config_with_backup = MagicMock()
response = test_client.delete(
f"/backend-api/api/llm/backends/{TEST_BACKEND_NAME}", headers=auth_headers
)
data = response.json()
assert not data.get("error")
assert "data" in data
backend = data.get("data")
assert backend.get("name") == TEST_BACKEND_NAME
ConfigLoader.save_config_with_backup.assert_called_once()
# 验证后端已被删除
response = test_client.get(
f"/backend-api/api/llm/backends/{TEST_BACKEND_NAME}", headers=auth_headers
)
data = response.json()
assert "error" in data
@pytest.mark.asyncio
async def test_get_adapter_config_schema(self, test_client, auth_headers):
"""测试获取适配器配置模式"""
response = test_client.get(
f"/backend-api/api/llm/types/{TEST_ADAPTER_TYPE}/config-schema",
headers=auth_headers,
)
data = response.json()
assert "configSchema" in data
schema = data.get("configSchema")
assert schema.get("title") == "TestConfig"
assert schema.get("type") == "object"
assert "properties" in schema
properties = schema.get("properties")
assert "api_key" in properties
assert properties["api_key"].get("title") == "Api Key"
assert properties["api_key"].get("type") == "string"
assert "model" in properties
assert properties["model"].get("title") == "Model"
assert properties["model"].get("type") == "string"
assert properties["model"].get("default") == "test-model"
@pytest.mark.asyncio
async def test_get_adapter_config_schema_not_found(self, test_client, auth_headers):
"""测试获取不存在的适配器配置模式"""
response = test_client.get(
"/backend-api/api/llm/types/not-exist/config-schema", headers=auth_headers
)
assert response.status_code == 404
================================================
FILE: tests/web/api/media/test_media.py
================================================
import collections
import os
import tempfile
import time
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from kirara_ai.config.config_loader import CONFIG_FILE
from kirara_ai.config.global_config import GlobalConfig, MediaConfig, WebConfig
from kirara_ai.events.event_bus import EventBus
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.media.manager import MediaManager
from kirara_ai.media.metadata import MediaMetadata
from kirara_ai.media.types.media_type import MediaType
from kirara_ai.web.app import WebServer
from tests.utils.auth_test_utils import auth_headers, setup_auth_service # noqa
# ==================== 常量区 ====================
TEST_PASSWORD = "test-password"
TEST_SECRET_KEY = "test-secret-key"
# ==================== Fixtures ====================
@pytest.fixture(scope="module")
def temp_media_dir():
"""为媒体文件创建一个临时目录。"""
temp_dir = tempfile.mkdtemp(prefix="kirara_test_media_api_")
media_dir = os.path.join(temp_dir, "media")
# 确保目录存在,因为 MediaManager 会创建它,但测试可能在之前运行
os.makedirs(media_dir, exist_ok=True)
print(f"Created temp media dir: {media_dir}")
yield media_dir
print(f"Removing temp media dir: {temp_dir}")
# shutil.rmtree(temp_dir) # 在某些系统上可能会有权限问题,暂时注释掉
@pytest.fixture(scope="module")
def container(temp_media_dir):
"""创建一个带有模拟组件的依赖容器。"""
container = DependencyContainer()
container.register(DependencyContainer, container)
container.register(EventBus, EventBus())
# 配置
config = GlobalConfig()
config.web = WebConfig(
secret_key=TEST_SECRET_KEY, password_file="test_password.hash"
)
config.media = MediaConfig(
cleanup_duration=7,
auto_remove_unreferenced=True,
last_cleanup_time=int(time.time()) - 86400, # 昨天
)
container.register(GlobalConfig, config)
# 认证服务
setup_auth_service(container) # 基于配置设置认证
# 媒体管理器 (真实的,但使用临时目录)
# 如果 MediaManager 使用 __new__,则重置单例实例以进行测试
if hasattr(MediaManager, "_instance"):
del MediaManager._instance
media_manager = MediaManager(media_dir=temp_media_dir)
container.register(MediaManager, media_manager)
return container
@pytest.fixture(scope="module")
def app(container):
"""创建 FastAPI 应用实例。"""
web_server = WebServer(container)
container.register(WebServer, web_server)
return web_server.app
@pytest.fixture(scope="module")
def test_client(app):
"""创建一个 TestClient 实例。"""
# 使用 lifespan 管理器来确保启动和关闭事件被触发
with TestClient(app) as client:
yield client
# ==================== 测试用例 ====================
@pytest.mark.usefixtures("test_client", "auth_headers") # 应用 test_client 和 auth_headers
class TestMediaAPI:
@pytest.fixture(autouse=True)
def setup_mocks(self, container, temp_media_dir):
"""在每个测试之前设置模拟对象。"""
# 模拟 MediaManager 方法
self.mock_media_manager = MagicMock(spec=MediaManager)
# 确保 mock manager 知道正确的 media_dir 以便 disk_usage 测试
self.mock_media_manager.media_dir = temp_media_dir
# 使用 patch 来替换路由中获取 MediaManager 的函数
self.patcher_get_manager = patch(
"kirara_ai.web.api.media.routes._get_media_manager",
return_value=self.mock_media_manager,
)
self.mock_get_manager = self.patcher_get_manager.start()
# 模拟 ConfigLoader 保存
# 注意:需要模拟 routes.py 中使用的 ConfigLoader 实例或类方法
self.patcher_save_config = patch(
"kirara_ai.web.api.media.routes.ConfigLoader.save_config_with_backup"
)
self.mock_save_config = self.patcher_save_config.start()
# 模拟 shutil.disk_usage
self.patcher_disk_usage = patch("kirara_ai.web.api.media.routes.shutil.disk_usage")
self.mock_disk_usage = self.patcher_disk_usage.start()
# 模拟 time.time() 以便检查 last_cleanup_time 的更新
self.current_time = int(time.time())
self.patcher_time = patch("kirara_ai.web.api.media.routes.time.time", return_value=self.current_time)
self.mock_time = self.patcher_time.start()
yield # 运行测试
# 停止 patchers
self.patcher_get_manager.stop()
self.patcher_save_config.stop()
self.patcher_disk_usage.stop()
self.patcher_time.stop()
def test_get_system_info(self, test_client, auth_headers, container):
"""测试 GET /system 端点。"""
config: GlobalConfig = container.resolve(GlobalConfig)
# media_manager_instance: MediaManager = container.resolve(MediaManager) # 获取真实的实例以获取路径
# 设置模拟返回值
mock_media_ids = ["media1", "media2"]
mock_metadata1 = MediaMetadata(
media_id="media1",
media_type=MediaType.IMAGE,
format="jpg",
size=1024,
references={"ref1"},
)
mock_metadata2 = MediaMetadata(
media_id="media2",
media_type=MediaType.AUDIO,
format="mp3",
size=2048,
references={"ref2"},
)
self.mock_media_manager.get_all_media_ids.return_value = mock_media_ids
self.mock_media_manager.get_metadata.side_effect = lambda mid: (
mock_metadata1
if mid == "media1"
else (mock_metadata2 if mid == "media2" else None)
)
mock_disk_usage_result = collections.namedtuple(
"usage", ["total", "used", "free"]
)(
total=10 * 1024 * 1024, used=3 * 1024 * 1024, free=7 * 1024 * 1024
)
self.mock_disk_usage.return_value = mock_disk_usage_result
response = test_client.get("/backend-api/api/media/system", headers=auth_headers)
assert response.status_code == 200, f"响应内容: {response.text}"
data = response.json()
assert data["cleanup_duration"] == config.media.cleanup_duration
assert data["auto_remove_unreferenced"] == config.media.auto_remove_unreferenced
assert data["last_cleanup_time"] == config.media.last_cleanup_time
assert data["total_media_count"] == 2
assert data["total_media_size"] == 1024 + 2048
assert data["disk_total"] == mock_disk_usage_result.total
assert data["disk_used"] == mock_disk_usage_result.used
assert data["disk_free"] == mock_disk_usage_result.free
self.mock_media_manager.get_all_media_ids.assert_called_once()
assert self.mock_media_manager.get_metadata.call_count == 2
# 验证 disk_usage 使用了正确的路径 (来自 mock manager)
self.mock_disk_usage.assert_called_once_with(self.mock_media_manager.media_dir)
def test_set_config(self, test_client, auth_headers, container):
"""测试 POST /system/config 端点。"""
config: GlobalConfig = container.resolve(GlobalConfig)
original_duration = config.media.cleanup_duration
original_auto_remove = config.media.auto_remove_unreferenced
new_config_data = {"cleanup_duration": 14, "auto_remove_unreferenced": False}
response = test_client.post(
"/backend-api/api/media/system/config",
headers=auth_headers,
json=new_config_data,
)
assert response.status_code == 200, f"响应内容: {response.text}"
data = response.json()
assert data["success"] is True
# 验证容器中的配置对象是否已更新
assert config.media.cleanup_duration == 14
assert config.media.auto_remove_unreferenced is False
# 验证模拟对象是否被调用
self.mock_media_manager.setup_cleanup_task.assert_called_once_with(container)
# 验证配置保存时传递了正确的参数
self.mock_save_config.assert_called_once()
args, kwargs = self.mock_save_config.call_args
assert args[0] == CONFIG_FILE
saved_config = args[1]
assert isinstance(saved_config, GlobalConfig)
assert saved_config.media.cleanup_duration == 14
assert saved_config.media.auto_remove_unreferenced is False
# 为其他测试恢复原始值(尽管 fixture 应该处理隔离)
config.media.cleanup_duration = original_duration
config.media.auto_remove_unreferenced = original_auto_remove
def test_cleanup_unreferenced(self, test_client, auth_headers, container):
"""测试 POST /system/cleanup-unreferenced 端点。"""
config: GlobalConfig = container.resolve(GlobalConfig)
original_last_cleanup_time = config.media.last_cleanup_time
# 设置清理的模拟返回值
cleanup_count = 5
self.mock_media_manager.cleanup_unreferenced.return_value = cleanup_count
response = test_client.post(
"/backend-api/api/media/system/cleanup-unreferenced", headers=auth_headers
)
assert response.status_code == 200, f"响应内容: {response.text}"
data = response.json()
assert data["success"] is True
assert data["count"] == cleanup_count
# 验证模拟对象是否被调用
self.mock_media_manager.cleanup_unreferenced.assert_called_once()
self.mock_save_config.assert_called_once()
self.mock_media_manager.setup_cleanup_task.assert_called_once_with(container)
# 验证 last_cleanup_time 是否已更新 (使用模拟的时间)
assert config.media.last_cleanup_time == self.current_time
# 验证保存的配置中 last_cleanup_time 也更新了
args, kwargs = self.mock_save_config.call_args
saved_config: GlobalConfig = args[1]
assert saved_config.media.last_cleanup_time == self.current_time
# 恢复原始时间(如果需要)
config.media.last_cleanup_time = original_last_cleanup_time
================================================
FILE: tests/web/api/plugin/test_plugin.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from kirara_ai.config.global_config import GlobalConfig, PluginConfig, WebConfig
from kirara_ai.events.event_bus import EventBus
from kirara_ai.im.im_registry import IMRegistry
from kirara_ai.im.manager import IMManager
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.llm_registry import LLMBackendRegistry
from kirara_ai.plugin_manager.models import PluginInfo
from kirara_ai.plugin_manager.plugin import Plugin
from kirara_ai.plugin_manager.plugin_loader import PluginLoader
from kirara_ai.web.app import WebServer
from kirara_ai.workflow.core.block import BlockRegistry
from kirara_ai.workflow.core.dispatch import WorkflowDispatcher
from kirara_ai.workflow.core.dispatch.registry import DispatchRuleRegistry
from kirara_ai.workflow.core.workflow.registry import WorkflowRegistry
from tests.utils.auth_test_utils import auth_headers, setup_auth_service # noqa
from tests.utils.test_block_registry import create_test_block_registry
# ==================== 常量区 ====================
TEST_PASSWORD = "test-password"
TEST_SECRET_KEY = "test-secret-key"
TEST_PLUGIN_NAME = "test-plugin"
# ==================== 测试用插件 ====================
def make_test_plugin():
class TestPlugin(Plugin):
"""测试用插件"""
__test__ = True
def __init__(self):
self.initialized = False
self.started = False
def on_load(self):
self.initialized = True
def on_start(self):
self.started = True
def on_stop(self):
self.started = False
return TestPlugin
# ==================== Mock 数据 ====================
async def MOCK_PLUGIN_SEARCH_RESPONSE():
return {
"plugins": [
{
"name": "测试插件",
"description": "测试插件描述",
"author": "测试作者",
"pypiPackage": "test-plugin",
"pypiInfo": {
"version": "0.1.0",
"description": "PyPI 描述",
"author": "PyPI 作者",
"homePage": "https://example.com",
},
"isInstalled": True,
"installedVersion": "1.0.0",
"isUpgradable": False,
"isEnabled": True,
"requiresRestart": False,
}
],
"totalCount": 1,
"totalPages": 1,
"page": 1,
"pageSize": 10,
}
async def MOCK_PLUGIN_INFO_RESPONSE():
return {
"name": "测试插件",
"description": "测试插件描述",
"author": "测试作者",
"pypiPackage": "test-plugin",
"pypiInfo": {
"version": "0.1.0",
"description": "PyPI 描述",
"author": "PyPI 作者",
"homePage": "https://example.com",
},
"isInstalled": True,
"installedVersion": "1.0.0",
"isUpgradable": False,
"isEnabled": True,
}
# ==================== Fixtures ====================
@pytest.fixture(scope="session")
def app():
"""创建测试应用实例"""
container = DependencyContainer()
container.register(DependencyContainer, container)
# 配置
config = GlobalConfig()
config.web = WebConfig(
secret_key=TEST_SECRET_KEY, password_file="test_password.hash"
)
config.plugins = PluginConfig(enable=[TEST_PLUGIN_NAME])
container.register(GlobalConfig, config)
# 设置认证服务
setup_auth_service(container)
# 注册必要的组件
container.register(EventBus, EventBus())
container.register(LLMBackendRegistry, LLMBackendRegistry())
container.register(IMRegistry, IMRegistry())
container.register(IMManager, IMManager(container))
container.register(WorkflowRegistry, WorkflowRegistry(container))
container.register(DispatchRuleRegistry, DispatchRuleRegistry(container))
container.register(WorkflowDispatcher, WorkflowDispatcher(container))
container.register(BlockRegistry, create_test_block_registry())
# 创建插件加载器并注册测试插件
plugin_loader = PluginLoader(container, "plugins")
plugin_loader.register_plugin(make_test_plugin(), TEST_PLUGIN_NAME)
container.register(PluginLoader, plugin_loader)
web_server = WebServer(container)
container.register(WebServer, web_server)
return web_server.app
@pytest.fixture
def test_client(app):
"""创建测试客户端"""
return TestClient(app)
# ==================== 测试用例 ====================
class TestPlugin:
@pytest.mark.asyncio
async def test_search_plugins(self, test_client, auth_headers):
"""测试搜索插件市场"""
with patch("aiohttp.ClientSession.get") as mock_get:
mock_response = MagicMock()
mock_response.status = 200
mock_response.json.return_value = MOCK_PLUGIN_SEARCH_RESPONSE()
mock_get.return_value.__aenter__.return_value = mock_response
response = test_client.get(
"/backend-api/api/plugin/v1/search?query=test&page=1&pageSize=10",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data == await MOCK_PLUGIN_SEARCH_RESPONSE()
@pytest.mark.asyncio
async def test_get_plugin_info(self, test_client, auth_headers):
"""测试获取插件详情"""
with patch("aiohttp.ClientSession.get") as mock_get:
mock_response = MagicMock()
mock_response.status = 200
mock_response.json.return_value = MOCK_PLUGIN_INFO_RESPONSE()
mock_get.return_value.__aenter__.return_value = mock_response
response = test_client.get(
f"/backend-api/api/plugin/v1/info/{TEST_PLUGIN_NAME}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
del data["requiresRestart"]
assert data == await MOCK_PLUGIN_INFO_RESPONSE()
@pytest.mark.asyncio
async def test_get_plugin_details(self, test_client, auth_headers):
"""测试获取插件详情"""
response = test_client.get(
f"/backend-api/api/plugin/plugins/{TEST_PLUGIN_NAME}", headers=auth_headers
)
data = response.json()
assert "error" not in data
assert "plugin" in data
plugin = data["plugin"]
assert plugin["name"] == "TestPlugin"
assert plugin["is_internal"] is True
assert plugin["is_enabled"] is True
@pytest.mark.asyncio
async def test_get_nonexistent_plugin(self, test_client, auth_headers):
"""测试获取不存在的插件"""
response = test_client.get(
"/backend-api/api/plugin/plugins/nonexistent", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
assert "error" in data
@pytest.mark.asyncio
async def test_update_plugin(self, test_client, auth_headers):
"""测试更新插件"""
# 由于是内部插件,更新应该失败
response = test_client.put(
f"/backend-api/api/plugin/plugins/{TEST_PLUGIN_NAME}", headers=auth_headers
)
assert response.status_code == 400 # 内部插件不支持更新
data = response.json()
assert "error" in data
@pytest.mark.asyncio
async def test_enable_plugin(self, test_client, auth_headers):
"""测试启用插件"""
# Mock 配置文件保存
with patch(
"kirara_ai.config.config_loader.ConfigLoader.save_config_with_backup"
) as mock_save:
response = test_client.post(
f"/backend-api/api/plugin/plugins/{TEST_PLUGIN_NAME}/enable",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "error" not in data
assert data["plugin"]["is_enabled"] is True
@pytest.mark.asyncio
async def test_disable_plugin(self, test_client, auth_headers):
"""测试禁用插件"""
# Mock 配置文件保存
with patch(
"kirara_ai.config.config_loader.ConfigLoader.save_config_with_backup"
) as mock_save:
response = test_client.post(
f"/backend-api/api/plugin/plugins/{TEST_PLUGIN_NAME}/disable",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "error" not in data
assert data["plugin"]["is_enabled"] is False
@pytest.mark.asyncio
async def test_install_plugin(self, test_client, auth_headers):
"""测试安装插件"""
with patch(
"kirara_ai.web.api.plugin.routes.PluginLoader.install_plugin"
) as mock_install_plugin:
mock_install_plugin.return_value = PluginInfo(
name="test-plugin",
package_name="test-plugin-package",
description="test-plugin-description",
is_internal=False,
is_enabled=False,
version="1.0.0",
author="test-author",
)
# Mock 配置文件保存
with patch(
"kirara_ai.config.config_loader.ConfigLoader.save_config_with_backup"
) as mock_save:
response = test_client.post(
"/backend-api/api/plugin/plugins",
headers=auth_headers,
json={"package_name": "test-plugin-package", "version": "1.0.0"},
)
data = response.json()
assert "error" not in data
assert data["plugin"]["package_name"] == "test-plugin-package"
# 验证配置保存
mock_save.assert_called_once()
@pytest.mark.asyncio
async def test_uninstall_plugin(self, test_client, auth_headers):
"""测试卸载插件"""
# 由于是内部插件,卸载应该失败
response = test_client.delete(
f"/backend-api/api/plugin/plugins/{TEST_PLUGIN_NAME}", headers=auth_headers
)
data = response.json()
assert "error" in data
assert response.status_code == 400 # 内部插件不能卸载
================================================
FILE: tests/web/api/system/test_system.py
================================================
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from kirara_ai.config.global_config import GlobalConfig, WebConfig
from kirara_ai.im.manager import IMManager
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.llm.llm_manager import LLMManager
from kirara_ai.plugin_manager.plugin_loader import PluginLoader
from kirara_ai.web.app import WebServer
from kirara_ai.workflow.core.workflow import WorkflowRegistry
from tests.utils.auth_test_utils import auth_headers, setup_auth_service # noqa
# ==================== 常量区 ====================
TEST_PASSWORD = "test-password"
TEST_SECRET_KEY = "test-secret-key"
# ==================== Fixtures ====================
@pytest.fixture
def app():
"""创建测试应用实例"""
container = DependencyContainer()
# 配置mock
config = GlobalConfig()
config.web = WebConfig(
secret_key=TEST_SECRET_KEY, password_file="test_password.hash"
)
container.register(GlobalConfig, config)
# 设置认证服务
setup_auth_service(container)
# Mock其他依赖
im_manager = MagicMock(spec=IMManager)
im_manager.adapters = {
"adapter1": MagicMock(is_running=True),
"adapter2": MagicMock(is_running=False),
}
container.register(IMManager, im_manager)
llm_manager = MagicMock(spec=LLMManager)
llm_manager.active_backends = {"backend1": [], "backend2": []}
container.register(LLMManager, llm_manager)
plugin_loader = MagicMock(spec=PluginLoader)
plugin_loader.plugins = [MagicMock(), MagicMock(), MagicMock()]
container.register(PluginLoader, plugin_loader)
workflow_registry = MagicMock(spec=WorkflowRegistry)
workflow_registry._workflows = {"workflow1": MagicMock(), "workflow2": MagicMock()}
container.register(WorkflowRegistry, workflow_registry)
web_server = WebServer(container)
container.register(WebServer, web_server)
return web_server.app
@pytest.fixture
def test_client(app):
"""创建测试客户端"""
return TestClient(app)
# ==================== 测试用例 ====================
class TestSystemStatus:
@pytest.mark.asyncio
async def test_get_system_status(self, test_client, auth_headers):
"""测试获取系统状态"""
# Mock psutil.Process
mock_process = MagicMock()
mock_process.memory_full_info.return_value = MagicMock(
uss=1024 * 1024 * 100 # 100MB
)
mock_process.cpu_percent.return_value = 1.2
# Mock psutil.virtual_memory
mock_virtual_memory = MagicMock()
mock_virtual_memory.total = 1024 * 1024 * 8192 # 8GB
mock_virtual_memory.available = 1024 * 1024 * 4096 # 4GB
mock_virtual_memory.used = 1024 * 1024 * 4096 # 4GB
with patch(
"kirara_ai.web.api.system.utils.psutil.Process", return_value=mock_process
), patch(
"kirara_ai.web.api.system.utils.psutil.virtual_memory", return_value=mock_virtual_memory
), patch(
"kirara_ai.web.api.system.utils.psutil.cpu_percent", return_value=1.2
):
response = test_client.get(
"/backend-api/api/system/status", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert "status" in data
status = data["status"]
# 验证基本字段
assert "version" in status
assert "uptime" in status
assert status["active_adapters"] == 1 # 只有一个运行中的适配器
assert status["active_backends"] == 2 # 两个后端
assert status["loaded_plugins"] == 3 # 三个插件
assert status["workflow_count"] == 2 # 两个工作流
# 验证资源使用情况
assert "memory_usage" in status
assert "cpu_usage" in status
assert status["memory_usage"]["percent"] == 0.5 # used/total
assert status["memory_usage"]["total"] == 8192 # 8GB
assert status["memory_usage"]["free"] == 4096 # 4GB
assert status["memory_usage"]["used"] == 100 # 100MB (process.memory_full_info().uss)
assert status["cpu_usage"] == 1.2
@pytest.mark.asyncio
async def test_get_system_status_unauthorized(self, test_client):
"""测试未认证时获取系统状态"""
response = test_client.get("/backend-api/api/system/status")
assert response.status_code == 401
data = response.json()
assert "error" in data
@pytest.mark.asyncio
async def test_check_update(self, test_client, auth_headers):
"""测试检查更新"""
response = test_client.get("/backend-api/api/system/check-update", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["current_backend_version"] != "0.0.0"
assert data["latest_backend_version"] != "0.0.0"
assert data["backend_update_available"] == False
assert data["latest_webui_version"] != "0.0.0"
assert data["webui_download_url"] != ""
================================================
FILE: tests/web/api/workflow/test_workflow.py
================================================
import pytest
from fastapi.testclient import TestClient
from kirara_ai.config.global_config import GlobalConfig, WebConfig
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.web.app import WebServer
from kirara_ai.workflow.core.block import Block, BlockRegistry
from kirara_ai.workflow.core.block.input_output import Input, Output
from kirara_ai.workflow.core.workflow import WorkflowRegistry
from kirara_ai.workflow.core.workflow.builder import WorkflowBuilder
from tests.utils.auth_test_utils import auth_headers, setup_auth_service # noqa
# ==================== 常量区 ====================
TEST_PASSWORD = "test-password"
TEST_SECRET_KEY = "test-secret-key"
TEST_GROUP_ID = "test-group"
TEST_WORKFLOW_ID = "test-workflow"
TEST_WORKFLOW_ID_NEW = "test-workflow-new"
TEST_WORKFLOW_NAME = "Test Workflow"
TEST_WORKFLOW_NAME_NEW = "Test Workflow New"
TEST_WORKFLOW_DESC = "A test workflow"
# ==================== 测试用Block ====================
class MessageBlock(Block):
name = "message_block"
inputs = {}
outputs = {"output": Output("output", "输出", str, "Output message")}
container: DependencyContainer
def __init__(self, text: str = ""):
self.config = {"text": text}
self.position = {"x": 0, "y": 0}
def execute(self) -> dict:
return {"output": self.config["text"]}
class LLMBlock(Block):
name = "llm_block"
inputs = {"input": Input("input", "输入", str, "Input message")}
outputs = {"output": Output("output", "输出", str, "Output message")}
container: DependencyContainer
def __init__(self, prompt: str = ""):
self.config = {"prompt": prompt}
self.position = {"x": 200, "y": 0}
def execute(self, input: str) -> dict:
return {"output": f"Response to: {input}"}
# ==================== Fixtures ====================
@pytest.fixture
def app():
"""创建测试应用实例"""
container = DependencyContainer()
# 配置
config = GlobalConfig()
config.web = WebConfig(
secret_key=TEST_SECRET_KEY, password_file="test_password.hash"
)
container.register(GlobalConfig, config)
# 设置认证服务
setup_auth_service(container)
# 创建并注册 BlockRegistry
block_registry = BlockRegistry()
block_registry.register("message", "test", MessageBlock)
block_registry.register("llm", "test", LLMBlock)
container.register(BlockRegistry, block_registry)
# 创建工作流
builder = (
WorkflowBuilder(TEST_WORKFLOW_NAME)
.use(MessageBlock, text="Hello")
.chain(LLMBlock, prompt="How are you?")
)
# 创建并注册 WorkflowRegistry
registry = WorkflowRegistry(container)
registry.register(TEST_GROUP_ID, TEST_WORKFLOW_ID, builder)
container.register(WorkflowRegistry, registry)
web_server = WebServer(container)
container.register(WebServer, web_server)
return web_server.app
@pytest.fixture
def test_client(app):
"""创建测试客户端"""
return TestClient(app)
# ==================== 测试用例 ====================
class TestWorkflow:
@pytest.mark.asyncio
async def test_list_workflows(self, test_client, auth_headers):
"""测试获取工作流列表"""
response = test_client.get(
"/backend-api/api/workflow", headers=auth_headers
)
data = response.json()
assert "error" not in data
assert "workflows" in data
workflows = data["workflows"]
assert len(workflows) == 1
workflow = workflows[0]
assert workflow["workflow_id"] == TEST_WORKFLOW_ID
assert workflow["group_id"] == TEST_GROUP_ID
assert workflow["name"] == TEST_WORKFLOW_NAME
@pytest.mark.asyncio
async def test_get_workflow(self, test_client, auth_headers):
"""测试获取单个工作流"""
response = test_client.get(
f"/backend-api/api/workflow/{TEST_GROUP_ID}/{TEST_WORKFLOW_ID}",
headers=auth_headers,
)
data = response.json()
assert "error" not in data
assert "workflow" in data
workflow = data["workflow"]
assert workflow["workflow_id"] == TEST_WORKFLOW_ID
assert workflow["group_id"] == TEST_GROUP_ID
assert workflow["name"] == TEST_WORKFLOW_NAME
assert len(workflow["wires"]) == 1
@pytest.mark.asyncio
async def test_create_workflow(self, test_client, auth_headers):
"""测试创建工作流"""
workflow_data = {
"workflow_id": TEST_WORKFLOW_ID_NEW,
"group_id": TEST_GROUP_ID,
"name": TEST_WORKFLOW_NAME,
"description": TEST_WORKFLOW_DESC,
"blocks": [
{
"block_id": "node1",
"type_name": "test:message",
"name": "Message Node",
"config": {"text": "Hello"},
"position": {"x": 0, "y": 0},
}
],
"wires": [],
}
response = test_client.post(
f"/backend-api/api/workflow/{TEST_GROUP_ID}/{TEST_WORKFLOW_ID_NEW}",
headers=auth_headers,
json=workflow_data,
)
data = response.json()
assert "error" not in data
assert data["workflow_id"] == TEST_WORKFLOW_ID_NEW
assert data["group_id"] == TEST_GROUP_ID
assert data["name"] == TEST_WORKFLOW_NAME
assert len(data["blocks"]) == 1
@pytest.mark.asyncio
async def test_update_workflow(self, test_client, auth_headers):
"""测试更新工作流"""
workflow_data = {
"workflow_id": TEST_WORKFLOW_ID,
"group_id": TEST_GROUP_ID,
"name": "Updated Workflow",
"description": "Updated workflow description",
"blocks": [
{
"block_id": "node1",
"type_name": "test:message",
"name": "Message Node",
"config": {"text": "Updated text"},
"position": {"x": 0, "y": 0},
}
],
"wires": [],
}
response = test_client.put(
f"/backend-api/api/workflow/{TEST_GROUP_ID}/{TEST_WORKFLOW_ID}",
headers=auth_headers,
json=workflow_data,
)
data = response.json()
assert "error" not in data
assert data["workflow_id"] == TEST_WORKFLOW_ID
assert data["group_id"] == TEST_GROUP_ID
assert data["name"] == "Updated Workflow"
assert data["description"] == "Updated workflow description"
assert len(data["blocks"]) == 1
assert data["blocks"][0]["config"]["text"] == "Updated text"
@pytest.mark.asyncio
async def test_delete_workflow(self, test_client, auth_headers):
"""测试删除工作流"""
response = test_client.delete(
f"/backend-api/api/workflow/{TEST_GROUP_ID}/{TEST_WORKFLOW_ID}",
headers=auth_headers,
)
data = response.json()
assert "error" not in data
assert "message" in data
assert data["message"] == "Workflow deleted successfully"
================================================
FILE: tests/web/auth/test_auth.py
================================================
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from kirara_ai.config.global_config import GlobalConfig, WebConfig
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.web.app import WebServer
from tests.utils.auth_test_utils import TEST_PASSWORD, setup_auth_service
# ==================== 常量区 ====================
TEST_NEW_PASSWORD = "new-password"
TEST_SECRET_KEY = "test-secret-key"
TEST_TOKEN = "mock_token" # 使用 MockAuthService 的固定 token
# ==================== Fixtures ====================
@pytest.fixture
def app():
"""创建测试应用实例"""
container = DependencyContainer()
container.register(GlobalConfig, GlobalConfig(web=WebConfig(secret_key=TEST_SECRET_KEY)))
web_server = WebServer(container)
container.register(WebServer, web_server)
setup_auth_service(container)
return web_server.app
@pytest.fixture
def test_client(app):
"""创建测试客户端"""
return TestClient(app)
@pytest_asyncio.fixture
async def auth_token(test_client):
"""获取认证token"""
response = test_client.post(
"/backend-api/api/auth/login", json={"password": TEST_PASSWORD}
)
data = response.json()
assert "access_token" in data
return data["access_token"]
# ==================== 测试用例 ====================
class TestAuth:
@pytest.mark.asyncio
async def test_check_first_time(self, test_client):
"""测试检查是否首次访问接口"""
# 首次访问
response = test_client.get("/backend-api/api/auth/check-first-time")
assert response.status_code == 200
data = response.json()
assert "is_first_time" in data
assert data["is_first_time"] == True
# 模拟登录后
response = test_client.post(
"/backend-api/api/auth/login", json={"password": TEST_PASSWORD}
)
# 再次检查
response = test_client.get("/backend-api/api/auth/check-first-time")
assert response.status_code == 200
data = response.json()
assert "is_first_time" in data
assert data["is_first_time"] == False
@pytest.mark.asyncio
async def test_normal_login(self, test_client):
"""测试普通登录"""
# 先进行首次登录设置密码
test_client.post(
"/backend-api/api/auth/login", json={"password": TEST_PASSWORD}
)
# 然后测试正常登录
response = test_client.post(
"/backend-api/api/auth/login", json={"password": TEST_PASSWORD}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["access_token"] == TEST_TOKEN
@pytest.mark.asyncio
async def test_login_wrong_password(self, test_client):
"""测试密码错误的情况"""
# 先进行首次登录设置密码
test_client.post(
"/backend-api/api/auth/login", json={"password": TEST_PASSWORD}
)
# 然后测试错误密码
response = test_client.post(
"/backend-api/api/auth/login", json={"password": "wrong-password"}
)
assert response.status_code == 401
data = response.json()
assert "error" in data
@pytest.mark.asyncio
async def test_change_password(self, test_client, auth_token):
"""测试修改密码"""
response = test_client.post(
"/backend-api/api/auth/change-password",
json={"old_password": TEST_PASSWORD, "new_password": TEST_NEW_PASSWORD},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
# 验证新密码可以登录
response = test_client.post(
"/backend-api/api/auth/login", json={"password": TEST_NEW_PASSWORD}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_change_password_wrong_old(self, test_client, auth_token):
"""测试修改密码时旧密码错误的情况"""
response = test_client.post(
"/backend-api/api/auth/change-password",
json={"old_password": "wrong-password", "new_password": TEST_NEW_PASSWORD},
headers={"Authorization": f"Bearer {auth_token}"},
)
data = response.json()
assert "error" in data
================================================
FILE: tests/workflow_executor/test_block.py
================================================
from kirara_ai.workflow.core.block import Block
from kirara_ai.workflow.core.block.input_output import Input, Output
# Define test inputs and outputs
input_data = Input(
name="input1", label="输入1", data_type=str, description="Input data"
)
output_data = Output(
name="output1", label="输出1", data_type=str, description="Processed data"
)
# Define test block
block = Block(
name="TestBlock", inputs={"input1": input_data}, outputs={"output1": output_data}
)
def test_block_creation():
"""Test block creation."""
assert block.name == "TestBlock"
assert block.inputs["input1"].data_type == str
assert block.outputs["output1"].data_type == str
def test_block_execute():
"""Test block execution."""
result = block.execute(input1="test_input")
assert "output1" in result
assert result["output1"] == "Processed {'input1': 'test_input'}"
================================================
FILE: tests/workflow_executor/test_executor.py
================================================
import pytest
from kirara_ai.events.event_bus import EventBus
from kirara_ai.ioc.container import DependencyContainer
from kirara_ai.workflow.core.block import Block, Input, Output
from kirara_ai.workflow.core.block.registry import BlockRegistry
from kirara_ai.workflow.core.execution.exceptions import BlockExecutionFailedException
from kirara_ai.workflow.core.execution.executor import WorkflowExecutor
from kirara_ai.workflow.core.workflow import Wire, Workflow
from tests.utils.test_block_registry import create_test_block_registry
# 创建测试用的 BlockRegistry
test_registry = create_test_block_registry()
# 创建测试用的 Block 类
class InputBlock(Block):
name = "InputBlock"
outputs = {
"output1": Output(
name="output1", label="输出1", data_type=str, description="Test output"
)
}
def execute(self, **kwargs):
return {"output1": "test_input"}
class ProcessBlock(Block):
name = "ProcessBlock"
inputs = {
"input1": Input(
name="input1", label="输入1", data_type=str, description="Test input"
)
}
outputs = {
"output1": Output(
name="output1", label="输出1", data_type=str, description="Test output"
)
}
def execute(self, input1: str, **kwargs):
return {"output1": input1.upper()}
class OutputBlock(Block):
name = "OutputBlock"
inputs = {
"input1": Input(
name="input1", label="输入1", data_type=str, description="Test input"
)
}
def execute(self, input1: str, **kwargs):
return {"result": input1}
class FailingBlock(Block):
name = "FailingBlock"
inputs = {
"input1": Input(
name="input1", label="输入1", data_type=str, description="Test input"
)
}
def execute(self, input1: str, **kwargs):
raise BlockExecutionFailedException("Test error")
# 注册测试用的 Block 类型
test_registry.register("test", "input", InputBlock)
test_registry.register("test", "process", ProcessBlock)
test_registry.register("test", "output", OutputBlock)
test_registry.register("test", "failing", FailingBlock)
# 创建测试用的工作流
input_block = InputBlock(name="input1")
process_block = ProcessBlock(name="process1")
output_block = OutputBlock(name="output1")
workflow = Workflow(
name="test_workflow",
blocks=[input_block, process_block, output_block],
wires=[
Wire(
source_block=input_block,
source_output="output1",
target_block=process_block,
target_input="input1",
),
Wire(
source_block=process_block,
source_output="output1",
target_block=output_block,
target_input="input1",
),
],
)
# 创建测试用的失败工作流
failing_block = FailingBlock(name="failing1")
failing_workflow = Workflow(
name="failing_workflow",
blocks=[input_block, failing_block],
wires=[
Wire(
source_block=input_block,
source_output="output1",
target_block=failing_block,
target_input="input1",
)
],
)
@pytest.mark.asyncio
async def test_executor_run():
"""Test workflow executor run."""
container = DependencyContainer()
container.register(DependencyContainer, container)
container.register(EventBus, EventBus())
container.register(BlockRegistry, test_registry)
container.register(Workflow, workflow)
executor = WorkflowExecutor(container)
result = await executor.run()
assert result["input1"]["output1"] == "test_input"
assert result["process1"]["output1"] == "TEST_INPUT"
assert result["output1"]["result"] == "TEST_INPUT"
@pytest.mark.asyncio
async def test_executor_with_failing_block():
"""Test workflow executor with a failing block."""
container = DependencyContainer()
container.register(DependencyContainer, container)
container.register(EventBus, EventBus())
container.register(BlockRegistry, test_registry)
container.register(Workflow, failing_workflow)
executor = WorkflowExecutor(container)
with pytest.raises(BlockExecutionFailedException, match="Test error"):
await executor.run()
@pytest.mark.asyncio
async def test_executor_with_no_blocks():
"""Test workflow executor with no blocks."""
container = DependencyContainer()
container.register(DependencyContainer, container)
container.register(EventBus, EventBus())
container.register(BlockRegistry, test_registry)
empty_workflow = Workflow(name="empty_workflow", blocks=[], wires=[])
container.register(Workflow, empty_workflow)
executor = WorkflowExecutor(container)
result = await executor.run()
assert result == {}
@pytest.mark.asyncio
async def test_executor_with_multiple_outputs():
"""Test workflow executor with a block that has multiple outputs."""
# Define a block with multiple outputs
multi_output_block = Block(
name="MultiOutputBlock",
inputs={
"input1": Input(
name="input1", label="输入1", data_type=str, description="Input data"
)
},
outputs={
"output1": Output(
name="output1", label="输出1", data_type=str, description="First output"
),
"output2": Output(
name="output2",
label="输出2",
data_type=int,
description="Second output",
),
},
)
# Define a workflow with the multi-output block
multi_output_workflow = Workflow(
name="multi_output_workflow",
blocks=[input_block, multi_output_block],
wires=[
Wire(
source_block=input_block,
source_output="output1",
target_block=multi_output_block,
target_input="input1",
)
],
)
container = DependencyContainer()
container.register(DependencyContainer, container)
container.register(EventBus, EventBus())
container.register(BlockRegistry, test_registry)
container.register(Workflow, multi_output_workflow)
executor = WorkflowExecutor(container)
result = await executor.run()
assert "MultiOutputBlock" in result
================================================
FILE: tests/workflow_executor/test_input_output.py
================================================
from kirara_ai.workflow.core.block.input_output import Input, Output
def test_input_validation():
"""Test input validation."""
input_obj = Input(
name="input1", label="输入1", data_type=int, description="An integer input"
)
assert input_obj.validate(10) is True
assert input_obj.validate("10") is False
assert input_obj.validate(None) is False # Not nullable by default
nullable_input = Input(
name="input2",
label="输入2",
data_type=int,
description="A nullable integer input",
nullable=True,
)
assert nullable_input.validate(None) is True
def test_output_validation():
"""Test output validation."""
output_obj = Output(
name="output1", label="输出1", data_type=str, description="A string output"
)
assert output_obj.validate("test") is True
assert output_obj.validate(10) is False
================================================
FILE: tests/workflow_executor/test_workflow_basic.py
================================================
from kirara_ai.workflow.core.block import Block, Input, Output
from kirara_ai.workflow.core.workflow import Wire, Workflow
# Define test blocks
input_block = Block(
name="InputBlock",
inputs={},
outputs={
"output1": Output(
name="output1", label="输出1", data_type=str, description="Input data"
)
},
)
bad_block = Block(
name="InputBlock",
inputs={},
outputs={
"output1": Output(
name="output1", label="输出1", data_type=int, description="Input data"
)
},
)
process_block = Block(
name="ProcessBlock",
inputs={
"input1": Input(
name="input1", label="输入1", data_type=str, description="Input data"
)
},
outputs={
"output1": Output(
name="output1", label="输出1", data_type=str, description="Processed data"
)
},
)
# Define test wires
wire = Wire(
source_block=input_block,
source_output="output1",
target_block=process_block,
target_input="input1",
)
def test_workflow_creation():
"""Test workflow creation."""
workflow = Workflow(
name="test_workflow", blocks=[input_block, process_block], wires=[wire]
)
assert len(workflow.blocks) == 2
assert len(workflow.wires) == 1
assert workflow.wires[0].source_block == input_block
assert workflow.wires[0].target_block == process_block