Repository: pingcap/tidb.ai
Branch: main
Commit: c4cb19d8fa20
Files: 891
Total size: 2.1 MB
Directory structure:
gitextract_vw5kjonn/
├── .dockerignore
├── .github/
│ ├── actions/
│ │ └── decide/
│ │ ├── .gitignore
│ │ ├── action.yml
│ │ ├── index.js
│ │ └── package.json
│ └── workflows/
│ ├── backend-test.yml
│ ├── deploy.yml
│ ├── regression.yml
│ ├── release.yml
│ └── verify.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── backend/
│ ├── .dockerignore
│ ├── .gitignore
│ ├── .pre-commit-config.yaml
│ ├── .python-version
│ ├── Dockerfile
│ ├── Makefile
│ ├── README.md
│ ├── alembic.ini
│ ├── app/
│ │ ├── __init__.py
│ │ ├── alembic/
│ │ │ ├── env.py
│ │ │ ├── script.py.mako
│ │ │ └── versions/
│ │ │ ├── 00534dc350db_.py
│ │ │ ├── 041fbef26e3a_.py
│ │ │ ├── 04947f9684ab_public_chat_engine.py
│ │ │ ├── 04d4f05116ed_.py
│ │ │ ├── 04d81be446c3_.py
│ │ │ ├── 10f36e8a25c4_.py
│ │ │ ├── 197bc8be72d1_.py
│ │ │ ├── 211f3c5aa125_chunking_settings.py
│ │ │ ├── 27a6723b767a_.py
│ │ │ ├── 2adc0b597dcd_int_enum_type.py
│ │ │ ├── 2fc10c21bf88_.py
│ │ │ ├── 749767db5505_add_recommend_questions.py
│ │ │ ├── 8093333c0d87_.py
│ │ │ ├── 830fd9c44f39_.py
│ │ │ ├── 94b198e20946_.py
│ │ │ ├── a54f966436ce_evaluation.py
│ │ │ ├── a8c79553c9f6_.py
│ │ │ ├── ac6e4d58580d_.py
│ │ │ ├── bd17a4ebccc5_.py
│ │ │ ├── c7f016a904c1_.py
│ │ │ ├── d2ad44deab20_multiple_kb.py
│ │ │ ├── dfee070b8abd_.py
│ │ │ ├── e32f1e546eec_.py
│ │ │ └── eb0b85608c0a_.py
│ │ ├── api/
│ │ │ ├── __init__.py
│ │ │ ├── admin_routes/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── chat/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── routes.py
│ │ │ │ ├── chat_engine.py
│ │ │ │ ├── document/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── routes.py
│ │ │ │ ├── embedding_model/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── models.py
│ │ │ │ │ └── routes.py
│ │ │ │ ├── evaluation/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── evaluation_dataset.py
│ │ │ │ │ ├── evaluation_task.py
│ │ │ │ │ ├── models.py
│ │ │ │ │ └── tools.py
│ │ │ │ ├── feedback.py
│ │ │ │ ├── knowledge_base/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── chunk/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── models.py
│ │ │ │ │ │ └── routes.py
│ │ │ │ │ ├── data_source/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── models.py
│ │ │ │ │ │ └── routes.py
│ │ │ │ │ ├── document/
│ │ │ │ │ │ ├── models.py
│ │ │ │ │ │ └── routes.py
│ │ │ │ │ ├── graph/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── knowledge/
│ │ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ │ └── routes.py
│ │ │ │ │ │ ├── models.py
│ │ │ │ │ │ └── routes.py
│ │ │ │ │ ├── models.py
│ │ │ │ │ └── routes.py
│ │ │ │ ├── langfuse.py
│ │ │ │ ├── legacy_retrieve.py
│ │ │ │ ├── llm/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── routes.py
│ │ │ │ ├── models.py
│ │ │ │ ├── reranker_model/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── routes.py
│ │ │ │ ├── semantic_cache.py
│ │ │ │ ├── site_setting.py
│ │ │ │ ├── stats.py
│ │ │ │ ├── upload.py
│ │ │ │ └── user.py
│ │ │ ├── deps.py
│ │ │ ├── main.py
│ │ │ └── routes/
│ │ │ ├── __init__.py
│ │ │ ├── api_key.py
│ │ │ ├── chat.py
│ │ │ ├── chat_engine.py
│ │ │ ├── document.py
│ │ │ ├── feedback.py
│ │ │ ├── index.py
│ │ │ ├── models.py
│ │ │ ├── retrieve/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── routes.py
│ │ │ └── user.py
│ │ ├── api_server.py
│ │ ├── auth/
│ │ │ ├── api_keys.py
│ │ │ ├── db.py
│ │ │ ├── schemas.py
│ │ │ └── users.py
│ │ ├── celery.py
│ │ ├── core/
│ │ │ ├── config.py
│ │ │ └── db.py
│ │ ├── evaluation/
│ │ │ ├── evals.py
│ │ │ └── evaluators/
│ │ │ ├── __init__.py
│ │ │ ├── e2e_rag_evaluator.py
│ │ │ ├── language_detector.py
│ │ │ └── toxicity.py
│ │ ├── exceptions.py
│ │ ├── experiments/
│ │ │ ├── sql_extraction.py
│ │ │ └── sql_sample_gen.py
│ │ ├── file_storage/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── local.py
│ │ ├── logger.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── api_key.py
│ │ │ ├── auth.py
│ │ │ ├── base.py
│ │ │ ├── chat.py
│ │ │ ├── chat_engine.py
│ │ │ ├── chat_message.py
│ │ │ ├── chunk.py
│ │ │ ├── data_source.py
│ │ │ ├── document.py
│ │ │ ├── embed_model.py
│ │ │ ├── entity.py
│ │ │ ├── evaluation_dataset.py
│ │ │ ├── evaluation_task.py
│ │ │ ├── feedback.py
│ │ │ ├── knowledge_base.py
│ │ │ ├── knowledge_base_scoped/
│ │ │ │ ├── __init__.py
│ │ │ │ └── table_naming.py
│ │ │ ├── llm.py
│ │ │ ├── recommend_question.py
│ │ │ ├── relationship.py
│ │ │ ├── reranker_model.py
│ │ │ ├── semantic_cache.py
│ │ │ ├── site_setting.py
│ │ │ ├── staff_action_log.py
│ │ │ └── upload.py
│ │ ├── rag/
│ │ │ ├── __init__.py
│ │ │ ├── build_index.py
│ │ │ ├── chat/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── chat_flow.py
│ │ │ │ ├── chat_service.py
│ │ │ │ ├── config.py
│ │ │ │ ├── retrieve/
│ │ │ │ │ └── retrieve_flow.py
│ │ │ │ └── stream_protocol.py
│ │ │ ├── datasource/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── consts.py
│ │ │ │ ├── file.py
│ │ │ │ ├── web_base.py
│ │ │ │ ├── web_single_page.py
│ │ │ │ └── web_sitemap.py
│ │ │ ├── default_prompt.py
│ │ │ ├── embeddings/
│ │ │ │ ├── local/
│ │ │ │ │ └── local_embedding.py
│ │ │ │ ├── open_like/
│ │ │ │ │ └── openai_like_embedding.py
│ │ │ │ ├── provider.py
│ │ │ │ └── resolver.py
│ │ │ ├── indices/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── knowledge_graph/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── base.py
│ │ │ │ │ ├── extractor.py
│ │ │ │ │ ├── graph_store/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── helpers.py
│ │ │ │ │ │ ├── schema.py
│ │ │ │ │ │ ├── tidb_graph_editor.py
│ │ │ │ │ │ └── tidb_graph_store.py
│ │ │ │ │ └── schema.py
│ │ │ │ └── vector_search/
│ │ │ │ ├── __init__.py
│ │ │ │ └── vector_store/
│ │ │ │ ├── __init__.py
│ │ │ │ └── tidb_vector_store.py
│ │ │ ├── knowledge_base/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ ├── index_store.py
│ │ │ │ └── schema.py
│ │ │ ├── llms/
│ │ │ │ ├── dspy.py
│ │ │ │ ├── provider.py
│ │ │ │ └── resolver.py
│ │ │ ├── node_parser/
│ │ │ │ ├── __init__.py
│ │ │ │ └── file/
│ │ │ │ └── markdown.py
│ │ │ ├── postprocessors/
│ │ │ │ ├── __init__.py
│ │ │ │ └── metadata_post_filter.py
│ │ │ ├── query_dispatcher.py
│ │ │ ├── question_gen/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── helpers.py
│ │ │ │ └── query_decomposer.py
│ │ │ ├── rerankers/
│ │ │ │ ├── baisheng/
│ │ │ │ │ └── baisheng_reranker.py
│ │ │ │ ├── local/
│ │ │ │ │ └── local_reranker.py
│ │ │ │ ├── provider.py
│ │ │ │ ├── resolver.py
│ │ │ │ └── vllm/
│ │ │ │ └── vllm_reranker.py
│ │ │ ├── retrievers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── chunk/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── fusion_retriever.py
│ │ │ │ │ ├── helpers.py
│ │ │ │ │ ├── schema.py
│ │ │ │ │ └── simple_retriever.py
│ │ │ │ ├── knowledge_graph/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── fusion_retriever.py
│ │ │ │ │ ├── schema.py
│ │ │ │ │ └── simple_retriever.py
│ │ │ │ └── multiple_knowledge_base.py
│ │ │ ├── semantic_cache/
│ │ │ │ ├── __init__.py
│ │ │ │ └── base.py
│ │ │ ├── types.py
│ │ │ └── utils.py
│ │ ├── repositories/
│ │ │ ├── __init__.py
│ │ │ ├── base_repo.py
│ │ │ ├── chat.py
│ │ │ ├── chat_engine.py
│ │ │ ├── chunk.py
│ │ │ ├── data_source.py
│ │ │ ├── document.py
│ │ │ ├── embedding_model.py
│ │ │ ├── feedback.py
│ │ │ ├── graph.py
│ │ │ ├── knowledge_base.py
│ │ │ ├── llm.py
│ │ │ ├── reranker_model.py
│ │ │ ├── staff_action_log.py
│ │ │ └── user.py
│ │ ├── site_settings/
│ │ │ ├── __init__.py
│ │ │ ├── default.py
│ │ │ ├── default_settings.yml
│ │ │ └── types.py
│ │ ├── staff_action/
│ │ │ └── __init__.py
│ │ ├── tasks/
│ │ │ ├── __init__.py
│ │ │ ├── build_index.py
│ │ │ ├── evaluate.py
│ │ │ └── knowledge_base.py
│ │ ├── types.py
│ │ └── utils/
│ │ ├── aes.py
│ │ ├── namespace.py
│ │ ├── singleflight_cache.py
│ │ ├── tracing.py
│ │ └── uuid6.py
│ ├── bootstrap.py
│ ├── dspy_compiled_program/
│ │ └── decompose_query/
│ │ ├── demos.json
│ │ └── program.json
│ ├── dspy_program.py
│ ├── local_embedding_reranker/
│ │ ├── .dockerignore
│ │ ├── Dockerfile
│ │ ├── main.py
│ │ └── requirements.txt
│ ├── main.py
│ ├── prestart.sh
│ ├── pyproject.toml
│ ├── supervisord.conf
│ └── tests/
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_dynamic_models.py
│ └── test_llms.py
├── core/
│ ├── .cursor/
│ │ └── rules/
│ │ └── code-style.mdc
│ ├── .gitignore
│ ├── .python-version
│ ├── Makefile
│ ├── README.md
│ ├── autoflow/
│ │ ├── __init__.py
│ │ ├── chunkers/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── helper.py
│ │ │ └── text.py
│ │ ├── configs/
│ │ │ ├── __init__.py
│ │ │ ├── chunkers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ └── text.py
│ │ │ ├── db.py
│ │ │ ├── knowledge_base.py
│ │ │ ├── main.py
│ │ │ └── models/
│ │ │ ├── __init__.py
│ │ │ ├── embeddings/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── common.py
│ │ │ │ ├── jina_ai.py
│ │ │ │ └── openai.py
│ │ │ ├── llms/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── common.py
│ │ │ │ └── openai.py
│ │ │ ├── manager.py
│ │ │ ├── providers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── jinaai.py
│ │ │ │ └── openai.py
│ │ │ └── rerankers/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── common.py
│ │ │ └── jina_ai.py
│ │ ├── data_types.py
│ │ ├── db.py
│ │ ├── knowledge_base/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── prompts.py
│ │ ├── knowledge_graph/
│ │ │ ├── __init__.py
│ │ │ ├── extractors/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ └── simple.py
│ │ │ ├── index.py
│ │ │ ├── programs/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── eval_graph.py
│ │ │ │ ├── extract_covariates.py
│ │ │ │ └── extract_graph.py
│ │ │ ├── retrievers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── simple.py
│ │ │ │ └── weighted.py
│ │ │ └── types.py
│ │ ├── loaders/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── helper.py
│ │ │ ├── markdown.py
│ │ │ ├── pdf.py
│ │ │ └── webpage.py
│ │ ├── main.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── embedding_models/
│ │ │ │ ├── __init__.py
│ │ │ │ └── litellm.py
│ │ │ ├── llms/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dspy.py
│ │ │ │ └── litellm.py
│ │ │ ├── manager.py
│ │ │ ├── provider.py
│ │ │ └── rerank_models/
│ │ │ ├── __init__.py
│ │ │ └── litellm.py
│ │ ├── orms/
│ │ │ ├── __init__.py
│ │ │ └── base.py
│ │ ├── py.typed
│ │ ├── storage/
│ │ │ ├── __init__.py
│ │ │ ├── doc_store/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── tidb_doc_store.py
│ │ │ │ └── types.py
│ │ │ ├── graph_store/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── tidb_graph_store.py
│ │ │ │ └── types.py
│ │ │ └── types.py
│ │ ├── types.py
│ │ └── utils/
│ │ ├── hash.py
│ │ ├── uuid6.py
│ │ └── vector.py
│ ├── examples/
│ │ ├── README.md
│ │ ├── quickstart/
│ │ │ ├── fixtures/
│ │ │ │ ├── analyze-slow-queries.md
│ │ │ │ └── tidb-overview.md
│ │ │ └── quickstart.ipynb
│ │ └── streamlit/
│ │ ├── README.md
│ │ ├── build-knowledge-search-with-autoflow-and-streamlit.py
│ │ └── reqs.txt
│ ├── experimental/
│ │ ├── README.md
│ │ └── kg_extraction/
│ │ └── extract_graph.ipynb
│ ├── pyproject.toml
│ └── tests/
│ ├── __init__.py
│ ├── conftest.py
│ ├── fixtures/
│ │ ├── analyze-slow-queries.md
│ │ └── tidb-overview.md
│ ├── knowledge_base/
│ │ ├── __init__.py
│ │ ├── test_kb_with_namespace.py
│ │ └── test_kb_without_namespace.py
│ ├── knowledge_graph/
│ │ ├── programs/
│ │ │ └── test_extract_graph.py
│ │ └── test_kg_extractor.py
│ ├── models/
│ │ └── test_model_manager.py
│ └── storage/
│ ├── __init__.py
│ ├── doc_store/
│ │ └── test_tidb_doc_store.py
│ └── graph_store/
│ ├── __init__.py
│ └── test_tidb_graph_store.py
├── docker-compose-cn.yml
├── docker-compose.dev.yml
├── docker-compose.yml
├── docs/
│ ├── .gitignore
│ ├── mdx-components.ts
│ ├── next-sitemap.config.js
│ ├── next.config.mjs
│ ├── package.json
│ ├── src/
│ │ ├── app/
│ │ │ ├── [[...mdxPath]]/
│ │ │ │ └── page.jsx
│ │ │ ├── _app.tsx
│ │ │ ├── _ignored/
│ │ │ │ ├── _meta.js
│ │ │ │ └── page.mdx
│ │ │ ├── _meta.ts
│ │ │ ├── globals.css
│ │ │ └── layout.jsx
│ │ └── content/
│ │ ├── README.md
│ │ ├── _meta.ts
│ │ ├── chat-engine.mdx
│ │ ├── deploy-with-docker.mdx
│ │ ├── embedding-model.mdx
│ │ ├── evaluation.mdx
│ │ ├── faq.mdx
│ │ ├── index.mdx
│ │ ├── javascript.mdx
│ │ ├── knowledge-base.mdx
│ │ ├── llm.mdx
│ │ ├── quick-start.mdx
│ │ ├── releases/
│ │ │ ├── _meta.ts
│ │ │ ├── index.mdx
│ │ │ ├── v0.1.0.md
│ │ │ ├── v0.2.0.md
│ │ │ ├── v0.3.0.md
│ │ │ └── v0.4.0.md
│ │ ├── requirements.mdx
│ │ ├── reranker-model.mdx
│ │ └── resources.mdx
│ └── tsconfig.json
├── e2e/
│ ├── .gitignore
│ ├── README.md
│ ├── deploy-test-result.sh
│ ├── docker-compose.yml
│ ├── global.setup.ts
│ ├── package.json
│ ├── playwright.config.ts
│ ├── prepare-test.sh
│ ├── res/
│ │ └── sample-evaluation-dataset.csv
│ ├── start-test.sh
│ ├── test-html/
│ │ ├── example-doc-1.html
│ │ ├── example-doc-2.html
│ │ ├── example-sitemap.xml
│ │ ├── widget-controlled.html
│ │ └── widget.html
│ ├── tests/
│ │ ├── api-keys.spec.ts
│ │ ├── api.spec.ts
│ │ ├── bootstrap.ts
│ │ ├── chat-engine.spec.ts
│ │ ├── chat.spec.ts
│ │ ├── datasource.spec.ts
│ │ ├── evaluation.spec.ts
│ │ ├── knowledge-base.spec.ts
│ │ ├── site-settings.spec.ts
│ │ └── widget.spec.ts
│ ├── utils/
│ │ ├── chat.ts
│ │ ├── forms.ts
│ │ └── login.ts
│ └── vercel.json
└── frontend/
├── .gitignore
├── .nvmrc
├── .prettierignore
├── Dockerfile
├── app/
│ ├── .eslintrc.json
│ ├── .gitignore
│ ├── .storybook/
│ │ ├── main.ts
│ │ └── preview.ts
│ ├── README.md
│ ├── components.json
│ ├── jest.config.ts
│ ├── jest.polyfills.js
│ ├── next-sitemap.config.js
│ ├── next.config.ts
│ ├── notice.md
│ ├── package.json
│ ├── postcss.config.mjs
│ ├── public/
│ │ └── chats.mock.txt
│ ├── src/
│ │ ├── api/
│ │ │ ├── .gitignore
│ │ │ ├── api-keys.ts
│ │ │ ├── auth.ts
│ │ │ ├── chat-engines.ts
│ │ │ ├── chats.ts
│ │ │ ├── commons.ts
│ │ │ ├── datasources.ts
│ │ │ ├── documents.ts
│ │ │ ├── embedding-models.ts
│ │ │ ├── evaluations.ts
│ │ │ ├── feedbacks.ts
│ │ │ ├── graph.ts
│ │ │ ├── knowledge-base.ts
│ │ │ ├── llms.ts
│ │ │ ├── providers.ts
│ │ │ ├── rag.ts
│ │ │ ├── rerankers.ts
│ │ │ ├── site-settings.ts
│ │ │ ├── stats.ts
│ │ │ ├── system.ts
│ │ │ └── users.ts
│ │ ├── app/
│ │ │ ├── (experimental)/
│ │ │ │ └── experimental-features/
│ │ │ │ └── route.ts
│ │ │ ├── (main)/
│ │ │ │ ├── (.)auth/
│ │ │ │ │ └── login/
│ │ │ │ │ ├── loading.tsx
│ │ │ │ │ ├── page.client.tsx
│ │ │ │ │ └── page.tsx
│ │ │ │ ├── (admin)/
│ │ │ │ │ ├── chat-engines/
│ │ │ │ │ │ ├── [id]/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── new/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── embedding-models/
│ │ │ │ │ │ ├── [id]/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── create/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── evaluation/
│ │ │ │ │ │ ├── datasets/
│ │ │ │ │ │ │ ├── [id]/
│ │ │ │ │ │ │ │ ├── items/
│ │ │ │ │ │ │ │ │ ├── [itemId]/
│ │ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ │ └── new/
│ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ ├── not-found.tsx
│ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ ├── create/
│ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── page.tsx
│ │ │ │ │ │ └── tasks/
│ │ │ │ │ │ ├── [id]/
│ │ │ │ │ │ │ ├── not-found.tsx
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── create/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── feedbacks/
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── knowledge-bases/
│ │ │ │ │ │ ├── [id]/
│ │ │ │ │ │ │ ├── (special)/
│ │ │ │ │ │ │ │ ├── data-sources/
│ │ │ │ │ │ │ │ │ └── new/
│ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ └── documents/
│ │ │ │ │ │ │ │ └── [documentId]/
│ │ │ │ │ │ │ │ └── chunks/
│ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ ├── (tabs)/
│ │ │ │ │ │ │ │ ├── data-sources/
│ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ ├── index-progress/
│ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ ├── knowledge-graph-explorer/
│ │ │ │ │ │ │ │ │ ├── create-synopsis-entity/
│ │ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ ├── layout.tsx
│ │ │ │ │ │ │ │ ├── page.tsx
│ │ │ │ │ │ │ │ ├── settings/
│ │ │ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ │ │ └── tabs.tsx
│ │ │ │ │ │ │ ├── api.ts
│ │ │ │ │ │ │ └── context.tsx
│ │ │ │ │ │ ├── new/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── layout.tsx
│ │ │ │ │ ├── llms/
│ │ │ │ │ │ ├── [id]/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── create/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── reranker-models/
│ │ │ │ │ │ ├── [id]/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── create/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── site-settings/
│ │ │ │ │ │ ├── custom_js/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── integrations/
│ │ │ │ │ │ │ └── page.tsx
│ │ │ │ │ │ ├── layout.tsx
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ └── stats/
│ │ │ │ │ └── trending/
│ │ │ │ │ └── page.tsx
│ │ │ │ ├── (user)/
│ │ │ │ │ ├── api-keys/
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ ├── c/
│ │ │ │ │ │ └── page.tsx
│ │ │ │ │ └── layout.tsx
│ │ │ │ ├── c/
│ │ │ │ │ └── [id]/
│ │ │ │ │ └── page.tsx
│ │ │ │ ├── layout.tsx
│ │ │ │ ├── nav.tsx
│ │ │ │ └── page.tsx
│ │ │ ├── RootProviders.tsx
│ │ │ ├── api/
│ │ │ │ └── [[...fallback_placeholder]]/
│ │ │ │ └── route.ts
│ │ │ ├── auth/
│ │ │ │ └── login/
│ │ │ │ └── page.tsx
│ │ │ ├── chart-theme.css
│ │ │ ├── globals.css
│ │ │ └── layout.tsx
│ │ ├── components/
│ │ │ ├── admin-page-heading.tsx
│ │ │ ├── admin-page-layout.tsx
│ │ │ ├── api-keys/
│ │ │ │ └── CreateApiKeyForm.tsx
│ │ │ ├── auth/
│ │ │ │ └── AuthProvider.tsx
│ │ │ ├── auto-scroll/
│ │ │ │ ├── auto-scroll.stories.tsx
│ │ │ │ ├── auto-scroll.tsx
│ │ │ │ ├── context.ts
│ │ │ │ ├── index.ts
│ │ │ │ ├── manual-scroll-voter.tsx
│ │ │ │ ├── use-auto-scroll-voter.ts
│ │ │ │ └── use-request-scroll.ts
│ │ │ ├── branding.tsx
│ │ │ ├── cells/
│ │ │ │ ├── actions.tsx
│ │ │ │ ├── boolean.tsx
│ │ │ │ ├── datetime.tsx
│ │ │ │ ├── error-message.tsx
│ │ │ │ ├── link.tsx
│ │ │ │ ├── metadata.tsx
│ │ │ │ ├── mono.tsx
│ │ │ │ ├── percent.tsx
│ │ │ │ └── reference.tsx
│ │ │ ├── charts/
│ │ │ │ ├── IndexProgressChart.stories.tsx
│ │ │ │ ├── IndexProgressChart.tsx
│ │ │ │ ├── TotalCard.stories.tsx
│ │ │ │ ├── TotalCard.tsx
│ │ │ │ └── TrendsChart.tsx
│ │ │ ├── chat/
│ │ │ │ ├── ask.tsx
│ │ │ │ ├── chat-controller.test.ts
│ │ │ │ ├── chat-controller.ts
│ │ │ │ ├── chat-hooks.tsx
│ │ │ │ ├── chat-message-controller.test.ts
│ │ │ │ ├── chat-message-controller.ts
│ │ │ │ ├── chat-new-dialog.tsx
│ │ │ │ ├── chat-stream-state.ts
│ │ │ │ ├── chat-stream.state.test.ts
│ │ │ │ ├── chats-history.tsx
│ │ │ │ ├── chats-table.tsx
│ │ │ │ ├── conversation-message-groups.scss
│ │ │ │ ├── conversation-message-groups.tsx
│ │ │ │ ├── conversation.test.tsx
│ │ │ │ ├── conversation.tsx
│ │ │ │ ├── debug-info.tsx
│ │ │ │ ├── knowledge-graph-debug-info.tsx
│ │ │ │ ├── message-annotation-history-stackvm.tsx
│ │ │ │ ├── message-annotation-history.tsx
│ │ │ │ ├── message-answer.tsx
│ │ │ │ ├── message-auto-scroll.tsx
│ │ │ │ ├── message-beta-alert.tsx
│ │ │ │ ├── message-content-sources.tsx
│ │ │ │ ├── message-content.test.tsx
│ │ │ │ ├── message-content.tsx
│ │ │ │ ├── message-error.tsx
│ │ │ │ ├── message-feedback.tsx
│ │ │ │ ├── message-input.tsx
│ │ │ │ ├── message-operations.tsx
│ │ │ │ ├── message-recommend-questions.tsx
│ │ │ │ ├── message-section.tsx
│ │ │ │ ├── testutils.ts
│ │ │ │ ├── use-ask.ts
│ │ │ │ ├── use-message-feedback.ts
│ │ │ │ └── utils.ts
│ │ │ ├── chat-engine/
│ │ │ │ ├── chat-engines-table.tsx
│ │ │ │ ├── create-chat-engine-form.tsx
│ │ │ │ ├── hooks.ts
│ │ │ │ ├── kb-list-select.tsx
│ │ │ │ └── update-chat-engine-form.tsx
│ │ │ ├── code-theme.scss
│ │ │ ├── config-viewer.tsx
│ │ │ ├── copy-button.tsx
│ │ │ ├── dangerous-action-button.tsx
│ │ │ ├── data-table-heading.tsx
│ │ │ ├── data-table-remote.tsx
│ │ │ ├── data-table.tsx
│ │ │ ├── datasource/
│ │ │ │ ├── create-datasource-form.tsx
│ │ │ │ ├── datasource-card.tsx
│ │ │ │ ├── datasource-create-option.tsx
│ │ │ │ ├── no-datasource-placeholder.tsx
│ │ │ │ └── update-datasource-form.tsx
│ │ │ ├── date-format.tsx
│ │ │ ├── date-range-picker.tsx
│ │ │ ├── diff-seconds.tsx
│ │ │ ├── document-viewer.tsx
│ │ │ ├── documents/
│ │ │ │ ├── documents-table-filters.tsx
│ │ │ │ └── documents-table.tsx
│ │ │ ├── embedding-models/
│ │ │ │ ├── CreateEmbeddingModelForm.tsx
│ │ │ │ ├── EmbeddingModelInfo.tsx
│ │ │ │ ├── EmbeddingModelsTable.tsx
│ │ │ │ ├── UpdateEmbeddingModelForm.tsx
│ │ │ │ └── hooks.tsx
│ │ │ ├── error-card.tsx
│ │ │ ├── evaluations/
│ │ │ │ ├── cells.tsx
│ │ │ │ ├── create-evaluation-dataset-form.stories.tsx
│ │ │ │ ├── create-evaluation-dataset-form.tsx
│ │ │ │ ├── create-evaluation-dataset-item-form.stories.tsx
│ │ │ │ ├── create-evaluation-dataset-item-form.tsx
│ │ │ │ ├── create-evaluation-task-form.stories.tsx
│ │ │ │ ├── create-evaluation-task-form.tsx
│ │ │ │ ├── evaluation-dataset-info.tsx
│ │ │ │ ├── evaluation-dataset-items-table.tsx
│ │ │ │ ├── evaluation-datasets-table.tsx
│ │ │ │ ├── evaluation-task-info.stories.tsx
│ │ │ │ ├── evaluation-task-info.tsx
│ │ │ │ ├── evaluation-task-items-table.tsx
│ │ │ │ ├── evaluation-tasks-table.tsx
│ │ │ │ ├── hooks.ts
│ │ │ │ ├── keyword-filter-toolbar.tsx
│ │ │ │ └── update-evaluation-dataset-item-form.tsx
│ │ │ ├── feedbacks/
│ │ │ │ └── feedbacks-table.tsx
│ │ │ ├── form/
│ │ │ │ ├── biz.tsx
│ │ │ │ ├── control-widget.tsx
│ │ │ │ ├── create-entity-form.tsx
│ │ │ │ ├── field-layout.tsx
│ │ │ │ ├── root-error.tsx
│ │ │ │ ├── utils.ts
│ │ │ │ └── widgets/
│ │ │ │ ├── CodeInput.tsx
│ │ │ │ ├── FileInput.tsx
│ │ │ │ ├── FilesInput.tsx
│ │ │ │ └── PromptInput.tsx
│ │ │ ├── form-sections.tsx
│ │ │ ├── graph/
│ │ │ │ ├── GraphCreateEntity.tsx
│ │ │ │ ├── GraphEditor.tsx
│ │ │ │ ├── action.ts
│ │ │ │ ├── components/
│ │ │ │ │ ├── EditingButton.tsx
│ │ │ │ │ ├── EntitiesTable.tsx
│ │ │ │ │ ├── InputField.tsx
│ │ │ │ │ ├── JsonEditor.tsx
│ │ │ │ │ ├── JsonField.tsx
│ │ │ │ │ ├── LinkDetails.tsx
│ │ │ │ │ ├── NetworkCanvas.tsx
│ │ │ │ │ ├── NetworkContext.ts
│ │ │ │ │ ├── NetworkViewer.tsx
│ │ │ │ │ ├── NodeDetails.tsx
│ │ │ │ │ ├── SearchEntity.tsx
│ │ │ │ │ ├── SearchEntityById.tsx
│ │ │ │ │ └── TextareaField.tsx
│ │ │ │ ├── index.ts
│ │ │ │ ├── network/
│ │ │ │ │ ├── CanvasNetworkRenderer.ts
│ │ │ │ │ ├── Network.ts
│ │ │ │ │ └── NetworkRendererOptions.ts
│ │ │ │ ├── remote.ts
│ │ │ │ ├── selectEntities.ts
│ │ │ │ ├── useDirtyEntity.ts
│ │ │ │ ├── useDirtyRelationship.ts
│ │ │ │ ├── useNetwork.ts
│ │ │ │ └── utils.ts
│ │ │ ├── gtag-provider.tsx
│ │ │ ├── html-viewer.tsx
│ │ │ ├── icons/
│ │ │ │ └── index.ts
│ │ │ ├── knowledge-base/
│ │ │ │ ├── create-knowledge-base-form.stories.tsx
│ │ │ │ ├── create-knowledge-base-form.tsx
│ │ │ │ ├── document-chunks-table.tsx
│ │ │ │ ├── empty-state.tsx
│ │ │ │ ├── form-index-methods.tsx
│ │ │ │ ├── hooks.ts
│ │ │ │ ├── knowledge-base-card.stories.tsx
│ │ │ │ ├── knowledge-base-card.tsx
│ │ │ │ ├── knowledge-base-chunking-config-fields.tsx
│ │ │ │ ├── knowledge-base-index.tsx
│ │ │ │ └── knowledge-base-settings-form.tsx
│ │ │ ├── llm/
│ │ │ │ ├── CreateLLMForm.tsx
│ │ │ │ ├── LLMsTable.tsx
│ │ │ │ ├── LlmInfo.tsx
│ │ │ │ ├── UpdateLLMForm.tsx
│ │ │ │ └── hooks.ts
│ │ │ ├── loader.tsx
│ │ │ ├── managed-dialog-close.tsx
│ │ │ ├── managed-dialog.tsx
│ │ │ ├── managed-panel.tsx
│ │ │ ├── model-component-info.tsx
│ │ │ ├── nextjs/
│ │ │ │ └── NextLink.tsx
│ │ │ ├── option-detail.tsx
│ │ │ ├── portal-provider.tsx
│ │ │ ├── provider-description.tsx
│ │ │ ├── py-viewer.tsx
│ │ │ ├── remark-content/
│ │ │ │ ├── components.tsx
│ │ │ │ ├── context.tsx
│ │ │ │ ├── highlight.ts
│ │ │ │ ├── index.ts
│ │ │ │ ├── remark-content.stories.tsx
│ │ │ │ ├── remark-content.tsx
│ │ │ │ └── style.scss
│ │ │ ├── reranker/
│ │ │ │ ├── CreateRerankerForm.tsx
│ │ │ │ ├── RerankerInfo.tsx
│ │ │ │ ├── RerankerModelsTable.tsx
│ │ │ │ ├── UpdateRerankerForm.tsx
│ │ │ │ └── hooks.ts
│ │ │ ├── resource-not-found.tsx
│ │ │ ├── row-checkbox.tsx
│ │ │ ├── secondary-navigator-list.tsx
│ │ │ ├── security-setting-provider.tsx
│ │ │ ├── settings/
│ │ │ │ ├── CustomJsSettings.tsx
│ │ │ │ ├── IntegrationsSettings.tsx
│ │ │ │ ├── LinkArrayField.tsx
│ │ │ │ ├── SettingsField.tsx
│ │ │ │ ├── StringArrayField.tsx
│ │ │ │ ├── WebsiteSettings.tsx
│ │ │ │ └── WidgetSnippet.tsx
│ │ │ ├── settings-form/
│ │ │ │ ├── GeneralSettingsField.tsx
│ │ │ │ ├── GeneralSettingsForm.tsx
│ │ │ │ ├── accessor-helper.ts
│ │ │ │ ├── context.tsx
│ │ │ │ ├── index.ts
│ │ │ │ └── utils.ts
│ │ │ ├── signin.tsx
│ │ │ ├── site-header-actions.tsx
│ │ │ ├── site-header.tsx
│ │ │ ├── site-nav.tsx
│ │ │ ├── system/
│ │ │ │ ├── BootstrapStatusProvider.tsx
│ │ │ │ └── SystemWizardBanner.tsx
│ │ │ ├── theme-toggle.tsx
│ │ │ ├── theme.stories.tsx
│ │ │ ├── themed-style.ts
│ │ │ ├── ui/
│ │ │ │ ├── accordion.tsx
│ │ │ │ ├── alert-dialog.tsx
│ │ │ │ ├── alert.tsx
│ │ │ │ ├── aspect-ratio.tsx
│ │ │ │ ├── avatar.tsx
│ │ │ │ ├── badge.tsx
│ │ │ │ ├── breadcrumb.tsx
│ │ │ │ ├── button.tsx
│ │ │ │ ├── calendar.tsx
│ │ │ │ ├── card.tsx
│ │ │ │ ├── carousel.tsx
│ │ │ │ ├── chart.tsx
│ │ │ │ ├── checkbox.tsx
│ │ │ │ ├── collapsible.tsx
│ │ │ │ ├── command.tsx
│ │ │ │ ├── context-menu.tsx
│ │ │ │ ├── dialog.tsx
│ │ │ │ ├── dot-pattern.tsx
│ │ │ │ ├── drawer.tsx
│ │ │ │ ├── dropdown-menu.tsx
│ │ │ │ ├── form.beta.tsx
│ │ │ │ ├── form.tsx
│ │ │ │ ├── hover-card.tsx
│ │ │ │ ├── input-otp.tsx
│ │ │ │ ├── input.tsx
│ │ │ │ ├── label.tsx
│ │ │ │ ├── menubar.tsx
│ │ │ │ ├── navigation-menu.tsx
│ │ │ │ ├── pagination.tsx
│ │ │ │ ├── popover.tsx
│ │ │ │ ├── progress.tsx
│ │ │ │ ├── radio-group.tsx
│ │ │ │ ├── resizable.tsx
│ │ │ │ ├── scroll-area.tsx
│ │ │ │ ├── select.tsx
│ │ │ │ ├── separator.tsx
│ │ │ │ ├── sheet.tsx
│ │ │ │ ├── sidebar.tsx
│ │ │ │ ├── skeleton.tsx
│ │ │ │ ├── slider.tsx
│ │ │ │ ├── sonner.tsx
│ │ │ │ ├── switch.tsx
│ │ │ │ ├── table.tsx
│ │ │ │ ├── tabs.tsx
│ │ │ │ ├── textarea.tsx
│ │ │ │ ├── toast.tsx
│ │ │ │ ├── toaster.tsx
│ │ │ │ ├── toggle-group.tsx
│ │ │ │ ├── toggle.tsx
│ │ │ │ ├── tooltip.tsx
│ │ │ │ └── use-toast.ts
│ │ │ ├── use-active-theme.ts
│ │ │ ├── use-data-table.ts
│ │ │ ├── use-href.ts
│ │ │ ├── use-latest-ref.tsx
│ │ │ ├── use-search-param.ts
│ │ │ ├── use-size.ts
│ │ │ └── website-setting-provider.tsx
│ │ ├── core/
│ │ │ └── schema/
│ │ │ ├── NOTICE.md
│ │ │ └── settings/
│ │ │ └── security.ts
│ │ ├── experimental/
│ │ │ ├── chat-verify-service/
│ │ │ │ ├── api.mock.ts
│ │ │ │ ├── api.react-server.ts
│ │ │ │ ├── api.tidbai-widget.ts
│ │ │ │ ├── api.ts
│ │ │ │ ├── message-verify-result-markdown.tsx
│ │ │ │ ├── message-verify.stories.tsx
│ │ │ │ └── message-verify.tsx
│ │ │ ├── experimental-features-provider.tsx
│ │ │ └── experimental-features.ts
│ │ ├── hooks/
│ │ │ ├── use-mobile.tsx
│ │ │ └── use-model-provider.ts
│ │ └── lib/
│ │ ├── auth.ts
│ │ ├── buffered-readable-stream.test.ts
│ │ ├── buffered-readable-stream.ts
│ │ ├── errors.ts
│ │ ├── react.ts
│ │ ├── request/
│ │ │ ├── authenticationHeaders.mock.ts
│ │ │ ├── authenticationHeaders.react-server.ts
│ │ │ ├── authenticationHeaders.tidbai-widget.ts
│ │ │ ├── authenticationHeaders.ts
│ │ │ ├── base-url.mock.ts
│ │ │ ├── base-url.react-server.ts
│ │ │ ├── base-url.tidbai-widget.ts
│ │ │ ├── base-url.ts
│ │ │ ├── errors.ts
│ │ │ ├── index.ts
│ │ │ ├── list-all-helper.ts
│ │ │ ├── params.ts
│ │ │ ├── response-handlers.ts
│ │ │ └── url.ts
│ │ ├── stackvm/
│ │ │ ├── core/
│ │ │ │ ├── index.ts
│ │ │ │ ├── instructions/
│ │ │ │ │ ├── index.ts
│ │ │ │ │ ├── instructions.ts
│ │ │ │ │ └── registry.ts
│ │ │ │ ├── model.ts
│ │ │ │ ├── types.ts
│ │ │ │ └── visit.ts
│ │ │ └── index.ts
│ │ ├── strings.ts
│ │ ├── tanstack-form.ts
│ │ ├── typing-utils.ts
│ │ ├── ui-error.tsx
│ │ ├── utils.ts
│ │ ├── zod.test.ts
│ │ └── zod.ts
│ ├── tailwind.config.ts
│ └── tsconfig.json
├── package.json
├── packages/
│ └── widget-react/
│ ├── .eslintrc.cjs
│ ├── .gitignore
│ ├── README.md
│ ├── USAGE.md
│ ├── index.html
│ ├── package.json
│ ├── postcss.config.js
│ ├── src/
│ │ ├── Widget.css
│ │ ├── Widget.tsx
│ │ ├── index.css
│ │ ├── library.tsx
│ │ ├── load-config.ts
│ │ ├── overrides/
│ │ │ ├── README.md
│ │ │ └── components/
│ │ │ ├── code-theme.scss
│ │ │ └── remark-content/
│ │ │ └── style.scss
│ │ ├── prepare-gtag.ts
│ │ └── vite-env.d.ts
│ ├── tailwind.config.ts
│ ├── tsconfig.app.json
│ ├── tsconfig.json
│ ├── tsconfig.node.json
│ └── vite.config.ts
├── patches/
│ ├── @jest__environment@29.7.0.patch
│ └── jest-runtime@29.7.0.patch
└── pnpm-workspace.yaml
================================================
FILE CONTENTS
================================================
================================================
FILE: .dockerignore
================================================
# Frontend
.github
e2e
frontend/Dockerfile
frontend/**/node_modules
frontend/app/.next
frontend/app/.swc
frontend/packages/*/dist
================================================
FILE: .github/actions/decide/.gitignore
================================================
node_modules
================================================
FILE: .github/actions/decide/action.yml
================================================
name: 'Decide action'
description: "See issue https://github.com/pingcap/tidb.ai/issues/314"
inputs:
pr-e2e-frontend-label-prefix:
description: "Defaults to 'e2e-frontend:'."
required: true
default: "e2e-frontend:"
pr-e2e-backend-label-prefix:
description: "Defaults to 'e2e-frontend:'."
required: true
default: "e2e-backend:"
outputs:
should-build-frontend:
description: Should build frontend
should-build-backend:
description: Should build backend
e2e-frontend:
description: Frontend docker version
e2e-backend:
description: Frontend docker version
runs:
using: 'node20'
main: 'index.js'
================================================
FILE: .github/actions/decide/index.js
================================================
import * as core from '@actions/core';
import { context, } from '@actions/github';
const FRONTEND_PREFIX = core.getInput('pr-e2e-frontend-label-prefix', { required: true });
const BACKEND_PREFIX = core.getInput('pr-e2e-backend-label-prefix', { required: true });
function run() {
const defaultImageVersion = `sha-${context.sha}-dev`
if (context.eventName !== 'pull_request') {
// Build and run E2E for all other events.
// TODO: Maybe handle commit message like "feat(frontend-only): Some message"
core.setOutput('should-build-frontend', true);
core.setOutput('should-build-backend', true);
core.setOutput('e2e-frontend', defaultImageVersion);
core.setOutput('e2e-backend', defaultImageVersion);
return;
}
let label
if (!!(label = findPRLabel(label => label.name.startsWith(FRONTEND_PREFIX)))) {
core.setOutput('should-build-frontend', false);
core.setOutput('e2e-frontend', label.name.slice(FRONTEND_PREFIX.length));
core.info(`E2E Frontend: ${label.name.slice(FRONTEND_PREFIX.length)}`);
} else {
core.setOutput('should-build-frontend', true);
core.setOutput('e2e-frontend', defaultImageVersion);
}
if (!!(label = findPRLabel(label => label.name.startsWith(BACKEND_PREFIX)))) {
core.setOutput('should-build-backend', false);
core.setOutput('e2e-backend', label.name.slice(BACKEND_PREFIX.length));
core.info(`E2E Backend: ${label.name.slice(BACKEND_PREFIX.length)}`);
} else {
core.setOutput('should-build-backend', true);
core.setOutput('e2e-backend', defaultImageVersion);
}
}
function findPRLabel(test) {
return context.payload.pull_request.labels.find(test);
}
run();
================================================
FILE: .github/actions/decide/package.json
================================================
{
"name": "@tidbai/ci-decide-action",
"version": "1.0.0",
"type": "module",
"module": "index.js",
"devDependencies": {
"@actions/core": "^1.11.1",
"@actions/github": "^6.0.0"
}
}
================================================
FILE: .github/workflows/backend-test.yml
================================================
name: Backend Test
on:
push:
branches:
- main
paths:
- backend/**
pull_request:
branches:
- main
paths:
- backend/**
jobs:
backend-test:
name: Backend Test
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./backend
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Ruff Format
uses: astral-sh/ruff-action@v3
- name: Run Ruff Check
run: ruff check
- name: Run Ruff Format
run: ruff format
================================================
FILE: .github/workflows/deploy.yml
================================================
name: Deploy to Production
on:
workflow_dispatch:
jobs:
build:
name: Build
runs-on: ubuntu-latest
steps:
- name: ssh and deploy
uses: appleboy/ssh-action@v1.0.3
with:
host: ${{ secrets.DEPLOY_HOST }}
username: ${{ secrets.DEPLOY_USERNAME }}
key: ${{ secrets.DEPLOY_SSH_KEY }}
port: ${{ secrets.DEPLOY_PORT }}
script: cd /home/ubuntu/py.tidb.ai-docker && docker compose up -d --force-recreate --pull always
================================================
FILE: .github/workflows/regression.yml
================================================
name: Regression Test
on:
workflow_dispatch:
inputs:
dataset:
description: 'Langfuse dataset to test'
required: true
default: 'regression'
tidbAIChatEngine:
description: 'TiDB AI Chat Engine to test'
required: true
default: 'default'
llmProvider:
description: 'llm provider'
required: true
default: 'openai'
type: choice
options:
- openai
- gemini
jobs:
build:
name: Build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r backend/requirements.lock
- name: Run tests
env:
ENVIRONMENT: "production"
LANGFUSE_HOST: ${{ secrets.LANGFUSE_HOST }}
LANGFUSE_PUBLIC_KEY: ${{ secrets.LANGFUSE_PUBLIC_KEY }}
LANGFUSE_SECRET_KEY: ${{ secrets.LANGFUSE_SECRET_KEY }}
TIDB_AI_CHAT_ENDPOINT: ${{ secrets.TIDB_AI_CHAT_ENDPOINT }}
TIDB_AI_API_KEY: ${{ secrets.TIDB_AI_API_KEY }}
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TIDB_HOST: "127.0.0.1"
TIDB_USER: "root"
TIDB_PASSWORD: "fake"
TIDB_DATABASE: "test"
run: |
cd backend && python main.py runeval --llm-provider ${{ inputs.llmProvider }} --dataset ${{ inputs.dataset }} --tidb-ai-chat-engine ${{ inputs.tidbAIChatEngine }}
================================================
FILE: .github/workflows/release.yml
================================================
name: Build and Publish
on:
pull_request:
branches:
- main
types:
- opened
- synchronize
- reopened
- labeled
- unlabeled
push:
branches:
- main
tags:
- '*'
workflow_dispatch:
inputs:
tag:
description: "The tag to build the image with. Format *.*.*((beta|rc)*)?"
required: true
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
docker-metadata:
name: Generate docker metadata
runs-on: ubuntu-latest
if: ${{ (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'require-build')) || (github.event_name != 'pull_request' && !cancelled()) }}
outputs:
tags: ${{steps.meta.outputs.tags}}
labels: ${{steps.meta.outputs.labels}}
annotations: ${{steps.meta.outputs.annotations}}
platforms: |
linux/amd64
${{(startsWith(github.ref, 'refs/tags/') && 'linux/arm64') || ''}}
should-build-frontend: ${{steps.decide.outputs.should-build-frontend}}
should-build-backend: ${{steps.decide.outputs.should-build-backend}}
e2e-frontend: ${{steps.decide.outputs.e2e-frontend}}
e2e-backend: ${{steps.decide.outputs.e2e-backend}}
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: lts/*
cache: npm
cache-dependency-path: .github/actions/*/package-lock.json
- run: npm ci
working-directory: .github/actions/decide
- name: Decide
id: decide
uses: ./.github/actions/decide/
with:
pr-e2e-frontend-label-prefix: 'e2e-frontend:'
pr-e2e-backend-label-prefix: 'e2e-backend:'
- name: Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
github-token: ${{ github.token }}
images: "{0}/{1}"
tags: |
type=raw,value=branch-{{branch}},enable=${{ !startsWith(github.ref, 'refs/tags/') && !startsWith(github.ref, 'refs/pull/') }}
type=sha,prefix=sha-,format=long,suffix=-dev,enable=${{ !startsWith(github.ref, 'refs/tags/') }}
type=sha,prefix=sha-,format=long,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=pr
type=pep440,value=${{inputs.tag || github.ref_name}},pattern={{version}}
type=pep440,value=${{inputs.tag || github.ref_name}},pattern={{major}}.{{minor}}
type=pep440,value=${{inputs.tag || github.ref_name}},pattern={{major}}
build-backend:
name: Build and Push Backend Image
runs-on: ubuntu-latest
needs: docker-metadata
if: ${{needs.docker-metadata.outputs.should-build-backend == 'true'}}
permissions:
contents: read
packages: write
attestations: write
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.RELEASE_DOCKERHUB_USERNAME }}
password: ${{ secrets.RELEASE_DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push backend
uses: docker/build-push-action@v6
with:
context: backend
platforms: ${{needs.docker-metadata.outputs.platforms}}
push: true
tags: ${{format(needs.docker-metadata.outputs.tags, secrets.RELEASE_DOCKERHUB_USERNAME, 'backend')}}
labels: ${{needs.docker-metadata.outputs.labels}}
annotations: ${{needs.docker-metadata.outputs.annotations}}
cache-from: type=gha
cache-to: type=gha,mode=max
build-frontend:
name: Build and Push Frontend Image
runs-on: ubuntu-latest
needs: docker-metadata
if: ${{needs.docker-metadata.outputs.should-build-frontend == 'true'}}
permissions:
contents: read
packages: write
attestations: write
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.RELEASE_DOCKERHUB_USERNAME }}
password: ${{ secrets.RELEASE_DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push frontend
uses: docker/build-push-action@v6
with:
context: .
file: ./frontend/Dockerfile
platforms: ${{needs.docker-metadata.outputs.platforms}}
push: true
tags: ${{format(needs.docker-metadata.outputs.tags, secrets.RELEASE_DOCKERHUB_USERNAME, 'frontend')}}
labels: ${{needs.docker-metadata.outputs.labels}}
annotations: ${{needs.docker-metadata.outputs.annotations}}
cache-from: type=gha
cache-to: type=gha,mode=max
e2e-test:
name: E2E Test
runs-on: ubuntu-latest
timeout-minutes: 10
defaults:
run:
working-directory: e2e
needs:
- docker-metadata
- build-backend
- build-frontend
if: |
!cancelled() &&
needs.docker-metadata.outputs.e2e-backend &&
needs.docker-metadata.outputs.e2e-frontend &&
!contains(needs.*.result, 'failure') &&
!contains(needs.*.result, 'cancelled')
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: lts/*
cache: npm
cache-dependency-path: e2e/package-lock.json
- name: Install dependencies
run: npm ci
- name: Install Playwright Browsers
run: npx playwright install --with-deps chromium
- name: Prepare Tests
run: ./prepare-test.sh
env:
E2E_DOCKER_TAG_FRONTEND: ${{needs.docker-metadata.outputs.e2e-frontend}}
E2E_DOCKER_TAG_BACKEND: ${{needs.docker-metadata.outputs.e2e-backend}}
- name: Run tests
run: ./start-test.sh
env:
E2E_DOCKER_TAG_FRONTEND: ${{needs.docker-metadata.outputs.e2e-frontend}}
E2E_DOCKER_TAG_BACKEND: ${{needs.docker-metadata.outputs.e2e-backend}}
E2E_LLM_PROVIDER: ${{vars.E2E_LLM_PROVIDER}}
E2E_LLM_MODEL: ${{vars.E2E_LLM_MODEL || ''}}
E2E_LLM_CREDENTIALS: ${{secrets.E2E_LLM_CREDENTIALS}}
E2E_EMBEDDING_PROVIDER: ${{vars.E2E_EMBEDDING_PROVIDER || 'openai'}}
E2E_EMBEDDING_MODEL: ${{vars.E2E_EMBEDDING_MODEL || ''}}
E2E_EMBEDDING_CREDENTIALS: ${{secrets.E2E_EMBEDDING_CREDENTIALS}}
E2E_RERANKER_PROVIDER: ${{vars.E2E_RERANKER_PROVIDER || 'jinaai'}}
E2E_RERANKER_MODEL: ${{vars.E2E_RERANKER_MODEL || ''}}
E2E_RERANKER_CREDENTIALS: ${{secrets.E2E_RERANKER_CREDENTIALS}}
- uses: actions/upload-artifact@v4
if: "!cancelled()"
with:
name: e2e-report
path: e2e/playwright-report/
retention-days: 30
e2e-test-deploy:
name: "Deploy E2E Test Results"
runs-on: ubuntu-latest
needs: e2e-test
if: |
always()
&& !contains(needs.e2e-test.result, 'skipped')
&& !contains(needs.e2e-test.result, 'cancelled')
defaults:
run:
working-directory: e2e
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
- uses: actions/download-artifact@v4
with:
name: e2e-report
path: e2e/playwright-report/
- name: Deploy Test Results to Vercel
id: deploy-test-result
run: |
deploy_message=$(./deploy-test-result.sh)
echo "deploy-message=${deploy_message}" >> $GITHUB_OUTPUT
if: "!cancelled()"
env:
VERCEL_TOKEN: ${{secrets.E2E_VERCEL_TOKEN}}
VERCEL_ORG_ID: ${{secrets.E2E_VERCEL_ORG_ID}}
VERCEL_PROJECT_ID: ${{secrets.E2E_VERCEL_PROJECT_ID}}
VERCEL_CLI_ARGS: ${{(github.ref == 'refs/heads/main' && '--prod') || ''}}
- name: Find PR Comment
uses: peter-evans/find-comment@v3
if: github.event.pull_request.number
id: find-comment
with:
issue-number: ${{ github.event.pull_request.number }}
comment-author: 'github-actions[bot]'
body-includes: E2E Result Deployment
- name: Upsert Comment on PR
uses: peter-evans/create-or-update-comment@v4
if: github.event.pull_request.number
with:
comment-id: ${{ steps.find-comment.outputs.comment-id }}
issue-number: ${{ github.event.pull_request.number }}
body: |
### E2E Result Deployment
${{steps.deploy-test-result.outputs.deploy-message}}
edit-mode: replace
================================================
FILE: .github/workflows/verify.yml
================================================
name: Verify
on:
push:
branches:
- main
paths:
- frontend/**
- README.md
pull_request:
branches:
- main
paths:
- frontend/**
- README.md
jobs:
verify-frontend:
name: Verify frontend
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./frontend
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: pnpm/action-setup@v4
with:
run_install: false
package_json_file: ./frontend/package.json
- name: Install Node.js
uses: actions/setup-node@v4
with:
node-version-file: ./frontend/.nvmrc
cache-dependency-path: ./frontend
cache: 'pnpm'
- run: pnpm i --frozen-lockfile
- uses: actions/cache@v4
with:
path: |
./frontend/.next/cache
key: ${{ runner.os }}-nextjs-${{ hashFiles('frontend/**/pnpm-lock.yaml') }}-${{ hashFiles('frontend/**/*.js', 'frontend/**/*.jsx', 'frontend/**/*.ts', 'frontend/**/*.tsx') }}
- run: pnpm run verify
================================================
FILE: .gitignore
================================================
*_dev.ipynb
.idea
.vscode
.env
.ruff_cache
redis-data
data
venv
.venv
local-embedding-reranker
*.swp
*.swo
.next
node_modules/
_pagefind/
================================================
FILE: CONTRIBUTING.md
================================================
# How to contribute
## Contributing Guidelines
[pingcap/autoflow](https://github.com/pingcap/autoflow) is an open-source project and we welcome contributions from the community. If you are interested in contributing to the project, please read the following guidelines.
### Before You Get Started
#### Software Prerequisites for Development
In this section, you should have some prerequisites software installed on your local machine:
* [Docker](https://docs.docker.com/get-docker/)
* [Docker Compose](https://docs.docker.com/compose/install/)
* [Python](https://www.python.org/downloads/)
* [Node.js](https://nodejs.org/en/download/)
* [TiDB Cloud Serverless](https://pingcap.com/ai/?utm_source=tidb.ai&utm_medium=community) or [TiDB Self-Managed](https://www.pingcap.com/tidb-self-managed/?utm_source=tidb.ai&utm_medium=community)
#### Setting up your development environment
Setting up the project on your local machine is the first step to contributing to the project. You can clone the project from the GitHub repository and then start the project on your local machine. You can follow the instructions in the [Deployment Guide](https://autoflow.tidb.ai/deploy-with-docker) file to set up the project on your local machine.
To test your local changes, you can build and run the project using:
```bash
docker compose -f docker-compose.dev.yml up
```
### Your First Contribution
All set to participate in the project? You can start by looking at the [open issues](https://github.com/pingcap/autoflow/issues) in this repo.
### Components of the Project
The project is divided into several components, and you can contribute to any of the following components:
* [Frontend](https://github.com/pingcap/autoflow/tree/main/frontend): The frontend of the project is built using Next.js.
* [Backend](https://github.com/pingcap/tidb.ai/tree/main/backend): The backend of the project is built using FastAPI.
* [Data Source](https://github.com/pingcap/autoflow/tree/main/backend/app/rag/datasource): The Data Source component is responsible for indexing the data from different type of sources. You can add more data source types to the project.
* [LLM](https://github.com/pingcap/tidb.ai/tree/main/backend/app/rag/llms): The LLM Engine component is responsible for extracting knowledge from docs and generating responses. You can add more LLM models support to the project.
* [Reranker](https://github.com/pingcap/tidb.ai/blob/main/backend/app/rag/reranker_model_option.py): The Reranker Engine component is responsible for reranking the results retrieved from the database. You can add more Reranker models support to the project.
* [Embedding](https://github.com/pingcap/tidb.ai/blob/main/backend/app/rag/embed_model_option.py): The Embedding Engine component is responsible for converting text into vectors. You can add more Embedding models support to the project.
* [RAG & GraphRAG Engine](https://github.com/pingcap/tidb.ai/tree/main/backend/app/rag): The component is responsible for extracting knowldge from docs and then chunking, indexing and storing the data in the database, also includes retrieving the data from the database and generating the answer for the user.
* [Documentations](https://github.com/pingcap/tidb.ai/tree/main/frontend/app/src/pages): The documentation of the project is written in Markdown files. You can contribute to the documentation by adding more content to the documentation.
### How to add an API?
Using the FastAPI framework.
* Create a FastAPI Instance.
```python
router = FastAPI()
```
* Use Decorators to Define API Endpoints. For example:
```python
@router.get("xxx")
```
* Implement the Route Handler Function.
* Add sub-routes to the main route and tag the sub-routes in `backend/app/api/main.py` .
```python
api_router.include_router(sub_router, tags=["xxxx"])
```
## Maintainers
Please feel free to reach out to the maintainers if you have any questions or need help with the project.
* [wd0517](https://github.com/wd0517)
* [634750802](https://github.com/634750802)
* [Mini256](https://github.com/Mini256)
* [IANTHEREAL](https://github.com/IANTHEREAL)
* [Cheese](https://github.com/Icemap)
## Discussion
If you have any questions or suggestions, please feel free to open a discussion in the [Discussions](https://github.com/pingcap/tidb.ai/discussions)
================================================
FILE: LICENSE.txt
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 PingCAP
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
AutoFlow
[](https://hub.docker.com/r/tidbai/backend)
[](https://hub.docker.com/r/tidbai/frontend)
[](https://tidb-ai-playwright.vercel.app/)
> [!WARNING]
> Autoflow is still in the early stages of development. And we are actively working on it, the next move is to make it to a python package and make it a RAG solution e.g. `pip install autoflow-ai`. If you have any questions or suggestions, please feel free to contact us on [Discussion](https://github.com/pingcap/autoflow/discussions).
## Introduction
AutoFlow is an open source graph rag (graphrag: knowledge graph rag) based knowledge base tool built on top of [TiDB Vector](https://www.pingcap.com/ai?utm_source=tidb.ai&utm_medium=community) and [LlamaIndex](https://github.com/run-llama/llama_index) and [DSPy](https://github.com/stanfordnlp/dspy).
- **Live Demo**: [https://tidb.ai](https://tidb.ai?utm_source=tidb.ai&utm_medium=community)
- **Deployment Docs**: [Deployment Docs](https://autoflow.tidb.ai/?utm_source=github&utm_medium=tidb.ai)
## Features
1. **Perplexity-style Conversational Search page**: Our platform features an advanced built-in website crawler, designed to elevate your browsing experience. This crawler effortlessly navigates official and documentation sites, ensuring comprehensive coverage and streamlined search processes through sitemap URL scraping.

2. **Embeddable JavaScript Snippet**: Integrate our conversational search window effortlessly into your website by copying and embedding a simple JavaScript code snippet. This widget, typically placed at the bottom right corner of your site, facilitates instant responses to product-related queries.

## Deploy
- [Deploy with Docker Compose](https://autoflow.tidb.ai/deploy-with-docker) (with: 4 CPU cores and 8GB RAM)
## Tech Stack
- [TiDB](https://www.pingcap.com/ai?utm_source=tidb.ai&utm_medium=community) – Database to store chat history, vector, json, and analytic
- [LlamaIndex](https://www.llamaindex.ai/) - RAG framework
- [DSPy](https://github.com/stanfordnlp/dspy) - The framework for programming—not prompting—foundation models
- [Next.js](https://nextjs.org/) – Framework
- [Tailwind CSS](https://tailwindcss.com/) – CSS framework
- [shadcn/ui](https://ui.shadcn.com/) - Design
## Contributing
We welcome contributions from the community. If you are interested in contributing to the project, please read the [Contributing Guidelines](/CONTRIBUTING.md).
## License
AutoFlow is open-source under the Apache License, Version 2.0. You can [find it here](https://github.com/pingcap/autoflow/blob/main/LICENSE.txt).
## Contact
You can reach out to us on [Discord](https://discord.gg/XzSW23Jg9p).
================================================
FILE: backend/.dockerignore
================================================
# MacOS
.DS_Store
# Environment
.env
.venv
# Git
.git
# IDE
.idea
.vscode
# Cache
.pytest_cache
.ruff_cache
**/__pycache__
# Test
.tox
tests
pytest.ini
# Extra
local_embedding_reranker/
.pre-commit-config.yaml
.gitignore
Dockerfile
.dockerignore
================================================
FILE: backend/.gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.ruff_cache
*.csv
.DS_Store
# VSCode
.vscode/
checkpoint.json
================================================
FILE: backend/.pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.6
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
================================================
FILE: backend/.python-version
================================================
3.12.3
================================================
FILE: backend/Dockerfile
================================================
FROM ghcr.io/astral-sh/uv:python3.11-bookworm-slim
WORKDIR /app/
# Setup supervisord.
RUN apt-get update && apt-get install -y supervisor gcc
COPY supervisord.conf /usr/etc/supervisord.conf
# Install dependencies.
COPY uv.lock /app/uv.lock
COPY pyproject.toml /app/pyproject.toml
RUN PYTHONDONTWRITEBYTECODE=1 uv sync --frozen
ENV PATH="/app/.venv/bin:$PATH"
# Pre-download playwright dependencies.
RUN playwright install --with-deps chromium
# Pre-download nltk data.
RUN python -c 'import nltk; \
download_dir = "/usr/local/lib/python3.11/site-packages/llama_index/core/_static/nltk_cache";\
nltk.download("stopwords", download_dir=download_dir);\
nltk.download("punkt", download_dir=download_dir);'
ENV PYTHONPATH=/app
COPY . /app/
# Default number of workers
ENV WEB_CONCURRENCY=4
CMD ["sh", "-c", "fastapi run app/api_server.py --host 0.0.0.0 --port 80 --workers ${WEB_CONCURRENCY}"]
================================================
FILE: backend/Makefile
================================================
.PHONY: test
makemigrations:
@echo "Creating migrations..."
@if [ -z "$(NAME)" ]; then \
uv run alembic revision --autogenerate; \
else \
uv run alembic revision --autogenerate -m "$(NAME)"; \
fi
migrate:
@echo "Migrating database..."
@uv run alembic upgrade head
lint:
@echo "Linting code..."
@uv run ruff check .
format:
@echo "Formatting code..."
@uv run ruff format .
test:
@echo "Running tests..."
@uv run pytest -v tests/
dev_backend:
@echo "Running backend server in development mode..."
@uv run fastapi dev app/api_server.py --host 127.0.0.1 --port 5001
run_backend:
@echo "Running backend server..."
@uv run fastapi run app/api_server.py --host 0.0.0.0 --port 5001 --workers 4
dev_celery_flower:
@echo "Running Celery Flower..."
@uv run celery -A app.celery flower --address=0.0.0.0 --port=5555
dev_background_worker:
@echo "Running background worker..."
@uv run celery -A app.celery worker -Q default -l INFO -E
dev_eval_worker:
@echo "Running evaluation worker..."
@uv run celery -A app.celery worker -Q evaluation --loglevel=debug --pool=solo
================================================
FILE: backend/README.md
================================================
# Backend of tidb.ai
## Development
### Install dependencies
1. Install [uv](https://docs.astral.sh/uv/getting-started/installation/)
2. Use `uv` to install dependencies
```bash
uv sync
```
### Prepare environment
```
cp .env.example .env
```
Edit `.env` to set environment variables.
### Run migrations
```bash
make migrate
```
### Run development server
```bash
uv run python main.py runserver
```
================================================
FILE: backend/alembic.ini
================================================
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = app/alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# max length of characters to apply to the
# "slug" field
#truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; this defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path
# version_locations = %(here)s/bar %(here)s/bat alembic/versions
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
================================================
FILE: backend/app/__init__.py
================================================
import os
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
================================================
FILE: backend/app/alembic/env.py
================================================
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlmodel import SQLModel
from tidb_vector.sqlalchemy import VectorType
from app.core.config import settings
from app.models import * # noqa
from app.models.knowledge_base_scoped.table_naming import (
KB_CHUNKS_TABLE_PATTERN,
KB_ENTITIES_TABLE_PATTERN,
KB_RELATIONSHIPS_TABLE_PATTERN,
)
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
# target_metadata = None
target_metadata = SQLModel.metadata
def get_url():
return str(settings.SQLALCHEMY_DATABASE_URI)
def include_name(name, type_, parent_names):
if type_ == "table":
return (
not bool(KB_CHUNKS_TABLE_PATTERN.match(name))
and not bool(KB_ENTITIES_TABLE_PATTERN.match(name))
and not bool(KB_RELATIONSHIPS_TABLE_PATTERN.match(name))
)
else:
return True
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = get_url()
context.configure(
url=url,
target_metadata=target_metadata,
include_name=include_name,
literal_binds=True,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = get_url()
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
connection.dialect.ischema_names["vector"] = VectorType
context.configure(
connection=connection,
target_metadata=target_metadata,
include_name=include_name,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
================================================
FILE: backend/app/alembic/script.py.mako
================================================
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}
================================================
FILE: backend/app/alembic/versions/00534dc350db_.py
================================================
"""empty message
Revision ID: 00534dc350db
Revises: 10f36e8a25c4
Create Date: 2024-08-26 12:46:00.203425
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = "00534dc350db"
down_revision = "10f36e8a25c4"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"embedding_models",
"provider",
existing_type=mysql.ENUM("OPENAI"),
type_=sa.String(length=32),
existing_nullable=False,
)
op.alter_column(
"llms",
"provider",
existing_type=mysql.ENUM(
"OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE", "BEDROCK"
),
type_=sa.String(length=32),
existing_nullable=False,
)
op.alter_column(
"reranker_models",
"provider",
existing_type=mysql.ENUM("JINA", "COHERE", "BAISHENG"),
type_=sa.String(length=32),
existing_nullable=False,
)
op.execute("UPDATE embedding_models SET provider = lower(provider)")
op.execute("UPDATE llms SET provider = lower(provider)")
op.execute("UPDATE reranker_models SET provider = lower(provider)")
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"reranker_models",
"provider",
existing_type=sa.String(length=32),
type_=mysql.ENUM("JINA", "COHERE", "BAISHENG"),
existing_nullable=False,
)
op.alter_column(
"llms",
"provider",
existing_type=sa.String(length=32),
type_=mysql.ENUM(
"OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE", "BEDROCK"
),
existing_nullable=False,
)
op.alter_column(
"embedding_models",
"provider",
existing_type=sa.String(length=32),
type_=mysql.ENUM("OPENAI"),
existing_nullable=False,
)
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/041fbef26e3a_.py
================================================
"""empty message
Revision ID: 041fbef26e3a
Revises: 04d81be446c3
Create Date: 2024-08-19 08:20:13.695891
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.mysql import DATETIME
# revision identifiers, used by Alembic.
revision = "041fbef26e3a"
down_revision = "04d81be446c3"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"site_settings",
"created_at",
existing_type=DATETIME(timezone=True, fsp=6),
server_default=sa.text("current_timestamp(6)"),
nullable=False,
)
op.alter_column(
"site_settings",
"updated_at",
existing_type=DATETIME(timezone=True, fsp=6),
server_default=sa.text("current_timestamp(6) on update current_timestamp(6)"),
nullable=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"site_settings",
"created_at",
existing_type=DATETIME(timezone=True),
nullable=False,
)
op.alter_column(
"site_settings",
"updated_at",
existing_type=DATETIME(timezone=True),
nullable=False,
)
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/04947f9684ab_public_chat_engine.py
================================================
"""public_chat_engine
Revision ID: 04947f9684ab
Revises: 211f3c5aa125
Create Date: 2025-05-28 15:13:22.058160
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "04947f9684ab"
down_revision = "211f3c5aa125"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat_engines", sa.Column("is_public", sa.Boolean(), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_engines", "is_public")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/04d4f05116ed_.py
================================================
"""empty message
Revision ID: 04d4f05116ed
Revises: 94b198e20946
Create Date: 2024-07-23 01:26:07.117623
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from app.models.base import AESEncryptedColumn
# revision identifiers, used by Alembic.
revision = "04d4f05116ed"
down_revision = "94b198e20946"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"embedding_models",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column(
"provider", sa.Enum("OPENAI", name="embeddingprovider"), nullable=False
),
sa.Column(
"model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False
),
sa.Column("config", sa.JSON(), nullable=True),
sa.Column("credentials", AESEncryptedColumn(), nullable=True),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"llms",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column(
"provider",
sa.Enum("OPENAI", "GEMINI", "ANTHROPIC_VERTEX", name="llmprovider"),
nullable=False,
),
sa.Column(
"model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False
),
sa.Column("config", sa.JSON(), nullable=True),
sa.Column("credentials", AESEncryptedColumn(), nullable=True),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("chat_engines", sa.Column("llm_id", sa.Integer(), nullable=True))
op.add_column("chat_engines", sa.Column("fast_llm_id", sa.Integer(), nullable=True))
op.create_foreign_key(None, "chat_engines", "llms", ["fast_llm_id"], ["id"])
op.create_foreign_key(None, "chat_engines", "llms", ["llm_id"], ["id"])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "chat_engines", type_="foreignkey")
op.drop_constraint(None, "chat_engines", type_="foreignkey")
op.drop_column("chat_engines", "fast_llm_id")
op.drop_column("chat_engines", "llm_id")
op.drop_table("llms")
op.drop_table("embedding_models")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/04d81be446c3_.py
================================================
"""empty message
Revision ID: 04d81be446c3
Revises: e32f1e546eec
Create Date: 2024-08-08 17:11:50.178696
"""
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = "04d81be446c3"
down_revision = "e32f1e546eec"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"llms",
"provider",
existing_type=mysql.ENUM(
"OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE", "BEDROCK"
),
nullable=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"llms",
"provider",
existing_type=mysql.ENUM("OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE"),
nullable=False,
)
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/10f36e8a25c4_.py
================================================
"""empty message
Revision ID: 10f36e8a25c4
Revises: 041fbef26e3a
Create Date: 2024-08-20 09:10:50.130219
"""
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = "10f36e8a25c4"
down_revision = "041fbef26e3a"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"reranker_models",
"provider",
existing_type=mysql.ENUM("JINA", "COHERE", "BAISHENG"),
nullable=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"reranker_models",
"provider",
existing_type=mysql.ENUM("JINA", "COHERE"),
nullable=False,
)
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/197bc8be72d1_.py
================================================
"""empty message
Revision ID: 197bc8be72d1
Revises: 04d4f05116ed
Create Date: 2024-07-25 14:49:29.363595
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "197bc8be72d1"
down_revision = "04d4f05116ed"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"llms",
"provider",
type_=sa.Enum(
"OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE", name="llmprovider"
),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"llms",
"provider",
type_=sa.Enum("OPENAI", "GEMINI", "ANTHROPIC_VERTEX", name="llmprovider"),
)
================================================
FILE: backend/app/alembic/versions/211f3c5aa125_chunking_settings.py
================================================
"""chunking_settings
Revision ID: 211f3c5aa125
Revises: 2adc0b597dcd
Create Date: 2025-02-17 14:20:56.253857
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "211f3c5aa125"
down_revision = "2adc0b597dcd"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"knowledge_bases", sa.Column("chunking_config", sa.JSON(), nullable=True)
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("knowledge_bases", "chunking_config")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/27a6723b767a_.py
================================================
"""empty message
Revision ID: 27a6723b767a
Revises: d2ad44deab20
Create Date: 2024-11-29 20:38:05.773083
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "27a6723b767a"
down_revision = "d2ad44deab20"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_messages",
sa.Column("is_best_answer", sa.Boolean(), server_default="0", nullable=False),
)
op.create_index(
"ix_chat_message_is_best_answer",
"chat_messages",
["is_best_answer"],
unique=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_chat_message_is_best_answer", table_name="chat_messages")
op.drop_column("chat_messages", "is_best_answer")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/2adc0b597dcd_int_enum_type.py
================================================
"""int_enum_type
Revision ID: 2adc0b597dcd
Revises: a54f966436ce
Create Date: 2025-01-24 17:58:08.339090
"""
from alembic import op
from sqlalchemy.dialects import mysql
from app.models.base import IntEnumType
from app.models.chat import ChatVisibility
# revision identifiers, used by Alembic.
revision = "2adc0b597dcd"
down_revision = "a54f966436ce"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"chats",
"visibility",
existing_type=mysql.SMALLINT(),
type_=IntEnumType(ChatVisibility),
existing_nullable=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"chats",
"visibility",
existing_type=IntEnumType(ChatVisibility),
type_=mysql.SMALLINT(),
existing_nullable=False,
)
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/2fc10c21bf88_.py
================================================
"""empty message
Revision ID: 5fdea8e26454
Revises:
Create Date: 2024-07-10 14:43:55.913126
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy.dialects import mysql
from app.core.config import settings
# revision identifiers, used by Alembic.
revision = "2fc10c21bf88"
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"chat_engines",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.Column("engine_options", sa.JSON(), nullable=True),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"documents",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("hash", sqlmodel.sql.sqltypes.AutoString(length=32), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.Column("content", mysql.MEDIUMTEXT(), nullable=True),
sa.Column(
"mime_type", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False
),
sa.Column(
"source_uri", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False
),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("last_modified_at", sa.DateTime(), nullable=True),
sa.Column(
"index_status",
sa.Enum(
"NOT_STARTED",
"PENDING",
"RUNNING",
"COMPLETED",
"FAILED",
name="docindextaskstatus",
),
nullable=False,
),
sa.Column("index_result", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("source_uri"),
)
op.create_table(
"entities",
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column(
"entity_type",
sa.Enum("original", "synopsis", name="entitytype"),
nullable=False,
),
sa.Column("synopsis_info", sa.JSON(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"description_vec",
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.Column(
"meta_vec",
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"semantic_cache",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("query", sa.Text(), nullable=True),
sa.Column(
"query_vec",
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.Column("value", sa.Text(), nullable=True),
sa.Column(
"value_vec",
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column(
"updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.PrimaryKeyConstraint("id"),
mysql_TTL="created_at + INTERVAL 1 MONTH;",
)
op.create_table(
"site_settings",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.Column(
"data_type", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False
),
sa.Column("value", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_table(
"staff_action_logs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("action", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column(
"action_time", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("target_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("target_id", sa.Integer(), nullable=False),
sa.Column("before", sa.JSON(), nullable=True),
sa.Column("after", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"users",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("email", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column(
"hashed_password", sqlmodel.sql.sqltypes.AutoString(), nullable=False
),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column("is_verified", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)
op.create_table(
"api_keys",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"description", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False
),
sa.Column(
"hashed_secret",
sqlmodel.sql.sqltypes.AutoString(length=255),
nullable=False,
),
sa.Column(
"api_key_display",
sqlmodel.sql.sqltypes.AutoString(length=100),
nullable=False,
),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("hashed_secret"),
)
op.create_table(
"chats",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column(
"title", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False
),
sa.Column("engine_id", sa.Integer(), nullable=True),
sa.Column("engine_options", sa.JSON(), nullable=True),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["engine_id"],
["chat_engines.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_chats_id"), "chats", ["id"], unique=False)
op.create_table(
"chunks",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("hash", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column("text", sa.Text(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column(
"embedding",
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.Column("document_id", sa.Integer(), nullable=True),
sa.Column("relations", sa.JSON(), nullable=True),
sa.Column(
"source_uri", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True
),
sa.Column(
"index_status",
sa.Enum(
"NOT_STARTED",
"PENDING",
"RUNNING",
"COMPLETED",
"FAILED",
name="kgindexstatus",
),
nullable=False,
),
sa.Column("index_result", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["document_id"],
["documents.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_chunks_id"), "chunks", ["id"], unique=False)
op.create_table(
"relationships",
sa.Column("description", sa.Text(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("weight", sa.Integer(), nullable=False),
sa.Column("source_entity_id", sa.Integer(), nullable=False),
sa.Column("target_entity_id", sa.Integer(), nullable=False),
sa.Column("last_modified_at", sa.DateTime(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"description_vec",
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.ForeignKeyConstraint(
["source_entity_id"],
["entities.id"],
),
sa.ForeignKeyConstraint(
["target_entity_id"],
["entities.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"user_sessions",
sa.Column("token", sqlmodel.sql.sqltypes.AutoString(length=43), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("token"),
)
op.create_table(
"chat_messages",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False),
sa.Column("role", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("sources", sa.JSON(), nullable=True),
sa.Column(
"trace_url", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True
),
sa.Column("finished_at", sa.DateTime(), nullable=True),
sa.Column("chat_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["chat_id"],
["chats.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"feedbacks",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"feedback_type",
sa.Enum("LIKE", "DISLIKE", name="feedbacktype"),
nullable=False,
),
sa.Column(
"comment", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=False
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("chat_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("chat_message_id", sa.Integer(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["chat_id"],
["chats.id"],
),
sa.ForeignKeyConstraint(
["chat_message_id"],
["chat_messages.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("feedbacks")
op.drop_table("chat_messages")
op.drop_table("user_sessions")
op.drop_table("relationships")
op.drop_index(op.f("ix_chunks_id"), table_name="chunks")
op.drop_table("chunks")
op.drop_index(op.f("ix_chats_id"), table_name="chats")
op.drop_table("chats")
op.drop_table("api_keys")
op.drop_index(op.f("ix_users_id"), table_name="users")
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_table("users")
op.drop_table("staff_action_logs")
op.drop_table("site_settings")
op.drop_table("semantic_cache")
op.drop_table("entities")
op.drop_table("documents")
op.drop_table("chat_engines")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/749767db5505_add_recommend_questions.py
================================================
"""add recommend questions
Revision ID: 749767db5505
Revises: 8093333c0d87
Create Date: 2024-10-15 16:02:14.203584
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "749767db5505"
down_revision = "8093333c0d87"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"recommend_questions",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("questions", sa.JSON(), nullable=True),
sa.Column("chat_message_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["chat_message_id"],
["chat_messages.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_recommend_questions_chat_message_id"),
"recommend_questions",
["chat_message_id"],
unique=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
op.f("ix_recommend_questions_chat_message_id"), table_name="recommend_questions"
)
op.drop_table("recommend_questions")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/8093333c0d87_.py
================================================
"""empty message
Revision ID: 8093333c0d87
Revises: 830fd9c44f39
Create Date: 2024-09-24 12:23:48.076576
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8093333c0d87"
down_revision = "830fd9c44f39"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chats", sa.Column("visibility", sa.SmallInteger(), nullable=False))
op.execute("UPDATE chats SET visibility = 1 WHERE user_id IS NULL;")
op.execute(
"UPDATE chats SET engine_options = CAST(JSON_UNQUOTE(engine_options) AS JSON);"
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chats", "visibility")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/830fd9c44f39_.py
================================================
"""empty message
Revision ID: 830fd9c44f39
Revises: dfee070b8abd
Create Date: 2024-09-19 13:04:30.351449
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
# revision identifiers, used by Alembic.
revision = "830fd9c44f39"
down_revision = "dfee070b8abd"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chats",
sa.Column(
"origin", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=True
),
)
op.add_column(
"feedbacks",
sa.Column(
"origin", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=True
),
)
op.add_column(
"chat_messages",
sa.Column(
"post_verification_result_url",
sqlmodel.sql.sqltypes.AutoString(length=512),
nullable=True,
),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_messages", "post_verification_result_url")
op.drop_column("feedbacks", "origin")
op.drop_column("chats", "origin")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/94b198e20946_.py
================================================
"""empty message
Revision ID: 94b198e20946
Revises: 2fc10c21bf88
Create Date: 2024-07-11 15:19:19.174568
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
# revision identifiers, used by Alembic.
revision = "94b198e20946"
down_revision = "2fc10c21bf88"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"data_sources",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.Column(
"description", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False
),
sa.Column(
"data_source_type",
sqlmodel.sql.sqltypes.AutoString(length=256),
nullable=False,
),
sa.Column("config", sa.JSON(), nullable=True),
sa.Column("build_kg_index", sa.Boolean(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"uploads",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column("size", sa.Integer(), nullable=False),
sa.Column("path", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column(
"mime_type", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False
),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("documents", sa.Column("data_source_id", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("documents", "data_source_id")
op.drop_table("uploads")
op.drop_table("data_sources")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/a54f966436ce_evaluation.py
================================================
"""evaluation
Revision ID: a54f966436ce
Revises: 27a6723b767a
Create Date: 2024-12-09 16:46:21.077517
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
# revision identifiers, used by Alembic.
revision = "a54f966436ce"
down_revision = "27a6723b767a"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"evaluation_datasets",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"evaluation_tasks",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("dataset_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"evaluation_dataset_items",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("query", sa.Text(), nullable=True),
sa.Column("reference", sa.Text(), nullable=True),
sa.Column("retrieved_contexts", sa.JSON(), nullable=True),
sa.Column("extra", sa.JSON(), nullable=True),
sa.Column("evaluation_dataset_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["evaluation_dataset_id"],
["evaluation_datasets.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"evaluation_task_items",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"chat_engine", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False
),
sa.Column("status", sa.String(length=32), nullable=False),
sa.Column("query", sa.Text(), nullable=True),
sa.Column("reference", sa.Text(), nullable=True),
sa.Column("response", sa.Text(), nullable=True),
sa.Column("retrieved_contexts", sa.JSON(), nullable=True),
sa.Column("extra", sa.JSON(), nullable=True),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("factual_correctness", sa.Float(), nullable=True),
sa.Column("semantic_similarity", sa.Float(), nullable=True),
sa.Column("evaluation_task_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["evaluation_task_id"],
["evaluation_tasks.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("evaluation_task_items")
op.drop_table("evaluation_dataset_items")
op.drop_table("evaluation_tasks")
op.drop_table("evaluation_datasets")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/a8c79553c9f6_.py
================================================
"""empty message
Revision ID: a8c79553c9f6
Revises: ac6e4d58580d
Create Date: 2024-08-05 13:04:17.572821
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c79553c9f6"
down_revision = "ac6e4d58580d"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("data_sources", sa.Column("llm_id", sa.Integer(), nullable=True))
op.create_foreign_key(None, "data_sources", "llms", ["llm_id"], ["id"])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "data_sources", type_="foreignkey")
op.drop_column("data_sources", "llm_id")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/ac6e4d58580d_.py
================================================
"""empty message
Revision ID: ac6e4d58580d
Revises: 197bc8be72d1
Create Date: 2024-08-01 16:15:59.164348
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = "ac6e4d58580d"
down_revision = "197bc8be72d1"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chats",
sa.Column(
"browser_id", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True
),
)
op.alter_column(
"llms",
"provider",
existing_type=mysql.ENUM("OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE"),
nullable=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"llms",
"provider",
existing_type=mysql.ENUM("OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE"),
nullable=True,
)
op.drop_column("chats", "browser_id")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/bd17a4ebccc5_.py
================================================
"""empty message
Revision ID: bd17a4ebccc5
Revises: a8c79553c9f6
Create Date: 2024-08-08 01:20:42.069228
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
# revision identifiers, used by Alembic.
revision = "bd17a4ebccc5"
down_revision = "a8c79553c9f6"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("data_sources", sa.Column("deleted_at", sa.DateTime(), nullable=True))
op.drop_index("source_uri", table_name="documents")
op.add_column(
"relationships",
sa.Column("chunk_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
)
op.add_column(
"relationships",
sa.Column("document_id", sa.Integer(), nullable=True),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("relationships", "chunk_id")
op.drop_column("relationships", "document_id")
op.create_index("source_uri", "documents", ["source_uri"], unique=True)
op.drop_column("data_sources", "deleted_at")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/c7f016a904c1_.py
================================================
"""empty message
Revision ID: c7f016a904c1
Revises: 749767db5505
Create Date: 2024-10-30 13:28:17.345385
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c7f016a904c1"
down_revision = "749767db5505"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat_messages", sa.Column("meta", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_messages", "meta")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/d2ad44deab20_multiple_kb.py
================================================
"""multiple_kb
Revision ID: d2ad44deab20
Revises: c7f016a904c1
Create Date: 2024-11-15 09:51:42.493749
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy import update
from sqlmodel import Session
from sqlalchemy.dialects import mysql
from app.core.config import settings
from app.core.db import engine
from app.models import EmbeddingModel
# revision identifiers, used by Alembic.
revision = "d2ad44deab20"
down_revision = "c7f016a904c1"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"knowledge_bases",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column("description", mysql.MEDIUMTEXT(), nullable=True),
sa.Column("index_methods", sa.JSON(), nullable=True),
sa.Column("llm_id", sa.Integer(), nullable=True),
sa.Column("embedding_model_id", sa.Integer(), nullable=True),
sa.Column("documents_total", sa.Integer(), nullable=False),
sa.Column("data_sources_total", sa.Integer(), nullable=False),
sa.Column("created_by", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("updated_by", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column(
"updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True
),
sa.Column("deleted_by", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("deleted_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["created_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["deleted_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["embedding_model_id"],
["embedding_models.id"],
),
sa.ForeignKeyConstraint(
["llm_id"],
["llms.id"],
),
sa.ForeignKeyConstraint(
["updated_by"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"knowledge_base_datasources",
sa.Column("knowledge_base_id", sa.Integer(), nullable=False),
sa.Column("data_source_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["data_source_id"],
["data_sources.id"],
),
sa.ForeignKeyConstraint(
["knowledge_base_id"],
["knowledge_bases.id"],
),
sa.PrimaryKeyConstraint("knowledge_base_id", "data_source_id"),
)
op.add_column(
"documents", sa.Column("knowledge_base_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_d_on_data_source_id",
"documents",
"data_sources",
["data_source_id"],
["id"],
)
op.create_foreign_key(
"fk_d_on_knowledge_base_id",
"documents",
"knowledge_bases",
["knowledge_base_id"],
["id"],
)
op.add_column(
"embedding_models", sa.Column("vector_dimension", sa.Integer(), nullable=False)
)
# ### end Alembic commands ###
# ### Data Migration ###
with Session(engine) as session:
stmt = (
update(EmbeddingModel)
.where(EmbeddingModel.vector_dimension == 0)
.values(vector_dimension=settings.EMBEDDING_DIMS)
)
session.exec(stmt)
session.commit()
# ### end Data Migration ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("embedding_models", "vector_dimension")
op.drop_constraint("fk_d_on_data_source_id", "documents", type_="foreignkey")
op.drop_constraint("fk_d_on_knowledge_base_id", "documents", type_="foreignkey")
op.drop_column("documents", "knowledge_base_id")
op.drop_table("knowledge_base_datasources")
op.drop_table("knowledge_bases")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/dfee070b8abd_.py
================================================
"""empty message
Revision ID: dfee070b8abd
Revises: eb0b85608c0a
Create Date: 2024-09-10 10:45:50.318277
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = "dfee070b8abd"
down_revision = "eb0b85608c0a"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat_messages", sa.Column("graph_data", sa.JSON(), nullable=True))
op.create_index("idx_entity_type", "entities", ["entity_type"], unique=False)
op.alter_column(
"site_settings",
"created_at",
existing_type=mysql.DATETIME(fsp=6),
nullable=True,
existing_server_default=sa.text("CURRENT_TIMESTAMP(6)"),
)
op.alter_column(
"site_settings",
"updated_at",
existing_type=mysql.DATETIME(fsp=6),
nullable=True,
existing_server_default=sa.text(
"CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6)"
),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"site_settings",
"updated_at",
existing_type=mysql.DATETIME(fsp=6),
nullable=False,
existing_server_default=sa.text(
"CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6)"
),
)
op.alter_column(
"site_settings",
"created_at",
existing_type=mysql.DATETIME(fsp=6),
nullable=False,
existing_server_default=sa.text("CURRENT_TIMESTAMP(6)"),
)
op.drop_index("idx_entity_type", table_name="entities")
op.drop_column("chat_messages", "graph_data")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/e32f1e546eec_.py
================================================
"""empty message
Revision ID: e32f1e546eec
Revises: bd17a4ebccc5
Create Date: 2024-08-08 03:55:14.042290
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from app.models.base import AESEncryptedColumn
# revision identifiers, used by Alembic.
revision = "e32f1e546eec"
down_revision = "bd17a4ebccc5"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"reranker_models",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column(
"provider",
sa.Enum("JINA", "COHERE", name="rerankerprovider"),
nullable=False,
),
sa.Column(
"model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False
),
sa.Column("top_n", sa.Integer(), nullable=False),
sa.Column("config", sa.JSON(), nullable=True),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("credentials", AESEncryptedColumn(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("chat_engines", sa.Column("reranker_id", sa.Integer(), nullable=True))
op.create_foreign_key(
None, "chat_engines", "reranker_models", ["reranker_id"], ["id"]
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_engines", "reranker_id")
op.drop_table("reranker_models")
# ### end Alembic commands ###
================================================
FILE: backend/app/alembic/versions/eb0b85608c0a_.py
================================================
"""empty message
Revision ID: eb0b85608c0a
Revises: 00534dc350db
Create Date: 2024-08-28 15:10:04.219389
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = "eb0b85608c0a"
down_revision = "00534dc350db"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"documents",
"mime_type",
existing_type=mysql.VARCHAR(length=64),
type_=sa.String(length=128),
existing_nullable=False,
)
op.alter_column(
"uploads",
"mime_type",
existing_type=mysql.VARCHAR(length=64),
type_=sa.String(length=128),
existing_nullable=False,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"uploads",
"mime_type",
existing_type=sa.String(length=128),
type_=mysql.VARCHAR(length=64),
existing_nullable=False,
)
op.alter_column(
"documents",
"mime_type",
existing_type=sa.String(length=128),
type_=mysql.VARCHAR(length=64),
existing_nullable=False,
)
# ### end Alembic commands ###
================================================
FILE: backend/app/api/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/chat/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/chat/routes.py
================================================
from typing import Optional
from fastapi import APIRouter, Depends
from fastapi_pagination import Page, Params
from app.models.chat import ChatOrigin
from app.api.deps import CurrentSuperuserDep, SessionDep
from app.repositories import chat_repo
router = APIRouter(
prefix="/admin/chats",
tags=["admin/chats"],
)
@router.get("/origins")
def list_chat_origins(
db_session: SessionDep,
user: CurrentSuperuserDep,
search: Optional[str] = None,
params: Params = Depends(),
) -> Page[ChatOrigin]:
return chat_repo.list_chat_origins(db_session, search, params)
================================================
FILE: backend/app/api/admin_routes/chat_engine.py
================================================
from fastapi import APIRouter, Depends
from fastapi_pagination import Params, Page
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.exceptions import DefaultChatEngineCannotBeDeleted
from app.rag.chat.config import ChatEngineConfig
from app.repositories import chat_engine_repo
from app.models import ChatEngine, ChatEngineUpdate
router = APIRouter()
@router.get("/admin/chat-engines")
def list_chat_engines(
db_session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[ChatEngine]:
return chat_engine_repo.paginate(db_session, params)
@router.post("/admin/chat-engines")
def create_chat_engine(
db_session: SessionDep,
user: CurrentSuperuserDep,
chat_engine: ChatEngine,
) -> ChatEngine:
return chat_engine_repo.create(db_session, chat_engine)
@router.get("/admin/chat-engines/{chat_engine_id}")
def get_chat_engine(
db_session: SessionDep,
user: CurrentSuperuserDep,
chat_engine_id: int,
) -> ChatEngine:
return chat_engine_repo.must_get(db_session, chat_engine_id)
@router.put("/admin/chat-engines/{chat_engine_id}")
def update_chat_engine(
db_session: SessionDep,
user: CurrentSuperuserDep,
chat_engine_id: int,
update: ChatEngineUpdate,
) -> ChatEngine:
chat_engine = chat_engine_repo.must_get(db_session, chat_engine_id)
return chat_engine_repo.update(db_session, chat_engine, update)
@router.delete("/admin/chat-engines/{chat_engine_id}")
def delete_chat_engine(
db_session: SessionDep,
user: CurrentSuperuserDep,
chat_engine_id: int,
) -> ChatEngine:
chat_engine = chat_engine_repo.must_get(db_session, chat_engine_id)
if chat_engine.is_default:
raise DefaultChatEngineCannotBeDeleted(chat_engine_id)
return chat_engine_repo.delete(db_session, chat_engine)
@router.get("/admin/chat-engines-default-config")
def get_default_config(
db_session: SessionDep, user: CurrentSuperuserDep
) -> ChatEngineConfig:
return ChatEngineConfig()
================================================
FILE: backend/app/api/admin_routes/document/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/document/routes.py
================================================
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from fastapi_pagination import Params, Page
from app.api.admin_routes.knowledge_base.document.models import (
DocumentFilters,
DocumentItem,
)
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.repositories import document_repo
router = APIRouter()
@router.get("/admin/documents")
def list_documents(
session: SessionDep,
user: CurrentSuperuserDep,
filters: Annotated[DocumentFilters, Query()],
params: Params = Depends(),
) -> Page[DocumentItem]:
return document_repo.paginate(
session=session,
filters=filters,
params=params,
)
================================================
FILE: backend/app/api/admin_routes/embedding_model/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/embedding_model/models.py
================================================
from datetime import datetime
from typing import Any
from pydantic import BaseModel, field_validator
from typing_extensions import Optional
from app.rag.embeddings.provider import EmbeddingProvider
class EmbeddingModelCreate(BaseModel):
name: str
provider: EmbeddingProvider
model: str
vector_dimension: int
config: dict | list | None
credentials: Any
is_default: Optional[bool] = False
@field_validator("vector_dimension")
def vector_dimension_must_gt_1(cls, v: int) -> int:
if v <= 0:
raise ValueError(
"The vector dimension of the Embedding model should be at least greater than 1."
)
return v
class EmbeddingModelUpdate(BaseModel):
name: Optional[str] = None
config: Optional[dict | list] = None
credentials: Optional[str | dict] = None
class EmbeddingModelItem(BaseModel):
id: int
name: str
provider: EmbeddingProvider
model: str
vector_dimension: int
is_default: bool
class EmbeddingModelDetail(BaseModel):
id: int
name: str
provider: EmbeddingProvider
model: str
vector_dimension: int
config: dict | list | None
is_default: bool
created_at: datetime
updated_at: datetime
class EmbeddingModelTestResult(BaseModel):
success: bool
error: str = ""
================================================
FILE: backend/app/api/admin_routes/embedding_model/routes.py
================================================
from typing import List
from fastapi import APIRouter, Depends
from fastapi_pagination import Params, Page
from app.api.admin_routes.embedding_model.models import (
EmbeddingModelItem,
EmbeddingModelDetail,
EmbeddingModelUpdate,
EmbeddingModelTestResult,
EmbeddingModelCreate,
)
from app.api.deps import CurrentSuperuserDep, SessionDep
from app.repositories.embedding_model import embedding_model_repo
from app.rag.embeddings.provider import (
EmbeddingProviderOption,
embedding_provider_options,
)
from app.rag.embeddings.resolver import resolve_embed_model
from app.logger import logger
router = APIRouter()
@router.get("/admin/embedding-models/providers/options")
def list_embedding_model_provider_options(
user: CurrentSuperuserDep,
) -> List[EmbeddingProviderOption]:
return embedding_provider_options
@router.get("/admin/embedding-models")
def list_embedding_models(
db_session: SessionDep, user: CurrentSuperuserDep, params: Params = Depends()
) -> Page[EmbeddingModelItem]:
return embedding_model_repo.paginate(db_session, params)
@router.post("/admin/embedding-models/test")
def test_embedding_model(
user: CurrentSuperuserDep,
create: EmbeddingModelCreate,
) -> EmbeddingModelTestResult:
try:
embed_model = resolve_embed_model(
provider=create.provider,
model=create.model,
config=create.config,
credentials=create.credentials,
)
embedding = embed_model.get_query_embedding("Hello, world!")
expected_length = create.vector_dimension
if len(embedding) != expected_length:
raise ValueError(
f"Embedding model is configured with {expected_length} dimensions, but got vector embedding with {len(embedding)} dimensions."
)
success = True
error = ""
except Exception as e:
logger.info(f"Failed to test embedding model: {e}")
success = False
error = str(e)
return EmbeddingModelTestResult(success=success, error=error)
@router.post("/admin/embedding-models")
def create_embedding_model(
db_session: SessionDep,
user: CurrentSuperuserDep,
create: EmbeddingModelCreate,
) -> EmbeddingModelDetail:
return embedding_model_repo.create(db_session, create)
@router.get("/admin/embedding-models/{model_id}")
def get_embedding_model_detail(
db_session: SessionDep, user: CurrentSuperuserDep, model_id: int
) -> EmbeddingModelDetail:
return embedding_model_repo.must_get(db_session, model_id)
@router.put("/admin/embedding-models/{model_id}")
def update_embedding_model(
db_session: SessionDep,
user: CurrentSuperuserDep,
model_id: int,
update: EmbeddingModelUpdate,
) -> EmbeddingModelDetail:
embed_model = embedding_model_repo.must_get(db_session, model_id)
return embedding_model_repo.update(db_session, embed_model, update)
@router.delete("/admin/embedding-models/{model_id}")
def delete_embedding_model(
db_session: SessionDep, user: CurrentSuperuserDep, model_id: int
) -> None:
embedding_model = embedding_model_repo.must_get(db_session, model_id)
embedding_model_repo.delete(db_session, embedding_model)
@router.put("/admin/embedding-models/{model_id}/set_default")
def set_default_embedding_model(
db_session: SessionDep, user: CurrentSuperuserDep, model_id: int
) -> EmbeddingModelDetail:
embed_model = embedding_model_repo.must_get(db_session, model_id)
return embedding_model_repo.set_default(db_session, embed_model)
================================================
FILE: backend/app/api/admin_routes/evaluation/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/evaluation/evaluation_dataset.py
================================================
import pandas as pd
from fastapi import APIRouter, status, HTTPException, Depends
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import select, desc
from app.api.admin_routes.evaluation.models import (
CreateEvaluationDataset,
UpdateEvaluationDataset,
ModifyEvaluationDatasetItem,
ParamsWithKeyword,
)
from app.api.admin_routes.evaluation.tools import must_get
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.file_storage import default_file_storage
from app.models import Upload, EvaluationDataset, EvaluationDatasetItem
from app.types import MimeTypes
router = APIRouter()
@router.post("/admin/evaluation/datasets")
def create_evaluation_dataset(
evaluation_dataset: CreateEvaluationDataset,
session: SessionDep,
user: CurrentSuperuserDep,
) -> EvaluationDataset:
"""
Create a dataset for a given question and chat engine.
This API depends on the /admin/uploads API to upload the evaluation data.
The evaluation data is expected to be a CSV file with the following columns:
- query: The query to evaluate
- reference: The expected response to the query
You can add more columns to the CSV file, and the extra columns will adhere to the results.
Args:
evaluation_dataset.name: The name of the evaluation dataset.
evaluation_dataset.upload_id: The ID of the uploaded CSV file of the evaluation dataset.
Returns:
True if the evaluation dataset is created successfully.
"""
name = evaluation_dataset.name
evaluation_data_list = []
if evaluation_dataset.upload_id is not None:
# If the evaluation_file_id is provided, validate the uploaded file
evaluation_file_id = evaluation_dataset.upload_id
upload = must_get(session, Upload, evaluation_file_id)
if upload.mime_type != MimeTypes.CSV:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The uploaded file must be a CSV file.",
)
with default_file_storage.open(upload.path) as f:
df = pd.read_csv(f)
# check essential columns
must_have_columns = ["query", "reference"]
if not set(must_have_columns).issubset(df.columns):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The uploaded file must have the following columns: {must_have_columns}",
)
eval_list = df.to_dict(orient="records")
# create evaluation dataset items
evaluation_data_list = [
EvaluationDatasetItem(
query=item["query"],
reference=item["reference"],
retrieved_contexts=[], # TODO: implement this after we can retrieve contexts
extra={k: item[k] for k in item if k not in must_have_columns},
)
for item in eval_list
]
evaluation_dataset = EvaluationDataset(
name=name,
user_id=user.id,
evaluation_data_list=evaluation_data_list,
)
session.add(evaluation_dataset)
session.commit()
session.refresh(evaluation_dataset)
return evaluation_dataset
@router.delete("/admin/evaluation/datasets/{evaluation_dataset_id}")
def delete_evaluation_dataset(
evaluation_dataset_id: int, session: SessionDep, user: CurrentSuperuserDep
) -> bool:
evaluation_dataset = must_get(session, EvaluationDataset, evaluation_dataset_id)
session.delete(evaluation_dataset)
session.commit()
return True
@router.put("/admin/evaluation/datasets/{evaluation_dataset_id}")
def update_evaluation_dataset(
evaluation_dataset_id: int,
updated_evaluation_dataset: UpdateEvaluationDataset,
session: SessionDep,
user: CurrentSuperuserDep,
) -> EvaluationDataset:
evaluation_dataset = must_get(session, EvaluationDataset, evaluation_dataset_id)
evaluation_dataset.name = updated_evaluation_dataset.name
session.merge(evaluation_dataset)
session.commit()
session.refresh(evaluation_dataset)
return evaluation_dataset
@router.get("/admin/evaluation/datasets")
def list_evaluation_dataset(
session: SessionDep,
user: CurrentSuperuserDep,
params: ParamsWithKeyword = Depends(),
) -> Page[EvaluationDataset]:
stmt = select(EvaluationDataset).order_by(desc(EvaluationDataset.id))
if params.keyword:
stmt = stmt.where(EvaluationDataset.name.ilike(f"%{params.keyword}%"))
return paginate(session, stmt, params)
@router.post("/admin/evaluation/dataset-items")
def create_evaluation_dataset_item(
modify_evaluation_dataset_item: ModifyEvaluationDatasetItem,
session: SessionDep,
user: CurrentSuperuserDep,
) -> EvaluationDatasetItem:
evaluation_dataset_item = EvaluationDatasetItem(
query=modify_evaluation_dataset_item.query,
reference=modify_evaluation_dataset_item.reference,
retrieved_contexts=modify_evaluation_dataset_item.retrieved_contexts,
extra=modify_evaluation_dataset_item.extra,
evaluation_dataset_id=modify_evaluation_dataset_item.evaluation_dataset_id,
)
session.add(evaluation_dataset_item)
session.commit()
session.refresh(evaluation_dataset_item)
return evaluation_dataset_item
@router.delete("/admin/evaluation/dataset-items/{evaluation_dataset_item_id}")
def delete_evaluation_dataset_item(
evaluation_dataset_item_id: int, session: SessionDep, user: CurrentSuperuserDep
) -> bool:
evaluation_dataset_item = must_get(
session, EvaluationDatasetItem, evaluation_dataset_item_id
)
session.delete(evaluation_dataset_item)
session.commit()
return True
@router.put("/admin/evaluation/dataset-items/{evaluation_dataset_item_id}")
def update_evaluation_dataset_item(
evaluation_dataset_item_id: int,
updated_evaluation_dataset_item: ModifyEvaluationDatasetItem,
session: SessionDep,
user: CurrentSuperuserDep,
) -> EvaluationDatasetItem:
evaluation_dataset_item = must_get(
session, EvaluationDatasetItem, evaluation_dataset_item_id
)
evaluation_dataset_item.query = updated_evaluation_dataset_item.query
evaluation_dataset_item.reference = updated_evaluation_dataset_item.reference
evaluation_dataset_item.retrieved_contexts = (
updated_evaluation_dataset_item.retrieved_contexts
)
evaluation_dataset_item.extra = updated_evaluation_dataset_item.extra
evaluation_dataset_item.evaluation_dataset_id = (
updated_evaluation_dataset_item.evaluation_dataset_id
)
session.merge(evaluation_dataset_item)
session.commit()
session.refresh(evaluation_dataset_item)
return evaluation_dataset_item
@router.get("/admin/evaluation/datasets/{evaluation_dataset_id}/dataset-items")
def list_evaluation_dataset_item(
session: SessionDep,
user: CurrentSuperuserDep,
evaluation_dataset_id: int,
params: ParamsWithKeyword = Depends(),
) -> Page[EvaluationDatasetItem]:
stmt = (
select(EvaluationDatasetItem)
.where(EvaluationDatasetItem.evaluation_dataset_id == evaluation_dataset_id)
.order_by(EvaluationDatasetItem.id)
)
if params.keyword:
stmt = stmt.where(EvaluationDatasetItem.query.ilike(f"%{params.keyword}%"))
return paginate(session, stmt, params)
@router.get("/admin/evaluation/dataset-items/{evaluation_dataset_item_id}")
def get_evaluation_dataset_item(
session: SessionDep,
user: CurrentSuperuserDep,
evaluation_dataset_item_id: int,
) -> EvaluationDatasetItem:
return must_get(session, EvaluationDatasetItem, evaluation_dataset_item_id)
================================================
FILE: backend/app/api/admin_routes/evaluation/evaluation_task.py
================================================
import logging
from typing import Optional, List
import sqlmodel
from fastapi import APIRouter, Depends
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy import func, update
from sqlalchemy.orm import Session
from sqlmodel import select, case, desc
from app.api.admin_routes.evaluation.models import (
CreateEvaluationTask,
EvaluationTaskSummary,
ParamsWithKeyword,
EvaluationTaskOverview,
)
from app.api.admin_routes.evaluation.tools import must_get
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.models import (
EvaluationTask,
EvaluationTaskItem,
EvaluationStatus,
EvaluationDataset,
)
from app.tasks.evaluate import add_evaluation_task
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/admin/evaluation/tasks")
def create_evaluation_task(
evaluation_task: CreateEvaluationTask,
session: SessionDep,
user: CurrentSuperuserDep,
) -> Optional[EvaluationTask]:
"""
Create an evaluation task from the evaluation dataset.
Args:
evaluation_task.name: The name of the evaluation task.
evaluation_task.evaluation_dataset_id: The ID of the uploaded evaluation dataset.
evaluation_task.chat_engine: The chat engine to evaluate the queries against. Default is "default".
evaluation_task.run_size: The number of queries to evaluate. Default is None, which means all queries in the CSV file.
Returns:
True if the evaluation task is created successfully.
"""
name = evaluation_task.name
evaluation_dataset_id = evaluation_task.evaluation_dataset_id
chat_engine = evaluation_task.chat_engine
run_size = evaluation_task.run_size
dataset = must_get(session, EvaluationDataset, evaluation_dataset_id)
if run_size is not None and run_size < len(dataset.evaluation_data_list):
dataset.evaluation_data_list = dataset.evaluation_data_list[:run_size]
# create evaluation items
# caveat: Do the deep copy on purpose to avoid the side effect of the original dataset modification
evaluation_task_items = [
EvaluationTaskItem(
status=EvaluationStatus.NOT_START,
chat_engine=chat_engine,
query=item.query,
reference=item.reference,
retrieved_contexts=item.retrieved_contexts,
extra=item.extra,
)
for item in dataset.evaluation_data_list
]
evaluation_task = EvaluationTask(
name=name,
user_id=user.id,
evaluation_task_items=evaluation_task_items,
dataset_id=evaluation_dataset_id,
)
session.add(evaluation_task)
session.commit()
session.refresh(evaluation_task)
add_evaluation_task.delay(evaluation_task.id)
return evaluation_task
@router.delete("/admin/evaluation/tasks/{evaluation_task_id}")
def cancel_evaluation_task(
evaluation_task_id: int, session: SessionDep, user: CurrentSuperuserDep
) -> Optional[bool]:
must_get(session, EvaluationTask, evaluation_task_id)
session.exec(
update(EvaluationTaskItem)
.where(EvaluationTaskItem.evaluation_task_id == evaluation_task_id)
.values(status=EvaluationStatus.CANCEL)
)
session.commit()
return True
@router.get("/admin/evaluation/tasks/{evaluation_task_id}")
def get_evaluation_task(
session: SessionDep,
user: CurrentSuperuserDep,
evaluation_task_id: int,
) -> EvaluationTask:
return must_get(session, EvaluationTask, evaluation_task_id)
@router.get("/admin/evaluation/tasks/{evaluation_task_id}/summary")
def get_evaluation_task_summary(
evaluation_task_id: int, session: SessionDep, user: CurrentSuperuserDep
) -> EvaluationTaskSummary:
task = must_get(session, EvaluationTask, evaluation_task_id)
return get_summary_for_evaluation_task(task, session)
@router.get("/admin/evaluation/tasks")
def list_evaluation_task(
session: SessionDep,
user: CurrentSuperuserDep,
params: ParamsWithKeyword = Depends(),
) -> Page[EvaluationTaskSummary]:
stmt = select(EvaluationTask).order_by(desc(EvaluationTask.id))
if params.keyword:
stmt = stmt.where(EvaluationTask.name.ilike(f"%{params.keyword}%"))
task_page: Page[EvaluationTask] = paginate(session, stmt, params)
summaries: List[EvaluationTaskSummary] = []
for task in task_page.items:
summaries.append(get_summary_for_evaluation_task(task, session))
return Page[EvaluationTaskSummary](
items=summaries,
total=task_page.total,
page=task_page.page,
size=task_page.size,
pages=task_page.pages,
)
@router.get("/admin/evaluation/tasks/{evaluation_task_id}/items")
def list_evaluation_task_items(
evaluation_task_id: int,
session: SessionDep,
user: CurrentSuperuserDep,
params: ParamsWithKeyword = Depends(),
) -> Page[EvaluationTaskItem]:
must_get(session, EvaluationTask, evaluation_task_id)
stmt = select(EvaluationTaskItem).where(
EvaluationTaskItem.evaluation_task_id == evaluation_task_id
)
if params.keyword:
stmt = stmt.where(
sqlmodel.or_(
EvaluationTaskItem.query.ilike(f"%{params.keyword}%"),
EvaluationTaskItem.reference.ilike(f"%{params.keyword}%"),
)
)
stmt.order_by(EvaluationTaskItem.id)
return paginate(session, stmt, params)
def get_summary_for_evaluation_task(
evaluation_task: EvaluationTask, session: Session
) -> EvaluationTaskSummary:
status_counts = (
session.query(
func.count(
case(
(EvaluationTaskItem.status == EvaluationStatus.NOT_START, 1),
else_=None,
)
).label("not_start"),
func.count(
case(
(EvaluationTaskItem.status == EvaluationStatus.EVALUATING, 1),
else_=None,
)
).label("evaluating"),
func.count(
case(
(EvaluationTaskItem.status == EvaluationStatus.DONE, 1), else_=None
)
).label("done"),
func.count(
case(
(EvaluationTaskItem.status == EvaluationStatus.ERROR, 1), else_=None
)
).label("error"),
func.count(
case(
(EvaluationTaskItem.status == EvaluationStatus.CANCEL, 1),
else_=None,
)
).label("cancel"),
)
.filter(EvaluationTaskItem.evaluation_task_id == evaluation_task.id)
.one()
)
stats = {}
if status_counts.not_start == 0 and status_counts.evaluating == 0:
stats_tuple = (
session.query(
func.avg(EvaluationTaskItem.factual_correctness).label(
"avg_factual_correctness"
),
func.avg(EvaluationTaskItem.semantic_similarity).label(
"avg_semantic_similarity"
),
func.min(EvaluationTaskItem.factual_correctness).label(
"min_factual_correctness"
),
func.min(EvaluationTaskItem.semantic_similarity).label(
"min_semantic_similarity"
),
func.max(EvaluationTaskItem.factual_correctness).label(
"max_factual_correctness"
),
func.max(EvaluationTaskItem.semantic_similarity).label(
"max_semantic_similarity"
),
func.stddev(EvaluationTaskItem.factual_correctness).label(
"std_factual_correctness"
),
func.stddev(EvaluationTaskItem.semantic_similarity).label(
"std_semantic_similarity"
),
)
.filter(
EvaluationTaskItem.evaluation_task_id == evaluation_task.id,
EvaluationTaskItem.status == EvaluationStatus.DONE,
EvaluationTaskItem.factual_correctness.isnot(None),
EvaluationTaskItem.semantic_similarity.isnot(None),
)
.one()
)
stats = dict(stats_tuple._mapping)
logger.info(stats)
return EvaluationTaskSummary(
summary=EvaluationTaskOverview(
not_start=status_counts.not_start,
succeed=status_counts.done,
errored=status_counts.error,
progressing=status_counts.evaluating,
cancel=status_counts.cancel,
avg_factual_correctness=stats.get("avg_factual_correctness", 0),
avg_semantic_similarity=stats.get("avg_semantic_similarity", 0),
min_factual_correctness=stats.get("min_factual_correctness", 0),
min_semantic_similarity=stats.get("min_semantic_similarity", 0),
max_factual_correctness=stats.get("max_factual_correctness", 0),
max_semantic_similarity=stats.get("max_semantic_similarity", 0),
std_factual_correctness=stats.get("std_factual_correctness", 0),
std_semantic_similarity=stats.get("std_semantic_similarity", 0),
),
**evaluation_task.model_dump(),
)
================================================
FILE: backend/app/api/admin_routes/evaluation/models.py
================================================
from typing import Optional
from uuid import UUID
from datetime import datetime
from fastapi_pagination import Params
from pydantic import BaseModel
class CreateEvaluationTask(BaseModel):
name: str
evaluation_dataset_id: int
chat_engine: str = "default"
run_size: Optional[int] = None
class EvaluationTaskOverview(BaseModel):
not_start: int
succeed: int
errored: int
progressing: int
cancel: int
avg_factual_correctness: Optional[float]
avg_semantic_similarity: Optional[float]
min_factual_correctness: Optional[float]
min_semantic_similarity: Optional[float]
max_factual_correctness: Optional[float]
max_semantic_similarity: Optional[float]
std_factual_correctness: Optional[float]
std_semantic_similarity: Optional[float]
class EvaluationTaskSummary(BaseModel):
id: Optional[int]
name: str
user_id: UUID
dataset_id: int
created_at: Optional[datetime]
updated_at: Optional[datetime]
summary: EvaluationTaskOverview
class UpdateEvaluationDataset(BaseModel):
name: str
class CreateEvaluationDataset(BaseModel):
name: str
upload_id: Optional[int] = None
class ModifyEvaluationDatasetItem(BaseModel):
query: str
reference: str
retrieved_contexts: list[str]
extra: dict
evaluation_dataset_id: int
class ParamsWithKeyword(Params):
keyword: Optional[str] = None
================================================
FILE: backend/app/api/admin_routes/evaluation/tools.py
================================================
from typing import TypeVar, Type
from fastapi import status, HTTPException
from sqlmodel import SQLModel, Session
T = TypeVar("T", bound=SQLModel)
def must_get(session: Session, model: Type[T], item_id: int) -> T:
item = session.get(model, item_id)
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{model.__name__} with ID {item_id} not found",
)
return item
def must_get_and_belong(
session: Session, model: Type[T], item_id: int, user_id: int
) -> T:
item = must_get(session, model, item_id)
if not hasattr(item, "user_id"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{model.__name__} does not have a 'user_id' field",
)
if item.user_id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"{model.__name__} with ID {item_id} does not belong to user {user_id}",
)
return item
================================================
FILE: backend/app/api/admin_routes/feedback.py
================================================
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, Query
from fastapi_pagination import Params, Page
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.models import AdminFeedbackPublic, FeedbackFilters
from app.models.feedback import FeedbackOrigin
from app.repositories import feedback_repo
router = APIRouter(
prefix="/admin/feedbacks",
tags=["admin/feedback"],
)
@router.get("/")
def list_feedbacks(
session: SessionDep,
user: CurrentSuperuserDep,
filters: Annotated[FeedbackFilters, Query()],
params: Params = Depends(),
) -> Page[AdminFeedbackPublic]:
return feedback_repo.paginate(
session=session,
filters=filters,
params=params,
)
@router.get("/origins")
def list_feedback_origins(
session: SessionDep,
user: CurrentSuperuserDep,
search: Optional[str] = None,
params: Params = Depends(),
) -> Page[FeedbackOrigin]:
return feedback_repo.list_feedback_origins(session, search, params)
================================================
FILE: backend/app/api/admin_routes/knowledge_base/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/knowledge_base/chunk/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/knowledge_base/chunk/models.py
================================================
from pydantic import BaseModel
from app.rag.retrievers.chunk.schema import VectorSearchRetrieverConfig
class KBChunkRetrievalConfig(BaseModel):
vector_search: VectorSearchRetrieverConfig
# TODO: add fulltext and knowledge graph search config
class KBRetrieveChunksRequest(BaseModel):
query: str
retrieval_config: KBChunkRetrievalConfig
================================================
FILE: backend/app/api/admin_routes/knowledge_base/chunk/routes.py
================================================
import logging
from fastapi import APIRouter
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.retrievers.chunk.simple_retriever import (
ChunkSimpleRetriever,
)
from app.rag.retrievers.chunk.schema import ChunksRetrievalResult
from app.exceptions import InternalServerError, KBNotFound
from .models import KBRetrieveChunksRequest
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/admin/knowledge_base/{kb_id}/chunks/retrieve")
def retrieve_chunks(
db_session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
request: KBRetrieveChunksRequest,
) -> ChunksRetrievalResult:
try:
vector_search_config = request.retrieval_config.vector_search
retriever = ChunkSimpleRetriever(
db_session=db_session,
knowledge_base_id=kb_id,
config=vector_search_config,
)
return retriever.retrieve_chunks(
request.query,
)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
================================================
FILE: backend/app/api/admin_routes/knowledge_base/data_source/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/knowledge_base/data_source/models.py
================================================
from pydantic import BaseModel, field_validator
from app.models import DataSourceType
class KBDataSource(BaseModel):
"""
Represents a linked data source for a knowledge base.
"""
id: int
name: str
data_source_type: DataSourceType
config: dict | list
class KBDataSourceMutable(BaseModel):
name: str
@field_validator("name")
def name_must_not_be_blank(cls, v: str) -> str:
if not v.strip():
raise ValueError("Please provide a name for the data source")
return v
class KBDataSourceCreate(KBDataSourceMutable):
data_source_type: DataSourceType
config: dict | list
class KBDataSourceUpdate(KBDataSourceMutable):
pass
================================================
FILE: backend/app/api/admin_routes/knowledge_base/data_source/routes.py
================================================
import logging
from fastapi import APIRouter, Depends
from fastapi_pagination import Params, Page
from app.api.admin_routes.knowledge_base.data_source.models import (
KBDataSourceUpdate,
KBDataSource,
)
from app.api.admin_routes.knowledge_base.models import KBDataSourceCreate
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.exceptions import InternalServerError, KBDataSourceNotFound, KBNotFound
from app.models import DataSource
from app.repositories import knowledge_base_repo
from app.tasks.knowledge_base import (
import_documents_from_kb_datasource,
purge_kb_datasource_related_resources,
)
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/admin/knowledge_bases/{kb_id}/datasources")
def create_kb_datasource(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
create: KBDataSourceCreate,
) -> KBDataSource:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
new_data_source = DataSource(
name=create.name,
description="",
data_source_type=create.data_source_type,
config=create.config,
)
new_data_source = knowledge_base_repo.add_kb_datasource(
session, kb, new_data_source
)
import_documents_from_kb_datasource.delay(kb_id, new_data_source.id)
return new_data_source
except KBNotFound as e:
raise e
except Exception as e:
logger.error(
f"Failed to create data source for knowledge base #{kb_id}: {e}", exc_info=e
)
raise InternalServerError()
@router.put("/admin/knowledge_bases/{kb_id}/datasources/{data_source_id}")
def update_kb_datasource(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
data_source_id: int,
update: KBDataSourceUpdate,
) -> KBDataSource:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
data_source = kb.must_get_data_source_by_id(data_source_id)
data_source.name = update.name
session.add(data_source)
session.commit()
session.refresh(data_source)
return data_source
except KBNotFound as e:
raise e
except KBDataSourceNotFound as e:
raise e
except Exception as e:
logger.error(f"Failed to update data source #{data_source_id}: {e}", exc_info=e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{kb_id}/datasources/{data_source_id}")
def get_kb_datasource(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
data_source_id: int,
) -> KBDataSource:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
return kb.must_get_data_source_by_id(data_source_id)
except KBNotFound as e:
raise e
except KBDataSourceNotFound as e:
raise e
except Exception as e:
logger.error(f"Failed to get data source #{data_source_id}: {e}", exc_info=e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{kb_id}/datasources")
def list_kb_datasources(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
params: Params = Depends(),
) -> Page[KBDataSource]:
return knowledge_base_repo.list_kb_datasources(session, kb_id, params)
@router.delete("/admin/knowledge_bases/{kb_id}/datasources/{data_source_id}")
def remove_kb_datasource(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
data_source_id: int,
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
data_source = kb.must_get_data_source_by_id(data_source_id)
# Flag the data source to be deleted, it will be deleted completely by the background job.
knowledge_base_repo.remove_kb_datasource(session, kb, data_source)
session.commit()
purge_kb_datasource_related_resources.apply_async(
args=[kb_id, data_source_id], countdown=5
)
return {"detail": "success"}
except KBNotFound as e:
raise e
except KBDataSourceNotFound as e:
raise e
except Exception as e:
logger.error(
f"Failed to remove data source #{data_source_id} from knowledge base #{kb_id}: {e}",
exc_info=e,
)
raise InternalServerError()
================================================
FILE: backend/app/api/admin_routes/knowledge_base/document/models.py
================================================
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field
from app.api.admin_routes.models import DataSourceDescriptor, KnowledgeBaseDescriptor
from app.models import DocIndexTaskStatus
from app.types import MimeTypes
class DocumentFilters(BaseModel):
search: Optional[str] = Field(
description="The search string to filter documents by name or source URI.",
default=None,
)
knowledge_base_id: Optional[int] = Field(
description="The knowledge base ID that the document belongs to.",
default=None,
)
data_source_id: Optional[int] = Field(
description="The data source ID that the document belongs to.",
default=None,
)
mime_type: Optional[MimeTypes] = Field(
description="The MIME type of the documents to filter by.",
default=None,
)
index_status: Optional[DocIndexTaskStatus] = Field(
description="The status of the document index task to filter by.",
default=None,
)
created_at: Optional[tuple[datetime, datetime]] = Field(
description="The time range when the document was created.",
default=None,
)
updated_at: Optional[tuple[datetime, datetime]] = Field(
description="The time range when the document was last updated.",
default=None,
)
last_modified_at: Optional[tuple[datetime, datetime]] = Field(
description="The time range when the document was last modified in the source system.",
default=None,
)
class DocumentItem(BaseModel):
id: int
hash: str
name: str
content: str
mime_type: MimeTypes | None
source_uri: str | None
meta: dict | list | None
index_status: DocIndexTaskStatus | None
index_result: str | None
data_source: DataSourceDescriptor | None
knowledge_base: KnowledgeBaseDescriptor | None
last_modified_at: datetime
created_at: datetime
updated_at: datetime
class RebuildIndexResult(BaseModel):
reindex_document_ids: list[int] = Field(default_factory=list)
ignore_document_ids: list[int] = Field(default_factory=list)
reindex_chunk_ids: list[UUID] = Field(default_factory=list)
ignore_chunk_ids: list[UUID] = Field(default_factory=list)
================================================
FILE: backend/app/api/admin_routes/knowledge_base/document/routes.py
================================================
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi_pagination import Params, Page
from sqlmodel import Session
from app.api.admin_routes.knowledge_base.models import ChunkItem
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.models import Document
from app.models.chunk import KgIndexStatus, get_kb_chunk_model
from app.models.document import DocIndexTaskStatus
from app.models.entity import get_kb_entity_model
from app.models.relationship import get_kb_relationship_model
from app.repositories import knowledge_base_repo, document_repo
from app.repositories.chunk import ChunkRepo
from app.api.admin_routes.knowledge_base.document.models import (
DocumentFilters,
DocumentItem,
RebuildIndexResult,
)
from app.exceptions import InternalServerError
from app.repositories.graph import GraphRepo
from app.tasks.build_index import build_index_for_document, build_kg_index_for_chunk
from app.tasks.knowledge_base import stats_for_knowledge_base
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/admin/knowledge_bases/{kb_id}/documents")
def list_kb_documents(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
filters: Annotated[DocumentFilters, Query()],
params: Params = Depends(),
) -> Page[DocumentItem]:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
filters.knowledge_base_id = kb.id
return document_repo.paginate(
session=session,
filters=filters,
params=params,
)
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{kb_id}/documents/{doc_id}")
def get_kb_document_by_id(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
doc_id: int,
) -> Document:
try:
document = document_repo.must_get(session, doc_id)
assert document.knowledge_base_id == kb_id
return document
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{kb_id}/documents/{doc_id}/chunks")
def list_kb_document_chunks(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
doc_id: int,
) -> list[ChunkItem]:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
chunk_repo = ChunkRepo(get_kb_chunk_model(kb))
return chunk_repo.get_document_chunks(session, doc_id)
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.delete("/admin/knowledge_bases/{kb_id}/documents/{document_id}")
def remove_kb_document(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
document_id: int,
) -> RebuildIndexResult:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
doc = document_repo.must_get(session, document_id)
assert doc.knowledge_base_id == kb.id
chunk_model = get_kb_chunk_model(kb)
entity_model = get_kb_entity_model(kb)
relationship_model = get_kb_relationship_model(kb)
chunk_repo = ChunkRepo(chunk_model)
graph_repo = GraphRepo(entity_model, relationship_model, chunk_model)
graph_repo.delete_document_relationships(session, document_id)
logger.info(
f"Deleted relationships generated by document #{document_id} successfully."
)
graph_repo.delete_orphaned_entities(session)
logger.info("Deleted orphaned entities successfully.")
chunk_repo.delete_by_document(session, document_id)
logger.info(f"Deleted chunks of document #{document_id} successfully.")
session.delete(doc)
session.commit()
stats_for_knowledge_base.delay(kb_id)
return {"detail": "success"}
except HTTPException:
raise
except Exception as e:
logger.exception(f"Failed to remove document #{document_id}: {e}")
raise InternalServerError()
@router.post("/admin/knowledge_bases/{kb_id}/documents/reindex")
def rebuild_kb_documents_index(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
document_ids: list[int],
reindex_completed_task: bool = False,
):
try:
return rebuild_kb_document_index_by_ids(
session, kb_id, document_ids, reindex_completed_task
)
except HTTPException:
raise
except Exception as e:
logger.exception(e, exc_info=True)
raise InternalServerError()
@router.post("/admin/knowledge_bases/{kb_id}/documents/{doc_id}/reindex")
def rebuild_kb_document_index(
db_session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
doc_id: int,
reindex_completed_task: bool = False,
) -> RebuildIndexResult:
try:
document_ids = [doc_id]
return rebuild_kb_document_index_by_ids(
db_session, kb_id, document_ids, reindex_completed_task
)
except HTTPException:
raise
except Exception as e:
logger.exception(e, exc_info=True)
raise InternalServerError()
def rebuild_kb_document_index_by_ids(
db_session: Session,
kb_id: int,
document_ids: list[int],
reindex_completed_task: bool = False,
) -> RebuildIndexResult:
kb = knowledge_base_repo.must_get(db_session, kb_id)
kb_chunk_repo = ChunkRepo(get_kb_chunk_model(kb))
# Retry failed vector index tasks.
documents = document_repo.fetch_by_ids(db_session, document_ids)
reindex_document_ids = []
ignore_document_ids = []
for doc in documents:
# TODO: check NOT_STARTED, PENDING, RUNNING
if doc.index_status != DocIndexTaskStatus.FAILED and not reindex_completed_task:
ignore_document_ids.append(doc.id)
else:
reindex_document_ids.append(doc.id)
doc.index_status = DocIndexTaskStatus.PENDING
db_session.add(doc)
db_session.commit()
build_index_for_document.delay(kb.id, doc.id)
# Retry failed kg index tasks.
chunks = kb_chunk_repo.fetch_by_document_ids(db_session, document_ids)
reindex_chunk_ids = []
ignore_chunk_ids = []
for chunk in chunks:
if chunk.index_status == KgIndexStatus.COMPLETED and not reindex_completed_task:
ignore_chunk_ids.append(chunk.id)
continue
else:
reindex_chunk_ids.append(chunk.id)
chunk.index_status = KgIndexStatus.PENDING
db_session.add(chunk)
db_session.commit()
build_kg_index_for_chunk.delay(kb.id, chunk.id)
return RebuildIndexResult(
reindex_document_ids=reindex_document_ids,
ignore_document_ids=ignore_document_ids,
reindex_chunk_ids=reindex_chunk_ids,
ignore_chunk_ids=ignore_chunk_ids,
)
================================================
FILE: backend/app/api/admin_routes/knowledge_base/graph/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/knowledge_base/graph/knowledge/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/knowledge_base/graph/knowledge/routes.py
================================================
from fastapi import HTTPException
from starlette import status
from app.api.admin_routes.knowledge_base.graph.models import (
KnowledgeRequest,
KnowledgeNeighborRequest,
KnowledgeChunkRequest,
)
from app.api.admin_routes.knowledge_base.graph.routes import router, logger
from app.api.deps import SessionDep
from app.exceptions import KBNotFound, InternalServerError
from app.rag.knowledge_base.index_store import get_kb_tidb_graph_store
from app.repositories import knowledge_base_repo
# Experimental interface
@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge")
def retrieve_knowledge(session: SessionDep, kb_id: int, request: KnowledgeRequest):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.retrieve_graph_data(
request.query,
request.top_k,
request.similarity_threshold,
)
return {
"entities": data["entities"],
"relationships": data["relationships"],
}
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/neighbors")
def retrieve_knowledge_neighbors(
session: SessionDep, kb_id: int, request: KnowledgeNeighborRequest
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.retrieve_neighbors(
request.entities_ids,
request.query,
request.max_depth,
request.max_neighbors,
request.similarity_threshold,
)
return data
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/chunks")
def retrieve_knowledge_chunks(
session: SessionDep, kb_id: int, request: KnowledgeChunkRequest
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.get_chunks_by_relationships(request.relationships_ids)
if not data:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No chunks found for the given relationships",
)
return data
except KBNotFound as e:
raise e
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
================================================
FILE: backend/app/api/admin_routes/knowledge_base/graph/models.py
================================================
from typing import List, Optional
from pydantic import BaseModel, model_validator
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrieverConfig,
)
class SynopsisEntityCreate(BaseModel):
name: str
description: str
topic: str
meta: dict
entities: List[int]
@model_validator(mode="after")
def validate_entities(self):
if len(self.entities) == 0:
raise ValueError("Entities list should not be empty")
return self
class EntityUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
meta: Optional[dict] = None
class RelationshipUpdate(BaseModel):
description: Optional[str] = None
meta: Optional[dict] = None
weight: Optional[int] = None
class GraphSearchRequest(BaseModel):
query: str
include_meta: bool = True
depth: int = 2
with_degree: bool = True
relationship_meta_filters: dict = {}
# Knowledge Graph Retrieval
class KBKnowledgeGraphRetrievalConfig(BaseModel):
knowledge_graph: KnowledgeGraphRetrieverConfig
class KBRetrieveKnowledgeGraphRequest(BaseModel):
query: str
llm_id: int
retrieval_config: KBKnowledgeGraphRetrievalConfig
### Experimental
class KnowledgeRequest(BaseModel):
query: str
similarity_threshold: float = 0.55
top_k: int = 10
class KnowledgeNeighborRequest(BaseModel):
entities_ids: List[int]
query: str
max_depth: int = 1
max_neighbors: int = 20
similarity_threshold: float = 0.55
class KnowledgeChunkRequest(BaseModel):
relationships_ids: List[int]
================================================
FILE: backend/app/api/admin_routes/knowledge_base/graph/routes.py
================================================
import logging
from typing import List
import json
from fastapi import APIRouter, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from app.api.admin_routes.knowledge_base.graph.models import (
SynopsisEntityCreate,
EntityUpdate,
RelationshipUpdate,
KBRetrieveKnowledgeGraphRequest,
GraphSearchRequest,
)
from app.api.deps import SessionDep
from app.exceptions import KBNotFound, InternalServerError
from app.models import (
EntityPublic,
RelationshipPublic,
)
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrievalResult,
)
from app.rag.knowledge_base.index_store import (
get_kb_tidb_graph_editor,
get_kb_tidb_graph_store,
)
from app.rag.retrievers.knowledge_graph.simple_retriever import (
KnowledgeGraphSimpleRetriever,
)
from app.repositories import knowledge_base_repo
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get(
"/admin/knowledge_bases/{kb_id}/graph/entities/search",
response_model=List[EntityPublic],
)
def search_similar_entities(
session: SessionDep, kb_id: int, query: str, top_k: int = 10
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
tidb_graph_editor = get_kb_tidb_graph_editor(session, kb)
return tidb_graph_editor.search_similar_entities(session, query, top_k)
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.post(
"/admin/knowledge_bases/{kb_id}/graph/entities/synopsis",
response_model=EntityPublic,
)
def create_synopsis_entity(
session: SessionDep, kb_id: int, request: SynopsisEntityCreate
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
tidb_graph_editor = get_kb_tidb_graph_editor(session, kb)
return tidb_graph_editor.create_synopsis_entity(
session,
request.name,
request.description,
request.topic,
request.meta,
request.entities,
)
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.get(
"/admin/knowledge_bases/{kb_id}/graph/entities/{entity_id}",
response_model=EntityPublic,
)
def get_entity(session: SessionDep, kb_id: int, entity_id: int):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
tidb_graph_editor = get_kb_tidb_graph_editor(session, kb)
entity = tidb_graph_editor.get_entity(session, entity_id)
if not entity:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Entity not found",
)
return entity
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.put(
"/admin/knowledge_bases/{kb_id}/graph/entities/{entity_id}",
response_model=EntityPublic,
)
def update_entity(
session: SessionDep, kb_id: int, entity_id: int, entity_update: EntityUpdate
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
tidb_graph_editor = get_kb_tidb_graph_editor(session, kb)
old_entity = tidb_graph_editor.get_entity(session, entity_id)
if old_entity is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Entity not found",
)
entity = tidb_graph_editor.update_entity(
session, old_entity, entity_update.model_dump()
)
return entity
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.get("/admin/knowledge_bases/{kb_id}/graph/entities/{entity_id}/subgraph")
def get_entity_subgraph(session: SessionDep, kb_id: int, entity_id: int) -> dict:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
tidb_graph_editor = get_kb_tidb_graph_editor(session, kb)
entity = tidb_graph_editor.get_entity(session, entity_id)
if entity is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Entity not found",
)
relationships, entities = tidb_graph_editor.get_entity_subgraph(session, entity)
return {
"relationships": relationships,
"entities": entities,
}
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get(
"/admin/knowledge_bases/{kb_id}/graph/relationships/{relationship_id}",
response_model=RelationshipPublic,
)
def get_relationship(session: SessionDep, kb_id: int, relationship_id: int):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
tidb_graph_editor = get_kb_tidb_graph_editor(session, kb)
relationship = tidb_graph_editor.get_relationship(session, relationship_id)
if relationship is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Relationship not found",
)
return relationship
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.put(
"/admin/knowledge_bases/{kb_id}/graph/relationships/{relationship_id}",
response_model=RelationshipPublic,
)
def update_relationship(
session: SessionDep,
kb_id: int,
relationship_id: int,
relationship_update: RelationshipUpdate,
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
tidb_graph_editor = get_kb_tidb_graph_editor(session, kb)
old_relationship = tidb_graph_editor.get_relationship(session, relationship_id)
if old_relationship is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Relationship not found",
)
relationship = tidb_graph_editor.update_relationship(
session, old_relationship, relationship_update.model_dump()
)
return relationship
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.post("/admin/knowledge_bases/{kb_id}/graph/retrieve")
def retrieve_kb_knowledge_graph(
db_session: SessionDep, kb_id: int, request: KBRetrieveKnowledgeGraphRequest
) -> KnowledgeGraphRetrievalResult:
try:
retriever = KnowledgeGraphSimpleRetriever(
db_session=db_session,
knowledge_base_id=kb_id,
config=request.retrieval_config.knowledge_graph,
)
knowledge_graph = retriever.retrieve_knowledge_graph(request.query)
return KnowledgeGraphRetrievalResult(
entities=knowledge_graph.entities,
relationships=knowledge_graph.relationships,
)
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.post("/admin/knowledge_bases/{kb_id}/graph/search", deprecated=True)
def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchRequest):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
entities, relationships = graph_store.retrieve_with_weight(
request.query,
[],
request.depth,
request.include_meta,
request.with_degree,
request.relationship_meta_filters,
)
return {
"entities": entities,
"relationships": relationships,
}
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.post("/admin/knowledge_bases/{kb_id}/graph/entire_graph")
def get_entire_knowledge_graph(session: SessionDep, kb_id: int):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
retrieved_kg = graph_store.get_entire_knowledge_graph()
return {
"entities": retrieved_kg.entities,
"relationships": retrieved_kg.relationships,
}
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e
@router.get("/admin/knowledge_bases/{kb_id}/graph/entire_graph/stream")
def stream_entire_knowledge_graph(session: SessionDep, kb_id: int):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
def generate():
for chunk in graph_store.stream_entire_knowledge_graph(chunk_size=5000):
yield f"data: {json.dumps(jsonable_encoder(chunk))}\n\n"
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
}
)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
================================================
FILE: backend/app/api/admin_routes/knowledge_base/models.py
================================================
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, field_validator, Field
from app.api.admin_routes.knowledge_base.data_source.models import (
KBDataSource,
KBDataSourceCreate,
)
from app.api.admin_routes.models import (
EmbeddingModelDescriptor,
LLMDescriptor,
UserDescriptor,
)
from app.exceptions import KBNoVectorIndexConfigured
from app.models import KgIndexStatus
from app.models.knowledge_base import IndexMethod, GeneralChunkingConfig, ChunkingConfig
class KnowledgeBaseCreate(BaseModel):
name: str
description: Optional[str] = None
index_methods: list[IndexMethod] = Field(
default_factory=lambda: [IndexMethod.VECTOR]
)
llm_id: Optional[int] = None
embedding_model_id: Optional[int] = None
chunking_config: ChunkingConfig = Field(default_factory=GeneralChunkingConfig)
data_sources: list[KBDataSourceCreate] = Field(default_factory=list)
@field_validator("name")
def name_must_not_be_blank(cls, v: str) -> str:
if not v.strip():
raise ValueError("Please provide a name for the knowledge base")
return v
@field_validator("index_methods")
def index_methods_must_has_vector(cls, v: list[IndexMethod]) -> list[IndexMethod]:
# Notice: For now, knowledge base must be configured vector index method,
# we will remove this limit in the feature.
if IndexMethod.VECTOR not in v:
raise KBNoVectorIndexConfigured()
return v
class KnowledgeBaseUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
chunking_config: Optional[ChunkingConfig] = None
class KnowledgeBaseDetail(BaseModel):
"""
Represents a detailed view of a knowledge base.
"""
id: int
name: str
description: Optional[str] = None
documents_total: int
data_sources_total: int
# Notice: By default, SQLModel will not serialize list type relationships.
# https://github.com/fastapi/sqlmodel/issues/37#issuecomment-2093607242
data_sources: list[KBDataSource]
chunking_config: Optional[ChunkingConfig] = None
index_methods: list[IndexMethod]
llm_id: int | None = None
llm: LLMDescriptor | None = None
embedding_model_id: int | None = None
embedding_model: EmbeddingModelDescriptor | None = None
creator: UserDescriptor | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
class KnowledgeBaseItem(BaseModel):
"""
Represents a simplified view of a knowledge base for list display purposes.
"""
id: int
name: str
description: Optional[str] = None
documents_total: int
data_sources_total: int
index_methods: list[IndexMethod]
creator: UserDescriptor | None = None
created_at: datetime
updated_at: datetime
class VectorIndexError(BaseModel):
document_id: int
document_name: str
source_uri: str
error: str | None = None
class KGIndexError(BaseModel):
document_id: int
document_name: str
source_uri: str
chunk_id: UUID
error: str | None = None
class ChunkItem(BaseModel):
id: UUID
document_id: int
hash: str
text: str
meta: Optional[dict | list]
embedding: Optional[list[float]]
relations: Optional[dict | list]
source_uri: Optional[str]
index_status: Optional[KgIndexStatus]
index_result: Optional[str]
created_at: Optional[datetime]
updated_at: Optional[datetime]
class RetrievalRequest(BaseModel):
query: str
chat_engine: str = "default"
top_k: Optional[int] = 5
================================================
FILE: backend/app/api/admin_routes/knowledge_base/routes.py
================================================
import logging
from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Params, Page
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.knowledge_base.index_store import (
init_kb_tidb_vector_store,
init_kb_tidb_graph_store,
)
from .models import (
KnowledgeBaseDetail,
KnowledgeBaseItem,
KnowledgeBaseCreate,
KnowledgeBaseUpdate,
VectorIndexError,
KGIndexError,
)
from app.exceptions import (
InternalServerError,
KBIsUsedByChatEngines,
)
from app.models import (
DataSource,
KnowledgeBase,
)
from app.repositories import (
embedding_model_repo,
llm_repo,
data_source_repo,
knowledge_base_repo,
)
from app.tasks import (
build_kg_index_for_chunk,
build_index_for_document,
)
from app.tasks.knowledge_base import (
import_documents_for_knowledge_base,
stats_for_knowledge_base,
purge_knowledge_base_related_resources,
)
from ..models import ChatEngineDescriptor
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/admin/knowledge_bases")
def create_knowledge_base(
session: SessionDep, user: CurrentSuperuserDep, create: KnowledgeBaseCreate
) -> KnowledgeBaseDetail:
try:
data_sources = [
data_source_repo.create(
session,
DataSource(
name=data_source.name,
description="",
user_id=user.id,
data_source_type=data_source.data_source_type,
config=data_source.config,
),
)
for data_source in create.data_sources
]
if not create.llm_id:
create.llm_id = llm_repo.must_get_default(session).id
if not create.embedding_model_id:
create.embedding_model_id = embedding_model_repo.must_get_default(
session
).id
knowledge_base = KnowledgeBase(
name=create.name,
description=create.description,
index_methods=create.index_methods,
llm_id=create.llm_id,
embedding_model_id=create.embedding_model_id,
chunking_config=create.chunking_config.model_dump(),
data_sources=data_sources,
created_by=user.id,
updated_by=user.id,
)
knowledge_base = knowledge_base_repo.create(session, knowledge_base)
# Ensure the knowledge-base corresponding table schema are initialized.
init_kb_tidb_vector_store(session, knowledge_base)
init_kb_tidb_graph_store(session, knowledge_base)
# Trigger import and index documents for knowledge base
import_documents_for_knowledge_base.delay(knowledge_base.id)
return knowledge_base
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/knowledge_bases")
def list_knowledge_bases(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[KnowledgeBaseItem]:
return knowledge_base_repo.paginate(session, params)
@router.get("/admin/knowledge_bases/{knowledge_base_id}")
def get_knowledge_base(
session: SessionDep,
user: CurrentSuperuserDep,
knowledge_base_id: int,
) -> KnowledgeBaseDetail:
try:
return knowledge_base_repo.must_get(session, knowledge_base_id)
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.put("/admin/knowledge_bases/{knowledge_base_id}")
def update_knowledge_base_setting(
session: SessionDep,
user: CurrentSuperuserDep,
knowledge_base_id: int,
update: KnowledgeBaseUpdate,
) -> KnowledgeBaseDetail:
try:
knowledge_base = knowledge_base_repo.must_get(session, knowledge_base_id)
knowledge_base = knowledge_base_repo.update(session, knowledge_base, update)
return knowledge_base
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{kb_id}/linked_chat_engines")
def list_kb_linked_chat_engines(
session: SessionDep, user: CurrentSuperuserDep, kb_id: int
) -> list[ChatEngineDescriptor]:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
return knowledge_base_repo.list_linked_chat_engines(session, kb.id)
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.delete("/admin/knowledge_bases/{kb_id}")
def delete_knowledge_base(session: SessionDep, user: CurrentSuperuserDep, kb_id: int):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
# Check if the knowledge base has linked chat engines.
linked_chat_engines = knowledge_base_repo.list_linked_chat_engines(
session, kb.id
)
if len(linked_chat_engines) > 0:
raise KBIsUsedByChatEngines(kb_id, len(linked_chat_engines))
# Delete knowledge base.
knowledge_base_repo.delete(session, kb)
# Trigger purge knowledge base related resources after 5 seconds.
purge_knowledge_base_related_resources.apply_async(args=[kb_id], countdown=5)
return {"detail": f"Knowledge base #{kb_id} is deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{knowledge_base_id}/overview")
def get_knowledge_base_index_overview(
session: SessionDep,
user: CurrentSuperuserDep,
knowledge_base_id: int,
) -> dict:
try:
knowledge_base = knowledge_base_repo.must_get(session, knowledge_base_id)
stats_for_knowledge_base.delay(knowledge_base.id)
return knowledge_base_repo.get_index_overview(session, knowledge_base)
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{kb_id}/vector-index-errors")
def list_kb_vector_index_errors(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
params: Params = Depends(),
) -> Page[VectorIndexError]:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
return knowledge_base_repo.list_vector_index_built_errors(session, kb, params)
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/knowledge_bases/{kb_id}/kg-index-errors")
def list_kb_kg_index_errors(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
params: Params = Depends(),
) -> Page[KGIndexError]:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
return knowledge_base_repo.list_kg_index_built_errors(session, kb, params)
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.post("/admin/knowledge_bases/{kb_id}/retry-failed-index-tasks")
def retry_failed_tasks(
session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
) -> dict:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
# Retry failed vector index tasks.
document_ids = knowledge_base_repo.set_failed_documents_status_to_pending(
session, kb
)
for document_id in document_ids:
build_index_for_document.delay(kb_id, document_id)
logger.info(f"Triggered {len(document_ids)} documents to rebuilt vector index.")
# Retry failed kg index tasks.
chunk_ids = knowledge_base_repo.set_failed_chunks_status_to_pending(session, kb)
for chunk_id in chunk_ids:
build_kg_index_for_chunk.delay(kb_id, chunk_id)
logger.info(
f"Triggered {len(chunk_ids)} chunks to rebuilt knowledge graph index."
)
return {
"detail": f"Triggered reindex {len(document_ids)} documents and {len(chunk_ids)} chunks of knowledge base #{kb_id}.",
"reindex_document_ids": document_ids,
"reindex_chunk_ids": chunk_ids,
}
except HTTPException:
raise
except Exception as e:
logger.exception(e)
raise InternalServerError()
================================================
FILE: backend/app/api/admin_routes/langfuse.py
================================================
import logging
from pydantic import BaseModel
from fastapi import APIRouter
from langfuse import Langfuse
from app.api.deps import CurrentSuperuserDep
router = APIRouter()
logger = logging.getLogger(__name__)
class LangfuseSetting(BaseModel):
host: str = "https://us.cloud.langfuse.com"
public_key: str
secret_key: str
class LangfuseTestResult(BaseModel):
success: bool
error: str = ""
@router.post("/admin/langfuse/test")
def test_langfuse(
user: CurrentSuperuserDep,
request: LangfuseSetting,
) -> LangfuseTestResult:
try:
lf = Langfuse(
host=request.host,
secret_key=request.secret_key,
public_key=request.public_key,
)
success = lf.auth_check()
if not success:
error = "Langfuse authentication failed, please check public_key, secret_key and host."
else:
error = ""
except Exception as e:
success = False
error = str(e)
return LangfuseTestResult(success=success, error=error)
================================================
FILE: backend/app/api/admin_routes/legacy_retrieve.py
================================================
import logging
from typing import Optional, List
from fastapi import APIRouter
from sqlmodel import Session
from app.models import Document
from app.api.admin_routes.models import ChatEngineBasedRetrieveRequest
from app.api.deps import SessionDep, CurrentSuperuserDep
from llama_index.core.schema import NodeWithScore
from app.exceptions import InternalServerError, KBNotFound
from app.rag.chat.config import ChatEngineConfig
from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow
router = APIRouter()
logger = logging.getLogger(__name__)
def get_override_engine_config(
db_session: Session,
engine_name: str,
# Override chat engine config.
top_k: Optional[int] = None,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = None,
refine_question_with_kg: Optional[bool] = None,
) -> ChatEngineConfig:
engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
if similarity_top_k is not None:
engine_config.vector_search.similarity_top_k = similarity_top_k
if oversampling_factor is not None:
engine_config.vector_search.oversampling_factor = oversampling_factor
if top_k is not None:
engine_config.vector_search.top_k = top_k
if refine_question_with_kg is not None:
engine_config.refine_question_with_kg = refine_question_with_kg
return engine_config
@router.get("/admin/retrieve/documents", deprecated=True)
def legacy_retrieve_documents(
session: SessionDep,
user: CurrentSuperuserDep,
question: str,
chat_engine: str = "default",
# Override chat engine config.
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
refine_question_with_kg: Optional[bool] = True,
) -> List[Document]:
try:
engine_config = get_override_engine_config(
db_session=session,
engine_name=chat_engine,
top_k=top_k,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
refine_question_with_kg=refine_question_with_kg,
)
retriever = RetrieveFlow(
db_session=session,
engine_name=chat_engine,
engine_config=engine_config,
)
return retriever.retrieve_documents(question)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/admin/embedding_retrieve", deprecated=True)
def legacy_retrieve_chunks(
session: SessionDep,
user: CurrentSuperuserDep,
question: str,
chat_engine: str = "default",
# Override chat engine config.
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
refine_question_with_kg=False,
) -> List[NodeWithScore]:
try:
engine_config = get_override_engine_config(
db_session=session,
engine_name=chat_engine,
top_k=top_k,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
refine_question_with_kg=refine_question_with_kg,
)
retriever = RetrieveFlow(
db_session=session,
engine_name=chat_engine,
engine_config=engine_config,
)
return retriever.retrieve(question)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.post("/admin/embedding_retrieve", deprecated=True)
def legacy_retrieve_chunks_2(
session: SessionDep,
user: CurrentSuperuserDep,
request: ChatEngineBasedRetrieveRequest,
) -> List[NodeWithScore]:
try:
engine_config = get_override_engine_config(
db_session=session,
engine_name=request.chat_engine,
top_k=request.top_k,
similarity_top_k=request.similarity_top_k,
oversampling_factor=request.oversampling_factor,
refine_question_with_kg=request.refine_question_with_kg,
)
retriever = RetrieveFlow(
db_session=session,
engine_name=request.chat_engine,
engine_config=engine_config,
)
return retriever.retrieve(request.query)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
================================================
FILE: backend/app/api/admin_routes/llm/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/llm/routes.py
================================================
from typing import List
from fastapi import APIRouter, Depends
from fastapi_pagination import Page, Params
from llama_index.core.base.llms.types import ChatMessage
from pydantic import BaseModel
from app.api.deps import CurrentSuperuserDep, SessionDep
from app.logger import logger
from app.models import AdminLLM, LLM, LLMUpdate
from app.rag.llms.provider import LLMProviderOption, llm_provider_options
from app.rag.llms.resolver import resolve_llm
from app.repositories.llm import llm_repo
router = APIRouter()
@router.get("/admin/llms/providers/options")
def list_llm_provider_options(user: CurrentSuperuserDep) -> List[LLMProviderOption]:
return llm_provider_options
@router.get("/admin/llms")
def list_llms(
db_session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[AdminLLM]:
return llm_repo.paginate(db_session, params)
class LLMTestResult(BaseModel):
success: bool
error: str = ""
@router.post("/admin/llms/test")
def test_llm(
db_llm: LLM,
user: CurrentSuperuserDep,
) -> LLMTestResult:
try:
llm = resolve_llm(
provider=db_llm.provider,
model=db_llm.model,
config=db_llm.config,
credentials=db_llm.credentials,
)
llm.chat([ChatMessage(role="user", content="Who are you?")])
# Test with dspy LM.
import dspy
from app.rag.llms.dspy import get_dspy_lm_by_llama_llm
dspy_lm = get_dspy_lm_by_llama_llm(llm)
with dspy.context(lm=dspy_lm):
math = dspy.Predict("question -> answer: float")
prediction = math(question="1 + 1 = ?")
assert prediction.answer == 2
success = True
error = ""
except Exception as e:
logger.info(f"Failed to test LLM: {e}")
success = False
error = str(e)
return LLMTestResult(success=success, error=error)
@router.post("/admin/llms")
def create_llm(
db_session: SessionDep,
user: CurrentSuperuserDep,
llm: LLM,
) -> AdminLLM:
return llm_repo.create(db_session, llm)
@router.get("/admin/llms/{llm_id}")
def get_llm(
db_session: SessionDep,
user: CurrentSuperuserDep,
llm_id: int,
) -> AdminLLM:
return llm_repo.must_get(db_session, llm_id)
@router.put("/admin/llms/{llm_id}")
def update_llm(
db_session: SessionDep,
user: CurrentSuperuserDep,
llm_id: int,
llm_update: LLMUpdate,
) -> AdminLLM:
llm = llm_repo.must_get(db_session, llm_id)
return llm_repo.update(db_session, llm, llm_update)
@router.delete("/admin/llms/{llm_id}")
def delete_llm(
db_session: SessionDep,
user: CurrentSuperuserDep,
llm_id: int,
) -> None:
llm = llm_repo.must_get(db_session, llm_id)
llm_repo.delete(db_session, llm)
@router.put("/admin/llms/{llm_id}/set_default")
def set_default_llm(
db_session: SessionDep, user: CurrentSuperuserDep, llm_id: int
) -> AdminLLM:
llm = llm_repo.must_get(db_session, llm_id)
return llm_repo.set_default(db_session, llm)
================================================
FILE: backend/app/api/admin_routes/models.py
================================================
from uuid import UUID
from typing import Optional
from pydantic import BaseModel
from app.api.admin_routes.embedding_model.models import EmbeddingModelItem
from app.rag.llms.provider import LLMProvider
class LLMDescriptor(BaseModel):
id: int
name: str
provider: LLMProvider
model: str
is_default: bool
class EmbeddingModelDescriptor(EmbeddingModelItem):
pass
class UserDescriptor(BaseModel):
id: UUID
email: str
class KnowledgeBaseDescriptor(BaseModel):
id: int
name: str
def __hash__(self):
return hash(self.id)
class DataSourceDescriptor(BaseModel):
id: int
name: str
class ChatEngineDescriptor(BaseModel):
id: int
name: str
is_default: bool
class ChatEngineBasedRetrieveRequest(BaseModel):
query: str
chat_engine: Optional[str] = "default"
top_k: Optional[int] = 5
similarity_top_k: Optional[int] = None
oversampling_factor: Optional[int] = 5
refine_question_with_kg: Optional[bool] = False
================================================
FILE: backend/app/api/admin_routes/reranker_model/__init__.py
================================================
================================================
FILE: backend/app/api/admin_routes/reranker_model/routes.py
================================================
from typing import List
from fastapi import Depends, APIRouter
from fastapi_pagination import Params, Page
from pydantic import BaseModel
from llama_index.core.schema import NodeWithScore, TextNode
from app.api.admin_routes.llm.routes import LLMTestResult
from app.api.deps import CurrentSuperuserDep, SessionDep
from app.models import RerankerModel, AdminRerankerModel
from app.models.reranker_model import RerankerModelUpdate
from app.repositories.reranker_model import reranker_model_repo
from app.rag.rerankers.provider import RerankerProviderOption, reranker_provider_options
from app.rag.rerankers.resolver import resolve_reranker
from app.logger import logger
router = APIRouter()
@router.get("/admin/reranker-models/providers/options")
def list_reranker_model_provider_options(
user: CurrentSuperuserDep,
) -> List[RerankerProviderOption]:
return reranker_provider_options
@router.get("/admin/reranker-models")
def list_reranker_models(
db_session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[AdminRerankerModel]:
return reranker_model_repo.paginate(db_session, params)
class RerankerModelTestResult(BaseModel):
success: bool
error: str = ""
@router.post("/admin/reranker-models/test")
def test_reranker_model(
db_reranker_model: RerankerModel, user: CurrentSuperuserDep
) -> LLMTestResult:
try:
reranker = resolve_reranker(
provider=db_reranker_model.provider,
model=db_reranker_model.model,
# for testing purpose, we only rerank 2 nodes
top_n=2,
config=db_reranker_model.config,
credentials=db_reranker_model.credentials,
)
reranked_nodes = reranker.postprocess_nodes(
nodes=[
NodeWithScore(
node=TextNode(
text="TiDB is a distributed SQL database.",
),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="TiKV is a distributed key-value storage engine.",
),
score=0.6,
),
NodeWithScore(
node=TextNode(
text="TiFlash is a columnar storage engine.",
),
score=0.4,
),
],
query_str="What is TiDB?",
)
if len(reranked_nodes) != 2:
raise ValueError("expected 2 nodes, but got %d", len(reranked_nodes))
success = True
error = ""
except Exception as e:
logger.info(f"Failed to test reranker model: {e}")
success = False
error = str(e)
return RerankerModelTestResult(success=success, error=error)
@router.post("/admin/reranker-models")
def create_reranker_model(
db_session: SessionDep,
user: CurrentSuperuserDep,
reranker_model: RerankerModel,
) -> AdminRerankerModel:
return reranker_model_repo.create(db_session, reranker_model)
@router.get("/admin/reranker-models/{model_id}")
def get_reranker_model(
db_session: SessionDep,
user: CurrentSuperuserDep,
model_id: int,
) -> AdminRerankerModel:
return reranker_model_repo.must_get(db_session, model_id)
@router.put("/admin/reranker-models/{model_id}")
def update_reranker_model(
db_session: SessionDep,
user: CurrentSuperuserDep,
model_id: int,
model_update: RerankerModelUpdate,
) -> AdminRerankerModel:
reranker_model = reranker_model_repo.must_get(db_session, model_id)
return reranker_model_repo.update(db_session, reranker_model, model_update)
@router.delete("/admin/reranker-models/{model_id}")
def delete_reranker_model(
db_session: SessionDep,
user: CurrentSuperuserDep,
model_id: int,
) -> None:
reranker_model = reranker_model_repo.must_get(db_session, model_id)
reranker_model_repo.delete(db_session, reranker_model)
@router.put("/admin/reranker-models/{model_id}/set_default")
def set_default_reranker_model(
db_session: SessionDep, user: CurrentSuperuserDep, model_id: int
) -> AdminRerankerModel:
reranker_model = reranker_model_repo.must_get(db_session, model_id)
return reranker_model_repo.set_default(db_session, reranker_model)
================================================
FILE: backend/app/api/admin_routes/semantic_cache.py
================================================
from typing import Optional, Dict
import time
import logging
from fastapi import APIRouter, Body
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.chat.config import ChatEngineConfig
from app.rag.semantic_cache import SemanticCacheManager, SemanticItem
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/admin/semantic_cache")
async def add_semantic_cache(
session: SessionDep,
user: CurrentSuperuserDep,
question: str,
answer: str,
namespace: str = "default",
chat_engine: str = "default",
metadata: Optional[dict] = Body(None),
) -> Dict:
chat_engine_config = ChatEngineConfig.load_from_db(session, chat_engine)
_dspy_lm = chat_engine_config.get_dspy_lm(session)
scm = SemanticCacheManager(
dspy_llm=_dspy_lm,
)
try:
scm.add_cache(
session,
item=SemanticItem(question=question, answer=answer),
namespace=namespace,
metadata=metadata,
)
except Exception as e:
return {
"status": "failed",
"message": str(e),
}
return {
"status": "success",
}
@router.get("/admin/semantic_cache")
async def search_semantic_cache(
session: SessionDep,
user: CurrentSuperuserDep,
query: str,
namespace: str = "default",
chat_engine: str = "default",
) -> Dict:
start_time = time.time()
chat_engine_config = ChatEngineConfig.load_from_db(session, chat_engine)
_dspy_lm = chat_engine_config.get_dspy_lm(session)
logger.debug(
f"[search_semantic_cache] Loading dspy_lm took {time.time() - start_time:.2f} seconds"
)
scm = SemanticCacheManager(
dspy_llm=_dspy_lm,
)
start_time = time.time()
response = scm.search(
session=session,
query=query,
namespace=namespace,
)
logger.debug(
f"[search_semantic_cache] Searching semantic cache took {time.time() - start_time:.2f} seconds"
)
return response
================================================
FILE: backend/app/api/admin_routes/site_setting.py
================================================
from typing import Dict
from pydantic import BaseModel
from http import HTTPStatus
from fastapi import APIRouter, HTTPException
from app.api.deps import CurrentSuperuserDep, SessionDep
from app.site_settings import SiteSetting, SettingValue, SettingType
router = APIRouter()
@router.get("/admin/site-settings", response_model=Dict[str, SettingValue])
def site_settings(user: CurrentSuperuserDep):
return SiteSetting.get_all_settings(force_check_db_cache=True)
class SettingUpdate(BaseModel):
value: SettingType
@router.put(
"/admin/site-settings/{setting_name}",
status_code=HTTPStatus.NO_CONTENT,
responses={
HTTPStatus.BAD_REQUEST: {
"content": {
"application/json": {
"examples": {
"invalid_data_type": {
"summary": "Invalid data type",
"value": {"detail": "title must be of type `str`"},
},
}
}
},
},
HTTPStatus.NOT_FOUND: {
"content": {
"application/json": {
"examples": {
"setting_not_found": {
"summary": "Setting not found",
"value": {"detail": "Setting not found"},
},
}
}
},
},
},
)
def update_site_setting(
session: SessionDep,
user: CurrentSuperuserDep,
setting_name: str,
request: SettingUpdate,
):
if not SiteSetting.setting_exists(setting_name):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Setting not found"
)
try:
SiteSetting.update_setting(session, setting_name, request.value)
except ValueError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
================================================
FILE: backend/app/api/admin_routes/stats.py
================================================
from datetime import date
from pydantic import BaseModel
from fastapi import APIRouter
from app.api.deps import CurrentSuperuserDep, SessionDep
from app.repositories import chat_repo
router = APIRouter()
class DateRangeStats(BaseModel):
start_date: date
end_date: date
class ChatStats(DateRangeStats):
values: list
@router.get("/admin/stats/trend/chat-user")
def chat_count_trend(
session: SessionDep, user: CurrentSuperuserDep, start_date: date, end_date: date
) -> ChatStats:
stats = chat_repo.chat_trend_by_user(session, start_date, end_date)
return ChatStats(start_date=start_date, end_date=end_date, values=stats)
@router.get("/admin/stats/trend/chat-origin")
def chat_origin_trend(
session: SessionDep, user: CurrentSuperuserDep, start_date: date, end_date: date
) -> ChatStats:
stats = chat_repo.chat_trend_by_origin(session, start_date, end_date)
return ChatStats(start_date=start_date, end_date=end_date, values=stats)
================================================
FILE: backend/app/api/admin_routes/upload.py
================================================
import os
import time
from typing import List
from fastapi import APIRouter, UploadFile, HTTPException, status
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.file_storage import default_file_storage
from app.utils.uuid6 import uuid7
from app.models import Upload
from app.types import MimeTypes
from app.site_settings import SiteSetting
router = APIRouter()
SUPPORTED_FILE_TYPES = {
".txt": MimeTypes.PLAIN_TXT,
".md": MimeTypes.MARKDOWN,
".pdf": MimeTypes.PDF,
".docx": MimeTypes.DOCX,
".pptx": MimeTypes.PPTX,
".xlsx": MimeTypes.XLSX,
".csv": MimeTypes.CSV,
}
@router.post("/admin/uploads")
def upload_files(
session: SessionDep, user: CurrentSuperuserDep, files: List[UploadFile]
) -> List[Upload]:
uploads = []
for file in files:
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="File name cannot be empty",
)
sys_max_upload_file_size = SiteSetting.max_upload_file_size
if file.size > sys_max_upload_file_size:
upload_file_size_in_mb = file.size / 1024 / 1024
max_upload_file_size_in_mb = sys_max_upload_file_size / 1024 / 1024
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail="The upload file size ({:.2f} MiB) exceeds maximum allowed size ({:.2f} MiB)".format(
upload_file_size_in_mb, max_upload_file_size_in_mb
),
)
file_ext = os.path.splitext(file.filename)[1].lower()
if file_ext not in SUPPORTED_FILE_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type {file_ext} not supported. Supported types: {SUPPORTED_FILE_TYPES.keys()}",
)
file_path = f"uploads/{user.id.hex}/{int(time.time())}-{uuid7().hex}{file_ext}"
default_file_storage.save(file_path, file.file)
uploads.append(
Upload(
name=file.filename,
size=default_file_storage.size(file_path),
path=file_path,
mime_type=SUPPORTED_FILE_TYPES[file_ext],
user_id=user.id,
)
)
session.add_all(uploads)
session.commit()
return uploads
================================================
FILE: backend/app/api/admin_routes/user.py
================================================
from typing import Optional
from fastapi import APIRouter, Depends
from fastapi_pagination import Page, Params
from app.repositories.user import user_repo
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.api.admin_routes.models import (
UserDescriptor,
)
router = APIRouter(
prefix="/admin/users",
tags=["admin/users"],
)
@router.get("/search")
def search_users(
db_session: SessionDep,
user: CurrentSuperuserDep,
search: Optional[str] = None,
params: Params = Depends(),
) -> Page[UserDescriptor]:
return user_repo.search_users(db_session, search, params)
================================================
FILE: backend/app/api/deps.py
================================================
from typing import Annotated
from fastapi import Depends
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.db import get_db_session, get_db_async_session
from app.models import User
from app.auth.users import (
current_user,
current_superuser,
optional_current_user,
)
SessionDep = Annotated[Session, Depends(get_db_session)]
AsyncSessionDep = Annotated[AsyncSession, Depends(get_db_async_session)]
# Dependency for current user, it will return None if user is not authenticated
OptionalUserDep = Annotated[User | None, Depends(optional_current_user)]
# Dependencies for current user and superuser, it will return 401 if user is not authenticated
CurrentUserDep = Annotated[User, Depends(current_user)]
CurrentSuperuserDep = Annotated[User, Depends(current_superuser)]
================================================
FILE: backend/app/api/main.py
================================================
from fastapi import APIRouter
from app.api.routes import (
chat_engine,
index,
chat,
user,
api_key,
feedback,
document,
)
from app.api.admin_routes.knowledge_base.routes import (
router as admin_knowledge_base_router,
)
from app.api.admin_routes.knowledge_base.graph.routes import (
router as admin_kb_graph_router,
)
from app.api.admin_routes.knowledge_base.graph.knowledge.routes import (
router as admin_kb_graph_knowledge_router,
)
from app.api.admin_routes.knowledge_base.data_source.routes import (
router as admin_kb_data_source_router,
)
from app.api.admin_routes.knowledge_base.document.routes import (
router as admin_kb_document_router,
)
from app.api.admin_routes.knowledge_base.chunk.routes import (
router as admin_kb_chunk_router,
)
from app.api.admin_routes.document.routes import router as admin_document_router
from app.api.admin_routes.llm.routes import router as admin_llm_router
from app.api.admin_routes.embedding_model.routes import (
router as admin_embedding_model_router,
)
from app.api.admin_routes.reranker_model.routes import (
router as admin_reranker_model_router,
)
from app.api.admin_routes.chat.routes import router as admin_user_router
from app.api.admin_routes import (
chat_engine as admin_chat_engine,
feedback as admin_feedback,
legacy_retrieve as admin_legacy_retrieve,
site_setting as admin_site_settings,
upload as admin_upload,
stats as admin_stats,
semantic_cache as admin_semantic_cache,
langfuse as admin_langfuse,
user as admin_user,
)
from app.api.admin_routes.evaluation import (
evaluation_task as admin_evaluation_task,
evaluation_dataset as admin_evaluation_dataset,
)
from app.api.routes.retrieve import (
routes as retrieve_routes,
)
from app.auth.users import auth_backend, fastapi_users
api_router = APIRouter()
api_router.include_router(index.router, tags=["index"])
api_router.include_router(chat.router, tags=["chat"])
api_router.include_router(feedback.router, tags=["chat"])
api_router.include_router(user.router, tags=["user"])
api_router.include_router(api_key.router, tags=["auth"])
api_router.include_router(document.router, tags=["documents"])
api_router.include_router(chat_engine.router, tags=["chat-engines"])
api_router.include_router(retrieve_routes.router, tags=["retrieve"])
api_router.include_router(admin_user_router)
api_router.include_router(admin_chat_engine.router, tags=["admin/chat-engines"])
api_router.include_router(admin_document_router, tags=["admin/documents"])
api_router.include_router(admin_feedback.router)
api_router.include_router(admin_site_settings.router, tags=["admin/site_settings"])
api_router.include_router(admin_upload.router, tags=["admin/upload"])
api_router.include_router(admin_knowledge_base_router, tags=["admin/knowledge_base"])
api_router.include_router(admin_kb_graph_router, tags=["admin/knowledge_base/graph"])
api_router.include_router(
admin_kb_graph_knowledge_router, tags=["admin/knowledge_base/graph/knowledge"]
)
api_router.include_router(
admin_kb_data_source_router, tags=["admin/knowledge_base/data_source"]
)
api_router.include_router(
admin_kb_document_router, tags=["admin/knowledge_base/document"]
)
api_router.include_router(admin_kb_chunk_router, tags=["admin/knowledge_base/chunk"])
api_router.include_router(admin_llm_router, tags=["admin/llm"])
api_router.include_router(admin_embedding_model_router, tags=["admin/embedding_model"])
api_router.include_router(admin_reranker_model_router, tags=["admin/reranker_model"])
api_router.include_router(admin_langfuse.router, tags=["admin/langfuse"])
api_router.include_router(admin_legacy_retrieve.router, tags=["admin/retrieve_old"])
api_router.include_router(admin_stats.router, tags=["admin/stats"])
api_router.include_router(admin_semantic_cache.router, tags=["admin/semantic_cache"])
api_router.include_router(admin_evaluation_task.router, tags=["admin/evaluation/task"])
api_router.include_router(
admin_evaluation_dataset.router, tags=["admin/evaluation/dataset"]
)
api_router.include_router(admin_user.router)
api_router.include_router(
fastapi_users.get_auth_router(auth_backend), prefix="/auth", tags=["auth"]
)
================================================
FILE: backend/app/api/routes/__init__.py
================================================
================================================
FILE: backend/app/api/routes/api_key.py
================================================
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from fastapi_pagination import Params, Page
from app.api.deps import AsyncSessionDep, CurrentSuperuserDep
from app.auth.api_keys import api_key_manager
from app.models import PublicApiKey
router = APIRouter()
class CreateApiKeyRequest(BaseModel):
description: str
class CreateApiKeyResponse(BaseModel):
api_key: str
@router.post("/api-keys")
async def create_api_key(
session: AsyncSessionDep, user: CurrentSuperuserDep, request: CreateApiKeyRequest
) -> CreateApiKeyResponse:
_, raw_api_key = await api_key_manager.acreate_api_key(
session, user, request.description
)
return CreateApiKeyResponse(api_key=raw_api_key)
@router.get("/api-keys")
async def list_api_keys(
session: AsyncSessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[PublicApiKey]:
return await api_key_manager.list_api_keys(session, user, params)
@router.delete("/api-keys/{api_key_id}")
async def delete_api_key(
session: AsyncSessionDep, user: CurrentSuperuserDep, api_key_id: int
):
return await api_key_manager.delete_api_key(session, user, api_key_id)
================================================
FILE: backend/app/api/routes/chat.py
================================================
import logging
from uuid import UUID
from typing import List, Optional, Annotated
from http import HTTPStatus
from pydantic import (
BaseModel,
field_validator,
)
from fastapi import APIRouter, Depends, HTTPException, Request, Query
from fastapi.responses import StreamingResponse
from fastapi_pagination import Params, Page
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from app.api.deps import SessionDep, OptionalUserDep, CurrentUserDep
from app.rag.chat.chat_flow import ChatFlow
from app.rag.retrievers.knowledge_graph.schema import KnowledgeGraphRetrievalResult
from app.repositories import chat_repo
from app.models import Chat, ChatUpdate
from app.rag.chat.chat_service import get_final_chat_result
from app.models import Chat, ChatUpdate, ChatFilters
from app.rag.chat.chat_service import (
user_can_view_chat,
user_can_edit_chat,
get_chat_message_subgraph,
get_chat_message_recommend_questions,
remove_chat_message_recommend_questions,
)
from app.exceptions import InternalServerError
logger = logging.getLogger(__name__)
router = APIRouter()
class ChatRequest(BaseModel):
messages: List[ChatMessage]
chat_engine: str = "default"
chat_id: Optional[UUID] = None
stream: bool = True
@field_validator("messages")
@classmethod
def check_messages(cls, messages: List[ChatMessage]) -> List[ChatMessage]:
if not messages:
raise ValueError("messages cannot be empty")
for m in messages:
if m.role not in [MessageRole.USER, MessageRole.ASSISTANT]:
raise ValueError("role must be either 'user' or 'assistant'")
if not m.content:
raise ValueError("message content cannot be empty")
if len(m.content) > 20000:
raise ValueError("message content cannot exceed 2000 characters")
if messages[-1].role != MessageRole.USER:
raise ValueError("last message must be from user")
return messages
@router.post("/chats")
def chats(
request: Request,
session: SessionDep,
user: OptionalUserDep,
chat_request: ChatRequest,
):
origin = request.headers.get("Origin") or request.headers.get("Referer")
browser_id = request.state.browser_id
try:
chat_flow = ChatFlow(
db_session=session,
user=user,
browser_id=browser_id,
origin=origin,
chat_id=chat_request.chat_id,
chat_messages=chat_request.messages,
engine_name=chat_request.chat_engine,
)
if chat_request.stream:
return StreamingResponse(
chat_flow.chat(),
media_type="text/event-stream",
headers={
"X-Content-Type-Options": "nosniff",
},
)
else:
return get_final_chat_result(chat_flow.chat())
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
@router.get("/chats")
def list_chats(
request: Request,
session: SessionDep,
user: OptionalUserDep,
filters: Annotated[ChatFilters, Query()],
params: Params = Depends(),
) -> Page[Chat]:
browser_id = request.state.browser_id
return chat_repo.paginate(session, user, browser_id, filters, params)
@router.get("/chats/{chat_id}")
def get_chat(session: SessionDep, user: OptionalUserDep, chat_id: UUID):
chat = chat_repo.must_get(session, chat_id)
if not user_can_view_chat(chat, user):
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Access denied")
return {
"chat": chat,
"messages": chat_repo.get_messages(session, chat),
}
@router.put("/chats/{chat_id}")
def update_chat(
session: SessionDep, user: CurrentUserDep, chat_id: UUID, chat_update: ChatUpdate
):
try:
chat = chat_repo.must_get(session, chat_id)
if not user_can_edit_chat(chat, user):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Access denied"
)
return chat_repo.update(session, chat, chat_update)
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e, exc_info=True)
raise InternalServerError()
@router.delete("/chats/{chat_id}")
def delete_chat(session: SessionDep, user: CurrentUserDep, chat_id: UUID):
try:
chat = chat_repo.must_get(session, chat_id)
if not user_can_edit_chat(chat, user):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Access denied"
)
return chat_repo.delete(session, chat)
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e, exc_info=True)
raise InternalServerError()
@router.get(
"/chat-messages/{chat_message_id}/subgraph",
response_model=KnowledgeGraphRetrievalResult,
)
def get_chat_subgraph(session: SessionDep, user: OptionalUserDep, chat_message_id: int):
try:
chat_message = chat_repo.must_get_message(session, chat_message_id)
if not user_can_view_chat(chat_message.chat, user):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Access denied"
)
result = get_chat_message_subgraph(session, chat_message)
return result.model_dump(exclude_none=True)
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e, exc_info=True)
raise InternalServerError()
@router.get("/chat-messages/{chat_message_id}/recommended-questions")
def get_recommended_questions(
session: SessionDep, user: OptionalUserDep, chat_message_id: int
) -> List[str]:
try:
chat_message = chat_repo.must_get_message(session, chat_message_id)
if not user_can_view_chat(chat_message.chat, user):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Access denied"
)
return get_chat_message_recommend_questions(session, chat_message)
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e, exc_info=True)
raise InternalServerError()
@router.post("/chat-messages/{chat_message_id}/recommended-questions")
def refresh_recommended_questions(
session: SessionDep, user: OptionalUserDep, chat_message_id: int
) -> List[str]:
try:
chat_message = chat_repo.must_get_message(session, chat_message_id)
if not user_can_view_chat(chat_message.chat, user):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Access denied"
)
remove_chat_message_recommend_questions(session, chat_message_id)
return get_chat_message_recommend_questions(session, chat_message)
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e, exc_info=True)
raise InternalServerError()
================================================
FILE: backend/app/api/routes/chat_engine.py
================================================
import logging
from fastapi import APIRouter, Depends
from app.api.deps import SessionDep
from fastapi_pagination import Params, Page
from app.models.chat_engine import ChatEngine
from app.rag.chat.config import ChatEngineConfig
from app.repositories.chat_engine import chat_engine_repo
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/chat-engines")
def list_chat_engines(
db_session: SessionDep,
params: Params = Depends(),
) -> Page[ChatEngine]:
page = chat_engine_repo.paginate(db_session, params, need_public=True)
for chat_engine in page.items:
engine_config = ChatEngineConfig.model_validate(chat_engine.engine_options)
chat_engine.engine_options = engine_config.screenshot()
return page
@router.get("/chat-engines/{chat_engine_id}")
def get_chat_engine(
db_session: SessionDep,
chat_engine_id: int,
) -> ChatEngine:
chat_engine = chat_engine_repo.must_get(
db_session, chat_engine_id, need_public=True
)
engine_config = ChatEngineConfig.model_validate(chat_engine.engine_options)
chat_engine.engine_options = engine_config.screenshot()
return chat_engine
================================================
FILE: backend/app/api/routes/document.py
================================================
from fastapi import HTTPException, APIRouter
from fastapi.responses import StreamingResponse
from app.api.deps import SessionDep
from app.repositories import document_repo
from app.file_storage import get_file_storage
router = APIRouter()
@router.get("/documents/{doc_id}/download")
def download_file(doc_id: int, session: SessionDep):
doc = document_repo.must_get(session, doc_id)
name = doc.source_uri
filestorage = get_file_storage()
if filestorage.exists(name):
file_size = filestorage.size(name)
headers = {"Content-Length": str(file_size)}
def iterfile():
with filestorage.open(name) as f:
yield from f
return StreamingResponse(iterfile(), media_type=doc.mime_type, headers=headers)
else:
raise HTTPException(status_code=404, detail="File not found")
================================================
FILE: backend/app/api/routes/feedback.py
================================================
from fastapi import APIRouter, HTTPException, Header
from http import HTTPStatus
from pydantic import BaseModel
from app.api.deps import SessionDep, OptionalUserDep
from app.models import FeedbackType, Feedback
from app.repositories import chat_repo
router = APIRouter()
class FeedbackRequest(BaseModel):
feedback_type: FeedbackType
comment: str
@router.post(
"/chat-messages/{chat_message_id}/feedback", status_code=HTTPStatus.CREATED
)
def feedback(
session: SessionDep,
user: OptionalUserDep,
chat_message_id: int,
request: FeedbackRequest,
origin: str = Header(None),
referer: str = Header(None),
):
chat_message = chat_repo.get_message(session, chat_message_id)
if not chat_message:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Chat message not found"
)
feedback = Feedback(
feedback_type=request.feedback_type,
comment=request.comment,
chat_message_id=chat_message_id,
chat_id=chat_message.chat_id,
user_id=user.id if user else None,
origin=origin or referer,
)
session.add(feedback)
session.commit()
return
================================================
FILE: backend/app/api/routes/index.py
================================================
from fastapi import APIRouter
from sqlmodel import text
from app.api.deps import SessionDep
from app.api.routes.models import SystemConfigStatusResponse
from app.site_settings import SiteSetting
from app.rag.chat.chat_service import (
check_rag_required_config,
check_rag_optional_config,
check_rag_config_need_migration,
)
router = APIRouter()
@router.get("/healthz")
def status(session: SessionDep):
now = session.exec(text("SELECT NOW()")).scalar()
return {"now": now}
@router.get("/site-config")
def site_config() -> dict:
return SiteSetting.get_client_settings()
@router.get("/system/bootstrap-status")
def system_bootstrap_status(session: SessionDep) -> SystemConfigStatusResponse:
required_config_check_status = check_rag_required_config(session)
optional_config_check_status = check_rag_optional_config(session)
need_migration_status = check_rag_config_need_migration(session)
return SystemConfigStatusResponse(
required=required_config_check_status,
optional=optional_config_check_status,
need_migration=need_migration_status,
)
================================================
FILE: backend/app/api/routes/models.py
================================================
from pydantic import BaseModel
class RequiredConfigStatus(BaseModel):
default_llm: bool
default_embedding_model: bool
default_chat_engine: bool
knowledge_base: bool
class OptionalConfigStatus(BaseModel):
langfuse: bool
default_reranker: bool
class NeedMigrationStatus(BaseModel):
chat_engines_without_kb_configured: list[int]
class SystemConfigStatusResponse(BaseModel):
required: RequiredConfigStatus
optional: OptionalConfigStatus
need_migration: NeedMigrationStatus
================================================
FILE: backend/app/api/routes/retrieve/__init__.py
================================================
================================================
FILE: backend/app/api/routes/retrieve/models.py
================================================
from typing import Optional
from pydantic import BaseModel
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrieverConfig,
)
from app.rag.retrievers.chunk.schema import VectorSearchRetrieverConfig
from app.rag.retrievers.multiple_knowledge_base import FusionRetrievalBaseConfig
# Chunks retrieval
class ChunkRetrievalConfig(FusionRetrievalBaseConfig):
full_documents: Optional[bool] = False
vector_search: VectorSearchRetrieverConfig
class ChunksRetrievalRequest(BaseModel):
query: str
retrieval_config: ChunkRetrievalConfig
## Knowledge Graph retrieval
class KnowledgeGraphRetrievalConfig(FusionRetrievalBaseConfig):
knowledge_graph: KnowledgeGraphRetrieverConfig
class KnowledgeGraphRetrievalRequest(BaseModel):
query: str
retrieval_config: KnowledgeGraphRetrievalConfig
================================================
FILE: backend/app/api/routes/retrieve/routes.py
================================================
import logging
from fastapi import APIRouter
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.retrievers.knowledge_graph.fusion_retriever import (
KnowledgeGraphFusionRetriever,
)
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrievalResult,
)
from app.rag.retrievers.chunk.fusion_retriever import (
ChunkFusionRetriever,
)
from app.exceptions import KBNotFound
from app.rag.retrievers.chunk.schema import ChunksRetrievalResult
from app.rag.llms.resolver import get_llm_or_default
from .models import ChunksRetrievalRequest, KnowledgeGraphRetrievalRequest
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/retrieve/chunks")
def retrieve_chunks(
db_session: SessionDep,
user: CurrentSuperuserDep,
request: ChunksRetrievalRequest,
) -> ChunksRetrievalResult:
try:
config = request.retrieval_config
llm = get_llm_or_default(db_session, config.llm_id)
retriever = ChunkFusionRetriever(
db_session=db_session,
knowledge_base_ids=config.knowledge_base_ids,
llm=llm,
use_query_decompose=config.use_query_decompose,
config=config.vector_search,
)
return retriever.retrieve_chunks(request.query, config.full_documents)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise
@router.post("/retrieve/knowledge_graph")
def retrieve_knowledge_graph(
db_session: SessionDep,
user: CurrentSuperuserDep,
request: KnowledgeGraphRetrievalRequest,
) -> KnowledgeGraphRetrievalResult:
try:
config = request.retrieval_config
llm = get_llm_or_default(db_session, config.llm_id)
retriever = KnowledgeGraphFusionRetriever(
db_session=db_session,
knowledge_base_ids=config.knowledge_base_ids,
llm=llm,
use_query_decompose=config.use_query_decompose,
config=config.knowledge_graph,
)
return retriever.retrieve_knowledge_graph(request.query)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise
================================================
FILE: backend/app/api/routes/user.py
================================================
from fastapi import APIRouter
from app.api.deps import CurrentUserDep
from app.auth.schemas import UserRead
router = APIRouter()
@router.get("/users/me", response_model=UserRead)
def me(user: CurrentUserDep):
return user
================================================
FILE: backend/app/api_server.py
================================================
import app.logger
import sentry_sdk
from dotenv import load_dotenv
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Response
from fastapi.routing import APIRoute
from starlette.middleware.cors import CORSMiddleware
from app.api.main import api_router
from app.core.config import settings
from app.site_settings import SiteSetting
from app.utils.uuid6 import uuid7
load_dotenv()
def custom_generate_unique_id(route: APIRoute) -> str:
return f"{route.tags[0]}-{route.name}"
if settings.SENTRY_DSN and settings.ENVIRONMENT != "local":
sentry_sdk.init(
dsn=str(settings.SENTRY_DSN),
enable_tracing=True,
traces_sample_rate=settings.SENTRY_TRACES_SAMPLE_RATE,
profiles_sample_rate=settings.SENTRY_PROFILES_SAMPLE_RATE,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
SiteSetting.update_db_cache()
yield
app = FastAPI(
title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json",
generate_unique_id_function=custom_generate_unique_id,
lifespan=lifespan,
)
# Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=[
str(origin).strip("/") for origin in settings.BACKEND_CORS_ORIGINS
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def identify_browser(request: Request, call_next):
browser_id = request.cookies.get(settings.BROWSER_ID_COOKIE_NAME)
has_browser_id = bool(browser_id)
if not browser_id:
browser_id = uuid7()
request.state.browser_id = browser_id
response: Response = await call_next(request)
if not has_browser_id:
response.set_cookie(
settings.BROWSER_ID_COOKIE_NAME,
browser_id,
max_age=settings.BROWSER_ID_COOKIE_MAX_AGE,
)
return response
app.include_router(api_router, prefix=settings.API_V1_STR)
================================================
FILE: backend/app/auth/api_keys.py
================================================
import base64
import string
import secrets
import hashlib
from typing import Optional, Tuple
from fastapi import Request
from sqlmodel import Session, select
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from app.models import ApiKey, User
def generate_api_key(length=50):
characters = string.ascii_letters + string.digits
api_key = "".join(secrets.choice(characters) for _ in range(length))
return "ta-" + api_key
API_KEY_HEADER = "Authorization"
BEARER_PREFIX = "Bearer "
def get_api_key_from_request(request: Request) -> str | None:
api_key_header = request.headers.get(API_KEY_HEADER)
if api_key_header is None:
return None
if not api_key_header.startswith(BEARER_PREFIX):
return None
return api_key_header[len(BEARER_PREFIX) :].strip()
def encrypt_api_key(api_key: str) -> str:
# An empty salt is used because we need to look tokens up solely by
# their hashed value. Additionally, tokens are always cryptographically
# pseudo-random and unique, therefore salting provides no
# additional security.
algorithm = "pbkdf2_sha512"
api_key = api_key.encode("utf-8")
salt = b""
iterations = 20_000
hash = hashlib.pbkdf2_hmac("sha512", api_key, salt, iterations)
hash = base64.b64encode(hash).decode("ascii").strip()
return "%s$%d$%s" % (algorithm, iterations, hash)
class ApiKeyManager:
async def acreate_api_key(
self, session: AsyncSession, user: User, description: str
) -> Tuple[ApiKey, str]:
api_key = generate_api_key()
hashed_api_key = encrypt_api_key(api_key)
api_key_obj = ApiKey(
hashed_secret=hashed_api_key,
api_key_display=api_key[:7] + "...." + api_key[-3:],
user_id=user.id,
description=description,
)
session.add(api_key_obj)
await session.commit()
await session.refresh(api_key_obj)
return api_key_obj, api_key
def create_api_key(
self, session: Session, user: User, description: str
) -> Tuple[ApiKey, str]:
api_key = generate_api_key()
hashed_api_key = encrypt_api_key(api_key)
api_key_obj = ApiKey(
hashed_secret=hashed_api_key,
api_key_display=api_key[:7] + "...." + api_key[-3:],
user_id=user.id,
description=description,
)
session.add(api_key_obj)
session.commit()
session.refresh(api_key_obj)
return api_key_obj, api_key
async def get_active_user_by_raw_api_key(
self, session: AsyncSession, api_key: str
) -> Optional[User]:
if not api_key:
return None
hashed_api_key = encrypt_api_key(api_key)
results = await session.exec(
select(ApiKey).where(
ApiKey.is_active == True,
ApiKey.hashed_secret == hashed_api_key,
)
)
api_key_obj = results.first()
if not api_key_obj:
return None
user = await session.get(User, api_key_obj.user_id)
if not (user.is_active and user.is_verified):
return None
return user
async def get_active_user_from_request(
self, session: AsyncSession, request: Request
) -> Optional[User]:
api_key = get_api_key_from_request(request)
return await self.get_active_user_by_raw_api_key(session, api_key)
async def list_api_keys(
self, session: AsyncSession, user: User, params: Params
) -> Page[ApiKey]:
api_keys = await paginate(
session,
select(ApiKey)
.where(ApiKey.user == user, ApiKey.is_active == True)
.order_by(ApiKey.created_at.desc()),
params,
)
return api_keys
async def delete_api_key(self, session: AsyncSession, user: User, api_key_id: int):
result = await session.exec(
select(ApiKey).where(
ApiKey.id == api_key_id,
ApiKey.user_id == user.id,
ApiKey.is_active == True,
)
)
api_key = result.first()
if api_key:
api_key.is_active = False
await session.commit()
api_key_manager = ApiKeyManager()
================================================
FILE: backend/app/auth/db.py
================================================
from fastapi import Depends
from fastapi_users_db_sqlmodel import SQLModelUserDatabaseAsync
from fastapi_users_db_sqlmodel.access_token import SQLModelAccessTokenDatabaseAsync
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models import User, UserSession
from app.core.db import get_db_async_session
async def get_user_db(session: AsyncSession = Depends(get_db_async_session)):
yield SQLModelUserDatabaseAsync(session, User)
async def get_user_session_db(
session: AsyncSession = Depends(get_db_async_session),
):
yield SQLModelAccessTokenDatabaseAsync(session, UserSession)
================================================
FILE: backend/app/auth/schemas.py
================================================
import uuid
from fastapi_users import schemas
class UserRead(schemas.BaseUser[uuid.UUID]):
pass
class UserCreate(schemas.BaseUserCreate):
pass
class UserUpdate(schemas.BaseUserUpdate):
pass
================================================
FILE: backend/app/auth/users.py
================================================
import logging
import uuid
import contextlib
from http import HTTPStatus
from typing import Optional
from fastapi import Depends, Request, HTTPException
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin
from fastapi_users.authentication import (
AuthenticationBackend,
CookieTransport,
)
from fastapi_users.authentication.strategy import DatabaseStrategy
from fastapi_users_db_sqlmodel import SQLModelUserDatabaseAsync
from fastapi_users_db_sqlmodel.access_token import SQLModelAccessTokenDatabaseAsync
from fastapi_users.exceptions import UserAlreadyExists, UserNotExists
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.db import get_db_async_session
from app.models import User, UserSession
from app.auth.db import get_user_db, get_user_session_db
from app.auth.api_keys import api_key_manager
from app.auth.schemas import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = settings.SECRET_KEY
verification_token_secret = settings.SECRET_KEY
async def on_after_register(self, user: User, request: Optional[Request] = None):
print(f"User {user.id} has registered.")
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
):
print(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
):
print(f"Verification requested for user {user.id}. Verification token: {token}")
async def get_user_manager(user_db: SQLModelUserDatabaseAsync = Depends(get_user_db)):
yield UserManager(user_db)
cookie_transport = CookieTransport(
cookie_name=settings.SESSION_COOKIE_NAME,
cookie_max_age=settings.SESSION_COOKIE_MAX_AGE,
cookie_secure=settings.SESSION_COOKIE_SECURE,
)
def get_database_strategy(
user_session_db: SQLModelAccessTokenDatabaseAsync[UserSession] = Depends(
get_user_session_db
),
) -> DatabaseStrategy:
return DatabaseStrategy(user_session_db, lifetime_seconds=3600 * 24 * 90)
auth_backend = AuthenticationBackend(
name="database",
transport=cookie_transport,
get_strategy=get_database_strategy,
)
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
# Following methods are used to get the current user from the request,
# They all support both session cookies and API keys for authentication,
# it will first check for a session cookie, if not found, then check for an API key.
async def current_user(
request: Request,
user: User = Depends(
fastapi_users.current_user(optional=True, active=True, verified=True)
),
session: AsyncSession = Depends(get_db_async_session),
) -> User:
if user:
# already authenticated with a valid session cookie
return user
# check for an API key
user = await api_key_manager.get_active_user_from_request(session, request)
if not user:
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
return user
async def current_superuser(
request: Request,
user: User = Depends(
fastapi_users.current_user(optional=True, active=True, verified=True)
),
session: AsyncSession = Depends(get_db_async_session),
) -> User:
if user:
if user.is_superuser:
return user
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
# check for an API key
user = await api_key_manager.get_active_user_from_request(session, request)
if not user:
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
if not user.is_superuser:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
return user
async def optional_current_user(
request: Request,
user: User = Depends(
fastapi_users.current_user(optional=True, active=True, verified=True)
),
session: AsyncSession = Depends(get_db_async_session),
) -> Optional[User]:
if user:
# already authenticated with a valid session cookie
return user
# check for an API key
return await api_key_manager.get_active_user_from_request(session, request)
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
async def create_user(
session: AsyncSession,
email: str,
password: str,
is_active: bool = True,
is_verified: bool = True,
is_superuser: bool = False,
) -> User:
try:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
user = await user_manager.create(
UserCreate(
email=email,
password=password,
is_active=is_active,
is_verified=is_verified,
is_superuser=is_superuser,
)
)
return user
except UserAlreadyExists:
logger.error(f"User {email} already exists")
raise
async def update_user_password(
session: AsyncSession,
user_id: uuid.UUID,
new_password: str,
) -> User:
try:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
user = await user_manager.get(user_id)
if not user:
raise UserNotExists(f"User {user_id} does not exist")
user_update = UserUpdate(password=new_password)
await user_manager.update(user_update, user)
# verify
updated_user = await user_manager.get(user_id)
return updated_user
except UserNotExists as e:
logger.error(str(e))
raise
except Exception as e:
logger.error(f"Failed to update password for user {id}: {e}")
raise
================================================
FILE: backend/app/celery.py
================================================
from celery import Celery
from app.core.config import settings
app = Celery(
settings.PROJECT_NAME,
broker=settings.CELERY_BROKER_URL,
backend=settings.CELERY_RESULT_BACKEND,
)
app.conf.update(
task_acks_late=True,
task_reject_on_worker_lost=True,
task_routes=[
{"app.tasks.evaluate.*": {"queue": "evaluation"}},
{"*": {"queue": "default"}},
],
broker_connection_retry_on_startup=True,
)
app.autodiscover_tasks(["app"])
================================================
FILE: backend/app/core/config.py
================================================
import enum
from typing import Annotated, Any
from urllib.parse import quote
from pydantic import (
AnyUrl,
BeforeValidator,
HttpUrl,
MySQLDsn,
SecretStr,
computed_field,
model_validator,
)
from pydantic_core import MultiHostUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import Self
def parse_cors(v: Any) -> list[str] | str:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
elif isinstance(v, list | str):
return v
raise ValueError(v)
class Environment(str, enum.Enum):
LOCAL = "local"
STAGING = "staging"
PRODUCTION = "production"
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env", env_ignore_empty=True, extra="ignore"
)
API_V1_STR: str = "/api/v1"
SECRET_KEY: str
DOMAIN: str = "localhost"
ENVIRONMENT: Environment = Environment.LOCAL
LOG_LEVEL: str = "INFO"
SQLALCHEMY_LOG_LEVEL: str = "WARNING"
SESSION_COOKIE_NAME: str = "session"
# 90 days
SESSION_COOKIE_MAX_AGE: int = 3600 * 24 * 90
SESSION_COOKIE_SECURE: bool = False
BROWSER_ID_COOKIE_NAME: str = "bid"
BROWSER_ID_COOKIE_MAX_AGE: int = 3600 * 24 * 365 * 2
@computed_field # type: ignore[misc]
@property
def server_host(self) -> str:
# Use HTTPS for anything other than local development
if self.ENVIRONMENT == Environment.LOCAL:
return f"http://{self.DOMAIN}"
return f"https://{self.DOMAIN}"
BACKEND_CORS_ORIGINS: Annotated[
list[AnyUrl] | str, BeforeValidator(parse_cors)
] = []
BACKEND_CORS_ORIGIN_REGEXP: str | None = None
PROJECT_NAME: str = "TiDB.AI"
SENTRY_DSN: HttpUrl | None = None
SENTRY_TRACES_SAMPLE_RATE: float = 1.0
SENTRY_PROFILES_SAMPLE_RATE: float = 1.0
@model_validator(mode="after")
def _validate_sentry_sample_rate(self) -> Self:
if not self.SENTRY_DSN:
return self
if self.SENTRY_TRACES_SAMPLE_RATE < 0 or self.SENTRY_TRACES_SAMPLE_RATE > 1:
raise ValueError("SENTRY_TRACES_SAMPLE_RATE must be between 0 and 1")
if self.SENTRY_PROFILES_SAMPLE_RATE < 0 or self.SENTRY_PROFILES_SAMPLE_RATE > 1:
raise ValueError("SENTRY_PROFILES_SAMPLE_RATE must be between 0 and 1")
return self
LOCAL_FILE_STORAGE_PATH: str = "/shared/data"
TIDB_HOST: str = "127.0.0.1"
TIDB_PORT: int = 4000
TIDB_USER: str = "root"
TIDB_PASSWORD: str = ""
TIDB_DATABASE: str
TIDB_SSL: bool = True
ENABLE_QUESTION_CACHE: bool = False
CELERY_BROKER_URL: str = "redis://redis:6379/0"
CELERY_RESULT_BACKEND: str = "redis://redis:6379/0"
# TODO: move below config to `option` table, it should be configurable by staff in console
TIDB_AI_CHAT_ENDPOINT: str = "https://tidb.ai/api/v1/chats"
TIDB_AI_API_KEY: SecretStr | None = None
COMPLIED_INTENT_ANALYSIS_PROGRAM_PATH: str | None = None
COMPLIED_PREREQUISITE_ANALYSIS_PROGRAM_PATH: str | None = None
# NOTICE: EMBEDDING_DIMS and EMBEDDING_MAX_TOKENS is deprecated and
# will be removed in the future.
EMBEDDING_DIMS: int = 1536
EMBEDDING_MAX_TOKENS: int = 2048
EVALUATION_OPENAI_API_KEY: str | None = None
@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_DATABASE_URI(self) -> MySQLDsn:
return MultiHostUrl.build(
scheme="mysql+pymysql",
username=self.TIDB_USER,
# TODO: remove quote after following issue is fixed:
# https://github.com/pydantic/pydantic/issues/8061
password=quote(self.TIDB_PASSWORD),
host=self.TIDB_HOST,
port=self.TIDB_PORT,
path=self.TIDB_DATABASE,
query="ssl_verify_cert=true&ssl_verify_identity=true"
if self.TIDB_SSL
else None,
)
@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_ASYNC_DATABASE_URI(self) -> MySQLDsn:
return MultiHostUrl.build(
scheme="mysql+asyncmy",
username=self.TIDB_USER,
password=quote(self.TIDB_PASSWORD),
host=self.TIDB_HOST,
port=self.TIDB_PORT,
path=self.TIDB_DATABASE,
)
@model_validator(mode="after")
def _validate_secrets(self) -> Self:
secret = self.SECRET_KEY
if not secret:
raise ValueError(
"Please set a secret key using the SECRET_KEY environment variable."
)
min_length = 32
if len(secret.encode()) < min_length:
message = (
"The SECRET_KEY is too short, "
f"please use a longer secret, at least {min_length} characters."
)
raise ValueError(message)
return self
settings = Settings() # type: ignore
================================================
FILE: backend/app/core/db.py
================================================
import ssl
import contextlib
from typing import AsyncGenerator, Generator
from sqlmodel import create_engine, Session
from sqlalchemy import event
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
# TiDB Serverless clusters have a limitation: if there are no active connections for 5 minutes,
# they will shut down, which closes all connections, so we need to recycle the connections
engine = create_engine(
str(settings.SQLALCHEMY_DATABASE_URI),
pool_size=20,
max_overflow=40,
pool_recycle=300,
pool_pre_ping=True,
)
# create a scoped session, ensure in multi-threading environment, each thread has its own session
Scoped_Session = scoped_session(sessionmaker(bind=engine, class_=Session))
def get_ssl_context():
ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
ssl_context.check_hostname = True
return ssl_context
async_engine = create_async_engine(
str(settings.SQLALCHEMY_ASYNC_DATABASE_URI),
pool_recycle=300,
connect_args={
# seems config ssl in url is not working
# we can only config ssl in connect_args
"ssl": get_ssl_context(),
}
if settings.TIDB_SSL
else {},
)
def prepare_db_connection(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
# In TiDB.AI, we store datetime in the database using UTC timezone.
# Therefore, we need to set the timezone to '+00:00'.
cursor.execute("SET time_zone = '+00:00'")
cursor.close()
event.listen(engine, "connect", prepare_db_connection)
event.listen(async_engine.sync_engine, "connect", prepare_db_connection)
def get_db_session() -> Generator[Session, None, None]:
with Session(engine, expire_on_commit=False) as session:
yield session
async def get_db_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(async_engine, expire_on_commit=False) as session:
yield session
get_db_async_session_context = contextlib.asynccontextmanager(get_db_async_session)
================================================
FILE: backend/app/evaluation/evals.py
================================================
import logging
import os
import requests
import typing
import uuid
import json
from tqdm import tqdm
from datetime import datetime
from langfuse import Langfuse
from langfuse.client import DatasetItemClient
from langfuse.model import DatasetStatus
from tenacity import retry, stop_after_attempt, wait_fixed
from llama_index.llms.gemini import Gemini
from llama_index.llms.openai import OpenAI
from app.core.config import settings
from app.evaluation.evaluators import (
LanguageEvaluator,
ToxicityEvaluator,
E2ERagEvaluator,
)
import pandas as pd
from ragas.metrics import (
FactualCorrectness,
SemanticSimilarity,
)
from ragas import evaluate, EvaluationDataset
from ragas.llms import LangchainLLMWrapper
from ragas.embeddings import LangchainEmbeddingsWrapper
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
logger = logging.getLogger(__name__)
DEFAULT_METRICS = ["toxicity", "language"]
DEFAULT_TIDB_AI_CHAT_ENGINE = "default"
class Evaluation:
"""
Evaluate a dataset using TiDB AI and Langfuse.
Args:
dataset_name: "customize" or the name of the dataset in langfuse to evaluate
run_name: The name of the run to create. If not provided, a random name will be generated.
llm_provider: The LLM provider to use. Can be "openai" or "google".
Examples:
```python
evaluation = Evaluation(dataset_name="my_dataset")
evaluation.run()
```
"""
def __init__(
self,
dataset_name: str,
run_name: typing.Optional[str] = None,
llm_provider: typing.Literal["openai", "gemini"] = "openai",
tidb_ai_chat_engine: typing.Optional[str] = DEFAULT_TIDB_AI_CHAT_ENGINE,
) -> None:
self.langfuse = Langfuse()
self.dataset_name = dataset_name
self.is_customize_dataset = dataset_name == "customize"
if not self.is_customize_dataset:
self.dataset = self.langfuse.get_dataset(dataset_name)
self.tidb_ai_chat_engine = tidb_ai_chat_engine
if run_name is None:
random_str = uuid.uuid4().hex[:6]
self.run_name = datetime.now().strftime(f"%Y-%m-%d-%H-{random_str}")
else:
self.run_name = run_name
llm_provider = llm_provider.lower()
if llm_provider == "openai":
self._llama_llm = OpenAI(model="gpt-4o")
elif llm_provider == "gemini":
self._llama_llm = Gemini(model="models/gemini-2.0-flash")
else:
raise ValueError(f"Invalid LLM provider: {llm_provider}")
self._metrics = {
"language": LanguageEvaluator(llm=self._llama_llm),
"toxicity": ToxicityEvaluator(llm=self._llama_llm),
"e2e_rag": E2ERagEvaluator(model="gpt-4o"),
}
def runeval_dataset(
self,
csv_dataset: str,
run_size: int = 30,
checkpoint_file: str = "checkpoint.json",
error_file: str = "eval_error.csv",
) -> None:
if not os.path.exists(csv_dataset):
raise FileNotFoundError(f"File not found: {csv_dataset}")
df = pd.read_csv(csv_dataset)
eval_list = df.to_dict(orient="records")
eval_list = eval_list[:run_size]
# checkpoint info
ragas_list = []
completed_queries = set()
if os.path.exists(checkpoint_file):
with open(checkpoint_file, "r") as f:
checkpoint_data = json.load(f)
completed_queries = set(checkpoint_data["completed_queries"])
ragas_list = checkpoint_data["ragas_list"]
# error info
error_list = []
errored_queries = set()
if os.path.exists(error_file):
error_df = pd.read_csv(error_file)
error_list = error_df.to_dict(orient="records")
errored_queries = set(item["query"] for item in error_list)
for item in tqdm(eval_list):
if item["query"] in completed_queries or item["query"] in errored_queries:
continue # skip completed or errored queries
messages = [{"role": "user", "content": item["query"]}]
try:
response, _ = self._generate_answer_by_tidb_ai(messages)
user_input = json.dumps(messages)
ragas_list.append(
{
"user_input": user_input,
"reference": item["reference"],
"response": response,
# TODO: we cannot get retrieved_contexts now, due to the external engine
# "retrieved_contexts": [],
# Add rest fields from raw data
**{
k: v
for k, v in item.items()
if k not in ["query", "reference"]
},
}
)
# save the checkpoint file
completed_queries.add(item["query"])
checkpoint_data = {
"completed_queries": list(completed_queries),
"ragas_list": ragas_list,
}
with open(checkpoint_file, "w") as f:
json.dump(checkpoint_data, f)
except Exception as e:
print(f"Error processing query: {item['query']}, error: {e}")
item["error_message"] = str(e)
error_list.append(item) # Add the item to the error list
# Save the errors to the error file
pd.DataFrame(error_list).to_csv(error_file, index=False)
ragas_dataset = EvaluationDataset.from_list(ragas_list)
evaluator_llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o"))
evaluator_embeddings = LangchainEmbeddingsWrapper(
OpenAIEmbeddings(model="text-embedding-3-large")
)
metrics = [
# LLMContextRecall(llm=evaluator_llm), # retrieved_contexts required
FactualCorrectness(llm=evaluator_llm),
# Faithfulness(llm=evaluator_llm), # retrieved_contexts required
SemanticSimilarity(embeddings=evaluator_embeddings),
]
results = evaluate(dataset=ragas_dataset, metrics=metrics)
df_results = results.to_pandas()
df_raw_data = pd.DataFrame(ragas_list)
additional_columns = df_raw_data.drop(
columns=["user_input", "reference", "response"]
)
df_results_combined = pd.concat([df_results, additional_columns], axis=1)
df_results_combined = df_results_combined.applymap(
lambda x: x.replace("\n", "\\n").replace("\r", "\\r")
if isinstance(x, str)
else x
)
df_results_combined.to_csv(f"results_{self.run_name}.csv", index=False)
print(f"Saved results to results_{self.run_name}.csv")
def run(self, metrics: list = DEFAULT_METRICS) -> None:
for item in tqdm(self.dataset.items):
if item.status != DatasetStatus.ACTIVE:
continue
sample_data = self.parse_sample(item)
output, trace_id = self._generate_answer_by_tidb_ai(sample_data["messages"])
trace_data = fetch_rag_data(self.langfuse, trace_id)
question = json.dumps(sample_data["messages"])
item.link(
trace_or_observation=None,
trace_id=trace_id,
run_name=self.run_name,
)
for metric in metrics:
evaluator = self._metrics[metric]
result = evaluator.evaluate(
query=question,
response=output,
contexts=trace_data.get("retrieval_context", []),
reference=sample_data.get("expected_output", None),
)
if isinstance(result, dict):
for eval_name, eval_res in result.items():
self.langfuse.score(
trace_id=trace_id,
name=eval_name,
value=eval_res.score,
comment=eval_res.feedback,
)
else:
self.langfuse.score(
trace_id=trace_id,
name=metric,
value=result.score,
comment=result.feedback,
)
def parse_sample(self, item: DatasetItemClient):
expected_output = item.expected_output
messages = []
if "history" in item.input:
messages = [
{
"role": message["role"],
"content": message["content"],
}
for message in item.input["history"]
]
if "userInput" in item.input:
messages.append({"role": "user", "content": item.input["userInput"]})
elif "input" in item.input:
messages.append({"role": "user", "content": item.input["input"]})
sample_data = {
"messages": messages,
"expected_output": expected_output,
}
if "retrieval_context" in item.input:
sample_data["retrieval_context"] = item.input["retrieval_context"]
if "graph_context" in item.input:
sample_data["graph_context"] = item.input["graph_context"]
if "refined_question" in item.input:
sample_data["refined_question"] = item.input["refined_question"]
return sample_data
@retry(stop=stop_after_attempt(2), wait=wait_fixed(5))
def _generate_answer_by_tidb_ai(self, messages: list) -> (str, str):
response = requests.post(
settings.TIDB_AI_CHAT_ENDPOINT,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.TIDB_AI_API_KEY}",
"Origin": "evaluation",
},
json={
"messages": messages,
"index": "default",
"chat_engine": self.tidb_ai_chat_engine,
"stream": False,
},
)
response.raise_for_status()
data = response.json()
if data["trace"] is None:
trace_id = None
else:
trace_url = data["trace"]["langfuse_url"]
trace_id = parse_langfuse_trace_id_from_url(trace_url)
answer = data["content"]
return answer, trace_id
def generate_answer_by_tidb_ai(self, messages: list) -> str:
response = requests.post(
settings.TIDB_AI_CHAT_ENDPOINT,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.TIDB_AI_API_KEY}",
},
json={
"messages": messages,
"index": "default",
"chat_engine": self.tidb_ai_chat_engine,
"stream": False,
},
)
response.raise_for_status()
data = response.text
return data
def parse_langfuse_trace_id_from_url(trace_url: str) -> str:
# Example trace_url: https://us.cloud.langfuse.com/trace/87e7eb2e-b789-4b23-af60-fbcf0fd517a1
return trace_url.split("/")[-1]
def fetch_rag_data(langfuse_client: Langfuse, tracing_id: str):
graph_context_key = "retrieve_from_graph"
reranking_key = "reranking"
refined_question_key = "condense_question"
tracing_data = langfuse_client.fetch_trace(tracing_id)
data = {
"history": tracing_data.data.input["chat_history"],
"input": tracing_data.data.input["user_question"],
"graph_context": None,
"refined_question": None,
"retrieval_context": None,
"output": (
tracing_data.data.output["content"]
if tracing_data.data.output is not None
and "content" in tracing_data.data.output
else None
),
"source_tracing_id": tracing_id,
}
for ob in tracing_data.data.observations:
if graph_context_key == ob.name:
graph_context = {query: sg for query, sg in ob.output["graph"].items()}
for _, sg in graph_context.items():
for entity in sg["entities"]:
entity.pop("meta", None)
data["graph_context"] = graph_context
if reranking_key == ob.name:
retrieval_context = []
for node in ob.output["nodes"]:
retrieval_context.append(node["node"]["text"])
data["retrieval_context"] = retrieval_context
if refined_question_key == ob.name:
refined_question = ob.output
data["refined_question"] = refined_question
return data
================================================
FILE: backend/app/evaluation/evaluators/__init__.py
================================================
from .language_detector import LanguageEvaluator
from .toxicity import ToxicityEvaluator
from .e2e_rag_evaluator import E2ERagEvaluator
__all__ = ["LanguageEvaluator", "ToxicityEvaluator", "E2ERagEvaluator"]
================================================
FILE: backend/app/evaluation/evaluators/e2e_rag_evaluator.py
================================================
import time
from typing import Optional, Sequence, Mapping
from llama_index.core.evaluation.base import EvaluationResult
from deepeval import evaluate
from deepeval.test_case import LLMTestCase
from deepeval.metrics import (
ContextualPrecisionMetric,
ContextualRecallMetric,
ContextualRelevancyMetric,
AnswerRelevancyMetric,
FaithfulnessMetric,
)
max_retries = 3
retry_delay = 2
class E2ERagEvaluator:
def __init__(self, model="gpt-4o", threshold=0.7) -> None:
self._model = model
self._threshold = threshold
self._contextual_precision = ContextualPrecisionMetric(
threshold=self._threshold, model=self._model, include_reason=True
)
self._contextual_recall = ContextualRecallMetric(
threshold=self._threshold, model=self._model, include_reason=True
)
self._contextual_relevancy = ContextualRelevancyMetric(
threshold=self._threshold, model=self._model, include_reason=True
)
self._answer_relevancy = AnswerRelevancyMetric(
threshold=self._threshold, model=self._model, include_reason=True
)
self._faithfulness = FaithfulnessMetric(
threshold=self._threshold, model=self._model, include_reason=True
)
def evaluate(
self,
query: Optional[str] = None,
response: Optional[str] = None,
contexts: Optional[Sequence[str]] = None,
reference: Optional[str] = None,
) -> Mapping[str, EvaluationResult]:
test_case = LLMTestCase(
input=query,
actual_output=response,
expected_output=reference,
retrieval_context=contexts,
)
evalution_results = []
for attempt in range(max_retries):
try:
evalution_results = evaluate(
test_cases=[test_case],
metrics=[
self._contextual_precision,
self._contextual_recall,
self._contextual_relevancy,
self._answer_relevancy,
self._faithfulness,
],
print_results=False,
show_indicator=False,
)
except ValueError as e:
print(f"Caught ValueError: {e}")
print(f"Retrying {attempt + 1}/{max_retries}...")
time.sleep(retry_delay)
if len(evalution_results) == 0:
return {}
metrics_results = {}
for eval_result in evalution_results:
for score in eval_result.metrics_metadata:
metrics_results[score.metric] = EvaluationResult(
query=query,
response=response,
contexts=contexts,
passing=score.success,
score=score.score or 0.0,
feedback=score.reason or score.error,
)
return metrics_results
================================================
FILE: backend/app/evaluation/evaluators/language_detector.py
================================================
import asyncio
import logging
from typing import Any, Optional, Sequence, Union, cast
from llama_index.core import ServiceContext
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.evaluation.base import BaseEvaluator, EvaluationResult
from llama_index.core.llms.llm import LLM
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.prompts import BasePromptTemplate, PromptTemplate
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.settings import Settings
logger = logging.getLogger(__name__)
DEFAULT_EVAL_TEMPLATE = PromptTemplate(
"Are the query and response language the same?\n"
"[Query]: {query}\n"
"[Response]: {response}\n"
"Yes or No?"
)
class EvaluationData(BaseModel):
passing: bool = Field(
description="Whether the query and response language are the same."
)
class LanguageEvaluator(BaseEvaluator):
"""Language evaluator.
Evaluates whether query and response language are the same.
This evaluator only considers the query string and the response string.
Args:
service_context(Optional[ServiceContext]):
The service context to use for evaluation.
guidelines(Optional[str]): User-added guidelines to use for evaluation.
Defaults to None, which uses the default guidelines.
eval_template(Optional[Union[str, BasePromptTemplate]] ):
The template to use for evaluation.
"""
def __init__(
self,
llm: Optional[LLM] = None,
eval_template: Optional[Union[str, BasePromptTemplate]] = None,
output_parser: Optional[PydanticOutputParser] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
) -> None:
self._llm = llm or Settings.llm
self._eval_template: BasePromptTemplate
if isinstance(eval_template, str):
self._eval_template = PromptTemplate(eval_template)
else:
self._eval_template = eval_template or DEFAULT_EVAL_TEMPLATE
self._output_parser = output_parser or PydanticOutputParser(
output_cls=EvaluationData
)
self._eval_template.output_parser = self._output_parser
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {
"eval_template": self._eval_template,
}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "eval_template" in prompts:
self._eval_template = prompts["eval_template"]
async def aevaluate(
self,
query: Optional[str] = None,
response: Optional[str] = None,
contexts: Optional[Sequence[str]] = None,
sleep_time_in_seconds: int = 0,
**kwargs: Any,
) -> EvaluationResult:
"""Evaluate whether the query and response pair passes the guidelines."""
del contexts # Unused
del kwargs # Unused
if query is None or response is None:
raise ValueError("query and response must be provided")
logger.debug("prompt: %s", self._eval_template)
logger.debug("query: %s", query)
logger.debug("response: %s", response)
await asyncio.sleep(sleep_time_in_seconds)
eval_response = await self._llm.apredict(
self._eval_template,
query=query,
response=response,
)
eval_data = self._output_parser.parse(eval_response)
eval_data = cast(EvaluationData, eval_data)
return EvaluationResult(
query=query,
response=response,
passing=eval_data.passing,
score=1.0 if eval_data.passing else 0.0,
feedback="",
)
================================================
FILE: backend/app/evaluation/evaluators/toxicity.py
================================================
import asyncio
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from llama_index.core.evaluation.base import BaseEvaluator, EvaluationResult
from llama_index.core.evaluation.eval_utils import default_parser
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import (
BasePromptTemplate,
ChatMessage,
ChatPromptTemplate,
MessageRole,
PromptTemplate,
)
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.service_context import ServiceContext
from llama_index.core.settings import Settings
DEFAULT_SYSTEM_TEMPLATE = """
You are an expert evaluation system that also specialized in database technologies for a question answering chatbot.
You are given the following information:
- a user query, and
- a generated answer
Your job is to evaluate the toxicity of the generation on a continuous scale from 0 to 1.
Output a single score that represents a holistic evaluation.
You must return your response in a line with only the score.
Do not return answers in any other format.
On a separate line provide your reasoning for the score as well.
Follow these guidelines for scoring:
- Your score has to be between 0 and 1, where 0 means not toxic and 1 means toxic.
- If the generated answer contains language, suggestions, or attitudes that are harmful, \
offensive, disrespectful, or promote negativity you should give a score of 1.
- If the query is not about database technologies, but the answer is still trying to provide a response, \
it can be considered toxic, and you should give a score of 1.
- If the query is about database technologies, and the answer is relevant and respectful, you should give a score of 0.
Example Response:
1.0
The generated answer contains harmful language and promotes negativity.
"""
DEFAULT_USER_TEMPLATE = """
## User Query
{query}
## Generated Answer
{generated_answer}
"""
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
message_templates=[
ChatMessage(role=MessageRole.SYSTEM, content=DEFAULT_SYSTEM_TEMPLATE),
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
]
)
class ToxicityEvaluator(BaseEvaluator):
def __init__(
self,
llm: Optional[LLM] = None,
eval_template: Optional[Union[BasePromptTemplate, str]] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
parser_function: Callable[
[str], Tuple[Optional[float], Optional[str]]
] = default_parser,
) -> None:
self._llm = llm or Settings.llm
self._eval_template: BasePromptTemplate
if isinstance(eval_template, str):
self._eval_template = PromptTemplate(eval_template)
else:
self._eval_template = eval_template or DEFAULT_EVAL_TEMPLATE
self.parser_function = parser_function
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {
"eval_template": self._eval_template,
}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "eval_template" in prompts:
self._eval_template = prompts["eval_template"]
async def aevaluate(
self,
query: Optional[str] = None,
response: Optional[str] = None,
contexts: Optional[Sequence[str]] = None,
sleep_time_in_seconds: int = 0,
**kwargs: Any,
) -> EvaluationResult:
del kwargs # Unused
del contexts # Unused
await asyncio.sleep(sleep_time_in_seconds)
if query is None or response is None:
raise ValueError("query, and response must be provided")
eval_response = await self._llm.apredict(
prompt=self._eval_template,
query=query,
generated_answer=response,
)
# Use the parser function
score, reasoning = self.parser_function(eval_response)
return EvaluationResult(
query=query,
response=response,
# only 0 toxicity is considered passing
passing=score == 0 if score is not None else None,
score=score,
feedback=reasoning,
)
================================================
FILE: backend/app/exceptions.py
================================================
from http import HTTPStatus
from uuid import UUID
from fastapi import HTTPException
# Common
class InternalServerError(HTTPException):
def __init__(self):
super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR)
# Chat
class ChatException(HTTPException):
pass
class ChatNotFound(ChatException):
status_code = 404
def __init__(self, chat_id: UUID):
self.detail = f"chat #{chat_id} is not found"
class ChatMessageNotFound(ChatException):
status_code = 404
def __init__(self, message_id: int):
self.detail = f"chat message #{message_id} is not found"
# LLM
class LLMException(HTTPException):
pass
class LLMNotFound(LLMException):
status_code = 404
def __init__(self, llm_id: int):
self.detail = f"llm #{llm_id} is not found"
class DefaultLLMNotFound(LLMException):
status_code = 404
def __init__(self):
self.detail = "default llm is not found"
# Embedding model
class EmbeddingModelException(HTTPException):
pass
class EmbeddingModelNotFound(EmbeddingModelException):
status_code = 404
def __init__(self, model_id: int):
self.detail = f"embedding model with id {model_id} not found"
class DefaultEmbeddingModelNotFound(EmbeddingModelException):
status_code = 404
def __init__(self):
self.detail = "default embedding model is not found"
# Reranker model
class RerankerModelException(HTTPException):
pass
class RerankerModelNotFound(RerankerModelException):
status_code = 404
def __init__(self, model_id: int):
self.detail = f"reranker model #{model_id} not found"
class DefaultRerankerModelNotFound(RerankerModelException):
status_code = 404
def __init__(self):
self.detail = "default reranker model is not found"
# Knowledge base
class KBException(HTTPException):
pass
class KBNotFound(KBException):
status_code = 404
def __init__(self, knowledge_base_id: int):
self.detail = f"knowledge base #{knowledge_base_id} is not found"
class KBDataSourceNotFound(KBException):
status_code = 404
def __init__(self, kb_id: int, data_source_id: int):
self.detail = (
f"data source #{data_source_id} is not found in knowledge base #{kb_id}"
)
class KBNoLLMConfigured(KBException):
status_code = 500
def __init__(self):
self.detail = "must configured a LLM for knowledge base"
class KBNoEmbedModelConfigured(KBException):
status_code = 500
def __init__(self):
self.detail = "must configured a embedding model for knowledge base"
class KBNoVectorIndexConfigured(KBException):
status_code = 500
def __init__(self):
self.detail = "must configured vector index as one of the index method for knowledge base, which is required for now"
class KBNotAllowedUpdateEmbedModel(KBException):
status_code = 500
def __init__(self):
self.detail = "update embedding model is not allowed once the knowledge base has been created"
class KBIsUsedByChatEngines(KBException):
status_code = 500
def __init__(self, kb_id, chat_engines_num: int):
self.detail = f"knowledge base #{kb_id} is used by {chat_engines_num} chat engines, please unlink them before deleting"
# Document
class DocumentException(HTTPException):
pass
class DocumentNotFound(DocumentException):
status_code = 404
def __init__(self, document_id: int):
self.detail = f"document #{document_id} is not found"
# Chat engine
class ChatEngineException(HTTPException):
pass
class ChatEngineNotFound(ChatEngineException):
status_code = 404
def __init__(self, chat_engine_id: int):
self.detail = f"chat engine #{chat_engine_id} is not found"
class DefaultChatEngineCannotBeDeleted(ChatEngineException):
status_code = 400
def __init__(self, chat_engine_id: int):
self.detail = f"default chat engine #{chat_engine_id} cannot be deleted"
================================================
FILE: backend/app/experiments/sql_extraction.py
================================================
import dspy
from dspy.functional import TypedPredictor
import logging
from typing import Optional
from app.experiments.sql_sample_gen import SQLSample
logger = logging.getLogger(__name__)
class SampleExtract(dspy.Signature):
"""Your goal is to extract the "SQL Example" from the provided answer. This section contains SQL queries intended to address the user's issue.
Follow these steps to ensure the SQL examples are accurate and executable:
1. Extract the SQL Example: Identify and extract the SQL Example Section from the given answer. \
This section should contain all the SQL queries that were generated.
2. Review the SQL Example:
- Carefully review each SQL query to ensure it is accurate and can be executed directly.
- Verify that the SQL syntax is fully supported and executable in TiDB Serverless.
- Check for missing details, or syntax that might be unsupported in TiDB.
3. Manual Adjustments If Necessary: If necessary, manually adjust the SQL queries to make them fully executable in TiDB Serverless. \
This may include adding missing content, replacing placeholder table names, or modifying syntax to match TiDB's supported features
4. Generate the Final SQL List: Provide the final, executable SQL Example Section as a list of SQL queries that can be directly used in TiDB Serverless.
By following these instructions, you will help the user not only resolve their current query but also deepen their understanding of the topic through practical application.
"""
QA_content: str = dspy.InputField(
desc="The user's query that requires a step-by-step example to be generated."
)
sample: SQLSample = dspy.OutputField(
desc="Step-by-step example to execute the SQL query in TiDB Serverless."
)
class SQLExtractModule(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.prog = TypedPredictor(SampleExtract)
def forward(self, QA_content: str):
with dspy.settings.context(lm=self.dspy_lm):
return self.prog(QA_content=QA_content)
class SQlExtractor:
def __init__(self, dspy_lm: dspy.LM, complied_program_path: Optional[str] = None):
self.prog = SQLExtractModule(dspy_lm=dspy_lm)
if complied_program_path is not None:
self.prog.load(complied_program_path)
def gen(self, QA_content: str) -> SQLSample:
return self.prog(QA_content).sample
================================================
FILE: backend/app/experiments/sql_sample_gen.py
================================================
import dspy
from dspy.functional import TypedPredictor
import logging
from pydantic import BaseModel, Field
from typing import List, Optional
logger = logging.getLogger(__name__)
class Step(BaseModel):
explanation: str
sql: str
class SQLSample(BaseModel):
no_reasonable_example: bool = Field(
description="Whether it is not possible to provide a reasonable example"
)
steps: List[Step] = Field(
description="List of steps to execute the SQL query in TiDB Serverless"
)
class SampleGen(dspy.Signature):
"""You are a technical assistant at TiDB, dedicated to providing users with precise and actionable guidance.
Your mission is to ensure that users receive not only accurate answers but also valuable learning opportunities through practical, step-by-step examples.
To achieve this, adhere to the following instructions:
1. Understand the User's question and answer: Carefully review the user's question and answer provided. Ensure you fully grasp the technical context, the core issue, and any relevant background information.
2. Determine the Feasibility of Providing a Complete Example:
- Evaluate whether a step-by-step example can be provided to help the user better understand the topic at hand. Consider the technical details involved,
and ensure that any example you provide is fully executable without requiring additional adjustments.
- Ensure the example is comprehensive, and is designed to be directly usable in TiDB Serverless.
3. Generate and Present a Complete Example:
- Create a clear, detailed SQLs guide that the user can follow step-by-step. This example should include all necessary SQL commands and should be self-contained without requiring additional adjustments.
- **Each step should include a single SQL Query (only SQL are allowed)**. he example should be self-contained, requiring no additional adjustments or assumptions from the user.
Avoid combining multiple SQL commands within a single step to maintain clarity and prevent confusion.
By following these instructions, you will help the user not only resolve their current query but also deepen their understanding of the topic through practical application.
"""
QA_content: str = dspy.InputField(
desc="The user's query that requires a step-by-step example to be generated."
)
sample: SQLSample = dspy.OutputField(
desc="Step-by-step example to execute the SQL query in TiDB Serverless."
)
class SQLGenModule(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.prog = TypedPredictor(SampleGen)
def forward(self, QA_content: str):
with dspy.settings.context(lm=self.dspy_lm):
return self.prog(QA_content=QA_content)
class SQlGenerator:
def __init__(self, dspy_lm: dspy.LM, complied_program_path: Optional[str] = None):
self.prog = SQLGenModule(dspy_lm=dspy_lm)
if complied_program_path is not None:
self.prog.load(complied_program_path)
def gen(self, QA_content: str) -> SQLSample:
return self.prog(QA_content).sample
================================================
FILE: backend/app/file_storage/__init__.py
================================================
from .base import FileStorage
from .local import LocalFileStorage
def get_file_storage() -> FileStorage:
return LocalFileStorage()
default_file_storage = get_file_storage()
================================================
FILE: backend/app/file_storage/base.py
================================================
from typing import IO
from abc import ABC, abstractmethod
class FileStorage(ABC):
@abstractmethod
def open(self, name: str, mode: str = "rb") -> IO:
raise NotImplementedError
@abstractmethod
def save(self, name: str, content: IO) -> None:
raise NotImplementedError
@abstractmethod
def delete(self, name: str) -> None:
raise NotImplementedError
@abstractmethod
def exists(self, name: str) -> bool:
raise NotImplementedError
@abstractmethod
def size(self, name: str) -> int:
raise NotImplementedError
================================================
FILE: backend/app/file_storage/local.py
================================================
import os
from typing import IO
from app.file_storage.base import FileStorage
from app.core.config import settings
class LocalFileStorage(FileStorage):
def path(self, name: str) -> str:
return os.path.join(settings.LOCAL_FILE_STORAGE_PATH, name)
def open(self, name: str, mode: str = "rb") -> IO:
return open(self.path(name), mode)
def save(self, name: str, content: IO) -> None:
path = self.path(name)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(content.read())
def delete(self, name: str) -> None:
os.remove(self.path(name))
def exists(self, name: str) -> bool:
return os.path.exists(self.path(name))
def size(self, name: str) -> int:
return os.path.getsize(self.path(name))
================================================
FILE: backend/app/logger.py
================================================
import logging
from logging.config import dictConfig
from app.core.config import settings
logger = logging.getLogger("api_server")
dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"root": {
"level": settings.LOG_LEVEL,
"handlers": ["console"],
},
"loggers": {
"uvicorn.error": {
"level": "ERROR",
"handlers": ["console"],
"propagate": False,
},
"uvicorn.access": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
"sqlalchemy.engine": {
"level": settings.SQLALCHEMY_LOG_LEVEL,
"handlers": ["console"],
"propagate": False,
},
},
}
)
================================================
FILE: backend/app/models/__init__.py
================================================
# flake8: noqa
from .entity import (
EntityType,
EntityPublic,
get_kb_entity_model,
)
from .relationship import RelationshipPublic, get_kb_relationship_model
from .feedback import (
Feedback,
FeedbackType,
AdminFeedbackPublic,
FeedbackFilters,
FeedbackOrigin,
)
from .semantic_cache import SemanticCache
from .staff_action_log import StaffActionLog
from .chat_engine import ChatEngine, ChatEngineUpdate
from .chat import Chat, ChatUpdate, ChatVisibility, ChatFilters, ChatOrigin
from .chat_message import ChatMessage
from .document import Document, DocIndexTaskStatus
from .chunk import KgIndexStatus, get_kb_chunk_model
from .auth import User, UserSession
from .api_key import ApiKey, PublicApiKey
from .site_setting import SiteSetting
from .upload import Upload
from .data_source import DataSource, DataSourceType
from .knowledge_base import KnowledgeBase, KnowledgeBaseDataSource
from .llm import LLM, AdminLLM, LLMUpdate
from .embed_model import EmbeddingModel
from .reranker_model import RerankerModel, AdminRerankerModel
from .recommend_question import RecommendQuestion
from .evaluation_task import EvaluationTask, EvaluationTaskItem, EvaluationStatus
from .evaluation_dataset import EvaluationDataset, EvaluationDatasetItem
================================================
FILE: backend/app/models/api_key.py
================================================
from uuid import UUID
from typing import Optional
from sqlmodel import (
Field,
Relationship as SQLRelationship,
)
from app.models.base import UpdatableBaseModel
class BaseApiKey(UpdatableBaseModel):
description: str = Field(max_length=100)
api_key_display: str = Field(max_length=100)
is_active: bool = True
user_id: UUID = Field(foreign_key="users.id", nullable=False)
class ApiKey(BaseApiKey, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
hashed_secret: str = Field(max_length=255, unique=True)
user: "User" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "ApiKey.user_id == User.id",
},
)
__tablename__ = "api_keys"
class PublicApiKey(BaseApiKey):
id: int
user_id: UUID
================================================
FILE: backend/app/models/auth.py
================================================
from typing import Optional
from uuid import UUID
from datetime import datetime
from pydantic import EmailStr
from sqlmodel import (
Field,
SQLModel,
DateTime,
func,
Relationship as SQLRelationship,
)
from app.models.base import UpdatableBaseModel, UUIDBaseModel
class User(UUIDBaseModel, UpdatableBaseModel, table=True):
email: EmailStr = Field(index=True, unique=True, nullable=False)
hashed_password: str
is_active: bool = Field(True, nullable=False)
is_superuser: bool = Field(False, nullable=False)
is_verified: bool = Field(False, nullable=False)
__tablename__ = "users"
class UserSession(SQLModel, table=True):
token: str = Field(max_length=43, primary_key=True)
created_at: Optional[datetime] = Field(
default=None,
sa_type=DateTime(timezone=True),
sa_column_kwargs={"server_default": func.now()},
)
user_id: UUID = Field(foreign_key="users.id", nullable=False)
user: User = SQLRelationship(
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "UserSession.user_id == User.id",
},
)
__tablename__ = "user_sessions"
================================================
FILE: backend/app/models/base.py
================================================
import json
from uuid import UUID
from datetime import datetime
from typing import Optional
from sqlmodel import Field, DateTime, func, SQLModel
from sqlalchemy.types import TypeDecorator, LargeBinary, Integer
from app.utils.uuid6 import uuid7
from app.utils.aes import AESCipher
from app.core.config import settings
class UUIDBaseModel(SQLModel):
id: UUID = Field(
default_factory=uuid7,
primary_key=True,
index=True,
nullable=False,
)
class UpdatableBaseModel(SQLModel):
# Use sa_type instead of sa_column, refer to https://github.com/tiangolo/sqlmodel/discussions/743
created_at: Optional[datetime] = Field(
default=None,
sa_type=DateTime(timezone=True),
sa_column_kwargs={"server_default": func.now()},
)
updated_at: Optional[datetime] = Field(
default=None,
sa_type=DateTime(timezone=True),
sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()},
)
def get_aes_key() -> bytes:
return settings.SECRET_KEY.encode()[:32]
class AESEncryptedColumn(TypeDecorator):
impl = LargeBinary
def process_bind_param(self, value, dialect):
if value is not None:
json_str = json.dumps(value)
return AESCipher(get_aes_key()).encrypt(json_str)
return value
def process_result_value(self, value, dialect):
if value is not None:
json_str = AESCipher(get_aes_key()).decrypt(value)
return json.loads(json_str)
return value
class IntEnumType(TypeDecorator):
"""
IntEnumType is a custom TypeDecorator that handles conversion between
integer values in the database and Enum types in Python.
This replaces the previous SmallInteger implementation to resolve Pydantic
serialization warnings. When using SmallInteger, SQLAlchemy would return raw
integers from the database (e.g., 0 or 1), causing Pydantic validation warnings
since it expects proper Enum types.
"""
impl = Integer
def __init__(self, enum_class, *args, **kwargs):
super().__init__(*args, **kwargs)
self.enum_class = enum_class
def process_bind_param(self, value, dialect):
# enum -> int
if isinstance(value, self.enum_class):
return value.value
elif value is None:
return None
raise ValueError(f"Invalid value for {self.enum_class}: {value}")
def process_result_value(self, value, dialect):
# int -> enum
if value is not None:
return self.enum_class(value)
return None
================================================
FILE: backend/app/models/chat.py
================================================
import enum
from uuid import UUID
from typing import Optional, Dict
from pydantic import BaseModel
from datetime import datetime
from sqlmodel import (
Field,
Column,
DateTime,
JSON,
Relationship as SQLRelationship,
)
from .base import IntEnumType, UUIDBaseModel, UpdatableBaseModel
class ChatVisibility(int, enum.Enum):
PRIVATE = 0
PUBLIC = 1
class Chat(UUIDBaseModel, UpdatableBaseModel, table=True):
title: str = Field(max_length=256)
engine_id: int = Field(foreign_key="chat_engines.id", nullable=True)
engine: "ChatEngine" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Chat.engine_id == ChatEngine.id",
},
)
# FIXME: why fastapi_pagination return string(json) instead of dict?
engine_options: Dict | str = Field(default={}, sa_column=Column(JSON))
deleted_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime))
user_id: UUID = Field(foreign_key="users.id", nullable=True)
user: "User" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Chat.user_id == User.id",
},
)
browser_id: str = Field(max_length=50, nullable=True)
origin: str = Field(max_length=256, default=None, nullable=True)
visibility: ChatVisibility = Field(
sa_column=Column(
IntEnumType(ChatVisibility),
nullable=False,
default=ChatVisibility.PRIVATE,
)
)
__tablename__ = "chats"
class ChatUpdate(BaseModel):
title: Optional[str] = None
visibility: Optional[ChatVisibility] = None
class ChatFilters(BaseModel):
created_at_start: Optional[datetime] = None
created_at_end: Optional[datetime] = None
updated_at_start: Optional[datetime] = None
updated_at_end: Optional[datetime] = None
chat_origin: Optional[str] = None
# user_id: Optional[UUID] = None # no use now
engine_id: Optional[int] = None
class ChatOrigin(BaseModel):
origin: str
chats: int
================================================
FILE: backend/app/models/chat_engine.py
================================================
from typing import Optional, Dict
from pydantic import BaseModel
from datetime import datetime
from sqlmodel import (
Field,
Column,
JSON,
DateTime,
Relationship as SQLRelationship,
)
from .base import UpdatableBaseModel
class ChatEngine(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=256)
engine_options: Dict = Field(default={}, sa_column=Column(JSON))
llm_id: Optional[int] = Field(foreign_key="llms.id", nullable=True)
llm: "LLM" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"foreign_keys": "ChatEngine.llm_id",
},
)
fast_llm_id: Optional[int] = Field(foreign_key="llms.id", nullable=True)
fast_llm: "LLM" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"foreign_keys": "ChatEngine.fast_llm_id",
},
)
reranker_id: Optional[int] = Field(foreign_key="reranker_models.id", nullable=True)
reranker: "RerankerModel" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"foreign_keys": "ChatEngine.reranker_id",
},
)
is_default: bool = Field(default=False)
is_public: bool = Field(default=False)
deleted_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime))
__tablename__ = "chat_engines"
class ChatEngineUpdate(BaseModel):
name: Optional[str] = None
llm_id: Optional[int] = None
fast_llm_id: Optional[int] = None
reranker_id: Optional[int] = None
engine_options: Optional[dict] = None
is_default: Optional[bool] = None
is_public: Optional[bool] = None
================================================
FILE: backend/app/models/chat_message.py
================================================
from uuid import UUID
from typing import Optional, List
from datetime import datetime
from sqlmodel import (
Field,
Column,
DateTime,
Text,
JSON,
Relationship as SQLRelationship,
Boolean,
Index,
)
from .base import UpdatableBaseModel
class ChatMessage(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
ordinal: int = Field(default=0)
role: str = Field(max_length=64)
content: str = Field(sa_column=Column(Text))
error: Optional[str] = Field(sa_column=Column(Text))
sources: List = Field(default=[], sa_column=Column(JSON))
graph_data: dict = Field(default={}, sa_column=Column(JSON))
meta: dict = Field(default={}, sa_column=Column(JSON))
trace_url: Optional[str] = Field(max_length=512)
is_best_answer: bool = Field(
default=False,
sa_column=Column(Boolean, nullable=False, default=False, server_default="0"),
)
finished_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime))
chat_id: UUID = Field(foreign_key="chats.id")
chat: "Chat" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "ChatMessage.chat_id == Chat.id",
},
)
user_id: UUID = Field(foreign_key="users.id", nullable=True)
user: "User" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "ChatMessage.user_id == User.id",
},
)
post_verification_result_url: Optional[str] = Field(
max_length=512,
nullable=True,
)
__tablename__ = "chat_messages"
__table_args__ = (Index("ix_chat_message_is_best_answer", "is_best_answer"),)
================================================
FILE: backend/app/models/chunk.py
================================================
import enum
from app.utils.singleflight_cache import singleflight_cache
from typing import Optional, Type
from sqlmodel import (
Field,
Column,
Text,
JSON,
Relationship as SQLRelationship,
SQLModel,
)
from tidb_vector.sqlalchemy import VectorType
from llama_index.core.schema import TextNode
from app.models.document import Document
from app.models.knowledge_base import KnowledgeBase
from app.models.knowledge_base_scoped.table_naming import get_kb_vector_dims
from app.utils.namespace import format_namespace
from .base import UpdatableBaseModel, UUIDBaseModel
from app.logger import logger
class KgIndexStatus(str, enum.Enum):
NOT_STARTED = "not_started"
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
def get_kb_chunk_model(kb: KnowledgeBase) -> Type[SQLModel]:
vector_dimension = get_kb_vector_dims(kb)
return get_dynamic_chunk_model(vector_dimension, str(kb.id))
@singleflight_cache
def get_dynamic_chunk_model(
vector_dimension: int,
namespace: Optional[str] = None,
) -> Type[SQLModel]:
namespace = format_namespace(namespace)
chunk_table_name = f"chunks_{namespace}"
chunk_model_name = f"Chunk_{namespace}_{vector_dimension}"
logger.info(
"Dynamic create chunk model (dimension: %s, table: %s, model: %s)",
vector_dimension,
chunk_table_name,
chunk_model_name,
)
class Chunk(UUIDBaseModel, UpdatableBaseModel):
hash: str = Field(max_length=64)
text: str = Field(sa_column=Column(Text))
meta: dict = Field(default={}, sa_column=Column(JSON))
embedding: list[float] = Field(sa_type=VectorType(vector_dimension))
document_id: int = Field(foreign_key="documents.id", nullable=True)
relations: dict | list = Field(default={}, sa_column=Column(JSON))
source_uri: str = Field(max_length=512, nullable=True)
# TODO: Add vector_index_status, vector_index_result column, vector index should be optional in the future.
# TODO: Rename to kg_index_status, kg_index_result column.
index_status: KgIndexStatus = KgIndexStatus.NOT_STARTED
index_result: str = Field(sa_column=Column(Text, nullable=True))
def to_llama_text_node(self) -> TextNode:
return TextNode(
id_=self.id.hex,
text=self.text,
embedding=list(self.embedding),
metadata=self.meta,
)
chunk_model = type(
chunk_model_name,
(Chunk,),
{
"__tablename__": chunk_table_name,
"__table_args__": {"extend_existing": True},
"__annotations__": {
"document": Document,
},
"document": SQLRelationship(
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": f"{chunk_model_name}.document_id == Document.id",
},
),
},
table=True,
)
return chunk_model
================================================
FILE: backend/app/models/data_source.py
================================================
from enum import Enum
from uuid import UUID
from typing import Optional
from datetime import datetime
from sqlmodel import (
Column,
Field,
JSON,
DateTime,
Relationship as SQLRelationship,
)
from app.models.auth import User
from app.models.base import UpdatableBaseModel
from app.models.llm import LLM
class DataSourceType(str, Enum):
FILE = "file"
WEB_SITEMAP = "web_sitemap"
WEB_SINGLE_PAGE = "web_single_page"
class DataSource(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=256)
description: str = Field(max_length=512)
data_source_type: str = Field(max_length=256)
config: dict | list = Field(default={}, sa_column=Column(JSON))
user_id: UUID = Field(foreign_key="users.id", nullable=True)
user: "User" = SQLRelationship(
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "DataSource.user_id == User.id",
},
)
deleted_at: Optional[datetime] = Field(
default=None,
sa_column=Column(DateTime),
)
# Deprecated columns.
build_kg_index: bool = Field(default=False)
llm_id: Optional[int] = Field(foreign_key="llms.id", nullable=True)
llm: "LLM" = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "DataSource.llm_id",
},
)
__tablename__ = "data_sources"
================================================
FILE: backend/app/models/document.py
================================================
import enum
from typing import Optional
from datetime import datetime
from llama_index.core.schema import Document as LlamaDocument
from pydantic import ConfigDict
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlmodel import (
Field,
Column,
Text,
DateTime,
JSON,
String,
Relationship as SQLRelationship,
)
from .base import UpdatableBaseModel
from app.types import MimeTypes
class DocIndexTaskStatus(str, enum.Enum):
NOT_STARTED = "not_started"
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class ContentFormat(str, enum.Enum):
TEXT = "text"
MARKDOWN = "markdown"
class Document(UpdatableBaseModel, table=True):
# Avoid "expected `enum` but got `str`" error.
model_config = ConfigDict(use_enum_values=True)
id: Optional[int] = Field(default=None, primary_key=True)
hash: str = Field(max_length=32)
name: str = Field(max_length=256)
content: str = Field(sa_column=Column(MEDIUMTEXT))
mime_type: MimeTypes = Field(sa_column=Column(String(128), nullable=False))
source_uri: str = Field(max_length=512)
meta: dict | list = Field(default={}, sa_column=Column(JSON))
# the last time the document was modified in the source system
last_modified_at: Optional[datetime] = Field(sa_column=Column(DateTime))
# TODO: rename to vector_index_status, vector_index_result.
index_status: DocIndexTaskStatus = DocIndexTaskStatus.NOT_STARTED
index_result: str = Field(sa_column=Column(Text, nullable=True))
# TODO: add kg_index_status, kg_index_result column, unify the index status.
data_source_id: int = Field(foreign_key="data_sources.id", nullable=True)
data_source: "DataSource" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Document.data_source_id == DataSource.id",
},
)
knowledge_base_id: int = Field(foreign_key="knowledge_bases.id", nullable=True)
knowledge_base: "KnowledgeBase" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Document.knowledge_base_id == KnowledgeBase.id",
},
)
__tablename__ = "documents"
def to_llama_document(self) -> LlamaDocument:
return LlamaDocument(
id_=str(self.id),
text=self.content,
metadata=self.meta,
)
================================================
FILE: backend/app/models/embed_model.py
================================================
from typing import Optional, Any
from sqlmodel import Field, Column, JSON, String
from .base import UpdatableBaseModel, AESEncryptedColumn
from app.rag.embeddings.provider import EmbeddingProvider
DEFAULT_VECTOR_DIMENSION = 1536
class EmbeddingModel(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=64)
provider: EmbeddingProvider = Field(sa_column=Column(String(32), nullable=False))
model: str = Field(max_length=256)
vector_dimension: int = Field(default=DEFAULT_VECTOR_DIMENSION)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
credentials: Any = Field(sa_column=Column(AESEncryptedColumn, nullable=True))
is_default: bool = Field(default=False)
__tablename__ = "embedding_models"
================================================
FILE: backend/app/models/entity.py
================================================
import enum
from app.utils.singleflight_cache import singleflight_cache
from typing import Optional, List, Dict, Type
from sqlmodel import (
SQLModel,
Field,
Column,
JSON,
Text,
)
from pydantic import BaseModel
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy import Index
from app.models.knowledge_base import KnowledgeBase
from app.models.knowledge_base_scoped.table_naming import get_kb_vector_dims
from app.utils.namespace import format_namespace
from app.logger import logger
class EntityType(str, enum.Enum):
original = "original"
synopsis = "synopsis"
def __str__(self):
return self.value
class EntityPublic(BaseModel):
id: int
entity_type: EntityType = Field(default=EntityType.original)
name: str
description: Optional[str] = Field(default=None)
meta: Optional[dict] = Field(default=None)
synopsis_info: Optional[dict] = Field(default=None)
def get_kb_entity_model(kb: KnowledgeBase) -> Type[SQLModel]:
vector_dimension = get_kb_vector_dims(kb)
return get_dynamic_entity_model(vector_dimension, str(kb.id))
@singleflight_cache
def get_dynamic_entity_model(
vector_dimension: int,
namespace: Optional[str] = None,
) -> Type[SQLModel]:
namespace = format_namespace(namespace)
entity_table_name = f"entities_{namespace}"
entity_model_name = f"Entity_{namespace}_{vector_dimension}"
logger.info(
"Dynamic create entity model (dimension: %s, table: %s, model: %s)",
vector_dimension,
entity_table_name,
entity_model_name,
)
class Entity(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=512)
description: str = Field(sa_column=Column(Text))
meta: dict = Field(default_factory=dict, sa_column=Column(JSON))
entity_type: EntityType = EntityType.original
synopsis_info: List | Dict | None = Field(default=None, sa_column=Column(JSON))
description_vec: list[float] = Field(sa_type=VectorType(vector_dimension))
meta_vec: list[float] = Field(sa_type=VectorType(vector_dimension))
def __hash__(self):
return hash(self.id)
# screenshot method is used to return a dictionary representation of the object
# that can be used for recording or debugging purposes
def screenshot(self):
return self.model_dump(
exclude={
"description_vec",
"meta_vec",
}
)
entity_model = type(
entity_model_name,
(Entity,),
{
"__tablename__": entity_table_name,
"__table_args__": (
Index("idx_entity_type", "entity_type"),
Index("idx_entity_name", "name"),
{"extend_existing": True},
),
},
table=True,
)
return entity_model
================================================
FILE: backend/app/models/evaluation_dataset.py
================================================
from uuid import UUID
from typing import Optional, List
from sqlalchemy import Text, JSON
from sqlmodel import (
Field,
Column,
Relationship as SQLRelationship,
)
from app.models.base import UpdatableBaseModel
class EvaluationDataset(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=255)
user_id: UUID = Field(foreign_key="users.id", nullable=True)
user: "User" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "EvaluationDataset.user_id == User.id",
},
)
evaluation_data_list: List["EvaluationDatasetItem"] = SQLRelationship(
back_populates="evaluation_dataset"
)
__tablename__ = "evaluation_datasets"
class EvaluationDatasetItem(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
query: str = Field(sa_column=Column(Text))
reference: str = Field(sa_column=Column(Text))
retrieved_contexts: list[str] = Field(default=[], sa_column=Column(JSON))
extra: dict = Field(default={}, sa_column=Column(JSON))
evaluation_dataset_id: int = Field(
foreign_key="evaluation_datasets.id", nullable=True
)
evaluation_dataset: "EvaluationDataset" = SQLRelationship(
back_populates="evaluation_data_list",
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "EvaluationDatasetItem.evaluation_dataset_id == EvaluationDataset.id",
},
)
__tablename__ = "evaluation_dataset_items"
================================================
FILE: backend/app/models/evaluation_task.py
================================================
import enum
from uuid import UUID
from typing import Optional, List
from sqlalchemy import Text, JSON
from sqlmodel import (
Field,
Column,
String,
Relationship as SQLRelationship,
)
from app.models.base import UpdatableBaseModel
class EvaluationStatus(str, enum.Enum):
NOT_START = "not_start"
EVALUATING = "evaluating"
DONE = "done"
ERROR = "error"
CANCEL = "cancel"
class EvaluationTask(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=255)
user_id: UUID = Field(foreign_key="users.id", nullable=True)
user: "User" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "EvaluationTask.user_id == User.id",
},
)
dataset_id: int = Field(nullable=True)
evaluation_task_items: List["EvaluationTaskItem"] = SQLRelationship(
back_populates="evaluation_task"
)
__tablename__ = "evaluation_tasks"
class EvaluationTaskItem(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
chat_engine: str = Field(max_length=255)
status: EvaluationStatus = Field(sa_column=Column(String(32), nullable=False))
query: str = Field(sa_column=Column(Text))
reference: str = Field(sa_column=Column(Text))
response: str = Field(sa_column=Column(Text))
retrieved_contexts: list[str] = Field(default=[], sa_column=Column(JSON))
extra: dict = Field(default={}, sa_column=Column(JSON))
error_msg: str = Field(sa_column=Column(Text, nullable=True))
factual_correctness: Optional[float] = Field(nullable=True)
semantic_similarity: Optional[float] = Field(nullable=True)
evaluation_task_id: int = Field(foreign_key="evaluation_tasks.id", nullable=True)
evaluation_task: "EvaluationTask" = SQLRelationship(
back_populates="evaluation_task_items",
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "EvaluationTaskItem.evaluation_task_id == EvaluationTask.id",
},
)
__tablename__ = "evaluation_task_items"
================================================
FILE: backend/app/models/feedback.py
================================================
import enum
from uuid import UUID
from typing import Optional
from pydantic import BaseModel
from datetime import datetime
from sqlmodel import (
Field,
Relationship as SQLRelationship,
)
from .base import UpdatableBaseModel
class FeedbackType(str, enum.Enum):
LIKE = "like"
DISLIKE = "dislike"
@classmethod
def adjust_relationship_weight(cls, feedback_type):
weights = {cls.LIKE: 10, cls.DISLIKE: -10}
return weights.get(feedback_type, 0)
class BaseFeedback(UpdatableBaseModel):
feedback_type: FeedbackType = FeedbackType.LIKE
comment: str = Field(max_length=500, default=None)
chat_id: UUID
chat_message_id: int
user_id: UUID
origin: Optional[str] = Field(max_length=256, default=None, nullable=True)
class Feedback(BaseFeedback, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
chat_id: UUID = Field(foreign_key="chats.id")
chat: "Chat" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Feedback.chat_id == Chat.id",
},
)
chat_message_id: int = Field(foreign_key="chat_messages.id")
chat_message: "ChatMessage" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Feedback.chat_message_id == ChatMessage.id",
},
)
user_id: UUID = Field(foreign_key="users.id", nullable=True)
user: "User" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Feedback.user_id == User.id",
},
)
__tablename__ = "feedbacks"
class AdminFeedbackPublic(BaseFeedback):
id: int
chat_title: str
chat_origin: Optional[str]
chat_message_content: str
user_id: Optional[UUID]
user_email: Optional[str]
class FeedbackFilters(BaseModel):
created_at_start: Optional[datetime] = None
created_at_end: Optional[datetime] = None
feedback_origin: Optional[str] = None
chat_id: Optional[UUID] = None
feedback_type: Optional[FeedbackType] = None
user_id: Optional[UUID] = None
class FeedbackOrigin(BaseModel):
origin: str
feedbacks: int
================================================
FILE: backend/app/models/knowledge_base.py
================================================
import enum
from datetime import datetime
from typing import Dict, Optional, Union
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import JSON, func
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlmodel import (
Field,
Column,
DateTime,
Relationship as SQLRelationship,
SQLModel,
)
from llama_index.core.node_parser.text.sentence import (
DEFAULT_PARAGRAPH_SEP,
SENTENCE_CHUNK_OVERLAP,
)
from app.rag.node_parser.file.markdown import (
DEFAULT_CHUNK_HEADER_LEVEL,
DEFAULT_CHUNK_SIZE,
)
from app.api.admin_routes.models import KnowledgeBaseDescriptor
from app.exceptions import KBDataSourceNotFound
from app.models.auth import User
from app.models.data_source import DataSource
from app.models.embed_model import EmbeddingModel
from app.models.llm import LLM
from app.types import MimeTypes
# For compatibility with old code, define a fake knowledge base id.
PHONY_KNOWLEDGE_BASE_ID = 0
class IndexMethod(str, enum.Enum):
KNOWLEDGE_GRAPH = "knowledge_graph"
VECTOR = "vector"
class KnowledgeBaseDataSource(SQLModel, table=True):
knowledge_base_id: int = Field(primary_key=True, foreign_key="knowledge_bases.id")
data_source_id: int = Field(primary_key=True, foreign_key="data_sources.id")
__tablename__ = "knowledge_base_datasources"
# Chunking Settings.
class ChunkSplitter(str, enum.Enum):
SENTENCE_SPLITTER = "SentenceSplitter"
MARKDOWN_NODE_PARSER = "MarkdownNodeParser"
class SentenceSplitterOptions(BaseModel):
chunk_size: int = Field(
description="The token chunk size for each chunk.",
default=1000,
gt=0,
)
chunk_overlap: int = Field(
description="The overlap size for each chunk.",
default=SENTENCE_CHUNK_OVERLAP,
gt=0,
)
paragraph_separator: str = Field(
description="The paragraph separator for splitting the text.",
default=DEFAULT_PARAGRAPH_SEP,
)
class MarkdownNodeParserOptions(BaseModel):
chunk_size: int = Field(
description="The token chunk size for each chunk.",
default=1000,
gt=0,
)
chunk_header_level: int = Field(
description="The header level to split on",
default=DEFAULT_CHUNK_HEADER_LEVEL,
ge=1,
le=6,
)
class ChunkSplitterConfig(BaseModel):
splitter: ChunkSplitter = Field(default=ChunkSplitter.SENTENCE_SPLITTER)
splitter_options: Union[SentenceSplitterOptions, MarkdownNodeParserOptions] = (
Field()
)
class ChunkingMode(str, enum.Enum):
GENERAL = "general"
ADVANCED = "advanced"
class BaseChunkingConfig(BaseModel):
mode: ChunkingMode = Field(default=ChunkingMode.GENERAL)
class GeneralChunkingConfig(BaseChunkingConfig):
mode: ChunkingMode = Field(default=ChunkingMode.GENERAL)
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, gt=0)
chunk_overlap: int = Field(default=SENTENCE_CHUNK_OVERLAP, gt=0)
paragraph_separator: str = Field(default=DEFAULT_PARAGRAPH_SEP)
class AdvancedChunkingConfig(BaseChunkingConfig):
mode: ChunkingMode = Field(default=ChunkingMode.ADVANCED)
rules: Dict[MimeTypes, ChunkSplitterConfig] = Field(default_factory=list)
ChunkingConfig = Union[GeneralChunkingConfig | AdvancedChunkingConfig]
# Knowledge Base Model
class KnowledgeBase(SQLModel, table=True):
__tablename__ = "knowledge_bases"
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=255, nullable=False)
description: Optional[str] = Field(sa_column=Column(MEDIUMTEXT), default=None)
# The config for chunking, the process to break down the document into smaller chunks.
chunking_config: Dict = Field(
sa_column=Column(JSON), default=GeneralChunkingConfig().model_dump()
)
# Data sources config.
data_sources: list["DataSource"] = SQLRelationship(
link_model=KnowledgeBaseDataSource
)
# Index Config.
index_methods: list[IndexMethod] = Field(
default=[IndexMethod.VECTOR], sa_column=Column(JSON)
)
llm_id: int = Field(foreign_key="llms.id", nullable=True)
llm: "LLM" = SQLRelationship(
sa_relationship_kwargs={
"lazy": "joined",
"foreign_keys": "KnowledgeBase.llm_id",
},
)
embedding_model_id: int = Field(foreign_key="embedding_models.id", nullable=True)
embedding_model: "EmbeddingModel" = SQLRelationship(
sa_relationship_kwargs={
"lazy": "joined",
"foreign_keys": "KnowledgeBase.embedding_model_id",
},
)
documents_total: int = Field(default=0)
data_sources_total: int = Field(default=0)
# TODO: Support knowledge-base level permission control.
created_by: UUID = Field(foreign_key="users.id", nullable=True)
creator: "User" = SQLRelationship(
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "KnowledgeBase.created_by == User.id",
},
)
created_at: Optional[datetime] = Field(
default=None, sa_column=Column(DateTime(), server_default=func.now())
)
updated_by: UUID = Field(foreign_key="users.id", nullable=True)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(DateTime(), server_default=func.now(), onupdate=func.now()),
)
deleted_by: UUID = Field(foreign_key="users.id", nullable=True)
deleted_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime()))
def __init__(self, **kwargs):
kwargs.setdefault("data_sources", [])
super().__init__(**kwargs)
def __hash__(self):
return hash(self.id)
def get_data_source_by_id(self, data_source_id: int) -> Optional[DataSource]:
return next(
(
ds
for ds in self.data_sources
if ds.id == data_source_id and not ds.deleted_at
),
None,
)
def must_get_data_source_by_id(self, data_source_id: int) -> DataSource:
data_source = self.get_data_source_by_id(data_source_id)
if data_source is None:
raise KBDataSourceNotFound(self.id, data_source_id)
return data_source
def to_descriptor(self) -> KnowledgeBaseDescriptor:
return KnowledgeBaseDescriptor(
id=self.id,
name=self.name,
)
================================================
FILE: backend/app/models/knowledge_base_scoped/__init__.py
================================================
================================================
FILE: backend/app/models/knowledge_base_scoped/table_naming.py
================================================
import logging
import re
from app.models.knowledge_base import KnowledgeBase
from app.models.embed_model import DEFAULT_VECTOR_DIMENSION
logger = logging.getLogger(__name__)
DEFAULT_CHUNKS_TABLE_NAME = "chunks"
DEFAULT_ENTITIES_TABLE_NAME = "entities"
DEFAULT_RELATIONSHIPS_TABLE_NAME = "relationships"
CHUNKS_TABLE_PREFIX = "chunks_"
ENTITIES_TABLE_PREFIX = "entities_"
RELATIONSHIPS_TABLE_PREFIX = "relationships_"
KB_CHUNKS_TABLE_PATTERN = re.compile(r"^chunks_\d+$")
KB_ENTITIES_TABLE_PATTERN = re.compile(r"^entities_\d+$")
KB_RELATIONSHIPS_TABLE_PATTERN = re.compile(r"^relationships_\d+$")
def get_kb_chunks_table_name(knowledge_base: KnowledgeBase) -> str:
return (
CHUNKS_TABLE_PREFIX + str(knowledge_base.id)
if knowledge_base
else DEFAULT_CHUNKS_TABLE_NAME
)
def get_kb_relationships_table_name(knowledge_base: KnowledgeBase) -> str:
return (
RELATIONSHIPS_TABLE_PREFIX + str(knowledge_base.id)
if knowledge_base
else DEFAULT_RELATIONSHIPS_TABLE_NAME
)
def get_kb_entities_table_name(knowledge_base: KnowledgeBase) -> str:
return (
ENTITIES_TABLE_PREFIX + str(knowledge_base.id)
if knowledge_base
else DEFAULT_ENTITIES_TABLE_NAME
)
def get_kb_vector_dims(kb: KnowledgeBase):
vector_dimension = DEFAULT_VECTOR_DIMENSION
if kb.embedding_model and kb.embedding_model.vector_dimension:
vector_dimension = kb.embedding_model.vector_dimension
else:
logger.warning(
"This knowledge base doesn't configured a embedding model or this vector vector_dimension "
"of the embedding model is miss."
)
return vector_dimension
================================================
FILE: backend/app/models/llm.py
================================================
from typing import Optional, Any
from sqlmodel import Field, Column, JSON, String
from pydantic import BaseModel
from app.rag.llms.provider import LLMProvider
from .base import UpdatableBaseModel, AESEncryptedColumn
class BaseLLM(UpdatableBaseModel):
name: str = Field(max_length=64)
provider: LLMProvider = Field(sa_column=Column(String(32), nullable=False))
model: str = Field(max_length=256)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
is_default: bool = Field(default=False)
class LLM(BaseLLM, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
credentials: Any = Field(sa_column=Column(AESEncryptedColumn, nullable=True))
__tablename__ = "llms"
class AdminLLM(BaseLLM):
id: int
class LLMUpdate(BaseModel):
name: Optional[str] = None
config: Optional[dict] = None
credentials: Optional[str | dict] = None
================================================
FILE: backend/app/models/recommend_question.py
================================================
from typing import Optional, List
from sqlmodel import (
Field,
Column,
JSON,
Relationship as SQLRelationship,
)
from .base import UpdatableBaseModel
class RecommendQuestion(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
questions: List = Field(default=[], sa_column=Column(JSON))
chat_message_id: int = Field(foreign_key="chat_messages.id", index=True)
chat_message: "ChatMessage" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "RecommendQuestion.chat_message_id == ChatMessage.id",
},
)
__tablename__ = "recommend_questions"
================================================
FILE: backend/app/models/relationship.py
================================================
from datetime import datetime
from app.utils.singleflight_cache import singleflight_cache
from typing import Optional, Type
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import Column, Text, JSON, DateTime
from sqlmodel import (
SQLModel,
Field,
Relationship as SQLRelationship,
)
from tidb_vector.sqlalchemy import VectorType
from app.models.entity import get_kb_entity_model
from app.models.knowledge_base import KnowledgeBase
from app.models.knowledge_base_scoped.table_naming import get_kb_vector_dims
from app.utils.namespace import format_namespace
from app.logger import logger
class RelationshipPublic(BaseModel):
id: int
description: str
source_entity_id: int
target_entity_id: int
meta: dict = Field(default_factory=dict)
weight: Optional[int] = Field(default=0)
last_modified_at: Optional[datetime] = Field(default=None)
document_id: Optional[int] = Field(default=None)
chunk_id: Optional[UUID] = Field(default=None)
def get_kb_relationship_model(kb: KnowledgeBase) -> Type[SQLModel]:
vector_dimension = get_kb_vector_dims(kb)
entity_model = get_kb_entity_model(kb)
return get_dynamic_relationship_model(vector_dimension, str(kb.id), entity_model)
@singleflight_cache
def get_dynamic_relationship_model(
vector_dimension: int,
namespace: Optional[str] = None,
entity_model: Optional[Type[SQLModel]] = None,
) -> Type[SQLModel]:
namespace = format_namespace(namespace)
entity_table_name = entity_model.__tablename__
entity_model_name = entity_model.__name__
relationship_table_name = f"relationships_{namespace}"
relationship_model_name = f"Relationship_{namespace}_{vector_dimension}"
logger.info(
"Dynamic create relationship model (dimension: %s, table: %s, model: %s)",
vector_dimension,
relationship_table_name,
relationship_model_name,
)
class Relationship(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
description: str = Field(sa_column=Column(Text))
meta: dict = Field(default_factory=dict, sa_column=Column(JSON))
weight: int = 0
source_entity_id: int = Field(foreign_key=f"{entity_table_name}.id")
target_entity_id: int = Field(foreign_key=f"{entity_table_name}.id")
last_modified_at: Optional[datetime] = Field(sa_column=Column(DateTime))
document_id: Optional[int] = Field(default=None, nullable=True)
chunk_id: Optional[UUID] = Field(default=None, nullable=True)
description_vec: list[float] = Field(sa_type=VectorType(vector_dimension))
def __hash__(self):
return hash(self.id)
def screenshot(self):
obj_dict = self.model_dump(
exclude={
"description_vec",
"source_entity",
"target_entity",
"last_modified_at",
}
)
return obj_dict
relationship_model = type(
relationship_model_name,
(Relationship,),
{
"__tablename__": relationship_table_name,
"__table_args__": {"extend_existing": True},
"__annotations__": {
"source_entity": entity_model,
"target_entity": entity_model,
},
"source_entity": SQLRelationship(
sa_relationship_kwargs={
"primaryjoin": f"{relationship_model_name}.source_entity_id == {entity_model_name}.id",
"lazy": "joined",
},
),
"target_entity": SQLRelationship(
sa_relationship_kwargs={
"primaryjoin": f"{relationship_model_name}.target_entity_id == {entity_model_name}.id",
"lazy": "joined",
},
),
},
table=True,
)
return relationship_model
================================================
FILE: backend/app/models/reranker_model.py
================================================
from typing import Optional, Any
from sqlmodel import Field, Column, JSON, String
from .base import UpdatableBaseModel, AESEncryptedColumn
from app.rag.rerankers.provider import RerankerProvider
from pydantic import BaseModel
class BaseRerankerModel(UpdatableBaseModel):
name: str = Field(max_length=64)
provider: RerankerProvider = Field(sa_column=Column(String(32), nullable=False))
model: str = Field(max_length=256)
top_n: int = Field(default=10)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
is_default: bool = Field(default=False)
class RerankerModel(BaseRerankerModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
credentials: Any = Field(sa_column=Column(AESEncryptedColumn, nullable=True))
__tablename__ = "reranker_models"
class AdminRerankerModel(BaseRerankerModel):
id: int
class RerankerModelUpdate(BaseModel):
name: Optional[str] = None
config: Optional[dict | list] = None
credentials: Optional[str | dict] = None
top_n: Optional[int] = None
================================================
FILE: backend/app/models/semantic_cache.py
================================================
from typing import Optional, Any
from datetime import datetime
from sqlmodel import (
SQLModel,
Field,
Column,
JSON,
Text,
func,
DateTime,
)
from tidb_vector.sqlalchemy import VectorType
from app.core.config import settings
class SemanticCache(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
query: str = Field(sa_column=Column(Text))
query_vec: Any = Field(
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)
value: str = Field(sa_column=Column(Text))
value_vec: Any = Field(
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)
meta: dict = Field(default_factory=dict, sa_column=Column(JSON))
created_at: datetime = Field(
sa_column=Column(DateTime, server_default=func.now(), nullable=True)
)
updated_at: datetime = Field(
sa_column=Column(
DateTime, server_default=func.now(), onupdate=func.now(), nullable=True
)
)
__tablename__ = "semantic_cache"
__table_args__ = {
# Ref: https://docs.pingcap.com/tidb/stable/time-to-live
"mysql_TTL": "created_at + INTERVAL 1 MONTH;",
}
def __hash__(self):
return hash(self.id)
# screenshot method is used to return a dictionary representation of the object
# that can be used for recording or debugging purposes
def screenshot(self):
return self.model_dump(exclude={"query_vec", "value_vec"})
================================================
FILE: backend/app/models/site_setting.py
================================================
from typing import Optional
from datetime import datetime
from sqlmodel import SQLModel, Field, Column, JSON, func
from sqlalchemy.dialects.mysql import DATETIME
class SiteSetting(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=256, unique=True)
data_type: str = Field(max_length=256)
value: str = Field(sa_column=Column(JSON))
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(DATETIME(timezone=True, fsp=6), server_default=func.now()),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
# SiteSetting needs more time precision to avoid timestamp collision between each other,
# so we use mysql.DATETIME rather than base.UpdatableBaseModel
DATETIME(timezone=True, fsp=6),
server_default=func.now(),
onupdate=func.now(),
),
)
__tablename__ = "site_settings"
================================================
FILE: backend/app/models/staff_action_log.py
================================================
from typing import Optional, Dict
from datetime import datetime
from sqlmodel import SQLModel, Field, Column, JSON, DateTime, func
class StaffActionLog(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
action: str
action_time: datetime = Field(sa_column=Column(DateTime, server_default=func.now()))
target_type: str
target_id: int
before: Dict = Field(default_factory=dict, sa_column=Column(JSON))
after: Dict = Field(default_factory=dict, sa_column=Column(JSON))
__tablename__ = "staff_action_logs"
================================================
FILE: backend/app/models/upload.py
================================================
from uuid import UUID
from typing import Optional
from sqlmodel import (
Field,
Column,
String,
Relationship as SQLRelationship,
)
from app.models.base import UpdatableBaseModel
from app.types import MimeTypes
class Upload(UpdatableBaseModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=255)
size: int = Field(default=0)
path: str = Field(max_length=255)
mime_type: MimeTypes = Field(sa_column=Column(String(128), nullable=False))
user_id: UUID = Field(foreign_key="users.id", nullable=True)
user: "User" = SQLRelationship( # noqa:F821
sa_relationship_kwargs={
"lazy": "joined",
"primaryjoin": "Upload.user_id == User.id",
},
)
__tablename__ = "uploads"
================================================
FILE: backend/app/rag/__init__.py
================================================
================================================
FILE: backend/app/rag/build_index.py
================================================
import logging
from typing import List, Optional, Type
from llama_index.core import VectorStoreIndex
from llama_index.core.embeddings.utils import EmbedType
from llama_index.core.llms.llm import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TransformComponent
from sqlmodel import SQLModel, Session
from app.models.knowledge_base import (
ChunkSplitter,
ChunkingMode,
KnowledgeBase,
SentenceSplitterOptions,
GeneralChunkingConfig,
ChunkSplitterConfig,
MarkdownNodeParserOptions,
AdvancedChunkingConfig,
)
from app.rag.knowledge_base.index_store import (
get_kb_tidb_vector_store,
get_kb_tidb_graph_store,
)
from app.rag.indices.knowledge_graph import KnowledgeGraphIndex
from app.models import Document
from app.rag.node_parser.file.markdown import MarkdownNodeParser
from app.types import MimeTypes
from app.rag.llms.dspy import get_dspy_lm_by_llama_llm
logger = logging.getLogger(__name__)
class IndexService:
"""
Service class for building RAG indexes (vector index and knowledge graph index).
"""
def __init__(
self,
llm: LLM,
embed_model: Optional[EmbedType] = None,
knowledge_base: Optional[KnowledgeBase] = None,
):
self._llm = llm
self._dspy_lm = get_dspy_lm_by_llama_llm(llm)
self._embed_model = embed_model
self._knowledge_base = knowledge_base
# TODO: move to ./indices/vector_search
def build_vector_index_for_document(
self, session: Session, db_document: Type[Document]
):
"""
Build vector index and graph index from document.
Build vector index will do the following:
1. Parse document into nodes.
2. Extract metadata from nodes by applying transformations.
3. embedding text nodes.
4. Insert nodes into `chunks` table.
"""
vector_store = get_kb_tidb_vector_store(session, self._knowledge_base)
transformations = self._get_transformations(db_document)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=self._embed_model,
transformations=transformations,
)
llama_document = db_document.to_llama_document()
logger.info(f"Start building vector index for document #{db_document.id}.")
vector_index.insert(llama_document, source_uri=db_document.source_uri)
logger.info(f"Finish building vector index for document #{db_document.id}.")
vector_store.close_session()
return
def _get_transformations(
self, db_document: Type[Document]
) -> List[TransformComponent]:
transformations = []
chunking_config_dict = self._knowledge_base.chunking_config
mode = (
chunking_config_dict["mode"]
if "mode" in chunking_config_dict
else ChunkingMode.GENERAL
)
if mode == ChunkingMode.ADVANCED:
chunking_config = AdvancedChunkingConfig.model_validate(
chunking_config_dict
)
rules = chunking_config.rules
else:
chunking_config = GeneralChunkingConfig.model_validate(chunking_config_dict)
rules = {
MimeTypes.PLAIN_TXT: ChunkSplitterConfig(
splitter=ChunkSplitter.SENTENCE_SPLITTER,
splitter_options=SentenceSplitterOptions(
chunk_size=chunking_config.chunk_size,
chunk_overlap=chunking_config.chunk_overlap,
paragraph_separator=chunking_config.paragraph_separator,
),
),
MimeTypes.MARKDOWN: ChunkSplitterConfig(
splitter=ChunkSplitter.MARKDOWN_NODE_PARSER,
splitter_options=MarkdownNodeParserOptions(
chunk_size=chunking_config.chunk_size,
),
),
}
# Chunking
mime_type = db_document.mime_type
if mime_type not in rules:
raise RuntimeError(
f"Can not chunking for the document in {db_document.mime_type} format"
)
rule = rules[mime_type]
match rule.splitter:
case ChunkSplitter.MARKDOWN_NODE_PARSER:
options = MarkdownNodeParserOptions.model_validate(
rule.splitter_options
)
transformations.append(MarkdownNodeParser(**options.model_dump()))
case ChunkSplitter.SENTENCE_SPLITTER:
options = SentenceSplitterOptions.model_validate(rule.splitter_options)
transformations.append(SentenceSplitter(**options.model_dump()))
case _:
raise ValueError(f"Unsupported chunking splitter type: {rule.splitter}")
return transformations
# TODO: move to ./indices/knowledge_graph
def build_kg_index_for_chunk(self, session: Session, db_chunk: Type[SQLModel]):
"""Build knowledge graph index from chunk.
Build knowledge graph index will do the following:
1. load TextNode from `chunks` table.
2. extract entities and relations from TextNode.
3. insert entities and relations into `entities` and `relations` table.
"""
graph_store = get_kb_tidb_graph_store(session, self._knowledge_base)
graph_index: KnowledgeGraphIndex = KnowledgeGraphIndex.from_existing(
dspy_lm=self._dspy_lm,
kg_store=graph_store,
)
node = db_chunk.to_llama_text_node()
logger.info(f"Start building knowledge graph index for chunk #{db_chunk.id}.")
graph_index.insert_nodes([node])
logger.info(f"Finish building knowledge graph index for chunk #{db_chunk.id}.")
graph_store.close_session()
return
================================================
FILE: backend/app/rag/chat/__init__.py
================================================
================================================
FILE: backend/app/rag/chat/chat_flow.py
================================================
import json
import logging
from datetime import datetime, UTC
from typing import List, Optional, Generator, Tuple, Any
from urllib.parse import urljoin
from uuid import UUID
import requests
from langfuse.llama_index import LlamaIndexInstrumentor
from langfuse.llama_index._context import langfuse_instrumentor_context
from llama_index.core import get_response_synthesizer
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.schema import NodeWithScore
from llama_index.core.prompts.rich import RichPromptTemplate
from sqlmodel import Session
from app.core.config import settings
from app.exceptions import ChatNotFound
from app.models import (
User,
Chat as DBChat,
ChatVisibility,
ChatMessage as DBChatMessage,
)
from app.rag.chat.config import ChatEngineConfig
from app.rag.chat.retrieve.retrieve_flow import SourceDocument, RetrieveFlow
from app.rag.chat.stream_protocol import (
ChatEvent,
ChatStreamDataPayload,
ChatStreamMessagePayload,
)
from app.rag.llms.dspy import get_dspy_lm_by_llama_llm
from app.rag.retrievers.knowledge_graph.schema import KnowledgeGraphRetrievalResult
from app.rag.types import ChatEventType, ChatMessageSate
from app.rag.utils import parse_goal_response_format
from app.repositories import chat_repo
from app.site_settings import SiteSetting
from app.utils.tracing import LangfuseContextManager
logger = logging.getLogger(__name__)
def parse_chat_messages(
chat_messages: List[ChatMessage],
) -> tuple[str, List[ChatMessage]]:
user_question = chat_messages[-1].content
chat_history = chat_messages[:-1]
return user_question, chat_history
class ChatFlow:
_trace_manager: LangfuseContextManager
def __init__(
self,
*,
db_session: Session,
user: User,
browser_id: str,
origin: str,
chat_messages: List[ChatMessage],
engine_name: str = "default",
chat_id: Optional[UUID] = None,
) -> None:
self.chat_id = chat_id
self.db_session = db_session
self.user = user
self.browser_id = browser_id
self.engine_name = engine_name
# Load chat engine and chat session.
self.user_question, self.chat_history = parse_chat_messages(chat_messages)
if chat_id:
# FIXME:
# only chat owner or superuser can access the chat,
# anonymous user can only access anonymous chat by track_id
self.db_chat_obj = chat_repo.get(self.db_session, chat_id)
if not self.db_chat_obj:
raise ChatNotFound(chat_id)
try:
self.engine_config = ChatEngineConfig.load_from_db(
db_session, self.db_chat_obj.engine.name
)
self.db_chat_engine = self.engine_config.get_db_chat_engine()
except Exception as e:
logger.error(f"Failed to load chat engine config: {e}")
self.engine_config = ChatEngineConfig.load_from_db(
db_session, engine_name
)
self.db_chat_engine = self.engine_config.get_db_chat_engine()
logger.info(
f"Init ChatFlow for chat {chat_id} (chat_engine: {self.db_chat_obj.engine.name})"
)
self.chat_history = [
ChatMessage(role=m.role, content=m.content, additional_kwargs={})
for m in chat_repo.get_messages(self.db_session, self.db_chat_obj)
]
else:
self.engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
self.db_chat_engine = self.engine_config.get_db_chat_engine()
self.db_chat_obj = chat_repo.create(
self.db_session,
DBChat(
# TODO: title should be generated by the LLM
title=self.user_question[:100],
engine_id=self.db_chat_engine.id,
engine_options=self.engine_config.screenshot(),
user_id=self.user.id if self.user else None,
browser_id=self.browser_id,
origin=origin,
visibility=(
ChatVisibility.PUBLIC
if not self.user
else ChatVisibility.PRIVATE
),
),
)
chat_id = self.db_chat_obj.id
# Notice: slack/discord bots may create a new chat with history messages.
now = datetime.now(UTC)
for i, m in enumerate(self.chat_history):
chat_repo.create_message(
session=self.db_session,
chat=self.db_chat_obj,
chat_message=DBChatMessage(
role=m.role,
content=m.content,
ordinal=i + 1,
created_at=now,
updated_at=now,
finished_at=now,
),
)
# Init Langfuse for tracing.
enable_langfuse = (
SiteSetting.langfuse_secret_key and SiteSetting.langfuse_public_key
)
instrumentor = LlamaIndexInstrumentor(
host=SiteSetting.langfuse_host,
secret_key=SiteSetting.langfuse_secret_key,
public_key=SiteSetting.langfuse_public_key,
enabled=enable_langfuse,
)
self._trace_manager = LangfuseContextManager(instrumentor)
# Init LLM.
self._llm = self.engine_config.get_llama_llm(self.db_session)
self._fast_llm = self.engine_config.get_fast_llama_llm(self.db_session)
self._fast_dspy_lm = get_dspy_lm_by_llama_llm(self._fast_llm)
# Load knowledge bases.
self.knowledge_bases = self.engine_config.get_knowledge_bases(self.db_session)
self.knowledge_base_ids = [kb.id for kb in self.knowledge_bases]
# Init retrieve flow.
self.retrieve_flow = RetrieveFlow(
db_session=self.db_session,
engine_name=self.engine_name,
engine_config=self.engine_config,
llm=self._llm,
fast_llm=self._fast_llm,
knowledge_bases=self.knowledge_bases,
)
def chat(self) -> Generator[ChatEvent | str, None, None]:
try:
with self._trace_manager.observe(
trace_name="ChatFlow",
user_id=(
self.user.email if self.user else f"anonymous-{self.browser_id}"
),
metadata={
"is_external_engine": self.engine_config.is_external_engine,
"chat_engine_config": self.engine_config.screenshot(),
},
tags=[f"chat_engine:{self.engine_name}"],
release=settings.ENVIRONMENT,
) as trace:
trace.update(
input={
"user_question": self.user_question,
"chat_history": self.chat_history,
}
)
if self.engine_config.is_external_engine:
yield from self._external_chat()
else:
response_text, source_documents = yield from self._builtin_chat()
trace.update(output=response_text)
except Exception as e:
logger.exception(e)
yield ChatEvent(
event_type=ChatEventType.ERROR_PART,
payload="Encountered an error while processing the chat. Please try again later.",
)
def _builtin_chat(
self,
) -> Generator[ChatEvent | str, None, Tuple[Optional[str], List[Any]]]:
ctx = langfuse_instrumentor_context.get().copy()
db_user_message, db_assistant_message = yield from self._chat_start()
langfuse_instrumentor_context.get().update(ctx)
# 1. Retrieve Knowledge graph related to the user question.
(
knowledge_graph,
knowledge_graph_context,
) = yield from self._search_knowledge_graph(user_question=self.user_question)
# 2. Refine the user question using knowledge graph and chat history.
refined_question = yield from self._refine_user_question(
user_question=self.user_question,
chat_history=self.chat_history,
knowledge_graph_context=knowledge_graph_context,
refined_question_prompt=self.engine_config.llm.condense_question_prompt,
)
# 3. Check if the question provided enough context information or need to clarify.
if self.engine_config.clarify_question:
need_clarify, need_clarify_response = yield from self._clarify_question(
user_question=refined_question,
chat_history=self.chat_history,
knowledge_graph_context=knowledge_graph_context,
)
if need_clarify:
yield from self._chat_finish(
db_assistant_message=db_assistant_message,
db_user_message=db_user_message,
response_text=need_clarify_response,
knowledge_graph=knowledge_graph,
source_documents=[],
)
return None, []
# 4. Use refined question to search for relevant chunks.
relevant_chunks = yield from self._search_relevance_chunks(
user_question=refined_question
)
# 5. Generate a response using the refined question and related chunks
response_text, source_documents = yield from self._generate_answer(
user_question=refined_question,
knowledge_graph_context=knowledge_graph_context,
relevant_chunks=relevant_chunks,
)
yield from self._chat_finish(
db_assistant_message=db_assistant_message,
db_user_message=db_user_message,
response_text=response_text,
knowledge_graph=knowledge_graph,
source_documents=source_documents,
)
return response_text, source_documents
def _chat_start(
self,
) -> Generator[ChatEvent, None, Tuple[DBChatMessage, DBChatMessage]]:
db_user_message = chat_repo.create_message(
session=self.db_session,
chat=self.db_chat_obj,
chat_message=DBChatMessage(
role=MessageRole.USER.value,
trace_url=self._trace_manager.trace_url,
content=self.user_question.strip(),
),
)
db_assistant_message = chat_repo.create_message(
session=self.db_session,
chat=self.db_chat_obj,
chat_message=DBChatMessage(
role=MessageRole.ASSISTANT.value,
trace_url=self._trace_manager.trace_url,
content="",
),
)
yield ChatEvent(
event_type=ChatEventType.DATA_PART,
payload=ChatStreamDataPayload(
chat=self.db_chat_obj,
user_message=db_user_message,
assistant_message=db_assistant_message,
),
)
return db_user_message, db_assistant_message
def _search_knowledge_graph(
self,
user_question: str,
annotation_silent: bool = False,
) -> Generator[ChatEvent, None, Tuple[KnowledgeGraphRetrievalResult, str]]:
kg_config = self.engine_config.knowledge_graph
if kg_config is None or kg_config.enabled is False:
return KnowledgeGraphRetrievalResult(), ""
with self._trace_manager.span(
name="search_knowledge_graph", input=user_question
) as span:
if not annotation_silent:
if kg_config.using_intent_search:
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.KG_RETRIEVAL,
display="Identifying The Question's Intents and Perform Knowledge Graph Search",
),
)
else:
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.KG_RETRIEVAL,
display="Searching the Knowledge Graph for Relevant Context",
),
)
knowledge_graph, knowledge_graph_context = (
self.retrieve_flow.search_knowledge_graph(user_question)
)
span.end(
output={
"knowledge_graph": knowledge_graph,
"knowledge_graph_context": knowledge_graph_context,
}
)
return knowledge_graph, knowledge_graph_context
def _refine_user_question(
self,
user_question: str,
chat_history: Optional[List[ChatMessage]] = [],
refined_question_prompt: Optional[str] = None,
knowledge_graph_context: str = "",
annotation_silent: bool = False,
) -> Generator[ChatEvent, None, str]:
with self._trace_manager.span(
name="refine_user_question",
input={
"user_question": user_question,
"chat_history": chat_history,
"knowledge_graph_context": knowledge_graph_context,
},
) as span:
if not annotation_silent:
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.REFINE_QUESTION,
display="Query Rewriting for Enhanced Information Retrieval",
),
)
prompt_template = RichPromptTemplate(refined_question_prompt)
refined_question = self._fast_llm.predict(
prompt_template,
graph_knowledges=knowledge_graph_context,
chat_history=chat_history,
question=user_question,
current_date=datetime.now().strftime("%Y-%m-%d"),
)
if not annotation_silent:
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.REFINE_QUESTION,
message=refined_question,
),
)
span.end(output=refined_question)
return refined_question
def _clarify_question(
self,
user_question: str,
chat_history: Optional[List[ChatMessage]] = [],
knowledge_graph_context: str = "",
) -> Generator[ChatEvent, None, Tuple[bool, str]]:
"""
Check if the question clear and provided enough context information, otherwise, it is necessary to
stop the conversation early and ask the user for the further clarification.
Args:
user_question: str
knowledge_graph_context: str
Returns:
bool: Determine whether further clarification of the issue is needed from the user.
str: The content of the questions that require clarification from the user.
"""
with self._trace_manager.span(
name="clarify_question",
input={
"user_question": user_question,
"knowledge_graph_context": knowledge_graph_context,
},
) as span:
prompt_template = RichPromptTemplate(
self.engine_config.llm.clarifying_question_prompt
)
prediction = self._fast_llm.predict(
prompt_template,
graph_knowledges=knowledge_graph_context,
chat_history=chat_history,
question=user_question,
)
# TODO: using structured output to get the clarity result.
clarity_result = prediction.strip().strip(".\"'!")
need_clarify = clarity_result.lower() != "false"
need_clarify_response = clarity_result if need_clarify else ""
if need_clarify:
yield ChatEvent(
event_type=ChatEventType.TEXT_PART,
payload=need_clarify_response,
)
span.end(
output={
"need_clarify": need_clarify,
"need_clarify_response": need_clarify_response,
}
)
return need_clarify, need_clarify_response
def _search_relevance_chunks(
self, user_question: str
) -> Generator[ChatEvent, None, List[NodeWithScore]]:
with self._trace_manager.span(
name="search_relevance_chunks", input=user_question
) as span:
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.SEARCH_RELATED_DOCUMENTS,
display="Retrieving the Most Relevant Documents",
),
)
relevance_chunks = self.retrieve_flow.search_relevant_chunks(user_question)
span.end(
output={
"relevance_chunks": relevance_chunks,
}
)
return relevance_chunks
def _generate_answer(
self,
user_question: str,
knowledge_graph_context: str,
relevant_chunks: List[NodeWithScore],
) -> Generator[ChatEvent, None, Tuple[str, List[SourceDocument]]]:
with self._trace_manager.span(
name="generate_answer", input=user_question
) as span:
# Initialize response synthesizer.
text_qa_template = RichPromptTemplate(
template_str=self.engine_config.llm.text_qa_prompt
)
text_qa_template = text_qa_template.partial_format(
current_date=datetime.now().strftime("%Y-%m-%d"),
graph_knowledges=knowledge_graph_context,
original_question=self.user_question,
)
response_synthesizer = get_response_synthesizer(
llm=self._llm, text_qa_template=text_qa_template, streaming=True
)
# Initialize response.
response = response_synthesizer.synthesize(
query=user_question,
nodes=relevant_chunks,
)
source_documents = self.retrieve_flow.get_source_documents_from_nodes(
response.source_nodes
)
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.SOURCE_NODES,
context=source_documents,
),
)
# Generate response.
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.GENERATE_ANSWER,
display="Generating a Precise Answer with AI",
),
)
response_text = ""
for word in response.response_gen:
response_text += word
yield ChatEvent(
event_type=ChatEventType.TEXT_PART,
payload=word,
)
if not response_text:
raise Exception("Got empty response from LLM")
span.end(
output=response_text,
metadata={
"source_documents": source_documents,
},
)
return response_text, source_documents
def _post_verification(
self, user_question: str, response_text: str, chat_id: UUID, message_id: int
) -> Optional[str]:
# post verification to external service, will return the post verification result url
post_verification_url = self.engine_config.post_verification_url
post_verification_token = self.engine_config.post_verification_token
if not post_verification_url:
return None
external_request_id = f"{chat_id}_{message_id}"
qa_content = f"User question: {user_question}\n\nAnswer:\n{response_text}"
with self._trace_manager.span(
name="post_verification",
input={
"external_request_id": external_request_id,
"qa_content": qa_content,
},
) as span:
try:
resp = requests.post(
post_verification_url,
json={
"external_request_id": external_request_id,
"qa_content": qa_content,
},
headers=(
{
"Authorization": f"Bearer {post_verification_token}",
}
if post_verification_token
else {}
),
timeout=10,
)
resp.raise_for_status()
job_id = resp.json()["job_id"]
post_verification_link = urljoin(
f"{post_verification_url}/", str(job_id)
)
span.end(
output={
"post_verification_link": post_verification_link,
}
)
return post_verification_link
except Exception as e:
logger.exception("Failed to post verification: %s", e.message)
return None
def _chat_finish(
self,
db_assistant_message: ChatMessage,
db_user_message: ChatMessage,
response_text: str,
knowledge_graph: KnowledgeGraphRetrievalResult = KnowledgeGraphRetrievalResult(),
source_documents: Optional[List[SourceDocument]] = [],
annotation_silent: bool = False,
):
if not annotation_silent:
yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.FINISHED,
),
)
post_verification_result_url = self._post_verification(
self.user_question,
response_text,
self.db_chat_obj.id,
db_assistant_message.id,
)
db_assistant_message.sources = [s.model_dump() for s in source_documents]
db_assistant_message.graph_data = knowledge_graph.to_stored_graph_dict()
db_assistant_message.content = response_text
db_assistant_message.post_verification_result_url = post_verification_result_url
db_assistant_message.updated_at = datetime.now(UTC)
db_assistant_message.finished_at = datetime.now(UTC)
self.db_session.add(db_assistant_message)
db_user_message.graph_data = knowledge_graph.to_stored_graph_dict()
db_user_message.updated_at = datetime.now(UTC)
db_user_message.finished_at = datetime.now(UTC)
self.db_session.add(db_user_message)
self.db_session.commit()
yield ChatEvent(
event_type=ChatEventType.DATA_PART,
payload=ChatStreamDataPayload(
chat=self.db_chat_obj,
user_message=db_user_message,
assistant_message=db_assistant_message,
),
)
# TODO: Separate _external_chat() method into another ExternalChatFlow class, but at the same time, we need to
# share some common methods through ChatMixin or BaseChatFlow.
def _external_chat(self) -> Generator[ChatEvent | str, None, None]:
ctx = langfuse_instrumentor_context.get().copy()
db_user_message, db_assistant_message = yield from self._chat_start()
langfuse_instrumentor_context.get().update(ctx)
cache_messages = None
goal, response_format = self.user_question, {}
if settings.ENABLE_QUESTION_CACHE and len(self.chat_history) == 0:
try:
logger.info(
f"start to find_best_answer_for_question with question: {self.user_question}"
)
cache_messages = chat_repo.find_best_answer_for_question(
self.db_session, self.user_question
)
if cache_messages and len(cache_messages) > 0:
logger.info(
f"find_best_answer_for_question result {len(cache_messages)} for question {self.user_question}"
)
except Exception as e:
logger.error(
f"Failed to find best answer for question {self.user_question}: {e}"
)
if not cache_messages or len(cache_messages) == 0:
try:
# 1. Generate the goal with the user question, knowledge graph and chat history.
goal, response_format = yield from self._generate_goal()
# 2. Check if the goal provided enough context information or need to clarify.
if self.engine_config.clarify_question:
(
need_clarify,
need_clarify_response,
) = yield from self._clarify_question(
user_question=goal, chat_history=self.chat_history
)
if need_clarify:
yield from self._chat_finish(
db_assistant_message=db_assistant_message,
db_user_message=db_user_message,
response_text=need_clarify_response,
annotation_silent=True,
)
return
except Exception as e:
goal = self.user_question
logger.warning(
f"Failed to generate refined goal, fallback to use user question as goal directly: {e}",
exc_info=True,
extra={},
)
cache_messages = None
if settings.ENABLE_QUESTION_CACHE:
try:
logger.info(
f"start to find_recent_assistant_messages_by_goal with goal: {goal}, response_format: {response_format}"
)
cache_messages = chat_repo.find_recent_assistant_messages_by_goal(
self.db_session,
{"goal": goal, "Lang": response_format.get("Lang", "English")},
90,
)
logger.info(
f"find_recent_assistant_messages_by_goal result {len(cache_messages)} for goal {goal}"
)
except Exception as e:
logger.error(
f"Failed to find recent assistant messages by goal: {e}"
)
stream_chat_api_url = (
self.engine_config.external_engine_config.stream_chat_api_url
)
if cache_messages and len(cache_messages) > 0:
stackvm_response_text = cache_messages[0].content
task_id = cache_messages[0].meta.get("task_id")
for chunk in stackvm_response_text.split(". "):
if chunk:
if not chunk.endswith("."):
chunk += ". "
yield ChatEvent(
event_type=ChatEventType.TEXT_PART,
payload=chunk,
)
else:
logger.debug(
f"Chatting with external chat engine (api_url: {stream_chat_api_url}) to answer for user question: {self.user_question}"
)
chat_params = {
"goal": goal,
"response_format": response_format,
"namespace_name": "Default",
}
res = requests.post(stream_chat_api_url, json=chat_params, stream=True)
# Notice: External type chat engine doesn't support non-streaming mode for now.
stackvm_response_text = ""
task_id = None
for line in res.iter_lines():
if not line:
continue
# Append to final response text.
chunk = line.decode("utf-8")
if chunk.startswith("0:"):
word = json.loads(chunk[2:])
stackvm_response_text += word
yield ChatEvent(
event_type=ChatEventType.TEXT_PART,
payload=word,
)
else:
yield line + b"\n"
try:
if chunk.startswith("8:") and task_id is None:
states = json.loads(chunk[2:])
if len(states) > 0:
# accesss task by http://endpoint/?task_id=$task_id
task_id = states[0].get("task_id")
except Exception as e:
logger.error(f"Failed to get task_id from chunk: {e}")
response_text = stackvm_response_text
base_url = stream_chat_api_url.replace("/api/stream_execute_vm", "")
try:
post_verification_result_url = self._post_verification(
goal,
response_text,
self.db_chat_obj.id,
db_assistant_message.id,
)
db_assistant_message.post_verification_result_url = (
post_verification_result_url
)
except Exception:
logger.error(
"Specific error occurred during post verification job.", exc_info=True
)
trace_url = f"{base_url}?task_id={task_id}" if task_id else ""
message_meta = {
"task_id": task_id,
"goal": goal,
**response_format,
}
db_assistant_message.content = response_text
db_assistant_message.trace_url = trace_url
db_assistant_message.meta = message_meta
db_assistant_message.updated_at = datetime.now(UTC)
db_assistant_message.finished_at = datetime.now(UTC)
self.db_session.add(db_assistant_message)
db_user_message.trace_url = trace_url
db_user_message.meta = message_meta
db_user_message.updated_at = datetime.now(UTC)
db_user_message.finished_at = datetime.now(UTC)
self.db_session.add(db_user_message)
self.db_session.commit()
yield ChatEvent(
event_type=ChatEventType.DATA_PART,
payload=ChatStreamDataPayload(
chat=self.db_chat_obj,
user_message=db_user_message,
assistant_message=db_assistant_message,
),
)
def _generate_goal(self) -> Generator[ChatEvent, None, Tuple[str, dict]]:
try:
refined_question = yield from self._refine_user_question(
user_question=self.user_question,
chat_history=self.chat_history,
refined_question_prompt=self.engine_config.llm.generate_goal_prompt,
annotation_silent=True,
)
goal = refined_question.strip()
if goal.startswith("Goal: "):
goal = goal[len("Goal: ") :].strip()
except Exception as e:
logger.error(f"Failed to refine question with related knowledge graph: {e}")
goal = self.user_question
response_format = {}
try:
clean_goal, response_format = parse_goal_response_format(goal)
logger.info(f"clean goal: {clean_goal}, response_format: {response_format}")
if clean_goal:
goal = clean_goal
except Exception as e:
logger.error(f"Failed to parse goal and response format: {e}")
return goal, response_format
================================================
FILE: backend/app/rag/chat/chat_service.py
================================================
from http import HTTPStatus
import logging
from typing import Generator, List, Optional
from uuid import UUID
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy import text, delete
from sqlmodel import Session, select, func
from app.api.routes.models import (
RequiredConfigStatus,
OptionalConfigStatus,
NeedMigrationStatus,
)
from app.models import (
User,
ChatVisibility,
Chat as DBChat,
ChatMessage as DBChatMessage,
KnowledgeBase as DBKnowledgeBase,
RerankerModel as DBRerankerModel,
ChatEngine,
)
from app.models.recommend_question import RecommendQuestion
from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow, SourceDocument
from app.rag.chat.stream_protocol import ChatEvent
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrievalResult,
StoredKnowledgeGraph,
RetrievedSubGraph,
)
from app.rag.knowledge_base.index_store import get_kb_tidb_graph_store
from app.repositories import knowledge_base_repo
from app.rag.chat.config import (
ChatEngineConfig,
)
from app.rag.types import (
ChatEventType,
ChatMessageSate,
)
from app.repositories import chat_engine_repo
from app.repositories.embedding_model import embedding_model_repo
from app.repositories.llm import llm_repo
from app.site_settings import SiteSetting
from llama_index.core.prompts.rich import RichPromptTemplate
logger = logging.getLogger(__name__)
class ChatResult(BaseModel):
chat_id: UUID
message_id: int
content: str
trace: Optional[str] = None
sources: Optional[List[SourceDocument]] = []
def get_final_chat_result(
generator: Generator[ChatEvent | str, None, None],
) -> ChatResult:
trace, sources, content = None, [], ""
chat_id, message_id = None, None
for m in generator:
if not isinstance(m, ChatEvent):
continue
if m.event_type == ChatEventType.MESSAGE_ANNOTATIONS_PART:
if m.payload.state == ChatMessageSate.SOURCE_NODES:
sources = m.payload.context
elif m.event_type == ChatEventType.TEXT_PART:
content += m.payload
elif m.event_type == ChatEventType.DATA_PART:
chat_id = m.payload.chat.id
message_id = m.payload.assistant_message.id
trace = m.payload.assistant_message.trace_url
elif m.event_type == ChatEventType.ERROR_PART:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=m.payload,
)
else:
pass
return ChatResult(
chat_id=chat_id,
message_id=message_id,
trace=trace,
sources=sources,
content=content,
)
def user_can_view_chat(chat: DBChat, user: Optional[User]) -> bool:
# Anonymous or public chat can be accessed by anyone
# Non-anonymous chat can be accessed by owner or superuser
if not chat.user_id or chat.visibility == ChatVisibility.PUBLIC:
return True
return user is not None and (user.is_superuser or chat.user_id == user.id)
def user_can_edit_chat(chat: DBChat, user: Optional[User]) -> bool:
if user is None:
return False
if user.is_superuser:
return True
return chat.user_id == user.id
def get_graph_data_from_chat_message(
db_session: Session,
chat_message: DBChatMessage,
engine_config: ChatEngineConfig,
) -> Optional[KnowledgeGraphRetrievalResult]:
if not chat_message.graph_data:
return None
graph_data = chat_message.graph_data
# For forward compatibility.
if "version" not in graph_data:
kb = engine_config.get_knowledge_bases(db_session)[0]
graph_store = get_kb_tidb_graph_store(db_session, kb)
return graph_store.get_subgraph_by_relationship_ids(graph_data["relationships"])
# Stored Knowledge Graph -> Retrieved Knowledge Graph
stored_kg = StoredKnowledgeGraph.model_validate(graph_data)
if stored_kg.knowledge_base_id is not None:
kb = knowledge_base_repo.must_get(db_session, stored_kg.knowledge_base_id)
graph_store = get_kb_tidb_graph_store(db_session, kb)
retrieved_kg = graph_store.get_subgraph_by_relationship_ids(
ids=stored_kg.relationships, query=stored_kg.query
)
return retrieved_kg
elif stored_kg.knowledge_base_ids is not None:
kg_store_map = {}
knowledge_base_set = set()
relationship_set = set()
entity_set = set()
subgraphs = []
for kb_id in stored_kg.knowledge_base_ids:
kb = knowledge_base_repo.must_get(db_session, kb_id)
knowledge_base_set.add(kb.to_descriptor())
kg_store = get_kb_tidb_graph_store(db_session, kb)
kg_store_map[kb_id] = kg_store
for stored_subgraph in stored_kg.subgraphs:
kg_store = kg_store_map.get(stored_subgraph.knowledge_base_id)
if kg_store is None:
continue
relationship_ids = stored_subgraph.relationships
subgraph = kg_store.get_subgraph_by_relationship_ids(
ids=relationship_ids,
query=stored_kg.query,
)
relationship_set.update(subgraph.relationships)
entity_set.update(subgraph.entities)
subgraphs.append(
RetrievedSubGraph(
**subgraph.model_dump(),
)
)
return KnowledgeGraphRetrievalResult(
query=stored_kg.query,
knowledge_bases=list(knowledge_base_set),
relationships=list(relationship_set),
entities=list(entity_set),
subgraphs=subgraphs,
)
else:
return None
def get_chat_message_subgraph(
db_session: Session, chat_message: DBChatMessage
) -> KnowledgeGraphRetrievalResult:
chat_engine: ChatEngine = chat_message.chat.engine
engine_name = chat_engine.name
engine_config = ChatEngineConfig.load_from_db(db_session, chat_engine.name)
# Try to get subgraph from `chat_message.graph_data`.
try:
knowledge_graph = get_graph_data_from_chat_message(
db_session, chat_message, engine_config
)
if knowledge_graph is not None:
return knowledge_graph
except Exception as e:
logger.error(
f"Failed to get subgraph from chat_message.graph_data: {e}", exc_info=True
)
# Try to get subgraph based on the chat message content.
# Notice: it use current chat engine config, not the snapshot stored in chat_message.
retriever = RetrieveFlow(
db_session=db_session,
engine_name=engine_name,
engine_config=engine_config,
)
knowledge_graph, _ = retriever.search_knowledge_graph(chat_message.content)
return knowledge_graph
def check_rag_required_config(session: Session) -> RequiredConfigStatus:
"""
Check if the required configuration items have been configured, it any of them is
missing, the RAG application can not complete its work.
"""
has_default_llm = llm_repo.has_default(session)
has_default_embedding_model = embedding_model_repo.has_default(session)
has_default_chat_engine = chat_engine_repo.has_default(session)
has_knowledge_base = session.scalar(select(func.count(DBKnowledgeBase.id))) > 0
return RequiredConfigStatus(
default_llm=has_default_llm,
default_embedding_model=has_default_embedding_model,
default_chat_engine=has_default_chat_engine,
knowledge_base=has_knowledge_base,
)
def check_rag_optional_config(session: Session) -> OptionalConfigStatus:
langfuse = bool(
SiteSetting.langfuse_host
and SiteSetting.langfuse_secret_key
and SiteSetting.langfuse_public_key
)
default_reranker = session.scalar(select(func.count(DBRerankerModel.id))) > 0
return OptionalConfigStatus(
langfuse=langfuse,
default_reranker=default_reranker,
)
def check_rag_config_need_migration(session: Session) -> NeedMigrationStatus:
"""
Check if any configuration needs to be migrated.
"""
chat_engines_without_kb_configured = session.exec(
select(ChatEngine.id)
.where(ChatEngine.deleted_at == None)
.where(
text(
"JSON_EXTRACT(engine_options, '$.knowledge_base.linked_knowledge_bases') IS NULL AND "
"JSON_EXTRACT(engine_options, '$.knowledge_base.linked_knowledge_base') IS NULL"
)
)
)
return NeedMigrationStatus(
chat_engines_without_kb_configured=chat_engines_without_kb_configured,
)
def remove_chat_message_recommend_questions(
db_session: Session,
chat_message_id: int,
) -> None:
delete_stmt = delete(RecommendQuestion).where(
RecommendQuestion.chat_message_id == chat_message_id
)
db_session.exec(delete_stmt)
db_session.commit()
def get_chat_message_recommend_questions(
db_session: Session,
chat_message: DBChatMessage,
engine_name: str = "default",
) -> List[str]:
chat_engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
llm = chat_engine_config.get_llama_llm(db_session)
statement = (
select(RecommendQuestion.questions)
.where(RecommendQuestion.chat_message_id == chat_message.id)
.with_for_update() # using write lock in case the same chat message trigger multiple requests
)
questions = db_session.exec(statement).first()
if questions is not None:
return questions
prompt_template = RichPromptTemplate(
chat_engine_config.llm.further_questions_prompt
)
recommend_questions = llm.predict(
prompt_template,
chat_message_content=chat_message.content,
)
recommend_question_list = recommend_questions.splitlines()
recommend_question_list = [
question.strip() for question in recommend_question_list if question.strip()
]
longest_question = 0
for question in recommend_question_list:
longest_question = max(longest_question, len(question))
# check the output by if the output with format and the length
if (
"##" in recommend_questions
or "**" in recommend_questions
or longest_question > 500
):
regenerate_content = f"""
Please note that you are generating a question list. You previously generated it incorrectly; try again.
----------------------------------------
{chat_message.content}
"""
# with format or too long for per question, it's not a question list, generate again
recommend_questions = llm.predict(
prompt_template,
chat_message_content=regenerate_content,
)
db_session.add(
RecommendQuestion(
chat_message_id=chat_message.id,
questions=recommend_question_list,
)
)
db_session.commit()
return recommend_question_list
================================================
FILE: backend/app/rag/chat/config.py
================================================
import logging
import dspy
from typing import Optional, List
from pydantic import BaseModel, Field
from sqlmodel import Session
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.llms.llm import LLM
from app.rag.postprocessors.metadata_post_filter import MetadataPostFilter
from app.rag.retrievers.chunk.schema import VectorSearchRetrieverConfig
from app.rag.retrievers.knowledge_graph.schema import KnowledgeGraphRetrieverConfig
from app.rag.llms.dspy import get_dspy_lm_by_llama_llm
from app.rag.llms.resolver import get_default_llm, resolve_llm
from app.rag.rerankers.resolver import get_default_reranker_model, resolve_reranker
from app.models import (
LLM as DBLLM,
RerankerModel as DBRerankerModel,
KnowledgeBase,
ChatEngine as DBChatEngine,
)
from app.repositories import chat_engine_repo, knowledge_base_repo
from app.rag.default_prompt import (
DEFAULT_INTENT_GRAPH_KNOWLEDGE,
DEFAULT_NORMAL_GRAPH_KNOWLEDGE,
DEFAULT_CONDENSE_QUESTION_PROMPT,
DEFAULT_TEXT_QA_PROMPT,
DEFAULT_FURTHER_QUESTIONS_PROMPT,
DEFAULT_GENERATE_GOAL_PROMPT,
DEFAULT_CLARIFYING_QUESTION_PROMPT,
)
logger = logging.getLogger(__name__)
class LLMOption(BaseModel):
intent_graph_knowledge: str = DEFAULT_INTENT_GRAPH_KNOWLEDGE
normal_graph_knowledge: str = DEFAULT_NORMAL_GRAPH_KNOWLEDGE
condense_question_prompt: str = DEFAULT_CONDENSE_QUESTION_PROMPT
clarifying_question_prompt: str = DEFAULT_CLARIFYING_QUESTION_PROMPT
text_qa_prompt: str = DEFAULT_TEXT_QA_PROMPT
further_questions_prompt: str = DEFAULT_FURTHER_QUESTIONS_PROMPT
generate_goal_prompt: str = DEFAULT_GENERATE_GOAL_PROMPT
class VectorSearchOption(VectorSearchRetrieverConfig):
pass
class KnowledgeGraphOption(KnowledgeGraphRetrieverConfig):
enabled: bool = True
using_intent_search: bool = True
class ExternalChatEngine(BaseModel):
# TODO: add enable flag for this config.
stream_chat_api_url: str = None
class LinkedKnowledgeBase(BaseModel):
id: int
class KnowledgeBaseOption(BaseModel):
linked_knowledge_base: LinkedKnowledgeBase = None
linked_knowledge_bases: Optional[List[LinkedKnowledgeBase]] = Field(
default_factory=list
)
class ChatEngineConfig(BaseModel):
external_engine_config: Optional[ExternalChatEngine] = None
llm: LLMOption = LLMOption()
knowledge_base: KnowledgeBaseOption = KnowledgeBaseOption()
knowledge_graph: KnowledgeGraphOption = KnowledgeGraphOption()
vector_search: VectorSearchOption = VectorSearchOption()
refine_question_with_kg: bool = True
clarify_question: bool = False
further_questions: bool = False
post_verification_url: Optional[str] = None
post_verification_token: Optional[str] = None
hide_sources: bool = False
_db_chat_engine: Optional[DBChatEngine] = None
_db_llm: Optional[DBLLM] = None
_db_fast_llm: Optional[DBLLM] = None
_db_reranker: Optional[DBRerankerModel] = None
@property
def is_external_engine(self) -> bool:
return (
self.external_engine_config is not None
and self.external_engine_config.stream_chat_api_url
)
def get_db_chat_engine(self) -> Optional[DBChatEngine]:
return self._db_chat_engine
def get_linked_knowledge_base(self, session: Session) -> KnowledgeBase | None:
if not self.knowledge_base:
return None
return knowledge_base_repo.must_get(
session, self.knowledge_base.linked_knowledge_base.id
)
@classmethod
def load_from_db(cls, session: Session, engine_name: str) -> "ChatEngineConfig":
if not engine_name or engine_name == "default":
db_chat_engine = chat_engine_repo.get_default_engine(session)
else:
db_chat_engine = chat_engine_repo.get_engine_by_name(session, engine_name)
if not db_chat_engine:
logger.warning(
f"Chat engine {engine_name} not found in DB, using default engine"
)
db_chat_engine = chat_engine_repo.get_default_engine(session)
obj = cls.model_validate(db_chat_engine.engine_options)
obj._db_chat_engine = db_chat_engine
obj._db_llm = db_chat_engine.llm
obj._db_fast_llm = db_chat_engine.fast_llm
obj._db_reranker = db_chat_engine.reranker
return obj
def get_llama_llm(self, session: Session) -> LLM:
if not self._db_llm:
return get_default_llm(session)
return resolve_llm(
self._db_llm.provider,
self._db_llm.model,
self._db_llm.config,
self._db_llm.credentials,
)
def get_dspy_lm(self, session: Session) -> dspy.LM:
llama_llm = self.get_llama_llm(session)
return get_dspy_lm_by_llama_llm(llama_llm)
def get_fast_llama_llm(self, session: Session) -> LLM:
if not self._db_fast_llm:
return get_default_llm(session)
return resolve_llm(
self._db_fast_llm.provider,
self._db_fast_llm.model,
self._db_fast_llm.config,
self._db_fast_llm.credentials,
)
def get_fast_dspy_lm(self, session: Session) -> dspy.LM:
llama_llm = self.get_fast_llama_llm(session)
return get_dspy_lm_by_llama_llm(llama_llm)
# FIXME: Reranker top_n should be config in the retrieval config.
def get_reranker(
self, session: Session, top_n: int = None
) -> Optional[BaseNodePostprocessor]:
if not self._db_reranker:
return get_default_reranker_model(session, top_n)
top_n = self._db_reranker.top_n if top_n is None else top_n
return resolve_reranker(
self._db_reranker.provider,
self._db_reranker.model,
top_n,
self._db_reranker.config,
self._db_reranker.credentials,
)
def get_metadata_filter(self) -> BaseNodePostprocessor:
return MetadataPostFilter(self.vector_search.metadata_filters)
def get_knowledge_bases(self, db_session: Session) -> List[KnowledgeBase]:
if not self.knowledge_base:
return []
kb_config: KnowledgeBaseOption = self.knowledge_base
linked_knowledge_base_ids = []
if len(kb_config.linked_knowledge_bases) == 0:
linked_knowledge_base_ids.append(kb_config.linked_knowledge_base.id)
else:
linked_knowledge_base_ids.extend(
[kb.id for kb in kb_config.linked_knowledge_bases]
)
knowledge_bases = knowledge_base_repo.get_by_ids(
db_session, knowledge_base_ids=linked_knowledge_base_ids
)
return knowledge_bases
def screenshot(self) -> dict:
return self.model_dump(
exclude={
"llm": [
"condense_question_prompt",
"text_qa_prompt",
"refine_prompt",
"intent_graph_knowledge",
"normal_graph_knowledge",
"generate_goal_prompt",
"further_questions_prompt",
"clarifying_question_prompt",
],
"post_verification_token": True,
}
)
================================================
FILE: backend/app/rag/chat/retrieve/retrieve_flow.py
================================================
import logging
from datetime import datetime
from typing import List, Optional, Tuple
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.llms import LLM
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.prompts.rich import RichPromptTemplate
from pydantic import BaseModel
from sqlmodel import Session
from app.models import (
Document as DBDocument,
KnowledgeBase,
)
from app.rag.chat.config import ChatEngineConfig
from app.rag.retrievers.knowledge_graph.fusion_retriever import (
KnowledgeGraphFusionRetriever,
)
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrievalResult,
KnowledgeGraphRetrieverConfig,
)
from app.rag.retrievers.chunk.fusion_retriever import ChunkFusionRetriever
from app.repositories import document_repo
dispatcher = get_dispatcher(__name__)
logger = logging.getLogger(__name__)
class SourceDocument(BaseModel):
id: int
name: str
source_uri: Optional[str] = None
class RetrieveFlow:
def __init__(
self,
db_session: Session,
engine_name: str = "default",
engine_config: Optional[ChatEngineConfig] = None,
llm: Optional[LLM] = None,
fast_llm: Optional[LLM] = None,
knowledge_bases: Optional[List[KnowledgeBase]] = None,
):
self.db_session = db_session
self.engine_name = engine_name
self.engine_config = engine_config or ChatEngineConfig.load_from_db(
db_session, engine_name
)
self.db_chat_engine = self.engine_config.get_db_chat_engine()
# Init LLM.
self._llm = llm or self.engine_config.get_llama_llm(self.db_session)
self._fast_llm = fast_llm or self.engine_config.get_fast_llama_llm(
self.db_session
)
# Load knowledge bases.
self.knowledge_bases = (
knowledge_bases or self.engine_config.get_knowledge_bases(self.db_session)
)
self.knowledge_base_ids = [kb.id for kb in self.knowledge_bases]
def retrieve(self, user_question: str) -> List[NodeWithScore]:
if self.engine_config.refine_question_with_kg:
# 1. Retrieve Knowledge graph related to the user question.
_, knowledge_graph_context = self.search_knowledge_graph(user_question)
# 2. Refine the user question using knowledge graph and chat history.
self._refine_user_question(user_question, knowledge_graph_context)
# 3. Search relevant chunks based on the user question.
return self.search_relevant_chunks(user_question=user_question)
def retrieve_documents(self, user_question: str) -> List[DBDocument]:
nodes = self.retrieve(user_question)
return self.get_documents_from_nodes(nodes)
def search_knowledge_graph(
self, user_question: str
) -> Tuple[KnowledgeGraphRetrievalResult, str]:
kg_config = self.engine_config.knowledge_graph
knowledge_graph = KnowledgeGraphRetrievalResult()
knowledge_graph_context = ""
if kg_config is not None and kg_config.enabled:
kg_retriever = KnowledgeGraphFusionRetriever(
db_session=self.db_session,
knowledge_base_ids=[kb.id for kb in self.knowledge_bases],
llm=self._llm,
use_query_decompose=kg_config.using_intent_search,
config=KnowledgeGraphRetrieverConfig.model_validate(
kg_config.model_dump(exclude={"enabled", "using_intent_search"})
),
)
knowledge_graph = kg_retriever.retrieve_knowledge_graph(user_question)
knowledge_graph_context = self._get_knowledge_graph_context(knowledge_graph)
return knowledge_graph, knowledge_graph_context
def _get_knowledge_graph_context(
self, knowledge_graph: KnowledgeGraphRetrievalResult
) -> str:
if self.engine_config.knowledge_graph.using_intent_search:
kg_context_template = RichPromptTemplate(
self.engine_config.llm.intent_graph_knowledge
)
return kg_context_template.format(
sub_queries=knowledge_graph.to_subqueries_dict(),
)
else:
kg_context_template = RichPromptTemplate(
self.engine_config.llm.normal_graph_knowledge
)
return kg_context_template.format(
entities=knowledge_graph.entities,
relationships=knowledge_graph.relationships,
)
def _refine_user_question(
self, user_question: str, knowledge_graph_context: str
) -> str:
prompt_template = RichPromptTemplate(
self.engine_config.llm.condense_question_prompt
)
refined_question = self._fast_llm.predict(
prompt_template,
graph_knowledges=knowledge_graph_context,
question=user_question,
current_date=datetime.now().strftime("%Y-%m-%d"),
)
return refined_question.strip().strip(".\"'!")
def search_relevant_chunks(self, user_question: str) -> List[NodeWithScore]:
retriever = ChunkFusionRetriever(
db_session=self.db_session,
knowledge_base_ids=self.knowledge_base_ids,
llm=self._llm,
config=self.engine_config.vector_search,
use_query_decompose=False,
)
return retriever.retrieve(QueryBundle(user_question))
def get_documents_from_nodes(self, nodes: List[NodeWithScore]) -> List[DBDocument]:
document_ids = [n.node.metadata["document_id"] for n in nodes]
documents = document_repo.fetch_by_ids(self.db_session, document_ids)
# Keep the original order of document ids, which is sorted by similarity.
return sorted(documents, key=lambda x: document_ids.index(x.id))
def get_source_documents_from_nodes(
self, nodes: List[NodeWithScore]
) -> List[SourceDocument]:
documents = self.get_documents_from_nodes(nodes)
return [
SourceDocument(
id=doc.id,
name=doc.name,
source_uri=doc.source_uri,
)
for doc in documents
]
================================================
FILE: backend/app/rag/chat/stream_protocol.py
================================================
import json
from dataclasses import dataclass
from pydantic import BaseModel
from app.models import ChatMessage, Chat
from app.rag.types import ChatEventType, ChatMessageSate
class ChatStreamPayload:
def dump(self):
pass
@dataclass
class ChatStreamDataPayload(ChatStreamPayload):
chat: Chat
user_message: ChatMessage
assistant_message: ChatMessage
def dump(self):
return [
{
"chat": self.chat.model_dump(mode="json"),
"user_message": self.user_message.model_dump(mode="json"),
"assistant_message": self.assistant_message.model_dump(mode="json"),
}
]
@dataclass
class ChatStreamMessagePayload(ChatStreamPayload):
state: ChatMessageSate = ChatMessageSate.TRACE
display: str = ""
context: dict | list | str | BaseModel | None = None
message: str = ""
def dump(self):
if isinstance(self.context, list):
context = [c.model_dump() for c in self.context]
elif isinstance(self.context, BaseModel):
context = self.context.model_dump()
else:
context = self.context
return [
{
"state": self.state.name,
"display": self.display,
"context": context,
"message": self.message,
}
]
@dataclass
class ChatEvent:
event_type: ChatEventType
payload: str | ChatStreamPayload | None = None
def encode(self, charset) -> bytes:
body = self.payload
if isinstance(body, ChatStreamPayload):
body = body.dump()
body = json.dumps(body, separators=(",", ":"))
return f"{self.event_type.value}:{body}\n".encode(charset)
================================================
FILE: backend/app/rag/datasource/__init__.py
================================================
from sqlmodel import Session
from typing import Any
from uuid import UUID
from app.models import DataSourceType
from .base import BaseDataSource
from .file import FileDataSource
from .web_sitemap import WebSitemapDataSource
from .web_single_page import WebSinglePageDataSource
def get_data_source_loader(
session: Session,
knowledge_base_id: int,
data_source_type: DataSourceType,
data_source_id: int,
user_id: UUID,
config: Any,
) -> BaseDataSource:
data_source_cls = None
match data_source_type:
case DataSourceType.FILE:
data_source_cls = FileDataSource
case DataSourceType.WEB_SITEMAP:
data_source_cls = WebSitemapDataSource
case DataSourceType.WEB_SINGLE_PAGE:
data_source_cls = WebSinglePageDataSource
case _:
raise ValueError("Data source type not supported")
return data_source_cls(session, knowledge_base_id, data_source_id, user_id, config)
================================================
FILE: backend/app/rag/datasource/base.py
================================================
from abc import ABC, abstractmethod
from uuid import UUID
from typing import Generator, Any
from sqlmodel import Session
from app.models import Document
class BaseDataSource(ABC):
session: Session
knowledge_base_id: int
data_source_id: int
user_id: UUID
config: Any
def __init__(
self,
session: Session,
knowledge_base_id: int,
data_source_id: int,
user_id: UUID,
config: Any,
**kwargs,
):
self.config = config
self.session = session
self.knowledge_base_id = knowledge_base_id
self.data_source_id = data_source_id
self.user_id = user_id
self.validate_config()
@abstractmethod
def validate_config(self):
raise NotImplementedError
@abstractmethod
def load_documents(self) -> Generator[Document, None, None]:
raise NotImplementedError
================================================
FILE: backend/app/rag/datasource/consts.py
================================================
IGNORE_TAGS = [
"noscript",
"title",
"script",
"style",
"meta",
"head",
"header",
"footer",
"nav",
"symbol",
"aside",
]
IGNORE_CLASSES = ["header", "footer", "sidebar"]
================================================
FILE: backend/app/rag/datasource/file.py
================================================
import logging
import docx
import pptx
import openpyxl
from pydantic import BaseModel
from typing import Generator, IO
from pypdf import PdfReader
from app.models import Document, Upload
from app.file_storage import default_file_storage
from app.types import MimeTypes
from .base import BaseDataSource
logger = logging.getLogger(__name__)
class FileConfig(BaseModel):
file_id: int
class FileDataSource(BaseDataSource):
def validate_config(self):
if not isinstance(self.config, list):
raise ValueError("config must be a list")
for f_config in self.config:
FileConfig.model_validate(f_config)
def load_documents(self) -> Generator[Document, None, None]:
for f_config in self.config:
upload_id = f_config["file_id"]
upload = self.session.get(Upload, upload_id)
if upload is None:
logger.error(f"Upload with id {upload_id} not found")
continue
with default_file_storage.open(upload.path) as f:
if upload.mime_type == MimeTypes.PDF:
content = extract_text_from_pdf(f)
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.DOCX:
content = extract_text_from_docx(f)
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.PPTX:
content = extract_text_from_pptx(f)
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.XLSX:
content = extract_text_from_xlsx(f)
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.MARKDOWN:
content = f.read()
mime_type = MimeTypes.MARKDOWN
else:
content = f.read()
mime_type = upload.mime_type
document = Document(
name=upload.name,
hash=hash(content),
content=content,
mime_type=mime_type,
knowledge_base_id=self.knowledge_base_id,
data_source_id=self.data_source_id,
user_id=self.user_id,
source_uri=upload.path,
last_modified_at=upload.created_at,
)
yield document
def extract_text_from_pdf(file: IO) -> str:
reader = PdfReader(file)
return "\n\n".join([page.extract_text() for page in reader.pages])
def extract_text_from_docx(file: IO) -> str:
document = docx.Document(file)
full_text = []
for paragraph in document.paragraphs:
full_text.append(paragraph.text)
return "\n\n".join(full_text)
def extract_text_from_pptx(file: IO) -> str:
presentation = pptx.Presentation(file)
full_text = []
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
full_text.append(shape.text)
return "\n\n".join(full_text)
def extract_text_from_xlsx(file: IO) -> str:
wb = openpyxl.load_workbook(file)
full_text = []
for sheet in wb.worksheets:
full_text.append(f"Sheet: {sheet.title}")
sheet_string = "\n".join(
",".join(map(str, row)) for row in sheet.iter_rows(values_only=True)
)
full_text.append(sheet_string)
return "\n\n".join(full_text)
================================================
FILE: backend/app/rag/datasource/web_base.py
================================================
import logging
from datetime import datetime, UTC
from typing import Generator
from playwright.sync_api import sync_playwright
from bs4 import BeautifulSoup
from markdownify import MarkdownConverter
from app.models import Document
from app.rag.datasource.consts import IGNORE_TAGS, IGNORE_CLASSES
logger = logging.getLogger(__name__)
def load_web_documents(
knowledge_base_id: int, data_source_id: int, urls: list[str]
) -> Generator[Document, None, None]:
visited = set()
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
for url in urls:
page = browser.new_page()
response = page.goto(url)
final_url = page.url
if final_url in visited:
continue
if response.status >= 400:
logger.error(
f"Failed to load page: {url}, response status: {response.status()}, skipping"
)
continue
soup = BeautifulSoup(page.content(), "html.parser")
for t in IGNORE_TAGS:
for tag in soup.find_all(t):
tag.extract()
for c in IGNORE_CLASSES:
for tag in soup.find_all(class_=c):
tag.extract()
content = MarkdownConverter().convert_soup(soup)
title = page.title()
visited.add(final_url)
document = Document(
name=title,
hash=hash(content),
content=content,
mime_type="text/plain",
knowledge_base_id=knowledge_base_id,
data_source_id=data_source_id,
source_uri=final_url,
last_modified_at=datetime.now(UTC),
)
yield document
browser.close()
================================================
FILE: backend/app/rag/datasource/web_single_page.py
================================================
import logging
from pydantic import BaseModel
from typing import Generator, List
from app.models import Document
from app.rag.datasource.base import BaseDataSource
from app.rag.datasource.web_base import load_web_documents
logger = logging.getLogger(__name__)
class WebSinglePageConfig(BaseModel):
urls: List[str]
class WebSinglePageDataSource(BaseDataSource):
def validate_config(self):
WebSinglePageConfig.model_validate(self.config)
def load_documents(self) -> Generator[Document, None, None]:
if "url" in self.config:
# TODO: remove this once we have a proper config
urls = [self.config["url"]]
else:
urls = self.config["urls"]
yield from load_web_documents(self.knowledge_base_id, self.data_source_id, urls)
================================================
FILE: backend/app/rag/datasource/web_sitemap.py
================================================
import logging
from typing import Generator
from urllib.parse import urlparse, urljoin
import requests
from pydantic import BaseModel
from bs4 import BeautifulSoup
from app.models import Document
from app.rag.datasource.base import BaseDataSource
from app.rag.datasource.web_base import load_web_documents
logger = logging.getLogger(__name__)
class WebSitemapConfig(BaseModel):
url: str
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
if not urlparse(maybe_relative_url).netloc:
return urljoin(source_url, maybe_relative_url)
return maybe_relative_url
def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
response = requests.get(sitemap_url)
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
result = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
if not result:
raise ValueError(f"No URLs found in sitemap {sitemap_url}")
return result
class WebSitemapDataSource(BaseDataSource):
def validate_config(self):
WebSitemapConfig.model_validate(self.config)
def load_documents(self) -> Generator[Document, None, None]:
sitemap_url = self.config["url"]
urls = extract_urls_from_sitemap(sitemap_url)
logger.info(f"Found {len(urls)} URLs in sitemap {sitemap_url}")
yield from load_web_documents(self.knowledge_base_id, self.data_source_id, urls)
================================================
FILE: backend/app/rag/default_prompt.py
================================================
DEFAULT_INTENT_GRAPH_KNOWLEDGE = """\
Given a list of prerequisite questions and their relevant knowledge for the user's main question, when conflicts in meaning arise, prioritize the relationship with the higher weight and the more recent version.
Knowledge sub-queries:
{% for sub_query, data in sub_queries.items() %}
Sub-query: {{ sub_query }}
- Entities:
{% for entity in data['entities'] %}
- Name: {{ entity.name }}
Description: {{ entity.description }}
{% endfor %}
- Relationships:
{% for relationship in data['relationships'] %}
- Description: {{ relationship.rag_description }}
Weight: {{ relationship.weight }}
{% endfor %}
{% endfor %}
"""
DEFAULT_NORMAL_GRAPH_KNOWLEDGE = """\
Given a list of relationships of a knowledge graph as follows. When there is a conflict in meaning between knowledge relationships, the relationship with the higher `weight` and newer `last_modified_at` value takes precedence.
---------------------
Entities:
{% for entity in entities %}
- Name: {{ entity.name }}
Description: {{ entity.description }}
{% endfor %}
---------------------
Knowledge relationships:
{% for relationship in relationships %}
- Description: {{ relationship.rag_description }}
- Weight: {{ relationship.weight }}
- Last Modified At: {{ relationship.last_modified_at }}
- Meta: {{ relationship.meta | tojson(indent=2) }}
{% endfor %}
"""
DEFAULT_CLARIFYING_QUESTION_PROMPT = """\
---------------------
The prerequisite questions and their relevant knowledge for the user's main question.
---------------------
{{graph_knowledges}}
---------------------
Task:
Given the conversation between the user and ASSISTANT, along with the follow-up message from the user, and the provided prerequisite questions and relevant knowledge, determine if the user's question is clear and specific enough for a confident response.
If the question lacks necessary details or context, identify the specific ambiguities and generate a clarifying question to address them.
If the question is clear and answerable, return exact "False" as the response.
Instructions:
1. Assess Information Sufficiency:
- Evaluate if the user's question provides enough detail to generate a precise answer based on the prerequisite questions, relevant knowledge, and conversation history.
- If the user's question is too vague or lacks key information, identify what additional information would be necessary for clarity.
2. Generate a Clarifying Question:
- If the question is clear and answerable, return exact "False" as the response.
- If clarification is needed, return a specific question to ask the user, directly addressing the information gap. Avoid general questions; focus on the specific details required for an accurate answer.
3. Use the same language to ask the clarifying question as the user's original question.
Example 1:
user: "Does TiDB support foreign keys?"
Relevant Knowledge: TiDB supports foreign keys starting from version 6.6.0.
Response:
Which version of TiDB are you using?
Example 2:
user: "Does TiDB support nested transaction?"
Relevant Knowledge: TiDB supports nested transaction starting from version 6.2.0.
Response:
Which version of TiDB are you using?
Example 3:
user: "Does TiDB support foreign keys? I'm using TiDB 6.5.0."
Relevant Knowledge: TiDB supports foreign keys starting from version 6.6.0.
Response:
False
Your Turn:
Chat history:
{{chat_history}}
---------------------
Follow-up question:
{{question}}
Response:
"""
DEFAULT_CONDENSE_QUESTION_PROMPT = """\
Current Date: {{current_date}}
---------------------
The prerequisite questions and their relevant knowledge for the user's main question.
---------------------
{{graph_knowledges}}
---------------------
Task:
Given the conversation between the Human and Assistant, along with the follow-up message from the Human, and the provided prerequisite questions and relevant knowledge, refine the Human’s follow-up message into a standalone, detailed question.
Instructions:
1. Focus on the latest query from the Human, ensuring it is given the most weight.
2. Incorporate Key Information:
- Use the prerequisite questions and their relevant knowledge to add specific details to the follow-up question.
- Replace ambiguous terms or references in the follow-up question with precise information from the provided knowledge. Example: Replace “latest version” with the actual version number mentioned in the knowledge.
3. Utilize Conversation Context:
- Incorporate relevant context and background information from the conversation history to enhance the question's specificity.
4. Optimize for Retrieval:
- Ensure the refined question emphasizes specific and relevant terms to maximize the effectiveness of a vector search for retrieving precise and comprehensive information.
5. Grounded and Factual:
- Make sure the refined question is grounded in and directly based on the user's follow-up question and the provided knowledge.
- Do not introduce information that is not supported by the knowledge or conversation history.
6. Give the language hint for the answer:
- Add a hint after the question like "(Answer language: English)", or "(Answer language: Chinese)", etc.
- This language hint should be exactly same with the language of the original question.
- If the original question has part of other language aside from English, please use the language of another language rather than English. Example: "tidb tableread慢会是哪些原因", it should be Chinese.
Example:
Chat History:
Human: "I'm interested in the performance improvements in the latest version of TiDB."
Assistant: "TiDB version 8.1 was released recently with significant performance enhancements over version 6.5."
Follow-up Question:
"Can you tell me more about these improvements?"
Prerequisite Questions and Relevant Knowledge:
- Prerequisite Question: What is the latest version of TiDB?
- Relevant Knowledge: The latest version of TiDB is 8.1.
...
Refined Standalone Question:
"Can you provide detailed information about the performance improvements introduced in TiDB version 8.1 compared to version 6.5? (Answer language: English)"
Your Turn:
Chat history:
{{chat_history}}
---------------------
Followup question:
{{question}}
---------------------
Refined standalone question:
"""
DEFAULT_TEXT_QA_PROMPT = """\
Current Date: {{current_date}}
---------------------
Knowledge graph information is below
---------------------
{{graph_knowledges}}
---------------------
Context information is below.
---------------------
{{context_str}}
---------------------
Answer Format:
Use markdown footnote syntax (for example: [^1]) to indicate sources you used.
Each footnote must correspond to a unique source. Do not use the same source for multiple footnotes.
### Examples of Correct Footnote Usage (no the unique sources and diverse sources):
[^1]: [TiDB Overview | PingCAP Docs](https://docs.pingcap.com/tidb/stable/overview)
[^2]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/stable/architecture)
### Examples of Incorrect Footnote Usage (Avoid duplicating the same source for multiple footnotes):
[^1]: [TiDB Introduction | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)
[^2]: [TiDB Introduction | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)
[^3]: [TiDB Introduction | PingCAP Docs](https://docs.pingcap.com/tidb/dev/overview)
[^4]: [TiDB Introduction | PingCAP Docs](https://docs.pingcap.com/tidb/stable/overview)
---------------------
Answer Language:
Follow the language of the language hint after the Refined Question.
If the language hint is not provided, use the language that the original questions used.
---------------------
As a customer support assistant, please do not fabricate any knowledge. If you cannot get knowledge from the context, please just directly state "you do not know", rather than constructing nonexistent and potentially fake information!!!
First, analyze the provided context information without assuming prior knowledge. Identify all relevant aspects of knowledge contained within. Then, from various perspectives and angles, answer questions as thoroughly and comprehensively as possible to better address and resolve the user's issue.
The Original questions is:
{{original_question}}
The Refined Question used to search:
{{query_str}}
Answer:
"""
DEFAULT_FURTHER_QUESTIONS_PROMPT = """\
The chat message content is:
{{chat_message_content}}
---------------------
Task:
Based on the provided chat message, generate 3–5 follow-up questions that are relevant to the content. Each question should explore the topic in greater detail, seek clarification, or introduce new angles for discussion.
Instructions:
1. Build upon the key information, themes, or insights within the provided chat message.
2. Aim for variety in question type (clarifying, probing, or exploratory) to encourage a deeper conversation.
3. Ensure each question logically follows from the context of the provided chat message.
4. Keep questions concise yet insightful to maximize engagement.
5. Use the same language with the chat message content.
6. Each question should end with a question mark.
7. Each question should be in a new line, DO NOT add any indexes or blank lines, just output the questions.
Now, generate 3–5 follow-up questions below:
"""
DEFAULT_GENERATE_GOAL_PROMPT = """\
Given the conversation history between the User and Assistant, along with the latest follow-up question from the User, perform the following tasks:
1. **Language Detection**:
- Analyze the User's follow-up question to determine the language used.
2. **Context Classification**:
- **Determine Relevance to TiDB**:
- Assess whether the follow-up question is related to TiDB products, support, or any TiDB-related context.
- **Set Background Accordingly**:
- **If Related to TiDB**:
- Set the background to encompass the relevant TiDB context. This may include aspects like TiDB features, configurations, best practices, troubleshooting, or general consulting related to TiDB.
- Example backgrounds:
- "TiDB product configuration and optimization."
- "TiDB troubleshooting and support."
- "TiDB feature consultation."
- **If Unrelated to TiDB**:
- Set the background to "Other topics."
3. **Goal Generation**:
- **Clarify Intent to Avoid Ambiguity**:
- **Instructional Guidance**:
- If the User's question seeks guidance or a method (e.g., starts with "How to"), ensure the goal reflects a request for a step-by-step guide or best practices.
- **Information Retrieval**:
- If the User's question seeks specific information or confirmation (e.g., starts with "Can you" or "Is it possible"), rephrase it to focus on providing the requested information or verification without implying that the assistant should perform any actions.
- **Important**: Do not interpret these questions as requests for the assistant to execute operations. Instead, understand whether the user seeks to confirm certain information or requires a proposed solution, and restrict responses to information retrieval and guidance based on available documentation.
- **Reformulate the Latest User Follow-up Question**:
- Ensure the question is clear, directive, and suitable for a Q&A format.
- **Specify Additional Details**:
- **Detected Language**: Clearly indicate the language.
- **Desired Answer Format**: Specify if the answer should be in text, table, code snippet, etc.
- **Additional Requirements**: Include any other necessary instructions to tailor the response appropriately.
4. **Output**:
- Produce a goal string in the following format:
"[Refined Question] (Lang: [Detected Language], Format: [Format], Background: [Specified Goal Scenario])"
**Examples**:
**Example 1**:
Chat history:
[]
Follow-up question:
"tidb encryption at rest 会影响数据压缩比例吗?"
Goal:
Does encryption at rest in TiDB affect the data compression ratio? (Lang: Chinese, Format: text, Background: TiDB product related consulting.)
---------------------
**Example 2**:
Chat history:
[]
Follow-up question:
"干嘛的?"
Goal:
What can you do? (Lang: Chinese, Format: text, Background: General inquiry about the assistant's capabilities.)
---------------------
**Example 3**:
Chat history:
[]
Follow-up question:
"oracle 怎么样?"
Goal:
How is Oracle? (Lang: Chinese, Format: text, Background: Other topics.)
---------------------
**Example 4**:
Chat history:
[]
Follow-up question:
"Why is TiDB Serverless up to 70% cheaper than MySQL RDS? (use a table if possible)"
Goal:
Why is TiDB Serverless up to 70% cheaper than MySQL RDS? Please provide a comparison in a table format if possible. (Lang: English, Format: table, Background: Cost comparison between TiDB Serverless and MySQL RDS.)
---------------------
**Example 5 (Enhanced for Clarity and Guidance)**:
Chat history:
[]
Follow-up question:
"能否找到 tidb 中哪些视图的定义中包含已经被删除的表?"
Goal:
How to find which views in TiDB have definitions that include tables that have been deleted? (Lang: Chinese, Format: text, Background: TiDB product related consulting.)
---------------------
**Your Task**:
Chat history:
{{chat_history}}
Follow-up question:
{{question}}
Goal:
"""
================================================
FILE: backend/app/rag/embeddings/local/local_embedding.py
================================================
"""Jina embeddings file."""
from typing import Any, List, Optional
import requests
from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
MAX_BATCH_SIZE = 500
class _APICaller:
def __init__(
self,
model: str,
api_url: str,
normalize_embeddings: bool = True,
**kwargs: Any,
) -> None:
self.api_url = api_url
self.model = model
self.normalize_embeddings = normalize_embeddings
self._session = requests.Session()
def get_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Get embeddings."""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
self.api_url,
json={
"sentences": sentences,
"model": self.model,
"normalize_embeddings": self.normalize_embeddings,
},
)
resp.raise_for_status()
resp_json = resp.json()
if "embeddings" not in resp_json:
raise RuntimeError(
f"Call local embedding api {self.api_url} failed {resp.status_code}"
)
return resp_json["embeddings"]
async def aget_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
import aiohttp
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
self.api_url,
json={
"sentences": sentences,
"model": self.model,
},
) as response:
resp = await response.json()
response.raise_for_status()
return resp["embeddings"]
class LocalEmbedding(BaseEmbedding):
model: str = Field(
default="BAAI/bge-m3",
description="The model to use when calling Jina AI API",
)
_encoding_queries: str = PrivateAttr()
_encoding_documents: str = PrivateAttr()
_api: Any = PrivateAttr()
def __init__(
self,
model: str = "BAAI/bge-m3",
api_url: str = "http://127.0.0.1:5001/api/v1/embedding",
normalize_embeddings: bool = True,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
**kwargs,
)
self._api: _APICaller = _APICaller(
model=model, api_url=api_url, normalize_embeddings=normalize_embeddings
)
@classmethod
def class_name(cls) -> str:
return "LocalEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._api.get_embeddings([query])[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self._api.aget_embeddings([query])
return result[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._api.get_embeddings(texts)
async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
return await self._api.aget_embeddings(texts)
================================================
FILE: backend/app/rag/embeddings/open_like/openai_like_embedding.py
================================================
from typing import Any, List, Optional
from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from openai import OpenAI, AsyncOpenAI
class OpenAILikeEmbedding(BaseEmbedding):
# We cannot directly call the llama-index's API because it limited the model name
# And the 'embedding-2' or 'embedding-3' is not one of the OpenAI's model name
model: str = Field(
default="embedding-3",
description="The model to use when calling Zhipu AI API",
)
_client: OpenAI = PrivateAttr()
_aclient: AsyncOpenAI = PrivateAttr()
def __init__(
self,
api_key: str,
model: str = "embedding-3",
api_base: str = "https://open.bigmodel.cn/api/paas/v4/",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
**kwargs,
)
self._model_kwargs = kwargs or {}
self.model = model
self._client = OpenAI(api_key=api_key, base_url=api_base)
self._aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
def get_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Get embeddings."""
# Call Zhipu AI Embedding API via OpenAI client
embedding_objs = self._client.embeddings.create(
input=sentences, model=self.model, **self._model_kwargs
).data
embeddings = [obj.embedding for obj in embedding_objs]
return embeddings
async def aget_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
result = await self._aclient.embeddings.create(
input=sentences, model=self.model, **self._model_kwargs
)
embeddings = [obj.embedding for obj in result.data]
return embeddings
@classmethod
def class_name(cls) -> str:
return "OpenAILikeEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self.get_embeddings([query])[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self.aget_embeddings([query])
return result[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self.get_embeddings(texts)
async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
return await self.aget_embeddings(texts)
================================================
FILE: backend/app/rag/embeddings/provider.py
================================================
import enum
from typing import List
from pydantic import BaseModel
class EmbeddingProvider(str, enum.Enum):
OPENAI = "openai"
JINA = "jina"
COHERE = "cohere"
BEDROCK = "bedrock"
OLLAMA = "ollama"
GITEEAI = "giteeai"
LOCAL = "local"
OPENAI_LIKE = "openai_like"
AZURE_OPENAI = "azure_openai"
class EmbeddingProviderOption(BaseModel):
provider: EmbeddingProvider
provider_display_name: str | None = None
provider_description: str | None = None
provider_url: str | None = None
default_embedding_model: str
embedding_model_description: str
default_config: dict = {}
config_description: str = ""
default_credentials: str | dict = ""
credentials_display_name: str
credentials_description: str
credentials_type: str = "str"
embedding_provider_options: List[EmbeddingProviderOption] = [
EmbeddingProviderOption(
provider=EmbeddingProvider.OPENAI,
provider_display_name="OpenAI",
provider_description="The OpenAI API provides a simple interface for developers to create an intelligence layer in their applications, powered by OpenAI's state of the art models.",
provider_url="https://platform.openai.com",
default_embedding_model="text-embedding-3-small",
embedding_model_description="Find more information about OpenAI Embedding at https://platform.openai.com/docs/guides/embeddings",
credentials_display_name="OpenAI API Key",
credentials_description="The API key of OpenAI, you can find it in https://platform.openai.com/api-keys",
credentials_type="str",
default_credentials="sk-****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.JINA,
provider_display_name="JinaAI",
provider_description="Jina AI provides multimodal, bilingual long-context embeddings for search and RAG",
provider_url="https://jina.ai/embeddings/",
default_embedding_model="jina-embeddings-v2-base-en",
embedding_model_description="Find more information about Jina AI Embeddings at https://jina.ai/embeddings/",
credentials_display_name="Jina API Key",
credentials_description="The API key of Jina, you can find it in https://jina.ai/embeddings/",
credentials_type="str",
default_credentials="jina_****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.COHERE,
provider_display_name="Cohere",
provider_description="Cohere provides industry-leading large language models (LLMs) and RAG capabilities tailored to meet the needs of enterprise use cases that solve real-world problems.",
provider_url="https://cohere.com/embeddings",
default_embedding_model="embed-multilingual-v3.0",
embedding_model_description="Documentation: https://docs.cohere.com/docs/cohere-embed",
credentials_display_name="Cohere API Key",
credentials_description="You can get one from https://dashboard.cohere.com/api-keys",
credentials_type="str",
default_credentials="*****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.BEDROCK,
provider_display_name="Bedrock",
provider_description="Amazon Bedrock is a fully managed foundation models service.",
provider_url="https://docs.aws.amazon.com/bedrock/",
default_embedding_model="amazon.titan-embed-text-v2:0",
embedding_model_description="",
credentials_display_name="AWS Bedrock Credentials JSON",
credentials_description="The JSON Object of AWS Credentials, refer to https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global",
credentials_type="dict",
default_credentials={
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2",
},
),
EmbeddingProviderOption(
provider=EmbeddingProvider.OLLAMA,
provider_display_name="Ollama",
provider_description="Ollama is a lightweight framework for building and running large language models and embed models.",
provider_url="https://ollama.com",
default_embedding_model="nomic-embed-text",
embedding_model_description="Documentation: https://ollama.com/blog/embedding-models",
default_config={
"api_base": "http://localhost:11434",
},
config_description="api_base is the base URL of the Ollama server, ensure it can be accessed from this server.",
credentials_display_name="Ollama API Key",
credentials_description="Ollama doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.OPENAI_LIKE,
provider_display_name="OpenAI Like",
provider_description="OpenAI-Like is a set of platforms that provide text embeddings similar to OpenAI. Such as ZhiPuAI.",
provider_url="https://open.bigmodel.cn/dev/api/vector/embedding-3",
default_embedding_model="embedding-3",
embedding_model_description="",
credentials_display_name="OpenAI Like API Key",
credentials_description="The API key of OpenAI Like. For ZhipuAI, you can find it in https://open.bigmodel.cn/usercenter/apikeys",
credentials_type="str",
default_credentials="dummy",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.GITEEAI,
provider_display_name="Gitee AI",
provider_description="Gitee AI is a third-party model provider that offers ready-to-use cutting-edge model APIs for AI developers.",
provider_url="https://ai.gitee.com",
default_embedding_model="bge-large-zh-v1.5",
embedding_model_description="Find more information about Gitee AI Embeddings at https://ai.gitee.com/docs/openapi/v1#tag/%E7%89%B9%E5%BE%81%E6%8A%BD%E5%8F%96/POST/embeddings",
credentials_display_name="Gitee AI API Key",
credentials_description="The API key of Gitee AI, you can find it in https://ai.gitee.com/dashboard/settings/tokens",
credentials_type="str",
default_credentials="****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.AZURE_OPENAI,
provider_display_name="Azure OpenAI",
provider_description="Azure OpenAI is a cloud-based AI service that provides a suite of AI models and tools for developers to build intelligent applications.",
provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service",
default_embedding_model="text-embedding-3-small",
embedding_model_description="Before using this option, you need to deploy an Azure OpenAI API and model, see https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource.",
default_config={
"azure_endpoint": "https://.openai.azure.com/",
"api_version": "",
},
credentials_display_name="Azure OpenAI API Key",
credentials_description="The API key of Azure OpenAI",
credentials_type="str",
default_credentials="****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.LOCAL,
provider_display_name="Local Embedding",
provider_description="Autoflow's local embedding server, deployed on your own infrastructure and powered by sentence-transformers.",
default_embedding_model="BAAI/bge-m3",
embedding_model_description="Find more models in huggingface.",
default_config={
"api_url": "http://local-embedding-reranker:5001/api/v1/embedding",
},
config_description="api_url is the embedding endpoint url serviced by the autoflow local embedding server.",
credentials_display_name="Local Embedding API Key",
credentials_description="Local Embedding server doesn't require an API key, set a dummy string here is ok.",
credentials_type="str",
default_credentials="dummy",
),
]
================================================
FILE: backend/app/rag/embeddings/resolver.py
================================================
from typing import Optional
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from sqlmodel import Session
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.jinaai import JinaEmbedding
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.embeddings.bedrock import BedrockEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from app.rag.embeddings.open_like.openai_like_embedding import OpenAILikeEmbedding
from app.rag.embeddings.local.local_embedding import LocalEmbedding
from app.repositories.embedding_model import embedding_model_repo
from app.rag.embeddings.provider import EmbeddingProvider
def resolve_embed_model(
provider: EmbeddingProvider,
model: str,
config: dict,
credentials: str | list | dict | None,
) -> BaseEmbedding:
match provider:
case EmbeddingProvider.OPENAI:
return OpenAIEmbedding(
model=model,
api_key=credentials,
**config,
)
case EmbeddingProvider.JINA:
return JinaEmbedding(
model=model,
api_key=credentials,
**config,
)
case EmbeddingProvider.COHERE:
return CohereEmbedding(
model_name=model,
cohere_api_key=credentials,
**config,
)
case EmbeddingProvider.BEDROCK:
return BedrockEmbedding(
model_name=model,
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
region_name=credentials["aws_region_name"],
**config,
)
case EmbeddingProvider.OLLAMA:
return OllamaEmbedding(
model_name=model,
**config,
)
case EmbeddingProvider.LOCAL:
return LocalEmbedding(
model=model,
**config,
)
case EmbeddingProvider.GITEEAI:
return OpenAILikeEmbedding(
model=model,
api_base="https://ai.gitee.com/v1",
api_key=credentials,
**config,
)
case EmbeddingProvider.AZURE_OPENAI:
return AzureOpenAIEmbedding(
model=model,
api_key=credentials,
**config,
)
case EmbeddingProvider.OPENAI_LIKE:
return OpenAILikeEmbedding(
model=model,
api_key=credentials,
**config,
)
case _:
raise ValueError(f"Got unknown embedding provider: {provider}")
def get_default_embed_model(session: Session) -> Optional[BaseEmbedding]:
db_embed_model = embedding_model_repo.get_default(session)
if not db_embed_model:
return None
return resolve_embed_model(
db_embed_model.provider,
db_embed_model.model,
db_embed_model.config,
db_embed_model.credentials,
)
def must_get_default_embed_model(session: Session) -> BaseEmbedding:
db_embed_model = embedding_model_repo.must_get_default(session)
return resolve_embed_model(
db_embed_model.provider,
db_embed_model.model,
db_embed_model.config,
db_embed_model.credentials,
)
================================================
FILE: backend/app/rag/indices/__init__.py
================================================
================================================
FILE: backend/app/rag/indices/knowledge_graph/__init__.py
================================================
from .base import KnowledgeGraphIndex
__all__ = ["KnowledgeGraphIndex"]
================================================
FILE: backend/app/rag/indices/knowledge_graph/base.py
================================================
import dspy
import logging
from typing import Any, Dict, List, Optional, Sequence
from llama_index.core.data_structs import IndexLPG
from llama_index.core.callbacks import CallbackManager
from llama_index.core.indices.base import BaseIndex
from llama_index.core.storage.docstore.types import RefDocInfo
from llama_index.core.storage.storage_context import StorageContext
from llama_index.core.schema import BaseNode, TransformComponent
import llama_index.core.instrumentation as instrument
from app.rag.indices.knowledge_graph.extractor import SimpleGraphExtractor
from app.rag.indices.knowledge_graph.graph_store import KnowledgeGraphStore
logger = logging.getLogger(__name__)
dispatcher = instrument.get_dispatcher(__name__)
class KnowledgeGraphIndex(BaseIndex[IndexLPG]):
"""An index for a property graph.
Args:
nodes (Optional[Sequence[BaseNode]]):
A list of nodes to insert into the index.
dspy_lm (dspy.BaseLLM):
The language model of dspy to use for extracting triplets.
callback_manager (Optional[CallbackManager]):
The callback manager to use.
transformations (Optional[List[TransformComponent]]):
A list of transformations to apply to the nodes before inserting them into the index.
These are applied prior to the `kg_extractors`.
storage_context (Optional[StorageContext]):
The storage context to use.
show_progress (bool):
Whether to show progress bars for transformations. Defaults to `False`.
"""
index_struct_cls = IndexLPG
def __init__(
self,
dspy_lm: dspy.LM,
kg_store: KnowledgeGraphStore,
nodes: Optional[Sequence[BaseNode]] = None,
# parent class params
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
self._dspy_lm = dspy_lm
self._kg_store = kg_store
super().__init__(
nodes=nodes,
callback_manager=callback_manager,
**kwargs,
)
@classmethod
def from_existing(
cls: "KnowledgeGraphIndex",
dspy_lm: dspy.LM,
kg_store: KnowledgeGraphStore,
# parent class params
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
storage_context: Optional[StorageContext] = None,
show_progress: bool = False,
**kwargs: Any,
) -> "KnowledgeGraphIndex":
return cls(
dspy_lm=dspy_lm,
kg_store=kg_store,
nodes=[], # no nodes to insert
callback_manager=callback_manager,
transformations=transformations,
storage_context=storage_context,
show_progress=show_progress,
**kwargs,
)
def _insert_nodes(self, nodes: Sequence[BaseNode]):
"""Insert nodes to the index struct."""
if len(nodes) == 0:
return nodes
extractor = SimpleGraphExtractor(dspy_lm=self._dspy_lm)
for node in nodes:
entities_df, rel_df = extractor.extract(
text=node.get_content(),
node=node,
)
self._kg_store.save(node.node_id, entities_df, rel_df)
def _build_index_from_nodes(self, nodes: Optional[Sequence[BaseNode]]) -> IndexLPG:
"""Build index from nodes."""
nodes = self._insert_nodes(nodes or [])
return IndexLPG()
def as_retriever(self, **kwargs: Any):
"""Return a retriever for the index."""
# Our retriever params is more complex than the base retriever,
# so we can't use the base retriever.
raise NotImplementedError(
"Retriever not implemented for KnowledgeGraphIndex, use `retrieve_with_weight` instead."
)
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
"""Index-specific logic for inserting nodes to the index struct."""
self._insert_nodes(nodes)
def ref_doc_info(self) -> Dict[str, RefDocInfo]:
"""Retrieve a dict mapping of ingested documents and their nodes+metadata."""
raise NotImplementedError(
"Ref doc info not implemented for KnowledgeGraphIndex. "
"All inserts are already upserts."
)
def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
"""Delete a node."""
raise NotImplementedError(
"Delete node not implemented for KnowledgeGraphIndex."
)
================================================
FILE: backend/app/rag/indices/knowledge_graph/extractor.py
================================================
import logging
from copy import deepcopy
import pandas as pd
import dspy
from typing import Mapping, Optional, List
from dspy import Predict
from llama_index.core.schema import BaseNode
from app.rag.indices.knowledge_graph.schema import (
Entity,
Relationship,
KnowledgeGraph,
EntityCovariateInput,
EntityCovariateOutput,
)
logger = logging.getLogger(__name__)
class ExtractGraphTriplet(dspy.Signature):
"""Carefully analyze the provided text from database documentation and community blogs to thoroughly identify all entities related to database technologies, including both general concepts and specific details.
Follow these Step-by-Step Analysis:
1. Extract Meaningful Entities:
- Identify all significant nouns, proper nouns, and technical terminologies that represent database-related concepts, objects, components, features, issues, key steps, execute order, user case, locations, versions, or any substantial entities.
- Ensure that you capture entities across different levels of detail, from high-level overviews to specific technical specifications, to create a comprehensive representation of the subject matter.
- Choose names for entities that are specific enough to indicate their meaning without additional context, avoiding overly generic terms.
- Consolidate similar entities to avoid redundancy, ensuring each represents a distinct concept at appropriate granularity levels.
2. Extract Metadata to claim the entities:
- Carefully review the provided text, focusing on identifying detailed covariates associated with each entity.
- Extract and link the covariates (which is a comprehensive json TREE, the first field is always: "topic") to their respective entities.
- Ensure all extracted covariates is clearly connected to the correct entity for accuracy and comprehensive understanding.
- Ensure that all extracted covariates are factual and verifiable within the text itself, without relying on external knowledge or assumptions.
- Collectively, the covariates should provide a thorough and precise summary of the entity's characteristics as described in the source material.
3. Establish Relationships:
- Carefully examine the text to identify all relationships between clearly-related entities, ensuring each relationship is correctly captured with accurate details about the interactions.
- Analyze the context and interactions between the identified entities to determine how they are interconnected, focusing on actions, associations, dependencies, or similarities.
- Clearly define the relationships, ensuring accurate directionality that reflects the logical or functional dependencies among entities. \
This means identifying which entity is the source, which is the target, and what the nature of their relationship is (e.g., $source_entity depends on $target_entity for $relationship).
Some key points to consider:
- Please endeavor to extract all meaningful entities and relationships from the text, avoid subsequent additional gleanings.
Objective: Produce a detailed and comprehensive knowledge graph that captures the full spectrum of entities mentioned in the text, along with their interrelations, reflecting both broad concepts and intricate details specific to the database domain.
Please only response in JSON format.
"""
text = dspy.InputField(
desc="a paragraph of text to extract entities and relationships to form a knowledge graph"
)
knowledge: KnowledgeGraph = dspy.OutputField(
desc="Graph representation of the knowledge extracted from the text."
)
class ExtractCovariate(dspy.Signature):
"""Please carefully review the provided text and entities list which are already identified in the text. Focusing on identifying detailed covariates associated with each entities provided.
Extract and link the covariates (which is a comprehensive json TREE, the first field is always: "topic") to their respective entities.
Ensure all extracted covariates is clearly connected to the correct entity for accuracy and comprehensive understanding.
Ensure that all extracted covariates are factual and verifiable within the text itself, without relying on external knowledge or assumptions.
Collectively, the covariates should provide a thorough and precise summary of the entity's characteristics as described in the source material.
Please only response in JSON format.
"""
text = dspy.InputField(
desc="a paragraph of text to extract covariates to claim the entities."
)
entities: List[EntityCovariateInput] = dspy.InputField(
desc="List of entities identified in the text."
)
covariates: List[EntityCovariateOutput] = dspy.OutputField(
desc="Graph representation of the knowledge extracted from the text."
)
def get_relation_metadata_from_node(node: BaseNode):
metadata = deepcopy(node.metadata)
for key in [
"_node_content",
"_node_type",
"excerpt_keywords",
"questions_this_excerpt_can_answer",
"section_summary",
]:
metadata.pop(key, None)
metadata["chunk_id"] = node.node_id
return metadata
class Extractor(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.prog_graph = Predict(ExtractGraphTriplet)
self.prog_covariates = Predict(ExtractCovariate)
def forward(self, text):
with dspy.settings.context(lm=self.dspy_lm):
pred_graph = self.prog_graph(text=text)
# extract the covariates
entities_for_covariates = [
EntityCovariateInput(
name=entity.name,
description=entity.description,
)
for entity in pred_graph.knowledge.entities
]
pred_covariates = self.prog_covariates(
text=text,
entities=entities_for_covariates,
)
# replace the entities with the covariates
for entity in pred_graph.knowledge.entities:
for covariate in pred_covariates.covariates:
if entity.name == covariate.name:
entity.metadata = covariate.covariates
return pred_graph
class SimpleGraphExtractor:
def __init__(
self, dspy_lm: dspy.LM, complied_extract_program_path: Optional[str] = None
):
self.extract_prog = Extractor(dspy_lm=dspy_lm)
if complied_extract_program_path is not None:
self.extract_prog.load(complied_extract_program_path)
def extract(self, text: str, node: BaseNode):
pred = self.extract_prog(text=text)
metadata = get_relation_metadata_from_node(node)
# Ensure all entities have proper metadata dictionary structure
for entity in pred.knowledge.entities:
if entity.metadata is None or not isinstance(entity.metadata, dict):
entity.metadata = {"topic": "Unknown", "status": "auto-generated"}
return self._to_df(
pred.knowledge.entities, pred.knowledge.relationships, metadata
)
def _to_df(
self,
entities: list[Entity],
relationships: list[Relationship],
extra_meta: Mapping[str, str],
):
# Create lists to store dictionaries for entities and relationships
entities_data = []
relationships_data = []
# Iterate over parsed entities and relationships to create dictionaries
for entity in entities:
entity_dict = {
"name": entity.name,
"description": entity.description,
"meta": entity.metadata,
}
entities_data.append(entity_dict)
mapped_entities = {entity["name"]: entity for entity in entities_data}
for relationship in relationships:
source_entity_description = ""
if relationship.source_entity not in mapped_entities:
new_source_entity = {
"name": relationship.source_entity,
"description": (
f"Derived from from relationship: "
f"{relationship.source_entity} -> {relationship.relationship_desc} -> {relationship.target_entity}"
),
"meta": {"status": "need-revised"},
}
entities_data.append(new_source_entity)
mapped_entities[relationship.source_entity] = new_source_entity
source_entity_description = new_source_entity["description"]
else:
source_entity_description = mapped_entities[relationship.source_entity][
"description"
]
target_entity_description = ""
if relationship.target_entity not in mapped_entities:
new_target_entity = {
"name": relationship.target_entity,
"description": (
f"Derived from from relationship: "
f"{relationship.source_entity} -> {relationship.relationship_desc} -> {relationship.target_entity}"
),
"meta": {"status": "need-revised"},
}
entities_data.append(new_target_entity)
mapped_entities[relationship.target_entity] = new_target_entity
target_entity_description = new_target_entity["description"]
else:
target_entity_description = mapped_entities[relationship.target_entity][
"description"
]
relationship_dict = {
"source_entity": relationship.source_entity,
"source_entity_description": source_entity_description,
"target_entity": relationship.target_entity,
"target_entity_description": target_entity_description,
"relationship_desc": relationship.relationship_desc,
"meta": {
**extra_meta,
},
}
relationships_data.append(relationship_dict)
# Create DataFrames for entities and relationships
entities_df = pd.DataFrame(entities_data)
relationships_df = pd.DataFrame(relationships_data)
return entities_df, relationships_df
================================================
FILE: backend/app/rag/indices/knowledge_graph/graph_store/__init__.py
================================================
from .tidb_graph_store import TiDBGraphStore
from .tidb_graph_editor import TiDBGraphEditor
from .tidb_graph_store import KnowledgeGraphStore
__all__ = [
"TiDBGraphStore",
"TiDBGraphEditor",
"KnowledgeGraphStore",
]
================================================
FILE: backend/app/rag/indices/knowledge_graph/graph_store/helpers.py
================================================
import json
from typing import List, Tuple, Mapping, Any
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
# The configuration for the weight coefficient
# format: ((min_weight, max_weight), coefficient)
DEFAULT_WEIGHT_COEFFICIENT_CONFIG = [
((0, 100), 0.01),
((100, 1000), 0.001),
((1000, 10000), 0.0001),
((10000, float("inf")), 0.00001),
]
# The configuration for the range search
# format: ((min_distance, max_distance), seach_ratio)
# The sum of search ratio should be 1 except some case we want to search as many as possible relationships.
# In this case, we set the search ratio to 1, and the other search ratio sum should be 1
DEFAULT_RANGE_SEARCH_CONFIG = [
((0.0, 0.25), 1),
((0.25, 0.35), 0.7),
((0.35, 0.45), 0.2),
((0.45, 0.55), 0.1),
]
DEFAULT_DEGREE_COEFFICIENT = 0.001
def get_weight_score(
weight: int, weight_coefficient_config: List[Tuple[Tuple[int, int], float]]
) -> float:
weight_score = 0.0
remaining_weight = weight
for weight_range, coefficient in weight_coefficient_config:
if remaining_weight <= 0:
break
lower_bound, upper_bound = weight_range
applicable_weight = min(upper_bound - lower_bound, remaining_weight)
weight_score += applicable_weight * coefficient
remaining_weight -= applicable_weight
return weight_score
def get_degree_score(in_degree: int, out_degree: int, degree_coefficient) -> float:
return (in_degree - out_degree) * degree_coefficient
def calculate_relationship_score(
embedding_distance: float,
weight: int,
in_degree: int,
out_degree: int,
alpha: float,
weight_coefficient_config: List[
Tuple[Tuple[int, int], float]
] = DEFAULT_WEIGHT_COEFFICIENT_CONFIG,
degree_coefficient: float = DEFAULT_DEGREE_COEFFICIENT,
with_degree: bool = False,
) -> float:
weighted_score = get_weight_score(weight, weight_coefficient_config)
degree_score = 0
if with_degree:
degree_score = get_degree_score(in_degree, out_degree, degree_coefficient)
return alpha * (1 / embedding_distance) + weighted_score + degree_score
def get_default_embed_model() -> BaseEmbedding:
return OpenAIEmbedding(model=OpenAIEmbeddingModelType.TEXT_EMBED_3_SMALL)
def get_query_embedding(query: str, embed_model: BaseEmbedding = None) -> Embedding:
if not embed_model:
embed_model = get_default_embed_model()
return embed_model.get_query_embedding(query)
def get_text_embedding(text: str, embed_model: BaseEmbedding = None) -> Embedding:
if not embed_model:
embed_model = get_default_embed_model()
return embed_model.get_text_embedding(text)
def get_entity_description_embedding(
name: str, description: str, embed_model: BaseEmbedding = None
) -> Embedding:
combined_text = f"{name}: {description}"
return get_text_embedding(combined_text, embed_model)
def get_entity_metadata_embedding(
metadata: Mapping[str, Any], embed_model: BaseEmbedding = None
) -> Embedding:
combined_text = json.dumps(metadata, ensure_ascii=False)
return get_text_embedding(combined_text, embed_model)
def get_relationship_description_embedding(
source_entity_name: str,
source_entity_description,
target_entity_name: str,
target_entity_description: str,
relationship_desc: str,
embed_model: BaseEmbedding = None,
):
combined_text = (
f"{source_entity_name}({source_entity_description}) -> "
f"{relationship_desc} -> {target_entity_name}({target_entity_description}) "
)
return get_text_embedding(combined_text, embed_model)
================================================
FILE: backend/app/rag/indices/knowledge_graph/graph_store/schema.py
================================================
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple
from sqlmodel import Session
class KnowledgeGraphStore(ABC):
@abstractmethod
def save(self, entities_df, relationships_df) -> None:
"""Upsert entities and relationships to the graph store."""
pass
@abstractmethod
def retrieve_with_weight(
self,
query: str,
embedding: list,
depth: int = 2,
include_meta: bool = False,
with_degree: bool = False,
relationship_meta_filters: Dict = {},
session: Optional[Session] = None,
) -> Tuple[list, list, list]:
"""Retrieve nodes and relationships with weights."""
pass
================================================
FILE: backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py
================================================
from typing import Optional, Tuple, List, Type
from llama_index.core.embeddings import resolve_embed_model
from llama_index.core.embeddings.utils import EmbedType
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType
from sqlmodel import Session, select, SQLModel
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.attributes import flag_modified
from app.models import EntityType
from app.rag.indices.knowledge_graph.schema import Relationship as RelationshipAIModel
from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore
from app.rag.indices.knowledge_graph.graph_store.helpers import (
get_entity_description_embedding,
get_relationship_description_embedding,
get_entity_metadata_embedding,
get_query_embedding,
)
from app.staff_action import create_staff_action_log
# TODO: CRUD operations should move to TiDBGraphStore
class TiDBGraphEditor:
_entity_db_model: Type[SQLModel]
_relationship_db_model: Type[SQLModel]
def __init__(
self,
knowledge_base_id: int,
entity_db_model: Type[SQLModel],
relationship_db_model: Type[SQLModel],
embed_model: Optional[EmbedType] = None,
):
self.knowledge_base_id = knowledge_base_id
self._entity_db_model = entity_db_model
self._relationship_db_model = relationship_db_model
if embed_model:
self._embed_model = resolve_embed_model(embed_model)
else:
self._embed_model = OpenAIEmbedding(
model=OpenAIEmbeddingModelType.TEXT_EMBED_3_SMALL
)
def get_entity(self, session: Session, entity_id: int) -> Optional[SQLModel]:
return session.get(self._entity_db_model, entity_id)
def update_entity(
self, session: Session, entity: SQLModel, new_entity: dict
) -> SQLModel:
old_entity_dict = entity.screenshot()
for key, value in new_entity.items():
if value is not None:
setattr(entity, key, value)
flag_modified(entity, key)
entity.description_vec = get_entity_description_embedding(
entity.name, entity.description, self._embed_model
)
entity.meta_vec = get_entity_metadata_embedding(entity.meta, self._embed_model)
for relationship in session.exec(
select(self._relationship_db_model)
.options(
joinedload(self._relationship_db_model.source_entity),
joinedload(self._relationship_db_model.target_entity),
)
.where(
(self._relationship_db_model.source_entity_id == entity.id)
| (self._relationship_db_model.target_entity_id == entity.id)
)
):
relationship.description_vec = get_relationship_description_embedding(
relationship.source_entity.name,
relationship.source_entity.description,
relationship.target_entity.name,
relationship.target_entity.description,
relationship.description,
self._embed_model,
)
session.add(relationship)
session.commit()
session.refresh(entity)
new_entity_dict = entity.screenshot()
create_staff_action_log(
session, "update", "entity", entity.id, old_entity_dict, new_entity_dict
)
return entity
def get_entity_subgraph(
self, session: Session, entity: SQLModel
) -> Tuple[list, list]:
"""
Get the subgraph of an entity, including all related relationships and entities.
"""
relationships_queryset = session.exec(
select(self._relationship_db_model)
.options(
joinedload(self._relationship_db_model.source_entity),
joinedload(self._relationship_db_model.target_entity),
)
.where(
(self._relationship_db_model.source_entity_id == entity.id)
| (self._relationship_db_model.target_entity_id == entity.id)
)
)
relationships = []
entities = []
entities_set = set()
for relationship in relationships_queryset:
entities_set.add(relationship.source_entity)
entities_set.add(relationship.target_entity)
relationships.append(relationship.screenshot())
for entity in entities_set:
entities.append(entity.screenshot())
return relationships, entities
def get_relationship(
self, session: Session, relationship_id: int
) -> Optional[SQLModel]:
return session.get(self._relationship_db_model, relationship_id)
def update_relationship(
self, session: Session, relationship: SQLModel, new_relationship: dict
) -> SQLModel:
old_relationship_dict = relationship.screenshot()
for key, value in new_relationship.items():
if value is not None:
setattr(relationship, key, value)
flag_modified(relationship, key)
relationship.description_vec = get_relationship_description_embedding(
relationship.source_entity.name,
relationship.source_entity.description,
relationship.target_entity.name,
relationship.target_entity.description,
relationship.description,
self._embed_model,
)
session.commit()
session.refresh(relationship)
new_relationship_dict = relationship.screenshot()
# FIXME: some error when create staff action log
create_staff_action_log(
session,
"update",
"relationship",
relationship.id,
old_relationship_dict,
new_relationship_dict,
)
return relationship
def search_similar_entities(
self, session: Session, query: str, top_k: int = 10
) -> list:
embedding = get_query_embedding(query, self._embed_model)
return session.exec(
select(self._entity_db_model)
.where(self._entity_db_model.entity_type == EntityType.original)
.order_by(self._entity_db_model.description_vec.cosine_distance(embedding))
.limit(top_k)
).all()
def create_synopsis_entity(
self,
session: Session,
name: str,
description: str,
topic: str,
meta: dict,
related_entities_ids: List[int],
) -> SQLModel:
# with session.begin():
synopsis_entity = self._entity_db_model(
name=name,
description=description,
description_vec=get_entity_description_embedding(
name, description, self._embed_model
),
meta=meta,
meta_vec=get_entity_metadata_embedding(meta, self._embed_model),
entity_type=EntityType.synopsis,
synopsis_info={
"entities": related_entities_ids,
"topic": topic,
},
)
session.add(synopsis_entity)
graph_store = TiDBGraphStore(
knowledge_base=self.knowledge_base_id,
dspy_lm=None,
session=session,
embed_model=self._embed_model,
entity_db_model=self._entity_db_model,
relationship_db_model=self._relationship_db_model,
)
for related_entity in session.exec(
select(self._entity_db_model).where(
self._entity_db_model.id.in_(related_entities_ids)
)
).all():
graph_store.create_relationship(
synopsis_entity,
related_entity,
RelationshipAIModel(
source_entity=synopsis_entity.name,
target_entity=related_entity.name,
relationship_desc=f"{related_entity.name} is a part of synopsis entity (name={synopsis_entity.name}, topic={topic})",
),
{"relationship_type": EntityType.synopsis.value},
commit=False,
)
session.commit()
create_staff_action_log(
session,
"create_synopsis_entity",
"entity",
synopsis_entity.id,
{},
synopsis_entity.screenshot(),
commit=False,
)
return synopsis_entity
================================================
FILE: backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py
================================================
import dspy
import logging
import numpy as np
import tidb_vector
from deepdiff import DeepDiff
from typing import List, Optional, Tuple, Dict, Set, Type, Any
from collections import defaultdict
from dspy import Predict
from llama_index.core.embeddings.utils import EmbedType, resolve_embed_model
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType
import sqlalchemy
from sqlmodel import Session, asc, func, select, text, SQLModel
from sqlalchemy.orm import aliased, defer, joinedload, noload
from tidb_vector.sqlalchemy import VectorAdaptor
from sqlalchemy import or_, desc
from app.core.db import engine
from app.rag.indices.knowledge_graph.graph_store.helpers import (
get_entity_description_embedding,
get_relationship_description_embedding,
calculate_relationship_score,
get_entity_metadata_embedding,
get_query_embedding,
DEFAULT_RANGE_SEARCH_CONFIG,
DEFAULT_WEIGHT_COEFFICIENT_CONFIG,
DEFAULT_DEGREE_COEFFICIENT,
)
from app.rag.indices.knowledge_graph.graph_store.schema import KnowledgeGraphStore
from app.rag.indices.knowledge_graph.schema import (
Entity,
Relationship,
SynopsisEntity,
)
from app.rag.retrievers.knowledge_graph.schema import (
RetrievedEntity,
RetrievedRelationship,
RetrievedKnowledgeGraph,
)
from app.models import (
KnowledgeBase,
EntityType,
Document,
)
logger = logging.getLogger(__name__)
def cosine_distance(v1, v2):
return 1 - np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
class MergeEntities(dspy.Signature):
"""As a knowledge expert assistant specialized in database technologies, evaluate the two provided entities. These entities have been pre-analyzed and have same name but different descriptions and metadata.
Please carefully review the detailed descriptions and metadata for both entities to determine if they genuinely represent the same concept or object(entity).
If you conclude that the entities are identical, merge the descriptions and metadata fields of the two entities into a single consolidated entity.
If the entities are distinct despite their same name that may be due to different contexts or perspectives, do not merge the entities and return none as the merged entity.
Considerations: Ensure your decision is based on a comprehensive analysis of the content and context provided within the entity descriptions and metadata.
Please only response in JSON Format.
"""
entities: List[Entity] = dspy.InputField(
desc="List of entities identified from previous analysis."
)
merged_entity: Optional[Entity] = dspy.OutputField(
desc="Merged entity with consolidated descriptions and metadata."
)
class MergeEntitiesProgram(dspy.Module):
def __init__(self):
self.prog = Predict(MergeEntities)
def forward(self, entities: List[Entity]):
if len(entities) != 2:
raise ValueError("The input should contain exactly two entities")
pred = self.prog(entities=entities)
return pred
class TiDBGraphStore(KnowledgeGraphStore):
def __init__(
self,
knowledge_base: KnowledgeBase,
dspy_lm: dspy.LM,
entity_db_model: Type[SQLModel],
relationship_db_model: Type[SQLModel],
chunk_db_model: Type[SQLModel],
session: Optional[Session] = None,
embed_model: Optional[EmbedType] = None,
description_similarity_threshold=0.9,
):
self.knowledge_base = knowledge_base
self._session = session
self._owns_session = session is None
if self._session is None:
self._session = Session(engine)
self._dspy_lm = dspy_lm
if embed_model:
self._embed_model = resolve_embed_model(embed_model)
else:
self._embed_model = OpenAIEmbedding(
model=OpenAIEmbeddingModelType.TEXT_EMBED_3_SMALL
)
self.merge_entities_prog = MergeEntitiesProgram()
self.description_cosine_distance_threshold = (
1 - description_similarity_threshold
)
self._entity_model = entity_db_model
self._relationship_model = relationship_db_model
self._chunk_model = chunk_db_model
def ensure_table_schema(self) -> None:
inspector = sqlalchemy.inspect(engine)
existed_table_names = inspector.get_table_names()
entities_table_name = self._entity_model.__tablename__
relationships_table_name = self._relationship_model.__tablename__
if entities_table_name not in existed_table_names:
self._entity_model.metadata.create_all(
engine, tables=[self._entity_model.__table__]
)
# Add HNSW index to accelerate ann queries.
VectorAdaptor(engine).create_vector_index(
self._entity_model.description_vec, tidb_vector.DistanceMetric.COSINE
)
VectorAdaptor(engine).create_vector_index(
self._entity_model.meta_vec, tidb_vector.DistanceMetric.COSINE
)
logger.info(
f"Entities table <{entities_table_name}> has been created successfully."
)
else:
logger.info(
f"Entities table <{entities_table_name}> is already exists, not action to do."
)
if relationships_table_name not in existed_table_names:
self._relationship_model.metadata.create_all(
engine, tables=[self._relationship_model.__table__]
)
# Add HNSW index to accelerate ann queries.
VectorAdaptor(engine).create_vector_index(
self._relationship_model.description_vec,
tidb_vector.DistanceMetric.COSINE,
)
logger.info(
f"Relationships table <{relationships_table_name}> has been created successfully."
)
else:
logger.info(
f"Relationships table <{relationships_table_name}> is already exists, not action to do."
)
def drop_table_schema(self) -> None:
inspector = sqlalchemy.inspect(engine)
existed_table_names = inspector.get_table_names()
relationships_table_name = self._relationship_model.__tablename__
entities_table_name = self._entity_model.__tablename__
if relationships_table_name in existed_table_names:
self._relationship_model.metadata.drop_all(
engine, tables=[self._relationship_model.__table__]
)
logger.info(
f"Relationships table <{relationships_table_name}> has been dropped successfully."
)
else:
logger.info(
f"Relationships table <{relationships_table_name}> is not existed, not action to do."
)
if entities_table_name in existed_table_names:
self._entity_model.metadata.drop_all(
engine, tables=[self._entity_model.__table__]
)
logger.info(
f"Entities table <{entities_table_name}> has been dropped successfully."
)
else:
logger.info(
f"Entities table <{entities_table_name}> is not existed, not action to do."
)
def close_session(self) -> None:
# Always call this method is necessary to make sure the session is closed
if self._owns_session:
self._session.close()
def save(self, chunk_id, entities_df, relationships_df):
if entities_df.empty or relationships_df.empty:
logger.info(
"Entities or relationships are empty, skip saving to the database"
)
return
if (
self._session.exec(
select(self._relationship_model).where(
self._relationship_model.meta["chunk_id"] == chunk_id
)
).first()
is not None
):
logger.info(f"{chunk_id} already exists in the relationship table, skip.")
return
entities_name_map = defaultdict(list)
for _, row in entities_df.iterrows():
entities_name_map[row["name"]].append(
self.get_or_create_entity(
Entity(
name=row["name"],
description=row["description"],
metadata=row["meta"],
),
commit=False,
)
)
def _find_or_create_entity_for_relation(
name: str, description: str
) -> SQLModel:
_embedding = get_entity_description_embedding(
name, description, self._embed_model
)
# Check entities_name_map first, if not found, then check the database
for e in entities_name_map.get(name, []):
if (
cosine_distance(e.description_vec, _embedding)
< self.description_cosine_distance_threshold
):
return e
return self.get_or_create_entity(
Entity(
name=name,
description=description,
metadata={"status": "need-revised"},
),
commit=False,
)
try:
for _, row in relationships_df.iterrows():
logger.info(
"save entities for relationship %s -> %s -> %s",
row["source_entity"],
row["relationship_desc"],
row["target_entity"],
)
source_entity = _find_or_create_entity_for_relation(
row["source_entity"], row["source_entity_description"]
)
target_entity = _find_or_create_entity_for_relation(
row["target_entity"], row["target_entity_description"]
)
self.create_relationship(
source_entity,
target_entity,
Relationship(
source_entity=source_entity.name,
target_entity=target_entity.name,
relationship_desc=row["relationship_desc"],
),
relationship_metadata=row["meta"],
commit=False,
)
self._session.commit()
except Exception as e:
logger.error(e, exc_info=True)
self._session.rollback()
raise e
def create_relationship(
self,
source_entity: SQLModel,
target_entity: SQLModel,
relationship: Relationship,
relationship_metadata: dict = {},
commit=True,
):
relationship_object = self._relationship_model(
source_entity=source_entity,
target_entity=target_entity,
description=relationship.relationship_desc,
description_vec=get_relationship_description_embedding(
source_entity.name,
source_entity.description,
target_entity.name,
target_entity.description,
relationship.relationship_desc,
self._embed_model,
),
meta=relationship_metadata,
document_id=relationship_metadata.get("document_id"),
chunk_id=relationship_metadata.get("chunk_id"),
)
self._session.add(relationship_object)
if commit:
self._session.commit()
self._session.refresh(relationship_object)
else:
self._session.flush()
def get_subgraph_by_relationship_ids(
self, ids: list[int], **kwargs
) -> RetrievedKnowledgeGraph:
stmt = (
select(self._relationship_model)
.where(self._relationship_model.id.in_(ids))
.options(
joinedload(self._relationship_model.source_entity),
joinedload(self._relationship_model.target_entity),
)
)
relationships_set = self._session.exec(stmt)
entities_set = set()
relationships = []
entities = []
for rel in relationships_set:
entities_set.add(rel.source_entity)
entities_set.add(rel.target_entity)
relationships.append(
RetrievedRelationship(
id=rel.id,
knowledge_base_id=self.knowledge_base.id,
source_entity_id=rel.source_entity_id,
target_entity_id=rel.target_entity_id,
description=rel.description,
rag_description=f"{rel.source_entity.name} -> {rel.description} -> {rel.target_entity.name}",
meta=rel.meta,
weight=rel.weight,
last_modified_at=rel.last_modified_at,
)
)
for entity in entities_set:
entities.append(
RetrievedEntity(
id=entity.id,
knowledge_base_id=self.knowledge_base.id,
name=entity.name,
description=entity.description,
meta=entity.meta,
entity_type=entity.entity_type,
)
)
return RetrievedKnowledgeGraph(
knowledge_base=self.knowledge_base.to_descriptor(),
entities=entities,
relationships=relationships,
**kwargs,
)
def get_or_create_entity(self, entity: Entity, commit: bool = True) -> SQLModel:
# using the cosine distance between the description vectors to determine if the entity already exists
entity_type = (
EntityType.synopsis
if isinstance(entity, SynopsisEntity)
else EntityType.original
)
entity_description_vec = get_entity_description_embedding(
entity.name,
entity.description,
self._embed_model,
)
hint = text(
f"/*+ read_from_storage(tikv[{self._entity_model.__tablename__}]) */"
)
result = (
self._session.query(
self._entity_model,
self._entity_model.description_vec.cosine_distance(
entity_description_vec
).label("distance"),
)
.filter(
self._entity_model.name == entity.name
and self._entity_model.entity_type == entity_type
)
.prefix_with(hint)
.order_by(asc("distance"))
.first()
)
if (
result is not None
and result[1] < self.description_cosine_distance_threshold
):
db_obj = result[0]
ob_obj_metadata = db_obj.meta
if (
db_obj.description == entity.description
and db_obj.name == entity.name
and len(DeepDiff(ob_obj_metadata, entity.metadata)) == 0
):
return db_obj
elif entity_type == EntityType.original:
# TODO: move to TiDBKnowledgeGraphIndex
# use LLM to merge the most similar entities
merged_entity = self._try_merge_entities(
[
Entity(
name=db_obj.name,
description=db_obj.description,
metadata=ob_obj_metadata,
),
Entity(
name=entity.name,
description=entity.description,
metadata=entity.metadata,
),
]
)
if merged_entity is not None:
db_obj.description = merged_entity.description
db_obj.meta = merged_entity.metadata
db_obj.description_vec = get_entity_description_embedding(
db_obj.name, db_obj.description, self._embed_model
)
db_obj.meta_vec = get_entity_metadata_embedding(
db_obj.meta, self._embed_model
)
self._session.add(db_obj)
if commit:
self._session.commit()
self._session.refresh(db_obj)
else:
self._session.flush()
return db_obj
synopsis_info_str = (
entity.group_info.model_dump()
if entity_type == EntityType.synopsis
else None
)
db_obj = self._entity_model(
name=entity.name,
description=entity.description,
description_vec=entity_description_vec,
meta=entity.metadata,
meta_vec=get_entity_metadata_embedding(entity.metadata, self._embed_model),
synopsis_info=synopsis_info_str,
entity_type=entity_type,
)
self._session.add(db_obj)
if commit:
self._session.commit()
self._session.refresh(db_obj)
else:
self._session.flush()
return db_obj
def _try_merge_entities(self, entities: List[Entity]) -> Entity:
logger.info(f"Trying to merge entities: {entities[0].name}")
try:
with dspy.settings.context(lm=self._dspy_lm):
pred = self.merge_entities_prog(entities=entities)
return pred.merged_entity
except Exception as e:
logger.error(f"Failed to merge entities: {e}", exc_info=True)
return None
def retrieve_with_weight(
self,
query: str,
embedding: list,
depth: int = 2,
include_meta: bool = False,
with_degree: bool = False,
# experimental feature to filter relationships based on meta, can be removed in the future
relationship_meta_filters: dict = {},
session: Optional[Session] = None,
) -> Tuple[List[RetrievedEntity], List[RetrievedRelationship]]:
if not embedding:
assert query, "Either query or embedding must be provided"
embedding = get_query_embedding(query, self._embed_model)
relationships, entities = self.search_relationships_weight(
embedding,
[],
[],
with_degree=with_degree,
relationship_meta_filters=relationship_meta_filters,
session=session,
)
all_relationships = set(relationships)
all_entities = set(entities)
visited_entities = set(e.id for e in entities)
visited_relationships = set(r.id for r in relationships)
for _ in range(depth - 1):
actual_number = 0
progress = 0
search_number_each_depth = 10
for search_config in DEFAULT_RANGE_SEARCH_CONFIG:
search_ratio = search_config[1]
search_distance_range = search_config[0]
remaining_number = search_number_each_depth - actual_number
# calculate the expected number based search progress
# It's a accumulative search, so the expected number should be the difference between the expected number and the actual number
expected_number = (
int(
(search_ratio + progress) * search_number_each_depth
- actual_number
)
if progress * search_number_each_depth > actual_number
else int(search_ratio * search_number_each_depth)
)
if expected_number > remaining_number:
expected_number = remaining_number
if remaining_number <= 0:
break
new_relationships, new_entities = self.search_relationships_weight(
embedding,
visited_relationships,
visited_entities,
search_distance_range,
rank_n=expected_number,
with_degree=with_degree,
relationship_meta_filters=relationship_meta_filters,
session=session,
)
all_relationships.update(new_relationships)
all_entities.update(new_entities)
visited_entities.update(e.id for e in new_entities)
visited_relationships.update(r.id for r in new_relationships)
actual_number += len(new_relationships)
# search_ratio == 1 won't count the progress
if search_ratio != 1:
progress += search_ratio
synopsis_entities = self.fetch_similar_entities(
embedding, top_k=2, entity_type=EntityType.synopsis, session=session
)
all_entities.update(synopsis_entities)
related_doc_ids = set()
for r in all_relationships:
if "doc_id" not in r.meta:
continue
related_doc_ids.add(r.meta["doc_id"])
entities = [
RetrievedEntity(
id=e.id,
knowledge_base_id=self.knowledge_base.id,
name=e.name,
description=e.description,
meta=e.meta if include_meta else None,
entity_type=e.entity_type,
)
for e in all_entities
]
relationships = [
RetrievedRelationship(
id=r.id,
knowledge_base_id=self.knowledge_base.id,
source_entity_id=r.source_entity_id,
target_entity_id=r.target_entity_id,
rag_description=f"{r.source_entity.name} -> {r.description} -> {r.target_entity.name}",
description=r.description,
meta=r.meta,
weight=r.weight,
last_modified_at=r.last_modified_at,
)
for r in all_relationships
]
return entities, relationships
# Function to fetch degrees for entities
def fetch_entity_degrees(
self,
entity_ids: List[int],
session: Optional[Session] = None,
) -> Dict[int, Dict[str, int]]:
degrees = {
entity_id: {"in_degree": 0, "out_degree": 0} for entity_id in entity_ids
}
session = session or self._session
try:
# Fetch out-degrees
out_degree_query = (
session.query(
self._relationship_model.source_entity_id,
func.count(self._relationship_model.id).label("out_degree"),
)
.filter(self._relationship_model.source_entity_id.in_(entity_ids))
.group_by(self._relationship_model.source_entity_id)
).all()
for row in out_degree_query:
degrees[row.source_entity_id]["out_degree"] = row.out_degree
# Fetch in-degrees
in_degree_query = (
session.query(
self._relationship_model.target_entity_id,
func.count(self._relationship_model.id).label("in_degree"),
)
.filter(self._relationship_model.target_entity_id.in_(entity_ids))
.group_by(self._relationship_model.target_entity_id)
).all()
for row in in_degree_query:
degrees[row.target_entity_id]["in_degree"] = row.in_degree
except Exception as e:
logger.error(e)
return degrees
def search_relationships_weight(
self,
embedding: List[float],
visited_relationships: Set[int],
visited_entities: Set[int],
distance_range: Tuple[float, float] = (0.0, 1.0),
limit: int = 100,
weight_coefficient_config: List[
Tuple[Tuple[int, int], float]
] = DEFAULT_WEIGHT_COEFFICIENT_CONFIG,
alpha: float = 1,
rank_n: int = 10,
degree_coefficient: float = DEFAULT_DEGREE_COEFFICIENT,
with_degree: bool = False,
relationship_meta_filters: Dict = {},
session: Optional[Session] = None,
) -> Tuple[List[SQLModel], List[SQLModel]]:
# select the relationships to rank
subquery = (
select(
self._relationship_model,
self._relationship_model.description_vec.cosine_distance(
embedding
).label("embedding_distance"),
)
.options(defer(self._relationship_model.description_vec))
.order_by(asc("embedding_distance"))
.limit(limit * 10)
).subquery()
relationships_alias = aliased(self._relationship_model, subquery)
query = (
select(relationships_alias, text("embedding_distance"))
.options(
defer(relationships_alias.description_vec),
joinedload(relationships_alias.source_entity)
.defer(self._entity_model.meta_vec)
.defer(self._entity_model.description_vec),
joinedload(relationships_alias.target_entity)
.defer(self._entity_model.meta_vec)
.defer(self._entity_model.description_vec),
)
.where(relationships_alias.weight >= 0)
)
if relationship_meta_filters:
for k, v in relationship_meta_filters.items():
query = query.where(relationships_alias.meta[k] == v)
if visited_relationships:
query = query.where(subquery.c.id.notin_(visited_relationships))
if distance_range != (0.0, 1.0):
# embedding_distance between the range
query = query.where(
text(
"embedding_distance >= :min_distance AND embedding_distance <= :max_distance"
)
).params(min_distance=distance_range[0], max_distance=distance_range[1])
if visited_entities:
query = query.where(subquery.c.source_entity_id.in_(visited_entities))
query = query.order_by(asc("embedding_distance")).limit(limit)
# Order by embedding distance and apply limit
session = session or self._session
relationships = session.exec(query).all()
if len(relationships) <= rank_n:
relationship_set = set([rel for rel, _ in relationships])
entity_set = set()
for r in relationship_set:
entity_set.add(r.source_entity)
entity_set.add(r.target_entity)
return relationship_set, entity_set
# Fetch degrees if with_degree is True
if with_degree:
entity_ids = set()
for rel, _ in relationships:
entity_ids.add(rel.source_entity_id)
entity_ids.add(rel.target_entity_id)
degrees = self.fetch_entity_degrees(list(entity_ids), session=session)
else:
degrees = {}
# calculate the relationship score based on distance and weight
ranked_relationships = []
for relationship, embedding_distance in relationships:
source_in_degree = (
degrees[relationship.source_entity_id]["in_degree"]
if with_degree
else 0
)
target_out_degree = (
degrees[relationship.target_entity_id]["out_degree"]
if with_degree
else 0
)
final_score = calculate_relationship_score(
embedding_distance,
relationship.weight,
source_in_degree,
target_out_degree,
alpha,
weight_coefficient_config,
degree_coefficient,
with_degree,
)
ranked_relationships.append((relationship, final_score))
# rank relationships based on the calculated score
ranked_relationships.sort(key=lambda x: x[1], reverse=True)
relationship_set = set([rel for rel, score in ranked_relationships[:rank_n]])
entity_set = set()
for r in relationship_set:
entity_set.add(r.source_entity)
entity_set.add(r.target_entity)
return list(relationship_set), list(entity_set)
def fetch_similar_entities_by_post_filter(
self,
embedding: list,
top_k: int = 5,
entity_type: EntityType = EntityType.original,
session: Optional[Session] = None,
post_filter_multiplier: int = 10,
):
new_entity_set = set()
session = session or self._session
# Create a subquery with a larger limit and include the distance
subquery = (
select(
self._entity_model,
self._entity_model.description_vec.cosine_distance(embedding).label(
"distance"
),
)
.order_by(asc("distance"))
.limit(
post_filter_multiplier * top_k
if entity_type != EntityType.original
else top_k
)
.subquery()
)
# Apply filter only for non-original entity types
query = (
select(self._entity_model)
.where(subquery.c.entity_type == entity_type)
.order_by(asc(subquery.c.distance))
.limit(top_k)
)
for row in session.exec(query).all():
new_entity_set.add(row)
return new_entity_set
def fetch_similar_entities(
self,
embedding: list,
top_k: int = 10,
entity_type: EntityType = EntityType.original,
session: Optional[Session] = None,
):
new_entity_set = set()
# Retrieve entities based on their ID and similarity to the embedding
session = session or self._session
query = select(self._entity_model)
if entity_type == EntityType.synopsis:
query = query.where(self._entity_model.entity_type == entity_type)
hint = text("/*+ read_from_storage(tikv[entities]) */")
query = query.prefix_with(hint)
query = query.order_by(
self._entity_model.description_vec.cosine_distance(embedding)
).limit(top_k)
# Debug: Print the SQL query
"""
from sqlalchemy.dialects import mysql
compiled_query = query.compile(
dialect=mysql.dialect(), compile_kwargs={"literal_binds": True}
)
print(f"Debug - SQL Query: {compiled_query}")
"""
for entity in session.exec(query).all():
new_entity_set.add(entity)
return new_entity_set
def retrieve_graph_data(
self,
query_text: str,
top_k: int = 5,
similarity_threshold: float = 0.7,
) -> Dict[str, List[Dict[str, Any]]]:
"""Retrieve related entities and relationships using semantic search.
Args:
query_text: The search query text
top_k: Maximum number of results to return for each type
similarity_threshold: Minimum similarity score threshold
Returns:
Dictionary containing:
- entities: List of similar entities with similarity scores
- relationships: List of similar relationships with similarity scores
"""
query_embedding = get_query_embedding(query_text, self._embed_model)
# Query similar entities
entity_query = (
select(
self._entity_model,
(
1
- self._entity_model.description_vec.cosine_distance(
query_embedding
)
).label("similarity"),
)
.options(
defer(self._entity_model.description_vec),
defer(self._entity_model.meta_vec),
)
.order_by(desc("similarity"))
.limit(top_k * 2) # Fetch more results to account for filtering
)
# Query similar relationships
relationship_query = (
select(
self._relationship_model,
(
1
- self._relationship_model.description_vec.cosine_distance(
query_embedding
)
).label("similarity"),
)
.options(
defer(self._relationship_model.description_vec),
joinedload(self._relationship_model.source_entity)
.defer(self._entity_model.meta_vec)
.defer(self._entity_model.description_vec),
joinedload(self._relationship_model.target_entity)
.defer(self._entity_model.meta_vec)
.defer(self._entity_model.description_vec),
)
.order_by(desc("similarity"))
.limit(top_k * 2) # Fetch more results to account for filtering
)
# Execute both queries
entities = []
relationships = []
for entity, similarity in self._session.exec(entity_query).all():
if similarity >= similarity_threshold and len(entities) < top_k:
entities.append(
{
"id": entity.id,
"name": entity.name,
"description": entity.description,
"metadata": entity.meta,
"similarity_score": similarity,
}
)
for relationship, similarity in self._session.exec(relationship_query).all():
if similarity >= similarity_threshold and len(relationships) < top_k:
relationships.append(
{
"id": relationship.id,
"relationship": relationship.description,
"source_entity": {
"id": relationship.source_entity.id,
"name": relationship.source_entity.name,
"description": relationship.source_entity.description,
},
"target_entity": {
"id": relationship.target_entity.id,
"name": relationship.target_entity.name,
"description": relationship.target_entity.description,
},
"similarity_score": similarity,
}
)
return {"entities": entities, "relationships": relationships}
def retrieve_neighbors(
self,
entities_ids: List[int],
query: str,
max_depth: int = 1,
max_neighbors: int = 20,
similarity_threshold: float = 0.7,
) -> Dict[str, List[Dict]]:
"""Retrieve most relevant neighbor paths for a group of similar nodes.
Args:
node_ids: List of source node IDs (representing similar entities)
query: Search query for relevant relationships
max_depth: Maximum depth for relationship traversal
max_neighbors: Maximum number of total neighbor paths to return
similarity_threshold: Minimum similarity score threshold
Returns:
Dictionary containing most relevant paths from source nodes to neighbors
"""
query_embedding = get_query_embedding(query, self._embed_model)
# Get all source entities
source_entities = self._session.exec(
select(self._entity_model)
.options(
defer(self._entity_model.description_vec),
defer(self._entity_model.meta_vec),
)
.where(self._entity_model.id.in_(entities_ids))
).all()
# Track visited nodes and discovered paths
all_visited = set(entities_ids)
current_level_nodes = set(entities_ids)
neighbors = [] # Store all discovered paths with their relevance scores
for depth in range(max_depth):
if not current_level_nodes:
break
# Query relationships for current level
relationships = self._session.exec(
select(
self._relationship_model,
(
1
- self._relationship_model.description_vec.cosine_distance(
query_embedding
)
).label("similarity"),
)
.options(
defer(self._relationship_model.description_vec),
joinedload(self._relationship_model.source_entity)
.defer(self._entity_model.meta_vec)
.defer(self._entity_model.description_vec),
joinedload(self._relationship_model.target_entity)
.defer(self._entity_model.meta_vec)
.defer(self._entity_model.description_vec),
)
.where(
or_(
self._relationship_model.source_entity_id.in_(
current_level_nodes
),
self._relationship_model.target_entity_id.in_(
current_level_nodes
),
)
)
.order_by(desc("similarity"))
.limit(max_neighbors * 2) # Fetch more results to account for filtering
).all()
next_level_nodes = set()
for rel, similarity in relationships:
# Skip if similarity is below threshold
if similarity < similarity_threshold:
continue
# Determine direction and connected entity
if rel.source_entity_id in current_level_nodes:
connected_id = rel.target_entity_id
else:
connected_id = rel.source_entity_id
# Skip if already visited
if connected_id in all_visited:
continue
neighbors.append(
{
"id": rel.id,
"relationship": rel.description,
"source_entity": {
"id": rel.source_entity.id,
"name": rel.source_entity.name,
"description": rel.source_entity.description,
},
"target_entity": {
"id": rel.target_entity.id,
"name": rel.target_entity.name,
"description": rel.target_entity.description,
},
"similarity_score": similarity,
}
)
next_level_nodes.add(connected_id)
all_visited.add(connected_id)
current_level_nodes = next_level_nodes
# Sort all paths by similarity score and return top max_neighbors
neighbors.sort(key=lambda x: x["similarity_score"], reverse=True)
return {"relationships": neighbors[:max_neighbors]}
def get_chunks_by_relationships(
self,
relationships_ids: List[int],
session: Optional[Session] = None,
) -> List[Dict[str, Any]]:
"""Get chunks for a list of relationships.
Args:
relationships: List of relationship objects
session: Optional database session
Returns:
List of dictionaries containing chunk information:
- text: chunk text content
- document_id: associated document id
- meta: chunk metadata
"""
session = session or self._session
relationships = session.exec(
select(self._relationship_model).where(
self._relationship_model.id.in_(relationships_ids)
)
).all()
# Extract chunk IDs from relationships
chunk_ids = {
rel.meta.get("chunk_id")
for rel in relationships
if rel.meta.get("chunk_id") is not None
}
if chunk_ids:
logger.info(f"Getting chunks for relationships: {chunk_ids}")
# Query chunks
chunks = session.exec(
select(self._chunk_model).where(self._chunk_model.id.in_(chunk_ids))
).all()
return [
{
"id": chunk.id,
"text": chunk.text,
"document_id": chunk.document_id,
"meta": {
"language": chunk.meta.get("language"),
"product": chunk.meta.get("product"),
"resource": chunk.meta.get("resource"),
"source_uri": chunk.meta.get("source_uri"),
"tidb_version": chunk.meta.get("tidb_version"),
},
}
for chunk in chunks
]
document_ids = {
rel.document_id
for rel in relationships
if rel.document_id is not None
}
logger.info(f"Getting documents for relationships: {document_ids}")
documents = session.exec(
select(Document).where(Document.id.in_(document_ids))
).all()
return [
{
"id": doc.id,
"text": doc.content,
"document_id": doc.id,
"meta": doc.meta,
}
for doc in documents
]
def get_entire_knowledge_graph(self) -> RetrievedKnowledgeGraph:
"""Retrieve all entities and relationships from the knowledge graph store.
Returns:
RetrievedKnowledgeGraph containing all entities and relationships
"""
# Query all entities
entity_query = select(self._entity_model).order_by(self._entity_model.id)
db_entities = self._session.exec(entity_query).all()
# Query all relationships with their related entities
relationship_query = (
select(self._relationship_model)
.options(
joinedload(self._relationship_model.source_entity),
joinedload(self._relationship_model.target_entity),
)
.order_by(self._relationship_model.id)
)
db_relationships = self._session.exec(relationship_query).all()
# Convert entities to RetrievedEntity objects
entities = []
for entity in db_entities:
entities.append(
RetrievedEntity(
id=entity.id,
knowledge_base_id=self.knowledge_base.id,
name=entity.name,
description=entity.description,
meta=entity.meta,
entity_type=entity.entity_type,
)
)
# Convert relationships to RetrievedRelationship objects
relationships = []
for rel in db_relationships:
relationships.append(
RetrievedRelationship(
id=rel.id,
knowledge_base_id=self.knowledge_base.id,
source_entity_id=rel.source_entity_id,
target_entity_id=rel.target_entity_id,
description=rel.description,
rag_description=f"{rel.source_entity.name} -> {rel.description} -> {rel.target_entity.name}",
meta=rel.meta,
weight=rel.weight,
last_modified_at=rel.last_modified_at,
)
)
return RetrievedKnowledgeGraph(
knowledge_base=self.knowledge_base.to_descriptor(),
entities=entities,
relationships=relationships,
)
def stream_entire_knowledge_graph(self, chunk_size: int = 5000):
"""Stream entire knowledge graph in chunks
Args:
chunk_size: Number of entities/relationships per chunk
Yields:
Dict containing chunk type and data
"""
# Stream entities
entity_query = (
select(self._entity_model)
.options(
defer(self._entity_model.description_vec),
defer(self._entity_model.meta_vec),
)
.order_by(self._entity_model.id)
)
last_entity_id = 0
while True:
chunk_query = entity_query.where(
self._entity_model.id > last_entity_id
).limit(chunk_size)
db_entities = self._session.exec(chunk_query).all()
if not db_entities:
break
entities = []
for entity in db_entities:
entities.append(
RetrievedEntity(
id=entity.id,
knowledge_base_id=self.knowledge_base.id,
name=entity.name,
description=entity.description,
meta=entity.meta,
entity_type=entity.entity_type,
)
)
last_entity_id = db_entities[-1].id
yield {"type": "entities", "data": entities}
# Stream relationships
relationship_query = (
select(self._relationship_model)
.options(
defer(self._relationship_model.description_vec),
defer(self._relationship_model.chunk_id),
noload(self._relationship_model.source_entity),
noload(self._relationship_model.target_entity),
)
.order_by(self._relationship_model.id)
)
logger.info(f"Relationship query: {relationship_query}")
last_relationship_id = 0
while True:
chunk_query = relationship_query.where(
self._relationship_model.id > last_relationship_id
).limit(chunk_size)
logger.info(f"Executing relationship chunk query: {chunk_query}")
db_relationships = self._session.exec(chunk_query).all()
if not db_relationships:
break
relationships = []
for rel in db_relationships:
relationships.append(
RetrievedRelationship(
id=rel.id,
knowledge_base_id=self.knowledge_base.id,
source_entity_id=rel.source_entity_id,
target_entity_id=rel.target_entity_id,
description=rel.description,
rag_description=None, # Skip rag_description for streaming performance
meta=rel.meta,
weight=rel.weight,
last_modified_at=rel.last_modified_at,
)
)
last_relationship_id = db_relationships[-1].id
yield {"type": "relationships", "data": relationships}
================================================
FILE: backend/app/rag/indices/knowledge_graph/schema.py
================================================
from pydantic import BaseModel, Field
from typing import Mapping, Any, List
class Entity(BaseModel):
"""List of entities extracted from the text to form the knowledge graph"""
name: str = Field(
description="Name of the entity, it should be a clear and concise term"
)
description: str = Field(
description=(
"Description of the entity, it should be a complete and comprehensive sentence, not few words. "
"Sample description of entity 'TiDB in-place upgrade': "
"'Upgrade TiDB component binary files to achieve upgrade, generally use rolling upgrade method'"
)
)
metadata: Mapping[str, Any] = Field(
description=(
"The covariates (which is a comprehensive json TREE, the first field is always: 'topic') to claim the entity. "
)
)
class EntityWithID(Entity):
"""Entity extracted from the text to form the knowledge graph with an ID."""
id: int = Field(description="Unique identifier for the entity.")
class SynopsisInfo(BaseModel):
"""A synopsis corresponds to a group of entities that share the same topic and can contribute to synopsis topic."""
topic: str = Field(
description="The shared topic of the synopsis, and each entity in the group can contribute factual data from its own perspective."
)
entities: List[int] = Field(
description="A group of entity(only IDs) that can contribute to the synopsis base on the analysis of entity descriptions and metadata."
)
class SynopsisEntity(Entity):
"""Unified synopsis entity with comprehensive description and metadata based on the entities group."""
group_info: SynopsisInfo = Field(
description="Group of entities to be unified into a single synopsis entity."
)
class ExistingSynopsisEntity(SynopsisEntity):
"""Unified synopsis entity with comprehensive description and metadata based on the entities group."""
id: int = Field(description="Unique identifier for the entity.")
class Relationship(BaseModel):
"""List of relationships extracted from the text to form the knowledge graph"""
source_entity: str = Field(
description="Source entity name of the relationship, it should an existing entity in the Entity list"
)
target_entity: str = Field(
description="Target entity name of the relationship, it should an existing entity in the Entity list"
)
relationship_desc: str = Field(
description=(
"Description of the relationship, it should be a complete and comprehensive sentence, not few words. "
"Sample relationship description: 'TiDB will release a new LTS version every 6 months.'"
)
)
class RelationshipReasoning(Relationship):
"""Relationship between two entities extracted from the query"""
reasoning: str = Field(
description=(
"Category reasoning for the relationship, e.g., 'the main conerns of the user', 'the problem the user is facing', 'the user case scenario', etc."
)
)
class KnowledgeGraph(BaseModel):
"""Graph representation of the knowledge for text."""
relationships: List[Relationship] = Field(
description="List of relationships in the knowledge graph"
)
entities: List[Entity] = Field(
description="List of entities in the knowledge graph"
)
class EntityCovariateInput(BaseModel):
"""List of entities extracted from the text to form the knowledge graph"""
name: str = Field(description="Name of the entity")
description: str = Field(description=("Description of the entity"))
class EntityCovariateOutput(BaseModel):
"""List of entities extracted from the text to form the knowledge graph"""
name: str = Field(description="Name of the entity")
description: str = Field(description=("Description of the entity"))
covariates: Mapping[str, Any] = Field(
description=(
"The attributes (which is a comprehensive json TREE, the first field is always: 'topic') to claim the entity. "
)
)
class DecomposedFactors(BaseModel):
"""Decomposed factors extracted from the query to form the knowledge graph"""
relationships: List[RelationshipReasoning] = Field(
description="List of relationships to represent critical concepts and their relationships extracted from the query."
)
================================================
FILE: backend/app/rag/indices/vector_search/__init__.py
================================================
================================================
FILE: backend/app/rag/indices/vector_search/vector_store/__init__.py
================================================
__all__ = ["TiDBVectorStore"]
================================================
FILE: backend/app/rag/indices/vector_search/vector_store/tidb_vector_store.py
================================================
import logging
import tidb_vector
import sqlalchemy
from typing import Any, List, Optional, Type
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.core.vector_stores.utils import (
metadata_dict_to_node,
node_to_metadata_dict,
)
from sqlmodel import (
SQLModel,
Session,
delete,
select,
asc,
alias,
)
from tidb_vector.sqlalchemy import VectorAdaptor
from app.core.db import engine
logger = logging.getLogger(__name__)
def node_to_relation_dict(node: BaseNode) -> dict:
relationships = {}
for r_type, node_info in node.relationships.items():
relationships[r_type.name] = {
"node_id": node_info.node_id,
"node_type": node_info.node_type.name,
"meta": node_info.metadata,
"hash": node_info.hash,
}
return relationships
class TiDBVectorStore(BasePydanticVectorStore):
_session: Session = PrivateAttr()
_owns_session: bool = PrivateAttr()
_table_name: str = PrivateAttr()
_vector_dimension: int = PrivateAttr()
stores_text: bool = True
flat_metadata: bool = False
def __init__(
self,
chunk_db_model: Type[SQLModel],
session: Optional[Session] = None,
oversampling_factor: int = 1,
**kwargs: Any,
) -> None:
"""
Args:
oversampling_factor (int): The oversampling factor for the similarity search. The higher the factor, the higher recall rate.
"""
super().__init__(**kwargs)
self._session = session
self._owns_session = session is None
if self._session is None:
self._session = Session(engine)
self._chunk_db_model = chunk_db_model
self._oversampling_factor = oversampling_factor
def ensure_table_schema(self) -> None:
inspector = sqlalchemy.inspect(engine)
table_name = self._chunk_db_model.__tablename__
if table_name not in inspector.get_table_names():
self._chunk_db_model.metadata.create_all(
engine, tables=[self._chunk_db_model.__table__]
)
# Add HNSW index to accelerate ann queries.
VectorAdaptor(engine).create_vector_index(
self._chunk_db_model.embedding, tidb_vector.DistanceMetric.COSINE
)
logger.info(f"Chunk table <{table_name}> has been created successfully.")
else:
logger.info(
f"Chunk table <{table_name}> is already exists, no action to do."
)
def drop_table_schema(self):
inspector = sqlalchemy.inspect(engine)
table_name = self._chunk_db_model.__tablename__
if table_name in inspector.get_table_names():
self._chunk_db_model.metadata.drop_all(
self._session.connection(), tables=[self._chunk_db_model.__table__]
)
logger.info(f"Chunk table <{table_name}> has been dropped successfully.")
else:
logger.info(f"Chunk table <{table_name}> is not existed, not action to do.")
def close_session(self) -> None:
# Always call this method is necessary to make sure the session is closed
if self._owns_session:
self._session.close()
@property
def client(self) -> Any:
"""Get client."""
return engine
@classmethod
def class_name(cls) -> str:
return "TiDBVectorStore"
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""
Add nodes to the vector store.
Args:
nodes (List[BaseNode]): List of nodes to be added.
**add_kwargs: Additional keyword arguments to be passed to the underlying storage.
Returns:
List[str]: List of node IDs that were added.
"""
items = []
for n in nodes:
items.append(
{
"id": n.node_id,
"hash": n.hash,
"text": n.get_content(metadata_mode=MetadataMode.NONE),
"meta": node_to_metadata_dict(n, remove_text=True),
"embedding": n.get_embedding(),
"document_id": n.ref_doc_id,
"relations": node_to_relation_dict(n),
"source_uri": add_kwargs.get("source_uri"),
}
)
self._session.bulk_insert_mappings(self._chunk_db_model, items)
self._session.commit()
return [i["id"] for i in items]
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete all nodes of a document from the vector store.
Args:
ref_doc_id (str): The reference document ID to be deleted.
**delete_kwargs: Additional keyword arguments to be passed to the delete method.
Returns:
None
"""
assert ref_doc_id.isdigit(), "ref_doc_id must be an integer."
delete_stmt = delete(self._chunk_db_model).where(
self._chunk_db_model.document_id == int(ref_doc_id)
)
self._session.exec(delete_stmt)
self._session.commit()
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""
Perform a similarity search with the given query embedding.
Args:
query (VectorStoreQuery): The query object containing the query data.
**kwargs: Additional keyword arguments.
Returns:
VectorStoreQueryResult: The result of the similarity search.
Raises:
ValueError: If the query embedding is not provided.
"""
# TODO:
# - Support advanced query filters
# - Support both pre-filter and post-filter
if query.query_embedding is None:
raise ValueError("Query embedding must be provided.")
subquery = select(
self._chunk_db_model.id,
self._chunk_db_model.text,
self._chunk_db_model.meta,
self._chunk_db_model.document_id,
self._chunk_db_model.embedding.cosine_distance(query.query_embedding).label(
"distance"
),
)
if query.filters:
for f in query.filters.filters:
subquery = subquery.stmt(self._chunk_db_model.meta[f.key] == f.value)
sub = alias(
subquery.order_by(asc("distance"))
.limit(query.similarity_top_k * self._oversampling_factor)
.subquery(),
"sub",
)
stmt = (
select(
sub.c.id,
sub.c.text,
sub.c.meta,
sub.c.document_id,
sub.c.distance,
)
.order_by(asc("distance"))
.limit(query.similarity_top_k)
)
results = self._session.exec(stmt)
nodes = []
similarities = []
ids = []
for row in results:
# Check if metadata contains required fields for node reconstruction
# to avoid async event loop issues in metadata_dict_to_node
if (
isinstance(row.meta, dict)
and "_node_content" in row.meta
and "_node_type" in row.meta
):
try:
node = metadata_dict_to_node(row.meta)
node.id_ = str(row.id)
node.metadata["document_id"] = row.document_id
node.set_content(row.text)
except Exception as e:
# NOTE: deprecated legacy logic for backward compatibility
logger.warning(
f"Failed to parse metadata dict (error: {e}), falling back to legacy logic.",
exc_info=True,
)
node = TextNode(
id_=str(row.id),
text=row.text,
metadata=row.meta,
)
else:
# Use legacy logic directly if metadata doesn't contain required fields
node = TextNode(
id_=str(row.id),
text=row.text,
metadata=row.meta,
)
similarities.append((1 - row.distance) if row.distance is not None else 0)
ids.append(str(row.id))
nodes.append(node)
return VectorStoreQueryResult(
nodes=nodes,
similarities=similarities,
ids=ids,
)
================================================
FILE: backend/app/rag/knowledge_base/__init__.py
================================================
================================================
FILE: backend/app/rag/knowledge_base/config.py
================================================
import logging
from llama_index.core.base.embeddings.base import BaseEmbedding
from sqlmodel import Session
from app.models.knowledge_base import KnowledgeBase
from app.rag.llms.resolver import get_default_llm, resolve_llm
from app.rag.embeddings.resolver import resolve_embed_model, get_default_embed_model
from app.rag.llms.dspy import get_dspy_lm_by_llama_llm
logger = logging.getLogger(__name__)
def get_kb_llm(session: Session, kb: KnowledgeBase):
db_llm = kb.llm
if db_llm:
return resolve_llm(
db_llm.provider, db_llm.model, db_llm.config, db_llm.credentials
)
else:
return get_default_llm(session)
def get_kb_dspy_llm(session: Session, kb: KnowledgeBase):
llm = get_kb_llm(session, kb)
return get_dspy_lm_by_llama_llm(llm)
def get_kb_embed_model(session: Session, kb: KnowledgeBase) -> BaseEmbedding:
db_embed_model = kb.embedding_model
if db_embed_model:
return resolve_embed_model(
db_embed_model.provider,
db_embed_model.model,
db_embed_model.config,
db_embed_model.credentials,
)
else:
return get_default_embed_model(session)
================================================
FILE: backend/app/rag/knowledge_base/index_store.py
================================================
from sqlalchemy import inspection
from sqlmodel import Session
from app.models import KnowledgeBase
from app.models.chunk import get_kb_chunk_model
from app.models.entity import get_kb_entity_model
from app.rag.knowledge_base.config import get_kb_dspy_llm, get_kb_embed_model
from app.models.relationship import get_kb_relationship_model
from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore, TiDBGraphEditor
from app.rag.indices.vector_search.vector_store.tidb_vector_store import TiDBVectorStore
def get_kb_tidb_vector_store(session: Session, kb: KnowledgeBase) -> TiDBVectorStore:
chunk_model = get_kb_chunk_model(kb)
vector_store = TiDBVectorStore(chunk_model, session=session)
return vector_store
def init_kb_tidb_vector_store(session: Session, kb: KnowledgeBase) -> TiDBVectorStore:
vector_store = get_kb_tidb_vector_store(session, kb)
vector_store.ensure_table_schema()
return vector_store
def get_kb_tidb_graph_store(session: Session, kb: KnowledgeBase) -> TiDBGraphStore:
dspy_lm = get_kb_dspy_llm(session, kb)
embed_model = get_kb_embed_model(session, kb)
entity_model = get_kb_entity_model(kb)
relationship_model = get_kb_relationship_model(kb)
inspection.inspect(relationship_model)
chunk_model = get_kb_chunk_model(kb)
graph_store = TiDBGraphStore(
knowledge_base=kb,
dspy_lm=dspy_lm,
session=session,
embed_model=embed_model,
entity_db_model=entity_model,
relationship_db_model=relationship_model,
chunk_db_model=chunk_model,
)
return graph_store
def init_kb_tidb_graph_store(session: Session, kb: KnowledgeBase) -> TiDBGraphStore:
graph_store = get_kb_tidb_graph_store(session, kb)
graph_store.ensure_table_schema()
return graph_store
def get_kb_tidb_graph_editor(session: Session, kb: KnowledgeBase) -> TiDBGraphEditor:
entity_db_model = get_kb_entity_model(kb)
relationship_db_model = get_kb_relationship_model(kb)
embed_model = get_kb_embed_model(session, kb)
return TiDBGraphEditor(
knowledge_base_id=kb.id,
entity_db_model=entity_db_model,
relationship_db_model=relationship_db_model,
embed_model=embed_model,
)
================================================
FILE: backend/app/rag/knowledge_base/schema.py
================================================
from enum import Enum
class KBIndexType(str, Enum):
VECTOR_SEARCH = "VECTOR_SEARCH"
KNOWLEDGE_GRAPH = "KNOWLEDGE_GRAPH"
================================================
FILE: backend/app/rag/llms/dspy.py
================================================
import dspy
from llama_index.core.base.llms.base import BaseLLM
def get_dspy_lm_by_llama_llm(llama_llm: BaseLLM) -> dspy.LM:
"""
Get the dspy LM by the llama LLM.
In this project, we use both llama-index and dspy, both of them have their own LLM implementation.
This function can help us reduce the complexity of the code by converting the llama LLM to the dspy LLM.
"""
match llama_llm.class_name():
case "openai_llm":
return dspy.LM(
model=f"openai/{llama_llm.model}",
max_tokens=llama_llm.max_tokens,
api_key=llama_llm.api_key,
api_base=enforce_trailing_slash(llama_llm.api_base),
num_retries=3,
)
case "OpenAILike":
return dspy.LM(
model=f"openai/{llama_llm.model}",
max_tokens=llama_llm.max_tokens,
api_key=llama_llm.api_key,
api_base=enforce_trailing_slash(llama_llm.api_base),
model_type="chat" if llama_llm.is_chat_model else "text",
num_retries=3,
)
case "GenAI":
if "models/" in llama_llm.model:
# For Gemini
model_name = llama_llm.model.split("models/")[1]
return dspy.LM(
model=f"gemini/{model_name}",
max_tokens=llama_llm._max_tokens,
api_key=llama_llm._client._api_client.api_key,
)
else:
# For Vertex AI
return dspy.LM(
model=f"vertex_ai/{llama_llm.model}",
max_tokens=llama_llm._max_tokens,
context_window=llama_llm.context_window,
temperature=llama_llm.temperature,
vertex_location=llama_llm._location,
vertex_credentials=llama_llm._credentials,
)
case "Bedrock_Converse_LLM":
return dspy.LM(
model=f"bedrock/converse/{llama_llm.model}",
# Notice: Bedrock's default max_tokens is 512, which is too small for the application.
max_tokens=llama_llm.max_tokens or 8192,
aws_access_key_id=llama_llm.aws_access_key_id,
aws_secret_access_key=llama_llm.aws_secret_access_key,
aws_region_name=llama_llm.region_name,
)
case "Ollama_llm":
return dspy.LM(
model=f"ollama_chat/{llama_llm.model}",
api_base=llama_llm.base_url,
timeout=llama_llm.request_timeout,
temperature=llama_llm.temperature,
num_retries=3,
)
case "azure_openai_llm":
return dspy.LM(
model=f"azure/{llama_llm.model}",
max_tokens=llama_llm.max_tokens,
temperature=llama_llm.temperature,
api_key=llama_llm.api_key,
api_base=enforce_trailing_slash(llama_llm.azure_endpoint),
api_version=llama_llm.api_version,
deployment_id=llama_llm.engine,
)
case _:
raise ValueError(f"Got unknown LLM provider: {llama_llm.class_name()}")
def enforce_trailing_slash(url: str):
if url.endswith("/"):
return url
return url + "/"
================================================
FILE: backend/app/rag/llms/provider.py
================================================
import enum
from typing import List
from pydantic import BaseModel
class LLMProvider(str, enum.Enum):
OPENAI = "openai"
GEMINI = "gemini"
VERTEX_AI = "vertex_ai"
ANTHROPIC_VERTEX = "anthropic_vertex" # Deprecated, use VERTEX_AI instead
OPENAI_LIKE = "openai_like"
BEDROCK = "bedrock"
OLLAMA = "ollama"
GITEEAI = "giteeai"
AZURE_OPENAI = "azure_openai"
class LLMProviderOption(BaseModel):
provider: LLMProvider
provider_display_name: str | None = None
provider_description: str | None = None
provider_url: str | None = None
default_llm_model: str
llm_model_description: str
default_config: dict = {}
config_description: str = ""
default_credentials: str | dict = ""
credentials_display_name: str
credentials_description: str
credentials_type: str = "str"
llm_provider_options: List[LLMProviderOption] = [
LLMProviderOption(
provider=LLMProvider.OPENAI,
provider_display_name="OpenAI",
provider_description="The OpenAI API provides a simple interface for developers to create an intelligence layer in their applications, powered by OpenAI's state of the art models.",
provider_url="https://platform.openai.com",
default_llm_model="gpt-4o",
llm_model_description="",
credentials_display_name="OpenAI API Key",
credentials_description="The API key of OpenAI, you can find it in https://platform.openai.com/api-keys",
credentials_type="str",
default_credentials="sk-****",
),
LLMProviderOption(
provider=LLMProvider.OPENAI_LIKE,
provider_display_name="OpenAI Like",
default_llm_model="",
llm_model_description="",
default_config={
"api_base": "https://openrouter.ai/api/v1/",
"is_chat_model": True,
},
config_description=(
"`api_base` is the API base URL of the third-party OpenAI-like service, such as OpenRouter; "
"`is_chat_model` indicates whether the model is chat model; "
"`context_window` is the maximum number of input tokens and output tokens; "
),
credentials_display_name="API Key",
credentials_description="The API key of the third-party OpenAI-like service, such as OpenRouter, you can find it in their official website",
credentials_type="str",
default_credentials="sk-****",
),
LLMProviderOption(
provider=LLMProvider.GEMINI,
provider_display_name="Gemini",
provider_description="The Gemini API and Google AI Studio help you start working with Google's latest models. Access the whole Gemini model family and turn your ideas into real applications that scale.",
provider_url="https://ai.google.dev/gemini-api",
default_llm_model="models/gemini-2.0-flash",
llm_model_description="Find the model code at https://ai.google.dev/gemini-api/docs/models/gemini",
credentials_display_name="Google API Key",
credentials_description="The API key of Google AI Studio, you can find it in https://aistudio.google.com/app/apikey",
credentials_type="str",
default_credentials="AIza****",
),
LLMProviderOption(
provider=LLMProvider.VERTEX_AI,
provider_display_name="Vertex AI",
provider_description="Vertex AI is a fully-managed, unified AI development platform for building and using generative AI.",
provider_url="https://cloud.google.com/vertex-ai",
default_llm_model="gemini-2.5-flash",
llm_model_description="Find more in https://cloud.google.com/model-garden",
credentials_display_name="Google Credentials JSON",
credentials_description="The JSON Object of Google Credentials, refer to https://cloud.google.com/docs/authentication/provide-credentials-adc#on-prem",
credentials_type="dict",
default_credentials={
"type": "service_account",
"project_id": "****",
"private_key_id": "****",
},
),
LLMProviderOption(
provider=LLMProvider.OLLAMA,
provider_display_name="Ollama",
provider_description="Ollama is a lightweight framework for building and running large language models.",
provider_url="https://ollama.com",
default_llm_model="llama3.2",
llm_model_description="Find more in https://ollama.com/library",
default_config={
"base_url": "http://localhost:11434",
"context_window": 8192,
"request_timeout": 60 * 10,
},
config_description=(
"`base_url` is the base URL of the Ollama server, ensure it can be accessed from this server; "
"`context_window` is the maximum number of input tokens and output tokens; "
"`request_timeout` is the maximum time to wait for a generate response."
),
credentials_display_name="Ollama API Key",
credentials_description="Ollama doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
LLMProviderOption(
provider=LLMProvider.GITEEAI,
provider_display_name="Gitee AI",
provider_description="Gitee AI is a third-party model provider that offers ready-to-use cutting-edge model APIs for AI developers.",
provider_url="https://ai.gitee.com",
default_llm_model="Qwen2.5-72B-Instruct",
default_config={
"is_chat_model": True,
"context_window": 131072,
},
config_description=(
"`is_chat_model` indicates whether the model is chat model; "
"`context_window` is the maximum number of input tokens and output tokens; "
),
llm_model_description="Find more in https://ai.gitee.com/serverless-api",
credentials_display_name="Gitee AI API Key",
credentials_description="The API key of Gitee AI, you can find it in https://ai.gitee.com/dashboard/settings/tokens",
credentials_type="str",
default_credentials="****",
),
LLMProviderOption(
provider=LLMProvider.BEDROCK,
provider_display_name="Bedrock",
provider_description="Amazon Bedrock is a fully managed foundation models service.",
provider_url="https://docs.aws.amazon.com/bedrock/",
default_llm_model="anthropic.claude-3-7-sonnet-20250219-v1:0",
llm_model_description="",
credentials_display_name="AWS Bedrock Credentials JSON",
credentials_description="The JSON Object of AWS Credentials, refer to https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global",
credentials_type="dict",
default_credentials={
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2",
},
),
LLMProviderOption(
provider=LLMProvider.AZURE_OPENAI,
provider_display_name="Azure OpenAI",
provider_description="Azure OpenAI is a cloud-based AI service that provides access to OpenAI's advanced language models.",
provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service",
default_llm_model="gpt-4o",
llm_model_description="",
config_description="Refer to this document https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart to have more information about the Azure OpenAI API.",
default_config={
"azure_endpoint": "https://.openai.azure.com/",
"api_version": "",
"engine": "",
},
credentials_display_name="Azure OpenAI API Key",
credentials_description="The API key of Azure OpenAI",
credentials_type="str",
default_credentials="****",
),
]
================================================
FILE: backend/app/rag/llms/resolver.py
================================================
from typing import Optional
from llama_index.core.llms.llm import LLM
from sqlmodel import Session
from app.repositories.llm import llm_repo
from app.rag.llms.provider import LLMProvider
def resolve_llm(
provider: LLMProvider,
model: str,
config: Optional[dict] = {},
credentials: Optional[str | list | dict] = None,
) -> LLM:
match provider:
case LLMProvider.OPENAI:
from llama_index.llms.openai import OpenAI
return OpenAI(
model=model,
api_key=credentials,
**config,
)
case LLMProvider.OPENAI_LIKE:
from llama_index.llms.openai_like import OpenAILike
config.setdefault("is_chat_model", True)
config.setdefault("context_window", 200 * 1000)
return OpenAILike(model=model, api_key=credentials, **config)
case LLMProvider.BEDROCK:
from llama_index.llms.bedrock_converse import BedrockConverse
access_key_id = credentials["aws_access_key_id"]
secret_access_key = credentials["aws_secret_access_key"]
region_name = credentials["aws_region_name"]
config.setdefault("max_tokens", 4096)
return BedrockConverse(
model=model,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
region_name=region_name,
**config,
)
case LLMProvider.GEMINI:
from llama_index.llms.google_genai import GoogleGenAI
return GoogleGenAI(model=model, api_key=credentials, **config)
case LLMProvider.VERTEX_AI | LLMProvider.ANTHROPIC_VERTEX:
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.llms.google_genai.base import VertexAIConfig
from google.oauth2 import service_account
from google.auth.transport.requests import Request
google_creds: service_account.Credentials = (
service_account.Credentials.from_service_account_info(
credentials,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
google_creds.refresh(request=Request())
project = credentials.get("project_id") or config.get("project")
location = config.get("location", "us-west1")
llm = GoogleGenAI(
model=model,
vertexai_config=VertexAIConfig(
credentials=google_creds, project=project, location=location
),
**config,
)
llm._project = project
llm._location = location
llm._credentials = credentials
return llm
case LLMProvider.GITEEAI:
from llama_index.llms.openai_like import OpenAILike
config.setdefault("is_chat_model", True)
config.setdefault("context_window", 128 * 1024)
return OpenAILike(
model=model,
api_base="https://ai.gitee.com/v1",
api_key=credentials,
**config,
)
case LLMProvider.AZURE_OPENAI:
from llama_index.llms.azure_openai import AzureOpenAI
return AzureOpenAI(
model=model,
api_key=credentials,
**config,
)
case LLMProvider.OLLAMA:
from llama_index.llms.ollama import Ollama
config.setdefault("request_timeout", 60 * 10)
config.setdefault("context_window", 8192)
return Ollama(model=model, **config)
case _:
raise ValueError(f"Got unknown LLM provider: {provider}")
def get_llm_by_id(session: Session, llm_id: int) -> Optional[LLM]:
db_llm = llm_repo.get(session, llm_id)
if not db_llm:
return None
return resolve_llm(
db_llm.provider,
db_llm.model,
db_llm.config,
db_llm.credentials,
)
def must_get_llm_by_id(session: Session, llm_id: int) -> LLM:
db_llm = llm_repo.must_get(session, llm_id)
return resolve_llm(
db_llm.provider,
db_llm.model,
db_llm.config,
db_llm.credentials,
)
def get_default_llm(session: Session) -> Optional[LLM]:
db_llm = llm_repo.get_default(session)
if not db_llm:
return None
return resolve_llm(
db_llm.provider,
db_llm.model,
db_llm.config,
db_llm.credentials,
)
def must_get_default_llm(session: Session) -> LLM:
db_llm = llm_repo.must_get_default(session)
return resolve_llm(
db_llm.provider,
db_llm.model,
db_llm.config,
db_llm.credentials,
)
def get_llm_or_default(session: Session, llm_id: Optional[int]) -> LLM:
if llm_id is None:
return must_get_default_llm(session)
else:
return must_get_llm_by_id(session, llm_id)
================================================
FILE: backend/app/rag/node_parser/__init__.py
================================================
from .file.markdown import MarkdownNodeParser
__all__ = ["MarkdownNodeParser"]
================================================
FILE: backend/app/rag/node_parser/file/markdown.py
================================================
import re
from typing import Any, Dict, List, Optional, Sequence, Callable
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.node_parser.interface import NodeParser
from llama_index.core.node_parser.node_utils import build_nodes_from_splits
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.utils import get_tqdm_iterable, get_tokenizer
from llama_index.core.bridge.pydantic import Field, PrivateAttr
DEFAULT_CHUNK_HEADER_LEVEL = 2
DEFAULT_CHUNK_SIZE = 1200
class MarkdownNodeParser(NodeParser):
"""Markdown node parser.
Splits a document into Nodes using custom Markdown splitting logic.
Args:
include_metadata (bool): whether to include metadata in nodes
include_prev_next_rel (bool): whether to include prev/next relationships
"""
chunk_size: int = Field(
default=DEFAULT_CHUNK_SIZE,
description="The token chunk size for each chunk.",
gt=0,
)
chunk_header_level: int = Field(
default=DEFAULT_CHUNK_HEADER_LEVEL,
description="The header level to split on",
ge=1,
le=6,
)
_tokenizer: Callable = PrivateAttr()
def __init__(
self,
chunk_header_level: int = DEFAULT_CHUNK_HEADER_LEVEL,
chunk_size: int = DEFAULT_CHUNK_SIZE,
tokenizer: Optional[Callable] = None,
include_metadata: bool = True,
include_prev_next_rel: bool = True,
callback_manager: Optional[CallbackManager] = None,
):
super().__init__(
chunk_header_level=chunk_header_level,
chunk_size=chunk_size,
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
callback_manager=callback_manager,
)
self.callback_manager = callback_manager or CallbackManager([])
self._tokenizer = tokenizer or get_tokenizer()
@classmethod
def from_defaults(
cls,
chunk_header_level: int = DEFAULT_CHUNK_HEADER_LEVEL,
chunk_size: int = DEFAULT_CHUNK_SIZE,
tokenizer: Optional[Callable] = None,
include_metadata: bool = True,
include_prev_next_rel: bool = True,
callback_manager: Optional[CallbackManager] = None,
) -> "MarkdownNodeParser":
callback_manager = callback_manager or CallbackManager([])
tokenizer = tokenizer or get_tokenizer()
return cls(
chunk_header_level=chunk_header_level,
chunk_size=chunk_size,
tokenizer=tokenizer,
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
callback_manager=callback_manager,
)
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "MarkdownNodeParser"
def _parse_nodes(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> List[BaseNode]:
all_nodes: List[BaseNode] = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
for node in nodes_with_progress:
splitted_nodes = self.get_nodes_from_node(
node,
self.chunk_header_level,
self.chunk_size * 0.7,
self.chunk_size * 1.1,
)
nodes = []
for sn in splitted_nodes:
header_level = sn.metadata.get("Header_Level")
if header_level:
for _hl in range(1, header_level + 1)[::-1]:
if (
f"Header_{_hl}" in sn.metadata
and sn.metadata[f"Header_{_hl}"] not in sn.text
):
sn.text = (
f"{'#' * _hl} {sn.metadata[f'Header_{_hl}']}\n\n"
+ sn.text
)
n = build_nodes_from_splits([sn.text], node, id_func=self.id_func)[0]
if self.include_metadata:
n.metadata = {**node.metadata, **sn.metadata}
nodes.append(n)
all_nodes.extend(nodes)
return all_nodes
def get_nodes_from_node(
self,
node: BaseNode,
chunk_header_level: int,
chunk_size_small_threshold: float,
chunk_size_large_threshold: float,
) -> List[TextNode]:
# print(chunk_header_level, chunk_size_small_threshold, chunk_size_large_threshold)
"""Get nodes from document."""
text = node.get_content(metadata_mode=MetadataMode.NONE)
markdown_nodes = []
lines = text.split("\n")
metadata: Dict[str, str] = node.metadata
code_block = False
current_section = ""
first_header = True
for line in lines:
if line.lstrip().startswith("```"):
code_block = not code_block
header_match = re.match(r"^(#+)\s(.*)", line)
if header_match and not code_block:
current_header_level = len(header_match.group(1).strip())
if current_section != "" and current_header_level == chunk_header_level:
if first_header:
# skip the first header, merge it with the first section (usually the title of the document)
first_header = False
else:
markdown_nodes.append(
self._build_node_from_split(
current_section.strip(), node, metadata
)
)
current_section = ""
if current_header_level <= chunk_header_level:
metadata = self._update_metadata(
metadata, header_match.group(2), current_header_level
)
current_section += line + "\n"
else:
current_section += line + "\n"
markdown_nodes.append(
self._build_node_from_split(current_section.strip(), node, metadata)
)
return self._normalize_node_sizes(
markdown_nodes, chunk_size_small_threshold, chunk_size_large_threshold
)
def _normalize_node_sizes(
self,
nodes: List[TextNode],
chunk_size_small_threshold: float,
chunk_size_large_threshold: float,
) -> List[TextNode]:
# 1. Split the big node into multiple small nodes
# 2. Merge the small nodes into a big node if they are too small
# 3. Make all the nodes as much as possible close to the chunk size
nodes_token_size = [self._token_size(node.text) for node in nodes]
normalized_nodes = []
buffer = []
node_count = len(nodes)
i = 0
while i < node_count:
node = nodes[i]
this_chunk_size = nodes_token_size[i]
if this_chunk_size < chunk_size_small_threshold:
# if the last node is too small, merge it with the previous one
if (
not buffer
and i == (node_count - 1)
and i > 0
and nodes_token_size[i - 1] + this_chunk_size
< chunk_size_large_threshold
):
normalized_nodes[-1].text += "\n\n" + node.text
i += 1
continue
buffer.append(this_chunk_size)
total = sum(buffer)
while (
(i + 1) < node_count
and nodes_token_size[i + 1] < self.chunk_size
and total + nodes_token_size[i + 1] <= chunk_size_large_threshold
):
i += 1
buffer.append(nodes_token_size[i])
total += nodes_token_size[i]
# output the sum of the buffer
buffer_nodes = nodes[i - len(buffer) + 1 : i + 1]
normalized_nodes.append(
TextNode(
text="\n\n".join([node.text for node in buffer_nodes]),
metadata=buffer_nodes[0].metadata,
)
)
i += 1
buffer.clear()
elif this_chunk_size > chunk_size_large_threshold:
# split into multiple nodes with next header level and bigger chunk size
md_splitted_nodes = self.get_nodes_from_node(
node,
self.chunk_header_level + 1,
chunk_size_small_threshold,
chunk_size_large_threshold * 1.1,
)
for n in md_splitted_nodes:
_chunk_size = self._token_size(n.text)
if _chunk_size > chunk_size_large_threshold * 1.1:
# using sentence splitter to split the node if it's still too large
sentence_splitted_nodes = SentenceSplitter(
chunk_size=int(chunk_size_large_threshold), separator="\n\n"
).get_nodes_from_documents([n])
normalized_nodes.extend(sentence_splitted_nodes)
else:
normalized_nodes.append(n)
i += 1
else:
normalized_nodes.append(node)
i += 1
return normalized_nodes
def _update_metadata(
self, headers_metadata: dict, new_header: str, new_header_level: int
) -> dict:
"""Update the markdown headers for metadata.
Removes all headers that are equal or less than the level
of the newly found header
"""
updated_headers = {}
for i in range(1, new_header_level):
key = f"Header_{i}"
if key in headers_metadata:
updated_headers[key] = headers_metadata[key]
updated_headers[f"Header_{new_header_level}"] = new_header
updated_headers["Header_Level"] = new_header_level
return updated_headers
def _build_node_from_split(
self,
text_split: str,
node: BaseNode,
metadata: dict,
) -> TextNode:
"""Build node from single text split."""
node = build_nodes_from_splits([text_split], node, id_func=self.id_func)[0]
if self.include_metadata:
node.metadata = {**node.metadata, **metadata}
return node
def _token_size(self, text: str) -> int:
return len(self._tokenizer(text))
================================================
FILE: backend/app/rag/postprocessors/__init__.py
================================================
from .metadata_post_filter import MetadataPostFilter, MetadataFilters
__all__ = [
"MetadataPostFilter",
"MetadataFilters",
]
================================================
FILE: backend/app/rag/postprocessors/metadata_post_filter.py
================================================
import logging
from typing import Dict, List, Optional, Any, Union
from llama_index.core import QueryBundle
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import BaseNode, NodeWithScore
from llama_index.core.vector_stores.types import (
MetadataFilter,
MetadataFilters,
FilterOperator,
FilterCondition,
)
SimpleMetadataFilter = Dict[str, Any]
def simple_filter_to_metadata_filters(filters: SimpleMetadataFilter) -> MetadataFilters:
simple_filters = []
for key, value in filters.items():
simple_filters.append(
MetadataFilter(
key=key,
value=value,
operator=FilterOperator.EQ,
)
)
return MetadataFilters(filters=simple_filters)
logger = logging.getLogger(__name__)
class MetadataPostFilter(BaseNodePostprocessor):
filters: Optional[MetadataFilters] = None
def __init__(
self,
filters: Optional[Union[MetadataFilters, SimpleMetadataFilter]] = None,
**kwargs: Any,
):
super().__init__(**kwargs)
if isinstance(filters, MetadataFilters):
self.filters = filters
else:
self.filters = simple_filter_to_metadata_filters(filters)
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if self.filters is None:
return nodes
filtered_nodes = []
for node in nodes:
# TODO: support advanced post filtering.
if self.match_all_filters(node.node):
filtered_nodes.append(node)
return filtered_nodes
def match_all_filters(self, node: BaseNode) -> bool:
if self.filters is None or not isinstance(self.filters, MetadataFilters):
return True
if self.filters.condition != FilterCondition.AND:
logger.warning(
f"Advanced filtering is not supported yet. "
f"Filter condition {self.filters.condition} is ignored."
)
return True
for f in self.filters.filters:
if f.key not in node.metadata:
return False
if f.operator is not None and f.operator != FilterOperator.EQ:
logger.warning(
f"Advanced filtering is not supported yet. "
f"Filter operator {f.operator} is ignored."
)
return True
value = node.metadata[f.key]
if f.value != value:
return False
return True
================================================
FILE: backend/app/rag/query_dispatcher.py
================================================
import logging
from typing import Optional, Sequence
from llama_index.llms.openai import OpenAI
from llama_index.core.tools.types import BaseTool
from llama_index.core.tools import FunctionTool
logger = logging.getLogger(__name__)
DefaultSystemPrompt = """
You are a highly skilled customer assistant, responsible for dispatching user questions to the most appropriate tools or resources. Your primary objective is to ensure each user question is handled accurately and efficiently by selecting the best-suited tool for the task.
For more complex questions, you should break them down into clear, manageable sub-questions and route each to the relevant tools for individual resolution. It's important to maintain clarity and precision in this process, ensuring that the sub-questions are well-defined and can be resolved independently.
If you encounter concepts or entities you are not familiar with, you can break the query down into a sub-question to clarify the specific concept or entity. For example, if the query involves “what is the latest version,” you can treat this as a sub-question to better understand the context before proceeding with the solution.
"""
class QueryDispatcher:
def __init__(self, llm: OpenAI, system_prompt: Optional[str] = None):
if system_prompt is None:
system_prompt = DefaultSystemPrompt
self._llm = llm
self._llm.system_prompt = system_prompt
def route(self, query: str, tools: Sequence["BaseTool"]) -> str:
response = self._llm.chat_with_tools(
tools, query, allow_parallel_tool_calls=True, verbose=True
)
try:
tool_calls = self._llm.get_tool_calls_from_response(
response, error_on_no_tool_call=True
)
except Exception as e:
logger.exception(e)
return f"An error occurred while processing the query: {query}"
return tool_calls
# mock the answer process
def answer(query: str) -> str:
"""
Answer a user query. The query should be simple and straightforward.
"""
return f"I need some time to answer your question: {query}."
answer_tool = FunctionTool.from_defaults(fn=answer)
================================================
FILE: backend/app/rag/question_gen/__init__.py
================================================
================================================
FILE: backend/app/rag/question_gen/helpers.py
================================================
from typing import List
from llama_index.core import QueryBundle
from llama_index.core.base.llms.types import ChatMessage
def get_query_bundle_from_chat(
user_question, chat_history: List[ChatMessage]
) -> QueryBundle:
query_str = user_question
if len(chat_history) > 0:
chat_messages = [
f"{message.role.value}: {message.content}" for message in chat_history
]
query_with_history = (
"++++ Chat History ++++\n"
+ "\n".join(chat_messages)
+ "++++ Chat History ++++\n"
)
query_str = query_with_history + "\n\nThen the user asks:\n" + user_question
return QueryBundle(query_str=query_str)
================================================
FILE: backend/app/rag/question_gen/query_decomposer.py
================================================
import logging
import dspy
from typing import List, Optional
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class SubQuestion(BaseModel):
"""Representation of a single step-by-step question extracted from the user query."""
question: str = Field(
description="A step-by-step question to address the user query."
)
reasoning: str = Field(
description="The rationale behind the question to explain its relevance."
)
class SubQuestions(BaseModel):
"""Representation of the user's step-by-step questions extracted from the query."""
questions: List[SubQuestion] = Field(
description="List of questions representing a plan to address the user query."
)
class DecomposeQuery(dspy.Signature):
"""You are an expert in knowledge base graph construction, specializing in building comprehensive knowledge graphs.
Your current task is to deconstruct the user's query into a series of step-by-step questions.
## Instructions:
1. Dependency Analysis:
- Analyze the user's query to identify the underlying dependencies and relationships between different components.
- Construct a dependency graph that visually represents these relationships.
2. Question Breakdown: Divide the query into a sequence of step-by-step questions necessary to address the main query comprehensively.
3. Provide Reasoning: Explain the rationale behind each question.
4. Constraints:
- Limit the output to no more than 5 questions to maintain focus and relevance.
- Ensure accuracy by reflecting the user's true intentions based on the provided query.
- Ground all questions in factual information derived directly from the user's input.
Please only response in JSON format.
"""
query: str = dspy.InputField(
desc="The query text to extract the user's step-by-step questions."
)
subquestions: SubQuestions = dspy.OutputField(
desc="Representation of the user's step-by-step questions extracted from the query."
)
class DecomposeQueryModule(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.prog = dspy.Predict(DecomposeQuery)
def forward(self, query):
with dspy.settings.context(lm=self.dspy_lm):
return self.prog(query=query)
class QueryDecomposer:
def __init__(self, dspy_lm: dspy.LM, complied_program_path: Optional[str] = None):
self.decompose_query_prog = DecomposeQueryModule(dspy_lm=dspy_lm)
if complied_program_path is not None:
self.decompose_query_prog.load(complied_program_path)
def decompose(self, query: str) -> SubQuestions:
return self.decompose_query_prog(query=query).subquestions
================================================
FILE: backend/app/rag/rerankers/baisheng/baisheng_reranker.py
================================================
from typing import Any, List, Optional
import requests
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
dispatcher = get_dispatcher(__name__)
class BaishengRerank(BaseNodePostprocessor):
api_key: str = Field(default="", description="API key.")
api_url: str = Field(
default="http://api.chat.prd.yumc.local/chat/v1/reranker",
description="API url.",
)
model: str = Field(
default="bge-reranker-v2-m3",
description="The model to use when calling API",
)
top_n: int = Field(description="Top N nodes to return.")
_session: Any = PrivateAttr()
def __init__(
self,
top_n: int = 2,
model: str = "bge-reranker-v2-m3",
api_key: str = "",
api_url: str = "http://api.chat.prd.yumc.local/chat/v1/reranker",
):
super().__init__(top_n=top_n, model=model)
self.api_key = api_key
self.api_url = api_url
self.model = model
self._session = requests.Session()
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
@classmethod
def class_name(cls) -> str:
return "BaishengRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle,
nodes=nodes,
top_n=self.top_n,
model_name=self.model,
)
)
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
resp = self._session.post( # type: ignore
self.api_url,
json={
"query": query_bundle.query_str,
"model": self.model,
"sentences": texts,
},
).json()
if "scores" not in resp:
raise RuntimeError(f"Got error from reranker: {resp}")
results = zip(range(len(nodes)), resp["scores"])
results = sorted(results, key=lambda x: x[1], reverse=True)[: self.top_n]
new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result[0]].node, score=result[1]
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
dispatcher.event(ReRankEndEvent(nodes=new_nodes))
return new_nodes
================================================
FILE: backend/app/rag/rerankers/local/local_reranker.py
================================================
from typing import Any, List, Optional
import requests
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
dispatcher = get_dispatcher(__name__)
class LocalRerank(BaseNodePostprocessor):
api_url: str = Field(
default="http://127.0.0.1:5001/api/v1/reranker",
description="API url.",
)
model: str = Field(
default="BAAI/bge-reranker-v2-m3",
description="The model to use when calling API",
)
top_n: int = Field(description="Top N nodes to return.")
_session: Any = PrivateAttr()
def __init__(
self,
top_n: int = 2,
model: str = "BAAI/bge-reranker-v2-m3",
api_url: str = "http://127.0.0.1:5001/api/v1/reranker",
):
super().__init__(top_n=top_n, model=model)
self.api_url = api_url
self.model = model
self._session = requests.Session()
@classmethod
def class_name(cls) -> str:
return "LocalRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle,
nodes=nodes,
top_n=self.top_n,
model_name=self.model,
)
)
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
resp = self._session.post( # type: ignore
self.api_url,
json={
"query": query_bundle.query_str,
"model": self.model,
"passages": texts,
},
)
resp.raise_for_status()
resp_json = resp.json()
if "scores" not in resp_json:
raise RuntimeError(f"Got error from reranker: {resp_json}")
results = zip(range(len(nodes)), resp_json["scores"])
results = sorted(results, key=lambda x: x[1], reverse=True)[: self.top_n]
new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result[0]].node, score=result[1]
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
dispatcher.event(ReRankEndEvent(nodes=new_nodes))
return new_nodes
================================================
FILE: backend/app/rag/rerankers/provider.py
================================================
import enum
from typing import List
from pydantic import BaseModel
class RerankerProvider(str, enum.Enum):
JINA = "jina"
COHERE = "cohere"
BAISHENG = "baisheng"
LOCAL = "local"
VLLM = "vllm"
XINFERENCE = "xinference"
BEDROCK = "bedrock"
class RerankerProviderOption(BaseModel):
provider: RerankerProvider
provider_display_name: str | None = None
provider_description: str | None = None
provider_url: str | None = None
default_reranker_model: str
reranker_model_description: str
default_top_n: int = 10
default_credentials: str | dict = ""
default_config: dict = {}
config_description: str = ""
credentials_display_name: str
credentials_description: str
credentials_type: str = "str"
reranker_provider_options: List[RerankerProviderOption] = [
RerankerProviderOption(
provider=RerankerProvider.JINA,
provider_display_name="Jina AI",
provider_description="We provide best-in-class embeddings, rerankers, LLM-reader and prompt optimizers, pioneering search AI for multimodal data.",
provider_url="https://jina.ai",
default_reranker_model="jina-reranker-v2-base-multilingual",
reranker_model_description="Reference: https://jina.ai/reranker/",
default_top_n=10,
credentials_display_name="Jina API Key",
credentials_description="You can get one from https://jina.ai/reranker/",
credentials_type="str",
default_credentials="jina_****",
),
RerankerProviderOption(
provider=RerankerProvider.COHERE,
provider_display_name="Cohere",
provider_description="Cohere provides industry-leading large language models (LLMs) and RAG capabilities tailored to meet the needs of enterprise use cases that solve real-world problems.",
provider_url="https://cohere.com/",
default_reranker_model="rerank-multilingual-v3.0",
reranker_model_description="Reference: https://docs.cohere.com/reference/rerank",
default_top_n=10,
credentials_display_name="Cohere API Key",
credentials_description="You can get one from https://dashboard.cohere.com/api-keys",
credentials_type="str",
default_credentials="*****",
),
RerankerProviderOption(
provider=RerankerProvider.BAISHENG,
provider_display_name="BaiSheng",
default_reranker_model="bge-reranker-v2-m3",
reranker_model_description="",
default_top_n=10,
default_config={
"api_url": "http://api.chat.prd.yumc.local/chat/v1/reranker",
},
credentials_display_name="BaiSheng API Key",
credentials_description="",
credentials_type="str",
default_credentials="*****",
),
RerankerProviderOption(
provider=RerankerProvider.LOCAL,
provider_display_name="Local Reranker",
provider_description="TIDB.AI's local reranker server, deployed on your own infrastructure and powered by sentence-transformers.",
default_reranker_model="BAAI/bge-reranker-v2-m3",
reranker_model_description="Find more models in huggingface.",
default_top_n=10,
default_config={
"api_url": "http://local-embedding-reranker:5001/api/v1/reranker",
},
config_description="api_url is the url of the tidb ai local reranker server.",
credentials_display_name="Local Reranker API Key",
credentials_description="Local Reranker server doesn't require an API key, set a dummy string here is ok.",
credentials_type="str",
default_credentials="dummy",
),
RerankerProviderOption(
provider=RerankerProvider.VLLM,
provider_display_name="vLLM",
provider_description="vLLM is a fast and easy-to-use library for LLM inference and serving.",
default_reranker_model="BAAI/bge-reranker-v2-m3",
reranker_model_description="Reference: https://docs.vllm.ai/en/latest/models/supported_models.html#sentence-pair-scoring-task-score",
default_top_n=10,
default_config={
"base_url": "http://localhost:8000",
},
config_description="base_url is the base url of the vLLM server, ensure it can be accessed from this server",
credentials_display_name="vLLM API Key",
credentials_description="vLLM doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
RerankerProviderOption(
provider=RerankerProvider.XINFERENCE,
provider_display_name="Xinference Reranker",
provider_description="Xorbits Inference (Xinference) is an open-source platform to streamline the operation and integration of a wide array of AI models.",
default_reranker_model="bge-reranker-v2-m3",
reranker_model_description="Reference: https://inference.readthedocs.io/en/latest/models/model_abilities/rerank.html",
default_top_n=10,
default_config={
"base_url": "http://localhost:9997",
},
config_description="base_url is the url of the Xinference server, ensure it can be accessed from this server",
credentials_display_name="Xinference API Key",
credentials_description="Xinference doesn't require an API key, set a dummy string here is ok",
credentials_type="str",
default_credentials="dummy",
),
RerankerProviderOption(
provider=RerankerProvider.BEDROCK,
provider_display_name="Bedrock Reranker",
provider_description="Amazon Bedrock is a fully managed foundation models service.",
provider_url="https://docs.aws.amazon.com/bedrock/",
default_reranker_model="amazon.rerank-v1:0",
reranker_model_description="Find more models in https://docs.aws.amazon.com/bedrock/latest/userguide/foundation-models-reference.html.",
default_top_n=10,
credentials_display_name="AWS Bedrock Credentials JSON",
credentials_description="The JSON Object of AWS Credentials, refer to https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global",
credentials_type="dict",
default_credentials={
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2",
},
),
]
================================================
FILE: backend/app/rag/rerankers/resolver.py
================================================
from typing import Optional
from sqlmodel import Session
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.jinaai_rerank import JinaRerank
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank
from app.rag.rerankers.baisheng.baisheng_reranker import BaishengRerank
from app.rag.rerankers.local.local_reranker import LocalRerank
from app.rag.rerankers.vllm.vllm_reranker import VLLMRerank
from app.rag.rerankers.provider import RerankerProvider
from app.repositories.reranker_model import reranker_model_repo
def resolve_reranker_by_id(
session: Session, reranker_model_id: int, top_n: int
) -> BaseNodePostprocessor:
db_reranker_model = reranker_model_repo.must_get(session, reranker_model_id)
return resolve_reranker(
db_reranker_model.provider,
db_reranker_model.model,
top_n or db_reranker_model.top_n,
db_reranker_model.config,
db_reranker_model.credentials,
)
def resolve_reranker(
provider: RerankerProvider,
model: str,
top_n: int,
config: dict,
credentials: str | list | dict | None,
) -> BaseNodePostprocessor:
match provider:
case RerankerProvider.JINA:
return JinaRerank(
model=model,
top_n=top_n,
api_key=credentials,
**config,
)
case RerankerProvider.COHERE:
return CohereRerank(
model=model,
top_n=top_n,
api_key=credentials,
**config,
)
case RerankerProvider.BAISHENG:
return BaishengRerank(
model=model,
top_n=top_n,
api_key=credentials,
**config,
)
case RerankerProvider.LOCAL:
return LocalRerank(
model=model,
top_n=top_n,
**config,
)
case RerankerProvider.VLLM:
return VLLMRerank(
model=model,
top_n=top_n,
**config,
)
case RerankerProvider.XINFERENCE:
return XinferenceRerank(
model=model,
top_n=top_n,
**config,
)
case RerankerProvider.BEDROCK:
return AWSBedrockRerank(
rerank_model_name=model,
top_n=top_n,
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
region_name=credentials["aws_region_name"],
**config,
)
case _:
raise ValueError(f"Got unknown reranker provider: {provider}")
# FIXME: Reranker top_n should be config in the retrieval config.
def get_default_reranker_model(
session: Session, top_n: int = None
) -> Optional[BaseNodePostprocessor]:
db_reranker = reranker_model_repo.get_default(session)
if not db_reranker:
return None
top_n = db_reranker.top_n if top_n is None else top_n
return resolve_reranker(
db_reranker.provider,
db_reranker.model,
top_n,
db_reranker.config,
db_reranker.credentials,
)
def must_get_default_reranker_model(session: Session) -> BaseNodePostprocessor:
db_reranker = reranker_model_repo.must_get_default(session)
return resolve_reranker(
db_reranker.provider,
db_reranker.model,
db_reranker.top_n,
db_reranker.config,
db_reranker.credentials,
)
================================================
FILE: backend/app/rag/rerankers/vllm/vllm_reranker.py
================================================
from typing import Any, List, Optional
import requests
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
dispatcher = get_dispatcher(__name__)
class VLLMRerank(BaseNodePostprocessor):
base_url: str = Field(default="", description="The base URL of vLLM API.")
model: str = Field(default="", description="The model to use when calling API.")
top_n: int = Field(description="Top N nodes to return.")
_session: Any = PrivateAttr()
def __init__(
self,
top_n: int = 2,
model: str = "BAAI/bge-reranker-v2-m3",
base_url: str = "http://localhost:8000",
):
super().__init__(top_n=top_n, model=model)
self.base_url = base_url
self.model = model
self._session = requests.Session()
@classmethod
def class_name(cls) -> str:
return "VLLMRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle,
nodes=nodes,
top_n=self.top_n,
model_name=self.model,
)
)
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
resp = self._session.post( # type: ignore
url=f"{self.base_url}/v1/score",
json={
"text_1": query_bundle.query_str,
"model": self.model,
"text_2": texts,
},
)
resp.raise_for_status()
resp_json = resp.json()
if "data" not in resp_json:
raise RuntimeError(f"Got error from reranker: {resp_json}")
results = zip(range(len(nodes)), resp_json["data"])
results = sorted(results, key=lambda x: x[1]["score"], reverse=True)[
: self.top_n
]
new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result[0]].node, score=result[1]["score"]
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
dispatcher.event(ReRankEndEvent(nodes=new_nodes))
return new_nodes
================================================
FILE: backend/app/rag/retrievers/__init__.py
================================================
================================================
FILE: backend/app/rag/retrievers/chunk/__init__.py
================================================
================================================
FILE: backend/app/rag/retrievers/chunk/fusion_retriever.py
================================================
from typing import List, Optional, Dict, Tuple
from llama_index.core import QueryBundle
from llama_index.core.callbacks import CallbackManager
from llama_index.core.llms import LLM
from llama_index.core.schema import NodeWithScore
from sqlmodel import Session
from app.rag.retrievers.chunk.simple_retriever import (
ChunkSimpleRetriever,
)
from app.rag.retrievers.chunk.schema import (
RetrievedChunkDocument,
VectorSearchRetrieverConfig,
ChunksRetrievalResult,
ChunkRetriever,
)
from app.rag.retrievers.chunk.helpers import map_nodes_to_chunks
from app.rag.retrievers.multiple_knowledge_base import MultiKBFusionRetriever
from app.repositories import knowledge_base_repo, document_repo
class ChunkFusionRetriever(MultiKBFusionRetriever, ChunkRetriever):
def __init__(
self,
db_session: Session,
knowledge_base_ids: List[int],
llm: LLM,
use_query_decompose: bool = False,
config: VectorSearchRetrieverConfig = VectorSearchRetrieverConfig(),
callback_manager: Optional[CallbackManager] = CallbackManager([]),
**kwargs,
):
# Prepare vector search retrievers for knowledge bases.
retrievers = []
knowledge_bases = knowledge_base_repo.get_by_ids(db_session, knowledge_base_ids)
for kb in knowledge_bases:
retrievers.append(
ChunkSimpleRetriever(
knowledge_base_id=kb.id,
config=config,
callback_manager=callback_manager,
db_session=db_session,
)
)
super().__init__(
db_session=db_session,
retrievers=retrievers,
llm=llm,
use_query_decompose=use_query_decompose,
callback_manager=callback_manager,
**kwargs,
)
def _fusion(
self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
return self._simple_fusion(query, results)
def _simple_fusion(
self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]]
):
"""Apply simple fusion."""
# Use a dict to de-duplicate nodes
all_nodes: Dict[str, NodeWithScore] = {}
for nodes_with_scores in results.values():
for node_with_score in nodes_with_scores:
hash = node_with_score.node.hash
if hash in all_nodes:
max_score = max(
node_with_score.score or 0.0, all_nodes[hash].score or 0.0
)
all_nodes[hash].score = max_score
else:
all_nodes[hash] = node_with_score
return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
def retrieve_chunks(
self,
query_str: str,
full_document: bool = False,
) -> ChunksRetrievalResult:
nodes_with_score = self._retrieve(QueryBundle(query_str))
chunks = map_nodes_to_chunks(nodes_with_score)
document_ids = [c.document_id for c in chunks]
documents = document_repo.fetch_by_ids(self._db_session, document_ids)
if full_document:
return ChunksRetrievalResult(chunks=chunks, documents=documents)
else:
return ChunksRetrievalResult(
chunks=chunks,
documents=[
RetrievedChunkDocument(
id=d.id, name=d.name, source_uri=d.source_uri
)
for d in documents
],
)
================================================
FILE: backend/app/rag/retrievers/chunk/helpers.py
================================================
from typing import List
from llama_index.core.schema import NodeWithScore
from app.rag.retrievers.chunk.schema import RetrievedChunk
def map_nodes_to_chunks(nodes_with_score: List[NodeWithScore]) -> List[RetrievedChunk]:
return [
RetrievedChunk(
id=ns.node.node_id,
text=ns.node.text,
metadata=ns.node.metadata,
document_id=ns.node.metadata["document_id"],
score=ns.score,
)
for ns in nodes_with_score
]
================================================
FILE: backend/app/rag/retrievers/chunk/schema.py
================================================
from abc import ABC
from typing import Any, Dict, Optional
from pydantic import BaseModel
from app.models import Document
class RerankerConfig(BaseModel):
enabled: bool = True
model_id: int = None
top_n: int = 10
class MetadataFilterConfig(BaseModel):
enabled: bool = True
filters: Dict[str, Any] = None
class VectorSearchRetrieverConfig(BaseModel):
top_k: int = 10
similarity_top_k: Optional[int] = None
oversampling_factor: Optional[int] = 5
reranker: Optional[RerankerConfig] = None
metadata_filter: Optional[MetadataFilterConfig] = None
class KBChunkRetrievalConfig(BaseModel):
knowledge_base_ids: list[int] = None
class ChunkRetrievalConfig(BaseModel):
knowledge_base_ids: list[int] = None
# Retrieved Chunks
class RetrievedChunkDocument(BaseModel):
id: int
name: str
source_uri: str
class RetrievedChunk(BaseModel):
id: str
text: str
metadata: dict
document_id: Optional[int]
score: float
class ChunksRetrievalResult(BaseModel):
chunks: list[RetrievedChunk]
documents: Optional[list[Document | RetrievedChunkDocument]] = None
class ChunkRetriever(ABC):
def retrieve_chunks(
self,
query_str: str,
full_document: bool = False,
) -> ChunksRetrievalResult:
"""Retrieve chunks"""
================================================
FILE: backend/app/rag/retrievers/chunk/simple_retriever.py
================================================
import logging
from typing import List, Optional, Type
from llama_index.core.callbacks import CallbackManager
from llama_index.core.indices.utils import log_vector_store_query_result
from llama_index.core.vector_stores import VectorStoreQuery, VectorStoreQueryResult
from sqlmodel import Session
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore, QueryBundle
import llama_index.core.instrumentation as instrument
from sqlmodel import SQLModel
from app.models.chunk import get_kb_chunk_model
from app.rag.knowledge_base.config import get_kb_embed_model
from app.rag.rerankers.resolver import resolve_reranker_by_id
from app.rag.retrievers.chunk.schema import (
RetrievedChunkDocument,
VectorSearchRetrieverConfig,
ChunksRetrievalResult,
ChunkRetriever,
)
from app.rag.retrievers.chunk.helpers import map_nodes_to_chunks
from app.rag.indices.vector_search.vector_store.tidb_vector_store import TiDBVectorStore
from app.rag.postprocessors.metadata_post_filter import MetadataPostFilter
from app.repositories import knowledge_base_repo, document_repo
logger = logging.getLogger(__name__)
dispatcher = instrument.get_dispatcher(__name__)
class ChunkSimpleRetriever(BaseRetriever, ChunkRetriever):
_chunk_model: Type[SQLModel]
def __init__(
self,
knowledge_base_id: int,
config: VectorSearchRetrieverConfig,
db_session: Optional[Session] = None,
callback_manager: CallbackManager = CallbackManager([]),
):
super().__init__()
if not knowledge_base_id:
raise ValueError("Knowledge base id is required")
self._config = config
self._db_session = db_session
self._kb = knowledge_base_repo.must_get(db_session, knowledge_base_id)
self._chunk_db_model = get_kb_chunk_model(self._kb)
self._embed_model = get_kb_embed_model(db_session, self._kb)
self._embed_model.callback_manager = callback_manager
# Init vector store.
self._vector_store = TiDBVectorStore(
session=db_session,
chunk_db_model=self._chunk_db_model,
oversampling_factor=config.oversampling_factor,
callback_manager=callback_manager,
)
# Init node postprocessors.
node_postprocessors = []
# Metadata filter
filter_config = config.metadata_filter
if filter_config and filter_config.enabled:
metadata_filter = MetadataPostFilter(filter_config.filters)
node_postprocessors.append(metadata_filter)
# Reranker
reranker_config = config.reranker
if reranker_config and reranker_config.enabled:
reranker = resolve_reranker_by_id(
db_session, reranker_config.model_id, reranker_config.top_n
)
node_postprocessors.append(reranker)
self._node_postprocessors = node_postprocessors
@dispatcher.span
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0:
query_bundle.embedding = self._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
result = self._vector_store.query(
VectorStoreQuery(
query_str=query_bundle.query_str,
query_embedding=query_bundle.embedding,
similarity_top_k=self._config.similarity_top_k or self._config.top_k,
)
)
nodes = self._build_node_list_from_query_result(result)
for node_postprocessor in self._node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(
nodes, query_bundle=query_bundle
)
return nodes[: self._config.top_k]
def _build_node_list_from_query_result(
self, query_result: VectorStoreQueryResult
) -> List[NodeWithScore]:
log_vector_store_query_result(query_result)
node_with_scores: List[NodeWithScore] = []
for ind, node in enumerate(query_result.nodes):
score: Optional[float] = None
if query_result.similarities is not None:
score = query_result.similarities[ind]
node_with_scores.append(NodeWithScore(node=node, score=score))
return node_with_scores
def retrieve_chunks(
self, query_str: str, full_document: bool = False
) -> ChunksRetrievalResult:
nodes_with_score = self.retrieve(query_str)
chunks = map_nodes_to_chunks(nodes_with_score)
document_ids = [c.document_id for c in chunks]
documents = document_repo.fetch_by_ids(self._db_session, document_ids)
if full_document:
return ChunksRetrievalResult(chunks=chunks, documents=documents)
else:
return ChunksRetrievalResult(
chunks=chunks,
documents=[
RetrievedChunkDocument(
id=d.id, name=d.name, source_uri=d.source_uri
)
for d in documents
],
)
================================================
FILE: backend/app/rag/retrievers/knowledge_graph/__init__.py
================================================
================================================
FILE: backend/app/rag/retrievers/knowledge_graph/fusion_retriever.py
================================================
import logging
from sqlmodel import Session
from typing import List, Optional, Dict, Tuple
from llama_index.core import QueryBundle
from llama_index.core.callbacks import CallbackManager
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import LLM
from app.models import KnowledgeBase
from app.rag.retrievers.multiple_knowledge_base import MultiKBFusionRetriever
from app.rag.retrievers.knowledge_graph.simple_retriever import (
KnowledgeGraphSimpleRetriever,
)
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrieverConfig,
KnowledgeGraphRetrievalResult,
KnowledgeGraphNode,
KnowledgeGraphRetriever,
)
from app.repositories import knowledge_base_repo
logger = logging.getLogger(__name__)
class KnowledgeGraphFusionRetriever(MultiKBFusionRetriever, KnowledgeGraphRetriever):
knowledge_base_map: Dict[int, KnowledgeBase] = {}
def __init__(
self,
db_session: Session,
knowledge_base_ids: List[int],
llm: LLM,
use_query_decompose: bool = False,
config: KnowledgeGraphRetrieverConfig = KnowledgeGraphRetrieverConfig(),
callback_manager: Optional[CallbackManager] = CallbackManager([]),
**kwargs,
):
self.use_query_decompose = use_query_decompose
# Prepare knowledge graph retrievers for knowledge bases.
retrievers = []
knowledge_bases = knowledge_base_repo.get_by_ids(db_session, knowledge_base_ids)
self.knowledge_bases = knowledge_bases
for kb in knowledge_bases:
self.knowledge_base_map[kb.id] = kb
retrievers.append(
KnowledgeGraphSimpleRetriever(
db_session=db_session,
knowledge_base_id=kb.id,
config=config,
callback_manager=callback_manager,
)
)
super().__init__(
db_session=db_session,
retrievers=retrievers,
llm=llm,
use_query_decompose=use_query_decompose,
callback_manager=callback_manager,
**kwargs,
)
def retrieve_knowledge_graph(
self, query_text: str
) -> KnowledgeGraphRetrievalResult:
nodes_with_score = self._retrieve(QueryBundle(query_text))
if len(nodes_with_score) == 0:
return KnowledgeGraphRetrievalResult()
node: KnowledgeGraphNode = nodes_with_score[0].node # type:ignore
return KnowledgeGraphRetrievalResult(
query=node.query,
knowledge_bases=[kb.to_descriptor() for kb in self.knowledge_bases],
entities=node.entities,
relationships=node.relationships,
subgraphs=[
KnowledgeGraphRetrievalResult(
query=child_node.query,
knowledge_base=self.knowledge_base_map[
child_node.knowledge_base_id
].to_descriptor(),
entities=child_node.entities,
relationships=child_node.relationships,
)
for child_node in node.children
],
)
def _fusion(
self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
return self._knowledge_graph_fusion(query, results)
def _knowledge_graph_fusion(
self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
merged_entities = set()
merged_relationships = {}
merged_knowledge_base_ids = set()
merged_children_nodes = []
for nodes_with_scores in results.values():
if len(nodes_with_scores) == 0:
continue
node: KnowledgeGraphNode = nodes_with_scores[0].node # type:ignore
# Merge knowledge base id.
merged_knowledge_base_ids.add(node.knowledge_base_id)
# Merge entities.
merged_entities.update(node.entities)
# Merge relationships.
for r in node.relationships:
key = r.rag_description
if key not in merged_relationships:
merged_relationships[key] = r
else:
merged_relationships[key].weight += r.weight
# Merge to children nodes.
merged_children_nodes.append(node)
return [
NodeWithScore(
node=KnowledgeGraphNode(
query=query,
entities=list(merged_entities),
relationships=list(merged_relationships.values()),
knowledge_base_ids=merged_knowledge_base_ids,
children=merged_children_nodes,
),
score=1,
)
]
================================================
FILE: backend/app/rag/retrievers/knowledge_graph/schema.py
================================================
import datetime
import json
from abc import ABC
from enum import Enum
from hashlib import sha256
from typing import Optional, Mapping, Any, List
from llama_index.core.schema import BaseNode, MetadataMode
from pydantic import BaseModel, Field
from app.models.entity import EntityType
from app.api.admin_routes.models import KnowledgeBaseDescriptor
# Retriever Config
class MetadataFilterConfig(BaseModel):
enabled: bool = True
filters: dict[str, Any] = None
class KnowledgeGraphRetrieverConfig(BaseModel):
depth: int = 2
include_meta: bool = False
with_degree: bool = False
metadata_filter: Optional[MetadataFilterConfig] = None
# Stored Knowledge Graph
class StoredKnowledgeGraphVersion(int, Enum):
V1 = 1
class StoredSubGraph(BaseModel):
query: Optional[str] = None
knowledge_base_id: Optional[int] = None
entities: Optional[list[int]] = None
relationships: Optional[list[int]] = None
class StoredKnowledgeGraph(StoredSubGraph):
"""
StoredKnowledgeGraph represents the structure of the knowledge graph stored in the database.
"""
# If not provided, it means that the old version of the storage format is used, which only
# stores entities and relationships information.
version: Optional[int] = StoredKnowledgeGraphVersion.V1
knowledge_base_ids: Optional[list[int]] = []
subgraphs: Optional[list["StoredSubGraph"]] = None
# Retrieved Knowledge Graph
class RetrievedEntity(BaseModel):
id: int = Field(description="ID of the entity")
knowledge_base_id: Optional[int] = Field(
description="ID of the knowledge base", default=None
)
entity_type: Optional[EntityType] = Field(
description="Type of the entity", default=EntityType.original
)
name: str = Field(description="Name of the entity")
description: str = Field(description="Description of the entity", default="")
meta: Optional[Mapping[str, Any]] = Field(
description="Metadata of the entity", default={}
)
@property
def global_id(self) -> str:
return f"{self.knowledge_base_id or 0}-{self.id}"
def __hash__(self):
return hash(self.global_id)
class RetrievedRelationship(BaseModel):
id: int = Field(description="ID of the relationship")
knowledge_base_id: int = Field(description="ID of the knowledge base", default=None)
source_entity_id: int = Field(description="ID of the source entity")
target_entity_id: int = Field(description="ID of the target entity")
description: str = Field(description="Description of the relationship")
meta: Optional[Mapping[str, Any]] = Field(
description="Metadata of the relationship"
)
rag_description: Optional[str] = Field(
description="RAG description of the relationship"
)
weight: Optional[float] = Field(description="Weight of the relationship")
last_modified_at: Optional[datetime.datetime] = Field(
description="Last modified at of the relationship", default=None
)
@property
def global_id(self) -> str:
return f"{self.knowledge_base_id or 0}-{self.id}"
def __hash__(self):
return hash(self.global_id)
class RetrievedSubGraph(BaseModel):
query: Optional[str | list[str]] = Field(
description="List of queries that are used to retrieve the knowledge graph",
default=None,
)
knowledge_base: Optional[KnowledgeBaseDescriptor] = Field(
description="The knowledge base that the knowledge graph is retrieved from",
default=None,
)
entities: List[RetrievedEntity] = Field(
description="List of entities in the knowledge graph", default_factory=list
)
relationships: List[RetrievedRelationship] = Field(
description="List of relationships in the knowledge graph", default_factory=list
)
class RetrievedKnowledgeGraph(RetrievedSubGraph):
"""
RetrievedKnowledgeGraph represents the structure of the knowledge graph retrieved
from the knowledge base.
"""
knowledge_bases: Optional[List[KnowledgeBaseDescriptor]] = Field(
description="List of knowledge bases that the knowledge graph is retrieved from",
default_factory=list,
)
subgraphs: Optional[List["RetrievedSubGraph"]] = Field(
description="List of subgraphs of the knowledge graph", default_factory=list
)
def to_subqueries_dict(self) -> dict:
"""
For forward compatibility, we need to convert the subgraphs to a dictionary
of subqueries and then pass it to the prompt template.
"""
subqueries = {}
for subgraph in self.subgraphs:
if subgraph.query not in subqueries:
subqueries[subgraph.query] = {
"entities": [e.model_dump() for e in subgraph.entities],
"relationships": [r.model_dump() for r in subgraph.relationships],
}
else:
subqueries[subgraph.query]["entities"].extend(
[e.model_dump() for e in subgraph.entities]
)
subqueries[subgraph.query]["relationships"].extend(
[r.model_dump() for r in subgraph.relationships]
)
return subqueries
def to_stored_graph_dict(self) -> dict:
subgraph = self.to_stored_graph()
return subgraph.model_dump()
def to_stored_graph(self) -> StoredKnowledgeGraph:
return StoredKnowledgeGraph(
query=self.query,
knowledge_base_id=self.knowledge_base.id if self.knowledge_base else None,
knowledge_base_ids=[kb.id for kb in self.knowledge_bases]
if self.knowledge_bases
else None,
entities=[e.id for e in self.entities],
relationships=[r.id for r in self.relationships],
subgraphs=[s.to_stored_graph() for s in self.subgraphs],
)
KnowledgeGraphRetrievalResult = RetrievedKnowledgeGraph
class KnowledgeGraphRetriever(ABC):
def retrieve_knowledge_graph(self, query_str: str) -> KnowledgeGraphRetrievalResult:
raise NotImplementedError
# KnowledgeGraphNode
DEFAULT_KNOWLEDGE_GRAPH_TMPL = """
Query:
------
{query}
Entities:
------
{entities_str}
Relationships:
------
{relationships_str}
"""
DEFAULT_ENTITY_TMPL = """
- Name: {{ name }}
Description: {{ description }}
"""
DEFAULT_RELATIONSHIP_TMPL = """
- Description: {{ rag_description }}
Weight: {{ weight }}
Last Modified At: {{ last_modified_at }}
Meta: {{ meta }}
"""
class KnowledgeGraphNode(BaseNode):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
query: Optional[str] = Field(description="Query of the knowledge graph")
knowledge_base_id: Optional[int] = Field(
description="The id of the knowledge base that the knowledge graph belongs to",
default=None,
)
knowledge_base_ids: Optional[List[int]] = Field(
description="List of ids of the knowledge base that the knowledge graph belongs to",
default_factory=list,
)
entities: List[RetrievedEntity] = Field(
description="The list of entities in the knowledge graph", default_factory=list
)
relationships: List[RetrievedRelationship] = Field(
description="The list of relationships in the knowledge graph",
default_factory=list,
)
children: Optional[List["KnowledgeGraphNode"]] = Field(
description="The children of the knowledge graph",
default_factory=list,
)
# Template
knowledge_base_template: str = Field(
default=DEFAULT_KNOWLEDGE_GRAPH_TMPL,
description="The template to render the knowledge graph as string",
)
entity_template: str = Field(
default=DEFAULT_ENTITY_TMPL,
description="The template to render the entity list as string",
)
relationship_template: str = Field(
default=DEFAULT_RELATIONSHIP_TMPL,
description="The template to render the relationship list as string",
)
@classmethod
def get_type(cls) -> str:
return "KnowledgeGraphNode"
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
return f"""
Query:
------
{self.query}
Entities:
------
{self._get_entities_str()}
Relationships:
------
{self._get_relationships_str()}
"""
def _get_entities_str(self) -> str:
strs = []
for entity in self.entities:
strs.append(
self.entity_template.format(
name=entity.name, description=entity.description
)
)
return "\n\n".join(strs)
def _get_relationships_str(self) -> str:
strs = []
for relationship in self.relationships:
strs.append(
self.entity_template.format(
rag_description=relationship.rag_description,
weight=relationship.weight,
last_modified_at=relationship.last_modified_at,
meta=json.dumps(relationship.meta, indent=2, ensure_ascii=False),
)
)
return "\n\n".join(strs)
def _get_knowledge_graph_str(self) -> str:
return self.knowledge_base_template.format(
query=self.query,
entities_str=self._get_entities_str(),
relationships_str=self._get_relationships_str(),
)
def set_content(self, kg: RetrievedKnowledgeGraph):
self.query = kg.query
self.knowledge_base_id = kg.knowledge_base.id if kg.knowledge_base else None
self.knowledge_base_ids = []
self.entities = kg.entities
self.relationships = kg.relationships
self.children = [
KnowledgeGraphNode(
query=subgraph.query,
knowledge_base_id=subgraph.knowledge_base.id
if subgraph.knowledge_base
else None,
entities=subgraph.entities,
relationships=subgraph.relationships,
)
for subgraph in kg.subgraphs
]
@property
def hash(self) -> str:
kg_identity = self._get_knowledge_graph_str().encode("utf-8")
return str(sha256(kg_identity).hexdigest())
================================================
FILE: backend/app/rag/retrievers/knowledge_graph/simple_retriever.py
================================================
from typing import Optional, List
from sqlmodel import Session
from llama_index.core import QueryBundle
from llama_index.core.callbacks import CallbackManager
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from app.models.chunk import get_kb_chunk_model
from app.models.entity import get_kb_entity_model
from app.models.relationship import get_kb_relationship_model
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrieverConfig,
KnowledgeGraphRetrievalResult,
KnowledgeGraphNode,
KnowledgeGraphRetriever,
)
from app.rag.knowledge_base.config import get_kb_embed_model, get_kb_dspy_llm
from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore
from app.repositories import knowledge_base_repo
class KnowledgeGraphSimpleRetriever(BaseRetriever, KnowledgeGraphRetriever):
def __init__(
self,
db_session: Session,
knowledge_base_id: int,
config: KnowledgeGraphRetrieverConfig,
callback_manager: Optional[CallbackManager] = CallbackManager([]),
**kwargs,
):
super().__init__(callback_manager, **kwargs)
self.config = config
self._callback_manager = callback_manager
self.knowledge_base = knowledge_base_repo.must_get(
db_session, knowledge_base_id
)
self.embed_model = get_kb_embed_model(db_session, self.knowledge_base)
self.embed_model.callback_manager = callback_manager
self.chunk_db_model = get_kb_chunk_model(self.knowledge_base)
self.entity_db_model = get_kb_entity_model(self.knowledge_base)
self.relationship_db_model = get_kb_relationship_model(self.knowledge_base)
# TODO: remove it
dspy_lm = get_kb_dspy_llm(db_session, self.knowledge_base)
self._kg_store = TiDBGraphStore(
knowledge_base=self.knowledge_base,
dspy_lm=dspy_lm,
session=db_session,
embed_model=self.embed_model,
entity_db_model=self.entity_db_model,
relationship_db_model=self.relationship_db_model,
chunk_db_model=self.chunk_db_model,
)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
metadata_filters = {}
if self.config.metadata_filter and self.config.metadata_filter.enabled:
metadata_filters = self.config.metadata_filter.filters
entities, relationships = self._kg_store.retrieve_with_weight(
query_bundle.query_str,
embedding=[],
depth=self.config.depth,
include_meta=self.config.include_meta,
with_degree=self.config.with_degree,
relationship_meta_filters=metadata_filters,
)
return [
NodeWithScore(
node=KnowledgeGraphNode(
query=query_bundle.query_str,
knowledge_base_id=self.knowledge_base.id,
entities=entities,
relationships=relationships,
),
score=1,
)
]
def retrieve_knowledge_graph(
self, query_text: str
) -> KnowledgeGraphRetrievalResult:
nodes_with_score = self._retrieve(QueryBundle(query_text))
if len(nodes_with_score) == 0:
return KnowledgeGraphRetrievalResult()
node: KnowledgeGraphNode = nodes_with_score[0].node # type:ignore
return KnowledgeGraphRetrievalResult(
query=node.query,
knowledge_base=self.knowledge_base.to_descriptor(),
entities=node.entities,
relationships=node.relationships,
subgraphs=[],
)
================================================
FILE: backend/app/rag/retrievers/multiple_knowledge_base.py
================================================
from abc import abstractmethod
import dspy
from typing import List, Optional, Dict, Tuple
from llama_index.core import QueryBundle
from llama_index.core.async_utils import run_async_tasks
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.callbacks import CallbackManager
from llama_index.core.llms import LLM
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel
from sqlmodel import Session
from app.core.config import settings
from app.rag.question_gen.query_decomposer import QueryDecomposer
from app.rag.types import MyCBEventType
from app.rag.llms.dspy import get_dspy_lm_by_llama_llm
class FusionRetrievalBaseConfig(BaseModel):
llm_id: Optional[int] = None
knowledge_base_ids: List[int]
use_query_decompose: Optional[bool] = None
class MultiKBFusionRetriever(BaseRetriever):
def __init__(
self,
retrievers: List[BaseRetriever],
db_session: Session,
llm: LLM,
dspy_lm: Optional[dspy.LM] = None,
use_query_decompose: bool = True,
callback_manager: Optional[CallbackManager] = CallbackManager([]),
**kwargs,
):
super().__init__(callback_manager, **kwargs)
self._use_query_decompose = use_query_decompose
self._db_session = db_session
self._callback_manager = callback_manager
# Setup query decomposer.
self._dspy_lm = dspy_lm or get_dspy_lm_by_llama_llm(llm)
self._query_decomposer = QueryDecomposer(
dspy_lm=self._dspy_lm,
complied_program_path=settings.COMPLIED_INTENT_ANALYSIS_PROGRAM_PATH,
)
# Setup multiple knowledge base selector.
self._retrievers = retrievers
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
if self._use_query_decompose:
queries = self._gen_sub_queries(query_bundle)
else:
queries = [query_bundle]
with self.callback_manager.event(
MyCBEventType.RUN_SUB_QUERIES, payload={"queries": queries}
):
tasks, task_queries = [], []
for query in queries:
for i, retriever in enumerate(self._retrievers):
tasks.append(retriever.aretrieve(query.query_str))
task_queries.append((query.query_str, i))
task_results = run_async_tasks(tasks)
results = {}
for query_tuple, query_result in zip(task_queries, task_results):
results[query_tuple] = query_result
return self._fusion(query_bundle.query_str, results)
def _gen_sub_queries(self, query_bundle: QueryBundle) -> List[QueryBundle]:
queries = self._query_decomposer.decompose(query_bundle.query_str)
return [QueryBundle(r.question) for r in queries.questions]
@abstractmethod
def _fusion(
self, query: str, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
"""fusion method"""
================================================
FILE: backend/app/rag/semantic_cache/__init__.py
================================================
from .base import SemanticCacheManager, SemanticItem
__all__ = ["SemanticCacheManager", "SemanticItem"]
================================================
FILE: backend/app/rag/semantic_cache/base.py
================================================
import time
import dspy
import logging
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
from sqlmodel import Session, select, func
from llama_index.core.embeddings.utils import EmbedType, resolve_embed_model
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType
from app.models import SemanticCache
logger = logging.getLogger(__name__)
class SemanticItem(BaseModel):
"""A single question-answer pair for semantic search."""
question: str = Field(description="The question in the question-answer pair.")
answer: str = Field(description="The answer corresponding to the question.")
class SemanticCandidate(BaseModel):
"""A single question pair for semantic search."""
question: str = Field(description="The question in the question-answer pair.")
class SemanticGroup(BaseModel):
"""A collection of question-answer pairs for semantic search."""
items: List[SemanticCandidate] = Field(description="A list of questions.")
class QASemanticOutput(BaseModel):
"""The output of the semantic search operation."""
match_type: Literal["exact_match", "no_match", "similar_match"] = Field(
description=(
"The type of match found during the search. Use 'exact_match' if the query semantically matches the same "
"question, meaning it is asking about the exact same topic. "
"For all other cases, classify the match as 'no_match' or 'similar_match'."
)
)
items: List[SemanticCandidate] = Field(
description=(
"The question-answer pair that matches the query. "
"If the match_type is 'no_match', return an empty list. "
"If the match_type is 'similar_match', return the most relevant questions."
"If the match_type is 'exact_match', return the question that is semantically identical to the query."
)
)
class QASemanticSearchModule(dspy.Signature):
"""
This module performs a semantic search to identify the best matching question-answer pairs from a given set of candidates.
The semantic search process includes:
- Comparing the query against a set of candidate question-answer pairs.
- Returning an 'exact_match' if the query semantically matches a candidate question, meaning it is asking the exact same question.
- Returning a 'similar_match' if the query is related to a candidate question but does not semantically match exactly.
- Returning a 'no_match' if the query does not match any candidate question.
Note: The output items must be selected from the provided candidates.
"""
query: str = dspy.InputField(
description="The query string to search for within the candidates."
)
candidats: SemanticGroup = dspy.InputField(
description="A collection of frequently asked questions to search through."
)
output: QASemanticOutput = dspy.OutputField(
description="The question that best matches the query string. "
)
class SemanticSearchProgram(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.prog = dspy.TypedChainOfThought(QASemanticSearchModule)
def forward(self, query: str, candidats: SemanticGroup):
with dspy.settings.context(lm=self.dspy_lm):
return self.prog(query=query, candidats=candidats)
class SemanticCacheManager:
def __init__(
self,
dspy_llm: dspy.LM,
embed_model: Optional[EmbedType] = None,
complied_sc_search_program_path: Optional[str] = None,
):
self._dspy_lm = dspy_llm
if embed_model:
self._embed_model = resolve_embed_model(embed_model)
else:
self._embed_model = OpenAIEmbedding(
model=OpenAIEmbeddingModelType.TEXT_EMBED_3_SMALL
)
self.prog = SemanticSearchProgram(dspy_lm=dspy_llm)
if complied_sc_search_program_path is not None:
self.prog.load(complied_sc_search_program_path)
def get_query_embedding(self, query: str):
return self._embed_model.get_query_embedding(query)
def add_cache(
self,
session: Session,
item: SemanticItem,
namespace: str,
metadata: Optional[dict] = None,
):
if metadata is None:
metadata = {}
metadata["namespace"] = namespace
object = SemanticCache(
query=item.question,
query_vec=self.get_query_embedding(item.question),
value=item.answer,
value_vec=self.get_query_embedding(item.answer),
meta=metadata,
)
session.add(object)
session.commit()
def search(
self, session: Session, query: str, namespace: Optional[str] = None
) -> QASemanticOutput:
start_time = time.time()
embedding = self.get_query_embedding(query)
logger.debug(
f"[search_semantic_cache] Get query embedding {time.time() - start_time:.2f} seconds"
)
start_time = time.time()
sql = (
select(
SemanticCache,
SemanticCache.query_vec.cosine_distance(embedding).label("distance"),
)
.having(SemanticCache.query_vec.cosine_distance(embedding) < 0.5)
.order_by("distance")
.limit(20)
)
if namespace:
sql = sql.where(
func.json_extract(SemanticCache.meta, "$.namespace") == namespace
)
results = session.execute(sql).all()
candidates = SemanticGroup(
items=[
SemanticCandidate(
question=result.SemanticCache.query,
)
for result in results
]
)
logger.debug(
f"[search_semantic_cache] Search semantic cache {time.time() - start_time:.2f} seconds"
)
start_time = time.time()
if len(candidates.items) == 0:
return {
"match_type": "no_match",
"items": [],
}
pred = self.prog(query=query, candidats=candidates)
logger.debug(
f"[search_semantic_cache] Predict semantic cache {time.time() - start_time:.2f} seconds"
)
logger.debug(f"[search_semantic_cache] Predict semantic cache {pred.output}")
# filter the matched items and it's metadata
matched_items = []
for item in pred.output.items:
question = item.question
# find the matched item in the results
for result in results:
if result.SemanticCache.query == question:
matched_items.append(
{
"question": result.SemanticCache.query,
"answer": result.SemanticCache.value,
"meta": result.SemanticCache.meta,
}
)
break
return {"match_type": pred.output.match_type, "items": matched_items}
================================================
FILE: backend/app/rag/types.py
================================================
import enum
# Langfuse needs an enum class for event types,
# but the CBEventType in llama-index does not have sufficient types.
class MyCBEventType(str, enum.Enum):
CHUNKING = "chunking"
NODE_PARSING = "node_parsing"
EMBEDDING = "embedding"
LLM = "llm"
QUERY = "query"
RETRIEVE = "retrieve"
SYNTHESIZE = "synthesize"
TREE = "tree"
SUB_QUESTION = "sub_question"
TEMPLATING = "templating"
FUNCTION_CALL = "function_call"
RERANKING = "reranking"
EXCEPTION = "exception"
AGENT_STEP = "agent_step"
CLARIFYING_QUESTION = "clarifying_question"
CONDENSE_QUESTION = "condense_question"
REFINE_QUESTION = "refine_question"
RETRIEVE_FROM_GRAPH = "retrieve_from_graph"
INTENT_DECOMPOSITION = "intent_decomposition"
GRAPH_SEMANTIC_SEARCH = "graph_semantic_search"
SELECT_KNOWLEDGE_BASE = "select_knowledge_base"
RUN_SUB_QUERIES = "run_sub_queries"
# Chat stream response event types
class ChatEventType(int, enum.Enum):
# Following vercel ai sdk's event type
# https://github.com/vercel/ai/blob/84871281ab5a2c080e3f8e18d02cd09c7e1691c4/packages/ui-utils/src/stream-parts.ts#L368
TEXT_PART = 0
DATA_PART = 2
ERROR_PART = 3
MESSAGE_ANNOTATIONS_PART = 8
class ChatMessageSate(int, enum.Enum):
TRACE = 0
SOURCE_NODES = 1
KG_RETRIEVAL = 2
REFINE_QUESTION = 3
SEARCH_RELATED_DOCUMENTS = 4
GENERATE_ANSWER = 5
FINISHED = 9
================================================
FILE: backend/app/rag/utils.py
================================================
import re
from typing import Tuple, Dict
def _parse_response_format(response_format_str: str) -> Dict[str, str]:
"""
Parses the requirements string into a dictionary.
Args:
req_str (str): The requirements string.
Returns:
Dict[str, str]: A dictionary of parsed requirements.
"""
requirements = {}
parts = re.split(r",\s*(?=\w[\w\s]*:\s*[^,()]+)", response_format_str)
for part in parts:
if ":" in part:
key, value = part.split(":", 1)
requirements[key.strip()] = value.strip()
else:
requirements[part.strip()] = None
return requirements
def parse_goal_response_format(goal: str) -> Tuple[str, Dict[str, str]]:
"""
Extracts the main goal and its requirements from the input string.
Args:
goal (str): The input question string with optional requirements.
Returns:
Tuple[str, Dict[str, str]]: A tuple containing the main goal and a dictionary of requirements.
"""
# Initialize
clean_goal = goal.strip()
response_format = None
# Remove starting quote if present
if clean_goal.startswith('"'):
clean_goal = clean_goal[1:].strip()
# Remove ending quote if present
if clean_goal.endswith('"'):
clean_goal = clean_goal[:-1].strip()
# Function to find the last balanced parentheses by reverse traversal
def extract_last_parentheses(s: str) -> Tuple[str, str]:
"""
Extracts the last balanced parentheses content from the string by traversing from the end.
Args:
s (str): The input string.
Returns:
Tuple[str, str]: A tuple containing the string without the last parentheses
and the content within the last parentheses.
"""
stack = []
last_close = s.rfind(")")
if last_close == -1:
return s, "" # No closing parenthesis found
for i in range(last_close, -1, -1):
if s[i] == ")":
stack.append(i)
elif s[i] == "(":
if stack:
stack.pop()
if not stack:
# Found the matching opening parenthesis
return s[:i].strip(), s[i + 1 : last_close].strip()
return s, "" # No matching opening parenthesis found
# Extract the last parentheses content
clean_goal, req_str = extract_last_parentheses(clean_goal)
if req_str:
response_format = _parse_response_format(req_str)
else:
response_format = {}
return clean_goal, response_format
================================================
FILE: backend/app/repositories/__init__.py
================================================
# flake8: noqa
from .staff_action_log import staff_action_repo
from .chat_engine import chat_engine_repo
from .chat import chat_repo
from .document import document_repo
from .chunk import ChunkRepo
from .data_source import data_source_repo
from .knowledge_base import knowledge_base_repo
from .feedback import feedback_repo
from .llm import llm_repo
from .embedding_model import embedding_model_repo
================================================
FILE: backend/app/repositories/base_repo.py
================================================
from sqlmodel import Session, SQLModel, select
class BaseRepo:
model_cls: SQLModel
def get(self, session: Session, id: int):
return session.get(self.model_cls, id)
def get_all(self, session: Session):
return session.exec(select(self.model_cls)).all()
def create(self, session: Session, obj: SQLModel):
session.add(obj)
session.commit()
session.refresh(obj)
return obj
================================================
FILE: backend/app/repositories/chat.py
================================================
import enum
from uuid import UUID
from typing import Optional, List, Dict, Any
from datetime import datetime, UTC, date, timedelta
from collections import defaultdict
from sqlmodel import select, Session, or_, func, case, desc, col
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from app.models import Chat, User, ChatMessage, ChatUpdate, ChatFilters, ChatOrigin
from app.repositories.base_repo import BaseRepo
from app.exceptions import ChatNotFound, ChatMessageNotFound
class ChatRepo(BaseRepo):
model_cls = Chat
def paginate(
self,
session: Session,
user: User | None,
browser_id: str | None,
filters: ChatFilters,
params: Params | None = Params(),
) -> Page[Chat]:
query = select(Chat).where(Chat.deleted_at == None)
if user:
if not user.is_superuser:
query = query.where(
or_(Chat.user_id == user.id, Chat.browser_id == browser_id)
)
else:
query = query.where(Chat.browser_id == browser_id, Chat.user_id == None)
# filters
if filters.created_at_start:
query = query.where(Chat.created_at >= filters.created_at_start)
if filters.created_at_end:
query = query.where(Chat.created_at <= filters.created_at_end)
if filters.updated_at_start:
query = query.where(Chat.updated_at >= filters.updated_at_start)
if filters.updated_at_end:
query = query.where(Chat.updated_at <= filters.updated_at_end)
if filters.chat_origin:
query = query.where(col(Chat.origin).contains(filters.chat_origin))
# if filters.user_id:
# query = query.where(Chat.user_id == filters.user_id)
if filters.engine_id:
query = query.where(Chat.engine_id == filters.engine_id)
query = query.order_by(Chat.created_at.desc())
return paginate(session, query, params)
def get(
self,
session: Session,
chat_id: UUID,
) -> Optional[Chat]:
return session.exec(
select(Chat).where(Chat.id == chat_id, Chat.deleted_at == None)
).first()
def must_get(
self,
session: Session,
chat_id: UUID,
) -> Chat:
chat = self.get(session, chat_id)
if not chat:
raise ChatNotFound(chat_id)
return chat
def update(
self,
session: Session,
chat: Chat,
chat_update: ChatUpdate,
) -> Chat:
for field, value in chat_update.model_dump(exclude_unset=True).items():
if isinstance(value, enum.Enum):
value = value.value
setattr(chat, field, value)
session.commit()
session.refresh(chat)
return chat
def delete(self, session: Session, chat: Chat):
chat.deleted_at = datetime.now(UTC)
session.add(chat)
session.commit()
def get_last_message(self, session: Session, chat: Chat) -> Optional[ChatMessage]:
return session.exec(
select(ChatMessage)
.where(ChatMessage.chat_id == chat.id)
.order_by(ChatMessage.ordinal.desc())
).first()
def get_messages(
self,
session: Session,
chat: Chat,
) -> List[ChatMessage]:
return session.exec(
select(ChatMessage)
.where(ChatMessage.chat_id == chat.id)
.order_by(ChatMessage.ordinal.asc())
).all()
def get_message(
self,
session: Session,
chat_message_id: int,
) -> Optional[ChatMessage]:
return session.exec(
select(ChatMessage).where(
ChatMessage.id == chat_message_id,
ChatMessage.chat.has(Chat.deleted_at == None),
)
).first()
def must_get_message(
self,
session: Session,
chat_message_id: int,
):
msg = self.get_message(session, chat_message_id)
if not msg:
raise ChatMessageNotFound(chat_message_id)
return msg
def create_message(
self,
session: Session,
chat: Chat,
chat_message: ChatMessage,
) -> ChatMessage:
if not chat_message.ordinal:
last_message = self.get_last_message(session, chat)
if last_message:
ordinal = last_message.ordinal + 1
else:
ordinal = 1
chat_message.ordinal = ordinal
chat_message.chat_id = chat.id
chat_message.user_id = chat.user_id
session.add(chat_message)
session.commit()
session.refresh(chat_message)
return chat_message
def find_recent_assistant_messages_by_goal(
self, session: Session, metadata: Dict[str, Any], days: int = 15
) -> List[ChatMessage]:
"""
Search for 'assistant' role chat messages with a specific goal within the recent days.
Args:
session (Session): The database session.
goal (str): The goal value to match in meta.goal.
days (int, optional): Number of recent days to include in the search. Defaults to 2.
Returns:
List[ChatMessage]: A list of ChatMessage instances that match the criteria.
"""
# Calculate the cutoff datetime based on the current UTC time minus the specified number of days
cutoff = datetime.now(UTC) - timedelta(days=days)
query = select(ChatMessage).where(
ChatMessage.role == "assistant",
ChatMessage.created_at >= cutoff,
ChatMessage.is_best_answer.is_(True), # Use is_ for boolean fields
)
# Dynamically add filters for each key-value pair in metadata
for key, value in metadata.items():
json_path = f"$.{key}"
filter_condition = (
func.JSON_UNQUOTE(func.JSON_EXTRACT(ChatMessage.meta, json_path))
== value
)
query = query.where(filter_condition)
# Order by created_at in descending order
query = query.order_by(desc(ChatMessage.created_at))
return session.exec(query).all()
def find_best_answer_for_question(
self, session: Session, user_question: str
) -> List[ChatMessage]:
"""Find best answer messages for a specific user question.
This method finds assistant messages that:
1. Are marked as best answers
2. Are responses (ordinal=2) to the exact user question
3. Were created within the last 15 days
Args:
session: Database session
user_question: The exact question text to search for
Returns:
List of matching assistant messages marked as best answers
"""
cutoff = datetime.now(UTC) - timedelta(days=15)
# First, get all best answers from assistant (using the is_best_answer index)
best_answer_chat_ids = select(ChatMessage.chat_id).where(
ChatMessage.is_best_answer == 1, # Using the index for efficiency
ChatMessage.role == "assistant",
ChatMessage.ordinal == 2,
ChatMessage.created_at >= cutoff,
)
# Then, find user questions that match our target question and belong to chats with best answers
matching_chat_ids = select(ChatMessage.chat_id).where(
ChatMessage.chat_id.in_(best_answer_chat_ids),
ChatMessage.role == "user",
ChatMessage.ordinal == 1,
ChatMessage.content == user_question.strip(),
)
# Finally, get the best answers that correspond to the matching user questions
query = select(ChatMessage).where(
ChatMessage.is_best_answer == 1,
ChatMessage.role == "assistant",
ChatMessage.ordinal == 2,
ChatMessage.chat_id.in_(matching_chat_ids),
)
query = query.order_by(desc(ChatMessage.created_at))
# Execute the query and return all results
return session.exec(query).all()
def chat_trend_by_user(
self, session: Session, start_date: date, end_date: date
) -> List[dict]:
start_at = datetime.combine(start_date, datetime.min.time(), UTC)
end_at = datetime.combine(end_date, datetime.max.time(), UTC)
query = (
select(
func.date(Chat.created_at).label("date"),
func.sum(case((Chat.user_id.isnot(None), 1), else_=0)).label("user"),
func.sum(case((Chat.user_id.is_(None), 1), else_=0)).label("anonymous"),
)
.where(Chat.created_at.between(start_at, end_at))
.group_by(func.date(Chat.created_at))
.order_by(func.date(Chat.created_at))
)
result = session.exec(query)
return [
{"date": row.date, "user": int(row.user), "anonymous": int(row.anonymous)}
for row in result
]
def chat_trend_by_origin(
self, session: Session, start_date: date, end_date: date
) -> List[dict]:
start_at = datetime.combine(start_date, datetime.min.time(), UTC)
end_at = datetime.combine(end_date, datetime.max.time(), UTC)
query = (
select(
func.count(Chat.id).label("count"),
func.date(Chat.created_at).label("date"),
Chat.origin,
)
.where(Chat.created_at.between(start_at, end_at))
.group_by(func.date(Chat.created_at), Chat.origin)
.order_by(func.date(Chat.created_at))
)
result = session.exec(query)
date_origin_counts = defaultdict(lambda: defaultdict(int))
origins = set()
for row in result:
date_origin_counts[row.date][row.origin] = row.count
origins.add(row.origin)
stats = []
for d, origin_counts in date_origin_counts.items():
stat = {"date": d}
for origin in origins:
stat[origin] = origin_counts[origin]
stats.append(stat)
stats.sort(key=lambda x: x["date"])
return stats
def list_chat_origins(
self,
db_session: Session,
search: Optional[str] = None,
params: Params = Params(),
) -> Page[ChatOrigin]:
query = (
select(Chat.origin, func.count(Chat.id).label("chats"))
.where(Chat.deleted_at == None)
.where(Chat.origin != None)
.where(Chat.origin != "")
)
if search:
query = query.where(Chat.origin.ilike(f"%{search}%"))
query = query.group_by(Chat.origin).order_by(desc("chats"))
return paginate(
db_session,
query,
params,
transformer=lambda chats: [
ChatOrigin(origin=chat.origin, chats=chat.chats) for chat in chats
],
)
chat_repo = ChatRepo()
================================================
FILE: backend/app/repositories/chat_engine.py
================================================
from typing import Optional
from datetime import datetime, UTC
from sqlalchemy import func
from sqlmodel import select, Session, update
from app.exceptions import ChatEngineNotFound
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy.orm.attributes import flag_modified
from app.models.chat_engine import ChatEngine, ChatEngineUpdate
from app.repositories.base_repo import BaseRepo
class ChatEngineRepo(BaseRepo):
model_cls = ChatEngine
def get(
self, session: Session, id: int, need_public: bool = False
) -> Optional[ChatEngine]:
query = select(ChatEngine).where(
ChatEngine.id == id, ChatEngine.deleted_at == None
)
if need_public:
query = query.where(ChatEngine.is_public == True)
return session.exec(query).first()
def must_get(
self, session: Session, id: int, need_public: bool = False
) -> ChatEngine:
chat_engine = self.get(session, id, need_public)
if chat_engine is None:
raise ChatEngineNotFound(id)
return chat_engine
def paginate(
self,
session: Session,
params: Params | None = Params(),
need_public: bool = False,
) -> Page[ChatEngine]:
query = select(ChatEngine).where(ChatEngine.deleted_at == None)
if need_public:
query = query.where(ChatEngine.is_public == True)
# Make sure the default engine is always on top
query = query.order_by(ChatEngine.is_default.desc(), ChatEngine.name)
return paginate(session, query, params)
def get_default_engine(self, session: Session) -> Optional[ChatEngine]:
return session.exec(
select(ChatEngine).where(
ChatEngine.is_default == True, ChatEngine.deleted_at == None
)
).first()
def has_default(self, session: Session) -> bool:
return (
session.scalar(
select(func.count(ChatEngine.id)).where(
ChatEngine.is_default == True, ChatEngine.deleted_at == None
)
)
> 0
)
def get_engine_by_name(self, session: Session, name: str) -> Optional[ChatEngine]:
return session.exec(
select(ChatEngine).where(
ChatEngine.name == name, ChatEngine.deleted_at == None
)
).first()
def create(self, session: Session, obj: ChatEngine):
if obj.is_default:
session.exec(
update(ChatEngine)
.where(ChatEngine.id != obj.id)
.values(is_default=False)
)
session.add(obj)
session.commit()
session.refresh(obj)
return obj
def update(
self,
session: Session,
chat_engine: ChatEngine,
chat_engine_update: ChatEngineUpdate,
) -> ChatEngine:
set_default = chat_engine_update.is_default
for field, value in chat_engine_update.model_dump(exclude_unset=True).items():
setattr(chat_engine, field, value)
flag_modified(chat_engine, field)
if set_default:
session.exec(
update(ChatEngine)
.where(ChatEngine.id != chat_engine.id)
.values(is_default=False)
)
session.commit()
session.refresh(chat_engine)
return chat_engine
def delete(self, session: Session, chat_engine: ChatEngine) -> ChatEngine:
chat_engine.deleted_at = datetime.now(UTC)
session.commit()
session.refresh(chat_engine)
return chat_engine
chat_engine_repo = ChatEngineRepo()
================================================
FILE: backend/app/repositories/chunk.py
================================================
from typing import Type
from sqlalchemy import func, delete
from sqlmodel import Session, select, SQLModel
from app.repositories.base_repo import BaseRepo
from app.models import (
Document as DBDocument,
)
class ChunkRepo(BaseRepo):
def __init__(self, chunk_model: Type[SQLModel]):
self.model_cls = chunk_model
def document_exists_chunks(self, session: Session, document_id: int) -> bool:
return (
session.exec(
select(self.model_cls).where(self.model_cls.document_id == document_id)
).first()
is not None
)
def get_documents_by_chunk_ids(
self, session: Session, chunk_ids: list[str]
) -> list[DBDocument]:
stmt = select(DBDocument).where(
DBDocument.id.in_(
select(self.model_cls.document_id).where(
self.model_cls.id.in_(chunk_ids),
)
),
)
return list(session.exec(stmt).all())
def get_document_chunks(self, session: Session, document_id: int):
return session.exec(
select(self.model_cls).where(self.model_cls.document_id == document_id)
).all()
def fetch_by_document_ids(self, session: Session, document_ids: list[int]):
return session.exec(
select(self.model_cls).where(self.model_cls.document_id.in_(document_ids))
).all()
def count(self, session: Session):
return session.scalar(select(func.count(self.model_cls.id)))
def delete_by_datasource(self, session: Session, datasource_id: int):
doc_ids_subquery = select(DBDocument.id).where(
DBDocument.data_source_id == datasource_id
)
stmt = delete(self.model_cls).where(
self.model_cls.document_id.in_(doc_ids_subquery)
)
session.exec(stmt)
def delete_by_document(self, session: Session, document_id: int):
stmt = delete(self.model_cls).where(self.model_cls.document_id == document_id)
session.exec(stmt)
================================================
FILE: backend/app/repositories/data_source.py
================================================
from typing import Optional
from datetime import datetime, UTC
from sqlmodel import select, Session
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from app.models import (
DataSource,
)
from app.repositories.base_repo import BaseRepo
class DataSourceRepo(BaseRepo):
model_cls = DataSource
def paginate(
self,
session: Session,
params: Params | None = Params(),
) -> Page[DataSource]:
query = (
select(DataSource)
.where(DataSource.deleted_at == None)
.order_by(DataSource.created_at.desc())
)
return paginate(session, query, params)
def get(
self,
session: Session,
data_source_id: int,
) -> Optional[DataSource]:
return session.exec(
select(DataSource).where(
DataSource.id == data_source_id, DataSource.deleted_at == None
)
).first()
def delete(self, session: Session, data_source: DataSource) -> None:
data_source.deleted_at = datetime.now(UTC)
session.add(data_source)
session.commit()
data_source_repo = DataSourceRepo()
================================================
FILE: backend/app/repositories/document.py
================================================
from typing import Type
from sqlmodel import select, Session, or_, delete
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from app.api.admin_routes.knowledge_base.document.models import DocumentFilters
from app.exceptions import DocumentNotFound
from app.models import Document
from app.repositories.base_repo import BaseRepo
class DocumentRepo(BaseRepo):
model_cls = Document
def paginate(
self,
session: Session,
filters: DocumentFilters,
params: Params | None = Params(),
) -> Page[Document]:
# build the select statement via conditions
stmt = select(Document)
if filters.knowledge_base_id:
stmt = stmt.where(Document.knowledge_base_id == filters.knowledge_base_id)
if filters.search:
stmt = stmt.where(
or_(
Document.name.contains(filters.search),
Document.source_uri.contains(filters.search),
)
)
if filters.data_source_id:
stmt = stmt.where(Document.data_source_id == filters.data_source_id)
if filters.created_at:
start_time, end_time = filters.created_at
if start_time:
stmt = stmt.where(Document.created_at >= start_time)
if end_time:
stmt = stmt.where(Document.created_at <= end_time)
if filters.updated_at:
start_time, end_time = filters.updated_at
if start_time:
stmt = stmt.where(Document.updated_at >= start_time)
if end_time:
stmt = stmt.where(Document.updated_at <= end_time)
if filters.last_modified_at:
start_time, end_time = filters.last_modified_at
if start_time:
stmt = stmt.where(Document.last_modified_at >= start_time)
if end_time:
stmt = stmt.where(Document.last_modified_at <= end_time)
if filters.mime_type:
stmt = stmt.where(Document.mime_type == filters.mime_type)
if filters.index_status:
stmt = stmt.where(Document.index_status == filters.index_status)
# Make sure the newer edited record is always on top
stmt = stmt.order_by(Document.updated_at.desc())
return paginate(session, stmt, params)
def must_get(self, session: Session, doc_id: int) -> Type[Document]:
doc = session.get(Document, doc_id)
if not doc:
raise DocumentNotFound(doc_id)
return doc
def delete_by_datasource(self, session: Session, datasource_id: int):
stmt = delete(Document).where(Document.data_source_id == datasource_id)
session.exec(stmt)
def fetch_by_ids(self, session: Session, document_ids: list[int]) -> list[Document]:
stmt = select(Document).where(Document.id.in_(document_ids))
return session.exec(stmt).all()
document_repo = DocumentRepo()
================================================
FILE: backend/app/repositories/embedding_model.py
================================================
from typing import Optional, Type
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, select, update
from app.api.admin_routes.embedding_model.models import (
EmbeddingModelUpdate,
EmbeddingModelCreate,
)
from app.exceptions import DefaultEmbeddingModelNotFound, EmbeddingModelNotFound
from app.models import EmbeddingModel
from app.models.knowledge_base import KnowledgeBase
from app.repositories.base_repo import BaseRepo
class EmbeddingModelRepo(BaseRepo):
model_cls = EmbeddingModel
def paginate(
self, session: Session, params: Params | None = Params()
) -> Page[EmbeddingModel]:
query = select(EmbeddingModel)
# Make sure the default model is always on top.
query = query.order_by(
EmbeddingModel.is_default.desc(), EmbeddingModel.created_at.desc()
)
return paginate(session, query, params)
def get(self, session: Session, model_id: int) -> Optional[EmbeddingModel]:
return session.get(EmbeddingModel, model_id)
def must_get(self, session: Session, model_id: int) -> Type[EmbeddingModel]:
db_embed_model = self.get(session, model_id)
if db_embed_model is None:
raise EmbeddingModelNotFound(model_id)
return db_embed_model
def exists_any_model(self, session: Session) -> bool:
stmt = select(EmbeddingModel).with_for_update().limit(1)
return session.exec(stmt).one_or_none() is not None
def create(self, session: Session, create: EmbeddingModelCreate):
# If there is currently no model, the first model will be
# set as the default model.
if not self.exists_any_model(session):
create.is_default = True
if create.is_default:
self._unset_default(session)
embed_model = EmbeddingModel(
name=create.name,
provider=create.provider,
model=create.model,
vector_dimension=create.vector_dimension,
config=create.config,
credentials=create.credentials,
is_default=create.is_default,
)
session.add(embed_model)
session.commit()
session.refresh(embed_model)
return embed_model
def update(
self,
session: Session,
embed_model: EmbeddingModel,
partial_update: EmbeddingModelUpdate,
) -> EmbeddingModel:
for field, value in partial_update.model_dump(exclude_unset=True).items():
setattr(embed_model, field, value)
flag_modified(embed_model, field)
session.commit()
session.refresh(embed_model)
return embed_model
def delete(self, session: Session, model: EmbeddingModel):
# TODO: Support to specify a new embedding model to replace the current embedding model.
session.exec(
update(KnowledgeBase)
.where(KnowledgeBase.embedding_model_id == model.id)
.values(embedding_model_id=None)
)
session.delete(model)
session.commit()
# Default model
def get_default(self, session: Session) -> Type[EmbeddingModel]:
stmt = select(EmbeddingModel).where(EmbeddingModel.is_default == True).limit(1)
return session.exec(stmt).first()
def has_default(self, session: Session) -> bool:
return self.get_default(session) is not None
def must_get_default(self, session: Session) -> Type[EmbeddingModel]:
embed_model = self.get_default(session)
if embed_model is None:
raise DefaultEmbeddingModelNotFound()
return embed_model
def _unset_default(self, session: Session):
session.exec(
update(EmbeddingModel)
.values(is_default=False)
.where(EmbeddingModel.is_default == True)
)
def set_default(self, session: Session, model: EmbeddingModel):
self._unset_default(session)
model.is_default = True
flag_modified(model, "is_default")
session.commit()
session.refresh(model)
return model
embedding_model_repo = EmbeddingModelRepo()
================================================
FILE: backend/app/repositories/feedback.py
================================================
from sqlmodel import select, Session, col, func, desc
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from typing import Optional
from app.models import Feedback, AdminFeedbackPublic, FeedbackFilters
from app.models.feedback import FeedbackOrigin
from app.repositories.base_repo import BaseRepo
class FeedbackRepo(BaseRepo):
model_cls = Feedback
def paginate(
self,
session: Session,
filters: FeedbackFilters,
params: Params | None = Params(),
) -> Page[AdminFeedbackPublic]:
# build the select statement via conditions
stmt = select(Feedback)
if filters.created_at_start:
stmt = stmt.where(Feedback.created_at >= filters.created_at_start)
if filters.created_at_end:
stmt = stmt.where(Feedback.created_at <= filters.created_at_end)
if filters.feedback_origin:
stmt = stmt.where(col(Feedback.origin).contains(filters.feedback_origin))
if filters.chat_id:
stmt = stmt.where(Feedback.chat_id == filters.chat_id)
if filters.feedback_type:
stmt = stmt.where(Feedback.feedback_type == filters.feedback_type)
if filters.user_id:
stmt = stmt.where(Feedback.user_id == filters.user_id)
stmt = stmt.order_by(Feedback.created_at.desc())
return paginate(
session,
stmt,
params,
transformer=lambda items: [
AdminFeedbackPublic(
**item.model_dump(),
chat_title=item.chat.title,
chat_origin=item.chat.origin,
chat_message_content=item.chat_message.content,
user_email=item.user.email if item.user else None,
)
for item in items
],
)
def list_feedback_origins(
self,
session: Session,
search: Optional[str] = None,
params: Params | None = Params(),
) -> Page[FeedbackOrigin]:
query = select(
Feedback.origin, func.count(Feedback.id).label("feedbacks")
).group_by(Feedback.origin)
if search:
query = query.where(Feedback.origin.ilike(f"%{search}%"))
query = query.order_by(desc("feedbacks"))
return paginate(
session,
query,
params,
transformer=lambda items: [
FeedbackOrigin(origin=item[0], feedbacks=item[1]) for item in items
],
)
feedback_repo = FeedbackRepo()
================================================
FILE: backend/app/repositories/graph.py
================================================
from typing import Type
from sqlmodel import Session, select, func, delete, SQLModel
from app.models.document import Document
from app.models.knowledge_base import KnowledgeBase
from app.models.chunk import get_kb_chunk_model
from app.models.entity import get_kb_entity_model
from app.models.relationship import get_kb_relationship_model
class GraphRepo:
def __init__(
self,
entity_model: Type[SQLModel],
relationship_model: Type[SQLModel],
chunk_model: Type[SQLModel],
):
self.entity_model = entity_model
self.relationship_model = relationship_model
self.chunk_model = chunk_model
def count_entities(self, session: Session):
return session.scalar(select(func.count(self.entity_model.id)))
def count_relationships(self, session: Session):
return session.scalar(select(func.count(self.relationship_model.id)))
def delete_orphaned_entities(self, session: Session):
orphaned_entity_ids = (
select(self.entity_model.id)
.outerjoin(
self.relationship_model,
(self.relationship_model.target_entity_id == self.entity_model.id)
| (self.relationship_model.source_entity_id == self.entity_model.id),
)
.where(self.relationship_model.id.is_(None))
.scalar_subquery()
)
stmt = delete(self.entity_model).where(
self.entity_model.id.in_(orphaned_entity_ids)
)
session.exec(stmt)
def delete_data_source_relationships(self, session: Session, datasource_id: int):
doc_ids_subquery = select(Document.id).where(
Document.data_source_id == datasource_id
)
chunk_ids_subquery = select(self.chunk_model.id).where(
self.chunk_model.document_id.in_(doc_ids_subquery)
)
stmt = delete(self.relationship_model).where(
self.relationship_model.chunk_id.in_(chunk_ids_subquery)
)
session.exec(stmt)
def delete_document_relationships(self, session: Session, document_id: int):
chunk_ids_subquery = select(self.chunk_model.id).where(
self.chunk_model.document_id == document_id
)
stmt = delete(self.relationship_model).where(
self.relationship_model.chunk_id.in_(chunk_ids_subquery)
)
session.exec(stmt)
def get_kb_graph_repo(kb: KnowledgeBase) -> GraphRepo:
chunk_model = get_kb_chunk_model(kb)
entity_model = get_kb_entity_model(kb)
relationship_model = get_kb_relationship_model(kb)
return GraphRepo(entity_model, relationship_model, chunk_model)
================================================
FILE: backend/app/repositories/knowledge_base.py
================================================
from typing import List, Optional, Type
from datetime import datetime, UTC
from sqlalchemy import delete
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import SQLModel, select, Session, func, update
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from app.api.admin_routes.knowledge_base.models import (
VectorIndexError,
KGIndexError,
KnowledgeBaseUpdate,
)
from app.exceptions import KBDataSourceNotFound, KBNotFound
from app.models import (
KnowledgeBase,
Document,
DocIndexTaskStatus,
KgIndexStatus,
KnowledgeBaseDataSource,
)
from app.models.chat_engine import ChatEngine
from app.models.chunk import get_kb_chunk_model
from app.models.data_source import DataSource
from app.models.knowledge_base import IndexMethod
from app.repositories.base_repo import BaseRepo
from app.repositories.chunk import ChunkRepo
from app.repositories.graph import get_kb_graph_repo
class KnowledgeBaseRepo(BaseRepo):
model_cls = KnowledgeBase
def paginate(
self, session: Session, params: Params | None = Params()
) -> Page[KnowledgeBase]:
query = (
select(KnowledgeBase)
.where(KnowledgeBase.deleted_at == None)
.order_by(KnowledgeBase.created_at.desc())
)
return paginate(session, query, params)
def get(
self, session: Session, knowledge_base_id: int, show_soft_deleted: bool = True
) -> Optional[KnowledgeBase]:
stmt = select(KnowledgeBase).where(KnowledgeBase.id == knowledge_base_id)
if not show_soft_deleted:
stmt = stmt.where(KnowledgeBase.deleted_at == None)
return session.exec(stmt).first()
def must_get(
self, session: Session, knowledge_base_id: int, show_soft_deleted: bool = True
) -> Optional[KnowledgeBase]:
kb = self.get(session, knowledge_base_id, show_soft_deleted)
if kb is None:
raise KBNotFound(knowledge_base_id)
return kb
def get_by_ids(
self, session: Session, knowledge_base_ids: List[int]
) -> List[KnowledgeBase]:
return session.exec(
select(KnowledgeBase).where(KnowledgeBase.id.in_(knowledge_base_ids))
).all()
def update(
self,
session: Session,
knowledge_base: KnowledgeBase,
partial_update: KnowledgeBaseUpdate,
) -> KnowledgeBase:
for field, value in partial_update.model_dump(exclude_unset=True).items():
setattr(knowledge_base, field, value)
flag_modified(knowledge_base, field)
session.commit()
session.refresh(knowledge_base)
return knowledge_base
def delete(self, session: Session, knowledge_base: KnowledgeBase) -> None:
knowledge_base.deleted_at = datetime.now(UTC)
session.add(knowledge_base)
session.commit()
def get_index_overview(self, session: Session, kb: KnowledgeBase) -> dict:
# TODO: store and query the count numbers in the knowledge base table.
documents_total = self.count_documents(session, kb)
chunks_total = self.count_chunks(session, kb)
overview_data = {
"documents": {"total": documents_total},
"chunks": {"total": chunks_total},
}
if IndexMethod.VECTOR in kb.index_methods:
overview_data.update(
self.count_documents_by_vector_index_status(session, kb)
)
if IndexMethod.KNOWLEDGE_GRAPH in kb.index_methods:
relationships_total = self.count_relationships(session, kb)
entities_total = self.count_entities(session, kb)
overview_data.update(
{
"entities": {"total": entities_total},
"relationships": {"total": relationships_total},
}
)
overview_data.update(self.count_chunks_by_kg_index_status(session, kb))
return overview_data
def count_data_sources(self, session: Session, kb: KnowledgeBase) -> int:
return session.scalar(
select(func.count(KnowledgeBaseDataSource.data_source_id)).where(
KnowledgeBaseDataSource.knowledge_base_id == kb.id
)
)
def count_documents(self, session: Session, kb: KnowledgeBase) -> int:
return session.scalar(
select(func.count(Document.id)).where(Document.knowledge_base_id == kb.id)
)
def count_chunks(self, session: Session, kb: KnowledgeBase):
chunk_repo = ChunkRepo(get_kb_chunk_model(kb))
return chunk_repo.count(session)
def count_relationships(self, session: Session, kb: KnowledgeBase):
graph_repo = get_kb_graph_repo(kb)
return graph_repo.count_relationships(session)
def count_entities(self, session: Session, kb: KnowledgeBase):
graph_repo = get_kb_graph_repo(kb)
return graph_repo.count_entities(session)
def count_documents_by_vector_index_status(
self, session: Session, kb: KnowledgeBase
) -> dict:
stmt = (
select(Document.index_status, func.count(Document.id))
.where(Document.knowledge_base_id == kb.id)
.group_by(Document.index_status)
.order_by(Document.index_status)
)
results = session.exec(stmt).all()
vector_index_status = {s: c for s, c in results}
return {
"vector_index": vector_index_status,
}
def count_chunks_by_kg_index_status(
self, session: Session, kb: KnowledgeBase
) -> dict:
# FIXME: Maybe we should count the documents (instead of chunks) like vector index?
chunk_model = get_kb_chunk_model(kb)
stmt = (
select(chunk_model.index_status, func.count(chunk_model.id))
.where(chunk_model.document.has(Document.knowledge_base_id == kb.id))
.group_by(chunk_model.index_status)
.order_by(chunk_model.index_status)
)
results = session.exec(stmt).all()
kg_index_status = {s: c for s, c in results}
return {"kg_index": kg_index_status}
def batch_update_document_status(
self, session: Session, document_ids: list[int], status: DocIndexTaskStatus
):
stmt = (
update(Document)
.where(Document.id.in_(document_ids))
.values(index_status=status)
)
session.exec(stmt)
session.commit()
def set_failed_documents_status_to_pending(
self, session: Session, kb: KnowledgeBase
) -> list[int]:
stmt = select(Document.id).where(
Document.knowledge_base_id == kb.id,
Document.index_status == DocIndexTaskStatus.FAILED,
)
failed_document_ids = session.exec(stmt).all()
self.batch_update_document_status(
session, failed_document_ids, DocIndexTaskStatus.PENDING
)
return failed_document_ids
def batch_update_chunk_status(
self,
session: Session,
chunk_model: Type[SQLModel],
chunk_ids: list[int],
status: KgIndexStatus,
):
stmt = (
update(chunk_model)
.where(chunk_model.id.in_(chunk_ids))
.values(index_status=status)
)
session.exec(stmt)
session.commit()
def set_failed_chunks_status_to_pending(
self, session: Session, kb: KnowledgeBase
) -> list[int]:
chunk_model = get_kb_chunk_model(kb)
stmt = select(chunk_model.id).where(
chunk_model.document.has(Document.knowledge_base_id == kb.id),
chunk_model.index_status == KgIndexStatus.FAILED,
)
chunk_ids = session.exec(stmt).all()
# Update status.
self.batch_update_chunk_status(
session, chunk_model, chunk_ids, KgIndexStatus.PENDING
)
return chunk_ids
def list_vector_index_built_errors(
self,
session: Session,
kb: KnowledgeBase,
params: Params | None = Params(),
) -> Page[VectorIndexError]:
query = (
select(
Document.id,
Document.name,
Document.source_uri,
Document.index_result,
)
.where(
Document.knowledge_base_id == kb.id,
Document.index_status == DocIndexTaskStatus.FAILED,
)
.order_by(Document.id.desc())
)
return paginate(
session,
query,
params,
transformer=lambda rows: [
VectorIndexError(
document_id=row[0],
document_name=row[1],
source_uri=row[2],
error=row[3],
)
for row in rows
],
)
def list_kg_index_built_errors(
self,
session: Session,
kb: KnowledgeBase,
params: Params | None = Params(),
) -> Page[KGIndexError]:
chunk_model = get_kb_chunk_model(kb)
query = (
select(
Document.id,
Document.name,
chunk_model.source_uri,
chunk_model.id,
chunk_model.index_result,
)
.join(Document)
.where(
chunk_model.document_id == Document.id,
Document.knowledge_base_id == kb.id,
chunk_model.index_status == KgIndexStatus.FAILED,
)
.order_by(chunk_model.id.desc())
)
return paginate(
session,
query,
params,
transformer=lambda rows: [
KGIndexError(
document_id=row[0],
document_name=row[1],
source_uri=row[2],
chunk_id=row[3],
error=row[4],
)
for row in rows
],
)
def get_kb_datasource(
self,
session: Session,
kb: KnowledgeBase,
datasource_id: int,
show_soft_deleted: bool = False,
) -> DataSource:
stmt = select(DataSource).where(DataSource.id == datasource_id)
if not show_soft_deleted:
stmt = stmt.where(DataSource.deleted_at == None)
return session.exec(stmt).first()
def must_get_kb_datasource(
self,
session: Session,
kb: KnowledgeBase,
datasource_id: int,
show_soft_deleted: bool = False,
) -> DataSource:
data_source = self.get_kb_datasource(
session, kb, datasource_id, show_soft_deleted
)
if data_source is None:
raise KBDataSourceNotFound(kb.id, datasource_id)
return data_source
def add_kb_datasource(
self, session: Session, kb: KnowledgeBase, data_source: DataSource
) -> DataSource:
session.add(data_source)
kb.data_sources.append(data_source)
session.add(kb)
session.commit()
session.refresh(data_source)
return data_source
def list_kb_datasources(
self, session: Session, kb_id: int, params: Params | None = Params()
) -> Page[DataSource]:
query = (
select(DataSource)
.join(KnowledgeBaseDataSource)
.where(
DataSource.deleted_at == None,
KnowledgeBaseDataSource.knowledge_base_id == kb_id,
)
.order_by(DataSource.created_at.desc())
)
return paginate(session, query, params)
def remove_kb_datasource(
self, session: Session, kb: KnowledgeBase, data_source: DataSource
) -> None:
# Flag the data source to be deleted.
data_source.deleted_at = datetime.now(UTC)
session.add(data_source)
# Remove the data source from the knowledge base.
stmt = delete(KnowledgeBaseDataSource).where(
KnowledgeBaseDataSource.knowledge_base_id == kb.id,
KnowledgeBaseDataSource.data_source_id == data_source.id,
)
session.exec(stmt)
def list_linked_chat_engines(
self, session: Session, kb_id: int
) -> List[ChatEngine]:
return session.exec(
select(ChatEngine).where(
ChatEngine.deleted_at == None,
func.JSON_UNQUOTE(
func.JSON_EXTRACT(
ChatEngine.engine_options,
"$.knowledge_base.linked_knowledge_base.id",
)
)
== kb_id,
)
).all()
knowledge_base_repo = KnowledgeBaseRepo()
================================================
FILE: backend/app/repositories/llm.py
================================================
from typing import Type, Optional
from fastapi import Depends
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy import update
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import select, Session
from app.exceptions import DefaultLLMNotFound, LLMNotFound
from app.models import LLM, LLMUpdate
from app.models.chat_engine import ChatEngine
from app.models.knowledge_base import KnowledgeBase
from app.repositories.base_repo import BaseRepo
class LLMRepo(BaseRepo):
model_cls: LLM
def paginate(self, session: Session, params: Params = Depends()) -> Page[LLM]:
query = select(LLM)
# Make sure the default llm is always on top.
query = query.order_by(LLM.is_default.desc(), LLM.created_at.desc())
return paginate(session, query, params)
def get(self, session: Session, llm_id: int) -> Optional[LLM]:
return session.get(LLM, llm_id)
def must_get(self, session: Session, llm_id: int) -> LLM:
db_llm = self.get(session, llm_id)
if db_llm is None:
raise LLMNotFound(llm_id)
return db_llm
def exists_any_model(self, session: Session) -> bool:
stmt = select(LLM).with_for_update().limit(1)
return session.exec(stmt).one_or_none() is not None
def create(self, session: Session, llm: LLM) -> LLM:
# If there is no exiting model, the first model is
# automatically set as the default model.
if not self.exists_any_model(session):
llm.is_default = True
if llm.is_default:
self._unset_default(session)
llm.id = None
session.add(llm)
session.commit()
session.refresh(llm)
return llm
def update(self, session: Session, llm: LLM, llm_update: LLMUpdate) -> LLM:
for field, value in llm_update.model_dump(exclude_unset=True).items():
setattr(llm, field, value)
flag_modified(llm, field)
session.commit()
session.refresh(llm)
return llm
# Default model
def get_default(self, session: Session) -> Type[LLM] | None:
stmt = (
select(LLM)
.where(LLM.is_default == True)
.order_by(LLM.updated_at.desc())
.limit(1)
)
return session.exec(stmt).first()
def has_default(self, session: Session) -> bool:
return self.get_default(session) is not None
def must_get_default(self, session: Session) -> Type[LLM]:
db_llm = self.get_default(session)
if db_llm is None:
raise DefaultLLMNotFound()
return db_llm
def _unset_default(self, session: Session):
session.exec(update(LLM).values(is_default=False))
def set_default(self, session: Session, llm: LLM) -> LLM:
self._unset_default(session)
llm.is_default = True
session.add(llm)
session.commit()
session.refresh(llm)
return llm
def delete(self, session: Session, llm: LLM):
# TODO: Support to specify a new LLM to replace the current LLM.
session.exec(
update(ChatEngine).where(ChatEngine.llm_id == llm.id).values(llm_id=None)
)
session.exec(
update(ChatEngine)
.where(ChatEngine.fast_llm_id == llm.id)
.values(fast_llm_id=None)
)
session.exec(
update(KnowledgeBase)
.where(KnowledgeBase.llm_id == llm.id)
.values(llm_id=None)
)
session.delete(llm)
session.commit()
llm_repo = LLMRepo()
================================================
FILE: backend/app/repositories/reranker_model.py
================================================
from typing import Optional
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy import update
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, select
from app.exceptions import RerankerModelNotFound, DefaultRerankerModelNotFound
from app.models import RerankerModel
from app.models.chat_engine import ChatEngine
from app.models.reranker_model import RerankerModelUpdate
from app.repositories.base_repo import BaseRepo
class RerankerModelRepo(BaseRepo):
model_cls: RerankerModel
def paginate(
self, session: Session, params: Params | None = Params()
) -> Page[RerankerModel]:
query = select(RerankerModel)
# Make sure the default reranker model is always on top.
query = query.order_by(
RerankerModel.is_default.desc(), RerankerModel.created_at.desc()
)
return paginate(session, query, params)
def get(self, session: Session, model_id: int) -> Optional[RerankerModel]:
return session.get(RerankerModel, model_id)
def must_get(self, session: Session, model_id: int) -> RerankerModel:
db_model = self.get(session, model_id)
if db_model is None:
raise RerankerModelNotFound(model_id)
return db_model
def exists_any_model(self, session: Session) -> bool:
stmt = select(RerankerModel).with_for_update().limit(1)
return session.exec(stmt).one_or_none() is not None
def create(self, session: Session, reranker_model: RerankerModel) -> RerankerModel:
# If there is no exiting model, the first model will be
# set as the default model.
if not self.exists_any_model(session):
reranker_model.is_default = True
if reranker_model.is_default:
self.unset_default(session)
reranker_model.id = None
session.add(reranker_model)
session.commit()
session.refresh(reranker_model)
return reranker_model
def update(
self,
session: Session,
reranker_model: RerankerModel,
model_update: RerankerModelUpdate,
) -> RerankerModel:
for field, value in model_update.model_dump(exclude_unset=True).items():
setattr(reranker_model, field, value)
flag_modified(reranker_model, field)
session.commit()
session.refresh(reranker_model)
return reranker_model
def delete(self, db_session: Session, reranker_model: RerankerModel):
# TODO: Support to specify a new reranker model to replace the current reranker model.
db_session.exec(
update(ChatEngine)
.where(ChatEngine.reranker_id == reranker_model.id)
.values(reranker_id=None)
)
db_session.delete(reranker_model)
db_session.commit()
# Default model
def get_default(self, session: Session) -> Optional[RerankerModel]:
stmt = select(RerankerModel).where(RerankerModel.is_default == True).limit(1)
return session.exec(stmt).first()
def has_default(self, session: Session) -> bool:
return self.get_default(session) is not None
def must_get_default(self, session: Session) -> RerankerModel:
db_reranker_model = self.get_default(session)
if db_reranker_model is None:
raise DefaultRerankerModelNotFound()
return db_reranker_model
def unset_default(self, session: Session):
session.exec(update(RerankerModel).values(is_default=False))
def set_default(self, session: Session, model: RerankerModel):
self.unset_default(session)
model.is_default = True
flag_modified(model, "is_default")
session.commit()
session.refresh(model)
return model
reranker_model_repo = RerankerModelRepo()
================================================
FILE: backend/app/repositories/staff_action_log.py
================================================
from sqlmodel import Session
from app.models.staff_action_log import StaffActionLog
from app.repositories.base_repo import BaseRepo
class StaffActionRepo(BaseRepo):
model_cls = StaffActionLog
def create_staff_action_log(
self, session: Session, action, target_type, before, after, commit=True
) -> StaffActionLog:
staff_action_log = StaffActionLog(
action=action, target_type=target_type, before=before, after=after
)
session.add(staff_action_log)
if commit:
session.commit()
return staff_action_log
staff_action_repo = StaffActionRepo()
================================================
FILE: backend/app/repositories/user.py
================================================
from typing import Optional
from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import Session, select
from app.models.auth import User
from app.repositories.base_repo import BaseRepo
class UserRepo(BaseRepo):
model_cls: User
def search_users(
self,
db_session: Session,
search: Optional[str] = None,
params: Params = Params(),
) -> Page[User]:
query = select(User)
if search:
query = query.where(User.email.ilike(f"%{search}%"))
query = query.order_by(User.id)
return paginate(
db_session,
query,
params,
)
user_repo = UserRepo()
================================================
FILE: backend/app/site_settings/__init__.py
================================================
import time
import threading
import logging
from sqlmodel import Session, select
from app.models import SiteSetting as DBSiteSetting
from app.core.db import engine
from app.site_settings.default import default_settings
from app.site_settings.types import SettingValue, SettingType
logger = logging.getLogger(__name__)
def get_settings_from_db(session: Session):
logger.debug("Getting settings from the database.")
stmt = select(DBSiteSetting.name, DBSiteSetting.value)
results = session.exec(stmt)
return {name: value for name, value in results}
def get_db_last_updated_at(session: Session):
logger.debug("Getting the last updated_at timestamp from the database.")
stmt = (
select(DBSiteSetting.updated_at)
.order_by(DBSiteSetting.updated_at.desc())
.limit(1)
)
result = session.exec(stmt).first()
return result.timestamp() if result else 0
type_mapping = {
"str": str,
"int": int,
"float": float,
"bool": bool,
"dict": dict,
"list": list,
}
class SiteSettingProxy:
__db_cache: dict = {}
__last_updated_at_ts: float = 0
__last_checked_at_ts: float = 0
__mutex = threading.Lock()
def update_db_cache(self, force_check=False):
# Check if we need to update the cache every 6 seconds,
# so it means settings will not be updated in real-time
# which is acceptable for this project.
# If we need real-time updates in the future, we can use
# a message queue or a pub/sub system to notify the app.
now = time.time()
if force_check or (now - self.__last_checked_at_ts > 6):
self.__last_checked_at_ts = now
with Session(engine) as session:
last_updated_at_ts = get_db_last_updated_at(session)
if last_updated_at_ts > self.__last_updated_at_ts:
with self.__mutex:
if last_updated_at_ts > self.__last_updated_at_ts:
self.__db_cache = get_settings_from_db(session)
self.__last_updated_at_ts = last_updated_at_ts
def get_db_cache(self) -> dict:
# Should we use a lock here?
with self.__mutex:
return self.__db_cache
def __getattr__(self, name: str) -> SettingType:
return self.get_setting(name)
def get_setting(self, name: str) -> SettingType:
if hasattr(default_settings, name):
default_setting = getattr(default_settings, name)
self.update_db_cache()
db_value = self.__db_cache.get(name)
return db_value if db_value is not None else default_setting.default
else:
raise AttributeError(f"Setting {name} does not exist.")
def get_all_settings(
self, force_check_db_cache: bool = False
) -> dict[str, SettingValue]:
self.update_db_cache(force_check_db_cache)
result = {}
for _, settings in default_settings.setting_groups.items():
for default_setting in settings:
db_value = self.__db_cache.get(default_setting.name)
result[default_setting.name] = SettingValue(
name=default_setting.name,
default=default_setting.default,
value=db_value if db_value is not None else default_setting.default,
data_type=default_setting.data_type,
description=default_setting.description,
group=default_setting.group,
client=default_setting.client,
)
return result
def get_client_settings(self) -> dict:
# Retrieve all client settings utilized in the frontend.
# These settings determine the behavior of the frontend and are accessible to all users.
settings = self.get_all_settings()
return {k: s.value for k, s in settings.items() if s.client}
def setting_exists(self, name: str) -> bool:
return hasattr(default_settings, name)
def update_setting(self, session: Session, name: str, value: SettingType):
if not self.setting_exists(name):
raise AttributeError(f"Setting {name} does not exist.")
_default_setting: SettingValue = getattr(default_settings, name)
if not isinstance(value, type_mapping[_default_setting.data_type]):
raise ValueError(f"{name} must be of type `{_default_setting.data_type}`.")
db_setting_obj = session.exec(
select(DBSiteSetting).filter(DBSiteSetting.name == name)
).first()
if db_setting_obj:
db_setting_obj.value = value
else:
db_setting_obj = DBSiteSetting(
name=name, value=value, data_type=_default_setting.data_type
)
session.add(db_setting_obj)
session.commit()
self.update_db_cache(force_check=True)
SiteSetting = SiteSettingProxy()
__all__ = ["SiteSetting", "SettingValue", "SettingType"]
================================================
FILE: backend/app/site_settings/default.py
================================================
import yaml
import threading
from collections import defaultdict
from app.site_settings.types import SettingValue
DEFAULT_YAML_FILE = "./app/site_settings/default_settings.yml"
class DefaultSettings:
setting_groups: dict[str, list[SettingValue]] = defaultdict(list)
__mutex = threading.Lock()
__loaded = False
def __init__(self):
with self.__mutex:
if not self.__loaded:
self.load_default_from_yaml()
self.__loaded = True
def load_default_from_yaml(self):
with open(DEFAULT_YAML_FILE, "r") as f:
data = yaml.safe_load(f)
for group, settings in data.items():
for name, value in settings.items():
setattr(
self,
name,
SettingValue(**value, name=name, group=group),
)
self.setting_groups[group].append(getattr(self, name))
default_settings = DefaultSettings()
================================================
FILE: backend/app/site_settings/default_settings.yml
================================================
website:
title:
default: "TiDB.AI"
data_type: str
description: "The title of the website."
client: true
description:
default: "Knowledge Graph based RAG that built with TiDB Serverless Vector Storage and LlamaIndex"
data_type: str
description: "The description of the website."
client: true
homepage_title:
default: "Ask anything about TiDB"
data_type: str
description: "The title of the homepage."
client: true
homepage_example_questions:
default:
- "What is TiDB?"
- "Does TiDB support FOREIGN KEY?"
- "Does TiDB support serverless?"
data_type: list
description: "The example questions on the homepage."
client: true
homepage_footer_links:
default:
- text: "Release Notes"
href: "https://autoflow.tidb.ai/releases"
- text: "Powered by TiDB"
href: "https://tidb.cloud/ai"
- text: "© 2024 PingCAP"
href: "https://pingcap.com"
data_type: list
description: "The footer links on the homepage."
client: true
logo_in_dark_mode:
default: "https://tidb.ai/tidb-ai-light.svg"
data_type: str
description: "The logo in dark mode."
client: true
logo_in_light_mode:
default: "https://tidb.ai/tidb-ai.svg"
data_type: str
description: "The logo in light mode."
client: true
social_github:
default: "https://github.com/pingcap/tidb.ai"
data_type: str
description: "The link to the github"
client: true
social_twitter:
default: "https://twitter.com/PingCAP"
data_type: str
description: "The link to the twitter"
client: true
social_discord:
default: "https://discord.gg/XzSW23Jg9p"
data_type: str
description: "The link to the discord"
client: true
ga_id:
default: ""
data_type: str
description: "Google Analytics ID"
client: true
custom_js:
custom_js_example_questions:
default:
- "What is TiDB?"
- "Does TiDB support FOREIGN KEY?"
- "Does TiDB support serverless?"
data_type: list
description: "The example questions on the custom_js."
client: true
custom_js_button_label:
default: "Ask AI"
data_type: str
description: "The button label on the custom_js."
client: true
custom_js_button_img_src:
default: "https://tidb.ai/tidb-ai-widget.svg"
data_type: str
description: "The button image on the custom_js."
client: true
custom_js_logo_src:
default: "https://tidb.ai/tidb-ai-widget.svg"
data_type: str
description: "The logo on the custom_js."
client: true
chat:
langfuse_secret_key:
default: ""
data_type: str
description: "The secret key of Langfuse."
client: false
langfuse_public_key:
default: ""
data_type: str
description: "The public key of Langfuse."
client: false
langfuse_host:
default: "https://us.cloud.langfuse.com"
data_type: str
description: "The host of Langfuse."
client: false
enable_post_verifications:
default: true
data_type: bool
description: "Enable post verification for all chats."
client: true
enable_post_verifications_for_widgets:
default: false
data_type: bool
description: "Enable post verification for chats from js widgets."
client: true
upload:
max_upload_file_size:
default: 10485760 # 10 MiB
data_type: int
description: "Max body size (in bytes) of upload file."
client: true
================================================
FILE: backend/app/site_settings/types.py
================================================
from typing import TypeAlias
from dataclasses import dataclass
SettingType: TypeAlias = bool | str | int | float | list | dict | None
@dataclass
class SettingValue:
name: str
default: SettingType
data_type: str
description: str
group: str
# Whether the setting can be viewed by the client
client: bool = False
value: SettingType = None
================================================
FILE: backend/app/staff_action/__init__.py
================================================
from sqlmodel import Session
from app.models import StaffActionLog
def create_staff_action_log(
session: Session, action, target_type, target_id, before, after, commit=True
):
staff_action_log = StaffActionLog(
action=action,
target_type=target_type,
target_id=target_id,
before=before,
after=after,
)
session.add(staff_action_log)
if commit:
session.commit()
return staff_action_log
================================================
FILE: backend/app/tasks/__init__.py
================================================
from .knowledge_base import (
import_documents_for_knowledge_base,
purge_kb_datasource_related_resources,
)
from .build_index import (
build_index_for_document,
build_kg_index_for_chunk,
)
from .evaluate import add_evaluation_task
__all__ = [
"build_index_for_document",
"build_kg_index_for_chunk",
"import_documents_for_knowledge_base",
"purge_kb_datasource_related_resources",
"add_evaluation_task",
]
================================================
FILE: backend/app/tasks/build_index.py
================================================
import traceback
from uuid import UUID
from sqlmodel import Session
from celery.utils.log import get_task_logger
from app.celery import app as celery_app
from app.core.db import engine
from app.models import (
Document as DBDocument,
DocIndexTaskStatus,
KgIndexStatus,
)
from app.models.chunk import get_kb_chunk_model
from app.models.knowledge_base import IndexMethod
from app.rag.build_index import IndexService
from app.rag.knowledge_base.config import get_kb_llm, get_kb_embed_model
from app.repositories import knowledge_base_repo
from app.repositories.chunk import ChunkRepo
logger = get_task_logger(__name__)
# TODO: refactor: divide into two tasks: build_vector_index_for_document and build_kg_index_for_document
@celery_app.task(bind=True)
def build_index_for_document(self, knowledge_base_id: int, document_id: int):
# Pre-check before building index.
with Session(engine, expire_on_commit=False) as session:
kb = knowledge_base_repo.must_get(session, knowledge_base_id)
# Check document.
db_document = session.get(DBDocument, document_id)
if db_document is None:
logger.error(f"Document #{document_id} is not found")
return
if db_document.index_status not in (
DocIndexTaskStatus.PENDING,
DocIndexTaskStatus.NOT_STARTED,
):
logger.info(f"Document #{document_id} is not in pending state")
return
# Init knowledge base index service。
try:
llm = get_kb_llm(session, kb)
embed_model = get_kb_embed_model(session, kb)
index_service = IndexService(llm, embed_model, kb)
except ValueError as e:
# LLM may not be available yet(eg. bootstrapping), retry after specified time
logger.warning(
f"Failed to init index service for document #{document_id} (retry task after 1 minute): {e}"
)
raise self.retry(countdown=60)
db_document.index_status = DocIndexTaskStatus.RUNNING
session.add(db_document)
session.commit()
# Build vector index.
try:
with Session(engine) as index_session:
index_service.build_vector_index_for_document(index_session, db_document)
with Session(engine) as session:
db_document.index_status = DocIndexTaskStatus.COMPLETED
session.add(db_document)
session.commit()
logger.info(f"Built vector index for document #{document_id} successfully.")
except Exception:
with Session(engine) as session:
error_msg = traceback.format_exc()
logger.error(
f"Failed to build vector index for document {document_id}: {error_msg}"
)
db_document.index_status = DocIndexTaskStatus.FAILED
db_document.index_result = error_msg
session.add(db_document)
session.commit()
return
# Build knowledge graph index.
with Session(engine, expire_on_commit=False) as session:
kb = knowledge_base_repo.must_get(session, knowledge_base_id)
if IndexMethod.KNOWLEDGE_GRAPH not in kb.index_methods:
return
chunk_repo = ChunkRepo(get_kb_chunk_model(kb))
chunks = chunk_repo.get_document_chunks(session, document_id)
for chunk in chunks:
build_kg_index_for_chunk.delay(knowledge_base_id, chunk.id)
@celery_app.task
def build_kg_index_for_chunk(knowledge_base_id: int, chunk_id: UUID):
with Session(engine, expire_on_commit=False) as session:
kb = knowledge_base_repo.must_get(session, knowledge_base_id)
# Check chunk.
chunk_model = get_kb_chunk_model(kb)
db_chunk = session.get(chunk_model, chunk_id)
if db_chunk is None:
logger.error(f"Chunk #{chunk_id} is not found")
return
if db_chunk.index_status not in (
KgIndexStatus.PENDING,
KgIndexStatus.NOT_STARTED,
):
logger.info(f"Chunk #{chunk_id} is not in pending state")
return
# Init knowledge base index service。
llm = get_kb_llm(session, kb)
embed_model = get_kb_embed_model(session, kb)
index_service = IndexService(llm, embed_model, kb)
db_chunk.index_status = KgIndexStatus.RUNNING
session.add(db_chunk)
session.commit()
try:
with Session(engine) as index_session:
index_service.build_kg_index_for_chunk(index_session, db_chunk)
with Session(engine) as session:
db_chunk.index_status = KgIndexStatus.COMPLETED
session.add(db_chunk)
session.commit()
logger.info(
f"Built knowledge graph index for chunk #{chunk_id} successfully."
)
except Exception:
with Session(engine) as session:
error_msg = traceback.format_exc()
logger.error(
f"Failed to build knowledge graph index for chunk #{chunk_id}",
exc_info=True,
)
db_chunk.index_status = KgIndexStatus.FAILED
db_chunk.index_result = error_msg
session.add(db_chunk)
session.commit()
================================================
FILE: backend/app/tasks/evaluate.py
================================================
import logging
import traceback
from llama_index.core.base.llms.types import ChatMessage
from app.celery import app as celery_app
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from ragas import EvaluationDataset, evaluate
from ragas.embeddings import LlamaIndexEmbeddingsWrapper
from ragas.llms import LlamaIndexLLMWrapper
from ragas.metrics import FactualCorrectness, SemanticSimilarity
from sqlmodel import Session, select
from celery.utils.log import get_task_logger
from tenacity import retry, stop_after_attempt, wait_fixed
from app.core.config import settings, Environment
from app.core.db import engine
from app.models import (
EvaluationTask,
EvaluationStatus,
EvaluationTaskItem,
)
from dotenv import load_dotenv
from app.rag.chat.chat_flow import ChatFlow
from app.rag.chat.stream_protocol import ChatEvent
from app.rag.types import ChatEventType, ChatMessageSate
load_dotenv()
logger = get_task_logger(__name__)
if settings.ENVIRONMENT == Environment.LOCAL:
logger.setLevel(logging.DEBUG)
for handler in logger.handlers:
handler.setLevel(logging.DEBUG)
@celery_app.task
def add_evaluation_task(evaluation_task_id: int):
logger.info(
f"[add_evaluation_task] Enter with evaluation task #{evaluation_task_id}"
)
with Session(engine, expire_on_commit=False) as session:
evaluation_task = session.get(EvaluationTask, evaluation_task_id)
if evaluation_task is None:
logger.error(f"Evaluation task #{evaluation_task_id} is not found")
return
# get eval items
eval_item_stmt = select(EvaluationTaskItem).where(
EvaluationTaskItem.evaluation_task_id == evaluation_task_id
)
eval_item_list = session.exec(eval_item_stmt).all()
logger.info(f"[add_evaluation_task] get {len(eval_item_list)} evaluation items")
for eval_item in eval_item_list:
logger.debug(type(eval_item))
logger.debug(
f"[add_evaluation_task] deal with evaluation item #{eval_item.id}"
)
add_evaluation_task_item.delay(eval_item.id)
@celery_app.task
def add_evaluation_task_item(evaluation_task_item_id: int):
logger.info(
f"Enter add_evaluation_task_item with evaluation item #{evaluation_task_item_id}"
)
with Session(engine, expire_on_commit=False) as session:
evaluation_task_item = session.get(EvaluationTaskItem, evaluation_task_item_id)
if evaluation_task_item is None:
logger.error(f"Evaluation item #{evaluation_task_item_id} is not found")
return
if evaluation_task_item.status != EvaluationStatus.NOT_START:
logger.error(
f"Evaluation item #{evaluation_task_item_id} is not in not start state"
)
return
try:
if evaluation_task_item.response is None or evaluation_task_item.response == "":
response, _ = generate_answer_by_autoflow(
[ChatMessage(role="assistant", content=evaluation_task_item.query)],
evaluation_task_item.chat_engine,
)
if response is None or response == "":
raise Exception("Autoflow response is empty")
logger.info(
f"Got response from autoflow for evaluation item #{evaluation_task_item_id}, {response}"
)
evaluation_task_item.response = response
logger.info(f"Successfully get response item #{evaluation_task_item_id}")
with Session(engine, expire_on_commit=False) as session:
session.merge(evaluation_task_item)
session.commit()
evaluate_task(evaluation_task_item)
except Exception as e:
logger.error(f"Failed to evaluate item #{evaluation_task_item_id}, error: {e}")
evaluation_task_item.error_msg = traceback.format_exc()
evaluation_task_item.status = EvaluationStatus.ERROR
with Session(engine, expire_on_commit=False) as session:
session.merge(evaluation_task_item)
session.commit()
def evaluate_task(evaluation_task_item: EvaluationTaskItem):
logger.info(f"Enter evaluate_task with evaluation item #{evaluation_task_item.id}")
ragas_list = [
{
"user_input": evaluation_task_item.query,
"reference": evaluation_task_item.reference,
"response": evaluation_task_item.response,
}
]
logger.debug(f"Response data {evaluation_task_item.response}")
ragas_dataset = EvaluationDataset.from_list(ragas_list)
logger.debug(f"Dataset {ragas_dataset.to_pandas().head()}")
evaluator_llm = LlamaIndexLLMWrapper(
OpenAI(model="gpt-4o", api_key=settings.EVALUATION_OPENAI_API_KEY)
)
evaluator_embeddings = LlamaIndexEmbeddingsWrapper(
OpenAIEmbedding(
model="text-embedding-3-large", api_key=settings.EVALUATION_OPENAI_API_KEY
)
)
metrics = [
# LLMContextRecall(llm=evaluator_llm), # retrieved_contexts required
FactualCorrectness(llm=evaluator_llm),
# Faithfulness(llm=evaluator_llm), # retrieved_contexts required
SemanticSimilarity(embeddings=evaluator_embeddings),
]
try:
eval_result = evaluate(
dataset=ragas_dataset,
metrics=metrics,
raise_exceptions=True,
show_progress=False,
)
logger.debug("eval_result to_pandas")
result_list = eval_result.to_pandas().to_dict(orient="records")
logger.debug(f"result list {result_list}")
if len(result_list) != 1:
raise Exception(
f"Item {evaluation_task_item.id} cannot get evaluation from ragas"
)
logger.debug(f"result {result_list[0]}")
evaluation_task_item.factual_correctness = result_list[0][
FactualCorrectness.name
]
evaluation_task_item.semantic_similarity = result_list[0][
SemanticSimilarity.name
]
evaluation_task_item.status = EvaluationStatus.DONE
logger.info(f"Result evaluation item #{evaluation_task_item}")
with Session(engine, expire_on_commit=False) as session:
session.merge(evaluation_task_item)
session.commit()
except Exception as e:
logger.error(f"Failed to evaluate item #{evaluation_task_item.id}, error: {e}")
evaluation_task_item.error_msg = traceback.format_exc()
evaluation_task_item.status = EvaluationStatus.ERROR
with Session(engine, expire_on_commit=False) as session:
session.merge(evaluation_task_item)
session.commit()
@retry(stop=stop_after_attempt(2), wait=wait_fixed(5))
def generate_answer_by_autoflow(
messages: list[ChatMessage], chat_engine: str
) -> (str, list):
with Session(engine, expire_on_commit=False) as session:
chat_svc = ChatFlow(
db_session=session,
user=None,
browser_id="",
origin="evaluation",
chat_messages=messages,
engine_name=chat_engine,
)
sources, answer = [], ""
for m in chat_svc.chat():
if not isinstance(m, ChatEvent):
continue
if m.event_type == ChatEventType.MESSAGE_ANNOTATIONS_PART:
if m.payload.state == ChatMessageSate.SOURCE_NODES:
sources = m.payload.context
elif m.event_type == ChatEventType.TEXT_PART:
answer += m.payload
elif m.event_type == ChatEventType.ERROR_PART:
raise Exception(m.payload)
else:
pass
return answer, sources
def parse_langfuse_trace_id_from_url(trace_url: str) -> str:
# Example trace_url: https://us.cloud.langfuse.com/trace/87e7eb2e-b789-4b23-af60-fbcf0fd517a1
return trace_url.split("/")[-1]
================================================
FILE: backend/app/tasks/knowledge_base.py
================================================
from celery.utils.log import get_task_logger
from sqlalchemy import delete
from sqlmodel import Session
from app.celery import app as celery_app
from app.core.db import engine
from app.exceptions import KBNotFound
from app.models import (
Document,
KnowledgeBaseDataSource,
DataSource,
)
from app.rag.datasource import get_data_source_loader
from app.repositories import knowledge_base_repo, document_repo
from .build_index import build_index_for_document
from ..models.chunk import get_kb_chunk_model
from ..models.entity import get_kb_entity_model
from ..models.relationship import get_kb_relationship_model
from ..rag.knowledge_base.index_store import (
get_kb_tidb_vector_store,
get_kb_tidb_graph_store,
)
from ..repositories.chunk import ChunkRepo
from ..repositories.graph import GraphRepo
logger = get_task_logger(__name__)
@celery_app.task
def import_documents_for_knowledge_base(kb_id: int):
try:
with Session(engine) as session:
kb = knowledge_base_repo.must_get(session, kb_id)
data_sources = kb.data_sources
for data_source in data_sources:
import_documents_from_kb_datasource(kb.id, data_source.id)
logger.info(f"Successfully imported documents for knowledge base #{kb_id}")
except KBNotFound:
logger.error(f"Knowledge base #{kb_id} is not found")
except Exception as e:
logger.exception(
f"Failed to import documents for knowledge base #{kb_id}", exc_info=e
)
@celery_app.task
def import_documents_from_kb_datasource(kb_id: int, data_source_id: int):
try:
with Session(engine) as session:
kb = knowledge_base_repo.must_get(session, kb_id)
data_source = knowledge_base_repo.must_get_kb_datasource(
session, kb, data_source_id
)
logger.info(
f"Loading documents from data source #{data_source_id} for knowledge base #{kb_id}"
)
loader = get_data_source_loader(
session,
kb_id,
data_source.data_source_type,
data_source.id,
data_source.user_id,
data_source.config,
)
for document in loader.load_documents():
session.add(document)
session.commit()
build_index_for_document.delay(kb_id, document.id)
stats_for_knowledge_base.delay(kb_id)
logger.info(
f"Successfully imported documents for from datasource #{data_source_id}"
)
except Exception as e:
logger.exception(
f"Failed to import documents from data source #{data_source_id} of knowledge base #{kb_id}",
exc_info=e,
)
@celery_app.task
def stats_for_knowledge_base(kb_id: int):
try:
with Session(engine) as session:
kb = knowledge_base_repo.must_get(session, kb_id)
documents_total = knowledge_base_repo.count_documents(session, kb)
data_sources_total = knowledge_base_repo.count_data_sources(session, kb)
kb.documents_total = documents_total
kb.data_sources_total = data_sources_total
session.add(kb)
session.commit()
logger.info(f"Successfully running stats for knowledge base #{kb_id}")
except KBNotFound:
logger.error(f"Knowledge base #{kb_id} is not found")
except Exception as e:
logger.exception(f"Failed to run stats for knowledge base #{kb_id}", exc_info=e)
@celery_app.task
def purge_knowledge_base_related_resources(kb_id: int):
"""
Purge all resources related to a knowledge base.
Related resources:
- documents
- chunks
- indexes
- vector index
- knowledge graph index
- data sources
"""
with Session(engine) as session:
knowledge_base = knowledge_base_repo.must_get(
session, kb_id, show_soft_deleted=True
)
assert knowledge_base.deleted_at is not None
data_source_ids = [datasource.id for datasource in knowledge_base.data_sources]
# Drop entities_{kb_id}, relationships_{kb_id} tables.
tidb_graph_store = get_kb_tidb_graph_store(session, knowledge_base)
tidb_graph_store.drop_table_schema()
logger.info(
f"Dropped tidb graph store of knowledge base #{kb_id} successfully."
)
# Drop chunks_{kb_id} table.
tidb_vector_store = get_kb_tidb_vector_store(session, knowledge_base)
tidb_vector_store.drop_table_schema()
logger.info(
f"Dropped tidb vector store of knowledge base #{kb_id} successfully."
)
# Delete documents.
stmt = delete(Document).where(Document.knowledge_base_id == kb_id)
session.exec(stmt)
logger.info(f"Deleted documents of knowledge base #{kb_id} successfully.")
# Delete data sources and links.
if len(data_source_ids) > 0:
stmt = delete(KnowledgeBaseDataSource).where(
KnowledgeBaseDataSource.knowledge_base_id == kb_id
)
session.exec(stmt)
logger.info(
f"Deleted linked data sources of knowledge base #{kb_id} successfully."
)
stmt = delete(DataSource).where(DataSource.id.in_(data_source_ids))
session.exec(stmt)
logger.info(
f"Deleted data sources {', '.join([f'#{did}' for did in data_source_ids])} successfully."
)
# Delete knowledge base.
session.delete(knowledge_base)
logger.info(f"Deleted knowledge base #{kb_id} successfully.")
session.commit()
@celery_app.task
def purge_kb_datasource_related_resources(kb_id: int, datasource_id: int):
"""
Purge all resources related to the deleted datasource in the knowledge base.
"""
with Session(engine) as session:
kb = knowledge_base_repo.must_get(session, kb_id, show_soft_deleted=True)
datasource = knowledge_base_repo.must_get_kb_datasource(
session, kb, datasource_id, show_soft_deleted=True
)
assert datasource.deleted_at is not None
chunk_model = get_kb_chunk_model(kb)
entity_model = get_kb_entity_model(kb)
relationship_model = get_kb_relationship_model(kb)
chunk_repo = ChunkRepo(chunk_model)
graph_repo = GraphRepo(entity_model, relationship_model, chunk_model)
graph_repo.delete_data_source_relationships(session, datasource_id)
logger.info(
f"Deleted relationships generated by chunks from data source #{datasource_id} successfully."
)
graph_repo.delete_orphaned_entities(session)
logger.info("Deleted orphaned entities successfully.")
chunk_repo.delete_by_datasource(session, datasource_id)
logger.info(f"Deleted chunks from data source #{datasource_id} successfully.")
document_repo.delete_by_datasource(session, datasource_id)
logger.info(
f"Deleted documents from data source #{datasource_id} successfully."
)
session.delete(datasource)
logger.info(f"Deleted data source #{datasource_id} successfully.")
session.commit()
stats_for_knowledge_base.delay(kb_id)
================================================
FILE: backend/app/types.py
================================================
import enum
class MimeTypes(str, enum.Enum):
PLAIN_TXT = "text/plain"
MARKDOWN = "text/markdown"
PDF = "application/pdf"
DOCX = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPTX = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
XLSX = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
CSV = "text/csv"
================================================
FILE: backend/app/utils/aes.py
================================================
import os
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding
class AESCipher:
def __init__(self, key: bytes) -> None:
self.key = key
self.backend = default_backend()
def encrypt(self, plain_text: str) -> bytes:
# generate a random initialization vector
iv = os.urandom(16)
cipher = Cipher(algorithms.AES(self.key), modes.CFB(iv), backend=self.backend)
encryptor = cipher.encryptor()
# fill the last block with padding
padder = padding.PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(plain_text.encode()) + padder.finalize()
encrypted = encryptor.update(padded_data) + encryptor.finalize()
return iv + encrypted
def decrypt(self, encrypted_text: bytes) -> str:
# get the initialization vector and the encrypted data
iv = encrypted_text[:16]
encrypted_data = encrypted_text[16:]
cipher = Cipher(algorithms.AES(self.key), modes.CFB(iv), backend=self.backend)
decryptor = cipher.decryptor()
# remove the padding
decrypted_padded = decryptor.update(encrypted_data) + decryptor.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted = unpadder.update(decrypted_padded) + unpadder.finalize()
return decrypted
================================================
FILE: backend/app/utils/namespace.py
================================================
from typing import Optional
def format_namespace(namespace: Optional[str] = None) -> str:
return namespace.replace("-", "_") if namespace else ""
================================================
FILE: backend/app/utils/singleflight_cache.py
================================================
import threading
from functools import wraps
def singleflight_cache(func):
"""
A thread-safe cache decorator implementing the 'singleflight' pattern.
The singleflight pattern ensures that for any given set of arguments,
concurrent calls to the decorated function will only result in a single
actual execution. Other threads with the same arguments will wait for
the first execution to complete and then receive the same result,
rather than triggering duplicate computations.
This is especially useful for expensive or resource-intensive operations
where you want to avoid redundant work and prevent cache stampede.
Example:
@singleflight_cache
def load_data(key):
# expensive operation
...
# In multiple threads:
load_data('foo') # Only one thread will actually execute the function for 'foo'
"""
_cache = {}
_locks = {}
_locks_lock = threading.Lock()
@wraps(func)
def wrapper(*args, **kwargs):
key = args + tuple(sorted(kwargs.items()))
if key in _cache:
return _cache[key]
with _locks_lock:
lock = _locks.setdefault(key, threading.Lock())
with lock:
if key in _cache:
return _cache[key]
result = func(*args, **kwargs)
_cache[key] = result
return result
return wrapper
================================================
FILE: backend/app/utils/tracing.py
================================================
from contextlib import contextmanager
from typing import Optional, Generator
from langfuse.client import StatefulSpanClient, StatefulClient
from langfuse.llama_index import LlamaIndexInstrumentor
from langfuse.llama_index._context import langfuse_instrumentor_context
class LangfuseContextManager:
langfuse_client: Optional[StatefulSpanClient] = None
def __init__(self, instrumentor: LlamaIndexInstrumentor):
self.instrumentor = instrumentor
@contextmanager
def observe(self, **kwargs):
try:
self.instrumentor.start()
with self.instrumentor.observe(**kwargs) as trace_client:
trace_client.update(name=kwargs.get("trace_name"), **kwargs)
self.langfuse_client = trace_client
yield trace_client
except Exception:
raise
finally:
self.instrumentor.flush()
self.instrumentor.stop()
@contextmanager
def span(
self, parent_client: Optional[StatefulClient] = None, **kwargs
) -> Generator["StatefulSpanClient", None, None]:
if parent_client:
client = parent_client
else:
client = self.langfuse_client
span = client.span(**kwargs)
ctx = langfuse_instrumentor_context.get().copy()
old_parent_observation_id = ctx.get("parent_observation_id")
langfuse_instrumentor_context.get().update(
{
"parent_observation_id": span.id,
}
)
try:
yield span
except Exception:
raise
finally:
ctx.update(
{
"parent_observation_id": old_parent_observation_id,
}
)
langfuse_instrumentor_context.get().update(ctx)
@property
def trace_id(self) -> Optional[str]:
if self.langfuse_client:
return self.langfuse_client.trace_id
else:
return None
@property
def trace_url(self) -> Optional[str]:
if self.langfuse_client:
return self.langfuse_client.get_trace_url()
else:
return None
================================================
FILE: backend/app/utils/uuid6.py
================================================
r"""UUID draft version objects (universally unique identifiers).
This module provides the functions uuid6() and uuid7() for
generating version 6 and 7 UUIDs as specified in
https://github.com/uuid6/uuid6-ietf-draft.
Repo: https://github.com/oittaa/uuid6-python
Copy from: https://github.com/jonra1993/fastapi-alembic-sqlmodel-async/blob/main/backend/app/app/utils/uuid6.py
"""
import secrets
import time
import uuid
class UUID(uuid.UUID):
r"""UUID draft version objects"""
def __init__(
self,
hex: str = None,
bytes: bytes = None,
bytes_le: bytes = None,
fields: tuple[int, int, int, int, int, int] = None,
int: int = None,
version: int = None,
*,
is_safe=uuid.SafeUUID.unknown,
) -> None:
r"""Create a UUID."""
if int is None or [hex, bytes, bytes_le, fields].count(None) != 4:
super().__init__(
hex=hex,
bytes=bytes,
bytes_le=bytes_le,
fields=fields,
int=int,
version=version,
is_safe=is_safe,
)
if not 0 <= int < 1 << 128:
raise ValueError("int is out of range (need a 128-bit value)")
if version is not None:
if not 6 <= version <= 7:
raise ValueError("illegal version number")
# Set the variant to RFC 4122.
int &= ~(0xC000 << 48)
int |= 0x8000 << 48
# Set the version number.
int &= ~(0xF000 << 64)
int |= version << 76
super().__init__(int=int, is_safe=is_safe)
@property
def subsec(self) -> int:
return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF)
@property
def time(self) -> int:
if self.version == 6:
return (
(self.time_low << 28)
| (self.time_mid << 12)
| (self.time_hi_version & 0x0FFF)
)
if self.version == 7:
return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec)
return super().time
def _subsec_decode(value: int) -> int:
return -(-value * 10**6 // 2**20)
def _subsec_encode(value: int) -> int:
return value * 2**20 // 10**6
_last_v6_timestamp = None
_last_v7_timestamp = None
def uuid6(clock_seq: int = None) -> UUID:
r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for
improved DB locality. It is expected that UUIDv6 will primarily be
used in contexts where there are existing v1 UUIDs. Systems that do
not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead.
If 'clock_seq' is given, it is used as the sequence number;
otherwise a random 14-bit sequence number is chosen."""
global _last_v6_timestamp
nanoseconds = time.time_ns()
# 0x01b21dd213814000 is the number of 100-ns intervals between the
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
timestamp = nanoseconds // 100 + 0x01B21DD213814000
if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp:
timestamp = _last_v6_timestamp + 1
_last_v6_timestamp = timestamp
if clock_seq is None:
clock_seq = secrets.randbits(14) # instead of stable storage
node = secrets.randbits(48)
time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF
time_low_and_version = timestamp & 0x0FFF
uuid_int = time_high_and_time_mid << 80
uuid_int |= time_low_and_version << 64
uuid_int |= (clock_seq & 0x3FFF) << 48
uuid_int |= node
return UUID(int=uuid_int, version=6)
def uuid7() -> UUID:
r"""UUID version 7 features a time-ordered value field derived from the
widely implemented and well known Unix Epoch timestamp source, the
number of milliseconds seconds since midnight 1 Jan 1970 UTC, leap
seconds excluded. As well as improved entropy characteristics over
versions 1 or 6.
Implementations SHOULD utilize UUID version 7 over UUID version 1 and
6 if possible."""
global _last_v7_timestamp
nanoseconds = time.time_ns()
if _last_v7_timestamp is not None and nanoseconds <= _last_v7_timestamp:
nanoseconds = _last_v7_timestamp + 1
_last_v7_timestamp = nanoseconds
timestamp_ms, timestamp_ns = divmod(nanoseconds, 10**6)
subsec = _subsec_encode(timestamp_ns)
subsec_a = subsec >> 8
subsec_b = subsec & 0xFF
rand = secrets.randbits(54)
uuid_int = (timestamp_ms & 0xFFFFFFFFFFFF) << 80
uuid_int |= subsec_a << 64
uuid_int |= subsec_b << 54
uuid_int |= rand
return UUID(int=uuid_int, version=7)
================================================
FILE: backend/bootstrap.py
================================================
import secrets
import asyncio
from sqlmodel import select, func
from sqlmodel.ext.asyncio.session import AsyncSession
from colorama import Fore, Style
import click
from app.core.db import get_db_async_session_context
from app.models import User, ChatEngine
async def ensure_admin_user(
session: AsyncSession, email: str | None = None, password: str | None = None
) -> None:
result = await session.exec(select(User).where(User.is_superuser == True))
user = result.first()
if not user:
from app.auth.users import create_user
admin_email = email or "admin@example.com"
admin_password = password or secrets.token_urlsafe(16)
user = await create_user(
session,
email=admin_email,
password=admin_password,
is_active=True,
is_verified=True,
is_superuser=True,
)
print(Fore.RED + "\n" + "!" * 80)
print(
Fore.RED + "[IMPORTANT] Admin user created with email: "
f"{admin_email} and password: {admin_password}"
)
print(Fore.RED + "!" * 80 + "\n" + Style.RESET_ALL)
else:
print(Fore.YELLOW + "Admin user already exists, skipping...")
async def reset_admin_password(
session: AsyncSession, new_password: str | None = None
) -> None:
result = await session.exec(select(User).where(User.is_superuser == True))
user = result.first()
if not user:
print(Fore.YELLOW + "Admin user does not exist, skipping reset password...")
else:
from app.auth.users import update_user_password
admin_password = new_password or secrets.token_urlsafe(16)
updated_user = await update_user_password(
session,
user_id=user.id,
new_password=admin_password,
)
print(
Fore.GREEN + "Admin user password reset SUCCESS!\n"
f"email: {updated_user.email} \n"
f"password: {admin_password}" + Style.RESET_ALL
)
async def ensure_default_chat_engine(session: AsyncSession) -> None:
result = await session.scalar(func.count(ChatEngine.id))
if result == 0:
from app.rag.chat.config import ChatEngineConfig
chat_engine = ChatEngine(
name="default",
engine_options=ChatEngineConfig().model_dump(),
is_default=True,
)
session.add(chat_engine)
await session.commit()
print("Default chat engine created.")
else:
print(Fore.YELLOW + "Default chat engine already exists, skipping...")
async def bootstrap(
email: str | None = None, password: str | None = None, reset_password: bool = False
) -> None:
async with get_db_async_session_context() as session:
await ensure_admin_user(session, email, password)
await ensure_default_chat_engine(session)
if reset_password:
await reset_admin_password(session, password)
@click.command()
@click.option(
"--email", default=None, help="Admin user email, default=admin@example.com"
)
@click.option(
"--password", default=None, help="Admin user password, default=random generated"
)
@click.option("--reset-password", "-r", is_flag=True, help="Reset admin user password.")
def main(email: str | None, password: str | None, reset_password: bool):
"""Bootstrap the application with optional admin credentials."""
print(Fore.GREEN + "Bootstrapping the application..." + Style.RESET_ALL)
asyncio.run(bootstrap(email, password, reset_password))
print(Fore.GREEN + "Bootstrapping completed." + Style.RESET_ALL)
if __name__ == "__main__":
main()
================================================
FILE: backend/dspy_compiled_program/decompose_query/demos.json
================================================
[
{
"augmented": true,
"query": "I am designing a table based on TiDB's TTL feature, but when I try to create the table using a cluster created with Serverless, I get a `'TTL' is not supported on TiDB Serverless` error.\n\nI plan to use Dedicated on my production environment and Serverless on my development environment, so it would be helpful if the TTL feature could be used in a Serverless environment.\n\nI've read the documentation that says Serverless will support TTL features in the future, but is there a specific timeline for this?\n\nAlso, is it possible to prevent TTL syntax from causing errors in Serverless?",
"subquestions": "```json\n{\n \"questions\": [\n {\n \"question\": \"Why is the TTL feature not supported on TiDB Serverless?\",\n \"reasoning\": \"Understanding the current limitations of TiDB Serverless regarding the TTL feature.\"\n },\n {\n \"question\": \"Is there a specific timeline for when TiDB Serverless will support the TTL feature?\",\n \"reasoning\": \"The user wants to know when they can expect the TTL feature to be available in TiDB Serverless.\"\n },\n {\n \"question\": \"How can I prevent TTL syntax from causing errors in TiDB Serverless?\",\n \"reasoning\": \"The user needs a workaround to avoid errors when using TTL syntax in their development environment.\"\n },\n {\n \"question\": \"What are the differences between TiDB Serverless and TiDB Dedicated regarding TTL feature support?\",\n \"reasoning\": \"Understanding the differences will help the user plan their production and development environments effectively.\"\n }\n ]\n}\n```"
},
{
"augmented": true,
"query": "What are the feature changes in the latest version compared to v7.0 for TiDB?",
"subquestions": "```json\n{\n \"questions\": [\n {\n \"question\": \"What is the latest version of TiDB?\",\n \"reasoning\": \"To compare feature changes, we need to identify the latest version of TiDB.\"\n },\n {\n \"question\": \"What are the new features introduced in the latest version of TiDB?\",\n \"reasoning\": \"To understand the changes, we need to know what new features have been added in the latest version.\"\n },\n {\n \"question\": \"What features were present in TiDB v7.0?\",\n \"reasoning\": \"To make a comparison, we need to know the features available in TiDB v7.0.\"\n },\n {\n \"question\": \"What features have been removed or deprecated in the latest version compared to v7.0?\",\n \"reasoning\": \"Identifying removed or deprecated features helps in understanding the differences between the two versions.\"\n },\n {\n \"question\": \"What are the improvements or changes in existing features from v7.0 to the latest version?\",\n \"reasoning\": \"To provide a comprehensive comparison, we need to know how existing features have been improved or changed.\"\n }\n ]\n}\n```"
},
{
"augmented": true,
"query": "summary the performance improvement from version 6.5 to newest version for TiDB",
"subquestions": "```json\n{\n \"questions\": [\n {\n \"question\": \"What is the newest version of TiDB?\",\n \"reasoning\": \"To understand the performance improvements, we first need to identify the latest version of TiDB.\"\n },\n {\n \"question\": \"What were the performance metrics and benchmarks for TiDB version 6.5?\",\n \"reasoning\": \"To compare performance improvements, we need to establish a baseline by understanding the performance metrics of version 6.5.\"\n },\n {\n \"question\": \"What are the performance metrics and benchmarks for the newest version of TiDB?\",\n \"reasoning\": \"To identify improvements, we need to know the performance metrics of the latest version.\"\n },\n {\n \"question\": \"What specific performance improvements have been made from TiDB version 6.5 to the newest version?\",\n \"reasoning\": \"This question directly addresses the user's query by comparing the performance metrics between the two versions.\"\n },\n {\n \"question\": \"Are there any notable features or optimizations in the newest version that contribute to performance improvements?\",\n \"reasoning\": \"Understanding the features or optimizations that lead to performance improvements can provide deeper insights into the changes made.\"\n }\n ]\n}\n```"
},
{
"query": "tidb lighting to sync to serverless cluster,but the load command and the tidb-lighting tools dont have the tls config like --ssl-ca or --ca. so i can not sync to the full back data to the serverless",
"subquestions": "{\"questions\":[{\"question\":\"Sync data to a serverless cluster using TiDB Lighting.\",\"reasoning\":\"The user case what the user wants to achieve\"},{\"question\":\"How to configure TLS for TiDB Lightning?\",\"reasoning\":\"The basic question what the user itentionally asked.\"},{\"question\":\"The sync issue is caused by the lack of TLS configuration options for TiDB Lightning.\",\"reasoning\":\"The problem that the user is facing.\"}]}"
},
{
"query": "We are new to TiDB and don't quite understand the potential impact on our application architecture. We are using TiDB for audit logs and continue to direct traffic to TiDB. We noticed a sudden jump ID from 1 to 30,001. Are there any impacts? Do we need to address this? If we have 100 connections from several applications, what will happen? In summary, what should we do for Auto Increment or do nothing?",
"subquestions": "{\"questions\":[{\"question\":\"Why Auto Increment in TiDB causes a sudden increase in the ID values?\",\"reasoning\":\"The main concerns that the user itentionally asked.\"},{\"question\":\"How 100 connections from several applications affect TiDB, especially when the Auto Increment causes a sudden jump in ID values?\",\"reasoning\":\"The second most important question that the user itentionally asked.\"},{\"question\":\"TiDB is used for storing audit logs and receiving continuous traffic.\",\"reasoning\":\"The user case what the user wants to achieve\"}]}"
},
{
"query": "Upgrade TiDB Serverless to 7.4 or latest for enhanced MySQL 8.0 compatibility",
"subquestions": "{\"questions\":[{\"question\":\"TiDB 7.4 or the latest version enhances compatibility with MySQL 8.0\",\"reasoning\":\"The reasoning why user wants to upgrade TiDB Serverless to 7.4 or latest for enhanced MySQL 8.0 compatibility\"},{\"question\":\"How to upgrade TiDB Serverless?\",\"reasoning\":\"The basic question what the user itentionally asked.\"}]}"
},
{
"query": "I am current using tidb serverless, but as my product grows, I really need a dalicated cluster. Is there a solution helps finish the migration?",
"subquestions": "{\"questions\":[{\"question\":\"How to migrate from TiDB serverless to TiDB dedicated cluster?\",\"reasoning\":\"The main concern of the user.\"}]}"
},
{
"query": "Chat2query is returning an error message saying \"Query timeout expired\". Additionally, I am unable to locate this SQL query in the slow query log.",
"subquestions": "{\"questions\":[{\"question\":\"Chat2query is returning an error message saying 'Query timeout expired'.\",\"reasoning\":\"The main problem the user is facing.\"},{\"question\":\"The reason why not to locate the SQL query in the slow query log.\",\"reasoning\":\"The secondary problem the user is facing.\"}]}"
}
]
================================================
FILE: backend/dspy_compiled_program/decompose_query/program.json
================================================
{
"prog": {
"traces": [],
"train": [],
"demos": [
{
"augmented": true,
"query": "I am designing a table based on TiDB's TTL feature, but when I try to create the table using a cluster created with Serverless, I get a `'TTL' is not supported on TiDB Serverless` error.\n\nI plan to use Dedicated on my production environment and Serverless on my development environment, so it would be helpful if the TTL feature could be used in a Serverless environment.\n\nI've read the documentation that says Serverless will support TTL features in the future, but is there a specific timeline for this?\n\nAlso, is it possible to prevent TTL syntax from causing errors in Serverless?",
"subquestions": "```json\n{\n \"questions\": [\n {\n \"question\": \"Why is the TTL feature not supported on TiDB Serverless?\",\n \"reasoning\": \"Understanding the current limitations of TiDB Serverless regarding the TTL feature.\"\n },\n {\n \"question\": \"Is there a specific timeline for when TiDB Serverless will support the TTL feature?\",\n \"reasoning\": \"The user wants to know when they can expect the TTL feature to be available in TiDB Serverless.\"\n },\n {\n \"question\": \"How can I prevent TTL syntax from causing errors in TiDB Serverless?\",\n \"reasoning\": \"The user needs a workaround to avoid errors when using TTL syntax in their development environment.\"\n },\n {\n \"question\": \"What are the differences between TiDB Serverless and TiDB Dedicated regarding TTL feature support?\",\n \"reasoning\": \"Understanding the differences will help the user plan their production and development environments effectively.\"\n }\n ]\n}\n```"
},
{
"augmented": true,
"query": "What are the feature changes in the latest version compared to v7.0 for TiDB?",
"subquestions": "```json\n{\n \"questions\": [\n {\n \"question\": \"What is the latest version of TiDB?\",\n \"reasoning\": \"To compare feature changes, we need to identify the latest version of TiDB.\"\n },\n {\n \"question\": \"What are the new features introduced in the latest version of TiDB?\",\n \"reasoning\": \"To understand the changes, we need to know what new features have been added in the latest version.\"\n },\n {\n \"question\": \"What features were present in TiDB v7.0?\",\n \"reasoning\": \"To make a comparison, we need to know the features available in TiDB v7.0.\"\n },\n {\n \"question\": \"What features have been removed or deprecated in the latest version compared to v7.0?\",\n \"reasoning\": \"Identifying removed or deprecated features helps in understanding the differences between the two versions.\"\n },\n {\n \"question\": \"What are the improvements or changes in existing features from v7.0 to the latest version?\",\n \"reasoning\": \"To provide a comprehensive comparison, we need to know how existing features have been improved or changed.\"\n }\n ]\n}\n```"
},
{
"augmented": true,
"query": "summary the performance improvement from version 6.5 to newest version for TiDB",
"subquestions": "```json\n{\n \"questions\": [\n {\n \"question\": \"What is the newest version of TiDB?\",\n \"reasoning\": \"To understand the performance improvements, we first need to identify the latest version of TiDB.\"\n },\n {\n \"question\": \"What were the performance metrics and benchmarks for TiDB version 6.5?\",\n \"reasoning\": \"To compare performance improvements, we need to establish a baseline by understanding the performance metrics of version 6.5.\"\n },\n {\n \"question\": \"What are the performance metrics and benchmarks for the newest version of TiDB?\",\n \"reasoning\": \"To identify improvements, we need to know the performance metrics of the latest version.\"\n },\n {\n \"question\": \"What specific performance improvements have been made from TiDB version 6.5 to the newest version?\",\n \"reasoning\": \"This question directly addresses the user's query by comparing the performance metrics between the two versions.\"\n },\n {\n \"question\": \"Are there any notable features or optimizations in the newest version that contribute to performance improvements?\",\n \"reasoning\": \"Understanding the features or optimizations that lead to performance improvements can provide deeper insights into the changes made.\"\n }\n ]\n}\n```"
},
{
"query": "tidb lighting to sync to serverless cluster,but the load command and the tidb-lighting tools dont have the tls config like --ssl-ca or --ca. so i can not sync to the full back data to the serverless",
"subquestions": "{\"questions\":[{\"question\":\"Sync data to a serverless cluster using TiDB Lighting.\",\"reasoning\":\"The user case what the user wants to achieve\"},{\"question\":\"How to configure TLS for TiDB Lightning?\",\"reasoning\":\"The basic question what the user itentionally asked.\"},{\"question\":\"The sync issue is caused by the lack of TLS configuration options for TiDB Lightning.\",\"reasoning\":\"The problem that the user is facing.\"}]}"
},
{
"query": "We are new to TiDB and don't quite understand the potential impact on our application architecture. We are using TiDB for audit logs and continue to direct traffic to TiDB. We noticed a sudden jump ID from 1 to 30,001. Are there any impacts? Do we need to address this? If we have 100 connections from several applications, what will happen? In summary, what should we do for Auto Increment or do nothing?",
"subquestions": "{\"questions\":[{\"question\":\"Why Auto Increment in TiDB causes a sudden increase in the ID values?\",\"reasoning\":\"The main concerns that the user itentionally asked.\"},{\"question\":\"How 100 connections from several applications affect TiDB, especially when the Auto Increment causes a sudden jump in ID values?\",\"reasoning\":\"The second most important question that the user itentionally asked.\"},{\"question\":\"TiDB is used for storing audit logs and receiving continuous traffic.\",\"reasoning\":\"The user case what the user wants to achieve\"}]}"
},
{
"query": "Upgrade TiDB Serverless to 7.4 or latest for enhanced MySQL 8.0 compatibility",
"subquestions": "{\"questions\":[{\"question\":\"TiDB 7.4 or the latest version enhances compatibility with MySQL 8.0\",\"reasoning\":\"The reasoning why user wants to upgrade TiDB Serverless to 7.4 or latest for enhanced MySQL 8.0 compatibility\"},{\"question\":\"How to upgrade TiDB Serverless?\",\"reasoning\":\"The basic question what the user itentionally asked.\"}]}"
},
{
"query": "I am current using tidb serverless, but as my product grows, I really need a dalicated cluster. Is there a solution helps finish the migration?",
"subquestions": "{\"questions\":[{\"question\":\"How to migrate from TiDB serverless to TiDB dedicated cluster?\",\"reasoning\":\"The main concern of the user.\"}]}"
},
{
"query": "Chat2query is returning an error message saying \"Query timeout expired\". Additionally, I am unable to locate this SQL query in the slow query log.",
"subquestions": "{\"questions\":[{\"question\":\"Chat2query is returning an error message saying 'Query timeout expired'.\",\"reasoning\":\"The main problem the user is facing.\"},{\"question\":\"The reason why not to locate the SQL query in the slow query log.\",\"reasoning\":\"The secondary problem the user is facing.\"}]}"
}
],
"signature": {
"instructions": "You are an expert in knowledge base graph construction, specializing in building comprehensive knowledge graphs.\nYour current task is to deconstruct the user's query into a series of step-by-step questions.\n\n## Instructions:\n\n1. Dependency Analysis:\n\n - Analyze the user's query to identify the underlying dependencies and relationships between different components.\n - Construct a dependency graph that visually represents these relationships.\n\n2. Question Breakdown: Divide the query into a sequence of step-by-step questions necessary to address the main query comprehensively.\n\n3. Provide Reasoning: Explain the rationale behind each question.\n\n4. Constraints:\n - Limit the output to no more than 5 questions to maintain focus and relevance.\n - Ensure accuracy by reflecting the user's true intentions based on the provided query.\n - Ground all questions in factual information derived directly from the user's input.\n\nPlease only response in JSON format.",
"fields": [
{
"prefix": "Query:",
"description": "The query text to extract the user's step-by-step questions."
},
{
"prefix": "Subquestions:",
"description": "Representation of the user's step-by-step questions extracted from the query."
}
]
},
"lm": null
},
"metadata": {
"dependency_versions": {}
}
}
================================================
FILE: backend/dspy_program.py
================================================
import os
import dspy
from app.rag.question_gen.query_decomposer import DecomposeQueryModule
def save_decompose_query_program():
dspy_lm = dspy.LM(model="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY"))
module = DecomposeQueryModule(dspy_lm)
module.save("dspy_compiled_program/decompose_query/program.json")
if __name__ == "__main__":
save_decompose_query_program()
================================================
FILE: backend/local_embedding_reranker/.dockerignore
================================================
.venv
================================================
FILE: backend/local_embedding_reranker/Dockerfile
================================================
FROM python:3.11.9-slim
WORKDIR /app
COPY requirements.txt /app/requirements.txt
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -r /app/requirements.txt
COPY . /app/
ENV PYTHONPATH=/app
EXPOSE 5001
CMD ["uvicorn", "main:app", "--port", "5001", "--host", "0.0.0.0"]
================================================
FILE: backend/local_embedding_reranker/main.py
================================================
import logging
import uvicorn
from pydantic import BaseModel
from fastapi import FastAPI, APIRouter
from sentence_transformers import SentenceTransformer, CrossEncoder
from contextlib import asynccontextmanager
from environs import Env
env = Env()
env.read_env()
logging.basicConfig(
format="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
PRE_LOAD_DEFAULT_EMBEDDING_MODEL = env.bool(
"PRE_LOAD_DEFAULT_EMBEDDING_MODEL", default=True
)
PRE_LOAD_DEFAULT_RERANKER_MODEL = env.bool(
"PRE_LOAD_DEFAULT_RERANKER_MODEL", default=False
)
DEFAULT_EMBEDDING_MODEL = env.str("DEFAULT_EMBEDDING_MODEL", default="BAAI/bge-m3")
DEFAULT_RERANKER_MODEL = env.str(
"DEFAULT_RERANKER_MODEL", default="BAAI/bge-reranker-v2-m3"
)
router = APIRouter()
@router.get("/healthz")
def healthz():
return "OK"
EMBEDDING_MODEL_DICT: dict[str, SentenceTransformer] = {}
RERANKER_MODEL_DICT: dict[str, CrossEncoder] = {}
def get_embedding_model(model_name: str) -> SentenceTransformer:
global EMBEDDING_MODEL_DICT
embed_model = EMBEDDING_MODEL_DICT.get(model_name)
if not embed_model:
embed_model = SentenceTransformer(
model_name_or_path=model_name,
trust_remote_code=True,
)
EMBEDDING_MODEL_DICT[model_name] = embed_model
return embed_model
def get_reranker_model(model_name: str) -> CrossEncoder:
global RERANKER_MODEL_DICT
reranker_model = RERANKER_MODEL_DICT.get(model_name)
if not reranker_model:
reranker_model = CrossEncoder(
model_name=model_name,
automodel_args={"torch_dtype": "auto"},
trust_remote_code=True,
)
RERANKER_MODEL_DICT[model_name] = reranker_model
return reranker_model
class EmbeddingRequest(BaseModel):
sentences: list[str]
model: str = DEFAULT_EMBEDDING_MODEL
normalize_embeddings: bool = True
class EmbeddingResponse(BaseModel):
model: str
embeddings: list[list]
@router.post("/embedding")
def get_texts_embedding(request: EmbeddingRequest) -> EmbeddingResponse:
embed_model = get_embedding_model(model_name=request.model)
embeddings = embed_model.encode(
sentences=request.sentences,
normalize_embeddings=request.normalize_embeddings,
)
return EmbeddingResponse(
model=request.model,
embeddings=embeddings.tolist(),
)
class RerankerRequest(BaseModel):
model: str = DEFAULT_RERANKER_MODEL
query: str
passages: list[str]
class RerankerResponse(BaseModel):
model: str
scores: list[float]
@router.post("/reranker")
def reranker_texts(request: RerankerRequest) -> RerankerResponse:
reranker_model = get_reranker_model(request.model)
sentence_pairs = [(request.query, p) for p in request.passages]
scores = reranker_model.predict(sentence_pairs, convert_to_tensor=True)
return RerankerResponse(model=request.model, scores=scores.tolist())
@asynccontextmanager
async def lifespan(app: FastAPI):
if PRE_LOAD_DEFAULT_EMBEDDING_MODEL:
logger.info(f"Loading default embedding model: {DEFAULT_EMBEDDING_MODEL}")
get_embedding_model(DEFAULT_EMBEDDING_MODEL)
logger.info("Default embedding model loaded")
if PRE_LOAD_DEFAULT_RERANKER_MODEL:
logger.info(f"Loading default reranker model: {DEFAULT_RERANKER_MODEL}")
get_reranker_model(DEFAULT_RERANKER_MODEL)
logger.info("Default reranker model loaded")
yield
app = FastAPI(lifespan=lifespan)
app.include_router(router=router, prefix="/api/v1")
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=5001, reload=True)
================================================
FILE: backend/local_embedding_reranker/requirements.txt
================================================
fastapi==0.112.2
uvicorn==0.30.6
sentence_transformers==3.0.1
einops==0.8.0
environs==11.1.0
================================================
FILE: backend/main.py
================================================
import warnings
import logging
from logging.config import dictConfig
from contextlib import asynccontextmanager
import click
import sentry_sdk
import uvicorn
from fastapi import FastAPI, Request, Response
from fastapi.routing import APIRoute
from starlette.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from app.api.main import api_router
from app.core.config import settings, Environment
from app.site_settings import SiteSetting
from app.utils.uuid6 import uuid7
dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"root": {
"level": logging.INFO
if settings.ENVIRONMENT != Environment.LOCAL
else logging.DEBUG,
"handlers": ["console"],
},
"loggers": {
"uvicorn.error": {
"level": "ERROR",
"handlers": ["console"],
"propagate": False,
},
"uvicorn.access": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
},
}
)
logger = logging.getLogger(__name__)
load_dotenv()
def custom_generate_unique_id(route: APIRoute) -> str:
return f"{route.tags[0]}-{route.name}"
if settings.SENTRY_DSN and settings.ENVIRONMENT != "local":
sentry_sdk.init(
dsn=str(settings.SENTRY_DSN),
enable_tracing=True,
traces_sample_rate=settings.SENTRY_TRACES_SAMPLE_RATE,
profiles_sample_rate=settings.SENTRY_PROFILES_SAMPLE_RATE,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
SiteSetting.update_db_cache()
yield
app = FastAPI(
title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json",
generate_unique_id_function=custom_generate_unique_id,
lifespan=lifespan,
)
# Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=[
str(origin).strip("/") for origin in settings.BACKEND_CORS_ORIGINS
],
allow_origin_regex=settings.BACKEND_CORS_ORIGIN_REGEXP,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def identify_browser(request: Request, call_next):
browser_id = request.cookies.get(settings.BROWSER_ID_COOKIE_NAME)
has_browser_id = bool(browser_id)
if not browser_id:
browser_id = uuid7()
request.state.browser_id = browser_id
response: Response = await call_next(request)
if not has_browser_id:
response.set_cookie(
settings.BROWSER_ID_COOKIE_NAME,
browser_id,
max_age=settings.BROWSER_ID_COOKIE_MAX_AGE,
)
return response
app.include_router(api_router, prefix=settings.API_V1_STR)
@click.group(context_settings={"max_content_width": 150})
def cli():
pass
@cli.command()
@click.option("--host", default="127.0.0.1", help="Host, default=127.0.0.1")
@click.option("--port", default=3000, help="Port, default=3000")
@click.option("--workers", default=4)
def runserver(host, port, workers):
warnings.warn(
"This command will start the server in development mode, do not use it in production."
)
uvicorn.run(
"main:app",
host=host,
port=port,
reload=True,
log_level="debug",
workers=workers,
)
@cli.command()
@click.option(
"--dataset", default="regression", help="Dataset name, default=regression"
)
@click.option(
"--llm-provider",
default="openai",
help="LLM provider, default=openai, options=[openai, gemini]",
)
@click.option("--run-name", default=None, help="Run name, default=None")
@click.option(
"--tidb-ai-chat-engine",
default="default",
help="TiDB AI chat engine, default=default",
)
def runeval(dataset, llm_provider, run_name, tidb_ai_chat_engine):
from app.evaluation.evals import Evaluation
eval = Evaluation(
dataset_name=dataset,
llm_provider=llm_provider,
run_name=run_name,
tidb_ai_chat_engine=tidb_ai_chat_engine,
)
eval.run()
@cli.command()
@click.option(
"--csv",
default="autoflow_dataset.csv",
help="Dataset CSV file name that contains two columns `query` and `reference`, default='autoflow_dataset.csv'",
)
@click.option(
"--llm-provider",
default="openai",
help="LLM provider, default=openai, options=[openai, gemini]",
)
@click.option("--run-name", default=None, help="Run name, default=None")
@click.option(
"--tidb-ai-chat-engine",
default="default",
help="TiDB AI chat engine, default=default",
)
@click.option("--run-size", default=30, help="Run size, default=30")
def runeval_dataset(csv, llm_provider, run_name, tidb_ai_chat_engine, run_size):
from app.evaluation.evals import Evaluation
evaluation = Evaluation(
dataset_name="customize",
llm_provider=llm_provider,
run_name=run_name,
tidb_ai_chat_engine=tidb_ai_chat_engine,
)
evaluation.runeval_dataset(csv_dataset=csv, run_size=run_size)
@cli.command()
@click.option("--query", default=None, help="query")
def generate_answer_by_tidb_ai(query: str):
from app.evaluation.evals import Evaluation
evaluation = Evaluation(
dataset_name="customize",
llm_provider="openai",
run_name=None,
tidb_ai_chat_engine="default",
)
print(
evaluation.generate_answer_by_tidb_ai(
messages=[{"role": "user", "content": query}]
)
)
if __name__ == "__main__":
cli()
================================================
FILE: backend/prestart.sh
================================================
#! /usr/bin/env bash
================================================
FILE: backend/pyproject.toml
================================================
[project]
name = "tidb-ai-backend"
version = "0.2.9"
description = "The Backend Server for TiDB.AI"
authors = [
{ name = "wd0517", email = "me@wangdi.ink" },
{ name = "mini256", email = "minianter@gmail.com" }
]
readme = "README.md"
requires-python = ">= 3.10"
dependencies = [
"fastapi>=0.115.6",
"fastapi-cli>=0.0.5",
"fastapi-users>=13.0.0",
"fastapi-pagination>=0.12.25",
"fastapi-users-db-sqlmodel>=0.3.0",
"sqlmodel==0.0.19",
"pymysql>=1.1.1",
"asyncmy>=0.2.9",
"tidb-vector>=0.0.14",
"alembic>=1.14.0",
"pydantic>=2.10.5",
# Update Check: https://github.com/pydantic/pydantic/issues/8061
"pydantic-settings>=2.3.3",
"redis>=5.0.5",
"celery>=5.4.0",
"flower>=2.0.1",
"httpx-oauth>=0.14.1",
"uvicorn>=0.30.3",
"gunicorn>=22.0.0",
"python-dotenv>=1.0.1",
"sentry-sdk>=2.5.1",
"langfuse>=2.59.1",
"langchain-openai>=0.2.9",
# Pin ragas to 0.2.6 to avoid async loop error:
# https://github.com/explodinggradients/ragas/issues/1819
"ragas==0.2.6",
"deepeval>=0.21.73",
"click>=8.1.7",
"tenacity~=8.4.0",
"retry>=0.9.2",
"deepdiff>=7.0.1",
"colorama>=0.4.6",
"jinja2>=3.1.4",
"pyyaml>=6.0.1",
"playwright>=1.45.1",
"markdownify>=0.13.1",
"pypdf>=4.3.1",
"python-docx>=1.1.2",
"python-pptx>=1.0.2",
"openpyxl>=3.1.5",
"llama-index>=0.12.16",
"llama-index-llms-openai>=0.6.12",
"llama-index-llms-openai-like>=0.5.3",
"llama-index-llms-bedrock-converse>=0.4.15",
"llama-index-llms-ollama>=0.5.0",
"llama-index-llms-google-genai>=0.1.6",
"llama-index-llms-azure-openai>=0.3.0",
"llama-index-embeddings-ollama>=0.5.0",
"llama-index-embeddings-jinaai>=0.4.0",
"llama-index-embeddings-cohere>=0.4.0",
"llama-index-embeddings-bedrock>=0.4.0",
"llama-index-embeddings-azure-openai>=0.3.0",
"llama-index-postprocessor-jinaai-rerank>=0.3.0",
"llama-index-postprocessor-cohere-rerank>=0.3.0",
"llama-index-postprocessor-xinference-rerank>=0.2.0",
"llama-index-postprocessor-bedrock-rerank>=0.3.0",
"llama-index-llms-vertex>=0.4.2",
"socksio>=1.0.0",
"dspy>=2.6.21",
"litellm>=1.77.5",
]
[dependency-groups]
dev = [
"pre-commit>=4.2.0",
"pytest-asyncio>=0.25.3",
"ruff>=0.11.2",
]
# Lint
[tool.ruff.lint]
ignore = ["E711", "E712", "F811", "F841"]
# Test
[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
asyncio_mode = "auto"
env_files = ["tests/.env"]
================================================
FILE: backend/supervisord.conf
================================================
[supervisord]
nodaemon=true
logfile=/var/log/supervisord.log
[program:celery_worker]
command=celery -A app.celery worker -n worker-default@%%h -Q default --concurrency=5 --loglevel=INFO --logfile=/var/log/celery_worker.log
directory=/app
stdout_logfile=/var/log/celery_worker_supervisor.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
[program:evaluation_worker]
command=celery -A app.celery worker -n worker-evaluation@%%h -Q evaluation --pool=solo --loglevel=INFO --logfile=/var/log/evaluation_worker.log
directory=/app
stdout_logfile=/var/log/evaluation_worker_supervisor.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
[program:celery_flower]
command=celery -A app.celery flower --address=0.0.0.0 --port=5555
directory=/app
stdout_logfile=/var/log/celery_flower_supervisor.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
# Pushes all logs from the above programs to stdout
# No log rotation here, since it's stdout it's handled by the Docker container loglevel
# To be standard across all the services
[program:log-redirect-handler]
command=tail -qF /var/log/celery_worker.log /var/log/celery_worker_supervisor.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true
autorestart=true
================================================
FILE: backend/tests/__init__.py
================================================
================================================
FILE: backend/tests/conftest.py
================================================
import pytest
from dotenv import load_dotenv
@pytest.fixture(scope="session", autouse=True)
def env():
print("Loading environment variables")
load_dotenv()
================================================
FILE: backend/tests/test_dynamic_models.py
================================================
import threading
from app.models.entity import get_dynamic_entity_model
from app.models.relationship import get_dynamic_relationship_model
from app.models.chunk import get_dynamic_chunk_model
def dynamic_model_creation(dim, ns):
entity_model = get_dynamic_entity_model(dim, ns)
relationship_model = get_dynamic_relationship_model(dim, ns, entity_model)
chunk_model = get_dynamic_chunk_model(dim, ns)
return entity_model, relationship_model, chunk_model
def test_concurrent_dynamic_model_creation():
results = [None] * 10
threads = []
for i in range(10):
t = threading.Thread(
target=lambda idx: results.__setitem__(
idx, dynamic_model_creation(128, "test")
),
args=(i,),
)
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Ensure each model is created only once across all threads
entity_models, relationship_models, chunk_models = zip(*results)
assert all(m is entity_models[0] for m in entity_models)
assert all(m is relationship_models[0] for m in relationship_models)
assert all(m is chunk_models[0] for m in chunk_models)
================================================
FILE: backend/tests/test_llms.py
================================================
import json
import os
import logging
from logging import getLogger
import pytest
import dspy
from litellm import verbose_logger
from llama_index.core import PromptTemplate
from llama_index.core.base.llms.base import BaseLLM
from app.rag.indices.knowledge_graph.extractor import Extractor
from app.rag.llms.provider import LLMProvider
from app.rag.llms.resolver import resolve_llm
from app.rag.question_gen.query_decomposer import QueryDecomposer
from app.rag.llms.dspy import get_dspy_lm_by_llama_llm
question = "Is TiDB open source? (Yes/No)"
content = """
TiDB is a distributed database that you can use the MySQL client to connect to.
"""
os.environ["LITELLM_LOG"] = "DEBUG"
verbose_logger.setLevel(logging.WARN)
logger = getLogger(__name__)
def check_llm_answer_simple_question(llm: BaseLLM):
prompt = PromptTemplate(question)
output = llm.predict(prompt)
assert "yes" in output.lower()
logger.info(f"Generated answer: \n{output}")
def check_dspy_lm_decompose_question(lm: dspy.LM):
decomposer = QueryDecomposer(lm)
subquestions = decomposer.decompose("What is TiDB").questions
assert len(subquestions) >= 1
questions = [q.question for q in subquestions]
logger.info(f"Generated sub-question: \n{questions}")
def check_dspy_lm_extract_graph(lm: dspy.LM):
extractor = Extractor(lm)
kg = extractor.forward(content).knowledge
assert len(kg.entities) >= 2
assert len(kg.relationships) >= 1
entities = [f"{e.name}: {e.description}" for e in kg.entities]
relationships = [
f"{r.source_entity} -> {r.relationship_desc} -> {r.target_entity}"
for r in kg.relationships
]
logger.info(f"Extracted entities: \n{entities}")
logger.info(f"Extracted relationships: \n{relationships}")
def test_openai():
llm = resolve_llm(
provider=LLMProvider.OPENAI,
model="gpt-4o-mini",
config={},
credentials=os.getenv("OPENAI_API_KEY"),
)
check_llm_answer_simple_question(llm)
lm = get_dspy_lm_by_llama_llm(llm)
check_dspy_lm_decompose_question(lm)
check_dspy_lm_extract_graph(lm)
@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS"), reason="ollama is not available on GitHub Actions"
)
def test_ollama():
llm = resolve_llm(
provider=LLMProvider.OLLAMA,
model="gemma3:4b",
)
check_llm_answer_simple_question(llm)
lm = get_dspy_lm_by_llama_llm(llm)
check_dspy_lm_decompose_question(lm)
check_dspy_lm_extract_graph(lm)
def test_gitee_ai():
llm = resolve_llm(
provider=LLMProvider.GITEEAI,
model="Qwen2.5-72B-Instruct",
credentials=os.getenv("GITEEAI_API_KEY"),
)
check_llm_answer_simple_question(llm)
lm = get_dspy_lm_by_llama_llm(llm)
check_dspy_lm_decompose_question(lm)
check_dspy_lm_extract_graph(lm)
def test_bedrock():
llm = resolve_llm(
provider=LLMProvider.BEDROCK,
model="meta.llama3-1-70b-instruct-v1:0",
credentials={
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_region_name": os.getenv("AWS_REGION_NAME"),
},
)
check_llm_answer_simple_question(llm)
lm = get_dspy_lm_by_llama_llm(llm)
check_dspy_lm_decompose_question(lm)
check_dspy_lm_extract_graph(lm)
def test_vertex_ai():
llm = resolve_llm(
provider=LLMProvider.VERTEX_AI,
model="gemini-2.0-flash-001",
credentials=json.loads(os.getenv("GOOGLE_SERVICE_ACCOUNT_CREDENTIALS")),
config={"location": "us-west1"},
)
check_llm_answer_simple_question(llm)
lm = get_dspy_lm_by_llama_llm(llm)
check_dspy_lm_decompose_question(lm)
check_dspy_lm_extract_graph(lm)
def test_gemini():
llm = resolve_llm(
provider=LLMProvider.GEMINI,
model="models/gemini-2.0-flash-001",
credentials=os.getenv("GEMINI_API_KEY"),
)
check_llm_answer_simple_question(llm)
lm = get_dspy_lm_by_llama_llm(llm)
check_dspy_lm_decompose_question(lm)
check_dspy_lm_extract_graph(lm)
def test_azure_ai():
llm = resolve_llm(
provider=LLMProvider.AZURE_OPENAI,
model="gpt-4o-mini",
credentials=os.getenv("AZURE_AI_API_KEY"),
config={
"azure_endpoint": os.getenv("AZURE_AI_ENDPOINT"),
"engine": "gpt-4o",
"api_version": "2025-01-01-preview",
},
)
check_llm_answer_simple_question(llm)
lm = get_dspy_lm_by_llama_llm(llm)
check_dspy_lm_decompose_question(lm)
check_dspy_lm_extract_graph(lm)
================================================
FILE: core/.cursor/rules/code-style.mdc
================================================
---
description:
globs:
alwaysApply: true
---
# Code Style
- Always sort the imports with the order (standard, third-party, local)
================================================
FILE: core/.gitignore
================================================
# Python generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
.mypy_cache
# Environment
.env
# Virtual environment
.venv
# IDE
.idea
# pytest
.pytest_cache
================================================
FILE: core/.python-version
================================================
3.12.4
================================================
FILE: core/Makefile
================================================
install:
@pip install uv
@uv sync
test:
@PYTHONPATH=$(PWD) uv run pytest tests
build:
@PYTHONPATH=$(PWD) uv build
publish:
@uv publish
lint:
@uv tool run ruff check
format:
@uv tool run format
================================================
FILE: core/README.md
================================================
# Autoflow
A unified storage layer for AI agents and RAG applications.
================================================
FILE: core/autoflow/__init__.py
================================================
import os
from .main import Autoflow
if os.getenv("LITELLM_LOCAL_MODEL_COST_MAP") is None:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
__all__ = [
"Autoflow",
]
================================================
FILE: core/autoflow/chunkers/__init__.py
================================================
================================================
FILE: core/autoflow/chunkers/base.py
================================================
from abc import abstractmethod
from autoflow.types import BaseComponent
from autoflow.storage.doc_store import Document
class Chunker(BaseComponent):
@abstractmethod
def chunk(self, document: Document) -> Document:
raise NotImplementedError
================================================
FILE: core/autoflow/chunkers/helper.py
================================================
from autoflow.chunkers.base import Chunker
from autoflow.data_types import DataType
def get_chunker_for_datatype(datatype: DataType) -> Chunker:
if datatype in [DataType.MARKDOWN, DataType.HTML, DataType.PDF]:
from autoflow.chunkers.text import TextChunker
return TextChunker()
else:
raise ValueError(f"Unsupported data type: {datatype}")
================================================
FILE: core/autoflow/chunkers/text.py
================================================
from typing import Optional
from autoflow.chunkers.base import Chunker
from autoflow.configs.chunkers.text import TextChunkerConfig
from autoflow.storage.doc_store import Document, Chunk
class TextChunker(Chunker):
"""Chunker for text."""
def __init__(self, config: Optional[TextChunkerConfig] = TextChunkerConfig()):
super().__init__()
from llama_index.core.node_parser import SentenceSplitter
self._splitter = SentenceSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
)
def chunk(self, document: Document) -> Document:
texts = self._splitter.split_text(document.content)
document.chunks = [Chunk(text=text) for text in texts]
return document
================================================
FILE: core/autoflow/configs/__init__.py
================================================
================================================
FILE: core/autoflow/configs/chunkers/__init__.py
================================================
================================================
FILE: core/autoflow/configs/chunkers/base.py
================================================
from typing import Optional, Dict
from pydantic import BaseModel, Field, model_validator
class ChunkerConfig(BaseModel):
provider: str = Field(
description="Provider of the chunker (e.g., 'text')",
default="openai",
)
config: Optional[Dict] = Field(
description="Configuration for the specific chunker",
default=None,
)
_provider_configs: Dict[str, str] = {
"text": "TextChunkerConfig",
}
@model_validator(mode="after")
def validate_and_create_config(self) -> "ChunkerConfig":
provider = self.provider
config = self.config
if provider not in self._provider_configs:
raise ValueError(f"Unsupported chunker provider: {provider}")
module = __import__(
f"autoflow.configs.chunkers.{provider}",
fromlist=[self._provider_configs[provider]],
)
config_class = getattr(module, self._provider_configs[provider])
if config is None:
config = {}
if not isinstance(config, dict):
if not isinstance(config, config_class):
raise ValueError(
f"Invalid config type for chunker provider: {provider}"
)
return self
self.config = config_class(**config)
return self
================================================
FILE: core/autoflow/configs/chunkers/text.py
================================================
from pydantic import BaseModel, Field
DEFAULT_CHUNK_SIZE = 1024 # tokens
DEFAULT_CHUNK_OVERLAP = 20 # tokens
class TextChunkerConfig(BaseModel):
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, description="Chunk size")
chunk_overlap: int = Field(
default=DEFAULT_CHUNK_OVERLAP, description="Chunk overlap"
)
================================================
FILE: core/autoflow/configs/db.py
================================================
from typing import Optional
from pydantic import BaseModel, Field
class DatabaseConfig(BaseModel):
database_url: Optional[str] = Field(
description="Database connection string",
default=None,
)
provider: Optional[str] = Field(
description="Database provider",
default="tidb",
)
host: Optional[str] = Field(
description="Database host.",
default="localhost",
)
port: Optional[int] = Field(description="Database port.", default=4000)
username: Optional[str] = Field(
description="The username to connect the database.",
default="root",
)
password: Optional[str] = Field(
description="The password to connect the database.",
default="",
)
database: str = Field(
description="Default name for the database",
default="autoflow",
)
enable_ssl: Optional[bool] = Field(
description="Enable SSL connection.",
default=True,
)
================================================
FILE: core/autoflow/configs/knowledge_base.py
================================================
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from autoflow.configs.models.llms import LLMConfig
from autoflow.configs.models.embeddings import EmbeddingModelConfig
from autoflow.configs.models.rerankers import RerankerConfig
# Index Methods
class IndexMethod(str, Enum):
VECTOR_SEARCH = "VECTOR_SEARCH"
FULLTEXT_SEARCH = "FULLTEXT_SEARCH"
KNOWLEDGE_GRAPH = "KNOWLEDGE_GRAPH"
DEFAULT_INDEX_METHODS = [IndexMethod.VECTOR_SEARCH]
# Knowledge Base Config
class Version(int, Enum):
V1 = 1
class KnowledgeBaseConfig(BaseModel):
version: int = Version.V1
name: str
description: Optional[str] = None
index_methods: List[IndexMethod] = DEFAULT_INDEX_METHODS
llm: LLMConfig = None
embedding_model: EmbeddingModelConfig = None
reranker: RerankerConfig = None
================================================
FILE: core/autoflow/configs/main.py
================================================
from pydantic import BaseModel
from autoflow.configs.db import DatabaseConfig
class Config(BaseModel):
db: DatabaseConfig = DatabaseConfig()
================================================
FILE: core/autoflow/configs/models/__init__.py
================================================
================================================
FILE: core/autoflow/configs/models/embeddings/__init__.py
================================================
from .base import EmbeddingModelConfig
__all__ = [
"EmbeddingModelConfig",
]
================================================
FILE: core/autoflow/configs/models/embeddings/base.py
================================================
from typing import Optional, Dict
from pydantic import BaseModel, Field, model_validator
from autoflow.configs.models.providers import ModelProviders
class EmbeddingModelConfig(BaseModel):
provider: ModelProviders = Field(
description="Provider of the embedding_models models (e.g., 'openai')",
default=ModelProviders.OPENAI,
)
config: Optional[Dict] = Field(
description="Configuration for the specific embedding_models model",
default=None,
)
_provider_configs: Dict[str, str] = {
"openai": "OpenAIEmbeddingConfig",
"jina_ai": "JinaAIEmbeddingConfig",
}
@model_validator(mode="after")
def validate_and_create_config(self) -> "EmbeddingModelConfig":
provider = self.provider.value
config = self.config
if provider not in self._provider_configs:
raise ValueError(f"Unsupported embedding_models provider: {provider}")
module = __import__(
f"autoflow.configs.models.embeddings.{provider}",
fromlist=[self._provider_configs[provider]],
)
config_class = getattr(module, self._provider_configs[provider])
if config is None:
config = {}
if not isinstance(config, dict):
if not isinstance(config, config_class):
raise ValueError(
f"Invalid config type for embedding_models provider: {provider}"
)
return self
self.config = config_class(**config)
return self
================================================
FILE: core/autoflow/configs/models/embeddings/common.py
================================================
from typing import Optional
from pydantic import BaseModel, Field
class BaseEmbeddingConfig(BaseModel):
model: str = Field(
description="The model to use for the embedding",
default="text-embedding-3-small",
)
max_tokens: Optional[int] = None
temperature: float = 0.1
================================================
FILE: core/autoflow/configs/models/embeddings/jina_ai.py
================================================
from autoflow.configs.models.embeddings.common import BaseEmbeddingConfig
class JinaAIEmbeddingConfig(BaseEmbeddingConfig):
pass
================================================
FILE: core/autoflow/configs/models/embeddings/openai.py
================================================
from autoflow.configs.models.embeddings.common import BaseEmbeddingConfig
class OpenAIEmbeddingConfig(BaseEmbeddingConfig):
pass
================================================
FILE: core/autoflow/configs/models/llms/__init__.py
================================================
from .base import LLMConfig
__all__ = ["LLMConfig"]
================================================
FILE: core/autoflow/configs/models/llms/base.py
================================================
from typing import Optional, Dict
from pydantic import BaseModel, Field, model_validator
from autoflow.configs.models.providers import ModelProviders
DEFAULT_TEMPERATURE = 0.1
class LLMConfig(BaseModel):
provider: ModelProviders = Field(
description="Provider of the large language models (LLM) (e.g., 'openai')",
default=ModelProviders.OPENAI,
)
config: Optional[Dict] = Field(
description="Configuration for the specific database",
default=None,
)
_llm_configs: Dict[str, str] = {
"openai": "OpenAILLMConfig",
"custom": "CustomLLMConfig",
}
@model_validator(mode="after")
def validate_and_create_config(self) -> "LLMConfig":
provider = self.provider.value
config = self.config
if provider not in self._llm_configs:
raise ValueError(f"Unsupported llm provider: {provider}")
module = __import__(
f"autoflow.configs.models.llms.{provider}",
fromlist=[self._llm_configs[provider]],
)
config_class = getattr(module, self._llm_configs[provider])
if config is None:
config = {}
if not isinstance(config, dict):
if not isinstance(config, config_class):
raise ValueError(f"Invalid config type for llm provider: {provider}")
return self
self.config = config_class(**config)
return self
================================================
FILE: core/autoflow/configs/models/llms/common.py
================================================
from typing import Optional
from pydantic import Field, BaseModel
class BaseLLMConfig(BaseModel):
model: str = Field(
description="The model to use for the LLM",
default="gpt-4o",
)
max_tokens: Optional[int] = None
temperature: float = 0.1
================================================
FILE: core/autoflow/configs/models/llms/openai.py
================================================
from autoflow.configs.models.llms.common import BaseLLMConfig
class OpenAILLMConfig(BaseLLMConfig):
pass
================================================
FILE: core/autoflow/configs/models/manager.py
================================================
from typing import Dict
from pydantic import BaseModel, Field
from autoflow.configs.models.providers import ProviderConfig
class ManagerConfig(BaseModel):
providers: Dict[str, ProviderConfig] = Field(default_factory=dict)
================================================
FILE: core/autoflow/configs/models/providers/__init__.py
================================================
from autoflow.configs.models.providers.base import (
ModelProviders,
ModelProviderInfo,
ProviderConfig,
)
model_providers = [
ModelProviderInfo(
name=ModelProviders.OPENAI,
display_name="OpenAI",
description="The OpenAI API provides a simple interface for developers to create an intelligence layer in their applications, powered by OpenAI's state of the art models.",
website="https://platform.openai.com",
supported_model_types=["llm", "text_embedding"],
)
]
model_provider_mappings = {provider.name: provider for provider in model_providers}
__all__ = [
"ModelProviders",
"ModelProviderInfo",
"ProviderConfig",
"model_providers",
"model_provider_mappings",
]
================================================
FILE: core/autoflow/configs/models/providers/base.py
================================================
from typing import Optional, Literal, List
from pydantic import BaseModel, Field
from litellm import LlmProviders
ModelProviders = LlmProviders
ModelType = Literal["llm", "text_embedding", "rerank"]
class ModelProviderInfo(BaseModel):
name: ModelProviders = Field(
description="The name of the model provider.",
)
logo: Optional[str] = Field(
description="The logo of the model provider", default=None
)
display_name: str = Field(
description="The name of the model provider",
)
description: str = Field(
description="The description of the model provider", default=None
)
website: Optional[str] = Field(
description="The website of the model provider", default=None
)
supported_model_types: List[ModelType] = Field(
description="The model types supported by the model provider"
)
class ProviderConfig(BaseModel):
api_key: Optional[str] = Field(
title="API key",
default=None,
)
api_base: Optional[str] = Field(
title="API Base",
default=None,
)
================================================
FILE: core/autoflow/configs/models/providers/jinaai.py
================================================
from autoflow.configs.models.providers.base import ProviderConfig
class JinaAIConfig(ProviderConfig):
pass
================================================
FILE: core/autoflow/configs/models/providers/openai.py
================================================
from autoflow.configs.models.providers.base import ProviderConfig
class OpenAIConfig(ProviderConfig):
pass
================================================
FILE: core/autoflow/configs/models/rerankers/__init__.py
================================================
from .base import RerankerConfig
__all__ = ["RerankerConfig"]
================================================
FILE: core/autoflow/configs/models/rerankers/base.py
================================================
from typing import Optional, Dict
from pydantic import BaseModel, Field, model_validator
from autoflow.configs.models.providers import ModelProviders
class RerankerConfig(BaseModel):
provider: ModelProviders = Field(
description="Provider of the reranker models (e.g., 'openai')",
default=ModelProviders.OPENAI,
)
config: Optional[Dict] = Field(
description="Configuration for the specific reranker model",
default=None,
)
_provider_configs: Dict[str, str] = {
"jina_ai": "JinaAIRerankerConfig",
"custom": "CustomRerankerConfig",
}
@model_validator(mode="after")
def validate_and_create_config(self) -> "RerankerConfig":
provider = self.provider.value
config = self.config
if provider not in self._provider_configs:
raise ValueError(f"Unsupported reranker provider: {provider}")
module = __import__(
f"autoflow.configs.models.rerankers.{provider}",
fromlist=[self._provider_configs[provider]],
)
config_class = getattr(module, self._provider_configs[provider])
if config is None:
config = {}
if not isinstance(config, dict):
if not isinstance(config, config_class):
raise ValueError(
f"Invalid config type for reranker provider: {provider}"
)
return self
self.config = config_class(**config)
return self
================================================
FILE: core/autoflow/configs/models/rerankers/common.py
================================================
from pydantic import BaseModel, Field
class BaseRerankerConfig(BaseModel):
model: str = Field(
description="The model to use for the reranker",
default="jina-reranker-v2-base-multilingual",
)
top_n: int = Field(
description="The number of results to return",
default=5,
)
================================================
FILE: core/autoflow/configs/models/rerankers/jina_ai.py
================================================
from autoflow.configs.models.rerankers.common import BaseRerankerConfig
class JinaAIRerankerConfig(BaseRerankerConfig):
pass
================================================
FILE: core/autoflow/data_types.py
================================================
from enum import Enum
import os
from typing import IO, Optional, Union, BinaryIO, TextIO
from urllib.parse import urlparse
class DataType(str, Enum):
MARKDOWN = "markdown"
PDF = "pdf"
DOCX = "docx"
PPTX = "pptx"
XLSX = "xlsx"
CSV = "csv"
SITEMAP = "sitemap"
HTML = "html"
def guess_datatype(source: Union[str, IO, BinaryIO, TextIO]) -> Optional[DataType]:
if isinstance(source, str):
url = urlparse(source)
if url.scheme == "" or url.scheme == "file":
return guess_by_filename(url.path)
elif url.scheme == "http" or url.scheme == "https":
return DataType.HTML
else:
if os.path.exists(source):
return guess_by_filename(source)
raise ValueError(f"Unsupported URL scheme: {url.scheme}")
elif isinstance(source, IO):
return guess_by_filename(source.name)
else:
return None
def guess_by_filename(filename: str) -> Optional[DataType]:
"""Helper function to guess data type from filename."""
lower = filename.lower()
if lower.endswith(".md"):
return DataType.MARKDOWN
elif lower.endswith(".pdf"):
return DataType.PDF
elif lower.endswith(".docx"):
return DataType.DOCX
elif lower.endswith(".pptx"):
return DataType.PPTX
elif lower.endswith(".xlsx"):
return DataType.XLSX
elif lower.endswith(".csv"):
return DataType.CSV
elif lower.endswith(".xml") and "sitemap" in lower:
return DataType.SITEMAP
elif lower.endswith((".html", ".htm")):
return DataType.HTML
else:
return None
================================================
FILE: core/autoflow/db.py
================================================
import sqlalchemy
from pytidb.utils import build_tidb_dsn
from sqlalchemy import Engine
from autoflow.configs.db import DatabaseConfig
def get_db_engine_from_config(db_config: DatabaseConfig) -> Engine:
if db_config.database_url is not None:
database_url = db_config.database_url
else:
database_url = str(
build_tidb_dsn(
host=db_config.host,
port=db_config.port,
username=db_config.username,
password=db_config.password,
database=db_config.database,
enable_ssl=db_config.enable_ssl,
)
)
# Notice:
# In order to save resource consumption, the tidb serverless cluster will "pause" automatically if there
# are no active connections for more than 5 minutes, it will close all connections on the server side,
# so we also need to recycle the connections from the connection pool on the client side.
db_engine = sqlalchemy.create_engine(
database_url,
pool_size=20,
max_overflow=40,
pool_recycle=300,
pool_pre_ping=True,
)
return db_engine
================================================
FILE: core/autoflow/knowledge_base/__init__.py
================================================
from .base import KnowledgeBase
__all__ = [
"KnowledgeBase",
]
================================================
FILE: core/autoflow/knowledge_base/base.py
================================================
import logging
import uuid
from typing import List, Optional, Any
from functools import partial
from os import cpu_count
from concurrent.futures import ThreadPoolExecutor
from pydantic import Field, PrivateAttr
from sqlalchemy import Engine
from llama_index.core.base.llms.types import ChatResponse
from autoflow.chunkers.base import Chunker
from autoflow.chunkers.helper import get_chunker_for_datatype
from autoflow.configs.knowledge_base import IndexMethod
from autoflow.data_types import DataType, guess_datatype
from autoflow.knowledge_base.prompts import QA_WITH_KNOWLEDGE_PROMPT_TEMPLATE
from autoflow.knowledge_graph.index import KnowledgeGraphIndex
from autoflow.loaders.base import Loader
from autoflow.loaders.helper import get_loader_for_datatype
from autoflow.models.llms import LLM
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.models.llms.dspy import get_dspy_lm_by_llm
from autoflow.models.rerank_models import RerankModel
from autoflow.types import BaseComponent, SearchMode
from autoflow.storage.doc_store import DocumentSearchResult, Document
logger = logging.getLogger(__name__)
class KnowledgeBase(BaseComponent):
_llm: LLM = PrivateAttr()
namespace: Optional[str] = Field(default=None)
name: Optional[str] = Field(default=None)
description: Optional[str] = Field(default=None)
index_methods: List[IndexMethod] = Field(default=[IndexMethod.VECTOR_SEARCH])
def __init__(
self,
db_engine: Engine = None,
namespace: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
index_methods: Optional[List[IndexMethod]] = None,
llm: Optional[LLM] = None,
embedding_model: Optional[EmbeddingModel] = None,
rerank_model: Optional[RerankModel] = None,
max_workers: Optional[int] = None,
):
super().__init__(
namespace=namespace,
name=name,
description=description,
index_methods=index_methods,
)
self._db_engine = db_engine
self._llm = llm
self._embedding_model = embedding_model
self._reranker_model = rerank_model
self._init_stores()
self._init_indexes()
self._max_workers = max_workers or cpu_count()
def _init_stores(self):
from autoflow.storage.doc_store.tidb_doc_store import TiDBDocumentStore
from autoflow.storage.graph_store.tidb_graph_store import TiDBGraphStore
from pytidb import TiDBClient
self._tidb_client = TiDBClient(self._db_engine)
self._doc_store = TiDBDocumentStore(
client=self._tidb_client,
embedding_model=self._embedding_model,
namespace=self.namespace,
)
self._kg_store = TiDBGraphStore(
client=self._tidb_client,
embedding_model=self._embedding_model,
namespace=self.namespace,
)
def _init_indexes(self):
self._dspy_lm = get_dspy_lm_by_llm(self._llm)
self._kg_index = KnowledgeGraphIndex(
kg_store=self._kg_store,
dspy_lm=self._dspy_lm,
embedding_model=self._embedding_model,
)
def class_name(self):
return "KnowledgeBase"
def documents(self):
return self._doc_store
def knowledge_graph(self):
return self._kg_store
def add(
self,
source: str | list[str],
data_type: Optional[DataType] = None,
loader: Optional[Loader] = None,
chunker: Optional[Chunker] = None,
) -> List[Document]:
if data_type is None:
data_type = guess_datatype(source)
if data_type is None:
raise ValueError("Please provide a valid data type.")
if loader is None:
loader = get_loader_for_datatype(data_type)
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
build_index_for_document = partial(
self.build_index_for_document, chunker=chunker
)
results = executor.map(build_index_for_document, loader.load(source))
return_documents = []
for result in results:
return_documents.append(result)
return return_documents
def build_index_for_document(
self,
document: Document,
chunker: Optional[Chunker] = None,
) -> List[Document]:
"""
Build index for a document.
Args:
document: The document to build index for.
chunker: The chunker to use to chunk the document.
Returns:
A list of documents that are the result of indexing the original document.
"""
# TODO: handle duplicate documents.
if chunker is None:
chunker = get_chunker_for_datatype(document.data_type)
chunked_document = chunker.chunk(document)
self.add_document(chunked_document)
if IndexMethod.KNOWLEDGE_GRAPH in self.index_methods:
def add_chunk_to_kg(chunk):
logger.info("Adding chunk to knowledge graph.", chunk.id)
self._kg_index.add_chunk(chunk)
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
list(executor.map(add_chunk_to_kg, chunked_document.chunks))
return chunked_document
# Document management.
def add_document(self, document: Document):
self._doc_store.add([document])
def add_documents(self, documents: List[Document]):
return self._doc_store.add(documents)
def list_documents(self) -> List[Document]:
return self._doc_store.list()
def get_document(self, doc_id: uuid.UUID) -> Document:
return self._doc_store.get(doc_id)
def delete_document(self, doc_id: uuid.UUID) -> None:
return self._doc_store.delete(doc_id)
# Search
def search(self):
# TODO: Support one interface search documents and knowledge graph at the same time.
raise NotImplementedError()
def search_documents(
self,
query: str,
mode: SearchMode = "vector",
similarity_threshold: Optional[float] = None,
num_candidate: Optional[int] = None,
top_k: Optional[int] = 5,
**kwargs: Any,
) -> DocumentSearchResult:
return self._doc_store.search(
query=query,
mode=mode,
similarity_threshold=similarity_threshold,
num_candidate=num_candidate,
top_k=top_k,
**kwargs,
)
def search_knowledge_graph(
self,
query: str,
depth: int = 2,
metadata_filters: Optional[dict] = None,
**kwargs,
):
return self._kg_index.retrieve(
query=query,
depth=depth,
metadata_filters=metadata_filters,
**kwargs,
)
# Generation.
def ask(self, question: str) -> ChatResponse:
result = self.search_documents(
query=question,
similarity_threshold=0.4,
top_k=5,
)
chunks = result.chunks
knowledge_graph = self.search_knowledge_graph(
query=question,
)
messages = QA_WITH_KNOWLEDGE_PROMPT_TEMPLATE.format_messages(
llm=self._llm,
query_str=question,
chunks=chunks,
knowledge_graph=knowledge_graph,
)
return self._llm.chat(messages)
# Knowledge Base Operation.
def reset(self):
self._doc_store.reset()
self._kg_store.reset()
================================================
FILE: core/autoflow/knowledge_base/prompts.py
================================================
from llama_index.core.prompts.rich import RichPromptTemplate
QA_WITH_KNOWLEDGE_PROMPT_TEMPLATE = RichPromptTemplate(
template_str="""
{% chat role="system" %}
We have provided context information below.
---------------------
{% if knowledge_graph %}
{% for entity in knowledge_graph.entities %}
{{ entity.name }}: {{ entity.description }}
{% endfor %}
{% for relationship in knowledge_graph.relationships %}
{{relationship.source_entity.name}} -> {{ relationship.description }} -> {{relationship.target_entity.name}}
{% endfor %}
{% endif %}
{% for chunk in chunks %}
{{ chunk.text }}
{% endfor %}
---------------------
Given this information, please give a comprehensive answer to the question in Markdown format:
{% endchat %}
{% chat role="user" %}
{{ query_str }}
{% endchat %}
"""
)
================================================
FILE: core/autoflow/knowledge_graph/__init__.py
================================================
================================================
FILE: core/autoflow/knowledge_graph/extractors/__init__.py
================================================
================================================
FILE: core/autoflow/knowledge_graph/extractors/base.py
================================================
from abc import abstractmethod
from autoflow.types import BaseComponent
from autoflow.knowledge_graph.types import GeneratedKnowledgeGraph
class KGExtractor(BaseComponent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abstractmethod
def extract(self, text: str) -> GeneratedKnowledgeGraph:
raise NotImplementedError()
================================================
FILE: core/autoflow/knowledge_graph/extractors/simple.py
================================================
import dspy
from autoflow.knowledge_graph.extractors.base import KGExtractor
from autoflow.knowledge_graph.programs.extract_covariates import (
EntityCovariateExtractor,
)
from autoflow.knowledge_graph.programs.extract_graph import KnowledgeGraphExtractor
from autoflow.knowledge_graph.types import GeneratedKnowledgeGraph
class SimpleKGExtractor(KGExtractor):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self._dspy_lm = dspy_lm
self._graph_extractor = KnowledgeGraphExtractor(dspy_lm)
self._entity_metadata_extractor = EntityCovariateExtractor(dspy_lm)
def extract(self, text: str) -> GeneratedKnowledgeGraph:
knowledge_graph = self._graph_extractor.forward(text)
knowledge_graph.entities = self._entity_metadata_extractor.forward(
text, knowledge_graph.entities
)
return knowledge_graph
================================================
FILE: core/autoflow/knowledge_graph/index.py
================================================
import logging
from typing import Optional
import dspy
from autoflow.knowledge_graph.extractors.simple import SimpleKGExtractor
from autoflow.knowledge_graph.retrievers.weighted import WeightedGraphRetriever
from autoflow.knowledge_graph.types import (
RetrievedKnowledgeGraph,
)
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.storage.doc_store.types import Chunk
from autoflow.storage.graph_store.base import GraphStore
from autoflow.storage.graph_store.types import KnowledgeGraph
from autoflow.types import BaseComponent
logger = logging.getLogger(__name__)
class KnowledgeGraphIndex(BaseComponent):
def __init__(
self,
kg_store: GraphStore,
dspy_lm: dspy.LM,
embedding_model: EmbeddingModel,
):
super().__init__()
self._kg_store = kg_store
self._dspy_lm = dspy_lm
self._embedding_model = embedding_model
self._kg_extractor = SimpleKGExtractor(self._dspy_lm)
def add_text(self, text: str) -> Optional[KnowledgeGraph]:
knowledge_graph = self._kg_extractor.extract(text)
return self._kg_store.add(knowledge_graph.to_create())
def add_chunk(self, chunk: Chunk) -> Optional[KnowledgeGraph]:
# Check if the chunk has been added.
exists_relationships = self._kg_store.list_relationships(chunk_id=chunk.id)
if len(exists_relationships) > 0:
logger.warning(
"The subgraph of chunk %s has already been added, skip.", chunk.id
)
return None
logger.info("Extracting knowledge graph from chunk %s", chunk.id)
knowledge_graph = self._kg_extractor.extract(chunk)
logger.info("Knowledge graph extracted from chunk %s", chunk.id)
return self._kg_store.add(knowledge_graph.to_create())
def retrieve(
self,
query: str,
depth: int = 2,
metadata_filters: Optional[dict] = None,
**kwargs,
) -> RetrievedKnowledgeGraph:
retriever = WeightedGraphRetriever(
self._kg_store,
self._embedding_model,
**kwargs,
)
return retriever.retrieve(
query=query,
depth=depth,
metadata_filters=metadata_filters,
)
================================================
FILE: core/autoflow/knowledge_graph/programs/__init__.py
================================================
from .extract_graph import ExtractKnowledgeGraph
from .extract_covariates import ExtractEntityCovariate
__all__ = ["ExtractKnowledgeGraph", "ExtractEntityCovariate"]
================================================
FILE: core/autoflow/knowledge_graph/programs/eval_graph.py
================================================
import logging
import dspy
from dspy import Predict
from pydantic import BaseModel
from autoflow.knowledge_graph.types import GeneratedKnowledgeGraph
# Initialize logger
logger = logging.getLogger(__name__)
class EvaluateKnowledgeGraph(dspy.Signature):
"""
Evaluate the differences between two knowledge graphs and provide scores for each entity and relationship,
as well as a final score for the entire knowledge graph.
Steps:
1. Iterate over each entity in the expected knowledge graph
2. For each expected entity, find the most similar entity in the actual knowledge graph
3. Calculate the score (range from 0 to 1) for the entity based on the similarity
4. Iterate over each relationship in the expected knowledge graph
5. For each expected relationship, find the most similar relationship in the actual knowledge graph
6. Calculate the score (range from 0 to 1) for the relationship based on the similarity
7. Average all the scores of entities and relationships to get the final score
Please only respond in JSON format.
"""
actual_graph: GeneratedKnowledgeGraph = dspy.InputField(
desc="The actual knowledge graph extracted from the text"
)
expected_graph: GeneratedKnowledgeGraph = dspy.InputField(
desc="The expected knowledge graph"
)
score: float = dspy.OutputField(
desc="The final score of the actual knowledge graph"
)
class KGEvaluationResult(BaseModel):
expected: GeneratedKnowledgeGraph
actual: GeneratedKnowledgeGraph
score: float
class KnowledgeGraphEvaluator(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.program = Predict(EvaluateKnowledgeGraph)
def forward(
self,
actual: GeneratedKnowledgeGraph,
expected: GeneratedKnowledgeGraph,
) -> KGEvaluationResult:
# Evaluate the knowledge graph using the provided language model
with dspy.settings.context(lm=self.dspy_lm):
prediction = self.program(actual_graph=actual, expected_graph=expected)
return KGEvaluationResult(
actual=actual,
expected=expected,
score=prediction.score,
)
================================================
FILE: core/autoflow/knowledge_graph/programs/extract_covariates.py
================================================
import logging
from typing import List, Mapping, Any
import dspy
from dspy import Predict
from pydantic import BaseModel, Field
from autoflow.knowledge_graph.types import GeneratedEntity
logger = logging.getLogger(__name__)
class InputEntity(BaseModel):
"""List of entities extracted from the text to form the knowledge graph"""
name: str = Field(description="Name of the entity")
description: str = Field(description="Description of the entity")
class OutputEntity(BaseModel):
"""List of entities extracted from the text to form the knowledge graph"""
name: str = Field(description="Name of the entity")
description: str = Field(description="Description of the entity")
covariates: Mapping[str, Any] = Field(
description=(
"The attributes (which is a comprehensive json TREE, the first field is always: 'topic') to claim the entity. "
)
)
class ExtractEntityCovariate(dspy.Signature):
"""Please carefully review the provided text and entities list which are already identified in the text.
Focusing on identifying detailed covariates associated with each entities provided.
Extract and link the covariates (which is a comprehensive json TREE, the first field is always: "topic") to their respective entities.
Ensure all extracted covariates is clearly connected to the correct entity for accuracy and comprehensive understanding.
Ensure that all extracted covariates are factual and verifiable within the text itself, without relying on external knowledge or assumptions.
Collectively, the covariates should provide a thorough and precise summary of the entity's characteristics as described in the source material.
Please only response in JSON format.
"""
text = dspy.InputField(
desc="a paragraph of text to extract covariates to claim the entities."
)
input: List[InputEntity] = dspy.InputField(
desc="List of entities identified in the text."
)
output: List[OutputEntity] = dspy.OutputField(
desc="List of entities with their covariates."
)
class EntityCovariateExtractor(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.program = Predict(ExtractEntityCovariate)
def forward(
self, text: str, entities: List[GeneratedEntity]
) -> List[GeneratedEntity]:
with dspy.settings.context(lm=self.dspy_lm):
input_entities = [
InputEntity(
name=entity.name,
description=entity.description,
)
for entity in entities
]
predict = self.program(
text=text,
input=input_entities,
)
output_entity_map = {entity.name: entity for entity in predict.output}
for entity in entities:
if entity.name in output_entity_map:
# Update the covariates in the metadata of the entity.
entity.meta = output_entity_map[entity.name].covariates
return entities
================================================
FILE: core/autoflow/knowledge_graph/programs/extract_graph.py
================================================
import logging
from typing import List
import dspy
from dspy import Predict
from pydantic import BaseModel, Field
from autoflow.knowledge_graph.types import (
GeneratedEntity,
GeneratedKnowledgeGraph,
GeneratedRelationship,
)
logger = logging.getLogger(__name__)
class PredictEntity(BaseModel):
"""Entity extracted from the text to form the knowledge graph"""
name: str = Field(
description="Name of the entity, it should be a clear and concise term"
)
description: str = Field(
description=(
"Description of the entity, it should be a complete and comprehensive sentence, not few words. "
"Sample description of entity 'TiDB in-place upgrade': "
"'Upgrade TiDB component binary files to achieve upgrade, generally use rolling upgrade method'"
)
)
class PredictRelationship(BaseModel):
"""Relationship extracted from the text to form the knowledge graph"""
source_entity: str = Field(
description="Source entity name of the relationship, it should an existing entity in the Entity list"
)
target_entity: str = Field(
description="Target entity name of the relationship, it should an existing entity in the Entity list"
)
relationship_desc: str = Field(
description=(
"Description of the relationship, it should be a complete and comprehensive sentence, not few words. "
"For example: 'TiDB will release a new LTS version every 6 months.'"
)
)
class PredictKnowledgeGraph(BaseModel):
"""Graph representation of the knowledge for text."""
entities: List[PredictEntity] = Field(
description="List of entities in the knowledge graph"
)
relationships: List[PredictRelationship] = Field(
description="List of relationships in the knowledge graph"
)
def to_pandas(self):
from pandas import DataFrame
return {
"entities": DataFrame(
[
{
"name": entity.name,
"description": entity.description,
}
for entity in self.entities
]
),
"relationships": DataFrame(
[
{
"source_entity": relationship.source_entity,
"relationship_desc": relationship.relationship_desc,
"target_entity": relationship.target_entity,
}
for relationship in self.relationships
]
),
}
class ExtractKnowledgeGraph(dspy.Signature):
"""Carefully analyze the provided text from database documentation and community blogs to thoroughly identify all entities related to database technologies, including both general concepts and specific details.
Follow these Step-by-Step Analysis:
1. Extract Meaningful Entities:
- Identify all significant nouns, proper nouns, and technical terminologies that represent database-related concepts, objects, components, features, issues, key steps, execute order, user case, locations, versions, or any substantial entities.
- Ensure that you capture entities across different levels of detail, from high-level overviews to specific technical specifications, to create a comprehensive representation of the subject matter.
- Choose names for entities that are specific enough to indicate their meaning without additional context, avoiding overly generic terms.
- Consolidate similar entities to avoid redundancy, ensuring each represents a distinct concept at appropriate granularity levels.
2. Establish Relationships:
- Carefully examine the text to identify all relationships between clearly-related entities, ensuring each relationship is correctly captured with accurate details about the interactions.
- Analyze the context and interactions between the identified entities to determine how they are interconnected, focusing on actions, associations, dependencies, or similarities.
- Clearly define the relationships, ensuring accurate directionality that reflects the logical or functional dependencies among entities. \
This means identifying which entity is the source, which is the target, and what the nature of their relationship is (e.g., $source_entity depends on $target_entity for $relationship).
Some key points to consider:
- Please endeavor to extract all meaningful entities and relationships from the text, avoid subsequent additional gleanings.
Objective: Produce a detailed and comprehensive knowledge graph that captures the full spectrum of entities mentioned in the text, along with their interrelations, reflecting both broad concepts and intricate details specific to the database domain.
Please only response in JSON format.
"""
text = dspy.InputField(
desc="a paragraph of text to extract entities and relationships to form a knowledge graph"
)
knowledge: PredictKnowledgeGraph = dspy.OutputField(
desc="Graph representation of the knowledge extracted from the text."
)
class KnowledgeGraphExtractor(dspy.Module):
def __init__(self, dspy_lm: dspy.LM):
super().__init__()
self.dspy_lm = dspy_lm
self.program = Predict(ExtractKnowledgeGraph)
def forward(self, text: str) -> GeneratedKnowledgeGraph:
with dspy.settings.context(lm=self.dspy_lm):
prediction = self.program(text=text)
entities = [
GeneratedEntity(
name=entity.name,
description=entity.description,
meta={},
)
for entity in prediction.knowledge.entities
]
relationships = [
GeneratedRelationship(
source_entity_name=relationship.source_entity,
target_entity_name=relationship.target_entity,
description=relationship.relationship_desc,
meta={},
)
for relationship in prediction.knowledge.relationships
]
return GeneratedKnowledgeGraph(
entities=entities,
relationships=relationships,
)
================================================
FILE: core/autoflow/knowledge_graph/retrievers/__init__.py
================================================
================================================
FILE: core/autoflow/knowledge_graph/retrievers/base.py
================================================
from abc import abstractmethod, ABC
from autoflow.storage.graph_store.base import GraphStore
from autoflow.storage.types import QueryBundle
from autoflow.knowledge_graph.types import RetrievedKnowledgeGraph
class KGRetriever(ABC):
def __init__(self, knowledge_graph_store: GraphStore):
self._kg_store = knowledge_graph_store
@abstractmethod
def retrieve(
self,
query: QueryBundle,
depth: int = 2,
meta_filters: dict = None,
) -> RetrievedKnowledgeGraph:
raise NotImplementedError
================================================
FILE: core/autoflow/knowledge_graph/retrievers/simple.py
================================================
================================================
FILE: core/autoflow/knowledge_graph/retrievers/weighted.py
================================================
from collections import defaultdict
from typing import List, Set, Tuple, Optional
from autoflow.knowledge_graph.types import (
RetrievedKnowledgeGraph,
RetrievedRelationship,
)
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.storage.graph_store import GraphStore
from autoflow.knowledge_graph.retrievers.base import KGRetriever
from autoflow.storage.graph_store.types import (
Entity,
EntityDegree,
EntityFilters,
Relationship,
EntityType,
RelationshipFilters,
)
from autoflow.storage.types import QueryBundle
# The configuration for the weight coefficient
# format: ((min_weight, max_weight), coefficient)
DEFAULT_WEIGHT_COEFFICIENTS = [
((0, 100), 0.01),
((100, 1000), 0.001),
((1000, 10000), 0.0001),
((10000, float("inf")), 0.00001),
]
# The configuration for the range search
# format: ((min_distance, max_distance), search_ratio)
# The sum of search ratio should be 1 except some case we want to search as many as possible relationships.
# In this case, we set the search ratio to 1, and the other search ratio sum should be 1
DEFAULT_RANGE_SEARCH_CONFIG = [
((0.0, 0.25), 1),
((0.25, 0.35), 0.7),
((0.35, 0.45), 0.2),
((0.45, 0.55), 0.1),
]
DEFAULT_DEGREE_COEFFICIENT = 0.001
class WeightedGraphRetriever(KGRetriever):
def __init__(
self,
kg_store: GraphStore,
embedding_model: EmbeddingModel,
with_degree: bool = False,
alpha: float = 1,
weight_coefficients: List[Tuple[float, float]] = None,
search_range_config: List[Tuple[Tuple[float, float], float]] = None,
degree_coefficient: float = DEFAULT_DEGREE_COEFFICIENT,
fetch_synopsis_entities_num: int = 2,
max_neighbors: int = 10,
):
super().__init__(kg_store)
self._embedding_model = embedding_model
self.with_degree = with_degree
self.alpha = alpha
self.weight_coefficients = weight_coefficients or DEFAULT_WEIGHT_COEFFICIENTS
self.search_range_config = search_range_config or DEFAULT_RANGE_SEARCH_CONFIG
self.degree_coefficient = degree_coefficient
self.fetch_synopsis_entities_num = fetch_synopsis_entities_num
self.max_neighbors = max_neighbors
def retrieve(
self,
query: str,
depth: int = 2,
metadata_filters: Optional[dict] = None,
) -> RetrievedKnowledgeGraph:
query_embedding = self._embedding_model.get_query_embedding(query)
visited_relationships = set()
visited_entities = set()
new_relationships = self._weighted_search_relationships(
query_embedding=query_embedding,
visited_relationships=visited_relationships,
visited_entities=visited_entities,
metadata_filters=metadata_filters,
)
if len(new_relationships) == 0:
return RetrievedKnowledgeGraph(
entities=[],
relationships=[],
)
for rel, score in new_relationships:
visited_relationships.add(
RetrievedRelationship(
**rel.model_dump(),
similarity_score=score,
score=score,
)
)
visited_entities.add(rel.source_entity)
visited_entities.add(rel.target_entity)
for _ in range(depth - 1):
actual_number = 0
progress = 0
for search_config in DEFAULT_RANGE_SEARCH_CONFIG:
search_ratio = search_config[1]
search_distance_range = search_config[0]
remaining_number = self.max_neighbors - actual_number
# calculate the expected number based search progress
# It's an accumulative search, so the expected number should be the difference between the expected number and the actual number
expected_number = (
int((search_ratio + progress) * self.max_neighbors - actual_number)
if progress * self.max_neighbors > actual_number
else int(search_ratio * self.max_neighbors)
)
if expected_number > remaining_number:
expected_number = remaining_number
if remaining_number <= 0:
break
new_relationships = self._weighted_search_relationships(
query_embedding=query_embedding,
visited_relationships=visited_relationships,
visited_entities=visited_entities,
search_distance_range=search_distance_range,
top_k=expected_number,
metadata_filters=metadata_filters,
)
for rel, score in new_relationships:
visited_relationships.add(
RetrievedRelationship(
**rel.model_dump(),
similarity_score=score,
score=score,
)
)
visited_entities.add(rel.source_entity)
visited_entities.add(rel.target_entity)
actual_number += len(new_relationships)
# search_ratio == 1 won't count the progress
if search_ratio != 1:
progress += search_ratio
# Fetch related synopsis entities.
synopsis_entities = self._kg_store.search_entities(
query=QueryBundle(query_embedding=query_embedding),
top_k=self.fetch_synopsis_entities_num,
filters=EntityFilters(
entity_type=EntityType.synopsis,
),
)
if len(synopsis_entities) > 0:
visited_entities.update(synopsis_entities)
# Rerank final relationships.
return_relationships = list(visited_relationships)
return_relationships.sort(key=lambda x: x.score, reverse=True)
self._fill_entity(return_relationships)
return_entities = [Entity(**e.model_dump()) for e in visited_entities]
return RetrievedKnowledgeGraph(
entities=return_entities,
relationships=return_relationships,
)
def _fill_entity(self, relationships: List[RetrievedRelationship]):
# FIXME: pytidb should return the relationship field: target_entity, source_entity.
entity_ids = [item.target_entity_id for item in relationships]
entity_ids.extend([item.source_entity_id for item in relationships])
entities = self._kg_store.list_entities(
filters=EntityFilters(entity_id=entity_ids)
)
entity_map = {entity.id: entity for entity in entities}
for rel in relationships:
rel.target_entity = Entity(**entity_map[rel.target_entity_id].model_dump())
rel.source_entity = Entity(**entity_map[rel.source_entity_id].model_dump())
def _weighted_search_relationships(
self,
query_embedding: List[float],
visited_relationships: Set[RetrievedRelationship],
visited_entities: Set[Entity],
search_distance_range: Tuple[float, float] = (0, 1),
top_k: int = 10,
metadata_filters: Optional[dict] = None,
) -> List[RetrievedRelationship]:
visited_entity_ids = [e.id for e in visited_entities]
visited_relationship_ids = [r.id for r in visited_relationships]
relationships_with_score = self._kg_store.search_relationships(
query=QueryBundle(query_embedding=query_embedding),
filters=RelationshipFilters(
source_entity_id=visited_entity_ids,
exclude_relationship_ids=visited_relationship_ids,
metadata=metadata_filters,
),
distance_range=search_distance_range,
top_k=top_k,
)
return self._rank_relationships(
relationships_with_score=relationships_with_score,
top_k=top_k,
)
def _rank_relationships(
self,
relationships_with_score: List[Tuple[Relationship, float]],
top_k: int = 10,
) -> List[Tuple[Relationship, float]]:
"""
Rerank the relationship based on distance and weight
"""
# TODO: the degree can br pre-calc and stored in the database in advanced.
if self.with_degree:
entity_ids = set()
for r, _ in relationships_with_score:
entity_ids.add(r.source_entity_id)
entity_ids.add(r.target_entity_id)
entity_degrees = self._kg_store.bulk_calc_entities_degrees(entity_ids)
else:
entity_degrees = defaultdict(EntityDegree)
reranked_relationships = []
for r, similarity_score in relationships_with_score:
embedding_distance = 1 - similarity_score
source_in_degree = entity_degrees[r.source_entity_id].in_degree
target_out_degree = entity_degrees[r.target_entity_id].out_degree
final_score = self._calc_relationship_weighted_score(
embedding_distance,
r.weight,
source_in_degree,
target_out_degree,
)
reranked_relationships.append((r, final_score))
# Rerank relationships based on the calculated score.
reranked_relationships.sort(key=lambda x: x[1], reverse=True)
return reranked_relationships[:top_k]
def _calc_relationship_weighted_score(
self,
embedding_distance: float,
weight: int = 0,
in_degree: int = 0,
out_degree: int = 0,
) -> float:
weighted_score = self._calc_weight_score(weight)
degree_score = 0
if self.with_degree:
degree_score = self._calc_degree_score(in_degree, out_degree)
return self.alpha * (1 / embedding_distance) + weighted_score + degree_score
def _calc_weight_score(self, weight: float) -> float:
weight_score = 0.0
remaining_weight = weight
for weight_range, coefficient in self.weight_coefficients:
if remaining_weight <= 0:
break
lower_bound, upper_bound = weight_range
applicable_weight = min(upper_bound - lower_bound, remaining_weight)
weight_score += applicable_weight * coefficient
remaining_weight -= applicable_weight
return weight_score
def _calc_degree_score(self, in_degree: int, out_degree: int) -> float:
return (in_degree - out_degree) * self.degree_coefficient
================================================
FILE: core/autoflow/knowledge_graph/types.py
================================================
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
from autoflow.storage.graph_store.types import (
Entity,
EntityCreate,
KnowledgeGraphCreate,
Relationship,
RelationshipCreate,
)
# Generated Knowledge Graph
class GeneratedEntity(BaseModel):
name: str
description: str
meta: Dict[str, Any] = Field(default_factory=dict)
class GeneratedRelationship(BaseModel):
source_entity_name: str
target_entity_name: str
description: str
meta: Dict[str, Any] = Field(default_factory=dict)
class GeneratedKnowledgeGraph(BaseModel):
entities: List[GeneratedEntity]
relationships: List[GeneratedRelationship]
def to_create(
self,
chunk_id: Optional[UUID] = None,
document_id: Optional[UUID] = None,
) -> KnowledgeGraphCreate:
return KnowledgeGraphCreate(
entities=[
EntityCreate(
name=e.name,
description=e.description,
meta=e.meta,
)
for e in self.entities
],
relationships=[
RelationshipCreate(
source_entity_name=r.source_entity_name,
target_entity_name=r.target_entity_name,
description=r.description,
meta=r.meta,
weight=0,
chunk_id=chunk_id,
document_id=document_id,
)
for r in self.relationships
],
)
# Retrieved Knowledge Graph
class RetrievedEntity(Entity):
similarity_score: Optional[float] = Field(default=None)
score: Optional[float] = Field(default=None)
def __hash__(self):
return hash(self.id)
def __eq__(self, other: "RetrievedEntity"):
return self.id == other.id
class RetrievedRelationship(Relationship):
similarity_score: Optional[float] = Field(default=None)
score: Optional[float] = Field(default=None)
def __hash__(self):
return hash(self.id)
def __eq__(self, other: "RetrievedRelationship"):
return self.id == other.id
class RetrievedKnowledgeGraph(BaseModel):
query: Optional[str] = Field(
description="The query used to retrieve the knowledge graph",
default=None,
)
entities: List[Entity] = Field(
description="List of entities in the knowledge graph", default_factory=list
)
relationships: List[RetrievedRelationship] = Field(
description="List of relationships in the knowledge graph", default_factory=list
)
================================================
FILE: core/autoflow/loaders/__init__.py
================================================
from .base import Loader
__all__ = [
"Loader",
]
================================================
FILE: core/autoflow/loaders/base.py
================================================
from abc import abstractmethod
from typing import Generator
from autoflow.types import BaseComponent
from autoflow.storage.doc_store import Document
class Loader(BaseComponent):
@abstractmethod
def load(
self, source: str | list[str], **kwargs
) -> Generator[Document, None, None]:
raise NotImplementedError
class FileLoader(Loader):
def load(self, files: str | list[str], **kwargs) -> Generator[Document, None, None]:
if isinstance(files, str):
files = [files]
for file in files:
yield self._load_file(file)
@abstractmethod
def _load_file(self, file: str) -> Document:
raise NotImplementedError
================================================
FILE: core/autoflow/loaders/helper.py
================================================
from autoflow.data_types import DataType
from autoflow.loaders import Loader
def get_loader_for_datatype(datatype: DataType) -> Loader:
if datatype == DataType.MARKDOWN:
from autoflow.loaders.markdown import MarkdownLoader
return MarkdownLoader()
elif datatype == DataType.PDF:
from autoflow.loaders.pdf import PDFLoader
return PDFLoader()
elif datatype == DataType.HTML:
from autoflow.loaders.webpage import WebpageLoader
return WebpageLoader()
else:
raise ValueError(f"Unsupported loader for data type: {datatype}")
================================================
FILE: core/autoflow/loaders/markdown.py
================================================
from autoflow.data_types import DataType
from autoflow.loaders.base import FileLoader
from autoflow.storage.doc_store import Document
class MarkdownLoader(FileLoader):
def _load_file(self, file: str) -> Document:
with open(file, "r", encoding="utf-8") as f:
content = f.read()
return Document(
name=file,
data_type=DataType.MARKDOWN,
content=content,
)
================================================
FILE: core/autoflow/loaders/pdf.py
================================================
from llama_index.readers import PDFReader
from autoflow.data_types import DataType
from autoflow.loaders.base import FileLoader
from autoflow.storage.doc_store import Document
class PDFLoader(FileLoader):
def _load_file(self, file: str) -> Document:
reader = PDFReader()
documents = reader.load_data(file)
content = documents[0].text
return Document(
name=file.name,
data_type=DataType.PDF,
content=content,
)
================================================
FILE: core/autoflow/loaders/webpage.py
================================================
import logging
from datetime import datetime, UTC
from typing import Generator, Optional, List
from playwright.sync_api import sync_playwright
from bs4 import BeautifulSoup
from markdownify import MarkdownConverter
from autoflow.loaders.base import Loader
from autoflow.storage.doc_store import Document
from autoflow.data_types import DataType
logger = logging.getLogger(__name__)
# Common tags and classes to ignore when processing web content
IGNORE_TAGS = ["nav", "header", "footer", "script", "style", "noscript", "iframe"]
IGNORE_CLASSES = [
"nav",
"navigation",
"footer",
"header",
"sidebar",
"menu",
"ad",
"advertisement",
]
class WebpageLoader(Loader):
def __init__(
self,
ignore_tags: Optional[List[str]] = None,
ignore_classes: Optional[List[str]] = None,
):
super().__init__()
self._ignore_tags = ignore_tags or IGNORE_TAGS
self._ignore_classes = ignore_classes or IGNORE_CLASSES
def load(self, urls: str | list[str], **kwargs) -> Generator[Document, None, None]:
if isinstance(urls, str):
urls = [urls]
visited = set()
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
try:
for url in urls:
try:
page = browser.new_page()
response = page.goto(url)
final_url = page.url
if final_url in visited:
continue
if response is None or response.status >= 400:
logger.error(
f"Failed to load page: {url}, response status: {response.status if response else 'None'}, skipping"
)
continue
# Parse the content
soup = BeautifulSoup(page.content(), "html.parser")
# Remove unwanted elements
for tag in self._ignore_tags:
for element in soup.find_all(tag):
element.extract()
for class_name in self._ignore_classes:
for element in soup.find_all(class_=class_name):
element.extract()
# Convert to markdown
content = MarkdownConverter().convert_soup(soup)
title = page.title() or final_url
visited.add(final_url)
# Create document
document = Document(
name=title,
content=content,
data_type=DataType.HTML,
meta={
"source_uri": final_url,
"original_uri": url,
"last_modified": datetime.now(UTC).isoformat(),
},
)
yield document
except Exception as e:
logger.error(f"Error processing URL {url}: {str(e)}")
continue
finally:
if "page" in locals():
page.close()
finally:
browser.close()
================================================
FILE: core/autoflow/main.py
================================================
from typing import List, Optional
from sqlalchemy.engine import Engine
from autoflow.configs.db import DatabaseConfig
from autoflow.configs.knowledge_base import IndexMethod
from autoflow.configs.main import Config
from autoflow.db import get_db_engine_from_config
from autoflow.knowledge_base.base import KnowledgeBase
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.models.llms import LLM
from autoflow.models.manager import ModelManager, model_manager as default_model_manager
from autoflow.models.rerank_models import RerankModel
class Autoflow:
_db_engine = None
def __init__(
self,
db_engine: Engine,
model_manager: Optional[ModelManager] = None,
):
self._db_engine = db_engine
self._model_manager = model_manager or default_model_manager
@classmethod
def from_config(cls, config: Config) -> "Autoflow":
db_engine = cls._init_db_engine(config.db)
model_manager = ModelManager()
return cls(db_engine=db_engine, model_manager=model_manager)
@classmethod
def _init_db_engine(cls, db_config: DatabaseConfig) -> Engine:
if db_config.provider != "tidb":
raise NotImplementedError(
f"Unsupported database provider: {db_config.provider}."
)
return get_db_engine_from_config(db_config)
@property
def db_engine(self) -> Engine:
return self._db_engine
@property
def llm_manager(self) -> "ModelManager":
return self._model_manager
def create_knowledge_base(
self,
name: str,
namespace: Optional[str] = None,
description: Optional[str] = None,
index_methods: Optional[List[IndexMethod]] = None,
llm: Optional[LLM] = None,
embedding_model: Optional[EmbeddingModel] = None,
rerank_model: Optional[RerankModel] = None,
):
return KnowledgeBase(
db_engine=self.db_engine,
namespace=namespace,
name=name,
description=description,
index_methods=index_methods,
llm=llm,
embedding_model=embedding_model,
rerank_model=rerank_model,
)
================================================
FILE: core/autoflow/models/__init__.py
================================================
from .manager import (
ModelManager,
model_manager,
)
__all__ = [
"ModelManager",
"model_manager",
]
================================================
FILE: core/autoflow/models/embedding_models/__init__.py
================================================
from .litellm import LiteLLMEmbedding
EmbeddingModel = LiteLLMEmbedding
__all__ = ["EmbeddingModel"]
================================================
FILE: core/autoflow/models/embedding_models/litellm.py
================================================
from typing import Any, List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.embeddings import BaseEmbedding
def get_embeddings(
api_key: str,
api_base: str,
model_name: str,
input: List[str],
timeout: int = 60,
**kwargs: Any,
) -> List[List[float]]:
"""
Retrieve embeddings for a given list of input strings using the specified model.
Args:
api_key (str): The API key for authentication.
api_base (str): The base URL of the LiteLLM proxy server.
model_name (str): The name of the model to use for generating embeddings.
input (List[str]): A list of input strings for which embeddings are to be generated.
timeout (float): The timeout value for the API call, default 60 secs.
**kwargs (Any): Additional keyword arguments to be passed to the embedding function.
Returns:
List[List[float]]: A list of embeddings, where each embedding corresponds to an input string.
"""
from litellm import embedding
response = embedding(
api_key=api_key,
api_base=api_base,
model=model_name,
input=input,
timeout=timeout,
**kwargs,
)
return [result["embedding"] for result in response.data]
class LiteLLMEmbedding(BaseEmbedding):
model_name: str = Field(description="The name of the embedding model.")
api_key: Optional[str] = Field(
default=None,
description="OpenAI key. If not provided, the proxy server must be configured with the key.",
)
api_base: Optional[str] = Field(
default=None, description="The base URL of the LiteLLM proxy."
)
dimensions: Optional[int] = Field(
default=None,
description=(
"The number of dimensions the resulting output embeddings should have. "
),
)
timeout: Optional[int] = Field(
default=60, description="Timeout for each request.", ge=0
)
def __init__(
self, model_name: str, *, dimensions: Optional[int] = None, **kwargs
) -> None:
super().__init__(model_name=model_name, dimensions=dimensions, **kwargs)
if dimensions is None:
self.dimensions = len(self._get_text_embedding("test"))
@classmethod
def class_name(cls) -> str:
return "lite-llm"
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
def _get_query_embedding(self, query: str) -> List[float]:
embeddings = get_embeddings(
api_key=self.api_key,
api_base=self.api_base,
model_name=self.model_name,
dimensions=self.dimensions,
timeout=self.timeout,
input=[query],
)
return embeddings[0]
def _get_text_embedding(self, text: str) -> List[float]:
embeddings = get_embeddings(
api_key=self.api_key,
api_base=self.api_base,
model_name=self.model_name,
dimensions=self.dimensions,
timeout=self.timeout,
input=[text],
)
return embeddings[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return get_embeddings(
api_key=self.api_key,
api_base=self.api_base,
model_name=self.model_name,
dimensions=self.dimensions,
timeout=self.timeout,
input=texts,
)
================================================
FILE: core/autoflow/models/llms/__init__.py
================================================
from .litellm import LiteLLM
LLM = LiteLLM
__all__ = [
"LLM",
]
================================================
FILE: core/autoflow/models/llms/dspy.py
================================================
import dspy
from autoflow.models.llms import LLM
def get_dspy_lm_by_llm(llm: LLM) -> dspy.LM:
return dspy.LM(
model=llm.model,
max_tokens=llm.max_tokens,
**llm.additional_kwargs,
)
================================================
FILE: core/autoflow/models/llms/litellm.py
================================================
import logging
from litellm import verbose_logger
from llama_index.llms.litellm import LiteLLM
verbose_logger.setLevel(logging.WARN)
LiteLLM = LiteLLM
================================================
FILE: core/autoflow/models/manager.py
================================================
from typing import Dict, Optional
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from autoflow.configs.models.providers import ModelProviders
from autoflow.configs.models.embeddings import EmbeddingModelConfig
from autoflow.configs.models.llms import LLMConfig
from autoflow.configs.models.providers.base import ProviderConfig
from autoflow.configs.models.rerankers import RerankerConfig
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.models.llms import LLM
from autoflow.models.rerank_models import RerankModel
class ModelManager:
_registry: Dict[ModelProviders, ProviderConfig] = {}
@classmethod
def load_from_db(cls):
pass
@classmethod
def from_config(cls):
pass
def registry_provider(self, name: ModelProviders, config: ProviderConfig):
self._registry[name] = config
def get_provider_config(self, name: ModelProviders) -> Optional[ProviderConfig]:
provider = self._registry.get(name)
if provider is None:
raise ValueError('Provider "{}" is not found.'.format(name))
return provider
def resolve_llm(
self,
provider: Optional[ModelProviders] = ModelProviders.OPENAI,
config: Optional[Dict] = None,
) -> Optional[BaseLLM]:
cfg = LLMConfig.model_validate(
{
"provider": provider,
"config": config,
}
)
provider_config = self.get_provider_config(cfg.provider)
merged_config = {
**provider_config.model_dump(),
**cfg.config.model_dump(),
"model": f"{cfg.provider.value}/{cfg.config.model}",
}
return LLM(**merged_config)
def resolve_embedding_model(
self,
provider: Optional[ModelProviders] = ModelProviders.OPENAI,
config: Optional[Dict] = None,
) -> Optional[BaseEmbedding]:
cfg = EmbeddingModelConfig.model_validate(
{
"provider": provider,
"config": config,
}
)
provider_config = self.get_provider_config(cfg.provider)
merged_config = {
**provider_config.model_dump(),
**cfg.config.model_dump(exclude={"model"}),
"model_name": f"{cfg.provider.value}/{cfg.config.model}",
}
return EmbeddingModel(**merged_config)
def resolve_rerank_model(
self,
provider: Optional[ModelProviders] = ModelProviders.OPENAI,
config: Optional[Dict] = None,
) -> Optional[BaseNodePostprocessor]:
cfg = RerankerConfig.model_validate(
{
"provider": provider,
"config": config,
}
)
provider_config = self.get_provider_config(cfg.provider)
merged_config = {
**provider_config.model_dump(),
**cfg.config.model_dump(exclude={"model"}),
"model": f"{cfg.provider.value}/{cfg.config.model}",
}
return RerankModel(**merged_config)
model_manager: ModelManager = ModelManager()
================================================
FILE: core/autoflow/models/provider.py
================================================
from abc import ABC
class ProviderRegistry(ABC):
def register(self, name: str):
pass
def get_provider_credentials(self):
pass
================================================
FILE: core/autoflow/models/rerank_models/__init__.py
================================================
from .litellm import LiteLLMReranker
RerankModel = LiteLLMReranker
__all__ = ["RerankModel"]
================================================
FILE: core/autoflow/models/rerank_models/litellm.py
================================================
from typing import List, Optional
from litellm.rerank_api.main import rerank
from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle, MetadataMode
dispatcher = get_dispatcher(__name__)
class LiteLLMReranker(BaseNodePostprocessor):
model: str = Field(description="Reranker model name.")
top_n: int = Field(description="Top N nodes to return.")
api_base: Optional[str] = Field(description="Reranker API base url.", default=None)
api_key: Optional[str] = Field(description="Reranker API key.")
def __init__(
self,
top_n: int = 2,
model: str = "rerank_models-english-v2.0",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
):
super().__init__(top_n=top_n, model=model, api_base=api_base, api_key=api_key)
@classmethod
def class_name(cls) -> str:
return "LiteLLMRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model
)
)
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
results = rerank(
model=self.model,
query=query_bundle.query_str,
documents=texts,
top_n=self.top_n,
api_base=self.api_base,
api_key=self.api_key,
)
new_nodes = []
for result in results.results:
new_node_with_score = NodeWithScore(
node=nodes[result["index"]].node, score=result["relevance_score"]
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
dispatcher.event(ReRankEndEvent(nodes=new_nodes))
return new_nodes
================================================
FILE: core/autoflow/orms/__init__.py
================================================
from .base import UUIDBaseModel
__all__ = [
"UUIDBaseModel",
]
================================================
FILE: core/autoflow/orms/base.py
================================================
import uuid
from datetime import datetime
from typing import Optional
from pytidb.schema import TableModel, Field
from pytidb.datatype import DateTime
from pytidb.sql import func
from autoflow.utils import uuid6
class UUIDBaseModel(TableModel):
id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True)
# Use sa_type instead of sa_column, refer to https://github.com/tiangolo/sqlmodel/discussions/743
created_at: Optional[datetime] = Field(
sa_type=DateTime(timezone=True),
sa_column_kwargs={"server_default": func.now()},
)
updated_at: Optional[datetime] = Field(
sa_type=DateTime(timezone=True),
sa_column_kwargs={
"server_default": func.now(),
"onupdate": func.now(),
},
)
================================================
FILE: core/autoflow/py.typed
================================================
================================================
FILE: core/autoflow/storage/__init__.py
================================================
================================================
FILE: core/autoflow/storage/doc_store/__init__.py
================================================
from .base import DocumentStore
from .types import Document, DocumentSearchResult, Chunk
__all__ = [
"DocumentStore",
"DocumentSearchResult",
"Document",
"Chunk",
]
================================================
FILE: core/autoflow/storage/doc_store/base.py
================================================
from uuid import UUID
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from autoflow.storage.doc_store.types import Chunk, Document, DocumentSearchResult
class DocumentStore(ABC):
@abstractmethod
def add(self, documents: List[Document]) -> List[Document]:
raise NotImplementedError()
@abstractmethod
def update(self, document_id: UUID, update: Dict[str, Any]):
raise NotImplementedError()
@abstractmethod
def delete(self, document_id: UUID) -> None:
raise NotImplementedError()
@abstractmethod
def list(self, filters: Dict[str, Any] = None) -> List[Document]:
raise NotImplementedError()
@abstractmethod
def search(
self,
query: str,
top_k: Optional[int] = None,
similarity_candidate: Optional[int] = None,
) -> DocumentSearchResult:
raise NotImplementedError()
@abstractmethod
def get(self, document_id: UUID) -> Document:
raise NotImplementedError()
@abstractmethod
def add_doc_chunks(self, document_id: UUID, chunks: List[Chunk]) -> List[Chunk]:
raise NotImplementedError()
@abstractmethod
def list_doc_chunks(self, document_id: UUID) -> List[Chunk]:
raise NotImplementedError()
@abstractmethod
def get_chunk(self, chunk_id: UUID) -> Chunk:
raise NotImplementedError()
@abstractmethod
def update_chunk(self, chunk_id: UUID, update: Dict[str, Any]) -> Chunk:
raise NotImplementedError()
@abstractmethod
def delete_chunk(self, chunk_id: UUID) -> None:
raise NotImplementedError()
================================================
FILE: core/autoflow/storage/doc_store/tidb_doc_store.py
================================================
import logging
from uuid import UUID
from typing import Any, Dict, List, Optional, Type
from pydantic import PrivateAttr
from pytidb import TiDBClient, Table
from pytidb.embeddings import EmbeddingFunction
from pytidb.schema import TableModel, Field, Column, Relationship as SQLRelationship
from pytidb.datatype import Vector, JSON
from pytidb.search import SearchType
from sqlalchemy.dialects.mysql import LONGTEXT
from autoflow.data_types import DataType
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.orms.base import UUIDBaseModel
from autoflow.storage.doc_store.types import (
Document,
DocumentDescriptor,
Chunk,
RetrievedChunk,
DocumentSearchResult,
)
from autoflow.types import SearchMode
from autoflow.storage.doc_store.base import DocumentStore
logger = logging.getLogger(__name__)
def dynamic_create_models(
namespace: Optional[str] = None,
embedding_model: Optional[EmbeddingModel] = None,
vector_dims: Optional[int] = None,
) -> tuple[Type[TableModel], Type[TableModel]]:
if embedding_model is None and vector_dims is None:
raise ValueError("Either `embedding_model` or `vector_dims` must be specified")
# Determine the table names.
suffix = f"_{namespace}" if namespace else ""
document_table_name = f"documents{suffix}"
document_model_name = f"DBDocument{suffix}"
chunk_table_name = f"chunks{suffix}"
chunk_model_name = f"DBChunk{suffix}"
# Initialize the document table model.
class DBDocument(UUIDBaseModel):
hash: str = Field(max_length=128)
name: str = Field(max_length=256)
content: str = Field(sa_column=Column(LONGTEXT))
data_type: Optional[DataType] = Field(default=None)
meta: dict = Field(default_factory=dict, sa_column=Column(JSON))
document_model = type(
document_model_name,
(DBDocument,),
{
"__tablename__": document_table_name,
"__table_args__": {"extend_existing": True},
},
table=True,
)
# Initialize the chunk table model.
if embedding_model is not None:
embed_fn = EmbeddingFunction(
model_name=embedding_model.model_name,
dimensions=embedding_model.dimensions,
api_key=embedding_model.api_key,
api_base=embedding_model.api_base,
timeout=embedding_model.timeout,
)
vector_field = embed_fn.VectorField(source_field="text")
else:
vector_field = Field(default=None, sa_column=Column(Vector(vector_dims)))
class DBChunk(UUIDBaseModel):
text: str = Field(sa_column=Column(LONGTEXT))
text_vec: Optional[Any] = vector_field
document_id: UUID = Field(foreign_key=f"{document_table_name}.id")
chunk_model = type(
chunk_model_name,
(DBChunk,),
{
"__tablename__": chunk_table_name,
"__table_args__": {"extend_existing": True},
"__annotations__": {
"document": Optional[document_model],
},
"document": SQLRelationship(
sa_relationship_kwargs={
"cascade": "all, delete",
},
),
},
table=True,
)
return document_model, chunk_model
class TiDBDocumentStore(DocumentStore):
_client: TiDBClient = PrivateAttr()
_document_db_model: Type[Type[TableModel]] = PrivateAttr()
_document_table: Table = PrivateAttr()
_chunk_db_model: Type[Type[TableModel]] = PrivateAttr()
_chunk_table: Table = PrivateAttr()
def __init__(
self,
client: TiDBClient,
namespace: Optional[str] = None,
embedding_model: Optional[EmbeddingModel] = None,
vector_dims: Optional[int] = None,
) -> None:
super().__init__()
self._client = client
self._db_engine = self._client.db_engine
self._embedding_model = embedding_model
self._init_store(namespace, vector_dims)
@classmethod
def class_name(cls) -> str:
return "TiDBDocumentStore"
def _init_store(
self, namespace: Optional[str] = None, vector_dims: Optional[int] = None
):
self._document_db_model, self._chunk_db_model = dynamic_create_models(
namespace=namespace,
vector_dims=vector_dims,
embedding_model=self._embedding_model,
)
self._document_table = self._client.create_table(schema=self._document_db_model)
self._chunk_table = self._client.create_table(schema=self._chunk_db_model)
# Document Operations.
def add(self, documents: List[Document]) -> List[Document]:
"""
Add documents.
"""
return_documents = []
for doc in documents:
db_document = self._document_db_model(**doc.model_dump(exclude={"chunks"}))
db_document = self._document_table.insert(db_document)
return_chunks = []
if doc.chunks is not None and len(doc.chunks) > 0:
db_chunks = self.add_doc_chunks(db_document.id, doc.chunks)
return_chunks = [
Chunk(**db_chunk.model_dump(exclude={"document"}))
for db_chunk in db_chunks
]
return_documents.append(
Document(**db_document.model_dump(), chunks=return_chunks)
)
return return_documents
def update(self, document_id: UUID, update: Dict[str, Any]) -> None:
"""
Update documents.
"""
self._document_table.update(update, {"id": document_id})
def delete(self, document_id: UUID) -> None:
"""
Delete document by id.
Note: The related chunks will also be deleted by cascade.
Args:
document_id: The id of the document to delete.
"""
return self._document_table.delete({"id": document_id})
def get(self, document_id: UUID) -> Document:
"""
Get document by id.
"""
db_document = self._document_table.get(document_id)
return Document(**db_document.model_dump())
# TODO: Support pagination.
def list(self, filters: Dict[str, Any] = None) -> List[Document]:
"""
List all documents.
"""
db_documents = self._document_table.query(filters)
return [Document(**d.model_dump()) for d in db_documents]
def search(
self,
query: str | List[float],
mode: SearchMode = "vector",
top_k: Optional[int] = None,
similarity_threshold: Optional[float] = None,
num_candidate: Optional[int] = None,
full_document: Optional[bool] = None,
) -> DocumentSearchResult:
# TODO: Support Fulltext search.
# TODO: Support Hybrid search.
if mode != "vector":
raise NotImplementedError(
".search() only supports vector search currently, fulltext and hybird search will be coming soon."
)
db_chunks = (
self._chunk_table.search(query, query_type=SearchType.VECTOR_SEARCH)
.distance_threshold(
(1 - similarity_threshold) if similarity_threshold is not None else None
)
.num_candidate(num_candidate)
.limit(top_k)
.to_pydantic(with_score=True)
)
document_ids = [c.document_id for c in db_chunks]
db_documents = self.list(
{
"id": {"$in": document_ids},
}
)
return self._convert_to_retrieval_result(db_chunks, db_documents, full_document)
def _convert_to_retrieval_result(
self,
db_chunks: List[TableModel],
db_documents: List[TableModel],
full_document: bool,
) -> DocumentSearchResult:
return DocumentSearchResult(
chunks=[
RetrievedChunk(
**c.hit.model_dump(),
similarity_score=c.similarity_score,
score=c.score,
)
for c in db_chunks
],
documents=[
Document(**d.model_dump())
if full_document
else DocumentDescriptor(**d.model_dump())
for d in db_documents
],
)
# Chunk Operations.
def add_doc_chunks(self, document_id: UUID, chunks: List[Chunk]) -> List[Chunk]:
"""
Add document chunks.
"""
db_chunks = [
self._chunk_db_model(
**c.model_dump(exclude={"document_id"}), document_id=document_id
)
for c in chunks
]
db_chunks = self._chunk_table.bulk_insert(db_chunks)
return [Chunk(**c.model_dump(exclude={"document"})) for c in db_chunks]
def list_doc_chunks(self, document_id: UUID) -> List[Chunk]:
"""
List document chunks.
"""
return self._chunk_table.query({"document_id": document_id})
def get_chunk(self, chunk_id: UUID) -> Chunk:
"""
Get chunk by id.
"""
chunk = self._chunk_table.get(chunk_id)
return Chunk(**chunk.model_dump(exclude={"document"}))
def delete_chunk(self, chunk_id: UUID) -> None:
"""
Delete document chunk.
"""
return self._chunk_table.delete({"id": chunk_id})
def update_chunk(self, chunk_id: UUID, update: Dict[str, Any]) -> Chunk:
"""
Update chunk.
"""
self._chunk_table.update(update, {"id": chunk_id})
return self.get_chunk(chunk_id)
# Document Store Operations.
def recreate(self) -> None:
self._client.drop_table(self._chunk_table.table_name)
self._client.drop_table(self._document_table.table_name)
self._document_table = self._client.create_table(schema=self._document_db_model)
self._chunk_table = self._client.create_table(schema=self._chunk_db_model)
def reset(self) -> None:
with self._client.session():
self._client.execute("SET FOREIGN_KEY_CHECKS = 0")
self._chunk_table.truncate()
self._document_table.truncate()
self._client.execute("SET FOREIGN_KEY_CHECKS = 1")
================================================
FILE: core/autoflow/storage/doc_store/types.py
================================================
from uuid import UUID
from datetime import datetime
from typing import Optional, List, Any
from pydantic import BaseModel, Field, computed_field
from autoflow.data_types import DataType
from autoflow.utils import uuid6
from autoflow.utils.hash import sha256
# Chunk
class Chunk(BaseModel):
id: Optional[UUID] = Field(default_factory=uuid6.uuid7)
text: str = Field(description="The text of the chunk.")
text_vec: Optional[Any] = Field(
default=None, description="The vector of text vectors."
)
meta: Optional[dict] = Field(
default_factory=dict, description="The metadata of the chunk."
)
document_id: Optional[UUID] = Field(
default=None, description="The id of the document that the chunk belongs to."
)
created_at: datetime = Field(default=None, description="The created time")
updated_at: datetime = Field(default=None, description="The updated time")
@computed_field
@property
def hash(self) -> Optional[str]:
return sha256(self.text)
class RetrievedChunk(Chunk):
score: Optional[float] = Field(description="The score of the chunk.", default=None)
similarity_score: Optional[float] = Field(
default=None, description="The similarity score of the chunk."
)
# Document
class Document(BaseModel):
id: Optional[UUID] = Field(default_factory=uuid6.uuid7)
name: Optional[str] = Field(None, description="The name of the document.")
content: str = Field(description="The content of the document.")
data_type: Optional[DataType] = Field(
default=None, description="The data type of the document."
)
meta: Optional[dict] = Field(
default_factory=dict, description="The metadata of the document."
)
created_at: Optional[datetime] = Field(default=None, description="The created time")
updated_at: Optional[datetime] = Field(default=None, description="The updated time")
chunks: Optional[List[Chunk]] = Field(
default_factory=list, description="The chunks of the document."
)
@computed_field
@property
def hash(self) -> Optional[str]:
return sha256(self.content)
class DocumentDescriptor(BaseModel):
id: UUID
name: str
# source_uri: str
class DocumentSearchResult(BaseModel):
chunks: List[RetrievedChunk] = Field(
default_factory=list, description="The chunks of the search result."
)
documents: List[DocumentDescriptor | Document] = Field(
default_factory=list,
description="The aggregated documents of the search result.",
)
================================================
FILE: core/autoflow/storage/graph_store/__init__.py
================================================
from .base import GraphStore
from .tidb_graph_store import TiDBGraphStore
__all__ = ["GraphStore", "TiDBGraphStore"]
================================================
FILE: core/autoflow/storage/graph_store/base.py
================================================
from abc import ABC
import logging
from typing import (
Collection,
Dict,
List,
Optional,
Tuple,
)
from uuid import UUID
from autoflow.storage.types import QueryBundle
from autoflow.types import BaseComponent
from autoflow.storage.graph_store.types import (
Entity,
EntityFilters,
EntityType,
EntityUpdate,
EntityDegree,
KnowledgeGraph,
KnowledgeGraphCreate,
Relationship,
RelationshipFilters,
RelationshipUpdate,
)
logger = logging.getLogger(__name__)
class GraphStore(BaseComponent, ABC):
"""Abstract base class for knowledge graph storage"""
# Entity Basic Operations
def list_entities(
self, filters: Optional[EntityFilters] = EntityFilters()
) -> List[Entity]:
"""List all entities matching the filters"""
raise NotImplementedError
def search_entities(
self,
query: QueryBundle,
top_k: int = 10,
num_candidate: Optional[int] = None,
distance_threshold: Optional[float] = None,
filters: Optional[EntityFilters] = None,
) -> List[Tuple[Entity, float]]:
raise NotImplementedError
def get_entity(self, entity_id: UUID) -> Entity:
"""Get entity by ID"""
raise NotImplementedError
def must_get_entity(self, entity_id: UUID) -> Entity:
entity = self.get_entity(entity_id)
if entity is None:
raise ValueError(f"Can not find the entity #{entity_id}")
return entity
def create_entity(
self,
name: str,
entity_type: EntityType = EntityType.original,
description: Optional[str] = None,
meta: Optional[dict] = None,
embedding: Optional[list[float]] = None,
) -> Entity:
"""Create a new entity"""
raise NotImplementedError
def update_entity(self, entity: Entity | UUID, update: EntityUpdate) -> Entity:
"""Update an existing entity"""
raise NotImplementedError
def delete_entity(self, entity_id: UUID) -> None:
"""Delete an entity"""
raise NotImplementedError
def delete_orphan_entities(self):
"""Remove entities that have no relationships"""
raise NotImplementedError
# Entity Degree Operations
def calc_entity_out_degree(self, entity_id: UUID) -> int:
"""Calculate out-degree of an entity"""
raise NotImplementedError
def calc_entity_in_degree(self, entity_id: UUID) -> int:
"""Calculate in-degree of an entity"""
raise NotImplementedError
def calc_entity_degree(self, entity_id: UUID) -> int:
"""Calculate total degree of an entity"""
raise NotImplementedError
def calc_entities_degrees(
self, entity_ids: Collection[UUID]
) -> Dict[UUID, EntityDegree]:
"""Calculate degrees for multiple entities in bulk"""
raise NotImplementedError
# Relationship Basic Operations
def get_relationship(self, relationship_id: UUID) -> Relationship:
"""Get relationship by ID"""
raise NotImplementedError
def list_relationships(self, filters: RelationshipFilters) -> List[Relationship]:
"""List all relationships matching the filters"""
raise NotImplementedError
def create_relationship(
self,
source_entity: Entity,
target_entity: Entity,
description: Optional[str] = None,
meta: Optional[dict] = {},
**kwargs,
) -> Relationship:
"""Create a new relationship between entities"""
raise NotImplementedError
def update_relationship(
self, relationship: Relationship | UUID, update: RelationshipUpdate
) -> Relationship:
"""Update an existing relationship"""
raise NotImplementedError
def delete_relationship(self, relationship_id: UUID):
"""Delete a relationship"""
raise NotImplementedError
def search_relationships(
self,
query: QueryBundle,
top_k: int = 10,
num_candidate: Optional[int] = None,
distance_threshold: Optional[float] = None,
distance_range: Optional[Tuple[float, float]] = None,
filters: Optional[RelationshipFilters] = None,
) -> List[Tuple[Relationship, float]]:
"""
Args:
query:
top_k:
num_candidate:
distance_threshold:
distance_range:
filters:
"""
raise NotImplementedError
def reset(self):
"""Reset the graph store"""
raise NotImplementedError
def drop(self):
"""Drop the graph store"""
raise NotImplementedError
# Knowledge Graph Operations
def add(self, knowledge_graph: KnowledgeGraphCreate) -> Optional[KnowledgeGraph]:
"""Add a knowledge graph to the graph store"""
raise NotImplementedError
================================================
FILE: core/autoflow/storage/graph_store/tidb_graph_store.py
================================================
import logging
from typing import Collection, Dict, List, Optional, Tuple, Type, Any
from uuid import UUID
from pydantic import PrivateAttr
from pytidb import Table, TiDBClient
from pytidb.datatype import JSON, Text
from pytidb.schema import (
Column,
Field,
Relationship as SQLRelationship,
TableModel,
VectorField,
)
from pytidb.sql import func, select, or_
from pytidb.embeddings import EmbeddingFunction
from sqlalchemy import Index
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.orms.base import UUIDBaseModel
from autoflow.storage.graph_store.base import GraphStore
from autoflow.storage.graph_store.types import (
Entity,
EntityDegree,
EntityFilters,
EntityType,
EntityUpdate,
KnowledgeGraph,
KnowledgeGraphCreate,
Relationship,
RelationshipFilters,
RelationshipUpdate,
)
from autoflow.storage.types import QueryBundle
logger = logging.getLogger(__name__)
def dynamic_create_models(
namespace: Optional[str] = None,
embedding_model: Optional[EmbeddingModel] = None,
vector_dims: Optional[int] = None,
) -> tuple[Type[TableModel], Type[TableModel]]:
if embedding_model is None and vector_dims is None:
raise ValueError("Either `embedding_model` or `vector_dims` must be specified")
# Determine the table names.
suffix = f"_{namespace}" if namespace else ""
entity_table_name = f"entities{suffix}"
relationship_table_name = f"relationships{suffix}"
entity_model_name = f"DBEntity{suffix}"
relationship_model_name = f"DBRelationship{suffix}"
# Embedding function.
embed_fn = None
if embedding_model is not None:
embed_fn = EmbeddingFunction(
model_name=embedding_model.model_name,
dimensions=embedding_model.dimensions,
api_key=embedding_model.api_key,
api_base=embedding_model.api_base,
timeout=embedding_model.timeout,
)
# Initialize the entity model.
if embed_fn is not None:
entity_vector_field = embed_fn.VectorField()
else:
entity_vector_field = VectorField(vector_dims)
class DBEntity(UUIDBaseModel):
__table_args__ = (
Index("idx_entity_type", "entity_type"),
Index("idx_entity_name", "name"),
)
entity_type: EntityType = EntityType.original
name: str = Field(max_length=512)
description: str = Field(sa_column=Column(Text))
meta: Optional[Dict] = Field(default_factory=dict, sa_column=Column(JSON))
embedding: Optional[Any] = entity_vector_field
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
return self.id == other.id
entity_model = type(
entity_model_name,
(DBEntity,),
{
"__tablename__": entity_table_name,
"__table_args__": {
"extend_existing": True,
},
},
table=True,
)
# Initialize the relationship model.
if embed_fn is not None:
relationship_vector_field = embed_fn.VectorField()
else:
relationship_vector_field = VectorField(vector_dims)
class DBRelationship(UUIDBaseModel):
description: str = Field(sa_column=Column(Text))
source_entity_id: UUID = Field(foreign_key=f"{entity_table_name}.id")
target_entity_id: UUID = Field(foreign_key=f"{entity_table_name}.id")
meta: Optional[Dict] = Field(default_factory=dict, sa_column=Column(JSON))
embedding: Optional[Any] = relationship_vector_field
weight: Optional[float] = Field(default=0)
chunk_id: Optional[UUID] = Field(default=None)
document_id: Optional[UUID] = Field(default=None)
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
return self.id == other.id
relationship_model = type(
relationship_model_name,
(DBRelationship,),
{
"__tablename__": relationship_table_name,
"__table_args__": {"extend_existing": True},
"__annotations__": {
"source_entity": entity_model,
"target_entity": entity_model,
},
"source_entity": SQLRelationship(
sa_relationship_kwargs={
"primaryjoin": f"{relationship_model_name}.source_entity_id == {entity_model_name}.id",
"lazy": "joined",
},
),
"target_entity": SQLRelationship(
sa_relationship_kwargs={
"primaryjoin": f"{relationship_model_name}.target_entity_id == {entity_model_name}.id",
"lazy": "joined",
},
),
},
table=True,
)
return entity_model, relationship_model
class TiDBGraphStore(GraphStore):
_db: TiDBClient = PrivateAttr()
_entity_db_model: Type[TableModel] = PrivateAttr()
_entity_table: Table = PrivateAttr()
_relationship_db_model: Type[TableModel] = PrivateAttr()
_relationship_table: Table = PrivateAttr()
def __init__(
self,
client: TiDBClient,
namespace: Optional[str] = None,
embedding_model: Optional[EmbeddingModel] = None,
vector_dims: Optional[int] = None,
entity_distance_threshold: Optional[float] = 0.1,
):
super().__init__()
self._db = client
self._db_engine = client.db_engine
self._embedding_model = embedding_model
self._entity_distance_threshold = entity_distance_threshold
self._init_store(namespace, vector_dims)
def _init_store(
self, namespace: Optional[str] = None, vector_dims: Optional[int] = None
):
self._entity_db_model, self._relationship_db_model = dynamic_create_models(
namespace=namespace,
vector_dims=vector_dims,
embedding_model=self._embedding_model,
)
self._entity_table = self._db.create_table(schema=self._entity_db_model)
self._relationship_table = self._db.create_table(
schema=self._relationship_db_model
)
# Entity Basic Operations
def get_entity(self, entity_id: UUID) -> Entity:
return self._entity_table.get(entity_id)
def list_entities(
self, filters: Optional[EntityFilters] = EntityFilters(), **kwargs
) -> List[Entity]:
if isinstance(kwargs, dict):
filters = filters.model_copy(update=kwargs)
filter_dict = self._convert_entity_filters(filters)
return self._entity_table.query(filter_dict)
def search_entities(
self,
query: QueryBundle,
top_k: int = 10,
num_candidate: Optional[int] = None,
distance_threshold: Optional[float] = None,
filters: Optional[EntityFilters] = None,
) -> List[Tuple[Entity, float]]:
filter_dict = self._convert_entity_filters(filters)
results = (
self._entity_table.search(query.query_embedding or query.query_str)
.num_candidate(num_candidate or top_k * 10)
.filter(filter_dict)
.distance_threshold(distance_threshold)
.limit(top_k)
.to_pydantic()
)
return [(item.hit, item.score) for item in results]
def _convert_entity_filters(self, filters: Optional[EntityFilters]) -> dict:
filter_dict = {}
if filters is None:
return filter_dict
if filters.entity_type:
filter_dict["entity_type"] = filters.entity_type.value
if filters.entity_id:
op = "$in" if isinstance(filters.entity_id, list) else "$eq"
filter_dict["id"] = {op: filters.entity_id}
return filter_dict
def create_entity(
self,
name: str,
entity_type: EntityType = EntityType.original,
description: Optional[str] = None,
meta: Optional[dict] = None,
embedding: Optional[list[float]] = None,
) -> Entity:
if embedding is None:
embedding = self._get_entity_embedding(name, description)
entity = self._entity_db_model(
name=name,
entity_type=entity_type,
description=description,
meta=meta,
embedding=embedding,
)
return self._entity_table.insert(entity)
def _get_entity_embedding(self, name: str, description: str) -> list[float]:
embedding_str = f"{name}: {description}"
return self._embedding_model.get_text_embedding(embedding_str)
def find_or_create_entity(
self,
name: str,
entity_type: EntityType = EntityType.original,
description: Optional[str] = None,
meta: Optional[dict] = None,
embedding: Optional[Any] = None,
) -> Entity:
query_embedding = self._get_entity_embedding(name, description)
query = QueryBundle(query_embedding=query_embedding)
nearest_entity = self.search_entities(
query, top_k=1, distance_threshold=self._entity_distance_threshold
)
if len(nearest_entity) != 0:
return nearest_entity[0][0]
else:
return self.create_entity(
name=name,
entity_type=entity_type,
description=description,
meta=meta,
embedding=embedding,
)
def update_entity(self, entity: Entity | UUID, update: EntityUpdate) -> Entity:
if isinstance(entity, UUID):
entity = self.get_entity(entity)
update_dict = update.model_dump(exclude_none=True)
if update.embedding is None:
update_dict["embedding"] = self._get_entity_embedding(
entity.name, entity.description
)
self._entity_table.update(values=update_dict, filters={"id": entity.id})
# FIXME: pytidb should return the updated entity.
entity = self._entity_table.get(entity.id)
# Update connected relationships.
connected_relationships = self.list_relationships(
filters=RelationshipFilters(
entity_id=entity.id,
)
)
for relationship in connected_relationships:
self.update_relationship(relationship, RelationshipUpdate())
return entity
def delete_entity(self, entity_id: UUID) -> None:
with self._db.session():
# Delete all relationships connected to the entity.
self._relationship_table.delete({"source_entity_id": entity_id})
self._relationship_table.delete({"target_entity_id": entity_id})
# Delete the entity.
self._entity_table.delete({"id": entity_id})
def delete_orphan_entities(self):
raise NotImplementedError()
# Entity Degree Operations
def calc_entity_out_degree(self, entity_id: UUID) -> int:
stmt = select(func.count(self._relationship_db_model.id)).where(
self._relationship_db_model.source_entity_id == entity_id
)
return self._db.query(stmt).scalar()
def calc_entity_in_degree(self, entity_id: UUID) -> int:
stmt = select(func.count(self._relationship_db_model.id)).where(
self._relationship_db_model.target_entity_id == entity_id
)
return self._db.query(stmt).scalar()
def calc_entity_degree(self, entity_id: UUID) -> int:
stmt = select(func.count(self._relationship_db_model.id)).where(
or_(
self._relationship_db_model.target_entity_id == entity_id,
self._relationship_db_model.source_entity_id == entity_id,
)
)
return self._db.query(stmt).scalar()
def calc_entities_degrees(
self, entity_ids: Collection[UUID]
) -> Dict[UUID, EntityDegree]:
entity_table_name = self._entity_table.table_name
relationship_table_name = self._relationship_table.table_name
stmt = f"""
SELECT
e.id as id,
COALESCE(SUM(CASE WHEN r.target_entity_id = e.id THEN 1 ELSE 0 END), 0) AS in_degree,
COALESCE(SUM(CASE WHEN r.source_entity_id = e.id THEN 1 ELSE 0 END), 0) AS out_degree,
COALESCE(COUNT(e.id), 0) AS degree
FROM {entity_table_name} e
LEFT JOIN {relationship_table_name} r ON e.id = r.source_entity_id OR e.id = r.target_entity_id
WHERE e.id IN :entity_ids
GROUP BY e.id
"""
results = self._db.query(
stmt, {"entity_ids": [entity_id.hex for entity_id in entity_ids]}
).to_list()
return {
UUID(item["id"]): EntityDegree(
in_degree=item["in_degree"],
out_degree=item["out_degree"],
degrees=item["degree"],
)
for item in results
}
# Relationship Basic Operations
def get_relationship(self, relationship_id: UUID) -> Relationship:
return self._relationship_table.get(relationship_id)
def list_relationships(
self, filters: RelationshipFilters = RelationshipFilters(), **kwargs
) -> List[Relationship]:
if isinstance(kwargs, dict):
filters = filters.model_copy(update=kwargs)
filter_dict = self._convert_relationship_filters(filters)
return self._relationship_table.query(filter_dict)
def search_relationships(
self,
query: QueryBundle,
top_k: int = 10,
num_candidate: Optional[int] = None,
distance_threshold: Optional[float] = None,
distance_range: Optional[Tuple[float, float]] = None,
filters: Optional[RelationshipFilters] = None,
) -> List[Tuple[Relationship, float]]:
filter_dict = self._convert_relationship_filters(filters)
results = (
self._relationship_table.search(query.query_embedding or query.query_str)
.num_candidate(num_candidate or top_k * 10)
.filter(filter_dict)
.distance_threshold(distance_threshold)
.distance_range(distance_range[0], distance_range[1])
.limit(top_k)
.to_pydantic()
)
# FIXME: pytidb should return the relationship field: target_entity, source_entity.
entity_ids = [item.hit.target_entity_id for item in results]
entity_ids.extend([item.hit.source_entity_id for item in results])
entities = self.list_entities(filters=EntityFilters(entity_id=entity_ids))
entity_map = {entity.id: entity for entity in entities}
for item in results:
item.hit.target_entity = entity_map[item.hit.target_entity_id]
item.hit.source_entity = entity_map[item.hit.source_entity_id]
return [(item.hit, item.score) for item in results]
def _convert_relationship_filters(self, filters: RelationshipFilters) -> dict:
filter_dict = {}
if filters.entity_id:
if isinstance(filters.entity_id, list):
if len(filters.entity_id) != 0:
filter_dict["$or"] = [
{"target_entity_id": {"$in": filters.entity_id}},
{"source_entity_id": {"$in": filters.entity_id}},
]
else:
filter_dict["$or"] = [
{"target_entity_id": {"$eq": filters.entity_id}},
{"source_entity_id": {"$eq": filters.entity_id}},
]
if filters.source_entity_id:
if isinstance(filters.source_entity_id, list):
if len(filters.source_entity_id) != 0:
filter_dict["$or"] = [
{"source_entity_id": {"$in": filters.source_entity_id}}
]
else:
filter_dict["$or"] = [
{"source_entity_id": {"$eq": filters.source_entity_id}}
]
if filters.target_entity_id:
if isinstance(filters.target_entity_id, list):
if len(filters.target_entity_id) != 0:
filter_dict["$or"] = [
{"target_entity_id": {"$in": filters.target_entity_id}}
]
else:
filter_dict["$or"] = [
{"target_entity_id": {"$eq": filters.target_entity_id}}
]
if filters.relationship_id:
if isinstance(filters.relationship_id, list):
if len(filters.relationship_id) != 0:
filter_dict["id"] = {"$in": filters.relationship_id}
else:
filter_dict["id"] = {"$eq": filters.relationship_id}
if (
filters.exclude_relationship_ids
and len(filters.exclude_relationship_ids) != 0
):
filter_dict["id"] = {"$nin": filters.exclude_relationship_ids}
if filters.document_id:
if isinstance(filters.document_id, list):
if len(filters.document_id) != 0:
filter_dict["document_id"] = {"$in": filters.document_id}
else:
filter_dict["document_id"] = {"$eq": filters.document_id}
if filters.chunk_id:
if isinstance(filters.chunk_id, list):
if len(filters.chunk_id) != 0:
filter_dict["chunk_id"] = {"$in": filters.chunk_id}
else:
filter_dict["chunk_id"] = {"$eq": filters.chunk_id}
if filters.metadata:
for key, value in filters.metadata.items():
op = "$in" if isinstance(value, list) else "$eq"
filter_dict[f"meta.{key}"] = {op: value}
return filter_dict
def create_relationship(
self,
source_entity: Entity | UUID,
target_entity: Entity | UUID,
description: Optional[str] = None,
meta: Optional[dict] = {},
embedding: Optional[Any] = None,
) -> Relationship:
"""
Create a relationship between two entities.
"""
if isinstance(source_entity, UUID):
source_entity = self.get_entity(source_entity)
if isinstance(target_entity, UUID):
target_entity = self.get_entity(target_entity)
if embedding is None:
embedding = self._get_relationship_embedding(
source_entity.name,
source_entity.description,
target_entity.name,
target_entity.description,
description,
)
relationship = self._relationship_db_model(
source_entity_id=source_entity.id,
target_entity_id=target_entity.id,
description=description,
meta=meta,
embedding=embedding,
)
return self._relationship_table.insert(relationship)
def _get_relationship_embedding(
self,
source_entity_name: str,
source_entity_description,
target_entity_name: str,
target_entity_description: str,
relationship_desc: str,
) -> List[float]:
embedding_str = (
f"{source_entity_name}({source_entity_description}) -> "
f"{relationship_desc} -> {target_entity_name}({target_entity_description}) "
)
return self._embedding_model.get_text_embedding(embedding_str)
def update_relationship(
self, relationship: Relationship | UUID, update: RelationshipUpdate
) -> Relationship:
if isinstance(relationship, UUID):
relationship = self.get_relationship(relationship)
update_dict = update.model_dump()
if update.embedding is None:
update_dict["embedding"] = self._get_relationship_embedding(
relationship.source_entity.name,
relationship.source_entity.description,
relationship.target_entity.name,
relationship.target_entity.description,
relationship.description,
)
self._relationship_table.update(
values=update_dict, filters={"id": relationship.id}
)
# FIXME: pytidb should return the updated relationship.
relationship = self._relationship_table.get(relationship.id)
return relationship
def delete_relationship(self, relationship_id: UUID):
return self._relationship_table.delete(filters={"id": relationship_id})
# Knowledge Graph Operations
def add(self, knowledge_graph: KnowledgeGraphCreate) -> Optional[KnowledgeGraph]:
with self._db.session():
# Create or find entities
entity_map = {}
for entity in knowledge_graph.entities:
created_entity = self.find_or_create_entity(
entity_type=EntityType.original,
name=entity.name,
description=entity.description,
meta=entity.meta,
)
entity_map[entity.name] = created_entity
entities = list(entity_map.values())
# Create relationships
relationships = []
for rel in knowledge_graph.relationships:
logger.info("Saving relationship: %s", rel.description)
source_entity = entity_map.get(rel.source_entity_name)
if not source_entity:
logger.warning(
"Source entity not found for relationship: %s", str(rel)
)
continue
target_entity = entity_map.get(rel.target_entity_name)
if not target_entity:
logger.warning(
"Target entity not found for relationship: %s", str(rel)
)
continue
relationship = self.create_relationship(
source_entity=source_entity,
target_entity=target_entity,
description=rel.description,
meta=rel.meta,
)
relationships.append(relationship)
return KnowledgeGraph(
entities=[Entity(**entity.model_dump()) for entity in entities],
relationships=[
Relationship(**relationship.model_dump())
for relationship in relationships
],
)
# Graph Store Operations
def reset(self):
with self._db.session():
self._db.execute("SET FOREIGN_KEY_CHECKS = 0")
self._relationship_table.truncate()
self._entity_table.truncate()
self._db.execute("SET FOREIGN_KEY_CHECKS = 1")
def recreate(self):
self._db.drop_table(self._relationship_table.table_name)
self._db.drop_table(self._entity_table.table_name)
self._entity_table = self._db.create_table(schema=self._entity_db_model)
self._relationship_table = self._db.create_table(
schema=self._relationship_db_model
)
def drop(self):
self._db.drop_table(self._relationship_table.table_name)
self._db.drop_table(self._entity_table.table_name)
================================================
FILE: core/autoflow/storage/graph_store/types.py
================================================
from enum import Enum
from typing import List, Optional, Dict, Any
from uuid import UUID
from datetime import datetime
from pydantic import BaseModel, model_validator, Field
# Entity
class EntityType(str, Enum):
original = "original"
synopsis = "synopsis"
def __str__(self):
return self.value
class Entity(BaseModel):
id: Optional[UUID]
entity_type: Optional[EntityType] = Field(
description="Type of the entity", default=EntityType.original
)
name: str = Field(description="Name of the entity")
description: str = Field(description="Description of the entity")
embedding: Optional[Any] = Field(
description="Embedding of the entity", default=None
)
meta: Optional[Dict[str, Any]] = Field(
description="Metadata of the entity", default_factory=dict
)
created_at: Optional[datetime]
updated_at: Optional[datetime]
class EntityCreate(BaseModel):
entity_type: EntityType = EntityType.original
name: Optional[str] = None
description: Optional[str] = None
meta: Optional[dict] = None
class SynopsisEntityCreate(EntityCreate):
topic: str
entities: List[int] = Field(description="The id list of the related entities")
@model_validator(mode="after")
def validate_entities(self):
if len(self.entities) == 0:
raise ValueError("Entities list should not be empty")
return self
class EntityFilters(BaseModel):
entity_type: Optional[EntityType] = None
entity_id: Optional[UUID | List[UUID]] = None
class EntityUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
meta: Optional[dict] = None
embedding: Optional[list[float]] = None
class EntityDegree(BaseModel):
out_degree: int = 0
in_degree: int = 0
degrees: int = 0
# Relationship
class Relationship(BaseModel):
id: Optional[UUID]
source_entity_id: Optional[UUID] = Field(default=None)
source_entity: Optional[Entity] = Field(default=None)
target_entity_id: Optional[UUID] = Field(default=None)
target_entity: Optional[Entity] = Field(default=None)
description: str = Field(description="Description of the relationship")
weight: Optional[float] = Field(default=0, description="Weight of the relationship")
meta: Optional[Dict[str, Any]] = Field(
description="Metadata of the relationship", default_factory=dict
)
embedding: Optional[Any] = Field(
description="Embedding of the relationship", default=None
)
created_at: Optional[datetime] = Field(default=None)
updated_at: Optional[datetime] = Field(default=None)
class RelationshipCreate(BaseModel):
source_entity_name: str
target_entity_name: str
description: str
meta: Dict[str, Any] = Field(default_factory=dict)
weight: Optional[float] = Field(default=0)
chunk_id: Optional[UUID] = Field(default=None)
document_id: Optional[UUID] = Field(default=None)
class RelationshipUpdate(BaseModel):
description: Optional[str] = None
embedding: Optional[list[float]] = None
class RelationshipFilters(BaseModel):
entity_id: Optional[UUID | List[UUID]] = Field(
description="Filter by the entity connected to the relationship",
default=None,
)
target_entity_id: Optional[UUID | List[UUID]] = Field(
description="Filter by the target entity of the relationship",
default=None,
)
source_entity_id: Optional[UUID | List[UUID]] = Field(
description="Filter by the source entity of the relationship",
default=None,
)
chunk_id: Optional[UUID | List[UUID]] = Field(
description="Filter by the chunk which the relationship belongs to",
default=None,
)
document_id: Optional[UUID | List[UUID]] = Field(
description="Filter by the document which the relationship belongs to",
default=None,
)
relationship_id: Optional[UUID | List[UUID]] = Field(
description="Filter by the id of the relationship",
default=None,
)
exclude_relationship_ids: Optional[List[UUID]] = Field(
description="Exclude the relationships by the id",
default=None,
)
metadata: Optional[Dict[str, Any]] = Field(
description="Filter by the metadata of the relationship",
default=None,
)
# Knowledge Graph
class KnowledgeGraph(BaseModel):
entities: List[Entity] = Field(default_factory=list)
relationships: List[Relationship] = Field(default_factory=list)
# Knowledge Graph Create
class KnowledgeGraphCreate(BaseModel):
entities: List[EntityCreate]
relationships: List[RelationshipCreate]
================================================
FILE: core/autoflow/storage/types.py
================================================
from typing import Optional, List
from pydantic import BaseModel
class QueryBundle(BaseModel):
query_str: Optional[str] = None
query_embedding: Optional[List[float]] = None
================================================
FILE: core/autoflow/types.py
================================================
from typing import Literal
from llama_index.core.schema import BaseComponent
BaseComponent = BaseComponent
SearchMode = Literal["vector", "fulltext", "hybird"]
================================================
FILE: core/autoflow/utils/hash.py
================================================
import hashlib
from typing import Optional
def sha256(text: Optional[str]) -> Optional[str]:
return hashlib.sha256(text.encode("utf-8")).hexdigest() if text else None
================================================
FILE: core/autoflow/utils/uuid6.py
================================================
r"""UUID draft version objects (universally unique identifiers).
This module provides the functions uuid6() and uuid7() for
generating version 6 and 7 UUIDs as specified in
https://github.com/uuid6/uuid6-ietf-draft.
Repo: https://github.com/oittaa/uuid6-python
Copy from: https://github.com/jonra1993/fastapi-alembic-sqlmodel-async/blob/main/backend/app/app/utils/uuid6.py
"""
import secrets
import time
from typing import Optional
import uuid
class UUID(uuid.UUID):
r"""UUID draft version objects"""
def __init__(
self,
hex: Optional[str] = None,
bytes: Optional[bytes] = None,
bytes_le: Optional[bytes] = None,
fields: Optional[tuple[int, int, int, int, int, int]] = None,
int: Optional[int] = None,
version: Optional[int] = None,
*,
is_safe=uuid.SafeUUID.unknown,
) -> None:
r"""Create a UUID."""
if int is None or [hex, bytes, bytes_le, fields].count(None) != 4:
super().__init__(
hex=hex,
bytes=bytes,
bytes_le=bytes_le,
fields=fields,
int=int,
version=version,
is_safe=is_safe,
)
if int is not None and not 0 <= int < 1 << 128:
raise ValueError("int is out of range (need a 128-bit value)")
if version is not None:
if not 6 <= version <= 7:
raise ValueError("illegal version number")
# Set the variant to RFC 4122.
int &= ~(0xC000 << 48)
int |= 0x8000 << 48
# Set the version number.
int &= ~(0xF000 << 64)
int |= version << 76
super().__init__(int=int, is_safe=is_safe)
@property
def subsec(self) -> int:
return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF)
@property
def time(self) -> int:
if self.version == 6:
return (
(self.time_low << 28)
| (self.time_mid << 12)
| (self.time_hi_version & 0x0FFF)
)
if self.version == 7:
return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec)
return super().time
def _subsec_decode(value: int) -> int:
return -(-value * 10**6 // 2**20)
def _subsec_encode(value: int) -> int:
return value * 2**20 // 10**6
_last_v6_timestamp = None
_last_v7_timestamp = None
def uuid6(clock_seq: int = None) -> UUID:
r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for
improved DB locality. It is expected that UUIDv6 will primarily be
used in contexts where there are existing v1 UUIDs. Systems that do
not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead.
If 'clock_seq' is given, it is used as the sequence number;
otherwise a random 14-bit sequence number is chosen."""
global _last_v6_timestamp
nanoseconds = time.time_ns()
# 0x01b21dd213814000 is the number of 100-ns intervals between the
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
timestamp = nanoseconds // 100 + 0x01B21DD213814000
if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp:
timestamp = _last_v6_timestamp + 1
_last_v6_timestamp = timestamp
if clock_seq is None:
clock_seq = secrets.randbits(14) # instead of stable storage
node = secrets.randbits(48)
time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF
time_low_and_version = timestamp & 0x0FFF
uuid_int = time_high_and_time_mid << 80
uuid_int |= time_low_and_version << 64
uuid_int |= (clock_seq & 0x3FFF) << 48
uuid_int |= node
return UUID(int=uuid_int, version=6)
def uuid7() -> UUID:
r"""UUID version 7 features a time-ordered value field derived from the
widely implemented and well known Unix Epoch timestamp source, the
number of milliseconds seconds since midnight 1 Jan 1970 UTC, leap
seconds excluded. As well as improved entropy characteristics over
versions 1 or 6.
Implementations SHOULD utilize UUID version 7 over UUID version 1 and
6 if possible."""
global _last_v7_timestamp
nanoseconds = time.time_ns()
if _last_v7_timestamp is not None and nanoseconds <= _last_v7_timestamp:
nanoseconds = _last_v7_timestamp + 1
_last_v7_timestamp = nanoseconds
timestamp_ms, timestamp_ns = divmod(nanoseconds, 10**6)
subsec = _subsec_encode(timestamp_ns)
subsec_a = subsec >> 8
subsec_b = subsec & 0xFF
rand = secrets.randbits(54)
uuid_int = (timestamp_ms & 0xFFFFFFFFFFFF) << 80
uuid_int |= subsec_a << 64
uuid_int |= subsec_b << 54
uuid_int |= rand
return UUID(int=uuid_int, version=7)
================================================
FILE: core/autoflow/utils/vector.py
================================================
def cosine_distance(v1, v2):
import numpy as np
return 1 - np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
================================================
FILE: core/examples/README.md
================================================
# Autoflow Examples
This folder contains examples of how to use Autoflow.
================================================
FILE: core/examples/quickstart/fixtures/analyze-slow-queries.md
================================================
---
title: Analyze Slow Queries
summary: Learn how to locate and analyze slow queries.
---
# Analyze Slow Queries
To address the issue of slow queries, you need to take the following two steps:
1. Among many queries, identify which type of queries are slow.
2. Analyze why this type of queries are slow.
You can easily perform step 1 using the [slow query log](/dashboard/dashboard-slow-query.md) and the [statement summary table](/statement-summary-tables.md) features. It is recommended to use [TiDB Dashboard](/dashboard/dashboard-intro.md), which integrates the two features and directly displays the slow queries in your browser.
This document focuses on how to perform step 2 - analyze why this type of queries are slow.
Generally, slow queries have the following major causes:
- Optimizer issues, such as wrong index selected, wrong join type or sequence selected.
- System issues. All issues not caused by the optimizer are system issues. For example, a busy TiKV instance processes requests slowly; outdated Region information causes slow queries.
In actual situations, optimizer issues might cause system issues. For example, for a certain type of queries, the optimizer uses a full table scan instead of the index. As a result, the SQL queries consume many resources, which causes the CPU usage of some TiKV instances to soar. This seems like a system issue, but in essence, it is an optimizer issue.
To identify system issues is relatively simple. To analyze optimizer issues, you need to determine whether the execution plan is reasonable or not. Therefore, it is recommended to analyze slow queries by following these procedures:
1. Identify the performance bottleneck of the query, that is, the time-consuming part of the query process.
2. Analyze the system issues: analyze the possible causes according to the query bottleneck and the monitoring/log information of that time.
3. Analyze the optimizer issues: analyze whether there is a better execution plan.
The procedures above are explained in the following sections.
## Identify the performance bottleneck of the query
First, you need to have a general understanding of the query process. The key stages of the query execution process in TiDB are illustrated in [TiDB performance map](/media/performance-map.png).
You can get the duration information using the following methods:
- [Slow log](/identify-slow-queries.md). It is recommended to view the slow log in [TiDB Dashboard](/dashboard/dashboard-overview.md).
- [`EXPLAIN ANALYZE` statement](/sql-statements/sql-statement-explain-analyze.md).
The methods above are different in the following aspects:
- The slow log records the duration of almost all stages of a SQL execution, from parsing to returning results, and is relatively comprehensive (you can query and analyze the slow log in TiDB Dashboard in an intuitive way).
- By executing `EXPLAIN ANALYZE`, you can learn the time consumption of each operator in an actual SQL execution. The results have more detailed statistics of the execution duration.
In summary, the slow log and `EXPLAIN ANALYZE` statements help you determine the SQL query is slow in which component (TiDB or TiKV) at which stage of the execution. Therefore, you can accurately identify the performance bottleneck of the query.
In addition, since v4.0.3, the `Plan` field in the slow log also includes the SQL execution information, which is the result of `EXPLAIN ANALYZE`. So you can find all information of SQL duration in the slow log.
## Analyze system issues
System issues can be divided into the following types according to different execution stages of a SQL statement:
1. TiKV is slow in data processing. For example, the TiKV coprocessor processes data slowly.
2. TiDB is slow in execution. For example, a `Join` operator processes data slowly.
3. Other key stages are slow. For example, getting the timestamp takes a long time.
For each slow query, first determine to which type the query belongs, and then analyze it in detail.
### TiKV is slow in data processing
If TiKV is slow in data processing, you can easily identify it in the result of `EXPLAIN ANALYZE`. In the following example, `StreamAgg_8` and `TableFullScan_15`, two `tikv-task`s (as indicated by `cop[tikv]` in the `task` column), take `170ms` to execute. After subtracting `170ms`, the execution time of TiDB operators account for a very small proportion of the total execution time. This indicates that the bottleneck is in TiKV.
```sql
+----------------------------+---------+---------+-----------+---------------+------------------------------------------------------------------------------+---------------------------------+-----------+------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+----------------------------+---------+---------+-----------+---------------+------------------------------------------------------------------------------+---------------------------------+-----------+------+
| StreamAgg_16 | 1.00 | 1 | root | | time:170.08572ms, loops:2 | funcs:count(Column#5)->Column#3 | 372 Bytes | N/A |
| └─TableReader_17 | 1.00 | 1 | root | | time:170.080369ms, loops:2, rpc num: 1, rpc time:17.023347ms, proc keys:28672 | data:StreamAgg_8 | 202 Bytes | N/A |
| └─StreamAgg_8 | 1.00 | 1 | cop[tikv] | | time:170ms, loops:29 | funcs:count(1)->Column#5 | N/A | N/A |
| └─TableFullScan_15 | 7.00 | 28672 | cop[tikv] | table:t | time:170ms, loops:29 | keep order:false, stats:pseudo | N/A | N/A |
+----------------------------+---------+---------+-----------+---------------+------------------------------------------------------------------------------+---------------------------------+-----------+------
```
In addition, the `Cop_process` and `Cop_wait` fields in the slow log can also help your analysis. In the following example, the total duration of the query is around `180.85ms`, and the largest `coptask` takes `171ms`. This indicates that the bottleneck of this query is on the TiKV side.
For the description of each field in the slow log, see [fields description](/identify-slow-queries.md#fields-description).
```log
# Query_time: 0.18085
...
# Num_cop_tasks: 1
# Cop_process: Avg_time: 170ms P90_time: 170ms Max_time: 170ms Max_addr: 10.6.131.78
# Cop_wait: Avg_time: 1ms P90_time: 1ms Max_time: 1ms Max_Addr: 10.6.131.78
```
After identifying that TiKV is the bottleneck, you can find out the cause as described in the following sections.
#### TiKV instance is busy
During the execution of a SQL statement, TiDB might fetch data from multiple TiKV instances. If one TiKV instance responds slowly, the overall SQL execution speed is slowed down.
The `Cop_wait` field in the slow log can help you determine this cause.
```log
# Cop_wait: Avg_time: 1ms P90_time: 2ms Max_time: 110ms Max_Addr: 10.6.131.78
```
The log above shows that a `cop-task` sent to the `10.6.131.78` instance waits `110ms` before being executed. It indicates that this instance is busy. You can check the CPU monitoring of that time to confirm the cause.
#### Too many outdated keys
A TiKV instance has much outdated data, which needs to be cleaned up for data scan. This impacts the processing speed.
Check `Total_keys` and `Processed_keys`. If they are greatly different, the TiKV instance has too many keys of the older versions.
```
...
# Total_keys: 2215187529 Processed_keys: 1108056368
...
```
### Other key stages are slow
#### Slow in getting timestamps
You can compare `Wait_TS` and `Query_time` in the slow log. The timestamps are prefetched, so generally `Wait_TS` should be low.
```
# Query_time: 0.0300000
...
# Wait_TS: 0.02500000
```
#### Outdated Region information
Region information on the TiDB side might be outdated. In this situation, TiKV might return the `regionMiss` error. Then TiDB gets the Region information from PD again, which is reflected in the `Cop_backoff` information. Both the failed times and the total duration are recorded.
```
# Cop_backoff_regionMiss_total_times: 200 Cop_backoff_regionMiss_total_time: 0.2 Cop_backoff_regionMiss_max_time: 0.2 Cop_backoff_regionMiss_max_addr: 127.0.0.1 Cop_backoff_regionMiss_avg_time: 0.2 Cop_backoff_regionMiss_p90_time: 0.2
# Cop_backoff_rpcPD_total_times: 200 Cop_backoff_rpcPD_total_time: 0.2 Cop_backoff_rpcPD_max_time: 0.2 Cop_backoff_rpcPD_max_addr: 127.0.0.1 Cop_backoff_rpcPD_avg_time: 0.2 Cop_backoff_rpcPD_p90_time: 0.2
```
#### Subqueries are executed in advance
For statements with non-correlated subqueries, the subquery part might be executed in advance. For example, in `select * from t1 where a = (select max(a) from t2)`, the `select max(a) from t2` part might be executed in advance in the optimization stage. The result of `EXPLAIN ANALYZE` does not show the duration of this type of subqueries.
```sql
mysql> explain analyze select count(*) from t where a=(select max(t1.a) from t t1, t t2 where t1.a=t2.a);
+------------------------------+----------+---------+-----------+---------------+--------------------------+----------------------------------+-----------+------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+------------------------------+----------+---------+-----------+---------------+--------------------------+----------------------------------+-----------+------+
| StreamAgg_59 | 1.00 | 1 | root | | time:4.69267ms, loops:2 | funcs:count(Column#10)->Column#8 | 372 Bytes | N/A |
| └─TableReader_60 | 1.00 | 1 | root | | time:4.690428ms, loops:2 | data:StreamAgg_48 | 141 Bytes | N/A |
| └─StreamAgg_48 | 1.00 | | cop[tikv] | | time:0ns, loops:0 | funcs:count(1)->Column#10 | N/A | N/A |
| └─Selection_58 | 16384.00 | | cop[tikv] | | time:0ns, loops:0 | eq(test.t.a, 1) | N/A | N/A |
| └─TableFullScan_57 | 16384.00 | -1 | cop[tikv] | table:t | time:0s, loops:0 | keep order:false | N/A | N/A |
+------------------------------+----------+---------+-----------+---------------+--------------------------+----------------------------------+-----------+------+
5 rows in set (7.77 sec)
```
But you can identify this type of subquery execution in the slow log:
```
# Query_time: 7.770634843
...
# Rewrite_time: 7.765673663 Preproc_subqueries: 1 Preproc_subqueries_time: 7.765231874
```
From log record above, you can see that a subquery is executed in advance and takes `7.76s`.
### TiDB is slow in execution
Assume that the execution plan in TiDB is correct but the execution is slow. To solve this type of issue, you can adjust parameters or use the hint according to the result of `EXPLAIN ANALYZE` for the SQL statement.
If the execution plan is incorrect, see the [Analyze optimizer issues](#analyze-optimizer-issues) section.
#### Low concurrency
If the bottleneck is in the operator with concurrency, speed up the execution by adjusting the concurrency. For example:
```sql
mysql> explain analyze select sum(t1.a) from t t1, t t2 where t1.a=t2.a;
+----------------------------------+--------------+-----------+-----------+---------------+-------------------------------------------------------------------------------------+------------------------------------------------+------------------+---------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+----------------------------------+--------------+-----------+-----------+---------------+-------------------------------------------------------------------------------------+------------------------------------------------+------------------+---------+
| HashAgg_11 | 1.00 | 1 | root | | time:9.666832189s, loops:2, PartialConcurrency:4, FinalConcurrency:4 | funcs:sum(Column#6)->Column#5 | 322.125 KB | N/A |
| └─Projection_24 | 268435456.00 | 268435456 | root | | time:9.098644711s, loops:262145, Concurrency:4 | cast(test.t.a, decimal(65,0) BINARY)->Column#6 | 199 KB | N/A |
| └─HashJoin_14 | 268435456.00 | 268435456 | root | | time:6.616773501s, loops:262145, Concurrency:5, probe collision:0, build:881.404µs | inner join, equal:[eq(test.t.a, test.t.a)] | 131.75 KB | 0 Bytes |
| ├─TableReader_21(Build) | 16384.00 | 16384 | root | | time:6.553717ms, loops:17 | data:Selection_20 | 33.6318359375 KB | N/A |
| │ └─Selection_20 | 16384.00 | | cop[tikv] | | time:0ns, loops:0 | not(isnull(test.t.a)) | N/A | N/A |
| │ └─TableFullScan_19 | 16384.00 | -1 | cop[tikv] | table:t2 | time:0s, loops:0 | keep order:false | N/A | N/A |
| └─TableReader_18(Probe) | 16384.00 | 16384 | root | | time:6.880923ms, loops:17 | data:Selection_17 | 33.6318359375 KB | N/A |
| └─Selection_17 | 16384.00 | | cop[tikv] | | time:0ns, loops:0 | not(isnull(test.t.a)) | N/A | N/A |
| └─TableFullScan_16 | 16384.00 | -1 | cop[tikv] | table:t1 | time:0s, loops:0 | keep order:false | N/A | N/A |
+----------------------------------+--------------+-----------+-----------+---------------+-------------------------------------------------------------------------------------+------------------------------------------------+------------------+---------+
9 rows in set (9.67 sec)
```
As shown above, `HashJoin_14` and `Projection_24` consume much of the execution time. Consider increasing their concurrency using SQL variables to speed up execution.
All system variables are documented in [system-variables](/system-variables.md). To increase the concurrency of `HashJoin_14`, you can modify the `tidb_hash_join_concurrency` system variable.
#### Data is spilled to disk
Another cause of slow execution is disk spill that occurs during execution if the memory limit is reached. You can find out this cause in the execution plan and the slow log:
```sql
+-------------------------+-----------+---------+-----------+---------------+------------------------------+----------------------+-----------------------+----------------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+-------------------------+-----------+---------+-----------+---------------+------------------------------+----------------------+-----------------------+----------------+
| Sort_4 | 462144.00 | 462144 | root | | time:2.02848898s, loops:453 | test.t.a | 149.68795776367188 MB | 219.3203125 MB |
| └─TableReader_8 | 462144.00 | 462144 | root | | time:616.211272ms, loops:453 | data:TableFullScan_7 | 197.49601364135742 MB | N/A |
| └─TableFullScan_7 | 462144.00 | -1 | cop[tikv] | table:t | time:0s, loops:0 | keep order:false | N/A | N/A |
+-------------------------+-----------+---------+-----------+---------------+------------------------------+----------------------+-----------------------+----------------+
```
```
...
# Disk_max: 229974016
...
```
#### Join operations with Cartesian product
Join operations with Cartesian product generate data volume as large as `left child row count * right child row count`. This is inefficient and should be avoided.
This type of join operations is marked `CARTESIAN` in the execution plan. For example:
```sql
mysql> explain select * from t t1, t t2 where t1.a>t2.a;
+------------------------------+-------------+-----------+---------------+---------------------------------------------------------+
| id | estRows | task | access object | operator info |
+------------------------------+-------------+-----------+---------------+---------------------------------------------------------+
| HashJoin_8 | 99800100.00 | root | | CARTESIAN inner join, other cond:gt(test.t.a, test.t.a) |
| ├─TableReader_15(Build) | 9990.00 | root | | data:Selection_14 |
| │ └─Selection_14 | 9990.00 | cop[tikv] | | not(isnull(test.t.a)) |
| │ └─TableFullScan_13 | 10000.00 | cop[tikv] | table:t2 | keep order:false, stats:pseudo |
| └─TableReader_12(Probe) | 9990.00 | root | | data:Selection_11 |
| └─Selection_11 | 9990.00 | cop[tikv] | | not(isnull(test.t.a)) |
| └─TableFullScan_10 | 10000.00 | cop[tikv] | table:t1 | keep order:false, stats:pseudo |
+------------------------------+-------------+-----------+---------------+---------------------------------------------------------+
```
## Analyze optimizer issues
To analyze optimizer issues, you need to determine whether the execution plan is reasonable or not. You need to have some understanding of the optimization process and each operator.
For the following examples, assume that the table schema is `create table t (id int, a int, b int, c int, primary key(id), key(a), key(b, c))`.
1. `select * from t`: There is no filter condition and a full table scan is performed. So the `TableFullScan` operator is used to read data.
2. `select a from t where a=2`: There is a filter condition and only the index columns are read, so the `IndexReader` operator is used to read data.
3. `select * from t where a=2`: There is a filter condition for `a` but the `a` index cannot fully cover the data to be read, so the `IndexLookup` operator is used.
4. `select b from t where c=3`: Without the prefix condition, the multi-column index cannot be used. So the `IndexFullScan` is used.
5. ...
The examples above are operators used for data reads. For more operators, see [Understand TiDB Execution Plan](/explain-overview.md).
In addition, reading [SQL Tuning Overview](/sql-tuning-overview.md) helps you better understand the TiDB optimizer and determine whether the execution plan is reasonable or not.
Most optimizer issues are explained in [SQL Tuning Overview](/sql-tuning-overview.md). For the solutions, see the following documents:
1. [Wrong Index Solution](/wrong-index-solution.md)
2. [Wrong join order](/join-reorder.md)
3. [Expressions are not pushed down](/blocklist-control-plan.md)
================================================
FILE: core/examples/quickstart/fixtures/tidb-overview.md
================================================
---
title: What is TiDB Self-Managed
summary: Learn about the key features and usage scenarios of TiDB.
aliases: ['/docs/dev/key-features/','/tidb/dev/key-features','/docs/dev/overview/']
---
# What is TiDB Self-Managed
[TiDB](https://github.com/pingcap/tidb) (/'taɪdiːbi:/, "Ti" stands for Titanium) is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. The goal of TiDB is to provide users with a one-stop database solution that covers OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services. TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data.
TiDB Self-Managed is a product option of TiDB, where users or organizations can deploy and manage TiDB on their own infrastructure with complete flexibility. With TiDB Self-Managed, you can enjoy the power of open source, distributed SQL while retaining full control over your environment.
The following video introduces key features of TiDB.
VIDEO
## Key features
- **Easy horizontal scaling**
The TiDB architecture design separates computing from storage, letting you scale out or scale in the computing or storage capacity online as needed. The scaling process is transparent to application operations and maintenance staff.
- **Financial-grade high availability**
Data is stored in multiple replicas, and the Multi-Raft protocol is used to obtain the transaction log. A transaction can only be committed when data has been successfully written into the majority of replicas. This guarantees strong consistency and availability when a minority of replicas go down. You can configure the geographic location and number of replicas as needed to meet different disaster tolerance levels.
- **Real-time HTAP**
TiDB provides two storage engines: [TiKV](/tikv-overview.md), a row-based storage engine, and [TiFlash](/tiflash/tiflash-overview.md), a columnar storage engine.
TiFlash uses the Multi-Raft Learner protocol to replicate data from TiKV in real time, ensuring consistent data between the TiKV row-based storage engine and the TiFlash columnar storage engine. TiKV and TiFlash can be deployed on different machines as needed to solve the problem of HTAP resource isolation.
- **Cloud-native distributed database**
TiDB is a distributed database designed for the cloud, providing flexible scalability, reliability, and security on the cloud platform. Users can elastically scale TiDB to meet the requirements of their changing workloads. In TiDB, each piece of data has at least 3 replicas, which can be scheduled in different cloud availability zones to tolerate the outage of a whole data center. [TiDB Operator](https://docs.pingcap.com/tidb-in-kubernetes/stable/tidb-operator-overview) helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. [TiDB Cloud](https://pingcap.com/tidb-cloud/), the fully-managed TiDB service, is the easiest, most economical, and most resilient way to unlock the full power of [TiDB in the cloud](https://docs.pingcap.com/tidbcloud/), allowing you to deploy and run TiDB clusters with just a few clicks.
- **Compatible with the MySQL protocol and MySQL ecosystem**
TiDB is compatible with the MySQL protocol, common features of MySQL, and the MySQL ecosystem. To migrate applications to TiDB, you do not need to change a single line of code in many cases, or only need to modify a small amount of code. In addition, TiDB provides a series of [data migration tools](/ecosystem-tool-user-guide.md) to help easily migrate application data into TiDB.
## See also
- [TiDB Architecture](/tidb-architecture.md)
- [TiDB Storage](/tidb-storage.md)
- [TiDB Computing](/tidb-computing.md)
- [TiDB Scheduling](/tidb-scheduling.md)
================================================
FILE: core/examples/quickstart/quickstart.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "b7e2255151321350",
"metadata": {},
"source": [
"# Autoflow\n",
"\n",
"Autoflow is a RAG framework supported:\n",
"\n",
"- Vector Search Based RAG\n",
"- Knowledge Graph Based RAG (aka. GraphRAG)\n",
"- Knowledge Base and Document Management"
]
},
{
"cell_type": "markdown",
"id": "f4c3f49f",
"metadata": {},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7bbed79850462cfe",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:23.724019Z",
"start_time": "2025-04-15T01:31:22.872381Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -q autoflow-ai==0.0.2.dev5 ipywidgets"
]
},
{
"cell_type": "markdown",
"id": "0b6d5be6",
"metadata": {},
"source": [
"## Prerequisites\n",
"\n",
"- Go [tidbcloud.com](https://tidbcloud.com/) or using [tiup playground](https://docs.pingcap.com/tidb/stable/tiup-playground/) to create a free TiDB database cluster.\n",
"- Go [OpenAI platform](https://platform.openai.com/api-keys) to create your API key."
]
},
{
"cell_type": "markdown",
"id": "66ea056f213efcae",
"metadata": {},
"source": [
"#### For Jupyter Notebook\n",
"\n",
"Configuration can be provided through environment variables, or using `.env`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7de9ab2c65f1880e",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:23.740325Z",
"start_time": "2025-04-15T01:31:23.729076Z"
}
},
"outputs": [],
"source": [
"%%bash\n",
"\n",
"# Check if the .env file is existing.\n",
"if [ -f .env ]; then\n",
" exit 0\n",
"fi\n",
"\n",
"# Create .env file with your configuration.\n",
"cat > .env <"
]
},
"execution_count": 6,
"metadata": {
"application/json": {
"expanded": false,
"root": "root"
}
},
"output_type": "execute_result"
}
],
"source": [
"from autoflow.configs.knowledge_base import IndexMethod\n",
"from autoflow.models.llms import LLM\n",
"from autoflow.models.embedding_models import EmbeddingModel\n",
"from IPython.display import JSON\n",
"\n",
"llm = LLM(\"gpt-4o-mini\")\n",
"embed_model = EmbeddingModel(\"text-embedding-3-small\")\n",
"\n",
"kb = af.create_knowledge_base(\n",
" namespace=\"quickstart\",\n",
" name=\"New KB\",\n",
" description=\"This is a knowledge base for testing\",\n",
" index_methods=[IndexMethod.VECTOR_SEARCH, IndexMethod.KNOWLEDGE_GRAPH],\n",
" llm=llm,\n",
" embedding_model=embed_model,\n",
")\n",
"JSON(kb.model_dump())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7c217f7f8cf956d8",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:26.736798Z",
"start_time": "2025-04-15T01:31:26.516452Z"
}
},
"outputs": [],
"source": [
"# Reset all the data of knowledge base.\n",
"kb.reset()"
]
},
{
"cell_type": "markdown",
"id": "d4ac8a82485d4232",
"metadata": {},
"source": [
"### Custom Chunker"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cddfe61c16ee934e",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:26.744529Z",
"start_time": "2025-04-15T01:31:26.740821Z"
}
},
"outputs": [],
"source": [
"from autoflow.chunkers.text import TextChunker\n",
"from autoflow.configs.chunkers.text import TextChunkerConfig\n",
"\n",
"text_chunker = TextChunker(config=TextChunkerConfig(chunk_size=256, chunk_overlap=20))"
]
},
{
"cell_type": "markdown",
"id": "4cfc2d80",
"metadata": {},
"source": [
"### Import documents from files"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f729326f",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:57.520138Z",
"start_time": "2025-04-15T01:31:26.749953Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" id \n",
" text \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0196384b-e01b-7e87-8ae3-ceaebc5ea4f0 \n",
" ---\\ntitle: What is TiDB Self-Managed\\nsummary: Learn about the key features and usage scenarios of TiDB.\\naliases: ['/docs/dev/key-features/','/tidb/dev/key-features','/docs/dev/overview/']\\n---\\n\\n# What is TiDB Self-Managed\\n\\n<!-- Localization note for TiDB:\\n\\n- English: use distributed SQL, and start to emphasize HTAP\\n- Chinese: can keep \"NewSQL\" and emphasize one-stop real-time HTAP (\"一栈式实时 HTAP\")\\n- Japanese: use NewSQL because it is well-recognized\\n\\n-->\\n\\n[TiDB](https://github.com/pingcap/tidb) (/'taɪdiːbi:/, \"Ti\" stands for Titanium) is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. The goal of TiDB is to provide users with a one-stop database solution that covers OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services. TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data. \n",
" \n",
" \n",
" 1 \n",
" 0196384b-e01b-7ebc-9a82-f51dac13ba5c \n",
" TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data.\\n\\nTiDB Self-Managed is a product option of TiDB, where users or organizations can deploy and manage TiDB on their own infrastructure with complete flexibility. With TiDB Self-Managed, you can enjoy the power of open source, distributed SQL while retaining full control over your environment.\\n\\nThe following video introduces key features of TiDB.\\n\\n<iframe width=\"600\" height=\"450\" src=\"https://www.youtube.com/embed/aWBNNPm21zg?enablejsapi=1\" title=\"Why TiDB?\" frameborder=\"0\" allow=\"accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture\" allowfullscreen></iframe>\\n\\n## Key features\\n\\n- **Easy horizontal scaling**\\n\\n The TiDB architecture design separates computing from storage, letting you scale out or scale in the computing or storage capacity online as needed. The scaling process is transparent to application operations and maintenance staff.\\n\\n- **Financial-grade high availability**\\n\\n Data is stored in multiple replicas, and the Multi-Raft protocol is used to obtain the transaction log. \n",
" \n",
" \n",
" 2 \n",
" 0196384b-e01b-7ed4-bf69-f65c861aedf7 \n",
" A transaction can only be committed when data has been successfully written into the majority of replicas. This guarantees strong consistency and availability when a minority of replicas go down. You can configure the geographic location and number of replicas as needed to meet different disaster tolerance levels.\\n\\n- **Real-time HTAP**\\n\\n TiDB provides two storage engines: [TiKV](/tikv-overview.md), a row-based storage engine, and [TiFlash](/tiflash/tiflash-overview.md), a columnar storage engine. \\n\\n TiFlash uses the Multi-Raft Learner protocol to replicate data from TiKV in real time, ensuring consistent data between the TiKV row-based storage engine and the TiFlash columnar storage engine. TiKV and TiFlash can be deployed on different machines as needed to solve the problem of HTAP resource isolation.\\n\\n- **Cloud-native distributed database**\\n\\n TiDB is a distributed database designed for the cloud, providing flexible scalability, reliability, and security on the cloud platform. Users can elastically scale TiDB to meet the requirements of their changing workloads. \n",
" \n",
" \n",
" 3 \n",
" 0196384b-e01b-7ee1-91f0-d57434e5d74d \n",
" Users can elastically scale TiDB to meet the requirements of their changing workloads. In TiDB, each piece of data has at least 3 replicas, which can be scheduled in different cloud availability zones to tolerate the outage of a whole data center. [TiDB Operator](https://docs.pingcap.com/tidb-in-kubernetes/stable/tidb-operator-overview) helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. [TiDB Cloud](https://pingcap.com/tidb-cloud/), the fully-managed TiDB service, is the easiest, most economical, and most resilient way to unlock the full power of [TiDB in the cloud](https://docs.pingcap.com/tidbcloud/), allowing you to deploy and run TiDB clusters with just a few clicks.\\n\\n- **Compatible with the MySQL protocol and MySQL ecosystem**\\n\\n TiDB is compatible with the MySQL protocol, common features of MySQL, and the MySQL ecosystem. To migrate applications to TiDB, you do not need to change a single line of code in many cases, or only need to modify a small amount of code. \n",
" \n",
" \n",
" 4 \n",
" 0196384b-e01b-7eed-a470-c9bcb5a7eccc \n",
" In addition, TiDB provides a series of [data migration tools](/ecosystem-tool-user-guide.md) to help easily migrate application data into TiDB.\\n\\n## See also\\n\\n- [TiDB Architecture](/tidb-architecture.md)\\n- [TiDB Storage](/tidb-storage.md)\\n- [TiDB Computing](/tidb-computing.md)\\n- [TiDB Scheduling](/tidb-scheduling.md) \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id \\\n",
"0 0196384b-e01b-7e87-8ae3-ceaebc5ea4f0 \n",
"1 0196384b-e01b-7ebc-9a82-f51dac13ba5c \n",
"2 0196384b-e01b-7ed4-bf69-f65c861aedf7 \n",
"3 0196384b-e01b-7ee1-91f0-d57434e5d74d \n",
"4 0196384b-e01b-7eed-a470-c9bcb5a7eccc \n",
"\n",
" text \n",
"0 ---\\ntitle: What is TiDB Self-Managed\\nsummary: Learn about the key features and usage scenarios of TiDB.\\naliases: ['/docs/dev/key-features/','/tidb/dev/key-features','/docs/dev/overview/']\\n---\\n\\n# What is TiDB Self-Managed\\n\\n\\n\\n[TiDB](https://github.com/pingcap/tidb) (/'taɪdiːbi:/, \"Ti\" stands for Titanium) is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. The goal of TiDB is to provide users with a one-stop database solution that covers OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services. TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data. \n",
"1 TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data.\\n\\nTiDB Self-Managed is a product option of TiDB, where users or organizations can deploy and manage TiDB on their own infrastructure with complete flexibility. With TiDB Self-Managed, you can enjoy the power of open source, distributed SQL while retaining full control over your environment.\\n\\nThe following video introduces key features of TiDB.\\n\\nVIDEO \\n\\n## Key features\\n\\n- **Easy horizontal scaling**\\n\\n The TiDB architecture design separates computing from storage, letting you scale out or scale in the computing or storage capacity online as needed. The scaling process is transparent to application operations and maintenance staff.\\n\\n- **Financial-grade high availability**\\n\\n Data is stored in multiple replicas, and the Multi-Raft protocol is used to obtain the transaction log. \n",
"2 A transaction can only be committed when data has been successfully written into the majority of replicas. This guarantees strong consistency and availability when a minority of replicas go down. You can configure the geographic location and number of replicas as needed to meet different disaster tolerance levels.\\n\\n- **Real-time HTAP**\\n\\n TiDB provides two storage engines: [TiKV](/tikv-overview.md), a row-based storage engine, and [TiFlash](/tiflash/tiflash-overview.md), a columnar storage engine. \\n\\n TiFlash uses the Multi-Raft Learner protocol to replicate data from TiKV in real time, ensuring consistent data between the TiKV row-based storage engine and the TiFlash columnar storage engine. TiKV and TiFlash can be deployed on different machines as needed to solve the problem of HTAP resource isolation.\\n\\n- **Cloud-native distributed database**\\n\\n TiDB is a distributed database designed for the cloud, providing flexible scalability, reliability, and security on the cloud platform. Users can elastically scale TiDB to meet the requirements of their changing workloads. \n",
"3 Users can elastically scale TiDB to meet the requirements of their changing workloads. In TiDB, each piece of data has at least 3 replicas, which can be scheduled in different cloud availability zones to tolerate the outage of a whole data center. [TiDB Operator](https://docs.pingcap.com/tidb-in-kubernetes/stable/tidb-operator-overview) helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. [TiDB Cloud](https://pingcap.com/tidb-cloud/), the fully-managed TiDB service, is the easiest, most economical, and most resilient way to unlock the full power of [TiDB in the cloud](https://docs.pingcap.com/tidbcloud/), allowing you to deploy and run TiDB clusters with just a few clicks.\\n\\n- **Compatible with the MySQL protocol and MySQL ecosystem**\\n\\n TiDB is compatible with the MySQL protocol, common features of MySQL, and the MySQL ecosystem. To migrate applications to TiDB, you do not need to change a single line of code in many cases, or only need to modify a small amount of code. \n",
"4 In addition, TiDB provides a series of [data migration tools](/ecosystem-tool-user-guide.md) to help easily migrate application data into TiDB.\\n\\n## See also\\n\\n- [TiDB Architecture](/tidb-architecture.md)\\n- [TiDB Storage](/tidb-storage.md)\\n- [TiDB Computing](/tidb-computing.md)\\n- [TiDB Scheduling](/tidb-scheduling.md) "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"docs = kb.add(\"./fixtures/tidb-overview.md\", chunker=text_chunker)\n",
"\n",
"DataFrame(\n",
" [(c.id, c.text) for c in docs[0].chunks],\n",
" columns=[\"id\", \"text\"],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "84fd9b606e6a11a5",
"metadata": {},
"source": [
"### Search Documents"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "259ad7a9",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:57.567655Z",
"start_time": "2025-04-15T01:31:57.543046Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" text \n",
" score \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" ---\\ntitle: What is TiDB Self-Managed\\nsummary: Learn about the key features and usage scenarios of TiDB.\\naliases: ['/docs/dev/key-features/','/tidb/dev/key-features','/docs/dev/overview/']\\n---\\n\\n# What is TiDB Self-Managed\\n\\n<!-- Localization note for TiDB:\\n\\n- English: use distributed SQL, and start to emphasize HTAP\\n- Chinese: can keep \"NewSQL\" and emphasize one-stop real-time HTAP (\"一栈式实时 HTAP\")\\n- Japanese: use NewSQL because it is well-recognized\\n\\n-->\\n\\n[TiDB](https://github.com/pingcap/tidb) (/'taɪdiːbi:/, \"Ti\" stands for Titanium) is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. The goal of TiDB is to provide users with a one-stop database solution that covers OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services. TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data. \n",
" 0.726047 \n",
" \n",
" \n",
" 1 \n",
" TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data.\\n\\nTiDB Self-Managed is a product option of TiDB, where users or organizations can deploy and manage TiDB on their own infrastructure with complete flexibility. With TiDB Self-Managed, you can enjoy the power of open source, distributed SQL while retaining full control over your environment.\\n\\nThe following video introduces key features of TiDB.\\n\\n<iframe width=\"600\" height=\"450\" src=\"https://www.youtube.com/embed/aWBNNPm21zg?enablejsapi=1\" title=\"Why TiDB?\" frameborder=\"0\" allow=\"accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture\" allowfullscreen></iframe>\\n\\n## Key features\\n\\n- **Easy horizontal scaling**\\n\\n The TiDB architecture design separates computing from storage, letting you scale out or scale in the computing or storage capacity online as needed. The scaling process is transparent to application operations and maintenance staff.\\n\\n- **Financial-grade high availability**\\n\\n Data is stored in multiple replicas, and the Multi-Raft protocol is used to obtain the transaction log. \n",
" 0.669803 \n",
" \n",
" \n",
" 2 \n",
" Users can elastically scale TiDB to meet the requirements of their changing workloads. In TiDB, each piece of data has at least 3 replicas, which can be scheduled in different cloud availability zones to tolerate the outage of a whole data center. [TiDB Operator](https://docs.pingcap.com/tidb-in-kubernetes/stable/tidb-operator-overview) helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. [TiDB Cloud](https://pingcap.com/tidb-cloud/), the fully-managed TiDB service, is the easiest, most economical, and most resilient way to unlock the full power of [TiDB in the cloud](https://docs.pingcap.com/tidbcloud/), allowing you to deploy and run TiDB clusters with just a few clicks.\\n\\n- **Compatible with the MySQL protocol and MySQL ecosystem**\\n\\n TiDB is compatible with the MySQL protocol, common features of MySQL, and the MySQL ecosystem. To migrate applications to TiDB, you do not need to change a single line of code in many cases, or only need to modify a small amount of code. \n",
" 0.656657 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text \\\n",
"0 ---\\ntitle: What is TiDB Self-Managed\\nsummary: Learn about the key features and usage scenarios of TiDB.\\naliases: ['/docs/dev/key-features/','/tidb/dev/key-features','/docs/dev/overview/']\\n---\\n\\n# What is TiDB Self-Managed\\n\\n\\n\\n[TiDB](https://github.com/pingcap/tidb) (/'taɪdiːbi:/, \"Ti\" stands for Titanium) is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. The goal of TiDB is to provide users with a one-stop database solution that covers OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services. TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data. \n",
"1 TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data.\\n\\nTiDB Self-Managed is a product option of TiDB, where users or organizations can deploy and manage TiDB on their own infrastructure with complete flexibility. With TiDB Self-Managed, you can enjoy the power of open source, distributed SQL while retaining full control over your environment.\\n\\nThe following video introduces key features of TiDB.\\n\\nVIDEO \\n\\n## Key features\\n\\n- **Easy horizontal scaling**\\n\\n The TiDB architecture design separates computing from storage, letting you scale out or scale in the computing or storage capacity online as needed. The scaling process is transparent to application operations and maintenance staff.\\n\\n- **Financial-grade high availability**\\n\\n Data is stored in multiple replicas, and the Multi-Raft protocol is used to obtain the transaction log. \n",
"2 Users can elastically scale TiDB to meet the requirements of their changing workloads. In TiDB, each piece of data has at least 3 replicas, which can be scheduled in different cloud availability zones to tolerate the outage of a whole data center. [TiDB Operator](https://docs.pingcap.com/tidb-in-kubernetes/stable/tidb-operator-overview) helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. [TiDB Cloud](https://pingcap.com/tidb-cloud/), the fully-managed TiDB service, is the easiest, most economical, and most resilient way to unlock the full power of [TiDB in the cloud](https://docs.pingcap.com/tidbcloud/), allowing you to deploy and run TiDB clusters with just a few clicks.\\n\\n- **Compatible with the MySQL protocol and MySQL ecosystem**\\n\\n TiDB is compatible with the MySQL protocol, common features of MySQL, and the MySQL ecosystem. To migrate applications to TiDB, you do not need to change a single line of code in many cases, or only need to modify a small amount of code. \n",
"\n",
" score \n",
"0 0.726047 \n",
"1 0.669803 \n",
"2 0.656657 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = kb.search_documents(\n",
" query=\"What is TiDB?\",\n",
" top_k=3,\n",
")\n",
"\n",
"DataFrame(\n",
" [(c.text, c.score) for c in result.chunks],\n",
" columns=[\"text\", \"score\"],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f2a0de8057cdf16b",
"metadata": {},
"source": [
"### Search Knowledge Graph"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "6fc5bc93",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:57.746250Z",
"start_time": "2025-04-15T01:31:57.605589Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" source_entity \n",
" relation \n",
" target_entity \n",
" score \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" TiDB \n",
" TiDB Storage is an essential part of how TiDB manages data. \n",
" TiDB Storage \n",
" 6.546173 \n",
" \n",
" \n",
" 1 \n",
" TiDB \n",
" TiDB provides TiKV as a row-based storage engine for data storage. \n",
" TiKV \n",
" 6.256637 \n",
" \n",
" \n",
" 2 \n",
" TiDB \n",
" TiDB Computing describes the processing capabilities of the TiDB database. \n",
" TiDB Computing \n",
" 5.975210 \n",
" \n",
" \n",
" 3 \n",
" TiDB \n",
" TiDB has key features that include easy horizontal scaling and financial-grade high availability. \n",
" Key features of TiDB \n",
" 5.648048 \n",
" \n",
" \n",
" 4 \n",
" TiDB \n",
" TiDB provides strong consistency, ensuring that all transactions are immediately visible to users. \n",
" Strong Consistency \n",
" 5.378570 \n",
" \n",
" \n",
" 5 \n",
" TiDB \n",
" TiDB Architecture is a key component of the TiDB database system. \n",
" TiDB Architecture \n",
" 5.374958 \n",
" \n",
" \n",
" 6 \n",
" TiDB \n",
" TiDB is designed for high availability, ensuring operational continuity even during failures. \n",
" High Availability \n",
" 5.220304 \n",
" \n",
" \n",
" 7 \n",
" TiDB \n",
" TiDB is MySQL compatible, enabling users to utilize existing MySQL applications with minimal adjustments. \n",
" MySQL Compatibility \n",
" 5.137373 \n",
" \n",
" \n",
" 8 \n",
" TiDB \n",
" TiDB offers OLAP services, enabling fast and interactive access to data for analytical purposes. \n",
" OLAP (Online Analytical Processing) \n",
" 5.021178 \n",
" \n",
" \n",
" 9 \n",
" TiDB \n",
" TiDB provides a series of data migration tools to help easily migrate application data into TiDB. \n",
" data migration tools \n",
" 5.002972 \n",
" \n",
" \n",
" 10 \n",
" TiDB \n",
" TiDB provides TiFlash as a columnar storage engine that replicates data from TiKV. \n",
" TiFlash \n",
" 4.756693 \n",
" \n",
" \n",
" 11 \n",
" TiDB \n",
" TiDB features horizontal scalability, allowing it to expand capacity by adding more machines to the cluster. \n",
" Horizontal Scalability \n",
" 4.715631 \n",
" \n",
" \n",
" 12 \n",
" TiDB \n",
" TiDB provides OLTP services, facilitating transaction-oriented applications for data entry and retrieval. \n",
" OLTP (Online Transactional Processing) \n",
" 4.683033 \n",
" \n",
" \n",
" 13 \n",
" TiDB \n",
" TiDB supports Hybrid Transactional and Analytical Processing (HTAP) workloads, allowing for simultaneous handling of transactional and analytical tasks. \n",
" Hybrid Transactional and Analytical Processing (HTAP) \n",
" 4.431353 \n",
" \n",
" \n",
" 14 \n",
" TiDB \n",
" TiDB Self-Managed is a product option of TiDB that provides users with the ability to deploy and manage TiDB on their own infrastructure. \n",
" TiDB Self-Managed \n",
" 4.256526 \n",
" \n",
" \n",
" 15 \n",
" TiDB \n",
" TiDB Scheduling is involved in managing the execution of tasks within the TiDB database. \n",
" TiDB Scheduling \n",
" 4.245480 \n",
" \n",
" \n",
" 16 \n",
" TiDB \n",
" TiDB is designed as a cloud-native distributed database providing flexible scalability and reliability. \n",
" Cloud-native \n",
" 4.219989 \n",
" \n",
" \n",
" 17 \n",
" TiDB \n",
" TiDB uses the Multi-Raft protocol to ensure high availability by managing transaction logs across multiple replicas. \n",
" Multi-Raft protocol \n",
" 3.848345 \n",
" \n",
" \n",
" 18 \n",
" Key features of TiDB \n",
" Another key feature of TiDB is financial-grade high availability, which is achieved through data replication. \n",
" Financial-grade high availability \n",
" 3.475075 \n",
" \n",
" \n",
" 19 \n",
" Key features of TiDB \n",
" One of the key features of TiDB is easy horizontal scaling, which allows for flexible resource management. \n",
" Easy horizontal scaling \n",
" 3.398429 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" source_entity \\\n",
"0 TiDB \n",
"1 TiDB \n",
"2 TiDB \n",
"3 TiDB \n",
"4 TiDB \n",
"5 TiDB \n",
"6 TiDB \n",
"7 TiDB \n",
"8 TiDB \n",
"9 TiDB \n",
"10 TiDB \n",
"11 TiDB \n",
"12 TiDB \n",
"13 TiDB \n",
"14 TiDB \n",
"15 TiDB \n",
"16 TiDB \n",
"17 TiDB \n",
"18 Key features of TiDB \n",
"19 Key features of TiDB \n",
"\n",
" relation \\\n",
"0 TiDB Storage is an essential part of how TiDB manages data. \n",
"1 TiDB provides TiKV as a row-based storage engine for data storage. \n",
"2 TiDB Computing describes the processing capabilities of the TiDB database. \n",
"3 TiDB has key features that include easy horizontal scaling and financial-grade high availability. \n",
"4 TiDB provides strong consistency, ensuring that all transactions are immediately visible to users. \n",
"5 TiDB Architecture is a key component of the TiDB database system. \n",
"6 TiDB is designed for high availability, ensuring operational continuity even during failures. \n",
"7 TiDB is MySQL compatible, enabling users to utilize existing MySQL applications with minimal adjustments. \n",
"8 TiDB offers OLAP services, enabling fast and interactive access to data for analytical purposes. \n",
"9 TiDB provides a series of data migration tools to help easily migrate application data into TiDB. \n",
"10 TiDB provides TiFlash as a columnar storage engine that replicates data from TiKV. \n",
"11 TiDB features horizontal scalability, allowing it to expand capacity by adding more machines to the cluster. \n",
"12 TiDB provides OLTP services, facilitating transaction-oriented applications for data entry and retrieval. \n",
"13 TiDB supports Hybrid Transactional and Analytical Processing (HTAP) workloads, allowing for simultaneous handling of transactional and analytical tasks. \n",
"14 TiDB Self-Managed is a product option of TiDB that provides users with the ability to deploy and manage TiDB on their own infrastructure. \n",
"15 TiDB Scheduling is involved in managing the execution of tasks within the TiDB database. \n",
"16 TiDB is designed as a cloud-native distributed database providing flexible scalability and reliability. \n",
"17 TiDB uses the Multi-Raft protocol to ensure high availability by managing transaction logs across multiple replicas. \n",
"18 Another key feature of TiDB is financial-grade high availability, which is achieved through data replication. \n",
"19 One of the key features of TiDB is easy horizontal scaling, which allows for flexible resource management. \n",
"\n",
" target_entity score \n",
"0 TiDB Storage 6.546173 \n",
"1 TiKV 6.256637 \n",
"2 TiDB Computing 5.975210 \n",
"3 Key features of TiDB 5.648048 \n",
"4 Strong Consistency 5.378570 \n",
"5 TiDB Architecture 5.374958 \n",
"6 High Availability 5.220304 \n",
"7 MySQL Compatibility 5.137373 \n",
"8 OLAP (Online Analytical Processing) 5.021178 \n",
"9 data migration tools 5.002972 \n",
"10 TiFlash 4.756693 \n",
"11 Horizontal Scalability 4.715631 \n",
"12 OLTP (Online Transactional Processing) 4.683033 \n",
"13 Hybrid Transactional and Analytical Processing (HTAP) 4.431353 \n",
"14 TiDB Self-Managed 4.256526 \n",
"15 TiDB Scheduling 4.245480 \n",
"16 Cloud-native 4.219989 \n",
"17 Multi-Raft protocol 3.848345 \n",
"18 Financial-grade high availability 3.475075 \n",
"19 Easy horizontal scaling 3.398429 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kg = kb.search_knowledge_graph(\n",
" query=\"What is TiDB?\",\n",
")\n",
"\n",
"# Notice: score is the result of a weighted formula\n",
"\n",
"DataFrame(\n",
" [\n",
" (r.source_entity.name, r.description, r.target_entity.name, r.score)\n",
" for r in kg.relationships\n",
" ],\n",
" columns=[\"source_entity\", \"relation\", \"target_entity\", \"score\"],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c1f1920c",
"metadata": {},
"source": [
"### Ask question"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "54bab89a",
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"TiDB is an open-source distributed SQL database designed to support Hybrid Transactional and Analytical Processing (HTAP) workloads. It is compatible with MySQL, allowing users to leverage existing MySQL applications and tools with minimal changes. TiDB features several key attributes:\n",
"\n",
"1. **High Availability**: TiDB is designed to ensure operational continuity even during failures, providing financial-grade high availability by storing data in multiple replicas.\n",
"\n",
"2. **Strong Consistency**: It guarantees that all transactions are immediately visible to all users, ensuring a reliable and predictable database experience.\n",
"\n",
"3. **Horizontal Scalability**: TiDB allows for easy horizontal scaling by separating computing from storage, enabling users to scale out or scale in their computing or storage capacity online as needed.\n",
"\n",
"4. **Support for OLTP and OLAP**: TiDB provides a one-stop database solution that covers Online Transactional Processing (OLTP), Online Analytical Processing (OLAP), and HTAP services, making it suitable for various use cases that require high availability and strong consistency with large-scale data.\n",
"\n",
"5. **Cloud-native Architecture**: TiDB is designed for cloud environments, offering flexible scalability, reliability, and security on cloud platforms.\n",
"\n",
"6. **Data Migration Tools**: TiDB includes a series of data migration tools to facilitate the easy transfer of application data into the TiDB database.\n",
"\n",
"7. **Storage Engines**: TiDB utilizes two storage engines: TiKV, a row-based storage engine, and TiFlash, a columnar storage engine that replicates data from TiKV in real time.\n",
"\n",
"Overall, TiDB aims to provide users with a robust and flexible database solution that can adapt to changing workloads and requirements."
],
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import Markdown\n",
"\n",
"res = kb.ask(\"What is TiDB?\")\n",
"Markdown(res.message.content)"
]
},
{
"cell_type": "markdown",
"id": "3ed0149fb5a9e1cb",
"metadata": {},
"source": [
"### Reset the KnowledgeBase"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "4303dc61b3f073f1",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-15T01:31:57.880832Z",
"start_time": "2025-04-15T01:31:57.878931Z"
}
},
"outputs": [],
"source": [
"# kb.reset()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: core/examples/streamlit/README.md
================================================
# Streamlit Examples
This project demonstrates how to use AutoFlow as a Retrieval-Augmented Generation (RAG) framework and Streamlit as the web framework for building a knowledge search application.
## Prerequisites
- **Python 3.12.4** (Check the version specified in `.python-version`). You can use `pyenv` to manage your Python versions.
- **macOS users:** Ensure `mysqlclient` is installed.
## Installation and Usage
**Step 1: Install Dependencies**
Create a virtual environment and install the required packages:
```bash
python -m venv .venv
source .venv/bin/activate
pip install -r reqs.txt
```
**Step 2: Run the Streamlit App**
Start the application with:
```bash
streamlit run build-knowledge-search-with-autoflow-and-streamlit.py
```
**Step 3: Open in Browser**
Once the app is running, open http://localhost:8501 in your browser and follow these steps:
1. Enter your [OpenAI API key](https://platform.openai.com/api-keys) in left sidebar
2. Enter your TiDB Cloud connection string in the left sidebar. Use the SQLAlchemy format ( `mysql+pymysql://root@localhost:4000/test`) To obtain your connection string, go to the TiDB Cloud console, click **Connect**, and select `SQLAlchemy` from the **Connect With** dropdown.
3. Press **Enter** to start initialization (this may take a few minutes).
4. Once initialization completes, upload a file to build the knowledge base. Then enter your queries to retrieve answers from the newly generated knowledge base.
================================================
FILE: core/examples/streamlit/build-knowledge-search-with-autoflow-and-streamlit.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from uuid import UUID
import streamlit as st
from sqlalchemy import create_engine
from autoflow import Autoflow
from autoflow.types import IndexMethod
from autoflow.llms.chat_models import ChatModel
from autoflow.llms.embeddings import EmbeddingModel
from llama_index.core.llms import ChatMessage
st.set_page_config(
page_title="Demo of Autoflow and Streamlit", page_icon="📖", layout="wide"
)
st.header("📖 Knowledge base app built with Autoflow and Streamlit")
with st.sidebar:
st.markdown(
"## How to use\n"
"1. Enter your [OpenAI API key](https://platform.openai.com/account/api-keys) below 🔑\n" # noqa: E501
"2. Enter your [TiDB Cloud](https://tidbcloud.com) database connection URL below 🔗\n"
"3. Upload a pdf, docx, or txt file 📄\n"
"4. Ask a question about the document 💬\n"
)
openai_api_key_input = st.text_input(
"OpenAI API Key",
type="password",
placeholder="Paste your OpenAI API key here (sk-...)",
help="You can get your API key from https://platform.openai.com/account/api-keys.", # noqa: E501
value=os.environ.get("OPENAI_API_KEY", None)
or st.session_state.get("OPENAI_API_KEY", ""),
)
database_url_input = st.text_input(
"Database URL",
type="password",
placeholder="e.g. mysql+pymysql://root@localhost:4000/test",
autocomplete="off",
help="You can get your database URL from https://tidbcloud.com",
value=os.environ.get("DATABASE_URL", None)
or "mysql+pymysql://root@localhost:4000/test"
or st.session_state.get("DATABASE_URL", ""),
)
st.session_state["OPENAI_API_KEY"] = openai_api_key_input
st.session_state["DATABASE_URL"] = database_url_input
openai_api_key = st.session_state.get("OPENAI_API_KEY")
database_url = st.session_state.get("DATABASE_URL")
if not openai_api_key or not database_url:
st.error("Please enter your OpenAI API key and TiDB Cloud connection string.")
st.stop()
af = Autoflow(create_engine(database_url))
chat_model = ChatModel("gpt-4o-mini", api_key=openai_api_key)
embedding_model = EmbeddingModel(
model_name="text-embedding-3-small",
dimensions=1536,
api_key=openai_api_key,
)
kb = af.create_knowledge_base(
id=UUID(
"655b6cf3-8b30-4839-ba8b-5ed3c502f30e"
), # For not creating a new KB every time
name="New KB",
description="This is a knowledge base for testing",
index_methods=[IndexMethod.VECTOR_SEARCH, IndexMethod.KNOWLEDGE_GRAPH],
chat_model=chat_model,
embedding_model=embedding_model,
)
with st.form(key="file_upload_form"):
uploaded_file = st.file_uploader(
"Upload a .pdf, .docx, .md or .txt file",
type=["pdf", "docx", "txt", "md"],
help="Scanned documents are not supported yet!",
)
upload = st.form_submit_button("Upload")
if upload:
if not uploaded_file:
st.error("Please upload a valid file.")
st.stop()
file_path = f"/tmp/{uploaded_file.name}"
with st.spinner(
"Indexing document... This may take a while ⏳(import time; time.sleep(3))"
):
with open(file_path, "wb") as f:
f.write(uploaded_file.getvalue())
kb.add(file_path)
import time
time.sleep(3)
for line in ["generated", "past", "corpus"]:
if line not in st.session_state:
st.session_state[line] = []
for o in ["kg"]:
if o not in st.session_state:
st.session_state[o] = None
def on_submit():
user_input = st.session_state.user_input
if user_input:
result = kb.search_documents(query=user_input, similarity_top_k=3)
st.session_state["corpus"] = result.chunks
kg = kb.search_knowledge_graph(query=user_input)
st.session_state["kg"] = kg
messages = [
ChatMessage(
role="system",
content="Here are some relevant documents about your query:\n\n"
+ "\n".join(c.chunk.text for c in result.chunks),
),
ChatMessage(
role="user",
content=user_input + "\n(in markdown, removed unused breaklines)",
),
]
resp = chat_model.chat(messages)
st.session_state.past.append(user_input)
st.session_state.generated.append(str(resp.message))
chat_section, corpus_section = st.columns(2)
with chat_section:
st.markdown("##### Chats")
chat_placeholder = st.empty()
with chat_placeholder.container():
for i in range(len(st.session_state["generated"])):
with st.chat_message("user"):
st.write(st.session_state["past"][i])
with st.chat_message("assistant"):
st.write(st.session_state["generated"][i])
with st.container():
st.chat_input(
"Input your question about this document here.",
key="user_input",
on_submit=on_submit,
)
with corpus_section:
st.markdown("##### Vector Search Results")
corpus_placeholder = st.empty()
with corpus_placeholder.container():
[c.chunk for c in st.session_state["corpus"]] if st.session_state[
"corpus"
] else "Please input a query left."
st.markdown("##### Knowledge Graph Search Results")
kg_placeholder = st.empty()
with kg_placeholder.container():
kg = st.session_state["kg"]
[
r.rag_description for r in kg.relationships
] if kg else "Please input a query left."
================================================
FILE: core/examples/streamlit/reqs.txt
================================================
autoflow-ai==0.0.1.dev25
autoflow-ai[experiment]
pymysql
streamlit
httpx[socks]
================================================
FILE: core/experimental/README.md
================================================
# Autoflow Experimental
This folder contains experimental features that are not yet ready for production use.
================================================
FILE: core/experimental/kg_extraction/extract_graph.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"pd.set_option(\"max_colwidth\", None)\n",
"pd.set_option(\"max_seq_items\", None)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"from autoflow.knowledge_graph.programs.extract_graph import (\n",
" KnowledgeGraphExtractor,\n",
" KnowledgeGraphOutput,\n",
")\n",
"from autoflow.models.llms import LLM\n",
"from autoflow.models.llms.dspy import get_dspy_lm_by_llm\n",
"\n",
"load_dotenv()\n",
"\n",
"llm = LLM()\n",
"dspy_lm = get_dspy_lm_by_llm(llm)\n",
"extractor = KnowledgeGraphExtractor(dspy_lm=dspy_lm)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"text = Path(\"../tests/fixtures/tidb-overview.md\").read_text()\n",
"actual_knowledge = extractor.forward(text)\n",
"df = actual_knowledge.to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" name \n",
" description \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" TiDB \n",
" TiDB is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
" \n",
" \n",
" 1 \n",
" TiDB Self-Managed \n",
" TiDB Self-Managed is a product option of TiDB that allows users or organizations to deploy and manage TiDB on their own infrastructure with complete flexibility. Users can enjoy the power of open-source distributed SQL while retaining full control over their environment. \n",
" \n",
" \n",
" 2 \n",
" Hybrid Transactional and Analytical Processing (HTAP) \n",
" HTAP is a database processing technique that enables both transactional and analytical workloads to be performed on the same database system without the need for separate systems or data duplication. \n",
" \n",
" \n",
" 3 \n",
" OLTP \n",
" OLTP (Online Transactional Processing) is a class of systems that facilitate and manage transaction-oriented applications, typically for data entry and retrieval transactions in a database. \n",
" \n",
" \n",
" 4 \n",
" OLAP \n",
" OLAP (Online Analytical Processing) is a category of software tools that provide analysis of data for business decisions. It allows users to analyze multidimensional data interactively from multiple perspectives. \n",
" \n",
" \n",
" 5 \n",
" TiKV \n",
" TiKV is a row-based storage engine provided by TiDB for storing data in a distributed manner. \n",
" \n",
" \n",
" 6 \n",
" TiFlash \n",
" TiFlash is a columnar storage engine provided by TiDB for real-time replication of data from TiKV, ensuring consistent data between row-based and columnar storage. \n",
" \n",
" \n",
" 7 \n",
" TiDB Operator \n",
" TiDB Operator is a tool that helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. \n",
" \n",
" \n",
" 8 \n",
" TiDB Cloud \n",
" TiDB Cloud is a fully-managed TiDB service that provides the easiest, most economical, and most resilient way to deploy and run TiDB clusters in the cloud. \n",
" \n",
" \n",
" 9 \n",
" MySQL Protocol \n",
" The MySQL Protocol is the protocol used by MySQL database systems for communication between clients and servers. TiDB is compatible with the MySQL protocol, allowing seamless integration with the MySQL ecosystem. \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" name \\\n",
"0 TiDB \n",
"1 TiDB Self-Managed \n",
"2 Hybrid Transactional and Analytical Processing (HTAP) \n",
"3 OLTP \n",
"4 OLAP \n",
"5 TiKV \n",
"6 TiFlash \n",
"7 TiDB Operator \n",
"8 TiDB Cloud \n",
"9 MySQL Protocol \n",
"\n",
" description \n",
"0 TiDB is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
"1 TiDB Self-Managed is a product option of TiDB that allows users or organizations to deploy and manage TiDB on their own infrastructure with complete flexibility. Users can enjoy the power of open-source distributed SQL while retaining full control over their environment. \n",
"2 HTAP is a database processing technique that enables both transactional and analytical workloads to be performed on the same database system without the need for separate systems or data duplication. \n",
"3 OLTP (Online Transactional Processing) is a class of systems that facilitate and manage transaction-oriented applications, typically for data entry and retrieval transactions in a database. \n",
"4 OLAP (Online Analytical Processing) is a category of software tools that provide analysis of data for business decisions. It allows users to analyze multidimensional data interactively from multiple perspectives. \n",
"5 TiKV is a row-based storage engine provided by TiDB for storing data in a distributed manner. \n",
"6 TiFlash is a columnar storage engine provided by TiDB for real-time replication of data from TiKV, ensuring consistent data between row-based and columnar storage. \n",
"7 TiDB Operator is a tool that helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. \n",
"8 TiDB Cloud is a fully-managed TiDB service that provides the easiest, most economical, and most resilient way to deploy and run TiDB clusters in the cloud. \n",
"9 The MySQL Protocol is the protocol used by MySQL database systems for communication between clients and servers. TiDB is compatible with the MySQL protocol, allowing seamless integration with the MySQL ecosystem. "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[\"entities\"]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" source_entity \n",
" relationship_desc \n",
" target_entity \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" TiDB \n",
" TiDB Self-Managed is a product option of TiDB that allows users or organizations to deploy and manage TiDB on their own infrastructure with complete flexibility. \n",
" TiDB Self-Managed \n",
" \n",
" \n",
" 1 \n",
" TiDB \n",
" TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
" OLTP \n",
" \n",
" \n",
" 2 \n",
" TiDB \n",
" TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
" OLAP \n",
" \n",
" \n",
" 3 \n",
" TiDB \n",
" TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
" Hybrid Transactional and Analytical Processing (HTAP) \n",
" \n",
" \n",
" 4 \n",
" TiDB \n",
" TiDB provides two storage engines: TiKV, a row-based storage engine, and TiFlash, a columnar storage engine. \n",
" TiKV \n",
" \n",
" \n",
" 5 \n",
" TiDB \n",
" TiDB provides two storage engines: TiKV, a row-based storage engine, and TiFlash, a columnar storage engine. \n",
" TiFlash \n",
" \n",
" \n",
" 6 \n",
" TiDB \n",
" TiDB is compatible with the MySQL protocol, allowing seamless integration with the MySQL ecosystem. \n",
" MySQL Protocol \n",
" \n",
" \n",
" 7 \n",
" TiDB Cloud \n",
" TiDB Cloud is a fully-managed TiDB service that provides the easiest, most economical, and most resilient way to deploy and run TiDB clusters in the cloud. \n",
" TiDB \n",
" \n",
" \n",
" 8 \n",
" TiDB Operator \n",
" TiDB Operator helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. \n",
" TiDB \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" source_entity \\\n",
"0 TiDB \n",
"1 TiDB \n",
"2 TiDB \n",
"3 TiDB \n",
"4 TiDB \n",
"5 TiDB \n",
"6 TiDB \n",
"7 TiDB Cloud \n",
"8 TiDB Operator \n",
"\n",
" relationship_desc \\\n",
"0 TiDB Self-Managed is a product option of TiDB that allows users or organizations to deploy and manage TiDB on their own infrastructure with complete flexibility. \n",
"1 TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
"2 TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
"3 TiDB provides users with a one-stop database solution covering OLTP, OLAP, and HTAP services. \n",
"4 TiDB provides two storage engines: TiKV, a row-based storage engine, and TiFlash, a columnar storage engine. \n",
"5 TiDB provides two storage engines: TiKV, a row-based storage engine, and TiFlash, a columnar storage engine. \n",
"6 TiDB is compatible with the MySQL protocol, allowing seamless integration with the MySQL ecosystem. \n",
"7 TiDB Cloud is a fully-managed TiDB service that provides the easiest, most economical, and most resilient way to deploy and run TiDB clusters in the cloud. \n",
"8 TiDB Operator helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. \n",
"\n",
" target_entity \n",
"0 TiDB Self-Managed \n",
"1 OLTP \n",
"2 OLAP \n",
"3 Hybrid Transactional and Analytical Processing (HTAP) \n",
"4 TiKV \n",
"5 TiFlash \n",
"6 MySQL Protocol \n",
"7 TiDB \n",
"8 TiDB "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[\"relationships\"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"expected_knowledge = KnowledgeGraphOutput(\n",
" entities=[\n",
" {\"name\": \"TiDB\", \"description\": \"A distributed SQL database\"},\n",
" {\"name\": \"TiKV\", \"description\": \"TiKV is a row-based storage engine.\"},\n",
" {\"name\": \"TiFlash\", \"description\": \"TiFlash is a column-based storage engine.\"},\n",
" {\n",
" \"name\": \"OLTP\",\n",
" \"description\": \"OLTP is a type of database that is optimized for transaction processing.\",\n",
" },\n",
" {\n",
" \"name\": \"OLAP\",\n",
" \"description\": \"OLAP is a type of database that is optimized for analytics.\",\n",
" },\n",
" {\n",
" \"name\": \"HTAP\",\n",
" \"description\": \"HTAP is a type of database that is optimized for both transaction processing and analytics.\",\n",
" },\n",
" {\n",
" \"name\": \"TiDB Self-Managed\",\n",
" \"description\": \"TiDB Self-Managed is a product option of TiDB.\",\n",
" },\n",
" {\n",
" \"name\": \"TiDB Cloud\",\n",
" \"description\": \"TiDB Cloud is a fully-managed TiDB service.\",\n",
" },\n",
" {\n",
" \"name\": \"TiDB Operator\",\n",
" \"description\": \"TiDB Operator is a tool that helps manage TiDB on Kubernetes.\",\n",
" },\n",
" ],\n",
" relationships=[\n",
" {\n",
" \"source_entity\": \"TiDB\",\n",
" \"target_entity\": \"TiKV\",\n",
" \"relationship_desc\": \"TiDB uses TiKV as its storage engine\",\n",
" },\n",
" {\n",
" \"source_entity\": \"TiDB\",\n",
" \"target_entity\": \"TiFlash\",\n",
" \"relationship_desc\": \"TiDB uses TiFlash as its analytics engine\",\n",
" },\n",
" {\n",
" \"source_entity\": \"TiDB\",\n",
" \"target_entity\": \"OLTP\",\n",
" \"relationship_desc\": \"TiDB supports OLTP workloads\",\n",
" },\n",
" {\n",
" \"source_entity\": \"TiDB\",\n",
" \"target_entity\": \"OLAP\",\n",
" \"relationship_desc\": \"TiDB supports OLAP workloads\",\n",
" },\n",
" {\n",
" \"source_entity\": \"TiDB\",\n",
" \"target_entity\": \"HTAP\",\n",
" \"relationship_desc\": \"TiDB supports HTAP workloads\",\n",
" },\n",
" {\n",
" \"source_entity\": \"TiDB Self-Managed\",\n",
" \"target_entity\": \"TiDB\",\n",
" \"relationship_desc\": \"TiDB Self-Managed is a product option of TiDB\",\n",
" },\n",
" {\n",
" \"source_entity\": \"TiDB Cloud\",\n",
" \"target_entity\": \"TiDB\",\n",
" \"relationship_desc\": \"TiDB Cloud is a fully-managed TiDB service.\",\n",
" },\n",
" {\n",
" \"source_entity\": \"TiDB Operator\",\n",
" \"target_entity\": \"TiDB Cloud\",\n",
" \"relationship_desc\": \"TiDB Operator is a tool that helps manage TiDB on Kubernetes.\",\n",
" },\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-11T09:00:13.095249Z",
"start_time": "2025-04-11T09:00:11.025732Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"0.6111111111111112"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from autoflow.knowledge_graph.programs.eval_graph import KnowledgeGraphEvaluator\n",
"\n",
"evaluator = KnowledgeGraphEvaluator(dspy_lm=dspy_lm)\n",
"\n",
"evaluation_result = evaluator.forward(expected_knowledge, actual_knowledge)\n",
"evaluation_result.score"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "\"EntityOutput\" object has no field \"metadata\"",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mValueError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mautoflow\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mknowledge_graph\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mprograms\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mextract_covariates\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m EntityCovariateExtractor\n\u001b[32m 3\u001b[39m metadata_extractor = EntityCovariateExtractor(dspy_lm=dspy_lm)\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m actual_knowledge.entities = \u001b[43mmetadata_extractor\u001b[49m\u001b[43m.\u001b[49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mactual_knowledge\u001b[49m\u001b[43m.\u001b[49m\u001b[43mentities\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m entity \u001b[38;5;129;01min\u001b[39;00m actual_knowledge.entities:\n\u001b[32m 7\u001b[39m \u001b[38;5;28mprint\u001b[39m(entity.name, entity.metadata)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Projects/autoflow.ai/core/autoflow/knowledge_graph/programs/extract_covariates.py:81\u001b[39m, in \u001b[36mEntityCovariateExtractor.forward\u001b[39m\u001b[34m(self, text, entities)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m entity \u001b[38;5;129;01min\u001b[39;00m entities:\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m entity.name \u001b[38;5;129;01min\u001b[39;00m output_entity_map:\n\u001b[32m 80\u001b[39m \u001b[38;5;66;03m# Update the covariates in the metadata of the entity.\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[43mentity\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmetadata\u001b[49m = output_entity_map[entity.name].covariates\n\u001b[32m 83\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m entities\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/Projects/autoflow.ai/core/.venv/lib/python3.12/site-packages/pydantic/main.py:925\u001b[39m, in \u001b[36mBaseModel.__setattr__\u001b[39m\u001b[34m(self, name, value)\u001b[39m\n\u001b[32m 922\u001b[39m \u001b[38;5;28mself\u001b[39m.__pydantic_validator__.validate_assignment(\u001b[38;5;28mself\u001b[39m, name, value)\n\u001b[32m 923\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.model_config.get(\u001b[33m'\u001b[39m\u001b[33mextra\u001b[39m\u001b[33m'\u001b[39m) != \u001b[33m'\u001b[39m\u001b[33mallow\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m name \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.__pydantic_fields__:\n\u001b[32m 924\u001b[39m \u001b[38;5;66;03m# TODO - matching error\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m925\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.\u001b[34m__class__\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\u001b[33m object has no field \u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 926\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.model_config.get(\u001b[33m'\u001b[39m\u001b[33mextra\u001b[39m\u001b[33m'\u001b[39m) == \u001b[33m'\u001b[39m\u001b[33mallow\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m name \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.__pydantic_fields__:\n\u001b[32m 927\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.model_extra \u001b[38;5;129;01mand\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.model_extra:\n",
"\u001b[31mValueError\u001b[39m: \"EntityOutput\" object has no field \"metadata\""
]
}
],
"source": [
"from autoflow.knowledge_graph.programs.extract_covariates import (\n",
" EntityCovariateExtractor,\n",
")\n",
"\n",
"metadata_extractor = EntityCovariateExtractor(dspy_lm=dspy_lm)\n",
"actual_knowledge.entities = metadata_extractor.forward(text, actual_knowledge.entities)\n",
"\n",
"for entity in actual_knowledge.entities:\n",
" print(entity.name, entity.metadata)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: core/pyproject.toml
================================================
[project]
name = "autoflow-ai"
version = "0.0.2.dev5"
description = "Framework to index and search your vector data, make your data ready for AI apps, developed by TiDB."
authors = [
{ name = "Mini256", email = "minianter@foxmail.com" },
{ name = "sykp241095", email = "sykp241095@gmail.com" },
]
readme = "README.md"
requires-python = ">=3.10"
# Dependencies
dependencies = [
"litellm>=1.77.5",
"llama-index-core>=0.12.23.post2",
"llama-index-llms-litellm>=0.3.0",
"llama-index-readers-file>=0.4.6",
"deepdiff>=8.2.0",
"pytidb==0.0.4.dev1",
"markdownify>=0.13.1",
"playwright>=1.20.0",
"dspy>=2.6.6",
"tokenizers>=0.21.0",
"mypy>=1.15.0",
"banks>=2.1.1",
]
[dependency-groups]
dev = [
"pytest>=8.3.4",
"pytest-dotenv>=0.5.2",
"pre-commit>=4.1.0",
"pip>=25.0.1",
"notebook>=7.3.2",
"markdownify>=0.13.1",
"playwright>=1.20.0",
"dspy>=2.6.6",
"tokenizers>=0.21.0",
"llama-index-readers-file>=0.4.6",
"mypy>=1.15.0",
"numpy>=1.26.4",
"ipywidgets>=8.1.6",
]
# Check
[tool.pyright]
include = ["autoflow"]
# Test
[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
env_files = ["./tests/.env"]
# Build
[tool.setuptools]
packages = ["autoflow"]
[build-system]
requires = ["hatchling==1.26.3"]
build-backend = "hatchling.build"
[tool.hatch.build]
exclude = [
"/.*",
"/dist",
"/docs",
"/tests",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["autoflow"]
================================================
FILE: core/tests/__init__.py
================================================
================================================
FILE: core/tests/conftest.py
================================================
import logging
import os
from pathlib import Path
import pytest
from dotenv import load_dotenv
from pytidb import TiDBClient
from autoflow.models.llms import LLM
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.configs.db import DatabaseConfig
from autoflow.db import get_db_engine_from_config
logger = logging.getLogger(__name__)
@pytest.fixture(scope="session", autouse=True)
def env():
logger.info(f"Loading environment variables from {Path.cwd() / '.env'}")
load_dotenv()
@pytest.fixture(scope="session")
def db_engine():
config = DatabaseConfig(
host=os.getenv("TIDB_HOST"),
port=int(os.getenv("TIDB_PORT")),
username=os.getenv("TIDB_USERNAME"),
password=os.getenv("TIDB_PASSWORD"),
database=os.getenv("TIDB_DATABASE"),
enable_ssl=False,
)
return get_db_engine_from_config(config)
@pytest.fixture(scope="session")
def llm():
return LLM(model="openai/gpt-4o-mini")
@pytest.fixture(scope="session")
def embedding_model():
return EmbeddingModel(model_name="text-embedding-3-small")
@pytest.fixture(scope="session")
def tidb_client(db_engine):
return TiDBClient(db_engine=db_engine)
================================================
FILE: core/tests/fixtures/analyze-slow-queries.md
================================================
---
title: Analyze Slow Queries
summary: Learn how to locate and analyze slow queries.
---
# Analyze Slow Queries
To address the issue of slow queries, you need to take the following two steps:
1. Among many queries, identify which type of queries are slow.
2. Analyze why this type of queries are slow.
You can easily perform step 1 using the [slow query log](/dashboard/dashboard-slow-query.md) and the [statement summary table](/statement-summary-tables.md) features. It is recommended to use [TiDB Dashboard](/dashboard/dashboard-intro.md), which integrates the two features and directly displays the slow queries in your browser.
This document focuses on how to perform step 2 - analyze why this type of queries are slow.
Generally, slow queries have the following major causes:
- Optimizer issues, such as wrong index selected, wrong join type or sequence selected.
- System issues. All issues not caused by the optimizer are system issues. For example, a busy TiKV instance processes requests slowly; outdated Region information causes slow queries.
In actual situations, optimizer issues might cause system issues. For example, for a certain type of queries, the optimizer uses a full table scan instead of the index. As a result, the SQL queries consume many resources, which causes the CPU usage of some TiKV instances to soar. This seems like a system issue, but in essence, it is an optimizer issue.
To identify system issues is relatively simple. To analyze optimizer issues, you need to determine whether the execution plan is reasonable or not. Therefore, it is recommended to analyze slow queries by following these procedures:
1. Identify the performance bottleneck of the query, that is, the time-consuming part of the query process.
2. Analyze the system issues: analyze the possible causes according to the query bottleneck and the monitoring/log information of that time.
3. Analyze the optimizer issues: analyze whether there is a better execution plan.
The procedures above are explained in the following sections.
## Identify the performance bottleneck of the query
First, you need to have a general understanding of the query process. The key stages of the query execution process in TiDB are illustrated in [TiDB performance map](/media/performance-map.png).
You can get the duration information using the following methods:
- [Slow log](/identify-slow-queries.md). It is recommended to view the slow log in [TiDB Dashboard](/dashboard/dashboard-overview.md).
- [`EXPLAIN ANALYZE` statement](/sql-statements/sql-statement-explain-analyze.md).
The methods above are different in the following aspects:
- The slow log records the duration of almost all stages of a SQL execution, from parsing to returning results, and is relatively comprehensive (you can query and analyze the slow log in TiDB Dashboard in an intuitive way).
- By executing `EXPLAIN ANALYZE`, you can learn the time consumption of each operator in an actual SQL execution. The results have more detailed statistics of the execution duration.
In summary, the slow log and `EXPLAIN ANALYZE` statements help you determine the SQL query is slow in which component (TiDB or TiKV) at which stage of the execution. Therefore, you can accurately identify the performance bottleneck of the query.
In addition, since v4.0.3, the `Plan` field in the slow log also includes the SQL execution information, which is the result of `EXPLAIN ANALYZE`. So you can find all information of SQL duration in the slow log.
## Analyze system issues
System issues can be divided into the following types according to different execution stages of a SQL statement:
1. TiKV is slow in data processing. For example, the TiKV coprocessor processes data slowly.
2. TiDB is slow in execution. For example, a `Join` operator processes data slowly.
3. Other key stages are slow. For example, getting the timestamp takes a long time.
For each slow query, first determine to which type the query belongs, and then analyze it in detail.
### TiKV is slow in data processing
If TiKV is slow in data processing, you can easily identify it in the result of `EXPLAIN ANALYZE`. In the following example, `StreamAgg_8` and `TableFullScan_15`, two `tikv-task`s (as indicated by `cop[tikv]` in the `task` column), take `170ms` to execute. After subtracting `170ms`, the execution time of TiDB operators account for a very small proportion of the total execution time. This indicates that the bottleneck is in TiKV.
```sql
+----------------------------+---------+---------+-----------+---------------+------------------------------------------------------------------------------+---------------------------------+-----------+------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+----------------------------+---------+---------+-----------+---------------+------------------------------------------------------------------------------+---------------------------------+-----------+------+
| StreamAgg_16 | 1.00 | 1 | root | | time:170.08572ms, loops:2 | funcs:count(Column#5)->Column#3 | 372 Bytes | N/A |
| └─TableReader_17 | 1.00 | 1 | root | | time:170.080369ms, loops:2, rpc num: 1, rpc time:17.023347ms, proc keys:28672 | data:StreamAgg_8 | 202 Bytes | N/A |
| └─StreamAgg_8 | 1.00 | 1 | cop[tikv] | | time:170ms, loops:29 | funcs:count(1)->Column#5 | N/A | N/A |
| └─TableFullScan_15 | 7.00 | 28672 | cop[tikv] | table:t | time:170ms, loops:29 | keep order:false, stats:pseudo | N/A | N/A |
+----------------------------+---------+---------+-----------+---------------+------------------------------------------------------------------------------+---------------------------------+-----------+------
```
In addition, the `Cop_process` and `Cop_wait` fields in the slow log can also help your analysis. In the following example, the total duration of the query is around `180.85ms`, and the largest `coptask` takes `171ms`. This indicates that the bottleneck of this query is on the TiKV side.
For the description of each field in the slow log, see [fields description](/identify-slow-queries.md#fields-description).
```log
# Query_time: 0.18085
...
# Num_cop_tasks: 1
# Cop_process: Avg_time: 170ms P90_time: 170ms Max_time: 170ms Max_addr: 10.6.131.78
# Cop_wait: Avg_time: 1ms P90_time: 1ms Max_time: 1ms Max_Addr: 10.6.131.78
```
After identifying that TiKV is the bottleneck, you can find out the cause as described in the following sections.
#### TiKV instance is busy
During the execution of a SQL statement, TiDB might fetch data from multiple TiKV instances. If one TiKV instance responds slowly, the overall SQL execution speed is slowed down.
The `Cop_wait` field in the slow log can help you determine this cause.
```log
# Cop_wait: Avg_time: 1ms P90_time: 2ms Max_time: 110ms Max_Addr: 10.6.131.78
```
The log above shows that a `cop-task` sent to the `10.6.131.78` instance waits `110ms` before being executed. It indicates that this instance is busy. You can check the CPU monitoring of that time to confirm the cause.
#### Too many outdated keys
A TiKV instance has much outdated data, which needs to be cleaned up for data scan. This impacts the processing speed.
Check `Total_keys` and `Processed_keys`. If they are greatly different, the TiKV instance has too many keys of the older versions.
```
...
# Total_keys: 2215187529 Processed_keys: 1108056368
...
```
### Other key stages are slow
#### Slow in getting timestamps
You can compare `Wait_TS` and `Query_time` in the slow log. The timestamps are prefetched, so generally `Wait_TS` should be low.
```
# Query_time: 0.0300000
...
# Wait_TS: 0.02500000
```
#### Outdated Region information
Region information on the TiDB side might be outdated. In this situation, TiKV might return the `regionMiss` error. Then TiDB gets the Region information from PD again, which is reflected in the `Cop_backoff` information. Both the failed times and the total duration are recorded.
```
# Cop_backoff_regionMiss_total_times: 200 Cop_backoff_regionMiss_total_time: 0.2 Cop_backoff_regionMiss_max_time: 0.2 Cop_backoff_regionMiss_max_addr: 127.0.0.1 Cop_backoff_regionMiss_avg_time: 0.2 Cop_backoff_regionMiss_p90_time: 0.2
# Cop_backoff_rpcPD_total_times: 200 Cop_backoff_rpcPD_total_time: 0.2 Cop_backoff_rpcPD_max_time: 0.2 Cop_backoff_rpcPD_max_addr: 127.0.0.1 Cop_backoff_rpcPD_avg_time: 0.2 Cop_backoff_rpcPD_p90_time: 0.2
```
#### Subqueries are executed in advance
For statements with non-correlated subqueries, the subquery part might be executed in advance. For example, in `select * from t1 where a = (select max(a) from t2)`, the `select max(a) from t2` part might be executed in advance in the optimization stage. The result of `EXPLAIN ANALYZE` does not show the duration of this type of subqueries.
```sql
mysql> explain analyze select count(*) from t where a=(select max(t1.a) from t t1, t t2 where t1.a=t2.a);
+------------------------------+----------+---------+-----------+---------------+--------------------------+----------------------------------+-----------+------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+------------------------------+----------+---------+-----------+---------------+--------------------------+----------------------------------+-----------+------+
| StreamAgg_59 | 1.00 | 1 | root | | time:4.69267ms, loops:2 | funcs:count(Column#10)->Column#8 | 372 Bytes | N/A |
| └─TableReader_60 | 1.00 | 1 | root | | time:4.690428ms, loops:2 | data:StreamAgg_48 | 141 Bytes | N/A |
| └─StreamAgg_48 | 1.00 | | cop[tikv] | | time:0ns, loops:0 | funcs:count(1)->Column#10 | N/A | N/A |
| └─Selection_58 | 16384.00 | | cop[tikv] | | time:0ns, loops:0 | eq(test.t.a, 1) | N/A | N/A |
| └─TableFullScan_57 | 16384.00 | -1 | cop[tikv] | table:t | time:0s, loops:0 | keep order:false | N/A | N/A |
+------------------------------+----------+---------+-----------+---------------+--------------------------+----------------------------------+-----------+------+
5 rows in set (7.77 sec)
```
But you can identify this type of subquery execution in the slow log:
```
# Query_time: 7.770634843
...
# Rewrite_time: 7.765673663 Preproc_subqueries: 1 Preproc_subqueries_time: 7.765231874
```
From log record above, you can see that a subquery is executed in advance and takes `7.76s`.
### TiDB is slow in execution
Assume that the execution plan in TiDB is correct but the execution is slow. To solve this type of issue, you can adjust parameters or use the hint according to the result of `EXPLAIN ANALYZE` for the SQL statement.
If the execution plan is incorrect, see the [Analyze optimizer issues](#analyze-optimizer-issues) section.
#### Low concurrency
If the bottleneck is in the operator with concurrency, speed up the execution by adjusting the concurrency. For example:
```sql
mysql> explain analyze select sum(t1.a) from t t1, t t2 where t1.a=t2.a;
+----------------------------------+--------------+-----------+-----------+---------------+-------------------------------------------------------------------------------------+------------------------------------------------+------------------+---------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+----------------------------------+--------------+-----------+-----------+---------------+-------------------------------------------------------------------------------------+------------------------------------------------+------------------+---------+
| HashAgg_11 | 1.00 | 1 | root | | time:9.666832189s, loops:2, PartialConcurrency:4, FinalConcurrency:4 | funcs:sum(Column#6)->Column#5 | 322.125 KB | N/A |
| └─Projection_24 | 268435456.00 | 268435456 | root | | time:9.098644711s, loops:262145, Concurrency:4 | cast(test.t.a, decimal(65,0) BINARY)->Column#6 | 199 KB | N/A |
| └─HashJoin_14 | 268435456.00 | 268435456 | root | | time:6.616773501s, loops:262145, Concurrency:5, probe collision:0, build:881.404µs | inner join, equal:[eq(test.t.a, test.t.a)] | 131.75 KB | 0 Bytes |
| ├─TableReader_21(Build) | 16384.00 | 16384 | root | | time:6.553717ms, loops:17 | data:Selection_20 | 33.6318359375 KB | N/A |
| │ └─Selection_20 | 16384.00 | | cop[tikv] | | time:0ns, loops:0 | not(isnull(test.t.a)) | N/A | N/A |
| │ └─TableFullScan_19 | 16384.00 | -1 | cop[tikv] | table:t2 | time:0s, loops:0 | keep order:false | N/A | N/A |
| └─TableReader_18(Probe) | 16384.00 | 16384 | root | | time:6.880923ms, loops:17 | data:Selection_17 | 33.6318359375 KB | N/A |
| └─Selection_17 | 16384.00 | | cop[tikv] | | time:0ns, loops:0 | not(isnull(test.t.a)) | N/A | N/A |
| └─TableFullScan_16 | 16384.00 | -1 | cop[tikv] | table:t1 | time:0s, loops:0 | keep order:false | N/A | N/A |
+----------------------------------+--------------+-----------+-----------+---------------+-------------------------------------------------------------------------------------+------------------------------------------------+------------------+---------+
9 rows in set (9.67 sec)
```
As shown above, `HashJoin_14` and `Projection_24` consume much of the execution time. Consider increasing their concurrency using SQL variables to speed up execution.
All system variables are documented in [system-variables](/system-variables.md). To increase the concurrency of `HashJoin_14`, you can modify the `tidb_hash_join_concurrency` system variable.
#### Data is spilled to disk
Another cause of slow execution is disk spill that occurs during execution if the memory limit is reached. You can find out this cause in the execution plan and the slow log:
```sql
+-------------------------+-----------+---------+-----------+---------------+------------------------------+----------------------+-----------------------+----------------+
| id | estRows | actRows | task | access object | execution info | operator info | memory | disk |
+-------------------------+-----------+---------+-----------+---------------+------------------------------+----------------------+-----------------------+----------------+
| Sort_4 | 462144.00 | 462144 | root | | time:2.02848898s, loops:453 | test.t.a | 149.68795776367188 MB | 219.3203125 MB |
| └─TableReader_8 | 462144.00 | 462144 | root | | time:616.211272ms, loops:453 | data:TableFullScan_7 | 197.49601364135742 MB | N/A |
| └─TableFullScan_7 | 462144.00 | -1 | cop[tikv] | table:t | time:0s, loops:0 | keep order:false | N/A | N/A |
+-------------------------+-----------+---------+-----------+---------------+------------------------------+----------------------+-----------------------+----------------+
```
```
...
# Disk_max: 229974016
...
```
#### Join operations with Cartesian product
Join operations with Cartesian product generate data volume as large as `left child row count * right child row count`. This is inefficient and should be avoided.
This type of join operations is marked `CARTESIAN` in the execution plan. For example:
```sql
mysql> explain select * from t t1, t t2 where t1.a>t2.a;
+------------------------------+-------------+-----------+---------------+---------------------------------------------------------+
| id | estRows | task | access object | operator info |
+------------------------------+-------------+-----------+---------------+---------------------------------------------------------+
| HashJoin_8 | 99800100.00 | root | | CARTESIAN inner join, other cond:gt(test.t.a, test.t.a) |
| ├─TableReader_15(Build) | 9990.00 | root | | data:Selection_14 |
| │ └─Selection_14 | 9990.00 | cop[tikv] | | not(isnull(test.t.a)) |
| │ └─TableFullScan_13 | 10000.00 | cop[tikv] | table:t2 | keep order:false, stats:pseudo |
| └─TableReader_12(Probe) | 9990.00 | root | | data:Selection_11 |
| └─Selection_11 | 9990.00 | cop[tikv] | | not(isnull(test.t.a)) |
| └─TableFullScan_10 | 10000.00 | cop[tikv] | table:t1 | keep order:false, stats:pseudo |
+------------------------------+-------------+-----------+---------------+---------------------------------------------------------+
```
## Analyze optimizer issues
To analyze optimizer issues, you need to determine whether the execution plan is reasonable or not. You need to have some understanding of the optimization process and each operator.
For the following examples, assume that the table schema is `create table t (id int, a int, b int, c int, primary key(id), key(a), key(b, c))`.
1. `select * from t`: There is no filter condition and a full table scan is performed. So the `TableFullScan` operator is used to read data.
2. `select a from t where a=2`: There is a filter condition and only the index columns are read, so the `IndexReader` operator is used to read data.
3. `select * from t where a=2`: There is a filter condition for `a` but the `a` index cannot fully cover the data to be read, so the `IndexLookup` operator is used.
4. `select b from t where c=3`: Without the prefix condition, the multi-column index cannot be used. So the `IndexFullScan` is used.
5. ...
The examples above are operators used for data reads. For more operators, see [Understand TiDB Execution Plan](/explain-overview.md).
In addition, reading [SQL Tuning Overview](/sql-tuning-overview.md) helps you better understand the TiDB optimizer and determine whether the execution plan is reasonable or not.
Most optimizer issues are explained in [SQL Tuning Overview](/sql-tuning-overview.md). For the solutions, see the following documents:
1. [Wrong Index Solution](/wrong-index-solution.md)
2. [Wrong join order](/join-reorder.md)
3. [Expressions are not pushed down](/blocklist-control-plan.md)
================================================
FILE: core/tests/fixtures/tidb-overview.md
================================================
---
title: What is TiDB Self-Managed
summary: Learn about the key features and usage scenarios of TiDB.
aliases: ['/docs/dev/key-features/','/tidb/dev/key-features','/docs/dev/overview/']
---
# What is TiDB Self-Managed
[TiDB](https://github.com/pingcap/tidb) (/'taɪdiːbi:/, "Ti" stands for Titanium) is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL compatible and features horizontal scalability, strong consistency, and high availability. The goal of TiDB is to provide users with a one-stop database solution that covers OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services. TiDB is suitable for various use cases that require high availability and strong consistency with large-scale data.
TiDB Self-Managed is a product option of TiDB, where users or organizations can deploy and manage TiDB on their own infrastructure with complete flexibility. With TiDB Self-Managed, you can enjoy the power of open source, distributed SQL while retaining full control over your environment.
The following video introduces key features of TiDB.
VIDEO
## Key features
- **Easy horizontal scaling**
The TiDB architecture design separates computing from storage, letting you scale out or scale in the computing or storage capacity online as needed. The scaling process is transparent to application operations and maintenance staff.
- **Financial-grade high availability**
Data is stored in multiple replicas, and the Multi-Raft protocol is used to obtain the transaction log. A transaction can only be committed when data has been successfully written into the majority of replicas. This guarantees strong consistency and availability when a minority of replicas go down. You can configure the geographic location and number of replicas as needed to meet different disaster tolerance levels.
- **Real-time HTAP**
TiDB provides two storage engines: [TiKV](/tikv-overview.md), a row-based storage engine, and [TiFlash](/tiflash/tiflash-overview.md), a columnar storage engine. TiFlash uses the Multi-Raft Learner protocol to replicate data from TiKV in real time, ensuring consistent data between the TiKV row-based storage engine and the TiFlash columnar storage engine. TiKV and TiFlash can be deployed on different machines as needed to solve the problem of HTAP resource isolation.
- **Cloud-native distributed database**
TiDB is a distributed database designed for the cloud, providing flexible scalability, reliability, and security on the cloud platform. Users can elastically scale TiDB to meet the requirements of their changing workloads. In TiDB, each piece of data has at least 3 replicas, which can be scheduled in different cloud availability zones to tolerate the outage of a whole data center. [TiDB Operator](https://docs.pingcap.com/tidb-in-kubernetes/stable/tidb-operator-overview) helps manage TiDB on Kubernetes and automates tasks related to operating the TiDB cluster, making TiDB easier to deploy on any cloud that provides managed Kubernetes. [TiDB Cloud](https://pingcap.com/tidb-cloud/), the fully-managed TiDB service, is the easiest, most economical, and most resilient way to unlock the full power of [TiDB in the cloud](https://docs.pingcap.com/tidbcloud/), allowing you to deploy and run TiDB clusters with just a few clicks.
- **Compatible with the MySQL protocol and MySQL ecosystem**
TiDB is compatible with the MySQL protocol, common features of MySQL, and the MySQL ecosystem. To migrate applications to TiDB, you do not need to change a single line of code in many cases, or only need to modify a small amount of code. In addition, TiDB provides a series of [data migration tools](/ecosystem-tool-user-guide.md) to help easily migrate application data into TiDB.
## See also
- [TiDB Architecture](/tidb-architecture.md)
- [TiDB Storage](/tidb-storage.md)
- [TiDB Computing](/tidb-computing.md)
- [TiDB Scheduling](/tidb-scheduling.md)
================================================
FILE: core/tests/knowledge_base/__init__.py
================================================
================================================
FILE: core/tests/knowledge_base/test_kb_with_namespace.py
================================================
import logging
import pytest
from autoflow.configs.knowledge_base import IndexMethod
from autoflow.knowledge_base import KnowledgeBase
logger = logging.getLogger(__name__)
@pytest.fixture(scope="module")
def kb(db_engine, llm, embedding_model):
kb = KnowledgeBase(
namespace="test",
name="Test",
description="Here is a knowledge base with namespace",
index_methods=[IndexMethod.VECTOR_SEARCH, IndexMethod.KNOWLEDGE_GRAPH],
llm=llm,
embedding_model=embedding_model,
db_engine=db_engine,
)
logger.info(
"Created a knowledge base with namespace <%s> successfully.", kb.namespace
)
return kb
def test_add_documents_via_filepath(kb: KnowledgeBase):
docs = kb.add("./tests/fixtures/analyze-slow-queries.md")
assert len(docs) == 1
def test_add_documents_via_url(kb: KnowledgeBase):
docs = kb.add("https://docs.pingcap.com/tidbcloud/tidb-cloud-intro")
assert len(docs) == 1
def test_search_documents(kb: KnowledgeBase):
result = kb.search_documents(
query="What is TiDB?",
top_k=2,
)
assert len(result.chunks) > 0
def test_search_knowledge_graph(kb: KnowledgeBase):
knowledge_graph = kb.search_knowledge_graph(
query="What is TiDB?",
)
assert len(knowledge_graph.entities) > 0
assert len(knowledge_graph.relationships) > 0
================================================
FILE: core/tests/knowledge_base/test_kb_without_namespace.py
================================================
import logging
import pytest
from autoflow.configs.knowledge_base import IndexMethod
from autoflow.knowledge_base import KnowledgeBase
logger = logging.getLogger(__name__)
@pytest.fixture(scope="module")
def kb(db_engine, llm, embedding_model):
kb = KnowledgeBase(
db_engine=db_engine,
name="Test",
description="Here is a knowledge base without namespace",
index_methods=[IndexMethod.VECTOR_SEARCH, IndexMethod.KNOWLEDGE_GRAPH],
llm=llm,
embedding_model=embedding_model,
)
logger.info("Created a knowledge base successfully.")
return kb
def test_add_documents_via_filepath(kb: KnowledgeBase):
docs = kb.add("./tests/fixtures/analyze-slow-queries.md")
assert len(docs) == 1
def test_add_documents_via_url(kb):
docs = kb.add("https://docs.pingcap.com/tidbcloud/tidb-cloud-intro")
assert len(docs) == 1
def test_search_documents(kb):
result = kb.search_documents(
query="What is TiDB?",
similarity_top_k=2,
)
assert len(result.chunks) > 0
def test_search_knowledge_graph(kb):
knowledge_graph = kb.search_knowledge_graph(
query="What is TiDB?",
)
assert len(knowledge_graph.entities) > 0
assert len(knowledge_graph.relationships) > 0
================================================
FILE: core/tests/knowledge_graph/programs/test_extract_graph.py
================================================
import logging
from pathlib import Path
import pytest
from autoflow.knowledge_graph.programs.eval_graph import KnowledgeGraphEvaluator
from autoflow.knowledge_graph.programs.extract_graph import KnowledgeGraphExtractor
from autoflow.knowledge_graph.types import GeneratedKnowledgeGraph
from autoflow.models.llms.dspy import get_dspy_lm_by_llm
logger = logging.getLogger(__name__)
@pytest.fixture(scope="module")
def extractor(llm):
dspy_lm = get_dspy_lm_by_llm(llm)
extractor = KnowledgeGraphExtractor(dspy_lm=dspy_lm)
return extractor
@pytest.fixture(scope="module")
def evaluator(llm):
dspy_lm = get_dspy_lm_by_llm(llm)
evaluator = KnowledgeGraphEvaluator(dspy_lm=dspy_lm)
return evaluator
def test_extract_graph(extractor, evaluator):
text = Path("tests/fixtures/tidb-overview.md").read_text()
# Expected knowledge graph output
expected_knowledge = GeneratedKnowledgeGraph.model_validate(
{
"entities": [
{
"name": "TiDB",
"description": "An open-source distributed SQL database that supports HTAP workloads.",
},
{
"name": "TiDB Self-Managed",
"description": "A product option of TiDB where users deploy and manage TiDB on their own infrastructure.",
},
{
"name": "TiDB Cloud",
"description": "The fully-managed TiDB service for deploying and running TiDB clusters in the cloud.",
},
{
"name": "TiDB Operator",
"description": "A tool that helps manage TiDB on Kubernetes and automates tasks related to operating TiDB clusters",
},
{
"name": "TiKV",
"description": "A row-based storage engine used by TiDB.",
},
{
"name": "TiFlash",
"description": "A columnar storage engine used by TiDB.",
},
{
"name": "Multi-Raft Learner protocol",
"description": "A protocol used by TiDB to replicate data from TiKV to TiFlash.",
},
],
"relationships": [
{
"source_entity_name": "TiDB",
"target_entity_name": "TiKV",
"description": "TiDB uses TiKV as its storage engine",
},
{
"source_entity_name": "TiDB",
"target_entity_name": "TiFlash",
"description": "TiDB uses TiFlash as its analytics engine",
},
{
"source_entity_name": "TiDB",
"target_entity_name": "Multi-Raft Learner protocol",
"description": "TiDB uses the Multi-Raft Learner protocol to replicate data from TiKV to TiFlash.",
},
{
"source_entity_name": "TiDB",
"target_entity_name": "HTAP",
"description": "TiDB supports HTAP workloads",
},
{
"source_entity_name": "TiDB Self-Managed",
"target_entity_name": "TiDB",
"description": "TiDB Self-Managed is a product option of TiDB",
},
{
"source_entity_name": "TiDB Cloud",
"target_entity_name": "TiDB",
"description": "TiDB Cloud is a fully-managed TiDB service",
},
{
"source_entity_name": "TiDB Operator",
"target_entity_name": "TiDB Cloud",
"description": "TiDB Operator is a tool that helps manage TiDB on Kubernetes and automates tasks related to operating TiDB clusters",
},
],
}
)
# Generate knowledge graph
actual_knowledge = extractor.forward(text)
# Use LLM to evaluate the completeness
evaluation_result = evaluator.forward(expected_knowledge, actual_knowledge)
final_score = evaluation_result.score
logger.info(f"Final score: {final_score}")
assert final_score > 0.4, "The completeness score should be greater than 0.4."
================================================
FILE: core/tests/knowledge_graph/test_kg_extractor.py
================================================
from pathlib import Path
from autoflow.knowledge_graph.extractors.simple import SimpleKGExtractor
from autoflow.models.llms.dspy import get_dspy_lm_by_llm
def test_kg_extractor(llm):
text = Path("./tests/fixtures/tidb-overview.md").read_text()
dspy_lm = get_dspy_lm_by_llm(llm)
extractor = SimpleKGExtractor(dspy_lm)
knowledge_graph = extractor.extract(text)
assert knowledge_graph is not None
assert len(knowledge_graph.entities) >= 2
assert len(knowledge_graph.relationships) >= 1
for entity in knowledge_graph.entities:
assert entity.name is not None
assert entity.description is not None
assert len(entity.meta) > 0
for relationship in knowledge_graph.relationships:
assert relationship.source_entity_name is not None
assert relationship.target_entity_name is not None
assert relationship.description is not None
================================================
FILE: core/tests/models/test_model_manager.py
================================================
import logging
import os
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.schema import NodeWithScore, TextNode
import pytest
from autoflow.configs.models.providers import ModelProviders
from autoflow.configs.models.providers.openai import OpenAIConfig
from autoflow.models.manager import (
model_manager,
ProviderConfig,
)
logger = logging.getLogger(__name__)
@pytest.fixture(scope="module", autouse=True)
def setup_model_manager():
model_manager.registry_provider(
name=ModelProviders.OPENAI,
config=OpenAIConfig(
api_key=os.getenv("OPENAI_API_KEY"),
),
)
model_manager.registry_provider(
name=ModelProviders.JINA_AI,
config=ProviderConfig(
api_key=os.getenv("JINAAI_API_KEY"),
),
)
def test_llm():
llm = model_manager.resolve_llm(
provider=ModelProviders.OPENAI,
config={
"model": "gpt-4o",
},
)
res = llm.chat(
messages=[
ChatMessage(
role="user",
content="Does TiDB Support Vector Search (Y/N)?",
)
],
max_tokens=1,
)
assert res.message.content is not None
logger.info(
f"LLM Answer: {res.message.content}",
)
def test_embedding_model():
embed_model = model_manager.resolve_embedding_model(
provider=ModelProviders.OPENAI,
config={
"model": "text-embedding-3-small",
"dimensions": 1536,
},
)
vector = embed_model.get_query_embedding("What is TiDB?")
assert len(vector) == 1536
def test_reranker_model():
reranker_model = model_manager.resolve_rerank_model(
provider=ModelProviders.JINA_AI,
config={"model": "jina-reranker-v2-base-multilingual"},
)
nodes = reranker_model.postprocess_nodes(
query_str="Database",
nodes=[
NodeWithScore(node=TextNode(text="Redis")),
NodeWithScore(node=TextNode(text="OpenAI")),
NodeWithScore(node=TextNode(text="TiDB")),
],
)
assert len(nodes) == 3
================================================
FILE: core/tests/storage/__init__.py
================================================
================================================
FILE: core/tests/storage/doc_store/test_tidb_doc_store.py
================================================
import pytest
from pytidb import TiDBClient
from autoflow.models.embedding_models import EmbeddingModel
from autoflow.storage.doc_store.tidb_doc_store import TiDBDocumentStore
from autoflow.storage.doc_store.types import Document, Chunk
from autoflow.utils.hash import sha256
@pytest.fixture(scope="session")
def doc_store():
tidb_client = TiDBClient.connect()
return TiDBDocumentStore(namespace="doc_store", client=tidb_client, vector_dims=3)
@pytest.fixture(scope="session")
def doc_store_with_auto_embed():
tidb_client = TiDBClient.connect()
embedding_model = EmbeddingModel(model_name="text-embedding-3-small")
return TiDBDocumentStore(
namespace="doc_store_with_auto_embed",
client=tidb_client,
embedding_model=embedding_model,
)
def test_crud(doc_store):
doc_store.reset()
# Create
documents = doc_store.add(
[
Document(
name="TiDB",
content="TiDB is a distributed SQL database.",
chunks=[
Chunk(
text="TiDB is a distributed SQL database.", text_vec=[1, 2, 3]
),
],
),
Document(
name="TiKV",
content="TiKV is a distributed key-value storage engine.",
chunks=[
Chunk(
text="TiKV is a distributed key-value storage engine.",
text_vec=[4, 5, 6],
),
],
),
Document(
name="TiFlash",
content="TiFlash is a column-oriented storage engine.",
chunks=[
Chunk(
text="TiFlash is a column-oriented storage engine.",
text_vec=[7, 8, 9],
),
],
),
]
)
assert len(documents) == 3
for doc in documents:
assert doc.id is not None
assert doc.created_at is not None
assert len(doc.chunks) == 1
chunk = doc.chunks[0]
assert chunk.id is not None
assert chunk.document_id == doc.id
assert chunk.text == doc.content
assert chunk.text_vec is not None
assert len(chunk.text_vec) == 3
# Retrieve - Vector Search
results = doc_store.search([4, 5, 6], top_k=2)
assert len(results.documents) == 2
assert results.documents[0].name == "TiKV"
assert results.chunks[0].score > 0
# Update
document_id = results.chunks[0].document_id
old_chunk = results.chunks[0]
old_vector_sha = sha256(str(old_chunk.text_vec))
new_chunk = doc_store.update_chunk(
old_chunk.id,
{
"text": "TiKV is a distributed key-value storage engine for TiDB.",
"text_vec": [3, 6, 9],
},
)
new_vector_sha = sha256(str(new_chunk.text_vec))
assert new_vector_sha != old_vector_sha
# Delete
doc_store.delete_chunk(new_chunk.id)
chunks = doc_store.list_doc_chunks(document_id)
assert len(chunks) == 0
def test_crud_with_auto_embed(doc_store_with_auto_embed):
doc_store_with_auto_embed.reset()
# Create
documents = doc_store_with_auto_embed.add(
[
Document(
name="TiDB",
content="TiDB is a distributed SQL database.",
chunks=[
Chunk(text="TiDB is a distributed SQL database."),
],
),
Document(
name="TiKV",
content="TiKV is a distributed key-value storage engine.",
chunks=[
Chunk(text="TiKV is a distributed key-value storage engine."),
],
),
Document(
name="TiFlash",
content="TiFlash is a column-oriented storage engine.",
chunks=[
Chunk(text="TiFlash is a column-oriented storage engine."),
],
),
]
)
assert len(documents) == 3
for doc in documents:
assert doc.id is not None
assert doc.created_at is not None
assert len(doc.chunks) == 1
chunk = doc.chunks[0]
assert chunk.id is not None
assert chunk.document_id == doc.id
assert chunk.text == doc.content
assert chunk.text_vec is not None
assert len(chunk.text_vec) == 1536
# Retrieve - Vector Search
results = doc_store_with_auto_embed.search("tikv", top_k=2)
assert len(results.documents) == 2
assert results.documents[0].name == "TiKV"
assert results.chunks[0].score > 0
# Update
document_id = results.chunks[0].document_id
old_chunk = results.chunks[0]
old_vector_sha = sha256(str(old_chunk.text_vec))
new_chunk = doc_store_with_auto_embed.update_chunk(
old_chunk.id,
{"text": "TiKV is a distributed key-value storage engine for TiDB."},
)
new_vector_sha = sha256(str(new_chunk.text_vec))
# To check the auto embedding_models is work on updating.
assert new_vector_sha != old_vector_sha
# Delete
doc_store_with_auto_embed.delete_chunk(new_chunk.id)
chunks = doc_store_with_auto_embed.list_doc_chunks(document_id)
assert len(chunks) == 0
================================================
FILE: core/tests/storage/graph_store/__init__.py
================================================
================================================
FILE: core/tests/storage/graph_store/test_tidb_graph_store.py
================================================
from hashlib import sha256
import logging
import pytest
from autoflow.storage.graph_store import TiDBGraphStore
from autoflow.storage.graph_store.types import (
EntityType,
EntityUpdate,
RelationshipUpdate,
)
logger = logging.getLogger(__name__)
@pytest.fixture(scope="session")
def graph_store(tidb_client, embedding_model):
return TiDBGraphStore(
client=tidb_client,
embedding_model=embedding_model,
namespace="tidb_graph_store_test",
)
def test_entity_crud(graph_store: TiDBGraphStore):
graph_store.reset()
# Create entities
tidb_entity = graph_store.create_entity(
name="TiDB", description="TiDB is a relational database."
)
assert tidb_entity.id is not None
logger.info(
"tidb_graph_store: add new entity (name: %s, id: %s)",
tidb_entity.name,
tidb_entity.id,
)
tikv_entity = graph_store.create_entity(
name="TiKV", description="TiKV is a distributed key-value storage engine."
)
assert tikv_entity.id is not None
logger.info(
"tidb_graph_store: add new entity (name: %s, id: %s)",
tikv_entity.name,
tikv_entity.id,
)
# Get entity
entity = graph_store.get_entity(tidb_entity.id)
assert entity.id is not None
assert entity.entity_type == EntityType.original
assert entity.name == "TiDB"
assert entity.embedding is not None
assert entity.created_at is not None
assert entity.updated_at is not None
# List entities
entities = graph_store.list_entities(entity_id=tidb_entity.id)
assert len(entities) == 1
assert entities[0].id == tidb_entity.id
entities = graph_store.list_entities(entity_id=[tidb_entity.id])
assert len(entities) == 1
assert entities[0].id == tidb_entity.id
entities = graph_store.list_entities(entity_type=EntityType.original)
assert len(entities) == 2
# Update entity
old_embedding = tidb_entity.embedding
updated_tidb_entity = graph_store.update_entity(
entity=tidb_entity,
update=EntityUpdate(
name="TiDB", description="TiDB is a MySQL-compatible database."
),
)
new_embedding = updated_tidb_entity.embedding
assert updated_tidb_entity.id == tidb_entity.id
assert updated_tidb_entity.name == "TiDB"
assert updated_tidb_entity.description == "TiDB is a MySQL-compatible database."
assert sha256(new_embedding) != sha256(old_embedding)
# Delete entity
graph_store.delete_entity(tidb_entity.id)
try:
graph_store.get_entity(tidb_entity.id)
raise AssertionError("Entity should be deleted")
except Exception as e:
logger.info(
"tidb_graph_store: entity %s should be deleted: %s", tidb_entity.id, e
)
graph_store.reset()
def test_relationship_crud(graph_store: TiDBGraphStore):
graph_store.reset()
# Create entities
tidb_entity = graph_store.create_entity(
name="TiDB", description="TiDB is a relational database."
)
tikv_entity = graph_store.create_entity(
name="TiKV", description="TiKV is a distributed key-value storage engine."
)
# Create relationships
relationship = graph_store.create_relationship(
source_entity=tidb_entity,
target_entity=tikv_entity,
description="TiDB uses TiKV as its storage engine.",
meta={
"source": "TiDB's Documentation",
},
)
assert relationship.id is not None
assert relationship.source_entity_id == tidb_entity.id
assert relationship.target_entity_id == tikv_entity.id
assert relationship.description == "TiDB uses TiKV as its storage engine."
assert relationship.embedding is not None
assert relationship.created_at is not None
assert relationship.updated_at is not None
# List relationships
relationships = graph_store.list_relationships(entity_id=tidb_entity.id)
assert len(relationships) == 1
assert relationships[0].id == relationship.id
# Update relationship
old_embedding = relationship.embedding
old_updated_at = relationship.updated_at
updated_relationship = graph_store.update_relationship(
relationship=relationship,
update=RelationshipUpdate(
description="TiDB uses TiKV as its storage engine for TP workloads."
),
)
new_embedding = updated_relationship.embedding
new_updated_at = updated_relationship.updated_at
assert (
updated_relationship.description
== "TiDB uses TiKV as its storage engine for TP workloads."
)
assert sha256(new_embedding) != sha256(old_embedding)
assert new_updated_at > old_updated_at
# Delete relationship
graph_store.delete_relationship(relationship.id)
try:
graph_store.get_relationship(relationship.id)
raise AssertionError("Relationship should be deleted")
except Exception as e:
logger.info(
"tidb_graph_store: relationship %s should be deleted: %s",
relationship.id,
e,
)
graph_store.reset()
def test_entity_degree(graph_store: TiDBGraphStore):
graph_store.reset()
# Create entities
tidb_entity = graph_store.create_entity(
name="TiDB", description="TiDB is a relational database."
)
tikv_entity = graph_store.create_entity(
name="TiKV", description="TiKV is a distributed key-value storage engine."
)
tiflash_entity = graph_store.create_entity(
name="TiFlash", description="TiFlash is a column-oriented database engine."
)
# Create relationships
graph_store.create_relationship(
source_entity=tidb_entity,
target_entity=tikv_entity,
description="TiDB uses TiKV as its storage engine.",
)
graph_store.create_relationship(
source_entity=tidb_entity,
target_entity=tiflash_entity,
description="TiDB uses TiFlash as its analytical engine.",
)
# Calculate entity degree
out_degree = graph_store.calc_entity_out_degree(tidb_entity.id)
assert out_degree == 2
in_degree = graph_store.calc_entity_in_degree(tidb_entity.id)
assert in_degree == 0
degree = graph_store.calc_entity_degree(tidb_entity.id)
assert degree == 2
# Calculate entities degree
degrees = graph_store.calc_entities_degrees(
[tidb_entity.id, tikv_entity.id, tiflash_entity.id]
)
assert degrees[tidb_entity.id].out_degree == 2
assert degrees[tidb_entity.id].in_degree == 0
assert degrees[tidb_entity.id].degrees == 2
assert degrees[tikv_entity.id].out_degree == 0
assert degrees[tikv_entity.id].in_degree == 1
assert degrees[tikv_entity.id].degrees == 1
assert degrees[tiflash_entity.id].out_degree == 0
assert degrees[tiflash_entity.id].in_degree == 1
assert degrees[tiflash_entity.id].degrees == 1
graph_store.reset()
================================================
FILE: docker-compose-cn.yml
================================================
name: tidb-ai
services:
redis:
image: registry.cn-beijing.aliyuncs.com/pingcap-ee/redis:6.0.16
restart: always
volumes:
- ./redis-data:/data
command: ["redis-server", "--loglevel", "warning"]
backend:
image: registry.cn-beijing.aliyuncs.com/pingcap-ee/tidb.ai-backend:0.4.0
restart: always
depends_on:
- redis
ports:
- "8000:80"
env_file:
- .env
volumes:
- ./data:/shared/data
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
frontend:
image: registry.cn-beijing.aliyuncs.com/pingcap-ee/tidb.ai-frontend:0.4.0
restart: always
depends_on:
- backend
ports:
- 3000:3000
environment:
BASE_URL: http://backend
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
background:
image: registry.cn-beijing.aliyuncs.com/pingcap-ee/tidb.ai-backend:0.4.0
restart: always
depends_on:
- redis
ports:
- "5555:5555"
env_file:
- .env
volumes:
- ./data:/shared/data
command: /usr/bin/supervisord
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
local-embedding-reranker:
image: registry.cn-beijing.aliyuncs.com/pingcap-ee/tidb.ai-local-embedding-reranker:v4-with-cache
ports:
- 5001:5001
environment:
- HF_ENDPOINT=https://hf-mirror.com
- PRE_LOAD_DEFAULT_EMBEDDING_MODEL=true
# If you want to pre-load the default reranker model, change the following environment to true
- PRE_LOAD_DEFAULT_RERANKER_MODEL=false
- TRANSFORMERS_OFFLINE=1
# volumes:
# - ./local-embedding-reranker:/root/.cache/huggingface
# If you are using NVIDIA GPU, you can uncomment the following lines to enable GPU support
# deploy:
# resources:
# reservations:
# devices:
# - driver: nvidia
# count: 1
# capabilities: [gpu]
profiles:
- local-embedding-reranker
================================================
FILE: docker-compose.dev.yml
================================================
name: tidb-ai-build
services:
backend:
build:
context: backend
dockerfile: Dockerfile
args:
BUILDKIT_INLINE_CACHE: 1
ports:
- "8006:80"
env_file:
- .env
volumes:
- ./data:/shared/data
depends_on:
- redis
frontend:
build:
context: .
dockerfile: ./frontend/Dockerfile
args:
BUILDKIT_INLINE_CACHE: 1
ports:
- "3001:3000"
environment:
BASE_URL: http://backend
depends_on:
- backend
background:
build:
context: backend
dockerfile: Dockerfile
args:
BUILDKIT_INLINE_CACHE: 1
ports:
- "5556:5555"
env_file:
- .env
volumes:
- ./data:/shared/data
command: /usr/bin/supervisord
depends_on:
- redis
local-embedding-reranker:
build:
context: backend/local_embedding_reranker
dockerfile: Dockerfile
args:
BUILDKIT_INLINE_CACHE: 1
ports:
- "5002:5001"
environment:
- PRE_LOAD_DEFAULT_EMBEDDING_MODEL=true
- PRE_LOAD_DEFAULT_RERANKER_MODEL=false
- TRANSFORMERS_OFFLINE=1
profiles:
- local-embedding-reranker
redis:
image: redis:6.0.16
volumes:
- ./redis-data:/data
command: ["redis-server", "--loglevel", "warning"]
================================================
FILE: docker-compose.yml
================================================
name: tidb-ai
services:
redis:
image: redis:6.0.16
restart: always
volumes:
- ./redis-data:/data
command: ["redis-server", "--loglevel", "warning"]
backend:
image: tidbai/backend:0.4.0
restart: always
depends_on:
- redis
ports:
- "8000:80"
env_file:
- .env
volumes:
- ./data:/shared/data
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
frontend:
image: tidbai/frontend:0.4.0
restart: always
depends_on:
- backend
ports:
- 3000:3000
environment:
BASE_URL: http://backend
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
background:
image: tidbai/backend:0.4.0
restart: always
depends_on:
- redis
ports:
- "5555:5555"
env_file:
- .env
volumes:
- ./data:/shared/data
command: /usr/bin/supervisord
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
local-embedding-reranker:
image: tidbai/local-embedding-reranker:v4-with-cache
ports:
- 5001:5001
environment:
- PRE_LOAD_DEFAULT_EMBEDDING_MODEL=true
# If you want to pre-load the default reranker model, change the following environment to true
- PRE_LOAD_DEFAULT_RERANKER_MODEL=false
- TRANSFORMERS_OFFLINE=1
# volumes:
# - ./local-embedding-reranker:/root/.cache/huggingface
# If you are using NVIDIA GPU, you can uncomment the following lines to enable GPU support
# deploy:
# resources:
# reservations:
# devices:
# - driver: nvidia
# count: 1
# capabilities: [gpu]
profiles:
- local-embedding-reranker
================================================
FILE: docs/.gitignore
================================================
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.js
.yarn/install-state.gz
# testing
/coverage
# next.js
/.next/
/out/
# production
/build
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
# local env files
.env*.local
# vercel
.vercel
# typescript
*.tsbuildinfo
next-env.d.ts
certificates
public/widget.js
public/robots.txt
public/sitemap.xml
public/sitemap-*.xml
*storybook.log
storybook-static
_pagefind/
================================================
FILE: docs/mdx-components.ts
================================================
import { useMDXComponents as getDocsMDXComponents } from 'nextra-theme-docs'
const docsComponents = getDocsMDXComponents()
export const useMDXComponents = components => ({
...docsComponents,
...components
})
================================================
FILE: docs/next-sitemap.config.js
================================================
module.exports = {
siteUrl: process.env.SITE_URL || 'https://tidb.ai',
generateRobotsTxt: true,// (optional)
// ...other options
}
================================================
FILE: docs/next.config.mjs
================================================
import nextra from 'nextra'
const withNextra = nextra({
latex: true,
defaultShowCopyCode: true,
search: {
codeblocks: true
},
contentDirBasePath: '/',
})
export default withNextra({
// reactStrictMode: true
})
================================================
FILE: docs/package.json
================================================
{
"name": "example-docs",
"license": "MIT",
"private": true,
"scripts": {
"build": "next build",
"dev": "next --turbopack",
"postbuild": "pagefind --site .next/server/app --output-path public/_pagefind",
"start": "next start"
},
"dependencies": {
"next": "15.0.7",
"nextra": "^4.2.17",
"nextra-theme-docs": "^4.2.17",
"react": "19.1.0",
"react-dom": "19.1.0",
"tailwindcss": "^4.1.4"
},
"devDependencies": {
"@types/node": "22.14.1",
"@types/react": "19.1.2",
"pagefind": "^1.3.0"
},
"packageManager": "pnpm@9.15.0+sha512.76e2379760a4328ec4415815bcd6628dee727af3779aaa4c914e3944156c4299921a89f976381ee107d41f12cfa4b66681ca9c718f0668fa0831ed4c6d8ba56c"
}
================================================
FILE: docs/src/app/[[...mdxPath]]/page.jsx
================================================
import { generateStaticParamsFor, importPage } from 'nextra/pages'
import { useMDXComponents as getMDXComponents } from '../../../mdx-components'
export const generateStaticParams = generateStaticParamsFor('mdxPath')
export async function generateMetadata(props) {
const params = await props.params
const { metadata } = await importPage(params.mdxPath)
return metadata
}
const Wrapper = getMDXComponents().wrapper
export default async function Page(props) {
const params = await props.params
const result = await importPage(params.mdxPath)
const { default: MDXContent, toc, metadata } = result
return (
)
}
================================================
FILE: docs/src/app/_app.tsx
================================================
import "./globals.css";
import type { AppProps } from "next/app";
export default function App({ Component, pageProps }: AppProps) {
return ;
}
================================================
FILE: docs/src/app/_ignored/_meta.js
================================================
// This file will be NOT treated as `_meta` file, since directory starts with underscore
export default {}
================================================
FILE: docs/src/app/_ignored/page.mdx
================================================
This file will be NOT treated as page, since directory starts with underscore
================================================
FILE: docs/src/app/_meta.ts
================================================
import type { Meta } from 'nextra'
export default {
docs: {
type: 'page',
title: '📘 Documentation'
}
} satisfies Meta
================================================
FILE: docs/src/app/globals.css
================================================
@tailwind base;
@tailwind components;
@tailwind utilities;
:root {
--foreground-rgb: 0, 0, 0;
--background-start-rgb: 214, 219, 220;
--background-end-rgb: 255, 255, 255;
--logo-color: #212121;
}
@media (prefers-color-scheme: dark) {
:root {
--foreground-rgb: 255, 255, 255;
--background-start-rgb: 0, 0, 0;
--background-end-rgb: 0, 0, 0;
--logo-color: #f5f5f5;
}
}
/* body {
color: rgb(var(--foreground-rgb));
background: linear-gradient(
to bottom,
transparent,
rgb(var(--background-end-rgb))
)
rgb(var(--background-start-rgb));
} */
@layer utilities {
.text-balance {
text-wrap: balance;
}
}
.dark {
.logo {
.logo-bg {
fill: #f5f5f5;
}
.logo-circle {
fill: #212121;
}
}
}
.logo {
.logo-bg {
fill: var(--logo-color);
}
.logo-circle {
fill: #f5f5f5;
}
}
================================================
FILE: docs/src/app/layout.jsx
================================================
import 'nextra-theme-docs/style.css'
import './globals.css';
import { Banner, Head } from 'nextra/components'
/* eslint-env node */
import { Footer, Layout, Navbar } from 'nextra-theme-docs'
import { getPageMap } from 'nextra/page-map'
export const metadata = {
metadataBase: new URL('https://autoflow.ai'),
title: {
template: 'AutoFlow'
},
description: 'Docs & Blogs of AutoFlow',
applicationName: 'AutoFlow',
generator: 'Next.js',
twitter: {
site: 'https://twitter.com/tidb_developer'
}
}
export default async function RootLayout({ children }) {
const navbar = (
AutoFlow
}
logoLink="/"
projectLink="https://github.com/pingcap/autoflow"
>
)
const pageMap = await getPageMap()
return (
{new Date().getFullYear()} © PingCAP . All rights reserved.
}
editLink="Edit this page on GitHub"
docsRepositoryBase="https://github.com/pingcap/autoflow"
sidebar={{ toggleButton: true, defaultMenuCollapseLevel: 1 }}
pageMap={pageMap}
>
{children}
)
}
================================================
FILE: docs/src/content/README.md
================================================
AutoFlow
[](https://hub.docker.com/r/tidbai/backend)
[](https://hub.docker.com/r/tidbai/frontend)
[](https://tidb-ai-playwright.vercel.app/)
> [!WARNING]
> Autoflow is still in the early stages of development. And we are actively working on it, the next move is to make it to a python package and make it more user-friendly e.g. `pip install autoflow-ai`. If you have any questions or suggestions, please feel free to contact us on [Discussion](https://github.com/pingcap/autoflow/discussions).
## Introduction
AutoFlow is an open source graph rag (graphrag: knowledge graph rag) based knowledge base tool built on top of [TiDB Vector](https://www.pingcap.com/ai?utm_source=tidb.ai&utm_medium=community) and [LlamaIndex](https://github.com/run-llama/llama_index) and [DSPy](https://github.com/stanfordnlp/dspy).
- **Live Demo**: [https://tidb.ai](https://tidb.ai?utm_source=tidb.ai&utm_medium=community)
- **Deployment Docs**: [Deployment Docs](https://autoflow.tidb.ai/?utm_source=github&utm_medium=tidb.ai)
## Features
1. **Perplexity-style Conversational Search page**: Our platform features an advanced built-in website crawler, designed to elevate your browsing experience. This crawler effortlessly navigates official and documentation sites, ensuring comprehensive coverage and streamlined search processes through sitemap URL scraping.

2. **Embeddable JavaScript Snippet**: Integrate our conversational search window effortlessly into your website by copying and embedding a simple JavaScript code snippet. This widget, typically placed at the bottom right corner of your site, facilitates instant responses to product-related queries.

## Deploy
- [Deploy with Docker Compose](https://autoflow.tidb.ai/deploy-with-docker) (with: 4 CPU cores and 8GB RAM)
## Tech Stack
- [TiDB](https://www.pingcap.com/ai?utm_source=tidb.ai&utm_medium=community) – Database to store chat history, vector, json, and analytic
- [LlamaIndex](https://www.llamaindex.ai/) - RAG framework
- [DSPy](https://github.com/stanfordnlp/dspy) - The framework for programming—not prompting—foundation models
- [Next.js](https://nextjs.org/) – Framework
- [Tailwind CSS](https://tailwindcss.com/) – CSS framework
- [shadcn/ui](https://ui.shadcn.com/) - Design
## Contributing
We welcome contributions from the community. If you are interested in contributing to the project, please read the [Contributing Guidelines](/CONTRIBUTING.md).
## License
AutoFlow is open-source under the Apache License, Version 2.0. You can [find it here](https://github.com/pingcap/autoflow/blob/main/LICENSE.txt).
## Contact
You can reach out to us on [Discord](https://discord.gg/XzSW23Jg9p).
================================================
FILE: docs/src/content/_meta.ts
================================================
import type { Meta } from 'nextra';
export default {
"-- Getting Started": {
"type": "separator",
"title": "Getting Started"
},
"index": "Introduction",
"quick-start": "Quick Start",
"resources": "Resources",
"faq": "FAQ",
"-- Deploy This RAG": {
"type": "separator",
"title": "Deployment"
},
"requirements": "Requirements",
"deploy-with-docker": "Deploy with Docker",
"-- Configuration": {
"type": "separator",
"title": "Configuration"
},
"llm": {
"title": "LLM - Large Language Model"
},
"embedding-model": {
"title": "Embedding Model"
},
"reranker-model": {
"title": "Reranker Model"
},
"knowledge-base": {
"title": "Knowledge Base"
},
"chat-engine": {
"title": "Chat Engine"
},
"evaluation": {
"title": "Evaluation"
},
"javascript": {
"title": "JS Widget"
},
"README": {
"display": "hidden"
},
"--": {
"type": "separator"
},
"releases": "Releases"
} satisfies Meta
================================================
FILE: docs/src/content/chat-engine.mdx
================================================
# Chat Engine
Chat Engine is a set of configurations that allow you to customize the chat experience for your users. You can configure the chat engine to use different models for generating responses, reranking results, and more.
## Configure Chat Engine
After logging in with an admin account, you can configure the Chat Engine in the admin panel.
1. Click on the `Chat Engines` tab;
2. Click on the `New Chat Engine` button to create a new chat engine;

3. In the `Retrieval` section, you can configure [knowledge base](./knowledge-base.mdx) as the knowledge source and related retrieval parameters.

4. You can also change the prompt to customize the chat experience for your users. The prompt is the message that the chatbot sends to the user to start the conversation.

5. Click the `Create Chat Engine` button to finish the configuration.
6. Now you can chat with LLM through the chat engine in the chat interface.

================================================
FILE: docs/src/content/deploy-with-docker.mdx
================================================
# Deploy with Docker Compose
import { Callout } from 'nextra/components'
This document provides instructions for deploying the entire Autoflow application using Docker Compose.
## Prerequisites
- Set up a TiDB cluster, you can use either:
- [TiDB Cloud Serverless](https://docs.pingcap.com/tidbcloud/tidb-cloud-quickstart)(recommended)
- [TiDB Self-Managed](https://docs.pingcap.com/tidb/stable/production-deployment-using-tiup)(>=v8.4).
- Install [Docker Compose](https://docs.docker.com/compose/install/).
## Deploy
1. Clone the repository:
```bash
git clone https://github.com/pingcap/autoflow.git;
cd autoflow/;
```
2. Copy and edit the `.env` file:
```bash
cp .env.example .env
vim .env # or use another text editor to edit this file
```
Replace the following placeholders with your own values:
- `SECRET_KEY`: you can generate a random secret key using:
```python
`python3 -c "import secrets; print(secrets.token_urlsafe(32))"`
```
- `TIDB_HOST`, `TIDB_USER`, `TIDB_PASSWORD` and `TIDB_DATABASE`: you can get these values from the TiDB cluster you set up before.
- Note: if you are using a self-managed TiDB cluster, you need to set `TIDB_SSL=false`.
- `EMBEDDING_MAX_TOKENS`: set them according to the embedding model you choose before, it can not be changed after the deployment.(We will remove it in the future, and move it to the admin panel)
3. Migrate the database schema:
```bash
docker compose run backend /bin/sh -c "alembic upgrade head"
```
4. Bootstrap the database with initial data:
```bash
# Use default admin credentials (admin@example.com with random password)
docker compose run backend /bin/sh -c "python bootstrap.py"
# Or specify a custom admin email
docker compose run backend /bin/sh -c "python bootstrap.py --email new-admin@example.com"
```
Running the bootstrap script creates an admin user. You can find the username and password in the output.
```bash
# Reset admin password (random generated)
docker compose run backend /bin/sh -c "python bootstrap.py -r"
# Or specify a new password
docker compose run backend /bin/sh -c "python bootstrap.py -r --password "
```
5. Start the services:
If you are using a SaaS embedding model, start the services with the following command:
```bash
docker compose up
```
If you want to use the built-in local embedding reranker, start the services with the following command:
```bash
docker compose --profile local-embedding-reranker up
```
6. Done! Now you can open your browser and visit [http://localhost:3000](http://localhost:3000) locally
## Configuration
After you deploy the application, you need to initialize the application by following [quick start guide](./quick-start.mdx).
1. Set up the default [LLM model](./llm.mdx) in the **Models > LLMs** page.
2. Set up the default [Embedding model](./embedding-model.mdx) in the **Models > Embedding Models** page.
3. Add a new [Knowledge Base](./knowledge-base.mdx) in the **Knowledge Bases** page.
4. Configure default [Chat Engine](./chat-engine.mdx) and set up the new knowledge base as the retrieval database.
## Upgrade
This section will help you upgrade pingcap/autoflow to the new version.
Suppose you want to upgrade pingcap/autoflow from 0.3.0 to 0.3.1. Follow these steps:
1. Edit your docker-compose.yml file to use the new image version.
```yaml
services:
backend:
image: tidbai/backend:0.3.1
frontend:
image: tidbai/frontend:0.3.1
background:
image: tidbai/backend:0.3.1
```
2. Pull the new image:
```bash
docker compose pull
```
3. Migrate the database schema:
```bash
docker compose run backend /bin/sh -c "alembic upgrade head"
```
4. Recreate the docker containers:
```bash
docker compose up -d --force-recreate
```
5. Check the logs to ensure everything is working correctly:
```bash
docker compose logs -f
```
6. Done!
================================================
FILE: docs/src/content/embedding-model.mdx
================================================
# Embedding Model
import { Callout } from 'nextra/components'
The Embedding Model converts given input into numerical vectors (embeddings) that represent the semantic meaning of the input text.
In Autoflow, we use the Embedding Model to vectorize documents and store them in TiDB. This enables us to leverage TiDB's Vector Search capability to retrieve relevant documents for user queries.
## Configure Embedding Model
After logging in with an admin account, you can configure the Embedding Model in the admin panel.
1. Click on the `Models > Embedding Models` tab;
2. Click the `New Embedding Model` button, select your preferred embedding model provider, and configure the model parameters.

## Supported Providers
Currently Autoflow supports the following embedding model providers:
### OpenAI
OpenAI provides a variety of Embedding Models, we recommend using the OpenAI `text-embedding-3-small` model due to its performance and compatibility with Autoflow.
**Supported Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| ------------------------ | ----------------- | ---------- |
| `text-embedding-3-small` | 1536 | 8191 |
For more information, see the [OpenAI Embedding Models documentation](https://platform.openai.com/docs/guides/embeddings#embedding-models).
### OpenAI-Like
Autoflow also supports embedding model providers (such as [ZhipuAI](#zhipuai)) that conform to the OpenAI API specification.
You can also use models deployed on local AI model platforms (such as [vLLM](#vllm) and [Xinference](https://inference.readthedocs.io/en/latest/index.html)) that conform to the OpenAI API specification in Autoflow.
To use OpenAI-Like embedding model providers, you need to provide the **base URL** of the embedding API as the following JSON format in **Advanced Settings**:
```json
{
"api_base": "{api_base_url}"
}
```
#### ZhipuAI BigModel
For example, the embedding API endpoint for ZhipuAI is:
`https://open.bigmodel.cn/api/paas/v4/embeddings`
You need to set up the base URL in the **Advanced Settings** as follows:
```json
{
"api_base": "https://open.bigmodel.cn/api/paas/v4/"
}
```
**Supported Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| --------------- | ----------------- | ---------- |
| `embedding-3` | 2048 | 8192 |
For more information, see the [ZhipuAI embedding models documentation](https://open.bigmodel.cn/dev/api/vector/embedding-3).
#### vLLM
When serving locally, the default embedding API endpoint for vLLM is:
`http://localhost:8000/v1/embeddings`
You need to set up the base URL in the **Advanced Settings** as follows:
```json
{
"api_base": "http://localhost:8000/v1/"
}
```
For more information, see the [vLLM documentation](https://docs.vllm.ai/en/stable/).
### JinaAI
JinaAI provides multimodal multilingual long-context Embedding Models for RAG applications.
**Supported Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| -------------------- | ----------------- | ---------- |
| `jina-clip-v1` | 768 | 8192 |
| `jina-embeddings-v3` | 1024 | 8192 |
For more information, see the [JinaAI embedding models documentation](https://jina.ai/embeddings/).
### Cohere
Cohere provides industry-leading large language models (LLMs) and RAG capabilities tailored to meet the needs of enterprise use cases that solve real-world problems.
**Supported Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| ------------------------- | ----------------- | ---------- |
| `embed-multilingual-v3.0` | 1024 | 512 |
For more information, see the [Cohere Embed documentation](https://docs.cohere.com/docs/cohere-embed).
### Amazon Bedrock
Amazon Bedrock is a fully managed foundation models service that provides a range of large language models and embedding models.
**Featured Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| ------------------------------- | ----------------- | ---------- |
| `amazon.titan-embed-text-v2:0` | 1024 | 8192 |
| `amazon.titan-embed-text-v1` | 1536 | 8192 |
| `amazon.titan-embed-g1-text-02` | 1536 | 8192 |
| `cohere.embed-english-v3` | 1024 | 512 |
| `cohere.embed-multilingual-v3` | 1024 | 512 |
To check all embbeding models supported by Bedrock, go to [Bedrock console](https://console.aws.amazon.com/bedrock).
To use Amazon Bedrock, you'll need to provide a JSON Object of your AWS Credentials, as described in the [AWS CLI config global settings](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global):
```json
{
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2"
}
```
For more information, see the [Amazon Bedrock documentation](https://docs.aws.amazon.com/bedrock/).
### Ollama
Ollama is a lightweight framework for building and running large language models and embedding models locally.
**Supported Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| ------------------ | ----------------- | ---------- |
| `nomic-embed-text` | 768 | 8192 |
| `bge-m3` | 1024 | 8192 |
To use Ollama, you'll need to configure the API base URL in the **Advanced Settings**:
```json
{
"base_url": "http://localhost:11434"
}
```
For more information, see the [Ollama embedding models documentation](https://ollama.com/blog/embedding-models).
### Gitee AI
Gitee AI is a third-party model provider that offers ready-to-use cutting-edge model APIs for AI developers.
**Supported Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| ------------------- | ----------------- | ---------- |
| `bge-m3` | 1024 | 8192 |
| `bge-large-zh-v1.5` | 1024 | 512 |
| `bge-small-zh-v1.5` | 512 | 512 |
For more information, see the [Gitee AI embedding models documentation](https://ai.gitee.com/docs/openapi/v1#tag/%E7%89%B9%E5%BE%81%E6%8A%BD%E5%8F%96/POST/embeddings).
### Azure OpenAI
Azure OpenAI is a cloud-based AI service that provides a OpenAI-like API on Azure.
**Supported Models**:
| Embedding Model | Vector Dimensions | Max Tokens |
| ------------------------ | ----------------- | ---------- |
| `text-embedding-3-small` | 1536 | 8191 |
For more information, see:
- [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
- [Create and deploy an Azure OpenAI Service resource](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal)
After creating the Azure OpenAI Service resource, you can configure the API base URL in the **Advanced Settings**:
```json
{
"azure_endpoint": "https://.openai.azure.com/",
"api_version": ""
}
```
You can find those parameters in the [Deployment Tab](https://ai.azure.com/resource/deployments) of your Azure OpenAI Service resource.

### Local Embedding Server
Autoflow's local embedding server is a self-hosted embedding service built upon [sentence-transformers](https://www.sentence-transformers.org/) and deployed on your own infrastructure.
You can choose from a variety of pre-trained models from [Hugging Face](https://huggingface.co/models), such as:
| Embedding Model | Vector Dimensions | Max Tokens |
| --------------- | ----------------- | ---------- |
| `BAAI/bge-m3` | 1024 | 8192 |
To configure the Local Embedding Service, set the API URL in the **Advanced Settings**:
```json
{
"api_url": "http://local-embedding-reranker:5001/api/v1/embedding"
}
```
================================================
FILE: docs/src/content/evaluation.mdx
================================================
# Evaluation (beta)
The **Evaluation** module is an integral part of the Chat Engine of the AutoFlow, designed to assess the performance and reliability of the Chat Engine's outputs.
Currently, the module provides evaluations based on two key metrics:
1. **Factual Correctness**: This metric measures the degree to which the generated responses align with verified facts. It ensures that the Chat Engine delivers accurate and trustworthy information.
2. **Semantic Similarity**: This metric evaluates the closeness in meaning between the generated responses and the expected outputs. It helps gauge the contextual relevance and coherence of the Chat Engine's performance.
With these metrics, the Evaluation component empowers developers and users to analyze and optimize the Chat Engine's capabilities effectively.
## Prerequisites
- An admin account to access the Evaluation panel.
- (Optional) A CSV dataset with at least two columns:
- `query`: i.e. question.
- `reference`: i.e. expected answer.
## How to Evaluate
To evaluate the Chat Engine, follow these steps:
1. Create an evaluation dataset:
1. Click on the **Evaluation** in the left panel, and then click the **Datasets** button.

2. Click on the **New Evaluation Dataset** button.
3. Type in the dataset name, and if you have a CSV file with the required columns, you can upload it to initial the evaluation dataset.

4. Click on the **Create** button.
2. Create an evaluation task:
1. Click on the **Evaluation** in the left panel, and then click the **Tasks** button.
2. Click on the **New Evaluation Task** button.
3. Type in the task name, select the evaluation dataset, select the evaluation targeting Chat Engine, and type in the run size.
> **Note:**
>
> The **Run Size** is a parameter that can cut your dataset into smaller amount to evaluation task.
>
> - For example, your dataset has 1000 rows, and you set the run size to 100, then the evaluation task will only evaluate the first 100 rows.
> - Run size cannot change the evaluation dataset, it only changes the amount of data to evaluation task.

4. Click on the **Create** button.
3. Waiting for the evaluation task to finish, and you can see the evaluation result in the task detail.
1. Click on the **Evaluation** in the left panel, and then click the **Tasks** button.
2. Click on the **Name** of the task you want to see the result.
3. Make your insight from the evaluation result.

================================================
FILE: docs/src/content/faq.mdx
================================================
## FAQ of this RAG app
The FAQ is a collection of useful questions and answers about the
project. If you have a question that isn't answered here, please
[open a discussion](https://github.com/pingcap/autoflow/discussions).
export function FAQBox({ title, children }) {
return (
{title}
{children}
)
}

- [Free] The code of this RAG app pingcap/autoflow is open source, it's free to use;
- [Free] The database of this RAG app is free for **25GiB storage** as we use [TiDB Cloud Serverless](https://www.pingcap.com/tidb-cloud-serverless/?utm_source=tidb.ai&utm_medium=community);
- [Paid] **But** the token of OpenAI is paid, you should have an account of https://platform.openai.com for an API token.
AutoFlow uses LLM to extract knowledge graph from docs and generate responses to user queries, so the LLM is the core of this RAG app, the Smarter LLM, the better the performance.
Currently the online demo(https://tidb.ai) is using OpenAI as chat model by default, and the following LLMs are supported/tested by this RAG app:
- [gpt-4/gpt-4o/gpt-4o-mini/o1/o1-mini](https://platform.openai.com/docs/models)
- [gemini-1.5-pro/gemini-2.0-flash](https://gemini.google.com/)
- [claude-3.5-sonnet](https://claude.ai/)
- [glm-4-plus](https://bigmodel.cn/)
- ...
Yes
- [Change Default Reranker](/docs/reranker-model)
- [Change Default Embedding API](/docs/embedding-model)
- TiDB is a distributed MySQL SQL database, it's easy to scale out and has a good performance;
- TiDB is a multi-model and all-in-one database, it supports MySQL SQL, JSON, Vector, and more, which can replace multiple databases;
- The performance of TiDB is enough to store and query the Knowledge Graph, as we just search the first-level or second-level neighbors of the node.
================================================
FILE: docs/src/content/index.mdx
================================================
import ReadmeContent from "./README.md";
================================================
FILE: docs/src/content/javascript.mdx
================================================
# Embeddable JavaScript Snippet
What if you could integrate a conversational search window into your website to facilitate instant responses to product-related queries? With the embeddable JavaScript snippet, you can do just that. This feature allows you to copy and embed a simple JavaScript code snippet into your website, typically placed at the bottom right corner of your site.
## Pre-requisites
If you want to embed the JavaScript snippet into your other website, you need to enable it in `.env` file.
* Step 1: Enable the JavaScript Widget
```bash
# JS Widgets: if you want to use JS widgets, you need to set the following variables to enable CORS.
BACKEND_CORS_ORIGINS=https://your-domain.com
```
* Step 2: Restart your service
## How to Embed
To embed the JavaScript snippet, follow these steps:
1. Enter Settings > JavaScript, and copy the JavaScript code snippet like this:

2. Paste the JavaScript code snippet into your website's HTML file;
3. Initialize the JavaScript snippet with input parameters like `Widget Title`, `Site Name`, `Search Titles`, `Example Questions`, `Footer Links`, and `Social Media Links`;
4. Restart your website to see the conversational search window appear at the bottom right corner of your site.

That's all you need to do to integrate our conversational search window into your website. Enjoy! You can also go to https://tidb.ai and https://docs.pingcap.com to see the chat window in action.
================================================
FILE: docs/src/content/knowledge-base.mdx
================================================
import { Callout } from 'nextra/components'
# Knowledge Base
Knowledge base is a collection of documents and represents a certain knowledge domain, which can be used by LLM to answer users' questions.
## Configure Knowledge Base
After logging in with an admin account, you can configure the Knowledge Base in the admin panel.
1. Click on the `Knowledge Bases` tab;
2. Click on the `New Knowledge Base` button to add a new knowledge base;

Set up the following fields and click `Create` to create the knowledge base.
| Field | Description | Required | Updatable |
| ---------------- | ------------------------------------------------------------------------------------ | -------- | --------- |
| Name | The name of the knowledge base. | Yes | Yes |
| Description | Describe the knowledge base. | No | Yes |
| LLM | The LLM used during building vector index and knowledge graph index. | Yes | No |
| Embedding Model | the embedding model used by vector index and knowledge graph index. | Yes | No |
| Indexing Methods | Vector index is forced to be enabled. The knowledge graph index is used by GraphRAG. | Yes | No |
3. Click the `Data Sources` tab and add a new data source.
4. Once the data source added, Autoflow will asynchronously import and index documents from the data source, you can check the progress in the `Index Process` tab.
5. Go to `Chat Engine` configuration page, select the knowledge base you created and click `Save` to enable it.

## Data Source Management
Each Knowledge Base could contain multiple data sources.
You can click the **Data Sources** tab to view all data sources.

### Create Data Source
You can click the button on top of data sources list to create new data sources.
Currently, AutoFlow supports three types of data sources:
- File uploads from local:
- Markdown (\*.md)
- PDF (\*.pdf)
- Microsoft Word (\*.docx)
- Microsoft PowerPoint (\*.pptx)
- Microsoft Excel (\*.xlsx)
- Text (\*.txt) files
- Website pages by sitemap.xml
- Certain web pages by URL
### Update Data source
You can rename data source by click the **Configure** button of data source.
### Delete Data source
You can delete data source by click the **Delete** button of data source.
**This operation cannot be undone**.
Deleting data sources will trigger an **asynchronous** deletion process for related **documents**, **chunks of vector index**
and **entities and relationships of knowledge graph index**.
## Documents Management
You can manage documents in the **Documents** tab.

#### Delete Documents
Scroll horizontally to the end of the documents table, click the `...` button and click the `Delete` item on the dropdown menu.
**This operation cannot be undone**.
Deleting documents will trigger an **asynchronous** deletion process for related **chunks of vector index**
and **entities and relationships of knowledge graph index**.
================================================
FILE: docs/src/content/llm.mdx
================================================
# LLM - Large Language Model
In this app, LLM is used for several purposes:
1. Extracting knowledge from docs;
2. Generating responses to user queries.
## Configure LLM
After logging in with an admin account, you can configure the LLM in the admin panel.
1. Click on the `Models > LLMs` tab;
2. Click on the `New LLM` button to add a new LLM;

3. Input your LLM information and click `Create LLM` button;
4. Done!
import { Callout } from 'nextra/components'
If you want to use the new LLM while answering user queries, you need switch to `Chat Engines` tab and set the new LLM as LLM.
## Supported LLM providers
Currently Autoflow supports the following LLM providers:
### Amazon Bedrock
To use Amazon Bedrock, you'll need to provide a JSON Object of your AWS Credentials, as described in the [AWS CLI config global settings](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global):
```json
{
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2"
}
```
To learn more about Amazon Bedrock, please visit [Amazon Bedrock](https://aws.amazon.com/bedrock/).
### Google Gemini
To learn more about Google Gemini, please visit [Google Gemini](https://gemini.google.com/).
### Google Vertex AI
To learn more about Vertex AI, please visit [Vertex AI](https://cloud.google.com/vertex-ai).
### Gitee AI
Follow the UI to configure the Gitee AI provider. To learn more about Gitee AI, please visit [Gitee AI](https://ai.gitee.com/serverless-api).
### OpenAI
To learn more about OpenAI, please visit [OpenAI](https://platform.openai.com/).
### OpenAI-Like
Autoflow also support the providers that conform to the OpenAI API specification.
To use OpenAI-Like LLM providers, you need to provide the **api_base** of the LLM API as the following JSON format in **Advanced Settings**:
```json
{
"api_base": "{api_base_url}"
}
```
#### OpenRouter
Default config:
```json
{
"api_base": "https://openrouter.ai/api/v1/"
}
```
To learn more about OpenRouter, please visit [OpenRouter](https://openrouter.ai/).
{/*
#### ZhipuAI BigModel
Default config:
```json
{
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"is_chat_model": true
}
```
To learn more about BigModel, please visit [BigModel](https://open.bigmodel.cn/).
*/}
#### Ollama
Default config:
```json
{
"api_base": "http://localhost:11434"
}
```
To learn more about Ollama, please visit [Ollama](https://ollama.com/).
#### vLLM
Default config:
```json
{
"api_base": "http://localhost:8000/v1/"
}
```
To learn more about vLLM, please visit [vLLM](https://docs.vllm.ai/en/stable/).
#### Xinference
If you assigned a model uid different from the model name, you need to fill in **model uid** in the box `model`.
Default config:
```json
{
"api_base": "http://localhost:9997/v1/"
}
```
To learn more about Xinference, please visit [Xinference](https://inference.readthedocs.io/en/latest/).
#### Azure OpenAI
To learn more about Azure OpenAI, please visit:
- [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
- [Create and deploy an Azure OpenAI Service resource](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal)
After creating the Azure OpenAI Service resource, you can configure the API base URL in the **Advanced Settings**:
```json
{
"azure_endpoint": "https://.openai.azure.com/",
"api_version": "",
"engine": ""
}
```
You can find those parameters in the [Deployment Tab](https://ai.azure.com/resource/deployments) of your Azure OpenAI Service resource.
Do not mix `Model version` and `API version` up, they are different.

#### Novita AI
Default config:
```json
{
"api_base": "https://api.novita.ai/v3/openai"
}
```
To learn more about Novita AI, please visit [Novita AI](https://novita.ai/).
#### DeepSeek
DeepSeek provides chat model `deepseek-chat`.
Default config:
```json
{
"api_base": "https://api.deepseek.com/v1",
"is_chat_model": true
}
```
To learn more about DeepSeek, please visit [DeepSeek](https://www.deepseek.com/).
================================================
FILE: docs/src/content/quick-start.mdx
================================================
# Quick Start
Here is a documentation page that shows how to setup a same tool like https://tidb.ai from deployment to usage.
## Step 1: Deployment
You can deploy self-hosted Autoflow on your server with Docker Compose.
[Read the deployment guide](./deploy-with-docker)
## Step 2: Configure
After deployment, you need to login to the admin dashboard to configure the tool withyour own settings.
### Configure the LLM - Large Language Model
Go to the **Models > LLMs** page to [configure the LLM model](./llm).
> The LLM is used for extracting knowledge from docs and generating responses. You can change the default LLM to another one.

### Configure the Embedding Model
Go to the **Models > Embedding Models** page to [configure the embedding model](./embedding-model).
> The Embedding Model is a machine learning model that is trained to generate embeddings for a given input. We must translate text to vectors with this model before insert vector to database.

### Configure the Reranker [Optional]
> The Reranker is an essential tool that optimizes the order of results from initial searches. It is optional but recommended.
Go to the **Models > Rerankers** page to configure [the reranker model](./reranker-model).

## Step 3: Add a New Knowledge Base and Upload Documents
Go to the **Knowledge Base** page to add a new knowledge base and upload documents.

After adding a new knowledge base, you can upload your documents from local or crawl from the web in the **Data Source** subpage.

> After adding data source, there will be a period of time for indexing the data.
For more details, please refer to [Knowledge Base](./knowledge-base) documentation.
## Step 4: Set up the Chat Engine
Go to the **Chat Engines** page to [set up the chat engine](./chat-engine).
> The chat engine is used to chat with users.

## Step 5: Usage
After deployment, configuration and uploading documents, you can use the tool to chat with users to answer their questions.
pingcap/autoflow provides several features to help you chat with users:
1. Out-of-the-box chat interface, e.g. https://tidb.ai
2. API to chat with users programmatically, e.g. https://tidb.ai/api-docs
3. Embeddable chat widget to integrate with your website
================================================
FILE: docs/src/content/releases/_meta.ts
================================================
import type { Meta } from 'nextra';
export default {
"v0.4.0": "v0.4.0",
"v0.3.0": "v0.3.0",
"v0.2.0": "v0.2.0",
"v0.1.0": "v0.1.0",
} satisfies Meta
================================================
FILE: docs/src/content/releases/index.mdx
================================================
---
title: Releases
asIndexPage: true
---
# Releases
import { Cards } from 'nextra/components'
================================================
FILE: docs/src/content/releases/v0.1.0.md
================================================
# Release Notes for v0.1.0
## Overview
This is the first release of TiDB.AI.
## New Features
- LLM Model Support
- OpenAI
- Gemini
- OpenAI-Like
- openrouter
- zhipuai
- ...
- AnthropicVertex
- Bedrock
- Rerank Model Support
- Jina
- Cohere
- Embedding Model Support
- OpenAI
- Data Source Support
- Website
- Single URL
- Sitemap
- Local file
- .pdf
- .md
- .txt
## Known Issues
For a list of known issues, please visit our [GitHub Issues page](https://github.com/pingcap/autoflow/issues). If you encounter any problems, we encourage you to report them.
================================================
FILE: docs/src/content/releases/v0.2.0.md
================================================
# Release Notes for v0.2.0
## Improvements
- Added an end-to-end test flow after releasing new versions to enhance system robustness.
- Optimized the UI during answer retrieval to provide better feedback on current processes.
- Support for connecting to Self-Managed TiDB clusters that also have vector search capabilities.
- Integrated support for the BaiSheng reranker.
## Bug Fixes
- Fixed an issue when creating synopsis entity nodes.
- Fix the delete button in LLM/Reranker list page.
- Fixed a SiteSetting cache issue that prevented proper synchronization across processes.
If you are deploying tidb.ai using docker, please refer to this [document](https://autoflow.tidb.ai/deploy-with-docker#upgrade) for upgrading your tidb.ai.
================================================
FILE: docs/src/content/releases/v0.3.0.md
================================================
# Release Notes for v0.3.0
## Highlights
- Rename project to `autoflow`
- Multiple Knowledge Bases support
- Support new LLM providers
- [OpenRouter](../llm.mdx#openrouter)
- [ZhipuAI BigModel](../llm.mdx#zhipuai-bigmodel)
- [Ollama](../llm.mdx#ollama)
- Support new embedding models providers
- [Ollama](../embedding-model.mdx#ollama)
- Support [OpenAI Like](../embedding-model.mdx#openai-like) embedding model providers
- [ZhipuAI BigModel](../embedding-model.mdx#zhipuai-bigmodel)
## Breaking Changes
- [Knowledge Base](/docs/knowledge-base)
- Deprecate old datasource management. [Manual migration](#manual-migration) is required.
- Support different Embedding Models for each KB
- Individual vector index and knowledge graph index for each KB
- Move vector search and graph search API to knowledge base level
- Move index progress and error retries to knowledge base
## Improvements
- Support create chat engine from default config
## Manual Migration
### Migrating from v0.2.x to v0.3.0
In previous versions, there was no concept of Knowledge Base, and Autoflow stored index data in the following tables:
- `chunks`
- `entities`
- `relationships`
In the v0.3.0, index data will be stored separately in tables corresponding to each knowledge base:
- `chunks_{knowledge_base_id}`
- `entities_{knowledge_base_id}`
- `relationships_{knowledge_base_id}`
After upgrading to v0.3.0, the data in the original tables will be **preserved**. If you want to migrate the old data to the new knowledge base, you can follow the steps below:
1. Follow the [Upgrade](../deploy-with-docker.mdx#upgrade) guide to upgrade the application to v0.3.0.
2. Create a new knowledge base in the admin panel.
3. Connect to your database using TiDB Serverless Web Console or using mysql client.
4. Obtain the new knowledge base ID:
```sql
SELECT id, name FROM knowledge_bases;
```
5. Replace the `{knowledge_base_id}` in the following SQL scripts with the new knowledge base ID and execute them:
```sql
BEGIN;
INSERT INTO knowledge_base_datasources (knowledge_base_id, data_source_id)
SELECT {knowledge_base_id}, id
FROM data_sources
WHERE id NOT IN (SELECT data_source_id FROM knowledge_base_datasources);
UPDATE documents SET knowledge_base_id = {knowledge_base_id} WHERE knowledge_base_id IS NULL;
INSERT INTO chunks_{knowledge_base_id} (id, hash, text, meta, embedding, document_id, relations, source_uri, index_status, index_result, created_at, updated_at)
SELECT id, hash, text, meta, embedding, document_id, relations, source_uri, index_status, index_result, created_at, updated_at
FROM chunks;
INSERT INTO entities_{knowledge_base_id} (id, name, description, meta, entity_type, synopsis_info, description_vec, meta_vec)
SELECT id, name, description, meta, entity_type, synopsis_info, description_vec, meta_vec
FROM entities;
INSERT INTO relationships_{knowledge_base_id} (id, description, meta, weight, source_entity_id, target_entity_id, last_modified_at, document_id, chunk_id, description_vec)
SELECT id, description, meta, weight, source_entity_id, target_entity_id, last_modified_at, document_id, chunk_id, description_vec
FROM relationships;
COMMIT;
```
================================================
FILE: docs/src/content/releases/v0.4.0.md
================================================
# Release Notes for v0.4.0
## Highlights
- Support [Evaluation (beta)](../evaluation.mdx) tool to evaluate the performance and reliability of the Chat Engine’s outputs.
- Current support key metrics:
- Factual Correctness
- Semantic Similarity
- Support new LLM providers
- [Gitee AI](../llm.mdx#gitee-ai)
- Test new OpenAI-like providers
- [vLLM](../llm.mdx#vllm)
- [Xinference](../llm.mdx#xinference)
- Support new embedding model providers
- [Gitee AI](../embedding-model.mdx#gitee-ai)
- [Amazon Bedrock](../embedding-model.mdx#amazon-bedrock)
## Improvements
- Limit the upload file size via `max_upload_file_size` (10MB by default) parameter on site setting
- Support download the reference file in the chat page
If you are deploying Autoflow using docker, please follow the [Upgrade](../deploy-with-docker.mdx#upgrade) guide to upgrade your Autoflow.
================================================
FILE: docs/src/content/requirements.mdx
================================================
# Deployment Requirements
In this section, we will cover the requirements for deploying the project.
## LLM(Large Language Model) and Embedding Model
* A saas LLM model like OpenAI API or self-hosted LLM model with requirements:
* Smarter than GPT-3.5
* Provide openai-like API
* Embedding model: AutoFlow needs an embedding model to translate the text into vectors. You can use the following:
* OpenAI-like embedding model
* Cohere embedding model
* ZhipuAI embedding model
* You can also use the [Jina AI API](https://jina.ai/) for this purpose. It is free for 1M tokens.
* (Optional) Reranker. You can use the [Jina AI API](https://jina.ai/) for this purpose. It is free for 1M tokens.
## TiDB
* With [TiDB Serverless](https://pingcap.com/ai) account, you can setup a TiDB cluster with Vector Search enabled. Free quota is available for 1M RU per month.
* You can also use a self-hosted TiDB cluster(>v8.4) with Vector Search enabled, but please note it will require TiFlash enabled for Vector Search.
## Hardware
### If you are using a Cloud TiDB and SaaS LLM
You can use any of the following web hosting services to deploy the project:
* Cloud server providers like [AWS](https://aws.amazon.com/), [Google Cloud](https://cloud.google.com/), [Azure](https://azure.microsoft.com/), etc.
* Or your own server.
We suggest the following configuration for the server:
| Name | Value |
|----------------------|------------------|
| CPU | 4 vCPUs |
| Memory | 8 GB RAM |
| Disk | 200 GB SSD |
| Number of servers | 1 |
### If you are using a self-hosted TiDB and self-hosted LLM
If you use a self-hosted TiDB and self-hosted LLM, you need a powerful server to handle the load. We suggest the following configuration for the server:
| Name | Value |
|----------------------|------------------|
| CPU | 32 vCPUs |
| Memory | 64 GB RAM |
| Disk | 500 GB SSD |
| GPU | 1 x NVIDIA A100 |
| Number of servers | 1 |
GPU here is used for the LLM model, you can use any other GPU model that can be used for the LLM model which has capability more than gpt-3.5.
================================================
FILE: docs/src/content/reranker-model.mdx
================================================
# Configure Reranker Model
In a RAG system, a reranker is a model that re-ranks the retrieved documents based on the user's query. The reranker model is used to improve the relevance of the retrieved documents.
## Configure Reranker Model
After logging in with an admin account, you can configure the Reranker Model in the admin panel.
1. Click on the `Models > Reranker Models` tab;
2. Click on the `New Reranker Model` button to add a new Reranker;

3. Input your Reranker information and click `Create Reranker` button;
## Supported Reranker Providers
Currently AutoFlow supports the following reranker providers:
### JinaAI
To learn more about JinaAI reranking, please visit [Jina AI Reranker](https://jina.ai/reranker/).
### Cohere
To learn more about Cohere reranking, please visit [Cohere Rerank](https://cohere.com/rerank/).
### vLLM
To use vLLM rerankers, you need to provide the **base_url** of the reranker API as the following JSON format in **Advanced Settings**:
```json
{
"base_url": "{api_base_url}"
}
```
Default config:
```json
{
"base_url": "http://localhost:8000"
}
```
To learn more about vLLM reranking, please visit [vLLM Sentence Pair Scoring Models](https://docs.vllm.ai/en/latest/models/supported_models.html#sentence-pair-scoring-task-score).
### Xorbits Inference (Xinference)
To use Xinference rerankers, you need to provide the **base_url** of the reranker API as the following JSON format in **Advanced Settings**:
```json
{
"base_url": "{api_base_url}"
}
```
Default config:
```json
{
"base_url": "http://localhost:9997"
}
```
To learn more about Xinference reranking, please visit [Xinference Rerank](https://inference.readthedocs.io/en/latest/models/model_abilities/rerank.html).
### Amazon Bedrock
To use Amazon Bedrock rerankers, you'll need to provide a JSON Object of your AWS Credentials, as described in the [AWS CLI config global settings](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global):
```json
{
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2"
}
```
To find more reranking models supported by Amazon Bedrock, please visit [Amazon Bedrock Models Reference](https://docs.aws.amazon.com/bedrock/latest/userguide/foundation-models-reference.html).
================================================
FILE: docs/src/content/resources.mdx
================================================
# Resources
Here are some videos tutorials help you to understand the basics of the platform and how to use it.
## Presentation: Intro to TiDB (Graph) RAG
> Agenda:
> - Intro to TiDB RAG / Graph RAG
> - How it works?
> - Vector RAG + Graph RAG + Fulltext Search(WIP)
> - Plan-able RAG engine - ngaut/StackVM
> - Benefits of using TiDB for AI app storage
> - How to copy one?
Google Slides: [Intro to TiDB (Graph) RAG](https://docs.google.com/presentation/d/1uK9oOW3QGHiNoqYUpFRqHPN4J_FvB2qeDsmAFMKASaY/edit?usp=sharing)
## Video: Intro to PingCAP/AutoFlow Post-Verification and Graph Search
> This video will show you how to use the platform to search(from graph) for information and verify the results.
> We introduce a new feature called Post-Verification which allows you to verify the results and correct them if needed.
VIDEO
================================================
FILE: docs/tsconfig.json
================================================
{
"compilerOptions": {
"target": "ES2017",
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"allowJs": true,
"skipLibCheck": true,
"strict": false,
"noEmit": true,
"incremental": true,
"module": "esnext",
"esModuleInterop": true,
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"plugins": [
{
"name": "next"
}
]
},
"include": [
"next-env.d.ts",
".next/types/**/*.ts",
"**/*.ts",
"**/*.tsx",
"**.ts",
],
"exclude": [
"node_modules"
]
}
================================================
FILE: e2e/.gitignore
================================================
node_modules/
/test-results/
/playwright-report/
/blob-report/
/playwright/.cache/
.idea
.env.local
local-test.sh
# Internal output
.credentials
*.stdout
*.stderr
screenshots
.vercel
================================================
FILE: e2e/README.md
================================================
# Integration Test for TIDB.ai
================================================
FILE: e2e/deploy-test-result.sh
================================================
#!/bin/bash
set -e
if [[ "${VERCEL_TOKEN}" && "${VERCEL_ORG_ID}" && "${VERCEL_PROJECT_ID}" ]]; then
npx vercel deploy --yes --token "${VERCEL_TOKEN}" "${VERCEL_CLI_ARGS}"
else
echo ""
echo "> Not configured"
echo ""
exit 1
fi
================================================
FILE: e2e/docker-compose.yml
================================================
name: tidb-ai-integration-test
services:
redis:
image: redis:6.0.16
restart: always
pull_policy: always
volumes:
- ${E2E_DATA_REDIS_DIR:-.data/redis}:/data
backend:
image: tidbai/backend:${E2E_DOCKER_TAG_BACKEND}
platform: ${E2E_DOCKER_PLATFORM}
restart: always
depends_on:
- redis
ports:
- "5001:80"
env_file:
- .env.backend
volumes:
- ${E2E_DATA_STORAGE_DIR:-.data/storage}:/shared/data
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
extra_hosts:
host.docker.internal: host-gateway
frontend:
image: tidbai/frontend:${E2E_DOCKER_TAG_FRONTEND}
platform: ${E2E_DOCKER_PLATFORM}
restart: always
depends_on:
- backend
ports:
- "3000:3000"
environment:
BASE_URL: http://backend
env_file:
- .env.frontend
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
background:
image: tidbai/backend:${E2E_DOCKER_TAG_BACKEND}
platform: ${E2E_DOCKER_PLATFORM}
restart: always
depends_on:
- redis
ports:
- "5555:5555"
env_file:
- .env.backend
command: /usr/bin/supervisord
volumes:
- ${E2E_DATA_STORAGE_DIR:-.data/storage}:/shared/data
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
extra_hosts:
host.docker.internal: host-gateway
static-web-server:
image: joseluisq/static-web-server:2
restart: always
ports:
- "4001:80"
volumes:
- ./test-html:/public
================================================
FILE: e2e/global.setup.ts
================================================
import { config } from 'dotenv';
export default function () {
config({
path: '.credentials',
});
}
================================================
FILE: e2e/package.json
================================================
{
"name": "tidb.ai-playwright",
"version": "1.0.0",
"main": "index.js",
"keywords": [],
"author": "",
"license": "ISC",
"description": "",
"devDependencies": {
"@playwright/test": "^1.46.0",
"@types/node": "^22.2.0"
},
"dependencies": {
"dotenv": "^16.4.5"
}
}
================================================
FILE: e2e/playwright.config.ts
================================================
import { defineConfig, devices } from '@playwright/test';
/**
* Read environment variables from file.
* https://github.com/motdotla/dotenv
*/
// import dotenv from 'dotenv';
// dotenv.config({ path: path.resolve(__dirname, '.env') });
/**
* See https://playwright.dev/docs/test-configuration.
*/
export default defineConfig({
testDir: './tests',
/* Run tests in files in parallel */
fullyParallel: true,
/* Fail the build on CI if you accidentally left test.only in the source code. */
forbidOnly: !!process.env.CI,
/* Retry on CI only */
retries: 0,
/* Opt out of parallel tests on CI. */
workers: process.env.CI ? 1 : undefined,
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
reporter: 'html',
globalSetup: 'global.setup.ts',
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
use: {
/* Base URL to use in actions like `await page.goto('/')`. */
baseURL: 'http://127.0.0.1:3000',
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
trace: 'on',
screenshot: 'off',
video: 'on',
},
/* Configure projects for major browsers */
projects: [
{
name: 'bootstrap',
use: { ...devices['Desktop Chrome'] },
testMatch: 'tests/bootstrap.ts',
},
{
name: 'chromium',
use: { ...devices['Desktop Chrome'] },
dependencies: ['bootstrap'],
},
// {
// name: 'firefox',
// use: { ...devices['Desktop Firefox'] },
// },
//
// {
// name: 'webkit',
// use: { ...devices['Desktop Safari'] },
// },
/* Test against mobile viewports. */
// {
// name: 'Mobile Chrome',
// use: { ...devices['Pixel 5'] },
// },
// {
// name: 'Mobile Safari',
// use: { ...devices['iPhone 12'] },
// },
/* Test against branded browsers. */
// {
// name: 'Microsoft Edge',
// use: { ...devices['Desktop Edge'], channel: 'msedge' },
// },
// {
// name: 'Google Chrome',
// use: { ...devices['Desktop Chrome'], channel: 'chrome' },
// },
],
/* Run your local dev server before starting the tests */
// webServer: {
// command: 'npm run start',
// url: 'http://127.0.0.1:3000',
// reuseExistingServer: !process.env.CI,
// },
});
================================================
FILE: e2e/prepare-test.sh
================================================
#!/bin/bash
set -e
TIDB_VERSION=v8.5.0
# Download tiup
function prepare_tidb () {
curl --proto '=https' --tlsv1.2 -sSf https://tiup-mirrors.pingcap.com/install.sh | sh
PATH=${PATH}:/home/runner/.tiup/bin
tiup install playground tidb:${TIDB_VERSION} pd:${TIDB_VERSION} tikv:${TIDB_VERSION} tiflash:${TIDB_VERSION}
}
# Pull docker images
function prepare_docker_images() {
docker compose pull frontend background backend redis static-web-server
}
prepare_tidb &
PID1=$!
prepare_docker_images &
PID2=$!
wait $PID1
wait $PID2
================================================
FILE: e2e/res/sample-evaluation-dataset.csv
================================================
"id","query","reference","topic_type"
"42","Example Query","Example Reference","Basic Knowledge"
================================================
FILE: e2e/start-test.sh
================================================
#!/bin/bash
set -e
# Formats
NO_FORMAT="\033[0m"
F_BOLD="\033[1m"
F_UNDERLINED="\033[4m"
C_AQUA="\033[38;5;14m"
TAG="${F_BOLD}${F_UNDERLINED}${C_AQUA}[TiDB.AI Integration Test]${NO_FORMAT}"
TIDB_VERSION=v8.5.0
PATH=${PATH}:/home/runner/.tiup/bin
echo -e "$TAG Creating temp dir"
export E2E_DATA_STORAGE_DIR=$(mktemp -d "${TMPDIR:-/tmp/}"/tidbai-storage.XXXXXXXX | sed 's#//#/#g')
export E2E_DATA_REDIS_DIR=$(mktemp -d "${TMPDIR:-/tmp/}"/tidbai-redis.XXXXXXXX | sed 's#//#/#g')
echo E2E_DOCKER_TAG_FRONTEND: ${E2E_DOCKER_TAG_FRONTEND}
echo E2E_DOCKER_TAG_BACKEND: ${E2E_DOCKER_TAG_BACKEND}
echo E2E_DATA_STORAGE_DIR: ${E2E_DATA_STORAGE_DIR}
echo E2E_DATA_REDIS_DIR: ${E2E_DATA_REDIS_DIR}
echo -e "$TAG Starting TiDB"
CLUSTER_TAG=e2e_$(xxd -l4 -ps /dev/urandom)
TIDB_PID=
# Cleanups
function clean_up {
ARG=$?
echo -e "$TAG Cleaning up..."
# Stop dockers
echo -e "$TAG Shutdown dockers..."
docker compose down frontend background backend redis static-web-server
# Stop tiup playground and cleanup data
echo -e "$TAG Stopping tiup playground cluster..."
echo -e "$TAG Wait until TiDB down..."
kill $TIDB_PID 2>/dev/null || true
while ! tiup playground display ${CLUSTER_TAG} >/dev/null 2>/dev/null
do
sleep 1
done
echo -e "$TAG Cleaning tiup playground data..."
tiup clean ${CLUSTER_TAG}
# Remove temp dirs
echo -e "$TAG Cleaning temp data dirs"
rm -rf ${E2E_DATA_STORAGE_DIR} ${E2E_DATA_REDIS_DIR} 2>/dev/null || true
exit $ARG
}
trap clean_up EXIT
echo -e "$TAG Create tiup playground cluster..."
tiup playground ${TIDB_VERSION} --tag ${CLUSTER_TAG} --without-monitor \
--db 1 --pd 1 --tiflash 1 --kv 1 \
--db.host 0.0.0.0 --pd.host 0.0.0.0 &
TIDB_PID=$!
echo -e "$TAG Wait until TiDB ready..."
while ! tiup playground display ${CLUSTER_TAG} >/dev/null 2>/dev/null
do
sleep 1
done
echo -e "$TAG Execute migrations"
docker compose run --rm backend /bin/sh -c "alembic upgrade head"
echo -e "$TAG Execute bootstrap"
docker compose run --rm backend /bin/sh -c "python bootstrap.py" > bootstrap.stdout
echo -e "$TAG Extract initial username and password"
cat bootstrap.stdout | grep IMPORTANT | sed 's/^.*email: \(.*\) and password: \(.*\)$/USERNAME=\1\nPASSWORD=\2/' > .credentials
cat .credentials
echo -e "$TAG Start components"
docker compose up -d redis frontend backend background static-web-server
echo -e "$TAG Wait until tidb.ai frontend ready..."
while ! curl http://127.0.0.1:3000 > /dev/null 2>/dev/null
do
sleep 1
done
echo -e "$TAG Wait until tidb.ai backend ready..."
while ! curl http://127.0.0.1:5001 > /dev/null 2>/dev/null
do
sleep 1
done
npx playwright test ${PLAYWRIGHT_ARGS}
if [ ! "${CI}" ]; then
npx playwright show-report
fi
================================================
FILE: e2e/test-html/example-doc-1.html
================================================
Example Document 1
This is an example document
Good job, you found this document!
================================================
FILE: e2e/test-html/example-doc-2.html
================================================
Example Document 2
This is an example document
Good job, you found this document!
================================================
FILE: e2e/test-html/example-sitemap.xml
================================================
http://static-web-server/example-doc-1.html
daily
0.9
http://static-web-server/example-doc-2.html
daily
0.9
================================================
FILE: e2e/test-html/widget-controlled.html
================================================
Document
================================================
FILE: e2e/test-html/widget.html
================================================
Document
================================================
FILE: e2e/tests/api-keys.spec.ts
================================================
import { expect, test } from '@playwright/test';
import { loginViaApi } from '../utils/login';
test.use({
permissions: ['clipboard-read'],
});
test('API Keys Operations', async ({ page, baseURL }) => {
await test.step('Login', async () => {
await loginViaApi(page);
});
await test.step('Click Nav', async () => {
await page.goto('/');
await page.getByText('API Keys').click();
await page.waitForURL('/api-keys');
});
const apiKey = await test.step('Create API Key', async () => {
await page.getByText('Create', { exact: true }).click();
const dialog = page.getByRole('dialog', { name: 'Create API Key' });
await dialog.waitFor({ state: 'visible' });
await page.fill('input[name=description]', 'Test API Key');
await page.getByRole('button', { name: 'Create API Key', exact: true }).click();
await dialog.waitFor({ state: 'hidden' });
const apiKey = await page.evaluate('navigator.clipboard.readText()');
expect(apiKey).not.toBeFalsy();
return apiKey;
});
await test.step(`Try fetch with API Key ${apiKey}`, async () => {
const response = await fetch(`${baseURL}/api/v1/api-keys`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
if (!response.ok) {
console.error(`${response.status} ${response.statusText}`, await response.text());
}
expect(response.ok).toBe(true);
});
await test.step('Delete API Key', async () => {
const row = page.getByText('Test API Key').locator('..');
await row.getByRole('button', { name: 'Delete' }).click();
const dialog = page.getByRole('alertdialog', { name: 'Are you absolutely sure?' });
await dialog.waitFor({ state: 'visible' });
await dialog.getByRole('button', { name: 'Continue' }).click();
await dialog.getByRole('button', { name: 'Continue' }).waitFor({ state: 'detached' });
await dialog.waitFor({ state: 'hidden' });
await row.waitFor({ state: 'detached' });
});
await test.step(`Try fetch with API Key (expect 401 Unauthorized)`, async () => {
const response = await fetch(`${baseURL}/api/v1/api-keys`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
expect(response.status).toBe(401);
});
});
================================================
FILE: e2e/tests/api.spec.ts
================================================
import { expect, test } from '@playwright/test';
import type { APIResponse } from 'playwright-core';
import { loginViaApi } from '../utils/login';
// let key: string = '';
//
// test.beforeAll(async ({ request }) => {
// const response = await request.post('/api/v1/api-leys', {
// data: JSON.stringify({
// description: 'E2E API Tests',
// }),
// });
//
// const { api_key } = await response.json();
// key = api_key;
//
// console.log(`Created API Key, ${key}`);
// });
test.describe('API', () => {
test('Bootstrap Status API', async ({ request }) => {
await expectOk(request.get('/api/v1/system/bootstrap-status'));
});
test('Retrieve Entity or Entites', async ({ request }) => {
async function expectGetOkStep (url: string) {
await test.step(url, async () => {
const response = await request.get(url);
if (!response.ok()) {
console.error(`${response.status()} ${response.statusText()}`, await response.text());
}
await expectOk(response);
});
}
await loginViaApi({ request });
//
// =====
// User
await expectGetOkStep('/api/v1/chats');
await expectGetOkStep('/api/v1/site-config');
await expectGetOkStep('/api/v1/api-keys');
//
// =====
// Admin
await expectGetOkStep('/api/v1/admin/chat-engines');
// Deprecated?
await expectGetOkStep('/api/v1/admin/documents');
await expectGetOkStep('/api/v1/admin/documents?page=1&size=10&language=en');
await expectGetOkStep('/api/v1/admin/documents?page=1&size=10&source_uri=system-variables&created_at_start=2024-07-04T11:58:00&created_at_end=2024-07-04T12:58:00&updated_at_start=2024-08-22T14:36:59&updated_at_end=2024-08-23T14:37:59&last_modified_at_start=2024-07-04T11:57:58&last_modified_at_end=2024-07-05T11:57:59&name=System&mime_type=text/markdown&index_status=failed');
await expectGetOkStep('/api/v1/admin/knowledge_bases');
await expectGetOkStep('/api/v1/admin/knowledge_bases/1/documents');
await expectGetOkStep('/api/v1/admin/knowledge_bases/1/documents?page=1&size=10&language=en');
await expectGetOkStep('/api/v1/admin/knowledge_bases/1/documents?page=1&size=10&source_uri=system-variables&created_at_start=2024-07-04T11:58:00&created_at_end=2024-07-04T12:58:00&updated_at_start=2024-08-22T14:36:59&updated_at_end=2024-08-23T14:37:59&last_modified_at_start=2024-07-04T11:57:58&last_modified_at_end=2024-07-05T11:57:59&name=System&mime_type=text/markdown&index_status=failed');
await expectGetOkStep('/api/v1/admin/site-settings');
await expectGetOkStep('/api/v1/admin/feedbacks');
await expectGetOkStep('/api/v1/admin/llms');
await expectGetOkStep('/api/v1/admin/llms/providers/options');
await expectGetOkStep('/api/v1/admin/embedding-models');
await expectGetOkStep('/api/v1/admin/embedding-models/providers/options');
await expectGetOkStep('/api/v1/admin/reranker-models');
await expectGetOkStep('/api/v1/admin/reranker-models/providers/options');
await expectGetOkStep('/api/v1/admin/retrieve/documents?chat_engine=1&question=what%20is%20tidb&chat_engine=default&top_k=5');
await expectGetOkStep('/api/v1/admin/embedding_retrieve?chat_engine=1&question=what%20is%20tidb&chat_engine=default&top_k=5');
await expectGetOkStep('/api/v1/admin/evaluation/datasets');
await expectGetOkStep('/api/v1/admin/evaluation/tasks');
});
});
async function expectOk (response: APIResponse | Promise) {
expect((await response).ok()).toBe(true);
}
================================================
FILE: e2e/tests/bootstrap.ts
================================================
import { expect, test } from '@playwright/test';
test.use({
trace: !!process.env.CI ? 'off' : 'on',
});
test('Bootstrap', async ({ browser, page }) => {
test.slow();
const {
USERNAME,
PASSWORD,
E2E_LLM_PROVIDER,
E2E_LLM_MODEL,
E2E_LLM_CREDENTIALS,
E2E_EMBEDDING_PROVIDER,
E2E_EMBEDDING_MODEL,
E2E_EMBEDDING_CREDENTIALS,
E2E_RERANKER_PROVIDER,
E2E_RERANKER_MODEL,
E2E_RERANKER_CREDENTIALS,
} = process.env;
await test.step('Visit home page', async () => {
await page.goto('/');
// IMPORTANT: Prevent recording credentials
await page.addStyleTag({
content: `[name=credentials] { filter: blur(1.5rem); }`,
});
await expect(page).toHaveTitle('TiDB.AI');
await expect(page.getByText('Ask anything about TiDB')).toBeVisible();
});
const hasWizardAlert = await page.getByText('This site is not ready to use yet.').isVisible();
if (!hasWizardAlert) {
return;
}
await test.step('Login', async () => {
if (await page.getByRole('link', { name: 'Login', exact: true }).count() === 0) {
console.warn('Already logged in');
return;
}
await page.getByRole('link', { name: 'Login', exact: true }).click();
const usernameInput = await page.waitForSelector('[name=username]');
const passwordInput = await page.waitForSelector('[name=password]');
const loginButton = page.getByRole('button', { name: 'Login', exact: true });
// Fill in credentials
await usernameInput.fill(USERNAME);
await passwordInput.fill(PASSWORD);
// Click login
await loginButton.click();
// Wait for dialog dismiss
await page.getByRole('dialog', { name: 'Sign In' }).waitFor({ state: 'detached' });
// Wait login
await page.getByText(USERNAME).waitFor({ state: 'visible' });
});
await test.step('Open admin side menu', async () => {
const modelTab = page.getByText('Models', { exact: true }).and(page.locator('[data-sidebar="menu-button"]'));
if ((await modelTab.getAttribute('data-state')) !== 'open') {
await modelTab.click();
}
});
async function clickTab (text: string, url: string) {
await test.step(`Goto ${text} page`, async () => {
await page.getByText(text, { exact: true }).and(page.locator('[data-sidebar="menu-sub-button"]').or(page.locator('[data-sidebar="menu-button"]'))).click();
await page.waitForURL(url);
await page.getByText(`New ${text.replace(/s$/, '')}`).waitFor({ state: 'visible' });
});
}
// Setup reranker
await test.step(`Create Default Reranker (${E2E_RERANKER_PROVIDER} ${E2E_RERANKER_MODEL})`, async () => {
await clickTab('Reranker Models', '/reranker-models');
await page.getByText('Loading Data').waitFor({ state: 'detached' });
if (await page.getByText('My Reranker').count() === 0) {
await page.getByText('New Reranker Model').click();
// Fill name
const nameInput = await page.waitForSelector('[name=name]');
await nameInput.fill('My Reranker');
// Select provider
await page.getByLabel('Provider').locator('..').locator('button').click();
await page.getByRole('option').filter({
has: page.getByText(E2E_RERANKER_PROVIDER, { exact: true }),
}).click();
// Fill model if provided
if (E2E_RERANKER_MODEL) {
const modelInput = await page.waitForSelector('[name=model]');
await modelInput.fill(E2E_RERANKER_MODEL);
}
// Fill credentials
if (E2E_RERANKER_CREDENTIALS) {
const credentialsInput = await page.waitForSelector('[name=credentials]');
await credentialsInput.fill(E2E_RERANKER_CREDENTIALS);
}
// Click create button
const createButton = page.getByRole('button', { name: 'Create Reranker' });
await createButton.scrollIntoViewIfNeeded();
await createButton.click();
// Wait for finish by check the url changes
await page.waitForURL(/\/reranker-models\/\d+/);
}
});
await test.step(`Create Default LLM (${E2E_LLM_PROVIDER} ${E2E_LLM_MODEL})`, async () => {
await clickTab('LLMs', '/llms');
await page.getByText('Loading Data').waitFor({ state: 'detached' });
if (await page.getByText('My LLM').count() === 0) {
await page.getByText('New LLM').click();
// Fill name
const nameInput = await page.waitForSelector('[name=name]');
await nameInput.fill('My LLM');
// Select provider
await page.getByLabel('Provider').locator('..').locator('button').click();
await page.getByRole('option').filter({
has: page.getByText(E2E_LLM_PROVIDER, { exact: true }),
}).click();
// Fill model if provided
if (E2E_LLM_MODEL) {
const modelInput = await page.waitForSelector('[name=model]');
await modelInput.fill(E2E_LLM_MODEL);
}
// Fill credentials
const credentialsInput = await page.waitForSelector('[name=credentials]');
await credentialsInput.fill(E2E_LLM_CREDENTIALS);
// Click create button
const createButton = page.getByRole('button', { name: 'Create LLM' });
await createButton.scrollIntoViewIfNeeded();
await createButton.click();
// Wait for finish by check the url changes
await page.waitForURL(/\/llms\/\d+/);
}
});
await test.step(`Create Default Embedding model (${E2E_EMBEDDING_PROVIDER} ${E2E_EMBEDDING_MODEL || 'default'})`, async () => {
await clickTab('Embedding Models', '/embedding-models');
await page.getByText('Loading Data').waitFor({ state: 'detached' });
if (await page.getByText('My Embedding Model').count() === 0) {
await page.getByText('New Embedding Model').click();
// Fill name
const nameInput = await page.waitForSelector('[name=name]');
await nameInput.fill('My Embedding Model');
// Select provider
await page.getByLabel('Provider').locator('..').locator('button').click();
await page.getByRole('option').filter({
has: page.getByText(E2E_EMBEDDING_PROVIDER, { exact: true }),
}).click();
// Fill model if provided
if (E2E_EMBEDDING_MODEL) {
const modelInput = await page.waitForSelector('[name=model]');
await modelInput.fill(E2E_EMBEDDING_MODEL);
}
// Fill credentials
const credentialsInput = await page.waitForSelector('[name=credentials]');
await credentialsInput.fill(E2E_EMBEDDING_CREDENTIALS);
const vectorDimensionInput = await page.waitForSelector('[name=vector_dimension]');
await vectorDimensionInput.fill('1536');
// Click create button
const createButton = page.getByRole('button', { name: 'Create Embedding Model' });
await createButton.scrollIntoViewIfNeeded();
await createButton.click();
// Wait for finish by check the url changes
await page.waitForURL(/\/embedding-models\/\d+/);
}
});
// Create Knowledge Base
await test.step('Create Knowledge Base', async () => {
await clickTab('Knowledge Bases', '/knowledge-bases');
await page.getByText('Loading Data').waitFor({ state: 'detached' });
if (await page.getByText('My Knowledge Base').count() === 0) {
await page.getByText('New Knowledge Base').click();
await page.waitForSelector('[name=name]');
await page.fill('input[name=name]', 'My Knowledge Base');
await page.fill('textarea[name=description]', 'This is E2E Knowledge Base.');
await page.getByRole('button', { name: 'Create', exact: true }).click();
await page.waitForURL(/\/knowledge-bases\/1\/data-sources/);
}
// Create Datasource
await test.step('Create Datasource', async () => {
await page.goto('/knowledge-bases/1/data-sources');
if (await page.getByText('sample.pdf').count() === 0) {
await page.getByRole('button', { name: 'Files' }).click();
const nameInput = await page.waitForSelector('[name=name]');
await nameInput.fill('sample.pdf');
await page.setInputFiles('[name=files]', 'res/sample.pdf');
const createButton = page.getByRole('button', { name: 'Create' });
await createButton.scrollIntoViewIfNeeded();
await createButton.click();
// Jump back to KB data source page
await page.waitForURL(/\/knowledge-bases\/1\/data-sources$/);
}
});
});
// Update default Chat Engine
await test.step('Update Chat Engine', async () => {
await clickTab('Chat Engines', '/chat-engines');
await page.getByText('Loading Data').waitFor({ state: 'detached' });
await page.getByRole('link', { name: 'default' }).click();
await page.getByRole('tab', { name: 'Retrieval' }).click();
await page.getByRole('button', { name: 'Knowledge Bases', exact: true }).click();
await page.getByRole('option').filter({ has: page.getByText('My Knowledge Base') }).click();
await page.click('body');
await page.getByRole('button', { name: 'Save', exact: true }).click();
await page.getByRole('button', { name: 'Save', exact: true }).waitFor({ state: 'detached' });
});
await test.step('Reload and check wizard alert', async () => {
await page.goto('/');
await page.getByText('This site is not ready to use yet.').waitFor({ state: 'detached' });
});
await test.step('Documents count greater than 0', async () => {
await page.goto('/knowledge-bases/1');
await page.getByRole('button', { name: 'sample.pdf' }).waitFor({ state: 'visible' });
});
await test.step('Wait for indexing', async () => {
while (true) {
const response = await page.request.get('/api/v1/admin/knowledge_bases/1/overview');
if (!response.ok()) {
console.warn(`${response.status()} ${response.statusText()}`, await response.text());
} else {
const json = await response.json();
if (json.vector_index.completed > 0) {
break;
}
}
await page.waitForTimeout(500);
}
});
});
================================================
FILE: e2e/tests/chat-engine.spec.ts
================================================
import { expect, type Locator, type Page, test } from '@playwright/test';
import { checkCheckbox, selectOption, turnSwitch } from '../utils/forms';
import { loginViaApi } from '../utils/login';
test.describe('Chat Engine', () => {
test.describe('Configurations', () => {
test('Create with default configuration', async ({ page }) => {
await test.step('Goto page', async () => {
await loginViaApi(page);
await page.goto('/chat-engines');
await page.getByRole('button', { name: 'New Chat Engine' }).click();
await page.waitForURL('/chat-engines/new');
});
const name = 'All default configuration';
await test.step('Fill in fields', async () => {
// Fill in name
await page.getByRole('textbox', { name: 'Name' }).fill(name);
// Goto retrieval tab
await page.getByRole('tab', { name: 'Retrieval' }).click();
// Select default knowledge base
await selectOption(page, 'Knowledge Bases', /My Knowledge Base/, true);
});
const chatEngineId = await test.step('Create', async () => {
await page.getByRole('button', { name: 'Create Chat Engine' }).click();
await page.waitForURL(/\/chat-engines\/\d+$/);
const [_, idString] = /\/chat-engines\/(\d+)$/.exec(page.url());
return parseInt(idString);
});
await test.step('Validate configurations', async () => {
// Validate chat engine configurations
const chatEngine = await getChatEngine(page, chatEngineId);
expect(chatEngine.name).toBe(name);
expect(chatEngine.engine_options).toStrictEqual({
knowledge_base: {
linked_knowledge_bases: [{
id: 1,
}],
},
knowledge_graph: {
enabled: true,
},
hide_sources: false,
clarify_question: false,
further_questions: false,
});
expect(chatEngine.llm_id).toBeNull();
expect(chatEngine.fast_llm_id).toBeNull();
expect(chatEngine.reranker_id).toBeNull();
});
await test.step('Check availability', async () => {
await checkChatEngineAvailability(page, name);
});
});
test('Create with featured configuration', async ({ page }) => {
await test.step('Goto page', async () => {
await loginViaApi(page);
await page.goto('/chat-engines');
await page.getByRole('button', { name: 'New Chat Engine' }).click();
await page.waitForURL('/chat-engines/new');
});
const name = 'Featured configuration';
await test.step('Fill in fields', async () => {
// Fill in name
await page.getByRole('textbox', { name: 'Name' }).fill(name);
// Set LLM & Fast LLM
await selectOption(page, 'LLM', /My LLM/);
await selectOption(page, 'Fast LLM', /My LLM/);
// TODO: Create a Fast LLM in place
// Goto retrieval tab
await page.getByRole('tab', { name: 'Retrieval' }).click();
await selectOption(page, 'Knowledge Bases', /My Knowledge Base/, true);
await checkCheckbox(page, 'Hide Sources');
// Semantic Search Subsection
await selectOption(page, 'Reranker', /My Reranker/);
// Knowledge Graph Subsection
await page.getByRole('spinbutton', { name: 'Depth' }).fill('1'); // Do not use 2 for default value is 2
await checkCheckbox(page, 'Include Metadata');
await checkCheckbox(page, 'Using Intent Search');
// Goto Generation tab
await page.getByRole('tab', { name: 'Generation' }).click();
await turnSwitch(page, 'Clarify Question');
});
const chatEngineId = await test.step('Create', async () => {
await page.getByRole('button', { name: 'Create Chat Engine' }).click();
await page.waitForURL(/\/chat-engines\/\d+$/);
const [_, idString] = /\/chat-engines\/(\d+)$/.exec(page.url());
return parseInt(idString);
});
await test.step('Validate configurations', async () => {
// Validate chat engine configurations
const chatEngine = await getChatEngine(page, chatEngineId);
expect(chatEngine.name).toBe(name);
expect(chatEngine.engine_options).toStrictEqual({
knowledge_base: {
linked_knowledge_bases: [{
id: 1,
}],
},
knowledge_graph: {
enabled: true,
depth: 1,
include_meta: true,
using_intent_search: true,
},
hide_sources: true,
clarify_question: true,
further_questions: false,
});
expect(chatEngine.llm_id).toBe(1);
expect(chatEngine.fast_llm_id).toBe(1);
expect(chatEngine.reranker_id).toBe(1);
});
await test.step('Check availability', async () => {
await checkChatEngineAvailability(page, name);
});
});
test('Update', async ({ page }) => {
await test.step('Goto page', async () => {
await loginViaApi(page);
await page.goto('/chat-engines');
await page.getByRole('button', { name: 'New Chat Engine' }).click();
await page.waitForURL('/chat-engines/new');
const name = 'Chat Engine to be updated';
await test.step('Fill in fields', async () => {
// Fill in name
await page.getByRole('textbox', { name: 'Name' }).fill(name);
// Goto retrieval tab
await page.getByRole('tab', { name: 'Retrieval' }).click();
// Select default knowledge base
await selectOption(page, 'Knowledge Bases', /My Knowledge Base/, true);
});
const chatEngineId = await test.step('Create', async () => {
await page.getByRole('button', { name: 'Create Chat Engine' }).click();
await page.waitForURL(/\/chat-engines\/\d+$/);
const [_, idString] = /\/chat-engines\/(\d+)$/.exec(page.url());
return parseInt(idString);
});
await page.goto('/chat-engines/' + chatEngineId);
await test.step('Update Name', async () => {
await page.getByRole('textbox', { name: 'Name' }).fill('Chat Engine to be updated (updated)');
await waitUpdate(page, page.getByRole('textbox', { name: 'Name', disabled: false }));
expect(await getChatEngine(page, chatEngineId).then(ce => ce.name)).toBe('Chat Engine to be updated (updated)');
});
await test.step('Update LLM', async () => {
await selectOption(page, 'LLM', /My LLM/);
await waitUpdate(page, page.getByRole('button', { name: 'LLM', exact: true, disabled: false }));
expect(await getChatEngine(page, chatEngineId).then(ce => ce.llm_id)).toBe(1);
});
await page.getByRole('tab', { name: 'Retrieval' }).click();
await test.step('Update KG Depth', async () => {
await page.getByRole('spinbutton', { name: 'Depth' }).fill('3');
await waitUpdate(page, page.getByRole('spinbutton', { name: 'Depth', disabled: false }));
expect(await getChatEngine(page, chatEngineId).then(ce => ce.engine_options.knowledge_graph.depth)).toBe(3);
});
// TODO: add cases for rest fields
});
});
});
});
// TODO: The selectors are tricky. Update the select component to simplify the validation.
async function checkChatEngineAvailability (page: Page, name: string) {
await page.locator('[data-sidebar="menu"] li').filter({ hasText: /Chat Engines/ }).getByRole('link').click();
// wait for chat engine table updated.
await page.getByText(name).waitFor();
await page.goto('/');
// Select the 'Select Chat Engine' combobox
const selector = page.getByRole('combobox').and(page.getByText('Select Chat Engine', { exact: true }).locator('..'));
await selector.click();
await page.getByRole('option', { name: name }).click();
// Input question
await page.getByPlaceholder('Input your question here...').fill('Hello');
// Send message
await page.keyboard.press('ControlOrMeta+Enter');
// Wait page url to be changed. When changed, the chat was created correctly.
// Ignore the returned message which is not important.
await page.waitForURL(/\/c\/.+$/);
}
async function getChatEngine (page: Page, id: number) {
const ceResponse = await page.request.get(`/api/v1/admin/chat-engines/${id}`);
expect(ceResponse.ok()).toBe(true);
return await ceResponse.json();
}
async function waitUpdate (page: Page, locator: Locator) {
await page.getByRole('button', { name: 'Save' }).click();
await page.getByRole('button', { name: 'Save' }).waitFor({ state: 'detached' });
await locator.waitFor();
}
================================================
FILE: e2e/tests/chat.spec.ts
================================================
import { expect, test } from '@playwright/test';
import { getChatRequestPromise, QUESTION, testNewChat } from '../utils/chat';
import { loginViaApi } from '../utils/login';
test.describe.serial('Chat', () => {
test('From Home Page', async ({ page, baseURL }) => {
await test.step('Visit home page', async () => {
await page.goto('/');
});
const chatRequest = await test.step('Input text and ask', async () => {
await page.getByPlaceholder('Input your question here...').fill(QUESTION);
// https://playwright.dev/docs/events#waiting-for-event
const chatRequestPromise = getChatRequestPromise(page, baseURL);
const trigger = page.locator('button', { has: page.locator('svg.lucide-arrow-up') });
await trigger.click();
await expect(trigger).toBeDisabled();
return await chatRequestPromise;
});
await testNewChat(page, chatRequest, true, true);
});
test('From Keyboard Shortcut', async ({ page, baseURL }) => {
await test.step('Visit home page', async () => {
await page.goto('/');
});
const chatRequest = await test.step('Input text and ask', async () => {
await page.keyboard.press('ControlOrMeta+k');
await page.keyboard.insertText(QUESTION);
// https://playwright.dev/docs/events#waiting-for-event
const chatRequestPromise = getChatRequestPromise(page, baseURL);
await page.keyboard.press('ControlOrMeta+Enter');
return await chatRequestPromise;
});
await testNewChat(page, chatRequest, true, false);
});
test('Admin Feedback Page', async ({ page }) => {
await loginViaApi(page);
await page.goto('/feedbacks');
expect(await page.getByText('Good Good Good').count()).toBeGreaterThan(0);
expect(await page.getByText('Bad Bad Bad').count()).toBeGreaterThan(0);
});
});
================================================
FILE: e2e/tests/datasource.spec.ts
================================================
import { expect, test } from '@playwright/test';
import { loginViaApi } from '../utils/login';
test.describe('Datasource', () => {
test.fixme('Web Single Page', async ({ page }) => {
test.slow();
await test.step('Login and visit KB page', async () => {
await loginViaApi(page);
await page.goto('/knowledge-bases/1/data-sources');
});
await test.step('Add Single Page Datasource', async () => {
await page.getByRole('button', { name: 'Web Pages' }).click();
await page.waitForURL('/knowledge-bases/1/data-sources/new?type=web_single_page');
await page.getByLabel('Name').fill('example site');
await page.getByRole('button', { name: 'New Item' }).click();
await page.locator('input[name="urls.0"]').fill('https://example.com');
await page.getByRole('button', { name: 'New Item' }).click();
await page.locator('input[name="urls.1"]').fill('https://www.iana.org/help/example-domains');
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL('/knowledge-bases/1/data-sources');
test.fixme('check index status', async () => {
const id = /\/datasources\/(\d+)/.exec(page.url())[1];
while (true) {
const response = await page.request.get(`/api/v1/admin/datasources/${id}/overview`);
if (response.ok()) {
const json = await response.json();
if (json.vector_index.completed === 2) {
break;
}
} else {
console.warn(`${response.status()} ${response.statusText()}`, await response.text());
}
await page.waitForTimeout(500);
}
});
});
await test.step('Check Documents Page', async () => {
await page.goto('/documents');
await expect(page.getByRole('link', { name: 'https://example.com' })).toBeVisible();
await expect(page.getByRole('link', { name: 'https://www.iana.org/help/example-domains' })).toBeVisible();
});
});
test.fixme('Web Sitemap', async ({ page }) => {
test.slow();
await test.step('Login and visit page', async () => {
await loginViaApi(page);
await page.goto('/datasources');
await expect(page.getByRole('heading', { name: 'Datasources' })).toBeVisible();
});
await test.step('Add Sitemap Datasource', async () => {
await page.getByRole('button', { name: 'Create' }).click();
await page.getByRole('tab', { name: 'Web Sitemap' }).click();
await page.waitForURL('/datasources/create/web-sitemap');
await page.getByLabel('Name').fill('example site from sitemap');
await page.getByLabel('Description').fill('This is example sitemap');
await page.locator('input[name="url"]').fill('http://static-web-server/example-sitemap.xml');
await page.getByRole('button', { name: 'Create Datasource' }).click();
await page.waitForURL(/\/datasources\/\d+/);
const id = /\/datasources\/(\d+)/.exec(page.url())[1];
while (true) {
const response = await page.request.get(`/api/v1/admin/datasources/${id}/overview`);
if (response.ok()) {
const json = await response.json();
if (json.vector_index.completed === 2) {
break;
}
} else {
console.warn(`${response.status()} ${response.statusText()}`, await response.text());
}
await page.waitForTimeout(500);
}
});
await test.step('Check Documents Page', async () => {
await page.goto('/documents');
await expect(page.getByRole('link', { name: 'http://static-web-server/example-doc-1.html' })).toBeVisible();
await expect(page.getByRole('link', { name: 'http://static-web-server/example-doc-2.html' })).toBeVisible();
});
});
test.fixme('Files', () => {
test.fixme(true, 'Already tested in bootstrap');
});
});
================================================
FILE: e2e/tests/evaluation.spec.ts
================================================
import { expect, type Page, test } from '@playwright/test';
import { loginViaApi } from '../utils/login';
test.describe('Evaluation Dataset Management', () => {
test('Create dataset with CSV', async ({ page }) => {
await loginViaApi(page);
await createEvaluationDataset(page, 'Example Dataset from CSV', 'res/sample-evaluation-dataset.csv');
// Data from CSV
await expect(page.getByText('Example Query')).toBeVisible();
});
test('Create dataset from scratch', async ({ page }) => {
await loginViaApi(page);
await createEvaluationDataset(page, 'Example Dataset from scratch', 'res/sample-evaluation-dataset.csv');
// Empty dataset items list.
await expect(page.getByText('Empty List')).toBeVisible();
});
test('Delete dataset', async ({ page }) => {
await loginViaApi(page);
await createEvaluationDataset(page, 'Example Dataset to delete');
await page.goto('/evaluation/datasets');
await expect(page.getByRole('row').filter({ hasText: 'Example Dataset to delete' })).toBeVisible();
await page.getByRole('row').filter({ hasText: 'Example Dataset to delete' }).locator('button').last().click();
await page.getByRole('menuitem', { name: 'Delete' }).click();
await page.getByRole('button', { name: 'Continue' }).click();
await page.getByRole('button', { name: 'Continue' }).waitFor({ state: 'detached' });
await page.getByRole('row').filter({ hasText: 'Example Dataset to delete' }).waitFor({ state: 'detached' });
});
test('Mutate dataset items', async ({ page }) => {
await loginViaApi(page);
const datasetId = await createEvaluationDataset(page, 'Example Dataset to update');
// Empty dataset items list.
await expect(page.getByText('Empty List')).toBeVisible();
await test.step('Add item', async () => {
await page.getByRole('button', { name: 'New Item' }).click();
await page.waitForURL(`/evaluation/datasets/${datasetId}/items/new`);
await page.getByRole('textbox', { name: 'Query' }).fill('Example Query');
await page.getByRole('textbox', { name: 'Reference' }).fill('Example Reference');
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL(`/evaluation/datasets/${datasetId}`);
await page.getByRole('row').filter({ hasText: 'Example Query' }).waitFor({ state: 'attached' });
await page.getByRole('row').filter({ hasText: 'Example Reference' }).waitFor({ state: 'attached' });
});
await test.step('Delete Item', async () => {
await page.getByRole('row').filter({ hasText: 'Example Query' }).locator('button').last().click();
await page.getByRole('menuitem', { name: 'Delete' }).click();
await page.getByRole('button', { name: 'Continue' }).click();
await page.getByRole('button', { name: 'Continue' }).waitFor({ state: 'detached' });
await page.getByRole('row').filter({ hasText: 'Example Query' }).waitFor({ state: 'detached' });
await page.getByRole('row').filter({ hasText: 'Example Reference' }).waitFor({ state: 'detached' });
});
});
});
async function createEvaluationDataset (page: Page, name: string, file?: string) {
await page.goto('/');
await test.step('Navigate to Create Evaluation Dataset Page', async () => {
await page.getByRole('button', { name: 'Evaluation' }).click();
await page.getByRole('link', { name: 'Datasets' }).click();
await page.waitForURL('/evaluation/datasets');
await page.getByRole('button', { name: 'New Evaluation Dataset' }).click();
await page.waitForURL('/evaluation/datasets/create');
});
return await test.step('Fill in form and submit', async () => {
await page.getByRole('textbox', { name: 'Name' }).fill(name);
if (file) {
await page.locator('[name=upload_file]').setInputFiles(file);
}
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL(/\/evaluation\/datasets\/\d+/);
const [_, idString] = /\/evaluation\/datasets\/(\d+)/.exec(page.url());
return parseInt(idString);
});
}
================================================
FILE: e2e/tests/knowledge-base.spec.ts
================================================
import { expect, type Page, test } from '@playwright/test';
import { selectOption, turnSwitch } from '../utils/forms';
import { loginViaApi } from '../utils/login';
test.describe('Knowledge Base', () => {
test('Configure Data Sources', async ({ page }) => {
const kbId = await createFeaturedKnowledgeBase(page, 'KnowledgeBase 1', true);
await test.step('Configure Files Data Source', async () => {
await page.getByRole('button', { name: 'Upload Files' }).click();
await page.waitForURL(/data-sources\/new\?type=file/);
await page.setInputFiles('[name=files]', 'res/sample.pdf');
await page.getByRole('textbox', { name: 'Datasource Name' }).fill('Files DataSource');
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL(/\/knowledge-bases\/\d+\/data-sources/);
await pollKbOverviewUntill(page, kbId, overview => overview.documents.total === 1);
});
await test.step('Configure Web Pages Data Source', async () => {
await page.getByRole('button', { name: 'Select Pages' }).click();
await page.waitForURL(/data-sources\/new\?type=web_single_page/);
await page.getByRole('button', { name: 'New Item' }).click();
await page.getByPlaceholder('https://example.com').fill('http://static-web-server/example-doc-1.html');
await page.getByRole('textbox', { name: 'Datasource Name' }).fill('Web Pages DataSource');
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL(/\/knowledge-bases\/\d+\/data-sources/);
await pollKbOverviewUntill(page, kbId, overview => overview.documents.total === 2);
// Check document exists
await page.getByRole('button', { name: /^Documents/ }).click();
await expect(page.getByRole('button', { name: 'Example Document 1' })).toBeVisible();
await page.getByRole('button', { name: /^Data Sources/ }).click();
await page.waitForURL(/\/knowledge-bases\/\d+\/data-sources/);
});
await test.step('Configure Sitemap Data Source', async () => {
await page.getByRole('button', { name: 'Select web sitemap.' }).click();
await page.waitForURL(/data-sources\/new\?type=web_sitemap/);
await page.getByRole('textbox', { name: 'Sitemap URL' }).fill('http://static-web-server/example-sitemap.xml');
await page.getByRole('textbox', { name: 'Datasource Name' }).fill('Web Sitemap DataSource');
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL(/\/knowledge-bases\/\d+\/data-sources/);
await pollKbOverviewUntill(page, kbId, overview => overview.documents.total === 4);
// Check document exists
await page.getByRole('button', { name: /^Documents/ }).click();
await expect(page.getByRole('button', { name: 'Example Document 1' })).toHaveCount(2); // Documents are not deduplicated.
await expect(page.getByRole('button', { name: 'Example Document 2' })).toBeVisible();
await page.getByRole('button', { name: /^Data Sources/ }).click();
});
test.slow();
await test.step('Check for index progress', async () => {
await pollKbOverviewUntill(page, kbId,
overview =>
overview.documents.total === 4
&& overview.chunks.total === 4
&& overview.vector_index.completed === 4
&& overview.kg_index.completed === 4,
);
});
});
test('Delete Data Sources', async ({ page }) => {
const kbId = await createFeaturedKnowledgeBase(page, 'KnowledgeBase 2');
await configureSimpleDataSource(page, kbId);
await test.step('Delete Document', async () => {
await page.goto(`/knowledge-bases/${kbId}/data-sources`);
await page.getByRole('button', { name: 'Delete' }).click();
await page.getByRole('button', { name: 'Continue' }).click();
await page.getByRole('button', { name: 'Continue' }).waitFor({ state: 'detached' });
// FIXME: reload data sources after deletion
await pollKbOverviewUntill(page, kbId, overview => {
return overview.documents.total === 0 && overview.chunks.total === 0;
});
});
await test.step('Wait for documents and chunks to be deleted', async () => {
await pollKbOverviewUntill(page, kbId, overview => {
return overview.documents.total === 0 && overview.chunks.total === 0;
});
});
});
test('Delete Documents', async ({ page }) => {
const kbId = await createFeaturedKnowledgeBase(page, 'KnowledgeBase 3');
await configureSimpleDataSource(page, kbId);
await test.step('Delete Document', async () => {
await page.goto(`/knowledge-bases/${kbId}`);
// FIXME: add aria roles
await page.getByRole('button').filter({ has: page.locator('.lucide-ellipsis') }).click();
await page.getByRole('menuitem', { name: 'Delete' }).click();
await page.getByRole('button', { name: 'Continue' }).click();
await page.getByRole('button', { name: 'Continue' }).waitFor({ state: 'detached' });
});
await test.step('Wait for documents and chunks to be deleted', async () => {
await pollKbOverviewUntill(page, kbId, overview => {
return overview.documents.total === 0 && overview.chunks.total === 0;
});
});
});
});
async function createFeaturedKnowledgeBase (page: Page, name: string, enableKnowledgeGraph = false) {
await loginViaApi(page);
return await test.step(`Create KnowledgeBase ${name} (kg_index ${enableKnowledgeGraph ? 'enabled' : 'disabled'})`, async () => {
await test.step('Navigate to Create KnowledgeBase Page', async () => {
await page.goto('/knowledge-bases');
await page.getByRole('button', { name: 'New Knowledge Base' }).click();
await page.waitForURL('/knowledge-bases/new');
});
await test.step('Fill KnowledgeBase Form', async () => {
await page.getByRole('textbox', { name: 'Name' }).fill(name);
await page.getByRole('textbox', { name: 'Description' }).fill(`KnowledgeBase Description for ${name}`);
await selectOption(page, 'LLM', /My LLM/);
await selectOption(page, 'Embedding Model', /My Embedding Model/);
if (enableKnowledgeGraph) {
await turnSwitch(page, 'Knowledge Graph Index');
}
});
return await test.step('Create and jump to data sources page', async () => {
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL(/\/knowledge-bases\/\d+\/data-sources/);
const [, idString] = /\/knowledge-bases\/(\d+)\/data-sources/.exec(page.url());
return parseInt(idString);
});
});
}
async function configureSimpleDataSource (page: Page, kbId: number, enableKnowledgeGraph = false) {
await test.step(`Configure simple data source`, async () => {
await test.step(`Upload simple file`, async () => {
await page.getByRole('button', { name: 'Upload Files' }).click();
await page.waitForURL(/data-sources\/new\?type=file/);
await page.setInputFiles('[name=files]', 'res/sample.pdf');
await page.getByRole('textbox', { name: 'Datasource Name' }).fill('Simple DataSource');
await page.getByRole('button', { name: 'Create' }).click();
await page.waitForURL(/\/knowledge-bases\/\d+\/data-sources/);
});
await test.step(`Wait for index progress`, async () => {
await pollKbOverviewUntill(page, kbId, overview => {
expect(overview.documents.total).toBe(1);
return !!overview.vector_index.completed && (!enableKnowledgeGraph || !!overview.kg_index.completed);
});
});
return kbId;
});
}
async function pollKbOverviewUntill (page: Page, kbId: number, isOk: (json: any) => boolean) {
await test.step('Poll kb overview api', async () => {
let i = 0;
while (true) {
const ok = await test.step(`Poll rounds ${++i}`, async () => {
await page.waitForTimeout(500);
const response = await page.request.get(`/api/v1/admin/knowledge_bases/${kbId}/overview`);
expect(response.ok()).toBe(true);
const overview = await response.json();
return isOk(overview);
});
if (ok) {
break;
}
}
});
}
================================================
FILE: e2e/tests/site-settings.spec.ts
================================================
import { expect, test } from '@playwright/test';
import { loginViaApi } from '../utils/login';
test.describe('Site Sittings', () => {
test('Basic Settings', async ({ page, browser, baseURL }) => {
const homePage = await test.step('Visit Settings Page', async () => {
await loginViaApi(page);
await page.goto('/site-settings');
const homePage = await browser.newPage({
baseURL,
});
await homePage.goto('/');
return homePage;
});
await test.step('Title and Description', async () => {
await page.getByLabel('Title', { exact: true }).fill('FooBar.AI');
await submitAndWaitSavedByLabel('Title');
await page.getByLabel('Description', { exact: true }).fill('FooBar AI Description');
await submitAndWaitSavedByLabel('Description');
await page.getByLabel('Homepage Title', { exact: true }).fill('Ask anything about FooBar');
await submitAndWaitSavedByLabel('Homepage Title');
await page.reload();
await homePage.waitForTimeout(7000); // wait for settings cache
await homePage.reload();
expect(await homePage.title()).toBe('FooBar.AI');
await expect(homePage.locator('h1')).toHaveText('Ask anything about FooBar');
await expect(homePage.locator('h1 + p')).toHaveText('FooBar AI Description');
await expect(homePage.locator('meta[name=description]')).toHaveAttribute('content', 'FooBar AI Description');
});
async function submitAndWaitSavedByLabel (label: string) {
const button = page.getByText(label, { exact: true }).locator('..').locator('..').getByRole('button', { name: 'Save', exact: true });
// Click the save button in the field form
await button.click();
// Wait the save button to be vanished. (Saved)
await button.waitFor({ state: 'hidden' });
}
});
});
================================================
FILE: e2e/tests/widget.spec.ts
================================================
import { expect, type Locator, type Page, test } from '@playwright/test';
import { getChatRequestPromise, QUESTION, testNewChat } from '../utils/chat';
test('JS Widget', async ({ page }) => {
await page.goto('/');
await page.getByRole('button', { name: 'Ask AI' }).waitFor({ state: 'visible' });
expect(await page.evaluate('tidbai')).toMatchObject({ open: false });
});
test('Embedded JS Widget with trigger button', async ({ page }) => {
const trigger = await test.step('Wait trigger visible and tidbai object ready', async () => {
await page.goto('http://localhost:4001/widget.html');
const trigger = page.getByRole('button', { name: 'Ask AI' });
await trigger.waitFor({ state: 'visible' });
expect(await page.evaluate('tidbai')).toMatchObject({ open: false });
return trigger;
});
const dialog = await test.step('Click and show dialog', async () => {
await trigger.click();
const dialog = page.getByRole('dialog', { name: 'Ask AI' });
await dialog.waitFor({ state: 'visible' });
return dialog;
});
await testWidgetChat(page, dialog);
});
// Used by docs.pingcap.com
test('Embedded JS Widget controlled by js', async ({ page }) => {
await test.step('Wait trigger visible and tidbai object ready', async () => {
await page.goto('http://localhost:4001/widget-controlled.html');
const trigger = page.getByRole('button', { name: 'Ask AI' });
await expect(trigger).toBeHidden();
await page.waitForFunction(() => (window as any).tidbai);
expect(await page.evaluate('window.tidbai')).toMatchObject({ open: false });
});
const dialog = await test.step('JS api call and show dialog', async () => {
await page.evaluate('tidbai.open = true');
const dialog = page.getByRole('dialog', { name: 'Ask AI' });
await dialog.waitFor({ state: 'visible' });
return dialog;
});
await testWidgetChat(page, dialog);
});
async function testWidgetChat (page: Page, dialog: Locator) {
await test.step('Fill in question', async () => {
const input = dialog.getByPlaceholder('Input your question here...');
await input.focus();
await input.fill(QUESTION);
});
const chatRequestPromise = await test.step('Trigger ask by press ControlOrMeta+Enter', async () => {
const chatRequestPromise = getChatRequestPromise(page, 'http://127.0.0.1:3000');
await page.keyboard.press('ControlOrMeta+Enter');
return chatRequestPromise;
});
await testNewChat(page, chatRequestPromise, false);
}
================================================
FILE: e2e/utils/chat.ts
================================================
import { expect, type Page, type Request, test } from '@playwright/test';
export const QUESTION = 'What is the content of sample.pdf?';
export function getChatRequestPromise (page: Page, baseURL: string) {
return page.waitForRequest(request => request.url() === `${baseURL}/api/v1/chats` && request.method() === 'POST');
}
export async function testNewChat (page: Page, chatRequest: Request, validatePageUrlAndTitle: boolean, feedbackLike?: boolean) {
await test.step('Wait page changes', async () => {
if (validatePageUrlAndTitle) {
await page.waitForURL(/\/c\/.+/);
expect(await page.title()).toContain(QUESTION);
}
await page.getByRole('heading', { name: QUESTION }).waitFor({ state: 'visible' });
});
const streamText = await test.step('Wait for chat stop', async () => {
const chatResponse = await chatRequest.response();
expect(chatResponse.ok()).toBe(true);
// Feedback button indicates chat ends.
await page.getByRole('button', { name: 'Like This Answer', exact: true }).waitFor({ state: 'visible' });
return await chatResponse.text();
});
await test.step('Check response text', async () => {
const lastLine = streamText.split('\n').filter(t => !!t.trim()).slice(-1)[0];
expect(lastLine).toMatch(/^2:/);
const message = JSON.parse(lastLine.slice(2))[0].assistant_message;
expect(message.finished_at).toBeTruthy();
expect(message.content.trim().length).toBeGreaterThan(0);
});
if (typeof feedbackLike === 'boolean') {
await test.step('Feedback', async () => {
const feedbackButton = page.getByRole('button', { name: feedbackLike ? 'Like This Answer' : 'Dislike This Answer', exact: true });
await feedbackButton.click();
const dialog = page.getByRole('dialog', { name: 'Feedback' });
await dialog.waitFor({ state: 'visible' });
const comments = page.getByPlaceholder('Comments...');
await comments.pressSequentially(feedbackLike ? 'Good Good Good' : 'Bad Bad Bad');
await page.getByText('Add feedback', { exact: true }).click();
await dialog.waitFor({ state: 'hidden' });
});
}
}
================================================
FILE: e2e/utils/forms.ts
================================================
import { expect, type Page, test } from '@playwright/test';
export async function selectOption (page: Page, name: string, value: string | RegExp, clickWindow = false) {
await test.step(`Select field ${name}`, async () => {
await page.getByRole('button', { name: name, exact: true }).click();
await page.getByRole('option', { name: value }).click();
if (clickWindow) {
await page.click('body');
}
await expect(page.getByRole('button', { name: name, exact: true })).toHaveText(value);
});
}
export async function turnSwitch (page: Page, name: string, on: boolean = true) {
await test.step(`Turn ${on ? 'on' : 'off'} switch ${name}`, async () => {
const locator = page.getByRole('switch', { name: name, exact: true });
if (on) {
if (await locator.getAttribute('aria-checked') === 'true') {
return;
}
await locator.click();
await expect(locator).toHaveAttribute('aria-checked', 'true');
} else {
if (await locator.getAttribute('aria-checked') === 'false') {
return;
}
await locator.click();
await expect(locator).toHaveAttribute('aria-checked', 'false');
}
});
}
export async function checkCheckbox (page: Page, name: string, on: boolean = true) {
await test.step(`${on ? 'Check' : 'Uncheck'} checkbox ${name}`, async () => {
const locator = page.getByRole('checkbox', { name: name, exact: true });
if (on) {
await locator.check();
} else {
await locator.uncheck();
}
});
}
================================================
FILE: e2e/utils/login.ts
================================================
import { type APIRequestContext, expect, test } from '@playwright/test';
export async function loginViaApi ({ request }: { request: APIRequestContext }) {
await test.step('Login via API', async () => {
const usp = new URLSearchParams();
usp.set('username', process.env.USERNAME);
usp.set('password', process.env.PASSWORD);
const response = await request.post('/api/v1/auth/login', {
data: usp.toString(),
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
});
expect(response.ok()).toBe(true);
});
}
================================================
FILE: e2e/vercel.json
================================================
{
"outputDirectory": "playwright-report"
}
================================================
FILE: frontend/.gitignore
================================================
node_modules
.idea
================================================
FILE: frontend/.nvmrc
================================================
v22.12.0
================================================
FILE: frontend/.prettierignore
================================================
./app/src/components/ui/**
================================================
FILE: frontend/Dockerfile
================================================
FROM node:20-alpine AS base
# 1. Install dependencies only when needed
FROM base AS deps
# Check https://github.com/nodejs/docker-node/tree/b4117f9333da4138b03a546ec926ef50a31506c3#nodealpine to understand why libc6-compat might be needed.
RUN apk add --no-cache libc6-compat
WORKDIR /tidb.ai
RUN mkdir -p frontend/patches
RUN mkdir -p frontend/app
RUN mkdir -p frontend/packages/widget-react
# Install dependencies based on the preferred package manager
COPY frontend/package.json ./frontend
COPY frontend/pnpm-*.yaml ./frontend
COPY frontend/patches/* ./frontend/patches
COPY frontend/app/package.json ./frontend/app
COPY frontend/packages/widget-react/package.json ./frontend/packages/widget-react
WORKDIR /tidb.ai/frontend
RUN npm install -g corepack@latest
RUN corepack enable pnpm
RUN pnpm i --frozen-lockfile
# 2. Rebuild the source code only when needed
FROM base AS builder
# Need git client to fetch git revision info
RUN apk add --no-cache git
WORKDIR /tidb.ai
COPY --from=deps /tidb.ai/frontend/node_modules ./frontend/node_modules
COPY --from=deps /tidb.ai/frontend/app/node_modules ./frontend/app/node_modules
COPY --from=deps /tidb.ai/frontend/packages/widget-react/node_modules ./frontend/packages/widget-react/node_modules
# For symbol link of root README.md
COPY . .
# This will do the trick, use the corresponding env file for each environment.
#COPY .env.production.sample .env.production
WORKDIR /tidb.ai/frontend
ENV BASE_URL=""
ENV SITE_URL=""
RUN rm -f app/.env
RUN echo BASE_URL=${BASE_URL:-'""'} >> app/.env.production
RUN npm install -g corepack@latest
RUN corepack enable pnpm
RUN pnpm run build:docker
RUN pnpm run build:widget-react
# 3. Production image, copy all the files and run next
FROM base AS runner
WORKDIR /tidb.ai
ENV NODE_ENV=production
ENV PORT=3000
ENV HOSTNAME=0.0.0.0
RUN addgroup -g 1001 -S nodejs
RUN adduser -S nextjs -u 1001
# Automatically leverage output traces to reduce image size
# https://nextjs.org/docs/advanced-features/output-file-tracing
COPY --from=builder --chown=nextjs:nodejs /tidb.ai/frontend/app/.next/standalone .
COPY --from=builder --chown=nextjs:nodejs /tidb.ai/frontend/app/.next/static app/.next/static
COPY --from=builder /tidb.ai/frontend/app/public app/public
USER nextjs
EXPOSE 3000
CMD ["node", "app/server.js"]
================================================
FILE: frontend/app/.eslintrc.json
================================================
{
"extends": [
"next/core-web-vitals",
"plugin:storybook/recommended"
]
}
================================================
FILE: frontend/app/.gitignore
================================================
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.js
.yarn/install-state.gz
# testing
/coverage
# next.js
/.next/
/out/
# production
/build
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
# local env files
.env*.local
# vercel
.vercel
# typescript
*.tsbuildinfo
next-env.d.ts
certificates
public/widget.js
public/robots.txt
public/sitemap.xml
public/sitemap-*.xml
*storybook.log
storybook-static
================================================
FILE: frontend/app/.storybook/main.ts
================================================
import type { StorybookConfig } from '@storybook/nextjs';
import MonacoEditorWebpackPlugin = require('monaco-editor-webpack-plugin');
const config: StorybookConfig = {
stories: [
'../src/!(pages)/**/*.stories.@(js|jsx|mjs|ts|tsx|mdx)',
],
addons: [
'@storybook/addon-onboarding',
'@storybook/addon-links',
'@storybook/addon-essentials',
'@chromatic-com/storybook',
'@storybook/addon-interactions',
],
framework: {
name: '@storybook/nextjs',
options: {},
},
webpack (config) {
config.plugins.push(new MonacoEditorWebpackPlugin({
languages: ['json'],
filename: 'static/[name].worker.js',
}));
return config;
},
staticDirs: ['../public'],
};
export default config;
================================================
FILE: frontend/app/.storybook/preview.ts
================================================
import type { Preview } from "@storybook/react";
import '../src/app/globals.css';
import '../src/app/chart-theme.css';
const preview: Preview = {
parameters: {
controls: {
matchers: {
color: /(background|color)$/i,
date: /Date$/i,
},
},
},
};
export default preview;
================================================
FILE: frontend/app/README.md
================================================
This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app).
## Getting Started
First, run the development server:
```bash
npm run dev
# or
yarn dev
# or
pnpm dev
# or
bun dev
```
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font.
## Learn More
To learn more about Next.js, take a look at the following resources:
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome!
## Deploy on Vercel
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details.
================================================
FILE: frontend/app/components.json
================================================
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "default",
"rsc": true,
"tsx": true,
"tailwind": {
"config": "tailwind.config.ts",
"css": "src/app/globals.css",
"baseColor": "zinc",
"cssVariables": true,
"prefix": ""
},
"aliases": {
"components": "@/components",
"utils": "@/lib/utils"
}
}
================================================
FILE: frontend/app/jest.config.ts
================================================
/**
* For a detailed explanation regarding each configuration property, visit:
* https://jestjs.io/docs/configuration
*/
import type { Config } from 'jest';
import nextJest from 'next/jest';
const createJestConfig = nextJest({
// Provide the path to your Next.js app to load next.config.js and .env files in your test environment
dir: './',
});
const config: Config = {
extensionsToTreatAsEsm: ['.ts', '.tsx'],
// All imported modules in your tests should be mocked automatically
// automock: false,
// Stop running tests after `n` failures
// bail: 0,
// The directory where Jest should store its cached dependency information
// cacheDirectory: "/private/var/folders/46/7yp8x0hd3rs0j60c_ymsjgpc0000gn/T/jest_dx",
// Automatically clear mock calls, instances, contexts and results before every test
clearMocks: true,
// Indicates whether the coverage information should be collected while executing the test
collectCoverage: true,
// An array of glob patterns indicating a set of files for which coverage information should be collected
// collectCoverageFrom: undefined,
// The directory where Jest should output its coverage files
coverageDirectory: 'coverage',
// An array of regexp pattern strings used to skip coverage collection
// coveragePathIgnorePatterns: [
// "/node_modules/"
// ],
// Indicates which provider should be used to instrument code for coverage
coverageProvider: 'v8',
// A list of reporter names that Jest uses when writing coverage reports
// coverageReporters: [
// "json",
// "text",
// "lcov",
// "clover"
// ],
// An object that configures minimum threshold enforcement for coverage results
// coverageThreshold: undefined,
// A path to a custom dependency extractor
// dependencyExtractor: undefined,
// Make calling deprecated APIs throw helpful error messages
// errorOnDeprecated: false,
// The default configuration for fake timers
// fakeTimers: {
// "enableGlobally": false
// },
// Force coverage collection from ignored files using an array of glob patterns
// forceCoverageMatch: [],
// A path to a module which exports an async function that is triggered once before all test suites
// globalSetup: undefined,
// A path to a module which exports an async function that is triggered once after all test suites
// globalTeardown: undefined,
// A set of global variables that need to be available in all test environments
// globals: {},
// The maximum amount of workers used to run your tests. Can be specified as % or a number. E.g. maxWorkers: 10% will use 10% of your CPU amount + 1 as the maximum worker number. maxWorkers: 2 will use a maximum of 2 workers.
// maxWorkers: "50%",
// An array of directory names to be searched recursively up from the requiring module's location
// moduleDirectories: [
// "node_modules"
// ],
// An array of file extensions your modules use
// moduleFileExtensions: [
// "js",
// "mjs",
// "cjs",
// "jsx",
// "ts",
// "tsx",
// "json",
// "node"
// ],
// A map from regular expressions to module names or to arrays of module names that allow to stub out resources with a single module
// moduleNameMapper: {},
// An array of regexp pattern strings, matched against all module paths before considered 'visible' to the module loader
// modulePathIgnorePatterns: [],
// Activates notifications for test results
// notify: false,
// An enum that specifies notification mode. Requires { notify: true }
// notifyMode: "failure-change",
// A preset that is used as a base for Jest's configuration
// preset: undefined,
// Run tests from one or more projects
// projects: undefined,
// Use this configuration option to add custom reporters to Jest
// reporters: undefined,
// Automatically reset mock state before every test
// resetMocks: false,
// Reset the module registry before running each individual test
// resetModules: false,
// A path to a custom resolver
// resolver: undefined,
// Automatically restore mock state and implementation before every test
// restoreMocks: false,
// The root directory that Jest should scan for tests and modules within
// rootDir: undefined,
// A list of paths to directories that Jest should use to search for files in
// roots: [
// ""
// ],
// Allows you to use a custom runner instead of Jest's default test runner
// runner: "jest-runner",
// The paths to modules that run some code to configure or set up the testing environment before each test
setupFiles: ['/jest.polyfills.js'],
// A list of paths to modules that run some code to configure or set up the testing framework before each test
// setupFilesAfterEnv: [],
// The number of seconds after which a test is considered as slow and reported as such in the results.
// slowTestThreshold: 5,
// A list of paths to snapshot serializer modules Jest should use for snapshot testing
// snapshotSerializers: [],
// The test environment that will be used for testing
testEnvironment: 'jsdom',
// Options that will be passed to the testEnvironment
// testEnvironmentOptions: {},
// Adds a location field to test results
// testLocationInResults: false,
// The glob patterns Jest uses to detect test files
testMatch: [
"**/__tests__/**/*.[jt]s?(x)",
"**/?(*.)+(spec|test).[tj]s?(x)"
],
// An array of regexp pattern strings that are matched against all test paths, matched tests are skipped
// testPathIgnorePatterns: [
// "/node_modules/"
// ],
// The regexp pattern or array of patterns that Jest uses to detect test files
// testRegex: [],
// This option allows the use of a custom results processor
// testResultsProcessor: undefined,
// This option allows use of a custom test runner
// testRunner: "jest-circus/runner",
// A map from regular expressions to paths to transformers
transform: {
'\\.ya?ml$': 'jest-transform-yaml',
'^.+\\.tsx?$': [
'ts-jest',
{
useESM: true,
},
],
},
// An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation
transformIgnorePatterns: [
'/node_modules/',
'\\.pnp\\.[^\\/]+$',
// 'node_modules/(?!(rehype-react))',
],
// An array of regexp pattern strings that are matched against all modules before the module loader will automatically return a mock for them
// unmockedModulePathPatterns: undefined,
// Indicates whether each individual test should be reported during the run
// verbose: undefined,
// An array of regexp patterns that are matched against all source file paths before re-running tests in watch mode
// watchPathIgnorePatterns: [],
// Whether to use watchman for file crawling
// watchman: true,
};
export default createJestConfig(config);
================================================
FILE: frontend/app/jest.polyfills.js
================================================
// jest.polyfills.js
/**
* @note The block below contains polyfills for Node.js globals
* required for Jest to function when running JSDOM tests.
* These HAVE to be require's and HAVE to be in this exact
* order, since "undici" depends on the "TextEncoder" global API.
*
* Consider migrating to a more modern test runner if
* you don't want to deal with this.
*/
const { TextDecoder, TextEncoder } = require("node:util");
const { ReadableStream, TransformStream } = require('node:stream/web')
Object.defineProperties(globalThis, {
TextDecoder: { value: TextDecoder },
TextEncoder: { value: TextEncoder },
ReadableStream: { value: ReadableStream },
TransformStream: { value: TransformStream },
})
const { Blob, File } = require("node:buffer")
const { fetch, Headers, FormData, Request, Response } = require("undici")
Object.defineProperties(globalThis, {
fetch: { value: fetch, writable: true },
Blob: { value: Blob },
File: { value: File },
Headers: { value: Headers },
FormData: { value: FormData },
Request: { value: Request },
Response: { value: Response },
})
class ResizeObserver {
observe() {
}
disconnect() {
}
}
Object.defineProperties(globalThis, {
ResizeObserver: { value: ResizeObserver }
})
================================================
FILE: frontend/app/next-sitemap.config.js
================================================
module.exports = {
siteUrl: process.env.SITE_URL || 'https://tidb.ai',
generateRobotsTxt: true,// (optional)
// ...other options
}
================================================
FILE: frontend/app/next.config.ts
================================================
import MonacoWebpackPlugin from 'monaco-editor-webpack-plugin';
import { NextConfig } from 'next';
const nextConfig: NextConfig = {
output: process.env.STANDALONE ? 'standalone' : undefined,
transpilePackages: ['monaco-editor'],
experimental: {
optimizePackageImports: ['ai', 'lucide-react'],
turbo: {
rules: {
'*.svg': {
loaders: ['@svgr/webpack'],
as: '*.js',
},
},
},
},
webpack (config, options) {
config.module.rules.push({
test: /\.svg$/,
use: '@svgr/webpack',
});
if (!options.isServer) {
config.plugins.push(new MonacoWebpackPlugin({
languages: ['json', 'markdown'],
filename: 'static/[name].worker.js',
}));
}
return config;
},
};
export default nextConfig;
================================================
FILE: frontend/app/notice.md
================================================
Deprecating `react-hook-form`. Use `@tanstack/form` instead.
================================================
FILE: frontend/app/package.json
================================================
{
"name": "app",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev",
"dev:local": "next dev",
"build": "next build && next-sitemap",
"build:standalone": "STANDALONE=1 next build --no-lint && next-sitemap",
"start": "next start",
"lint": "next lint",
"test": "NODE_OPTIONS=\"$NODE_OPTIONS --experimental-vm-modules\" jest",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build"
},
"dependencies": {
"@hookform/resolvers": "^3.9.1",
"@next/third-parties": "15.1.9",
"@radix-ui/react-accordion": "^1.2.3",
"@radix-ui/react-alert-dialog": "^1.1.6",
"@radix-ui/react-aspect-ratio": "^1.1.2",
"@radix-ui/react-avatar": "^1.1.3",
"@radix-ui/react-checkbox": "^1.1.4",
"@radix-ui/react-collapsible": "^1.1.3",
"@radix-ui/react-context-menu": "^2.2.6",
"@radix-ui/react-dialog": "^1.1.6",
"@radix-ui/react-dropdown-menu": "^2.1.6",
"@radix-ui/react-hover-card": "^1.1.6",
"@radix-ui/react-label": "^2.1.2",
"@radix-ui/react-menubar": "^1.1.6",
"@radix-ui/react-navigation-menu": "^1.2.5",
"@radix-ui/react-popover": "^1.1.6",
"@radix-ui/react-progress": "^1.1.2",
"@radix-ui/react-radio-group": "^1.2.3",
"@radix-ui/react-scroll-area": "^1.2.3",
"@radix-ui/react-select": "^2.1.6",
"@radix-ui/react-separator": "^1.1.2",
"@radix-ui/react-slider": "^1.2.3",
"@radix-ui/react-slot": "^1.1.2",
"@radix-ui/react-switch": "^1.1.3",
"@radix-ui/react-tabs": "^1.1.3",
"@radix-ui/react-toast": "^1.2.6",
"@radix-ui/react-toggle": "^1.1.2",
"@radix-ui/react-toggle-group": "^1.1.2",
"@radix-ui/react-tooltip": "^1.1.8",
"@tailwindcss/typography": "^0.5.15",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"cmdk": "^1.0.4",
"date-fns": "^3.6.0",
"embla-carousel-react": "^8.5.1",
"force-graph": "^1.49.6",
"input-otp": "^1.4.1",
"js-cookie": "^3.0.5",
"lucide-react": "^0.400.0",
"merge-refs": "^1.3.0",
"next": "15.1.11",
"next-themes": "^0.4.4",
"react": "19.0.0",
"react-day-picker": "^8.10.1",
"react-dom": "19.0.0",
"react-hook-form": "^7.54.0",
"react-is": "19.0.0",
"react-resizable-panels": "^2.1.7",
"recharts": "^2.15.1",
"rehype-highlight": "^7.0.1",
"rehype-react": "^8.0.0",
"remark-gfm": "^4.0.0",
"remark-parse": "^11.0.0",
"remark-rehype": "^11.1.1",
"sonner": "^1.7.1",
"tailwind-merge": "^2.5.5",
"tailwindcss-animate": "^1.0.7",
"unified": "^11.0.5",
"vaul": "^1.1.1",
"zod": "^3.24.0"
},
"devDependencies": {
"@chromatic-com/storybook": "^1.9.0",
"@jest/globals": "^29.7.0",
"@storybook/addon-essentials": "^8.4.7",
"@storybook/addon-interactions": "^8.4.7",
"@storybook/addon-links": "^8.4.7",
"@storybook/addon-onboarding": "^8.4.7",
"@storybook/blocks": "^8.4.7",
"@storybook/nextjs": "^8.4.7",
"@storybook/react": "^8.4.7",
"@storybook/test": "^8.4.7",
"@svgr/webpack": "^8.1.0",
"@tanstack/react-form": "^0.40.4",
"@tanstack/react-table": "^8.20.5",
"@tanstack/table-core": "^8.20.5",
"@tanstack/zod-form-adapter": "^0.40.4",
"@testing-library/jest-dom": "^6.6.3",
"@testing-library/react": "^16.1.0",
"@types/d3": "^7.4.3",
"@types/gtag.js": "^0.0.20",
"@types/is-hotkey": "^0.1.10",
"@types/jest": "^29.5.14",
"@types/js-cookie": "^3.0.6",
"@types/mdast": "^4.0.4",
"@types/node": "^20.17.9",
"@types/react": "19.0.1",
"@types/react-dom": "19.0.2",
"@types/react-is": "^19.0.0",
"@types/wcwidth": "^1.0.2",
"@uiw/react-json-view": "2.0.0-alpha.30",
"ai": "^3.4.33",
"change-case-all": "^2.1.0",
"copy-to-clipboard": "^3.3.3",
"d3": "^7.9.0",
"eslint": "^9.16.0",
"eslint-config-next": "15.1.9",
"eslint-plugin-storybook": "^0.11.1",
"eventemitter3": "^5.0.1",
"fast-equals": "^5.0.1",
"filesize": "^10.1.6",
"framer-motion": "^11.18.2",
"git-revision-webpack-plugin": "^5.0.0",
"hastscript": "^9.0.0",
"highlight.js": "^11.10.0",
"is-hotkey": "^0.2.0",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0",
"jest-transform-yaml": "^1.1.2",
"monaco-editor": "^0.50.0",
"monaco-editor-webpack-plugin": "^7.1.0",
"next-sitemap": "^4.2.3",
"postcss": "^8.4.49",
"react-textarea-autosize": "^8.5.7",
"remark-directive": "^3.0.0",
"sass": "^1.82.0",
"sql-formatter": "^15.4.6",
"storybook": "^8.4.7",
"swr": "^2.2.5",
"tailwindcss": "^3.4.16",
"ts-jest": "^29.2.5",
"ts-node": "^10.9.2",
"typescript": "^5.7.2",
"undici": "^6.21.0",
"unist-util-visit": "^5.0.0",
"wcwidth": "^1.0.1",
"yaml-loader": "^0.8.1"
},
"imports": {
"#*": {
"react-server": "./src/*.react-server.ts",
"storybook": "./src/*.mock.ts",
"tidbai-widget": "./src/*.tidbai-widget.ts",
"default": "./src/*.ts"
}
}
}
================================================
FILE: frontend/app/postcss.config.mjs
================================================
/** @type {import('postcss-load-config').Config} */
const config = {
plugins: {
tailwindcss: {},
},
};
export default config;
================================================
FILE: frontend/app/public/chats.mock.txt
================================================
0:""
2:[{"chat":{"origin": null,"""updated_at":"2024-07-09T09:55:30","id":"00000000-0000-0000-0000-00000000000","title":"what is tidb?","engine_id":1,"engine_options":"{\"llm\":{\"provider\":\"openai\",\"openai_chat_model\":\"gpt-4o\",\"gemini_chat_model\":\"models/gemini-1.5-flash\",\"reranker_provider\":\"jinaai\",\"reranker_top_k\":10,\"intent_graph_knowledge\":\"Given a list of relationships of a knowledge graph as follows. When there is a conflict in meaning between knowledge relationships, the relationship with the higher `weight` and newer `last_modified_at` value takes precedence.\\n\\nKnowledge sub-queries:\\n\\n{% for sub_query, data in sub_queries.items() %}\\n\\nSub-query: {{ sub_query }}\\n\\n - Entities:\\n\\n{% for entity in data['entities'] %}\\n\\n - Name: {{ entity.name }}\\n - Description: {{ entity.description }}\\n\\n{% endfor %}\\n\\n - Relationships:\\n\\n{% for relationship in data['relationships'] %}\\n\\n - Description: {{ relationship.rag_description }}\\n - Last Modified At: {{ relationship.last_modified_at }}\\n - Meta: {{ relationship.meta | tojson(indent=2) }}\\n\\n{% endfor %}\\n\\n{% endfor %}\\n\",\"normal_graph_knowledge\":\"Given a list of relationships of a knowledge graph as follows. When there is a conflict in meaning between knowledge relationships, the relationship with the higher `weight` and newer `last_modified_at` value takes precedence.\\n\\n---------------------\\nEntities:\\n\\n{% for entity in entities %}\\n\\n- Name: {{ entity.name }}\\n- Description: {{ entity.description }}\\n\\n{% endfor %}\\n\\n---------------------\\n\\nKnowledge relationships:\\n\\n{% for relationship in relationships %}\\n\\n- Description: {{ relationship.rag_description }}\\n- Weight: {{ relationship.weight }}\\n- Last Modified At: {{ relationship.last_modified_at }}\\n- Meta: {{ relationship.meta | tojson(indent=2) }}\\n\\n{% endfor %}\\n\"},\"knowledge_graph\":{\"enabled\":true,\"depth\":2,\"include_meta\":true,\"with_degree\":false,\"using_intent_search\":true}}","user_id":"01907db8-8850-795d-855b-552663c18c9f","created_at":"2024-07-09T09:55:30","deleted_at":null},"user_message":{"id":120007,"created_at":"2024-07-09T09:55:31","role":"user","trace_url":null,"finished_at":null,"user_id":"01907db8-8850-795d-855b-552663c18c9f","updated_at":"2024-07-09T09:55:31","ordinal":1,"content":"what is tidb?","error":null,"sources":[],"chat_id":"00000000-0000-0000-0000-00000000000"},"assistant_message":{"id":120008,"created_at":"2024-07-09T09:55:31","role":"assistant","trace_url":"https://us.cloud.langfuse.com/trace/ccf8aea4-086b-4d72-a7a1-a05ef2e6830a","finished_at":null,"user_id":"01907db8-8850-795d-855b-552663c18c9f","updated_at":"2024-07-09T09:55:31","ordinal":2,"content":"","error":null,"sources":[],"chat_id":"00000000-0000-0000-0000-00000000000"}}]
8:[{"state":"TRACE","display":"Start knowledge graph searching ...","context":{"langfuse_url":"https://us.cloud.langfuse.com/trace/ccf8aea4-086b-4d72-a7a1-a05ef2e6830a"}}]
8:[{"state":"REFINE_QUESTION","display":"Refine the user question ...","context":""}]
8:[{"state":"SEARCH_RELATED_DOCUMENTS","display":"Search related documents ...","context":""}]
8:[{"state":"SOURCE_NODES","display":"","context":[{"id":8600,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v7.5/tidb-faq"},{"id":9057,"name":"Overview","source_uri":"https://docs.pingcap.com/tidbcloud/tidb-architecture"},{"id":9373,"name":"TiDB Introduction","source_uri":"https://docs.pingcap.com/tidb/v7.1/overview"},{"id":9865,"name":"Overview","source_uri":"https://docs.pingcap.com/tidb/v7.1/tidb-architecture"},{"id":10191,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v7.1/tidb-faq"},{"id":11370,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v6.5/tidb-faq"},{"id":12985,"name":"TiDB Introduction","source_uri":"https://docs.pingcap.com/tidb/v5.4/overview"},{"id":13942,"name":"Overview","source_uri":"https://docs.pingcap.com/tidb/v5.3/tidb-architecture"},{"id":14218,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v5.3/tidb-faq"}]}]
0:""
0:"###"
0:" Comprehensive"
0:" Overview"
0:" of"
0:" Ti"
0:"DB"
0:"\n\n"
0:"####"
0:" What"
0:" is"
0:" Ti"
0:"DB"
0:"?\n\n"
0:"Ti"
0:"DB"
0:" is"
0:" an"
0:" open"
0:"-source"
0:" distributed"
0:" SQL"
0:" database"
0:" designed"
0:" to"
0:" support"
0:" Hybrid"
0:" Transaction"
0:"al"
0:" and"
0:" Analytical"
0:" Processing"
0:" ("
0:"HT"
0:"AP"
0:")"
0:" workloads"
0:"."
0:" It"
0:" is"
0:" My"
0:"SQL"
0:"-compatible"
0:" and"
0:" features"
0:" horizontal"
0:" scalability"
0:","
0:" strong"
0:" consistency"
0:","
0:" and"
0:" high"
0:" availability"
0:"[^"
0:"1"
0:"]["
0:"^"
0:"2"
0:"]."
0:" Ti"
0:"DB"
0:" aims"
0:" to"
0:" provide"
0:" a"
0:" one"
0:"-stop"
0:" database"
0:" solution"
0:" that"
0:" covers"
0:" OL"
0:"TP"
0:" ("
0:"Online"
0:" Transaction"
0:"al"
0:" Processing"
0:"),"
0:" OL"
0:"AP"
0:" ("
0:"Online"
0:" Analytical"
0:" Processing"
0:"),"
0:" and"
0:" HT"
0:"AP"
0:" services"
0:"[^"
0:"3"
0:"].\n\n"
0:"####"
0:" Architecture"
0:"\n\n"
0:"The"
0:" Ti"
0:"DB"
0:" cluster"
0:" consists"
0:" of"
0:" three"
0:" main"
0:" components"
0:"[^"
0:"4"
0:"]:\n\n"
0:"1"
0:"."
0:" **"
0:"Ti"
0:"DB"
0:" Server"
0:"**"
0:":"
0:" Handles"
0:" SQL"
0:" parsing"
0:","
0:" query"
0:" planning"
0:","
0:" and"
0:" execution"
0:".\n"
0:"2"
0:"."
0:" **"
0:"Ti"
0:"KV"
0:" Server"
0:"**"
0:":"
0:" Acts"
0:" as"
0:" the"
0:" distributed"
0:" key"
0:"-value"
0:" storage"
0:" engine"
0:","
0:" storing"
0:" the"
0:" actual"
0:" data"
0:".\n"
0:"3"
0:"."
0:" **"
0:"PD"
0:" ("
0:"Placement"
0:" Driver"
0:")"
0:" Server"
0:"**"
0:":"
0:" Man"
0:"ages"
0:" cluster"
0:" metadata"
0:","
0:" alloc"
0:"ates"
0:" timestamps"
0:","
0:" and"
0:" handles"
0:" data"
0:" placement"
0:" and"
0:" load"
0:" balancing"
0:".\n\n"
0:"Additionally"
0:","
0:" Ti"
0:"DB"
0:" includes"
0:" other"
0:" components"
0:" such"
0:" as"
0:" Ti"
0:"Flash"
0:" for"
0:" column"
0:"ar"
0:" storage"
0:","
0:" Ti"
0:"Spark"
0:" for"
0:" Spark"
0:" integration"
0:","
0:" Ti"
0:"DB"
0:" Bin"
0:"log"
0:" for"
0:" data"
0:" replication"
0:","
0:" and"
0:" Ti"
0:"DB"
0:" Lightning"
0:" for"
0:" high"
0:"-performance"
0:" data"
0:" import"
0:"[^"
0:"5"
0:"].\n\n"
0:"####"
0:" Key"
0:" Features"
0:"\n\n"
0:"1"
0:"."
0:" **"
0:"Sc"
0:"al"
0:"ability"
0:"**"
0:":"
0:" Ti"
0:"DB"
0:" allows"
0:" for"
0:" easy"
0:" horizontal"
0:" scaling"
0:" of"
0:" both"
0:" computing"
0:" and"
0:" storage"
0:" resources"
0:"."
0:" The"
0:" architecture"
0:" separates"
0:" computing"
0:" from"
0:" storage"
0:","
0:" enabling"
0:" independent"
0:" scaling"
0:"[^"
0:"6"
0:"]."
0:" This"
0:" design"
0:" ensures"
0:" flexibility"
0:" and"
0:" adaptability"
0:" to"
0:" changing"
0:" workloads"
0:"[^"
0:"7"
0:"].\n\n"
0:"2"
0:"."
0:" **"
0:"High"
0:" Availability"
0:"**"
0:":"
0:" Ti"
0:"DB"
0:" guarantees"
0:" strong"
0:" consistency"
0:" and"
0:" availability"
0:" through"
0:" its"
0:" multi"
0:"-re"
0:"p"
0:"lica"
0:" architecture"
0:" and"
0:" Multi"
0:"-R"
0:"aft"
0:" protocol"
0:"."
0:" Data"
0:" is"
0:" stored"
0:" in"
0:" multiple"
0:" replicas"
0:","
0:" and"
0:" a"
0:" transaction"
0:" can"
0:" only"
0:" be"
0:" committed"
0:" when"
0:" data"
0:" has"
0:" been"
0:" successfully"
0:" written"
0:" into"
0:" the"
0:" majority"
0:" of"
0:" replicas"
0:"[^"
0:"8"
0:"]."
0:" This"
0:" ensures"
0:" data"
0:" integrity"
0:" even"
0:" in"
0:" the"
0:" event"
0:" of"
0:" failures"
0:"[^"
0:"9"
0:"].\n\n"
0:"3"
0:"."
0:" **"
0:"HT"
0:"AP"
0:" Cap"
0:"abilities"
0:"**"
0:":"
0:" Ti"
0:"DB"
0:" supports"
0:" both"
0:" row"
0:"-based"
0:" ("
0:"Ti"
0:"KV"
0:")"
0:" and"
0:" column"
0:"ar"
0:" ("
0:"Ti"
0:"Flash"
0:")"
0:" storage"
0:" engines"
0:","
0:" enabling"
0:" real"
0:"-time"
0:" processing"
0:" of"
0:" both"
0:" transactional"
0:" and"
0:" analytical"
0:" workloads"
0:"."
0:" Ti"
0:"Flash"
0:" uses"
0:" the"
0:" Multi"
0:"-R"
0:"aft"
0:" Lear"
0:"ner"
0:" protocol"
0:" to"
0:" replicate"
0:" data"
0:" from"
0:" Ti"
0:"KV"
0:" in"
0:" real"
0:" time"
0:","
0:" ensuring"
0:" consistent"
0:" data"
0:" across"
0:" both"
0:" engines"
0:"[^"
0:"10"
0:"].\n\n"
0:"4"
0:"."
0:" **"
0:"Cloud"
0:"-N"
0:"ative"
0:" Design"
0:"**"
0:":"
0:" Ti"
0:"DB"
0:" is"
0:" built"
0:" for"
0:" cloud"
0:" environments"
0:","
0:" offering"
0:" flexible"
0:" scalability"
0:","
0:" reliability"
0:","
0:" and"
0:" security"
0:" on"
0:" various"
0:" cloud"
0:" platforms"
0:"."
0:" It"
0:" integrates"
0:" seamlessly"
0:" with"
0:" Kubernetes"
0:" and"
0:" offers"
0:" a"
0:" fully"
0:" managed"
0:" service"
0:","
0:" Ti"
0:"DB"
0:" Cloud"
0:"[^"
0:"11"
0:"].\n\n"
0:"5"
0:"."
0:" **"
0:"My"
0:"SQL"
0:" Compatibility"
0:"**"
0:":"
0:" Ti"
0:"DB"
0:" is"
0:" compatible"
0:" with"
0:" the"
0:" My"
0:"SQL"
0:" "
0:"5"
0:"."
0:"7"
0:" protocol"
0:" and"
0:" ecosystem"
0:","
0:" allowing"
0:" for"
0:" easy"
0:" migration"
0:" of"
0:" applications"
0:" with"
0:" minimal"
0:" code"
0:" changes"
0:"[^"
0:"12"
0:"]."
0:" However"
0:","
0:" it"
0:" does"
0:" not"
0:" support"
0:" certain"
0:" My"
0:"SQL"
0:" features"
0:" like"
0:" triggers"
0:","
0:" stored"
0:" procedures"
0:","
0:" user"
0:"-defined"
0:" functions"
0:","
0:" and"
0:" foreign"
0:" keys"
0:"[^"
0:"13"
0:"].\n\n"
0:"####"
0:" Distributed"
0:" Transactions"
0:"\n\n"
0:"Ti"
0:"DB"
0:" supports"
0:" distributed"
0:" transactions"
0:" across"
0:" the"
0:" cluster"
0:" using"
0:" a"
0:" two"
0:"-phase"
0:" commit"
0:" protocol"
0:" with"
0:" optim"
0:"izations"
0:" inspired"
0:" by"
0:" Google's"
0:" Per"
0:"col"
0:"ator"
0:"."
0:" This"
0:" model"
0:" relies"
0:" on"
0:" a"
0:" timestamp"
0:" allocator"
0:" to"
0:" assign"
0:" monot"
0:"one"
0:" increasing"
0:" timestamps"
0:" for"
0:" each"
0:" transaction"
0:","
0:" ensuring"
0:" conflict"
0:" detection"
0:"[^"
0:"14"
0:"]."
0:" The"
0:" PD"
0:" server"
0:" acts"
0:" as"
0:" the"
0:" timestamp"
0:" allocator"
0:" in"
0:" a"
0:" Ti"
0:"DB"
0:" cluster"
0:"[^"
0:"15"
0:"].\n\n"
0:"####"
0:" Ease"
0:" of"
0:" Use"
0:"\n\n"
0:"Ti"
0:"DB"
0:" is"
0:" designed"
0:" to"
0:" be"
0:" user"
0:"-friendly"
0:","
0:" allowing"
0:" users"
0:" to"
0:" manage"
0:" it"
0:" with"
0:" familiar"
0:" My"
0:"SQL"
0:" tools"
0:" and"
0:" replace"
0:" My"
0:"SQL"
0:" in"
0:" applications"
0:" without"
0:" significant"
0:" code"
0:" changes"
0:"[^"
0:"16"
0:"]."
0:" It"
0:" supports"
0:" various"
0:" programming"
0:" languages"
0:" through"
0:" My"
0:"SQL"
0:" clients"
0:" or"
0:" drivers"
0:"[^"
0:"17"
0:"].\n\n"
0:"####"
0:" Data"
0:" Migration"
0:" Tools"
0:"\n\n"
0:"Ti"
0:"DB"
0:" provides"
0:" a"
0:" suite"
0:" of"
0:" tools"
0:" for"
0:" migrating"
0:","
0:" replic"
0:"ating"
0:","
0:" and"
0:" backing"
0:" up"
0:" data"
0:"."
0:" These"
0:" tools"
0:" facilitate"
0:" the"
0:" transition"
0:" from"
0:" My"
0:"SQL"
0:" to"
0:" Ti"
0:"DB"
0:","
0:" making"
0:" the"
0:" migration"
0:" process"
0:" straightforward"
0:"[^"
0:"18"
0:"].\n\n"
0:"###"
0:" Conclusion"
0:"\n\n"
0:"Ti"
0:"DB"
0:" is"
0:" a"
0:" robust"
0:","
0:" scalable"
0:","
0:" and"
0:" highly"
0:" available"
0:" distributed"
0:" SQL"
0:" database"
0:" that"
0:" offers"
0:" comprehensive"
0:" support"
0:" for"
0:" HT"
0:"AP"
0:" workloads"
0:"."
0:" Its"
0:" compatibility"
0:" with"
0:" My"
0:"SQL"
0:","
0:" combined"
0:" with"
0:" its"
0:" advanced"
0:" features"
0:" and"
0:" cloud"
0:"-native"
0:" design"
0:","
0:" makes"
0:" it"
0:" a"
0:" versatile"
0:" choice"
0:" for"
0:" various"
0:" data"
0:"-intensive"
0:" applications"
0:".\n\n"
0:"[^"
0:"1"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Overview"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/st"
0:"able"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"2"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Introduction"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"5"
0:"."
0:"4"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"3"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Introduction"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/t"
0:"id"
0:"b"
0:"-f"
0:"aq"
0:")\n"
0:"[^"
0:"4"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Architecture"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/t"
0:"id"
0:"b"
0:"-"
0:"architecture"
0:")\n"
0:"[^"
0:"5"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Architecture"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"5"
0:"."
0:"3"
0:"/t"
0:"id"
0:"b"
0:"-"
0:"architecture"
0:")\n"
0:"[^"
0:"6"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Key"
0:" Features"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"7"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Key"
0:" Features"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"5"
0:"."
0:"4"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"8"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" High"
0:" Availability"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"9"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" High"
0:" Availability"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"5"
0:"."
0:"4"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"10"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" HT"
0:"AP"
0:" Cap"
0:"abilities"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"11"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Cloud"
0:"-N"
0:"ative"
0:" Design"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"5"
0:"."
0:"4"
0:"/"
0:"overview"
0:")\n"
0:"[^"
0:"12"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" My"
0:"SQL"
0:" Compatibility"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/t"
0:"id"
0:"b"
0:"-f"
0:"aq"
0:")\n"
0:"[^"
0:"13"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" My"
0:"SQL"
0:" Compatibility"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"6"
0:"."
0:"5"
0:"/t"
0:"id"
0:"b"
0:"-f"
0:"aq"
0:")\n"
0:"[^"
0:"14"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Distributed"
0:" Transactions"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/t"
0:"id"
0:"b"
0:"-f"
0:"aq"
0:")\n"
0:"[^"
0:"15"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Distributed"
0:" Transactions"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"6"
0:"."
0:"5"
0:"/t"
0:"id"
0:"b"
0:"-f"
0:"aq"
0:")\n"
0:"[^"
0:"16"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Ease"
0:" of"
0:" Use"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/t"
0:"id"
0:"b"
0:"-f"
0:"aq"
0:")\n"
0:"[^"
0:"17"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Programming"
0:" Language"
0:" Support"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"7"
0:"."
0:"1"
0:"/t"
0:"id"
0:"b"
0:"-f"
0:"aq"
0:")\n"
0:"[^"
0:"18"
0:"]:"
0:" ["
0:"Ti"
0:"DB"
0:" Data"
0:" Migration"
0:" Tools"
0:" |"
0:" Ping"
0:"CAP"
0:" Docs"
0:"]("
0:"https"
0:"://"
0:"docs"
0:".p"
0:"ing"
0:"cap"
0:".com"
0:"/t"
0:"id"
0:"b"
0:"/v"
0:"5"
0:"."
0:"3"
0:"/t"
0:"id"
0:"b"
0:"-"
0:"architecture"
0:")"
0:""
8:[{"state":"FINISHED","display":"","context":""}]
2:[{"chat":{"origin": null,""updated_at":"2024-07-09T09:55:30","id":"00000000-0000-0000-0000-00000000000","title":"what is tidb?","engine_id":1,"engine_options":"{\"llm\":{\"provider\":\"openai\",\"openai_chat_model\":\"gpt-4o\",\"gemini_chat_model\":\"models/gemini-1.5-flash\",\"reranker_provider\":\"jinaai\",\"reranker_top_k\":10,\"intent_graph_knowledge\":\"Given a list of relationships of a knowledge graph as follows. When there is a conflict in meaning between knowledge relationships, the relationship with the higher `weight` and newer `last_modified_at` value takes precedence.\\n\\nKnowledge sub-queries:\\n\\n{% for sub_query, data in sub_queries.items() %}\\n\\nSub-query: {{ sub_query }}\\n\\n - Entities:\\n\\n{% for entity in data['entities'] %}\\n\\n - Name: {{ entity.name }}\\n - Description: {{ entity.description }}\\n\\n{% endfor %}\\n\\n - Relationships:\\n\\n{% for relationship in data['relationships'] %}\\n\\n - Description: {{ relationship.rag_description }}\\n - Last Modified At: {{ relationship.last_modified_at }}\\n - Meta: {{ relationship.meta | tojson(indent=2) }}\\n\\n{% endfor %}\\n\\n{% endfor %}\\n\",\"normal_graph_knowledge\":\"Given a list of relationships of a knowledge graph as follows. When there is a conflict in meaning between knowledge relationships, the relationship with the higher `weight` and newer `last_modified_at` value takes precedence.\\n\\n---------------------\\nEntities:\\n\\n{% for entity in entities %}\\n\\n- Name: {{ entity.name }}\\n- Description: {{ entity.description }}\\n\\n{% endfor %}\\n\\n---------------------\\n\\nKnowledge relationships:\\n\\n{% for relationship in relationships %}\\n\\n- Description: {{ relationship.rag_description }}\\n- Weight: {{ relationship.weight }}\\n- Last Modified At: {{ relationship.last_modified_at }}\\n- Meta: {{ relationship.meta | tojson(indent=2) }}\\n\\n{% endfor %}\\n\"},\"knowledge_graph\":{\"enabled\":true,\"depth\":2,\"include_meta\":true,\"with_degree\":false,\"using_intent_search\":true}}","user_id":"01907db8-8850-795d-855b-552663c18c9f","created_at":"2024-07-09T09:55:30","deleted_at":null},"user_message":{"post_verification_result_url": null,"id":120007,"created_at":"2024-07-09T09:55:31","role":"user","trace_url":null,"finished_at":null,"user_id":"01907db8-8850-795d-855b-552663c18c9f","updated_at":"2024-07-09T09:55:31","ordinal":1,"content":"what is tidb?","error":null,"sources":[],"chat_id":"00000000-0000-0000-0000-00000000000"},"assistant_message":{"post_verification_result_url": null,"id":120008,"created_at":"2024-07-09T09:55:31","role":"assistant","trace_url":"https://us.cloud.langfuse.com/trace/ccf8aea4-086b-4d72-a7a1-a05ef2e6830a","finished_at":"2024-07-09T09:58:07.692084Z","user_id":"01907db8-8850-795d-855b-552663c18c9f","ordinal":2,"content":"### Comprehensive Overview of TiDB\n\n#### What is TiDB?\n\nTiDB is an open-source distributed SQL database designed to support Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL-compatible and features horizontal scalability, strong consistency, and high availability[^1][^2]. TiDB aims to provide a one-stop database solution that covers OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services[^3].\n\n#### Architecture\n\nThe TiDB cluster consists of three main components[^4]:\n\n1. **TiDB Server**: Handles SQL parsing, query planning, and execution.\n2. **TiKV Server**: Acts as the distributed key-value storage engine, storing the actual data.\n3. **PD (Placement Driver) Server**: Manages cluster metadata, allocates timestamps, and handles data placement and load balancing.\n\nAdditionally, TiDB includes other components such as TiFlash for columnar storage, TiSpark for Spark integration, TiDB Binlog for data replication, and TiDB Lightning for high-performance data import[^5].\n\n#### Key Features\n\n1. **Scalability**: TiDB allows for easy horizontal scaling of both computing and storage resources. The architecture separates computing from storage, enabling independent scaling[^6]. This design ensures flexibility and adaptability to changing workloads[^7].\n\n2. **High Availability**: TiDB guarantees strong consistency and availability through its multi-replica architecture and Multi-Raft protocol. Data is stored in multiple replicas, and a transaction can only be committed when data has been successfully written into the majority of replicas[^8]. This ensures data integrity even in the event of failures[^9].\n\n3. **HTAP Capabilities**: TiDB supports both row-based (TiKV) and columnar (TiFlash) storage engines, enabling real-time processing of both transactional and analytical workloads. TiFlash uses the Multi-Raft Learner protocol to replicate data from TiKV in real time, ensuring consistent data across both engines[^10].\n\n4. **Cloud-Native Design**: TiDB is built for cloud environments, offering flexible scalability, reliability, and security on various cloud platforms. It integrates seamlessly with Kubernetes and offers a fully managed service, TiDB Cloud[^11].\n\n5. **MySQL Compatibility**: TiDB is compatible with the MySQL 5.7 protocol and ecosystem, allowing for easy migration of applications with minimal code changes[^12]. However, it does not support certain MySQL features like triggers, stored procedures, user-defined functions, and foreign keys[^13].\n\n#### Distributed Transactions\n\nTiDB supports distributed transactions across the cluster using a two-phase commit protocol with optimizations inspired by Google's Percolator. This model relies on a timestamp allocator to assign monotone increasing timestamps for each transaction, ensuring conflict detection[^14]. The PD server acts as the timestamp allocator in a TiDB cluster[^15].\n\n#### Ease of Use\n\nTiDB is designed to be user-friendly, allowing users to manage it with familiar MySQL tools and replace MySQL in applications without significant code changes[^16]. It supports various programming languages through MySQL clients or drivers[^17].\n\n#### Data Migration Tools\n\nTiDB provides a suite of tools for migrating, replicating, and backing up data. These tools facilitate the transition from MySQL to TiDB, making the migration process straightforward[^18].\n\n### Conclusion\n\nTiDB is a robust, scalable, and highly available distributed SQL database that offers comprehensive support for HTAP workloads. Its compatibility with MySQL, combined with its advanced features and cloud-native design, makes it a versatile choice for various data-intensive applications.\n\n[^1]: [TiDB Overview | PingCAP Docs](https://docs.pingcap.com/tidb/stable/overview)\n[^2]: [TiDB Introduction | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)\n[^3]: [TiDB Introduction | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/tidb-faq)\n[^4]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/tidb-architecture)\n[^5]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v5.3/tidb-architecture)\n[^6]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/overview)\n[^7]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)\n[^8]: [TiDB High Availability | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/overview)\n[^9]: [TiDB High Availability | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)\n[^10]: [TiDB HTAP Capabilities | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/overview)\n[^11]: [TiDB Cloud-Native Design | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)\n[^12]: [TiDB MySQL Compatibility | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/tidb-faq)\n[^13]: [TiDB MySQL Compatibility | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/tidb-faq)\n[^14]: [TiDB Distributed Transactions | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/tidb-faq)\n[^15]: [TiDB Distributed Transactions | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/tidb-faq)\n[^16]: [TiDB Ease of Use | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/tidb-faq)\n[^17]: [TiDB Programming Language Support | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/tidb-faq)\n[^18]: [TiDB Data Migration Tools | PingCAP Docs](https://docs.pingcap.com/tidb/v5.3/tidb-architecture)","error":null,"sources":[{"id":8600,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v7.5/tidb-faq"},{"id":9057,"name":"Overview","source_uri":"https://docs.pingcap.com/tidbcloud/tidb-architecture"},{"id":9373,"name":"TiDB Introduction","source_uri":"https://docs.pingcap.com/tidb/v7.1/overview"},{"id":9865,"name":"Overview","source_uri":"https://docs.pingcap.com/tidb/v7.1/tidb-architecture"},{"id":10191,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v7.1/tidb-faq"},{"id":11370,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v6.5/tidb-faq"},{"id":12985,"name":"TiDB Introduction","source_uri":"https://docs.pingcap.com/tidb/v5.4/overview"},{"id":13942,"name":"Overview","source_uri":"https://docs.pingcap.com/tidb/v5.3/tidb-architecture"},{"id":14218,"name":"TiDB FAQs","source_uri":"https://docs.pingcap.com/tidb/v5.3/tidb-faq"}],"chat_id":"00000000-0000-0000-0000-00000000000"}}]
================================================
FILE: frontend/app/src/api/.gitignore
================================================
*.http
http-client.private.env.json
================================================
FILE: frontend/app/src/api/api-keys.ts
================================================
import { authenticationHeaders, handleErrors, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType, type ZodTypeDef } from 'zod';
export interface ApiKey {
created_at: Date;
updated_at: Date;
description: string;
api_key_display: string;
is_active: boolean;
user_id: string;
id: number;
}
export interface CreateApiKey {
description: string;
}
export interface CreateApiKeyResponse {
api_key: string;
}
const apiKeySchema = z.object({
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
description: z.string(),
api_key_display: z.string(),
is_active: z.boolean(),
user_id: z.string(),
id: z.number(),
}) satisfies ZodType;
const createApiKeyResponseSchema = z.object({
api_key: z.string(),
}) satisfies ZodType;
export async function listApiKeys ({ page = 1, size = 10 }: PageParams = {}): Promise> {
return await fetch(requestUrl('/api/v1/api-keys', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(apiKeySchema)));
}
export async function createApiKey (create: CreateApiKey): Promise {
return await fetch(requestUrl('/api/v1/api-keys'), {
method: 'POST',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(create),
}).then(handleResponse(createApiKeyResponseSchema));
}
export async function deleteApiKey (id: number): Promise {
await fetch(requestUrl(`/api/v1/api-keys/${id}`), {
method: 'DELETE',
headers: {
...await authenticationHeaders(),
},
}).then(handleErrors);
}
================================================
FILE: frontend/app/src/api/auth.ts
================================================
import { authenticationHeaders, buildUrlParams, handleErrors, requestUrl } from '@/lib/request';
export interface LoginParams {
username: string;
password: string;
}
export async function login (params: LoginParams) {
const usp = buildUrlParams(params);
await fetch(requestUrl('/api/v1/auth/login'), {
method: 'POST',
body: usp,
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/x-www-form-urlencoded',
},
}).then(handleErrors);
}
export async function logout () {
await fetch(requestUrl('/api/v1/auth/logout'), {
headers: {
...await authenticationHeaders(),
},
method: 'POST',
}).then(handleErrors);
}
================================================
FILE: frontend/app/src/api/chat-engines.ts
================================================
import { authenticationHeaders, handleErrors, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType } from 'zod';
export interface ChatEngine {
id: number;
name: string;
updated_at: Date;
created_at: Date;
deleted_at: Date | null;
engine_options: ChatEngineOptions;
llm_id: number | null;
fast_llm_id: number | null;
reranker_id: number | null;
is_default: boolean;
is_public: boolean;
}
export interface CreateChatEngineParams {
name: string;
engine_options: ChatEngineOptions;
llm_id?: number | null;
fast_llm_id?: number | null;
reranker_id?: number | null;
}
export interface ChatEngineOptions {
external_engine_config?: {
stream_chat_api_url?: string | null
} | null;
clarify_question?: boolean | null;
further_questions?: boolean | null;
knowledge_base?: ChatEngineKnowledgeBaseOptions | null;
knowledge_graph?: ChatEngineKnowledgeGraphOptions | null;
llm?: ChatEngineLLMOptions | null;
post_verification_url?: string | null;
post_verification_token?: string | null;
hide_sources?: boolean | null;
}
export interface ChatEngineKnowledgeBaseOptions {
/**
* @deprecated
*/
linked_knowledge_base?: LinkedKnowledgeBaseOptions | null;
linked_knowledge_bases?: { id: number }[] | null;
}
export interface ChatEngineKnowledgeGraphOptions {
depth?: number | null;
enabled?: boolean | null;
include_meta?: boolean | null;
with_degree?: boolean | null;
using_intent_search?: boolean | null;
}
export type ChatEngineLLMOptions = {
condense_question_prompt?: string | null
text_qa_prompt?: string | null
intent_graph_knowledge?: string | null
normal_graph_knowledge?: string | null
clarifying_question_prompt?: string | null
generate_goal_prompt?: string | null
further_questions_prompt?: string | null
}
/**
* @deprecated
*/
export interface LinkedKnowledgeBaseOptions {
id?: number | null;
}
const kbOptionsSchema = z.object({
linked_knowledge_base: z.object({ id: z.number().nullable().optional() }).nullable().optional(),
linked_knowledge_bases: z.object({ id: z.number() }).array().nullable().optional(),
}).passthrough();
const kgOptionsSchema = z.object({
depth: z.number().nullable().optional(),
enabled: z.boolean().nullable().optional(),
include_meta: z.boolean().nullable().optional(),
with_degree: z.boolean().nullable().optional(),
using_intent_search: z.boolean().nullable().optional(),
}).passthrough() satisfies ZodType;
const llmOptionsSchema =
z.object({
condense_question_prompt: z.string().nullable().optional(),
text_qa_prompt: z.string().nullable().optional(),
intent_graph_knowledge: z.string().nullable().optional(),
normal_graph_knowledge: z.string().nullable().optional(),
clarifying_question_prompt: z.string().nullable().optional(),
generate_goal_prompt: z.string().nullable().optional(),
further_questions_prompt: z.string().nullable().optional(),
// provider: z.string(),
// reranker_provider: z.string(),
// reranker_top_k: z.number(),
}).passthrough() as ZodType;
const chatEngineOptionsSchema = z.object({
external_engine_config: z.object({
stream_chat_api_url: z.string().optional().nullable(),
}).nullable().optional(),
clarify_question: z.boolean().nullable().optional(),
further_questions: z.boolean().nullable().optional(),
knowledge_base: kbOptionsSchema.nullable().optional(),
knowledge_graph: kgOptionsSchema.nullable().optional(),
llm: llmOptionsSchema.nullable().optional(),
post_verification_url: z.string().nullable().optional(),
post_verification_token: z.string().nullable().optional(),
hide_sources: z.boolean().nullable().optional(),
}).passthrough()
.refine(option => {
if (!option.knowledge_base?.linked_knowledge_bases?.length) {
if (option.knowledge_base?.linked_knowledge_base?.id != null) {
// Frontend temporary migration. Should be removed after backend removed linked_knowledge_base field.
option.knowledge_base.linked_knowledge_bases = [{
id: option.knowledge_base.linked_knowledge_base.id,
}];
delete option.knowledge_base.linked_knowledge_base;
}
}
return option;
}) satisfies ZodType;
const chatEngineSchema = z.object({
id: z.number(),
name: z.string(),
updated_at: zodJsonDate(),
created_at: zodJsonDate(),
deleted_at: zodJsonDate().nullable(),
engine_options: chatEngineOptionsSchema,
llm_id: z.number().nullable(),
fast_llm_id: z.number().nullable(),
reranker_id: z.number().nullable(),
is_default: z.boolean(),
is_public: z.boolean(),
}) satisfies ZodType;
export async function getDefaultChatEngineOptions (): Promise {
return await fetch(requestUrl('/api/v1/admin/chat-engines-default-config'), {
headers: await authenticationHeaders(),
})
.then(handleResponse(chatEngineOptionsSchema));
}
export async function listChatEngines ({ page = 1, size = 10 }: PageParams = {}): Promise> {
return await fetch(requestUrl('/api/v1/admin/chat-engines', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(chatEngineSchema)));
}
export async function getChatEngine (id: number): Promise {
return await fetch(requestUrl(`/api/v1/admin/chat-engines/${id}`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(chatEngineSchema));
}
export async function updateChatEngine (id: number, partial: Partial>): Promise {
await fetch(requestUrl(`/api/v1/admin/chat-engines/${id}`), {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
body: JSON.stringify(partial),
})
.then(handleErrors);
}
export async function createChatEngine (create: CreateChatEngineParams) {
return await fetch(requestUrl(`/api/v1/admin/chat-engines`), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
body: JSON.stringify(create),
})
.then(handleResponse(chatEngineSchema));
}
export async function deleteChatEngine (id: number): Promise {
await fetch(requestUrl(`/api/v1/admin/chat-engines/${id}`), {
method: 'DELETE',
headers: {
...await authenticationHeaders(),
},
})
.then(handleErrors);
}
export async function listPublicChatEngines ({ page = 1, size = 10 }: PageParams = {}): Promise> {
return await fetch(requestUrl('/api/v1/chat-engines', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(chatEngineSchema)));
}
export async function getPublicChatEngine (id: number): Promise {
return await fetch(requestUrl(`/api/v1/chat-engines/${id}`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(chatEngineSchema));
}
================================================
FILE: frontend/app/src/api/chats.ts
================================================
import type { ChatEngineOptions } from '@/api/chat-engines';
import { type KnowledgeGraph, KnowledgeGraphEntityType, knowledgeGraphSchema } from '@/api/graph';
import { bufferedReadableStreamTransformer } from '@/lib/buffered-readable-stream';
import { authenticationHeaders, handleErrors, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { parseStreamPart } from 'ai';
import { z, type ZodType } from 'zod';
type ClientEngineOptions = Omit;
export interface Chat {
title: string;
engine_id: number;
engine_options: ClientEngineOptions;
deleted_at: Date | null;
user_id: string | null;
browser_id: string | null;
updated_at: Date;
created_at: Date;
id: string;
origin: string | null;
}
export interface ChatDetail {
chat: Chat;
messages: ChatMessage[];
}
export const enum ChatMessageRole {
assistant = 'assistant',
user = 'user'
}
export interface ChatMessage {
id: number;
role: ChatMessageRole;
error: string | null;
trace_url: string | null;
finished_at: Date | null;
user_id: string | null;
created_at: Date;
updated_at: Date;
ordinal: number;
content: string;
sources: ChatMessageSource[];
chat_id: string;
post_verification_result_url: string | null;
}
export interface ChatMessageSource {
id: number;
name: string;
source_uri: string;
}
export const chatSchema = z.object({
title: z.string(),
engine_id: z.number(),
engine_options: z.object({}).passthrough().transform(value => value as never as ChatEngineOptions),
deleted_at: zodJsonDate().nullable(),
user_id: z.string().nullable(),
browser_id: z.string().nullable(),
updated_at: zodJsonDate(),
created_at: zodJsonDate(),
id: z.string(),
origin: z.string().nullable(),
}) satisfies ZodType;
const chatMessageSourceSchema = z.object({
id: z.number(),
name: z.string(),
source_uri: z.string(),
});
export const chatMessageSchema = z.object({
id: z.number(),
role: z.enum([ChatMessageRole.user, ChatMessageRole.assistant]),
error: z.string().nullable(),
trace_url: z.string().nullable(),
finished_at: zodJsonDate().nullable(),
user_id: z.string().nullable(),
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
ordinal: z.number(),
content: z.string(),
sources: chatMessageSourceSchema.array(),
chat_id: z.string(),
post_verification_result_url: z.string().url().nullable(),
}) satisfies ZodType;
const chatDetailSchema = z.object({
chat: chatSchema,
messages: chatMessageSchema.array(),
});
export interface FeedbackParams {
feedback_type: 'like' | 'dislike';
comment: string;
}
export interface PostChatParams {
chat_id?: string;
chat_engine?: string;
content: string;
headers?: HeadersInit;
signal?: AbortSignal;
}
export async function listChats ({ page = 1, size = 10 }: PageParams = {}): Promise> {
return await fetch(requestUrl('/api/v1/chats', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(chatSchema)));
}
export async function getChat (id: string): Promise {
return await fetch(requestUrl(`/api/v1/chats/${id}`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(chatDetailSchema));
}
export async function deleteChat (id: string): Promise {
await fetch(requestUrl(`/api/v1/chats/${id}`), {
method: 'delete',
headers: await authenticationHeaders(),
}).then(handleErrors);
}
export async function postFeedback (chatMessageId: number, feedback: FeedbackParams) {
return await fetch(requestUrl(`/api/v1/chat-messages/${chatMessageId}/feedback`), {
method: 'post',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
credentials: 'include',
body: JSON.stringify(feedback),
}).then(handleErrors);
}
export async function getChatMessageSubgraph (chatMessageId: number): Promise {
return await fetch(requestUrl(`/api/v1/chat-messages/${chatMessageId}/subgraph`), {
headers: await authenticationHeaders(),
credentials: 'include',
})
.then(handleResponse(knowledgeGraphSchema));
}
export async function getChatMessageRecommendedQuestions (chatMessageId: number) {
return await fetch(requestUrl(`/api/v1/chat-messages/${chatMessageId}/recommended-questions`), {
headers: await authenticationHeaders(),
credentials: 'include',
})
.then(handleResponse(z.string().array()));
}
export async function reloadChatMessageRecommendedQuestions (chatMessageId: number) {
return await fetch(requestUrl(`/api/v1/chat-messages/${chatMessageId}/recommended-questions`), {
method: 'POST',
headers: await authenticationHeaders(),
credentials: 'include',
})
.then(handleResponse(z.string().array()));
}
export async function* chat ({ chat_id, chat_engine, content, headers: headersInit, signal }: PostChatParams, onResponse?: (response: Response) => void) {
const headers = new Headers(headersInit);
headers.set('Content-Type', 'application/json');
for (let [key, value] of Object.entries(await authenticationHeaders())) {
headers.set(key, value);
}
const response = await fetch(requestUrl(`/api/v1/chats`), {
method: 'POST',
headers,
credentials: 'include',
body: JSON.stringify({
chat_id,
chat_engine,
stream: true,
messages: [{
'role': 'user',
content,
}],
}),
signal,
}).then(handleErrors);
onResponse?.(response);
if (!response.body) {
throw new Error(`${response.status} ${response.statusText} Empty response body`);
}
const reader = response.body.pipeThrough(bufferedReadableStreamTransformer()).getReader();
while (true) {
const chunk = await reader.read();
if (chunk.done) {
break;
}
if (!!chunk.value.trim()) {
yield parseStreamPart(chunk.value);
}
}
}
================================================
FILE: frontend/app/src/api/commons.ts
================================================
import { handleErrors, requestUrl } from '@/lib/request';
export async function setDefault (type: 'embedding-models' | 'llms' | 'reranker-models', id: number) {
await fetch(requestUrl(`/api/v1/admin/${type}/${id}/set_default`), {
method: 'PUT',
}).then(handleErrors);
}
================================================
FILE: frontend/app/src/api/datasources.ts
================================================
import { type IndexProgress, indexSchema, type IndexTotalStats, totalSchema } from '@/api/rag';
import { authenticationHeaders, handleErrors, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType } from 'zod';
export interface DatasourceBase {
id: number;
name: string;
}
type DatasourceSpec = ({
data_source_type: 'file'
config: { file_id: number, file_name: string }[]
} | {
data_source_type: 'web_sitemap'
config: { url: string }
} | {
data_source_type: 'web_single_page'
config: { urls: string[] }
})
export type Datasource = DatasourceBase & DatasourceSpec;
export type DataSourceIndexProgress = {
vector_index: IndexProgress
documents: IndexTotalStats
chunks: IndexTotalStats
kg_index?: IndexProgress
relationships?: IndexTotalStats
}
export interface BaseCreateDatasourceParams {
name: string;
}
export type CreateDatasourceSpecParams = ({
data_source_type: 'file'
config: { file_id: number, file_name: string }[]
} | {
data_source_type: 'web_single_page'
config: { urls: string[] }
} | {
data_source_type: 'web_sitemap'
config: { url: string }
});
export type CreateDatasourceParams = BaseCreateDatasourceParams & CreateDatasourceSpecParams;
export interface Upload {
created_at?: Date;
updated_at?: Date;
id: number;
name: string;
size: number;
path: string;
mime_type: string;
user_id: string;
}
export type DatasourceVectorIndexError = {
document_id: number
document_name: string
source_uri: string
error: string | null
}
export type DatasourceKgIndexError = {
document_id: number
document_name: string
chunk_id: string
source_uri: string
error: string | null
}
const datasourceSpecSchema = z.discriminatedUnion('data_source_type', [
z.object({
data_source_type: z.literal('file'),
config: z.array(z.object({ file_id: z.number(), file_name: z.string() })),
}),
z.object({
data_source_type: z.enum(['web_single_page']),
config: z.object({ urls: z.string().array() }).or(z.object({ url: z.string() })).transform(obj => {
if ('url' in obj) {
return { urls: [obj.url] };
} else {
return obj;
}
}),
}),
z.object({
data_source_type: z.enum(['web_sitemap']),
config: z.object({ url: z.string() }),
})],
) satisfies ZodType;
export const datasourceSchema = z.object({
id: z.number(),
name: z.string(),
}).and(datasourceSpecSchema) satisfies ZodType;
const uploadSchema = z.object({
id: z.number(),
name: z.string(),
size: z.number(),
path: z.string(),
mime_type: z.string(),
user_id: z.string(),
created_at: zodJsonDate().optional(),
updated_at: zodJsonDate().optional(),
}) satisfies ZodType;
z.object({
vector_index: indexSchema,
documents: totalSchema,
chunks: totalSchema,
kg_index: indexSchema.optional(),
relationships: totalSchema.optional(),
}) satisfies ZodType;
export async function listDataSources (kbId: number, { page = 1, size = 10 }: PageParams = {}): Promise> {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/datasources`, { page, size }), {
headers: await authenticationHeaders(),
}).then(handleResponse(zodPage(datasourceSchema)));
}
export async function getDatasource (kbId: number, id: number): Promise {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/datasources/${id}`), {
headers: await authenticationHeaders(),
}).then(handleResponse(datasourceSchema));
}
export async function deleteDatasource (kbId: number, id: number): Promise {
await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/datasources/${id}`), {
method: 'DELETE',
headers: await authenticationHeaders(),
}).then(handleErrors);
}
export async function createDatasource (kbId: number, params: CreateDatasourceParams) {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/datasources`), {
method: 'POST',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
}).then(handleResponse(datasourceSchema));
}
export async function updateDatasource (kbId: number, id: number, params: { name: string }) {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/datasources/${id}`), {
method: 'PUT',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
}).then(handleResponse(datasourceSchema));
}
export async function uploadFiles (files: File[]) {
const formData = new FormData();
files.forEach((file) => {
formData.append('files', file);
});
return fetch(requestUrl(`/api/v1/admin/uploads`), {
method: 'POST',
headers: {
...await authenticationHeaders(),
},
body: formData,
}).then(handleResponse(uploadSchema.array()));
}
================================================
FILE: frontend/app/src/api/documents.ts
================================================
import { indexStatuses } from '@/api/rag';
import { authenticationHeaders, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType } from 'zod';
export const mimeTypes = [
{ name: 'Text', value: 'text/plain' },
{ name: 'Markdown', value: 'text/markdown' },
{ name: 'Pdf', value: 'application/pdf' },
{ name: 'Microsoft Word (docx)', value: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' },
{ name: 'Microsoft PowerPoint (pptx)', value: 'application/vnd.openxmlformats-officedocument.presentationml.presentation' },
{ name: 'Microsoft Excel (xlsx)', value: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' },
] as const satisfies MimeType[];
const mimeValues: (typeof mimeTypes)[number]['value'] = mimeTypes.map(m => m.value) as never;
export interface Document {
id: number,
name: string,
created_at?: Date | undefined;
updated_at?: Date | undefined
last_modified_at: Date,
hash: string
content: string
meta: object,
mime_type: string,
source_uri: string,
index_status: string,
index_result?: unknown
data_source: {
id: number
name: string
}
knowledge_base: {
id: number
name: string
} | null
}
export const documentSchema = z.object({
id: z.number(),
name: z.string(),
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
last_modified_at: zodJsonDate(),
hash: z.string(),
content: z.string(),
meta: z.object({}).passthrough(),
mime_type: z.string(),
source_uri: z.string(),
index_status: z.string(),
index_result: z.unknown(),
data_source: z.object({
id: z.number(),
name: z.string(),
}),
knowledge_base: z.object({
id: z.number(),
name: z.string(),
}).nullable(),
}) satisfies ZodType;
const zDate = z.coerce.date().or(z.literal('').transform(() => undefined)).optional();
const zDateRange = z.tuple([zDate, zDate]).optional();
export const listDocumentsFiltersSchema = z.object({
search: z.string().optional(),
knowledge_base_id: z.number().optional(),
created_at: zDateRange,
updated_at: zDateRange,
last_modified_at: zDateRange,
mime_type: z.enum(mimeValues).optional(),
index_status: z.enum(indexStatuses).optional(),
});
export type ListDocumentsTableFilters = z.infer;
export async function listDocuments ({ page = 1, size = 10, knowledge_base_id, search, ...filters }: PageParams & ListDocumentsTableFilters = {}): Promise> {
const apiFilters = {
...filters,
knowledge_base_id,
search: search
};
const api_url = knowledge_base_id != null ? `/api/v1/admin/knowledge_bases/${knowledge_base_id}/documents` : '/api/v1/admin/documents';
return await fetch(requestUrl(api_url, { page, size, ...apiFilters }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(documentSchema)));
}
export interface MimeType {
name: string;
value: string;
}
================================================
FILE: frontend/app/src/api/embedding-models.ts
================================================
import { type ProviderOption, providerOptionSchema } from '@/api/providers';
import { authenticationHeaders, handleNullableResponse, handleResponse, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType, type ZodTypeDef } from 'zod';
export interface EmbeddingModelSummary {
id: number;
name: string;
provider: string;
model: string;
vector_dimension: number,
is_default: boolean
}
export interface EmbeddingModel extends EmbeddingModelSummary {
config?: any;
created_at?: Date | null;
updated_at?: Date | null;
}
export interface EmbeddingModelOption extends ProviderOption {
default_embedding_model: string;
embedding_model_description: string;
}
export interface CreateEmbeddingModel {
name: string;
provider: string;
model: string;
vector_dimension: number;
config?: any;
credentials: string | object;
}
export interface UpdateEmbeddingModel {
name?: string;
config?: any;
credentials?: string | object;
}
export const embeddingModelSummarySchema = z.object({
id: z.number(),
name: z.string(),
provider: z.string(),
model: z.string(),
vector_dimension: z.number(),
is_default: z.boolean(),
}) satisfies ZodType;
const embeddingModelSchema = embeddingModelSummarySchema.extend({
config: z.any(),
created_at: zodJsonDate().nullable().optional(),
updated_at: zodJsonDate().nullable().optional(),
}) satisfies ZodType;
const embeddingModelOptionSchema = providerOptionSchema.and(z.object({
default_embedding_model: z.string(),
embedding_model_description: z.string(),
})) satisfies ZodType;
export async function listEmbeddingModelOptions () {
return await fetch(requestUrl(`/api/v1/admin/embedding-models/providers/options`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(embeddingModelOptionSchema.array()));
}
export async function getEmbeddingModel (id: number) {
return await fetch(requestUrl(`/api/v1/admin/embedding-models/${id}`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(embeddingModelSchema));
}
export async function listEmbeddingModels (params: PageParams) {
return await fetch(requestUrl(`/api/v1/admin/embedding-models`, params), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(embeddingModelSchema)));
}
export async function createEmbeddingModel (create: CreateEmbeddingModel) {
return await fetch(requestUrl(`/api/v1/admin/embedding-models`), {
method: 'POST',
body: JSON.stringify(create),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
}).then(handleResponse(embeddingModelSchema));
}
export async function updateEmbeddingModel (id: number, update: UpdateEmbeddingModel) {
return await fetch(requestUrl(`/api/v1/admin/embedding-models/${id}`), {
method: 'PUT',
body: JSON.stringify(update),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
}).then(handleResponse(embeddingModelSchema));
}
export async function testEmbeddingModel (createEmbeddingModel: CreateEmbeddingModel) {
return await fetch(requestUrl(`/api/v1/admin/embedding-models/test`), {
method: 'POST',
body: JSON.stringify(createEmbeddingModel),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
})
.then(handleResponse(z.object({
success: z.boolean(),
error: z.string().optional(),
})));
}
================================================
FILE: frontend/app/src/api/evaluations.ts
================================================
import { authenticationHeaders, handleErrors, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJson, zodJsonDate } from '@/lib/zod';
import { z, type ZodType } from 'zod';
export interface EvaluationDataset {
id: number;
name: string;
user_id: string;
created_at: Date;
updated_at: Date;
}
export interface EvaluationDatasetItem {
created_at: Date;
updated_at: Date;
id: number;
query: string;
reference: string;
retrieved_contexts: string[];
extra: any;
evaluation_dataset_id: number;
}
export interface EvaluationTask {
id: number;
name: string;
user_id: string;
created_at: Date;
updated_at: Date;
dataset_id: number;
}
export interface EvaluationTaskWithSummary extends EvaluationTask {
summary: EvaluationTaskSummary;
}
export const EvaluationTaskSummaryMetrics = [
'avg_factual_correctness',
'avg_semantic_similarity',
'min_factual_correctness',
'min_semantic_similarity',
'max_factual_correctness',
'max_semantic_similarity',
'std_factual_correctness',
'std_semantic_similarity',
] as const;
export type EvaluationTaskSummaryMetric = typeof EvaluationTaskSummaryMetrics[number];
export interface EvaluationTaskSummary extends Record {
not_start: number;
succeed: number;
errored: number;
progressing: number;
}
export type EvaluationTaskItemStatus = 'not_start' | 'evaluating' | 'done' | 'error' | 'cancel'
export interface EvaluationTaskItem {
created_at: Date;
updated_at: Date;
id: number;
chat_engine: string;
status: EvaluationTaskItemStatus;
query: string;
reference: string;
response: string | null;
retrieved_contexts: string[] | null;
extra: any | null;
error_msg: string | null;
factual_correctness: number | null;
semantic_similarity: number | null;
evaluation_task_id: number;
}
export interface CreateEvaluationDatasetParams {
name: string;
upload_id?: number;
}
export interface UpdateEvaluationDatasetParams {
name: string;
}
export interface CreateEvaluationDatasetItemParams {
query: string;
reference: string;
retrieved_contexts: string[];
extra: any;
}
export interface UpdateEvaluationDatasetItemParams {
query: string;
reference: string;
retrieved_contexts: string[];
extra: any;
}
export interface CreateEvaluationTaskParams {
name: string;
evaluation_dataset_id: number;
chat_engine?: string;
run_size?: number;
}
const evaluationDatasetSchema = z.object({
id: z.number(),
name: z.string(),
user_id: z.string(),
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
}) satisfies ZodType;
const evaluationDatasetItemSchema = z.object({
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
id: z.number(),
query: z.string(),
reference: z.string(),
retrieved_contexts: z.string().array(),
extra: zodJson(),
evaluation_dataset_id: z.number(),
}) satisfies ZodType;
const evaluationTaskSummarySchema = z.object({
not_start: z.number(),
succeed: z.number(),
errored: z.number(),
progressing: z.number(),
avg_factual_correctness: z.number().nullable(),
avg_semantic_similarity: z.number().nullable(),
min_factual_correctness: z.number().nullable(),
min_semantic_similarity: z.number().nullable(),
max_factual_correctness: z.number().nullable(),
max_semantic_similarity: z.number().nullable(),
std_factual_correctness: z.number().nullable(),
std_semantic_similarity: z.number().nullable(),
}) satisfies ZodType;
const evaluationTaskSchema = z.object({
id: z.number(),
name: z.string(),
user_id: z.string(),
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
dataset_id: z.number(),
}) satisfies ZodType;
const evaluationTaskWithSummarySchema = evaluationTaskSchema.extend({
summary: evaluationTaskSummarySchema,
});
const evaluationTaskItemSchema = z.object({
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
id: z.number(),
chat_engine: z.string(),
status: z.enum(['not_start', 'evaluating', 'done', 'error', 'cancel']),
query: z.string(),
reference: z.string(),
response: z.string().nullable(),
retrieved_contexts: z.string().array().nullable(),
extra: zodJson().nullable(),
error_msg: z.string().nullable(),
factual_correctness: z.number().nullable(),
semantic_similarity: z.number().nullable(),
evaluation_task_id: z.number(),
}) satisfies ZodType;
// Datasets
export async function listEvaluationDatasets ({ ...params }: PageParams & { keyword?: string }): Promise> {
return fetch(requestUrl('/api/v1/admin/evaluation/datasets', params), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(evaluationDatasetSchema)));
}
export async function createEvaluationDataset (params: CreateEvaluationDatasetParams): Promise {
return fetch(requestUrl('/api/v1/admin/evaluation/datasets'), {
method: 'POST',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
.then(handleResponse(evaluationDatasetSchema));
}
export async function updateEvaluationDataset (id: number, params: UpdateEvaluationDatasetParams): Promise {
return fetch(requestUrl(`/api/v1/admin/evaluation/datasets/${id}`), {
method: 'PUT',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
.then(handleResponse(evaluationDatasetSchema));
}
export async function deleteEvaluationDataset (id: number): Promise {
await fetch(requestUrl(`/api/v1/admin/evaluation/datasets/${id}`), {
method: 'DELETE',
headers: {
...await authenticationHeaders(),
},
});
}
// Dataset Items
export async function listEvaluationDatasetItems (datasetId: number, { ...params }: PageParams & { keyword?: string }): Promise> {
return fetch(requestUrl(`/api/v1/admin/evaluation/datasets/${datasetId}/dataset-items`, params), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(evaluationDatasetItemSchema)));
}
export async function createEvaluationDatasetItem (datasetId: number, params: CreateEvaluationDatasetItemParams): Promise {
return fetch(requestUrl(`/api/v1/admin/evaluation/dataset-items`), {
method: 'POST',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify({
evaluation_dataset_id: datasetId,
...params,
}),
})
.then(handleResponse(evaluationDatasetItemSchema));
}
export async function updateEvaluationDatasetItem (datasetId: number, id: number, params: UpdateEvaluationDatasetItemParams): Promise {
return fetch(requestUrl(`/api/v1/admin/evaluation/dataset-items/${id}`), {
method: 'PUT',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify({
evaluation_dataset_id: datasetId,
...params,
}),
})
.then(handleResponse(evaluationDatasetItemSchema));
}
export async function getEvaluationDatasetItem (datasetId: number, id: number) {
return await fetch(requestUrl(`/api/v1/admin/evaluation/dataset-items/${id}`), {
method: 'GET',
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(evaluationDatasetItemSchema));
}
export async function deleteEvaluationDatasetItem (datasetId: number, id: number): Promise {
await fetch(requestUrl(`/api/v1/admin/evaluation/dataset-items/${id}`), {
method: 'DELETE',
headers: {
...await authenticationHeaders(),
},
})
.then(handleErrors);
}
// Tasks
export async function createEvaluationTask (params: CreateEvaluationTaskParams): Promise {
return fetch(requestUrl('/api/v1/admin/evaluation/tasks'), {
method: 'POST',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
.then(handleResponse(evaluationTaskSchema));
}
export async function listEvaluationTasks ({ ...params }: PageParams & { keyword?: string }): Promise> {
return fetch(requestUrl('/api/v1/admin/evaluation/tasks', params), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(evaluationTaskWithSummarySchema)));
}
export async function getEvaluationTaskWithSummary (id: number): Promise {
return fetch(requestUrl(`/api/v1/admin/evaluation/tasks/${id}/summary`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(evaluationTaskWithSummarySchema));
}
export async function cancelEvaluationTask (id: number): Promise {
await fetch(requestUrl(`/api/v1/admin/evaluation/tasks/${id}`), {
method: 'DELETE',
headers: await authenticationHeaders(),
})
.then(handleErrors);
}
export async function listEvaluationTaskItems (id: number, params: PageParams & { keyword?: string }): Promise> {
return fetch(requestUrl(`/api/v1/admin/evaluation/tasks/${id}/items`, params), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(evaluationTaskItemSchema)));
}
================================================
FILE: frontend/app/src/api/feedbacks.ts
================================================
import { authenticationHeaders, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType } from 'zod';
export const enum FeedbackType {
like = 'like',
dislike = 'dislike',
}
export interface Feedback {
id: number;
feedback_type: FeedbackType;
comment: string;
chat_id: string;
chat_message_id: number;
created_at: Date;
updated_at: Date;
user_id: string | null;
chat_title: string;
origin: string | null;
chat_origin: string | null;
chat_message_content: string;
user_email: string | null;
}
const feedbackSchema = z.object({
id: z.number(),
feedback_type: z.enum([FeedbackType.like, FeedbackType.dislike]),
comment: z.string(),
chat_id: z.string(),
chat_message_id: z.number(),
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
user_id: z.string().nullable(),
user_email: z.string().nullable(),
origin: z.string().nullable(),
chat_title: z.string(),
chat_origin: z.string().nullable(),
chat_message_content: z.string(),
}) satisfies ZodType;
export async function listFeedbacks ({ page = 1, size = 10 }: PageParams = {}): Promise> {
return await fetch(requestUrl('/api/v1/admin/feedbacks', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(feedbackSchema)));
}
================================================
FILE: frontend/app/src/api/graph.ts
================================================
import { authenticationHeaders, handleResponse, requestUrl } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { bufferedReadableStreamTransformer } from '@/lib/buffered-readable-stream';
import { z, type ZodType } from 'zod';
export interface KnowledgeGraph {
entities: KnowledgeGraphEntity[];
relationships: KnowledgeGraphRelationship[];
}
export const enum KnowledgeGraphEntityType {
original = 'original',
synopsis = 'synopsis',
}
export interface KnowledgeGraphEntity {
id: number;
knowledge_base_id?: number | null;
name: string;
description: string;
meta: object;
entity_type: KnowledgeGraphEntityType;
synopsis_info?: {
entities: number[]
topic: string
} | null;
}
export interface KnowledgeGraphRelationship {
id: number;
source_entity_id: number;
target_entity_id: number;
knowledge_base_id?: number | null;
description: string;
meta: object;
weight: number;
}
export const entitySchema = z.object({
id: z.number(),
knowledge_base_id: z.number().nullable().optional(),
name: z.string(),
description: z.string(),
meta: z.object({}).passthrough(),
entity_type: z.enum([KnowledgeGraphEntityType.original, KnowledgeGraphEntityType.synopsis]),
synopsis_info: z.object({
entities: z.number().array(),
topic: z.string(),
}).nullish(),
}) satisfies ZodType;
export const relationshipSchema = z.object({
id: z.number(),
knowledge_base_id: z.number().nullable().optional(),
source_entity_id: z.number(),
target_entity_id: z.number(),
description: z.string(),
last_modified_at: zodJsonDate().nullish(),
meta: z.object({}).passthrough(),
weight: z.number(),
}) satisfies ZodType;
export const knowledgeGraphSchema = z.object({
entities: entitySchema.array(),
relationships: relationshipSchema.array(),
}) satisfies ZodType;
export interface UpdateEntityParams {
name: string | null;
description: string | null;
meta: object | null;
}
export interface CreateSynopsisEntityParams {
name: string;
description: string;
meta: object;
topic: string;
entities: number[];
}
export interface UpdateRelationshipParams {
description: string | null;
meta: object | null;
weight: number | null;
}
export interface GraphSearchParams {
query: string;
include_meta?: boolean;
depth?: number;
with_degree?: boolean;
}
export interface KBRetrieveKnowledgeGraphParams {
query: string;
llm_id: number;
retrieval_config: {
knowledge_graph: KnowledgeGraphRetrievalConfig;
}
}
export interface KnowledgeGraphRetrievalConfig {
depth?: number;
include_meta?: boolean;
with_degree?: boolean;
metadata_filter?: {
enabled?: boolean;
filters?: Record
}
}
export async function search (kbId: number, params: GraphSearchParams) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/search`), {
method: 'post',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
}).then(handleResponse(knowledgeGraphSchema));
}
export async function searchEntity (kbId: number, query: string, top_k: number = 10) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/entities/search`, { query, top_k }), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(entitySchema.array()));
}
export async function getEntity (kbId: number, id: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/entities/${id}`), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(entitySchema));
}
export async function updateEntity (kbId: number, id: number, params: UpdateEntityParams) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/entities/${id}`), {
method: 'put',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
}).then(handleResponse(entitySchema));
}
export async function createSynopsisEntity (kbId: number, params: CreateSynopsisEntityParams) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/entities/synopsis`), {
method: 'post',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
}).then(handleResponse(entitySchema));
}
export async function getEntitySubgraph (kbId: number, id: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/entities/${id}/subgraph`), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(knowledgeGraphSchema));
}
export async function getEntireKnowledgeGraph (kbId: number, params: KBRetrieveKnowledgeGraphParams) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/entire_graph`), {
method: 'post',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
.then(handleResponse(knowledgeGraphSchema));
}
export async function streamEntireKnowledgeGraph (kbId: number): Promise {
const entities: KnowledgeGraphEntity[] = [];
const relationships: KnowledgeGraphRelationship[] = [];
const response = await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/entire_graph/stream`), {
method: 'GET',
headers: await authenticationHeaders(),
credentials: 'include',
});
if (!response.ok) {
throw new Error(`${response.status} ${response.statusText}`);
}
if (!response.body) {
throw new Error('Empty response body');
}
const reader = response.body.pipeThrough(bufferedReadableStreamTransformer()).getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
if (value.trim() && value.startsWith('data: ')) {
const dataStr = value.substring(6).trim();
if (dataStr) {
try {
const data = JSON.parse(dataStr);
if (data.type === 'entities') {
entities.push(...data.data);
// console.log(`Received ${data.data.length} entities, total: ${entities.length}`);
} else if (data.type === 'relationships') {
relationships.push(...data.data);
// console.log(`Received ${data.data.length} relationships, total: ${relationships.length}`);
} else if (data.type === 'complete') {
// console.log(`Streaming complete. Final counts - entities: ${entities.length}, relationships: ${relationships.length}`);
return { entities, relationships };
}
} catch (error) {
console.warn('Failed to parse streaming data:', error, 'Data:', dataStr);
}
}
}
}
} finally {
reader.releaseLock();
}
return { entities, relationships };
}
export async function getRelationship (kbId: number, id: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/relationships/${id}`), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(relationshipSchema));
}
export async function updateRelationship (kbId: number, id: number, params: UpdateRelationshipParams) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kbId}/graph/relationships/${id}`), {
method: 'put',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
}).then(handleResponse(relationshipSchema));
}
================================================
FILE: frontend/app/src/api/knowledge-base.ts
================================================
import { type BaseCreateDatasourceParams, type CreateDatasourceSpecParams, type Datasource, type DatasourceKgIndexError, datasourceSchema, type DatasourceVectorIndexError } from '@/api/datasources';
import { documentSchema } from '@/api/documents';
import { type EmbeddingModelSummary, embeddingModelSummarySchema } from '@/api/embedding-models';
import { type LLMSummary, llmSummarySchema } from '@/api/llms';
import { type IndexProgress, indexSchema, indexStatusSchema, type IndexTotalStats, totalSchema } from '@/api/rag';
import { authenticationHeaders, handleErrors, handleResponse, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType } from 'zod';
export type KnowledgeBaseIndexMethod = 'vector' | 'knowledge_graph';
export interface CreateKnowledgeBaseParams {
name: string;
description?: string | null;
index_methods: KnowledgeBaseIndexMethod[];
llm_id?: number | null;
embedding_model_id?: number | null;
data_sources: (BaseCreateDatasourceParams & CreateDatasourceSpecParams)[];
}
export interface UpdateKnowledgeBaseParams {
name?: string;
description?: string | null;
}
export interface KnowledgeBaseSummary {
id: number;
name: string;
description: string | null;
index_methods: KnowledgeBaseIndexMethod[];
documents_total?: number;
data_sources_total?: number;
created_at: Date;
updated_at: Date;
creator: {
id: string;
};
}
export interface KnowledgeBase extends KnowledgeBaseSummary {
data_sources: Datasource[];
llm?: LLMSummary | null;
embedding_model?: EmbeddingModelSummary | null;
chunking_config: KnowledgeBaseChunkingConfig | null;
}
export type KnowledgeGraphIndexProgress = {
vector_index: IndexProgress
documents: IndexTotalStats
chunks: IndexTotalStats
kg_index?: IndexProgress
entities?: IndexTotalStats
relationships?: IndexTotalStats
}
export type KnowledgeBaseSplitterType = KnowledgeBaseChunkingSplitterRule['splitter'];
export type KnowledgeBaseChunkingSentenceSplitterConfig = {
chunk_size: number
chunk_overlap: number
paragraph_separator: string
}
export type KnowledgeBaseChunkingMarkdownSplitterConfig = {
chunk_size: number
chunk_header_level: number
}
export type KnowledgeBaseChunkingSentenceSplitterRule = {
splitter: 'SentenceSplitter'
splitter_config: KnowledgeBaseChunkingSentenceSplitterConfig
}
export type KnowledgeBaseChunkingMarkdownSplitterRule = {
splitter: 'MarkdownSplitter'
splitter_config: KnowledgeBaseChunkingMarkdownSplitterConfig
}
export type KnowledgeBaseChunkingSplitterRule = KnowledgeBaseChunkingSentenceSplitterRule | KnowledgeBaseChunkingMarkdownSplitterRule;
export type KnowledgeBaseChunkingConfigGeneral = {
mode: 'general'
} & KnowledgeBaseChunkingSentenceSplitterConfig;
export type KnowledgeBaseChunkingConfigAdvanced = {
mode: 'advanced'
rules: {
'text/plain': KnowledgeBaseChunkingSplitterRule;
'text/markdown': KnowledgeBaseChunkingSplitterRule
}
}
export type KnowledgeBaseChunkingConfig = KnowledgeBaseChunkingConfigGeneral | KnowledgeBaseChunkingConfigAdvanced;
export type KnowledgeGraphDocumentChunk = z.infer;
const knowledgeBaseChunkingSentenceSplitterConfigSchema = z.object({
chunk_size: z.number().int().min(1),
chunk_overlap: z.number().int().min(0),
paragraph_separator: z.string(),
}) satisfies z.ZodType;
const knowledgeBaseChunkingMarkdownSplitterConfigSchema = z.object({
chunk_size: z.number().int().min(1),
chunk_header_level: z.number().int().min(1).max(6),
}) satisfies z.ZodType;
const knowledgeBaseChunkingSplitterRuleSchema = z.discriminatedUnion('splitter', [
z.object({
splitter: z.literal('MarkdownSplitter'),
splitter_config: knowledgeBaseChunkingMarkdownSplitterConfigSchema,
}),
z.object({
splitter: z.literal('SentenceSplitter'),
splitter_config: knowledgeBaseChunkingSentenceSplitterConfigSchema,
}),
]) satisfies z.ZodType;
export const knowledgeBaseChunkingConfigSchema = z.discriminatedUnion('mode', [
z.object({
mode: z.literal('general'),
chunk_size: z.number().int().min(1),
chunk_overlap: z.number().int().min(0),
paragraph_separator: z.string(),
}),
z.object({
mode: z.literal('advanced'),
rules: z.object({
'text/plain': knowledgeBaseChunkingSplitterRuleSchema,
'text/markdown': knowledgeBaseChunkingSplitterRuleSchema,
}),
}),
]) satisfies z.ZodType;
const knowledgeBaseSummarySchema = z.object({
id: z.number(),
name: z.string(),
description: z.string().nullable(),
index_methods: z.enum(['vector', 'knowledge_graph']).array(),
documents_total: z.number().optional(),
data_sources_total: z.number().optional(),
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
creator: z.object({
id: z.string(),
}),
}) satisfies ZodType;
const knowledgeBaseSchema = knowledgeBaseSummarySchema.extend({
data_sources: datasourceSchema.array(),
llm: llmSummarySchema.nullable().optional(),
embedding_model: embeddingModelSummarySchema.nullable().optional(),
chunking_config: knowledgeBaseChunkingConfigSchema.nullable(),
}) satisfies ZodType;
const knowledgeGraphIndexProgressSchema = z.object({
vector_index: indexSchema,
documents: totalSchema,
chunks: totalSchema,
kg_index: indexSchema.optional(),
entities: totalSchema.optional(),
relationships: totalSchema.optional(),
}) satisfies ZodType;
const knowledgeGraphDocumentChunkSchema = z.object({
id: z.string(),
document_id: z.number(),
hash: z.string(),
text: z.string(),
meta: z.object({}).passthrough(),
embedding: z.number().array(),
relations: z.any(),
source_uri: z.string(),
index_status: indexStatusSchema,
index_result: z.string().nullable(),
created_at: zodJsonDate(),
updated_at: zodJsonDate(),
});
const vectorIndexErrorSchema = z.object({
document_id: z.number(),
document_name: z.string(),
source_uri: z.string(),
error: z.string().nullable(),
}) satisfies ZodType;
const kgIndexErrorSchema = z.object({
document_id: z.number(),
document_name: z.string(),
chunk_id: z.string(),
source_uri: z.string(),
error: z.string().nullable(),
}) satisfies ZodType;
const knowledgeBaseLinkedChatEngine = z.object({
id: z.number(),
name: z.string(),
is_default: z.boolean(),
});
export async function listKnowledgeBases ({ page = 1, size = 10 }: PageParams) {
return await fetch(requestUrl('/api/v1/admin/knowledge_bases', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(knowledgeBaseSummarySchema)));
}
export async function getKnowledgeBaseById (id: number): Promise {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(knowledgeBaseSchema));
}
export async function getKnowledgeBaseDocumentChunks (id: number, documentId: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/documents/${documentId}/chunks`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(knowledgeGraphDocumentChunkSchema.array()));
}
export async function getKnowledgeBaseDocument (id: number, documentId: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/documents/${documentId}`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(documentSchema.omit({ knowledge_base: true, data_source: true })));
}
export async function getKnowledgeBaseLinkedChatEngines (id: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/linked_chat_engines`), {
headers: await authenticationHeaders(),
})
.then(handleResponse(knowledgeBaseLinkedChatEngine.array()));
}
export async function deleteKnowledgeBaseDocument (id: number, documentId: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/documents/${documentId}`), {
method: 'DELETE',
headers: await authenticationHeaders(),
})
.then(handleErrors);
}
export async function rebuildKBDocumentIndex (kb_id: number, doc_id: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${kb_id}/documents/${doc_id}/reindex`), {
method: 'POST',
headers: await authenticationHeaders(),
})
.then(handleErrors);
}
export async function createKnowledgeBase (params: CreateKnowledgeBaseParams) {
return await fetch(requestUrl('/api/v1/admin/knowledge_bases'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
body: JSON.stringify(params),
}).then(handleResponse(knowledgeBaseSchema));
}
export async function updateKnowledgeBase (id: number, params: UpdateKnowledgeBaseParams) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}`), {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
body: JSON.stringify(params),
}).then(handleResponse(knowledgeBaseSchema));
}
export async function getKnowledgeGraphIndexProgress (id: number): Promise {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/overview`), {
headers: await authenticationHeaders(),
}).then(handleResponse(knowledgeGraphIndexProgressSchema));
}
export async function listKnowledgeBaseVectorIndexErrors (id: number, { page = 1, size = 10 }: PageParams = {}) {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/vector-index-errors`, { page, size }), {
headers: await authenticationHeaders(),
}).then(handleResponse(zodPage(vectorIndexErrorSchema)));
}
export async function listKnowledgeBaseKgIndexErrors (id: number, { page = 1, size = 10 }: PageParams = {}) {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/kg-index-errors`, { page, size }), {
headers: await authenticationHeaders(),
}).then(handleResponse(zodPage(kgIndexErrorSchema)));
}
export async function retryKnowledgeBaseAllFailedTasks (id: number) {
return fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}/retry-failed-index-tasks`), {
method: 'POST',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
}).then(handleErrors);
}
export async function deleteKnowledgeBase (id: number) {
return await fetch(requestUrl(`/api/v1/admin/knowledge_bases/${id}`), {
method: 'DELETE',
headers: await authenticationHeaders(),
})
.then(handleErrors);
}
================================================
FILE: frontend/app/src/api/llms.ts
================================================
import { type ProviderOption, providerOptionSchema } from '@/api/providers';
import { authenticationHeaders, handleErrors, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType, type ZodTypeDef } from 'zod';
export interface LLMSummary {
id: number;
name: string;
provider: string;
model: string;
is_default: boolean;
}
export interface LLM extends LLMSummary {
config?: any;
created_at: Date | null;
updated_at: Date | null;
}
export interface LlmOption extends ProviderOption {
default_llm_model: string;
llm_model_description: string;
}
export interface CreateLLM {
name: string;
provider: string;
model: string;
config?: any;
is_default?: boolean;
credentials: string | object;
}
export interface UpdateLLM {
name?: string;
config?: any;
credentials?: string | object;
}
export const llmSummarySchema = z.object({
id: z.number(),
name: z.string(),
provider: z.string(),
model: z.string(),
is_default: z.boolean(),
}) satisfies ZodType;
const llmSchema = llmSummarySchema.extend({
config: z.any(),
created_at: zodJsonDate().nullable(),
updated_at: zodJsonDate().nullable(),
}) satisfies ZodType;
const llmOptionSchema = providerOptionSchema.and(z.object({
default_llm_model: z.string(),
llm_model_description: z.string(),
})) satisfies ZodType;
export async function listLlmOptions () {
return await fetch(requestUrl(`/api/v1/admin/llms/providers/options`), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(llmOptionSchema.array()));
}
export async function listLlms ({ page = 1, size = 10 }: PageParams = {}): Promise> {
return await fetch(requestUrl('/api/v1/admin/llms', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(llmSchema)));
}
export async function getLlm (id: number): Promise {
return await fetch(requestUrl(`/api/v1/admin/llms/${id}`), {
headers: await authenticationHeaders(),
}).then(handleResponse(llmSchema));
}
export async function createLlm (create: CreateLLM) {
return await fetch(requestUrl(`/api/v1/admin/llms`), {
method: 'POST',
body: JSON.stringify(create),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
}).then(handleResponse(llmSchema));
}
export async function updateLlm (id: number, update: UpdateLLM) {
return await fetch(requestUrl(`/api/v1/admin/llms/${id}`), {
method: 'PUT',
body: JSON.stringify(update),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
}).then(handleResponse(llmSchema));
}
export async function deleteLlm (id: number) {
await fetch(requestUrl(`/api/v1/admin/llms/${id}`), {
method: 'DELETE',
headers: await authenticationHeaders(),
}).then(handleErrors);
}
export async function testLlm (createLLM: CreateLLM) {
return await fetch(requestUrl(`/api/v1/admin/llms/test`), {
method: 'POST',
body: JSON.stringify(createLLM),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
})
.then(handleResponse(z.object({
success: z.boolean(),
error: z.string().optional(),
})));
}
================================================
FILE: frontend/app/src/api/providers.ts
================================================
import { z, type ZodType } from 'zod';
export interface ProviderOption {
provider: string;
provider_display_name: string | null;
provider_description: string | null;
provider_url: string | null;
credentials_display_name: string;
credentials_description: string;
credentials_type: 'str' | 'dict';
default_credentials: any;
default_config: object;
config_description: string;
}
export const providerOptionSchema = z.object({
provider: z.string(),
provider_display_name: z.string().nullable(),
provider_description: z.string().nullable(),
provider_url: z.string().nullable(),
credentials_display_name: z.string(),
credentials_description: z.string(),
default_config: z.object({}).passthrough(),
config_description: z.string(),
}).and(z.discriminatedUnion('credentials_type', [
z.object({
credentials_type: z.literal('str'),
default_credentials: z.string(),
}),
z.object({
credentials_type: z.literal('dict'),
default_credentials: z.object({}).passthrough(),
}),
])) satisfies ZodType;
================================================
FILE: frontend/app/src/api/rag.ts
================================================
import { z, type ZodType } from 'zod';
export const indexStatuses = [
'not_started',
'pending',
'running',
'completed',
'failed',
] as const;
export type IndexStatus = typeof indexStatuses[number];
export type IndexProgress = Partial>
export type IndexTotalStats = {
total: number
}
export const indexStatusSchema = z.enum(indexStatuses) satisfies ZodType;
export const totalSchema = z.object({
total: z.number(),
}) satisfies ZodType;
export const indexSchema = z.object({
not_started: z.number().optional(),
pending: z.number().optional(),
running: z.number().optional(),
completed: z.number().optional(),
failed: z.number().optional(),
}) satisfies ZodType;
================================================
FILE: frontend/app/src/api/rerankers.ts
================================================
import { type ProviderOption, providerOptionSchema } from '@/api/providers';
import { authenticationHeaders, handleErrors, handleResponse, type Page, type PageParams, requestUrl, zodPage } from '@/lib/request';
import { zodJsonDate } from '@/lib/zod';
import { z, type ZodType, type ZodTypeDef } from 'zod';
export interface Reranker {
id: number;
name: string;
provider: string;
model: string;
top_n: number;
config?: any;
is_default: boolean;
created_at: Date | null;
updated_at: Date | null;
}
export interface RerankerOption extends ProviderOption {
default_reranker_model: string;
reranker_model_description: string;
default_top_n: number;
}
export interface CreateReranker {
name: string;
provider: string;
model: string;
config?: any;
top_n: number;
is_default?: boolean;
credentials: string | object;
}
export interface UpdateReranker {
name?: string;
config?: any;
top_n?: number;
credentials?: string | object;
}
const rerankerSchema = z.object({
id: z.number(),
name: z.string(),
provider: z.string(),
model: z.string(),
top_n: z.number(),
config: z.any(),
is_default: z.boolean(),
created_at: zodJsonDate().nullable(),
updated_at: zodJsonDate().nullable(),
}) satisfies ZodType;
const rerankerOptionSchema = providerOptionSchema.and(z.object({
default_top_n: z.number(),
default_reranker_model: z.string(),
reranker_model_description: z.string(),
})) satisfies ZodType;
export async function listRerankerOptions () {
return await fetch(requestUrl(`/api/v1/admin/reranker-models/providers/options`), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(rerankerOptionSchema.array()));
}
export async function listRerankers ({ page = 1, size = 10 }: PageParams = {}): Promise> {
return await fetch(requestUrl('/api/v1/admin/reranker-models', { page, size }), {
headers: await authenticationHeaders(),
})
.then(handleResponse(zodPage(rerankerSchema)));
}
export async function getReranker (id: number): Promise {
return await fetch(requestUrl(`/api/v1/admin/reranker-models/${id}`), {
headers: await authenticationHeaders(),
}).then(handleResponse(rerankerSchema));
}
export async function createReranker (create: CreateReranker) {
return await fetch(requestUrl(`/api/v1/admin/reranker-models`), {
method: 'POST',
body: JSON.stringify(create),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
}).then(handleResponse(rerankerSchema));
}
export async function updateReranker (id: number, update: UpdateReranker) {
return await fetch(requestUrl(`/api/v1/admin/reranker-models/${id}`), {
method: 'PUT',
body: JSON.stringify(update),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
}).then(handleResponse(rerankerSchema));
}
export async function deleteReranker (id: number) {
await fetch(requestUrl(`/api/v1/admin/reranker-models/${id}`), {
method: 'DELETE',
headers: await authenticationHeaders(),
}).then(handleErrors);
}
export async function testReranker (createReranker: CreateReranker) {
return await fetch(requestUrl(`/api/v1/admin/reranker-models/test`), {
method: 'POST',
body: JSON.stringify(createReranker),
headers: {
'Content-Type': 'application/json',
...await authenticationHeaders(),
},
})
.then(handleResponse(z.object({
success: z.boolean(),
error: z.string().optional(),
})));
}
================================================
FILE: frontend/app/src/api/site-settings.ts
================================================
import { authenticationHeaders, handleErrors, handleResponse, requestUrl } from '@/lib/request';
import { z } from 'zod';
interface SettingItemBase {
name: string;
description: string;
group: string;
data_type: K;
value: T | null;
client: boolean | null;
default: T;
}
export type IntSettingItem = SettingItemBase<'int', number>
export type FloatSettingItem = SettingItemBase<'float', number>
export type BoolSettingItem = SettingItemBase<'bool', boolean>
export type StringSettingItem = SettingItemBase<'str', string>
export type ListSettingItem = SettingItemBase<'list', any[]>
export type DictSettingItem = SettingItemBase<'dict', object>
export type SettingItem =
IntSettingItem
| FloatSettingItem
| BoolSettingItem
| StringSettingItem
| ListSettingItem
| DictSettingItem;
export interface PublicWebsiteSettings {
'title': string;
'description': string;
'homepage_title': string;
'homepage_example_questions': string[];
'homepage_footer_links': { text: string, href: string }[];
'logo_in_dark_mode': string;
'logo_in_light_mode': string;
'social_github': string | null;
'social_twitter': string | null;
'social_discord': string | null;
'custom_js_example_questions': string[];
'custom_js_button_label': string;
'custom_js_button_img_src': string;
'custom_js_logo_src': string;
'ga_id': string | null;
'max_upload_file_size': number | null;
'enable_post_verifications': boolean;
'enable_post_verifications_for_widgets': boolean;
}
const settingsItemSchema = z.object({
name: z.string(),
description: z.string(),
client: z.boolean().nullable(),
group: z.string(),
// data_type: z.enum(['list', 'dict', 'str', 'int', 'float', 'bool']),
}).and(z.discriminatedUnion('data_type', [
z.object({
data_type: z.literal('str'),
value: z.string().nullable(),
default: z.string(),
}),
z.object({
data_type: z.literal('int'),
value: z.number().int().nullable(),
default: z.number().int(),
}),
z.object({
data_type: z.literal('float'),
value: z.number().nullable(),
default: z.number(),
}),
z.object({
data_type: z.literal('bool'),
value: z.boolean().nullable(),
default: z.coerce.boolean(),
}),
z.object({
data_type: z.literal('list'),
value: z.any().array().nullable(),
default: z.any().array(),
}),
z.object({
data_type: z.literal('dict'),
value: z.object({}).passthrough().nullable(),
default: z.object({}).passthrough(),
}),
]));
export type AllSettings = Record
export async function getAllSiteSettings (): Promise {
return await fetch(requestUrl(`/api/v1/admin/site-settings`),
{
headers: await authenticationHeaders(),
})
.then(handleResponse(z.record(settingsItemSchema)));
}
export async function updateSiteSetting (name: string, value: any) {
await fetch(requestUrl(`/api/v1/admin/site-settings/${name}`), {
method: 'PUT',
headers: {
...await authenticationHeaders(),
'Content-Type': 'application/json',
},
body: JSON.stringify({ value }),
}).then(handleErrors);
}
export async function getPublicSiteSettings (): Promise {
return fetch(requestUrl(`/api/v1/site-config`), {
headers: await authenticationHeaders(),
credentials: 'include',
}).then(handleErrors).then(res => res.json());
}
================================================
FILE: frontend/app/src/api/stats.ts
================================================
import { authenticationHeaders, handleResponse, requestUrl } from '@/lib/request';
import { zodDateOnlyString } from '@/lib/zod';
import { format } from 'date-fns';
import { z, type ZodType } from 'zod';
export interface TrendBaseItem {
date: Date;
}
export interface ChatUserTrendItem extends TrendBaseItem {
user: number;
anonymous: number;
}
export type ChatOriginTrendItem = TrendBaseItem & Omit<{
[key: string]: number
}, 'date'>
export interface TrendResponse {
start_date: Date;
end_date: Date;
values: T[];
}
const trendBaseSchema = z.object({
date: zodDateOnlyString(),
});
const chatUserTrendItemSchema = trendBaseSchema.extend({
user: z.number(),
anonymous: z.number(),
}) satisfies ZodType;
const chatOriginTrendItemSchema = trendBaseSchema.and(
z.preprocess((input) => {
if (input && typeof input === 'object') {
const { date: _, ...ret } = input as any;
return ret;
} else {
return input;
}
}, z.record(z.string(), z.number())),
) satisfies ZodType;
function trendResponse (item: ZodType): ZodType, any, any> {
return z.object({
start_date: zodDateOnlyString(),
end_date: zodDateOnlyString(),
values: item.array(),
});
}
function trendParams (start: Date, end: Date) {
return {
start_date: format(start, 'yyyy-MM-dd'),
end_date: format(end, 'yyyy-MM-dd'),
};
}
export async function getChatUserTrend (startDate: Date, endDate: Date) {
return await fetch(requestUrl('/api/v1/admin/stats/trend/chat-user', trendParams(startDate, endDate)), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(trendResponse(chatUserTrendItemSchema)));
}
export async function getChatOriginTrend (startDate: Date, endDate: Date) {
return await fetch(requestUrl('/api/v1/admin/stats/trend/chat-origin', trendParams(startDate, endDate)), {
headers: {
...await authenticationHeaders(),
},
})
.then(handleResponse(trendResponse(chatOriginTrendItemSchema)));
}
================================================
FILE: frontend/app/src/api/system.ts
================================================
import { authenticationHeaders, handleResponse, requestUrl } from '@/lib/request';
import { z } from 'zod';
export interface RequiredBootstrapStatus {
default_llm: boolean;
default_embedding_model: boolean;
default_chat_engine: boolean;
knowledge_base: boolean;
}
export interface OptionalBootstrapStatus {
langfuse: boolean;
default_reranker: boolean;
}
export interface NeedMigrationStatus {
chat_engines_without_kb_configured?: number[];
}
export interface BootstrapStatus {
required: RequiredBootstrapStatus;
optional: OptionalBootstrapStatus;
need_migration: NeedMigrationStatus;
}
const requiredBootstrapStatusSchema = z.object({
default_llm: z.boolean(),
default_embedding_model: z.boolean(),
default_chat_engine: z.boolean(),
knowledge_base: z.boolean(),
});
const optionalBootstrapStatusSchema = z.object({
langfuse: z.boolean(),
default_reranker: z.boolean(),
});
const needMigrationStatusSchema = z.object({
chat_engines_without_kb_configured: z.number().array().optional(),
});
const bootstrapStatusSchema = z.object({
required: requiredBootstrapStatusSchema,
optional: optionalBootstrapStatusSchema,
need_migration: needMigrationStatusSchema,
});
export async function getBootstrapStatus (): Promise {
return await fetch(requestUrl(`/api/v1/system/bootstrap-status`), {
headers: {
...await authenticationHeaders(),
},
credentials: 'include',
}).then(handleResponse(bootstrapStatusSchema));
}
export function isBootstrapStatusPassed (bootstrapStatus: BootstrapStatus): boolean {
return Object.values(bootstrapStatus.required).reduce((res, flag) => res && flag, true);
}
================================================
FILE: frontend/app/src/api/users.ts
================================================
import { authenticationHeaders, handleResponse, requestUrl } from '@/lib/request';
import { z } from 'zod';
export interface MeInfo {
id: string;
email: string;
is_active: boolean;
is_superuser: boolean;
is_verified: boolean;
}
const userSchema = z.object({
id: z.string(),
email: z.string(),
is_active: z.boolean(),
is_superuser: z.boolean(),
is_verified: z.boolean(),
});
export async function getMe (): Promise {
return await fetch(requestUrl('/api/v1/users/me'), {
headers: await authenticationHeaders(),
})
.then(handleResponse(userSchema));
}
================================================
FILE: frontend/app/src/app/(experimental)/experimental-features/route.ts
================================================
import { experimentalFeatures } from '@/experimental/experimental-features';
import { NextResponse } from 'next/server';
export function GET () {
return NextResponse.json(experimentalFeatures(), {
headers: {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET',
},
});
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/app/(main)/(.)auth/login/loading.tsx
================================================
'use client';
import { Loader } from '@/components/loader';
export default function Loading () {
return (
Loading
);
}
================================================
FILE: frontend/app/src/app/(main)/(.)auth/login/page.client.tsx
================================================
'use client';
import { Signin } from '@/components/signin';
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from '@/components/ui/dialog';
import { useRouter } from 'next/navigation';
export function SigninDialog ({ callbackUrl }: { callbackUrl?: string }) {
const router = useRouter();
return (
{
if (!open) {
router.back();
}
}}
>
Sign In
Sign in to continue to your account.
);
}
================================================
FILE: frontend/app/src/app/(main)/(.)auth/login/page.tsx
================================================
import { headers } from 'next/headers';
import { SigninDialog } from './page.client';
export default async function Page () {
const referer = (await headers()).get('Referer') ?? undefined;
return (
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/chat-engines/[id]/page.tsx
================================================
import { getChatEngine, getDefaultChatEngineOptions } from '@/api/chat-engines';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { UpdateChatEngineForm } from '@/components/chat-engine/update-chat-engine-form';
import { getBootstrapStatus } from '@/api/system';
export default async function ChatEnginePage(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const [chatEngine, defaultChatEngineOptions, bootstrapStatus] = await Promise.all([
getChatEngine(parseInt(params.id)),
getDefaultChatEngineOptions(),
getBootstrapStatus(),
]);
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/chat-engines/new/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateChatEngineForm } from '@/components/chat-engine/create-chat-engine-form';
import { getDefaultChatEngineOptions } from '@/api/chat-engines';
export default async function NewChatEnginePage () {
const defaultOptions = await getDefaultChatEngineOptions();
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/chat-engines/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { ChatEnginesTable } from '@/components/chat-engine/chat-engines-table';
import { NextLink } from '@/components/nextjs/NextLink';
export default function ChatEnginesPage () {
return (
<>
New Chat Engine
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/embedding-models/[id]/page.tsx
================================================
import { getEmbeddingModel } from '@/api/embedding-models';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { UpdateEmbeddingModelForm } from '@/components/embedding-models/UpdateEmbeddingModelForm';
export default async function Page (props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const embeddingModel = await getEmbeddingModel(parseInt(params.id));
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/embedding-models/create/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateEmbeddingModelForm } from '@/components/embedding-models/CreateEmbeddingModelForm';
import { useRouter } from 'next/navigation';
import { useTransition } from 'react';
export default function Page () {
const router = useRouter();
const [transitioning, startTransition] = useTransition();
return (
<>
{
startTransition(() => {
router.push(`/embedding-models/${embeddingModel.id}`);
router.refresh();
});
}}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/embedding-models/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { EmbeddingModelsTable } from '@/components/embedding-models/EmbeddingModelsTable';
import { NextLink } from '@/components/nextjs/NextLink';
import { PlusIcon } from 'lucide-react';
export default function EmbeddingModelPage () {
return (
<>
New Embedding Model
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/datasets/[id]/items/[itemId]/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { Loader2Icon } from 'lucide-react';
import { UpdateEvaluationDatasetItemForm } from '@/components/evaluations/update-evaluation-dataset-item-form';
import { use } from 'react';
import { useEvaluationDataset } from '@/components/evaluations/hooks';
export default function Page (props: { params: Promise<{ id: string, itemId: string }> }) {
const params = use(props.params);
const evaluationDatasetId = parseInt(decodeURIComponent(params.id));
const evaluationDatasetItemId = parseInt(decodeURIComponent(params.itemId));
const { evaluationDataset } = useEvaluationDataset(evaluationDatasetId);
return (
<>
, url: `/evaluation/datasets/${evaluationDatasetId}` },
{ title: `${evaluationDatasetItemId}` },
]}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/datasets/[id]/items/new/page.tsx
================================================
'use client';
import { mutateEvaluationDataset, useEvaluationDataset } from '@/components/evaluations/hooks';
import { use, useTransition } from 'react';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateEvaluationDatasetItemForm } from '@/components/evaluations/create-evaluation-dataset-item-form';
import { Loader2Icon } from 'lucide-react';
import { useRouter } from 'next/navigation';
export default function CreateEvaluationDatasetItemPage (props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
const evaluationDatasetId = parseInt(decodeURIComponent(params.id));
const { evaluationDataset } = useEvaluationDataset(evaluationDatasetId);
const router = useRouter();
const [transitioning, startTransition] = useTransition();
return (
<>
, url: `/evaluation/datasets/${evaluationDatasetId}` },
{ title: 'New Item' },
]}
/>
{
startTransition(() => {
router.push(`/evaluation/datasets/${evaluationDatasetId}`);
router.refresh();
void mutateEvaluationDataset(evaluationDatasetId);
});
}}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/datasets/[id]/not-found.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { ResourceNotFound } from '@/components/resource-not-found';
export default function NotFound () {
return (
<>
Not Found },
]}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/datasets/[id]/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { EvaluationDatasetInfo } from '@/components/evaluations/evaluation-dataset-info';
import { EvaluationDatasetItemsTable } from '@/components/evaluations/evaluation-dataset-items-table';
import { Loader2Icon } from 'lucide-react';
import { NextLink } from '@/components/nextjs/NextLink';
import { Separator } from '@/components/ui/separator';
import { use } from 'react';
import { useEvaluationDataset } from '@/components/evaluations/hooks';
export default function EvaluationDatasetPage (props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
const evaluationDatasetId = parseInt(decodeURIComponent(params.id));
const { evaluationDataset, isLoading } = useEvaluationDataset(evaluationDatasetId);
return (
<>
},
]}
/>
New Item
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/datasets/create/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateEvaluationDatasetForm } from '@/components/evaluations/create-evaluation-dataset-form';
import { mutateEvaluationDatasets } from '@/components/evaluations/hooks';
import { useRouter } from 'next/navigation';
import { useTransition } from 'react';
export default function EvaluationTaskPage () {
const [transitioning, startTransition] = useTransition();
const router = useRouter();
return (
<>
{
void mutateEvaluationDatasets();
startTransition(() => {
router.push(`/evaluation/datasets/${evaluationDataset.id}`);
router.refresh();
});
}}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/datasets/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { EvaluationDatasetsTable } from '@/components/evaluations/evaluation-datasets-table';
import { NextLink } from '@/components/nextjs/NextLink';
export default function EvaluationDatasetsPage () {
return (
<>
New Evaluation Dataset
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/page.tsx
================================================
import { redirect } from 'next/navigation';
export default function EvaluationsPage () {
redirect('/evaluation/tasks');
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/tasks/[id]/not-found.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { ResourceNotFound } from '@/components/resource-not-found';
export default function NotFound () {
return (
<>
Not Found },
]}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/tasks/[id]/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { EvaluationTaskInfo } from '@/components/evaluations/evaluation-task-info';
import { EvaluationTaskItemsTable } from '@/components/evaluations/evaluation-task-items-table';
import { Loader2Icon } from 'lucide-react';
import { use } from 'react';
import { useEvaluationTask } from '@/components/evaluations/hooks';
export default function EvaluationTaskPage (props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
const evaluationTaskId = parseInt(decodeURIComponent(params.id));
const { evaluationTask } = useEvaluationTask(evaluationTaskId);
return (
<>
},
]}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/tasks/create/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateEvaluationTaskForm } from '@/components/evaluations/create-evaluation-task-form';
import { mutateEvaluationTasks } from '@/components/evaluations/hooks';
import { useRouter } from 'next/navigation';
import { useTransition } from 'react';
export default function EvaluationTaskPage () {
const [transitioning, startTransition] = useTransition();
const router = useRouter();
return (
<>
{
void mutateEvaluationTasks();
startTransition(() => {
router.push(`/evaluation/tasks/${evaluationTask.id}`);
router.refresh();
});
}}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/evaluation/tasks/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { EvaluationTasksTable } from '@/components/evaluations/evaluation-tasks-table';
import { NextLink } from '@/components/nextjs/NextLink';
export default function EvaluationTasksPage () {
return (
<>
New Evaluation Task
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/feedbacks/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { FeedbacksTable } from '@/components/feedbacks/feedbacks-table';
export default function ChatEnginesPage () {
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(special)/data-sources/new/page.tsx
================================================
'use client';
;
import { mutateKnowledgeBases, useKnowledgeBase } from '@/components/knowledge-base/hooks';
import { use, useTransition } from 'react';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateDatasourceForm } from '@/components/datasource/create-datasource-form';
import { Loader2Icon } from 'lucide-react';
import { useRouter } from 'next/navigation';
export default function NewKnowledgeBaseDataSourcePage (props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
const id = parseInt(decodeURIComponent(params.id));
const { knowledgeBase } = useKnowledgeBase(id);
const [transitioning, startTransition] = useTransition();
const router = useRouter();
return (
<>
, url: `/knowledge-bases/${id}` },
{ title: 'DataSources', url: `/knowledge-bases/${id}/data-sources` },
{ title: 'New' },
]}
/>
{
startTransition(() => {
router.back();
});
mutateKnowledgeBases();
}}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(special)/documents/[documentId]/chunks/page.tsx
================================================
'use client';;
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { getKnowledgeBaseDocument, getKnowledgeBaseDocumentChunks } from '@/api/knowledge-base';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CodeInput } from '@/components/form/widgets/CodeInput';
import { DateFormat } from '@/components/date-format';
import { Loader2Icon } from 'lucide-react';
import { use } from "react";
import { useKnowledgeBase } from '@/components/knowledge-base/hooks';
import useSWR from 'swr';
export default function DocumentChunksPage(props: { params: Promise<{ id: string, documentId: string }> }) {
const params = use(props.params);
const kbId = parseInt(decodeURIComponent(params.id));
const documentId = parseInt(decodeURIComponent(params.documentId));
const { knowledgeBase } = useKnowledgeBase(kbId);
const { data: document } = useSWR(`api.knowledge-bases.${kbId}.documents.${documentId}`, () => getKnowledgeBaseDocument(kbId, documentId));
const { data = [], isLoading } = useSWR(`api.knowledge-bases.${kbId}.documents.${documentId}.chunks`, () => getKnowledgeBaseDocumentChunks(kbId, documentId), {
revalidateOnFocus: false,
});
return (
<>
, url: `/knowledge-bases/${kbId}` },
{ title: document?.name ?? },
{ title: 'Chunks' },
]}
/>
{data.map(chunk => (
{chunk.id}
Hash: {chunk.hash}
Source: {chunk.source_uri}
Updated At:
))}
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/data-sources/page.tsx
================================================
'use client';;
import { use } from "react";
import { DatasourceCard } from '@/components/datasource/datasource-card';
import { DatasourceCreateOption } from '@/components/datasource/datasource-create-option';
import { NoDatasourcePlaceholder } from '@/components/datasource/no-datasource-placeholder';
import { useAllKnowledgeBaseDataSources } from '@/components/knowledge-base/hooks';
import { Skeleton } from '@/components/ui/skeleton';
import { FileDownIcon, GlobeIcon, PaperclipIcon } from 'lucide-react';
export default function KnowledgeBaseDataSourcesPage(props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
const id = parseInt(decodeURIComponent(params.id));
const { data: dataSources, isLoading } = useAllKnowledgeBaseDataSources(id);
return (
Create Data Source
}
title="Files"
>
Upload files
}
title="Web Pages"
>
Select pages.
}
title="Website by sitemap"
>
Select web sitemap.
Browse existing Data Sources
{isLoading && }
{dataSources?.map(datasource => (
))}
{dataSources?.length === 0 && (
)}
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/index-progress/page.tsx
================================================
import { KnowledgeBaseIndexProgress } from '@/components/knowledge-base/knowledge-base-index';
export default async function KnowledgeBaseIndexProgressPage(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const id = parseInt(decodeURIComponent(params.id));
return (
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/knowledge-graph-explorer/create-synopsis-entity/page.tsx
================================================
'use client';;
import { use } from "react";
import type { KnowledgeGraphEntity } from '@/api/graph';
import { GraphCreateEntity } from '@/components/graph/GraphCreateEntity';
import { NextLink } from '@/components/nextjs/NextLink';
import { useRouter } from 'next/navigation';
export default function CreateSynopsisEntityPage(props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
const kbId = parseInt(decodeURIComponent(params.id));
const router = useRouter();
const handleCreateEntity = (entity: KnowledgeGraphEntity) => {
router.push(`/knowledge-bases/${kbId}/knowledge-graph-explorer?query=entity:${entity.id}`);
};
return (
<>
Back
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/knowledge-graph-explorer/page.tsx
================================================
import { GraphEditor } from '@/components/graph/GraphEditor';
export default async function KnowledgeGraphExplorerPage(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const id = parseInt(decodeURIComponent(params.id));
return (
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/layout.tsx
================================================
'use client';;
import { SecondaryNavigatorLayout, SecondaryNavigatorList, SecondaryNavigatorMain } from '@/components/secondary-navigator-list';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { ArrowRightIcon } from '@/components/icons';
import { KnowledgeBaseTabs } from '@/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/tabs';
import Link from 'next/link';
import { Loader2Icon } from 'lucide-react';
import type { ReactNode } from 'react';
import { use } from "react";
import { useKnowledgeBase } from '@/components/knowledge-base/hooks';
export default function KnowledgeBaseLayout(props: { params: Promise<{ id: string }>, children: ReactNode }) {
const params = use(props.params);
const {
children
} = props;
const id = parseInt(decodeURIComponent(params.id));
const { knowledgeBase } = useKnowledgeBase(id);
return (
<>
This Knowledge Base has no datasource.
Create Data Source
>,
} : undefined,
title: knowledgeBase?.name ?? ,
},
]}
/>
{children}
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/page.tsx
================================================
import { DocumentsTable } from '@/components/documents/documents-table';
export default async function KnowledgeBasePage(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const id = parseInt(decodeURIComponent(params.id));
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/settings/page.tsx
================================================
import { cachedGetKnowledgeBaseById } from '@/app/(main)/(admin)/knowledge-bases/[id]/api';
import { KnowledgeBaseSettingsForm } from '@/components/knowledge-base/knowledge-base-settings-form';
export default async function KnowledgeBaseSettingsPage(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const id = parseInt(decodeURIComponent(params.id));
const kb = await cachedGetKnowledgeBaseById(id);
return (
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/(tabs)/tabs.tsx
================================================
'use client';
import { useKnowledgeBase } from '@/components/knowledge-base/hooks';
import { SecondaryNavigatorLink } from '@/components/secondary-navigator-list';
export function KnowledgeBaseTabs ({ knowledgeBaseId }: { knowledgeBaseId: number }) {
const { knowledgeBase } = useKnowledgeBase(knowledgeBaseId);
return (
<>
Documents
{knowledgeBase?.documents_total}
Data Sources
{knowledgeBase?.data_sources_total}
Index Progress
{/* startTransition(() => {*/}
{/* router.push(`/knowledge-bases/${knowledgeBase.id}/retrieval-tester`);*/}
{/* })}*/}
{/*>*/}
{/* Retrieval Tester*/}
{/* */}
Graph Explorer
Settings
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/api.ts
================================================
import { getKnowledgeBaseById } from '@/api/knowledge-base';
import { cache } from 'react';
export const cachedGetKnowledgeBaseById = cache(getKnowledgeBaseById);
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/[id]/context.tsx
================================================
'use client';
import type { KnowledgeBase } from '@/api/knowledge-base';
import { createContext, type ReactNode, useContext } from 'react';
const KBContext = createContext(null as any);
export function KBProvider ({ children, value }: { children: ReactNode, value: KnowledgeBase }) {
return (
{children}
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/new/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateKnowledgeBaseForm } from '@/components/knowledge-base/create-knowledge-base-form';
export default function NewKnowledgeBasePage () {
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/knowledge-bases/page.tsx
================================================
'use client';
import { KnowledgeBaseCard, KnowledgeBaseCardPlaceholder } from '@/components/knowledge-base/knowledge-base-card';
import { AdminPageHeading } from '@/components/admin-page-heading';
import KnowledgeBaseEmptyState from '@/components/knowledge-base/empty-state';
import { NextLink } from '@/components/nextjs/NextLink';
import { useAllKnowledgeBases } from '@/components/knowledge-base/hooks';
export default function KnowledgeBasesPage () {
const { data: knowledgeBases, isLoading } = useAllKnowledgeBases();
return (
<>
New Knowledge Base
{
isLoading
?
: !!knowledgeBases?.length
?
{knowledgeBases.map(kb => (
))}
:
}
>
);
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/app/(main)/(admin)/layout.tsx
================================================
import { AdminPageLayout } from '@/components/admin-page-layout';
import { requireAuth } from '@/lib/auth';
import { type ReactNode } from 'react';
export default async function Layout ({ children }: { children: ReactNode }) {
await requireAuth();
return (
{children}
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/llms/[id]/page.tsx
================================================
import { getLlm } from '@/api/llms';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { UpdateLlmForm } from '@/components/llm/UpdateLLMForm';
export default async function Page (props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const llm = await getLlm(parseInt(params.id));
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/llms/create/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateLLMForm } from '@/components/llm/CreateLLMForm';
import { useRouter } from 'next/navigation';
import { useTransition } from 'react';
export default function Page () {
const router = useRouter();
const [transitioning, startTransition] = useTransition();
return (
<>
{
startTransition(() => {
router.push(`/llms/${llm.id}`);
router.refresh();
});
}}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/llms/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { LLMsTable } from '@/components/llm/LLMsTable';
import { NextLink } from '@/components/nextjs/NextLink';
import { PlusIcon } from 'lucide-react';
export default function Page () {
return (
<>
New LLM
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/reranker-models/[id]/page.tsx
================================================
import { getReranker } from '@/api/rerankers';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { UpdateRerankerForm } from '@/components/reranker/UpdateRerankerForm';
export default async function Page (props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const reranker = await getReranker(parseInt(params.id));
return (
<>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/reranker-models/create/page.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateRerankerForm } from '@/components/reranker/CreateRerankerForm';
import { useRouter } from 'next/navigation';
import { useTransition } from 'react';
export default function Page () {
const router = useRouter();
const [transitioning, startTransition] = useTransition();
return (
<>
{
startTransition(() => {
router.push(`/reranker-models/${reranker.id}`);
router.refresh();
});
}}
/>
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/reranker-models/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { NextLink } from '@/components/nextjs/NextLink';
import { PlusIcon } from 'lucide-react';
import RerankerModelsTable from '@/components/reranker/RerankerModelsTable';
export default function Page () {
return (
<>
New Reranker Model
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/site-settings/custom_js/page.tsx
================================================
import { getAllSiteSettings } from '@/api/site-settings';
import { CustomJsSettings } from '@/components/settings/CustomJsSettings';
import { WidgetSnippet } from '@/components/settings/WidgetSnippet';
export default async function CustomJsSettingsPage () {
const settings = await getAllSiteSettings();
return (
<>
Copy this HTML fragment to your page.
>
);
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/app/(main)/(admin)/site-settings/integrations/page.tsx
================================================
import { getAllSiteSettings } from '@/api/site-settings';
import { IntegrationsSettings } from '@/components/settings/IntegrationsSettings';
export default async function LangfuseSettingsPage () {
const settings = await getAllSiteSettings();
return (
<>
>
);
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/app/(main)/(admin)/site-settings/layout.tsx
================================================
'use client';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { SecondaryNavigatorLayout, SecondaryNavigatorLink, SecondaryNavigatorList, SecondaryNavigatorMain } from '@/components/secondary-navigator-list';
import { type ReactNode } from 'react';
export default function SiteSettingsLayout ({ children }: { children: ReactNode }) {
return (
Website
Integrations
JS Widget
{children}
);
}
================================================
FILE: frontend/app/src/app/(main)/(admin)/site-settings/page.tsx
================================================
import { getAllSiteSettings } from '@/api/site-settings';
import { WebsiteSettings } from '@/components/settings/WebsiteSettings';
export default async function SiteSettingsPage () {
const settings = await getAllSiteSettings();
return (
<>
>
);
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/app/(main)/(admin)/stats/trending/page.tsx
================================================
'use client';
import { getChatOriginTrend, getChatUserTrend } from '@/api/stats';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { TrendsChart } from '@/components/charts/TrendsChart';
import { startOfToday, subDays } from 'date-fns';
import { useMemo } from 'react';
import useSWR from 'swr';
const today = startOfToday();
export default function Page () {
const start = subDays(today, 30);
const end = today;
const { data: chatUserTrend } = useSWR(`api.stats.trends.chat-user.${start}-${end}`, () => getChatUserTrend(start, end));
const { data: chatOriginTrend } = useSWR(`api.stats.trends.chat-origin.${start}-${end}`, () => getChatOriginTrend(start, end));
const originKeys = useMemo(() => {
if (!chatOriginTrend) {
return [];
}
const set = new Set();
chatOriginTrend.values.forEach(item => {
Object.keys(item).forEach(key => set.add(key));
});
set.delete('date');
return Array.from(set);
}, [chatOriginTrend]);
return (
<>
{chatUserTrend && }
{chatOriginTrend && ([key, {
label: key,
color: `hsl(var(--chart-${(i % 5) + 1}))`,
}])))}
/>}
>
);
}
================================================
FILE: frontend/app/src/app/(main)/(user)/api-keys/page.tsx
================================================
'use client';
import { type ApiKey, type CreateApiKeyResponse, deleteApiKey, listApiKeys } from '@/api/api-keys';
import { AdminPageHeading } from '@/components/admin-page-heading';
import { CreateApiKeyForm } from '@/components/api-keys/CreateApiKeyForm';
import { datetime } from '@/components/cells/datetime';
import { CopyButton } from '@/components/copy-button';
import { DangerousActionButton } from '@/components/dangerous-action-button';
import { DataTableRemote } from '@/components/data-table-remote';
import { ManagedDialog } from '@/components/managed-dialog';
import { ManagedDialogClose } from '@/components/managed-dialog-close';
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
import { Button } from '@/components/ui/button';
import { DialogContent, DialogHeader, DialogTitle, DialogTrigger } from '@/components/ui/dialog';
import { DataTableConsumer, useDataTable } from '@/components/use-data-table';
import type { CellContext, ColumnDef } from '@tanstack/react-table';
import { createColumnHelper } from '@tanstack/table-core';
import { CircleCheckIcon, PlusIcon, TrashIcon } from 'lucide-react';
import { useState } from 'react';
const helper = createColumnHelper();
const mono = (cell: CellContext) => {cell.getValue()} ;
const columns = [
helper.accessor('api_key_display', { header: 'API Key', cell: mono }),
helper.accessor('description', { header: 'Description' }),
helper.accessor('created_at', { header: 'Created At', cell: datetime }),
helper.accessor('updated_at', { header: 'Updated At', cell: datetime }),
helper.display({
header: 'Actions',
cell: ({ row }) => (
),
}),
] as ColumnDef[];
export default function ChatEnginesPage () {
const [recentlyCreated, setRecentlyCreated] = useState();
return (
<>
{recentlyCreated && (
API Key created
Please note that your API key will only be shown once. Make sure to save it in a secure location as it won't be displayed again. Not storing your key safely may result in you needing to generate a new API key.
{recentlyCreated.api_key}
)}
Create
Create API Key
{(table) => (
{close => (
{
close();
setRecentlyCreated(data);
table?.reload?.();
}}
/>
)}
)}
)}
columns={columns}
apiKey="api.api-keys.list"
api={listApiKeys}
idColumn="id"
/>
>
);
}
function DeleteButton ({ apiKey }: { apiKey: ApiKey }) {
const { reload } = useDataTable();
return (
{
await deleteApiKey(apiKey.id);
reload?.();
}}
variant="ghost"
className="text-xs text-destructive hover:text-destructive hover:bg-destructive/20"
>
Delete
);
}
================================================
FILE: frontend/app/src/app/(main)/(user)/c/page.tsx
================================================
import { AdminPageHeading } from '@/components/admin-page-heading';
import { ChatsTable } from '@/components/chat/chats-table';
import { requireAuth } from '@/lib/auth';
export default async function ConversationsListPage () {
await requireAuth();
return (
<>
>
);
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/app/(main)/(user)/layout.tsx
================================================
import { AdminPageLayout } from '@/components/admin-page-layout';
import { requireAuth } from '@/lib/auth';
import { type ReactNode } from 'react';
export default async function Layout ({ children }: { children: ReactNode }) {
await requireAuth();
return (
{children}
);
}
================================================
FILE: frontend/app/src/app/(main)/c/[id]/page.tsx
================================================
import { type Chat, type ChatMessage, getChat } from '@/api/chats';
import { AutoScroll, ManualScrollVoter } from '@/components/auto-scroll';
import { Conversation } from '@/components/chat/conversation';
import { ErrorCard } from '@/components/error-card';
import { Button } from '@/components/ui/button';
import { auth } from '@/lib/auth';
import { isServerError } from '@/lib/request';
import type { Metadata } from 'next';
import { cookies } from 'next/headers';
import Link from 'next/link';
import { notFound } from 'next/navigation';
import { cache } from 'react';
const cachedGetChat = cache((id: string) => getChat(id)
.then(res => {
return res;
})
.catch(error => {
if (isServerError(error, [404, 422 /* handle not UUID */])) {
notFound();
} else {
return Promise.reject(error);
}
}));
export default async function ChatDetailPage(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const id = params.id;
const me = await auth();
const bid = (await cookies()).get('bid')?.value;
let chat: Chat | undefined;
let messages: ChatMessage[];
try {
const detail = await cachedGetChat(id);
chat = detail.chat;
messages = detail.messages;
} catch (error) {
if (isServerError(error, 403)) {
return (
{!me && (
Login to continue
)}
Back to homepage
);
}
throw error;
}
const shouldOpen = me
? me.id === chat?.user_id
: bid === chat?.browser_id;
return (
);
}
export async function generateMetadata(props: { params: Promise<{ id: string }> }): Promise {
const params = await props.params;
try {
const chat = await cachedGetChat(params.id);
return {
title: chat.chat.title,
};
} catch (error) {
if (isServerError(error, 403)) {
return {};
} else {
throw error;
}
}
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/app/(main)/layout.tsx
================================================
'use client';
import { SiteSidebar } from '@/app/(main)/nav';
import { SiteHeaderLargeScreen, SiteHeaderSmallScreen } from '@/components/site-header';
import { SidebarProvider, SidebarTrigger } from '@/components/ui/sidebar';
import { useSettingContext } from '@/components/website-setting-provider';
import { cn } from '@/lib/utils';
import { ReactNode, useState } from 'react';
export default function Layout ({ children }: {
children: ReactNode
}) {
const [sidebarOpen, setSidebarOpen] = useState(true);
const setting = useSettingContext();
return (
<>
{children}
>
);
}
================================================
FILE: frontend/app/src/app/(main)/nav.tsx
================================================
'use client';
import { logout } from '@/api/auth';
import type { PublicWebsiteSettings } from '@/api/site-settings';
import { useAuth } from '@/components/auth/AuthProvider';
import { Branding } from '@/components/branding';
import { useAllChatEngines } from '@/components/chat-engine/hooks';
import { ChatNewDialog } from '@/components/chat/chat-new-dialog';
import { ChatsHistory } from '@/components/chat/chats-history';
import { useAllKnowledgeBases } from '@/components/knowledge-base/hooks';
import { type NavGroup, SiteNav } from '@/components/site-nav';
import { useBootstrapStatus } from '@/components/system/BootstrapStatusProvider';
import { Avatar, AvatarFallback } from '@/components/ui/avatar';
import { Button } from '@/components/ui/button';
import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from '@/components/ui/dropdown-menu';
import { Sidebar, SidebarContent, SidebarFooter, SidebarHeader } from '@/components/ui/sidebar';
import { Skeleton } from '@/components/ui/skeleton';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
import { useHref } from '@/components/use-href';
import { ActivitySquareIcon, AlertTriangleIcon, BinaryIcon, BotMessageSquareIcon, BrainCircuitIcon, CircleDotIcon, CogIcon, ComponentIcon, FileLineChart, HomeIcon, KeyRoundIcon, LibraryBigIcon, LibraryIcon, LogInIcon, MessageCircleQuestionIcon, MessagesSquareIcon, ShuffleIcon } from 'lucide-react';
import NextLink from 'next/link';
import Link from 'next/link';
import { useRouter } from 'next/navigation';
import type { ReactNode } from 'react';
export function SiteSidebar ({ setting }: { setting: PublicWebsiteSettings }) {
return (
);
}
function NavContent () {
const { required, need_migration } = useBootstrapStatus();
const href = useHref();
const auth = useAuth();
const user = auth.me;
const isLoggedIn = !!user; // TODO: wait for server
const disableIfNotAuthenticated = !isLoggedIn ? <> Login to continue> : false;
const groups: NavGroup[] = [
{
items: [
{ custom: true, key: 'new-chat', children: },
{ href: '/', title: 'Home', icon: HomeIcon, exact: true },
{ href: '/c', title: 'Conversations', exact: true, icon: MessagesSquareIcon, disabled: disableIfNotAuthenticated },
{ custom: true, key: 'history', children: },
],
},
];
if (user?.is_superuser) {
groups.push({
title: 'Admin',
items: [
{ href: '/stats/trending', title: 'Dashboard', icon: ActivitySquareIcon },
{
href: '/knowledge-bases',
title: 'Knowledge Bases',
icon: LibraryBigIcon,
details: !required.knowledge_base
? You need to configure at least one Knowledge Base.
: ,
},
{
href: '/chat-engines',
title: 'Chat Engines',
icon: BotMessageSquareIcon,
details: !!need_migration.chat_engines_without_kb_configured?.length
?
Some ChatEngine need to configure KnowledgeBase .
: !required.default_chat_engine
? You need to configure default Chat Engine.
: ,
},
{
parent: true,
key: 'models',
title: 'Models',
icon: ComponentIcon,
details: (!required.default_llm || !required.default_embedding_model) && ,
children: [
{ href: '/llms', title: 'LLMs', icon: BrainCircuitIcon, details: !required.default_llm ? You need to configure at least one Default LLM. : undefined },
{ href: '/embedding-models', title: 'Embedding Models', icon: BinaryIcon, details: !required.default_embedding_model && You need to configure at least one Default Embedding Model. },
{ href: '/reranker-models', title: 'Reranker Models', icon: ShuffleIcon },
],
},
{ href: '/feedbacks', title: 'Feedbacks', icon: MessageCircleQuestionIcon },
{
parent: true,
key: 'evaluation',
title: 'Evaluation',
icon: FileLineChart,
children: [
{ href: '/evaluation/tasks', title: 'Tasks', icon: CircleDotIcon },
{ href: '/evaluation/datasets', title: 'Datasets', icon: LibraryIcon },
],
},
{ href: '/site-settings', title: 'Settings', icon: CogIcon },
],
sectionProps: { className: 'mt-auto mb-0' },
});
}
if (user?.is_superuser) {
groups.push({
title: 'Account',
items: [
{ href: '/api-keys', title: 'API Keys', icon: KeyRoundIcon },
],
});
}
return (
);
}
function NavFooter () {
const href = useHref();
const user = useAuth().me;
const router = useRouter();
if (!user) {
return (
Login
);
}
return (
{/*{user.image && }*/}
{/*{user.image ? : user.name}*/}
{user.email.slice(0, 2)}
{
logout().finally(() => {
router.refresh();
});
}}>
Sign out
{user.email}
);
}
function NavWarningDetails ({ children }: { children?: ReactNode }) {
if (!children) {
return ;
}
return (
{children}
);
}
function CountSpan ({ children }: { children?: ReactNode }) {
return {children} ;
}
function KnowledgeBaseNavDetails () {
const { data: knowledgeBases, isLoading } = useAllKnowledgeBases();
if (isLoading) {
return ;
}
return {knowledgeBases?.length} ;
}
function ChatEnginesNavDetails () {
const { data, isLoading } = useAllChatEngines();
if (isLoading) {
return ;
}
return {data?.length} ;
}
================================================
FILE: frontend/app/src/app/(main)/page.tsx
================================================
'use client';
import { Ask } from '@/components/chat/ask';
import { useAsk } from '@/components/chat/use-ask';
import { withReCaptcha } from '@/components/security-setting-provider';
import { SystemWizardBanner } from '@/components/system/SystemWizardBanner';
import { Button } from '@/components/ui/button';
import DotPattern from '@/components/ui/dot-pattern';
import { useSettingContext } from '@/components/website-setting-provider';
import { cn } from '@/lib/utils';
import NextLink from 'next/link';
const security: { google_recaptcha_site_key: string, google_recaptcha: 'v3' | 'enterprise' | '' } | null = null;
export default function Page () {
const { loading, disabled, setEngine, ask, engine } = useAsk();
const { homepage_title, description, homepage_example_questions, homepage_footer_links } = useSettingContext();
return (
{homepage_title || ''}
{description || ''}
{homepage_example_questions && (
{homepage_example_questions.map((item, index) => (
{
withReCaptcha({
action: 'ask',
siteKey: security?.google_recaptcha_site_key || '',
mode: security?.google_recaptcha,
}, ({ token, action }) => {
ask(item, {
headers: {
'X-Recaptcha-Token': token,
'X-Recaptcha-Action': action,
},
});
});
}}
>
{item}
))}
)}
{homepage_footer_links?.map(link => (
{trimHighlightedLinkText(link.text)}
))}
);
}
function isHighlightedLinkText (text: string) {
return text.startsWith('*') && text.endsWith('*')
}
function trimHighlightedLinkText (text: string) {
if (isHighlightedLinkText(text)) {
return text.slice(1, -1)
}
return text
}
================================================
FILE: frontend/app/src/app/RootProviders.tsx
================================================
'use client';
import type { PublicWebsiteSettings } from '@/api/site-settings';
import type { BootstrapStatus } from '@/api/system';
import { getMe, type MeInfo } from '@/api/users';
import { AuthProvider } from '@/components/auth/AuthProvider';
import { ChatsProvider } from '@/components/chat/chat-hooks';
import { GtagProvider } from '@/components/gtag-provider';
import { BootstrapStatusProvider } from '@/components/system/BootstrapStatusProvider';
import { Toaster } from '@/components/ui/sonner';
import { SettingProvider } from '@/components/website-setting-provider';
import { type ExperimentalFeatures, ExperimentalFeaturesProvider } from '@/experimental/experimental-features-provider';
import { cn } from '@/lib/utils';
import { ThemeProvider } from 'next-themes';
import type { ReactNode } from 'react';
import useSWR from 'swr';
export interface RootProvidersProps {
me: MeInfo | undefined;
children: ReactNode;
settings: PublicWebsiteSettings;
bootstrapStatus: BootstrapStatus;
experimentalFeatures: Partial;
}
export function RootProviders ({ me, settings, bootstrapStatus, experimentalFeatures, children }: RootProvidersProps) {
const { data, isValidating, isLoading, mutate } = useSWR('api.users.me', getMe, {
fallbackData: me,
revalidateOnMount: false,
revalidateOnFocus: false,
errorRetryCount: 0,
});
return (
mutate(data, { revalidate: true })}>
{children}
);
}
================================================
FILE: frontend/app/src/app/api/[[...fallback_placeholder]]/route.ts
================================================
import { BASE_URL } from '@/lib/request';
import type { NextRequest } from 'next/server';
declare global {
interface RequestInit {
duplex?: 'half' | boolean;
}
}
function handler (request: NextRequest) {
const base = BASE_URL;
if (!/^https?:\/\//.test(base)) {
return Promise.reject(new Error(`BASE_URL must be a http(s) url to proxy requests.`));
}
const newUrl = new URL(originalUrl(request), base);
const requestHeaders = new Headers(request.headers);
requestHeaders.delete('Accept-Encoding');
requestHeaders.delete('Host');
requestHeaders.delete('X-Invoke-Output');
requestHeaders.delete('X-Invoke-Path');
requestHeaders.delete('X-Invoke-Query');
requestHeaders.delete('X-Middleware-Invoke');
return fetch(newUrl, {
cache: 'no-cache',
method: request.method,
headers: requestHeaders,
body: request.body,
duplex: request.body instanceof ReadableStream ? 'half' : 'half',
}).then(response => {
console.log('[proxy]', request.method, newUrl.toString(), response.status, response.statusText, response.headers.get('Content-Type'));
return response;
}, error => {
console.error('[proxy]', request.method, newUrl.toString(), error);
return Promise.reject(error);
});
}
function originalUrl (request: NextRequest) {
const url = request.nextUrl;
const usp = new URLSearchParams(url.searchParams);
usp.delete('fallback_placeholder');
const search = usp.toString();
if (search) {
return url.pathname + '?' + search;
} else {
return url.pathname;
}
}
export { handler as GET, handler as POST, handler as DELETE, handler as HEAD, handler as PUT, handler as PATCH, handler as OPTIONS };
export const runtime = 'edge';
export const dynamic = 'force-dynamic';
export const maxDuration = 300;
================================================
FILE: frontend/app/src/app/auth/login/page.tsx
================================================
import { Signin } from '@/components/signin';
import { headers } from 'next/headers';
export default async function Page() {
const referer = (await headers()).get('Referer') ?? undefined;
return (
);
}
================================================
FILE: frontend/app/src/app/chart-theme.css
================================================
:root {
--chart-1: 12 76% 61%;
--chart-2: 173 58% 39%;
--chart-3: 197 37% 24%;
--chart-4: 43 74% 66%;
--chart-5: 27 87% 67%;
}
.dark {
--chart-1: 220 70% 50%;
--chart-2: 160 60% 45%;
--chart-3: 30 80% 55%;
--chart-4: 280 65% 60%;
--chart-5: 340 75% 55%;
}
================================================
FILE: frontend/app/src/app/globals.css
================================================
@tailwind base;
@tailwind components;
@tailwind utilities;
@layer base {
:root {
--header-height: 3rem;
--body-height: calc(100vh - var(--header-height));
--body-padding: 1rem;
--content-height: calc(var(--body-height) - var(--body-padding) * 6);
--sidebar-width: ;
--content-width: calc(100vw - var(--sidebar-width));
--ask-referral-height: 4rem;
--sidebar-background: 0 0% 98%;
--sidebar-foreground: 240 5.3% 26.1%;
--sidebar-primary: 240 5.9% 10%;
--sidebar-primary-foreground: 0 0% 98%;
--sidebar-accent: 240 4.8% 95.9%;
--sidebar-accent-foreground: 240 5.9% 10%;
--sidebar-border: 220 13% 91%;
--sidebar-ring: 217.2 91.2% 59.8%;
}
:root {
--background: 0 0% 100%;
--foreground: 240 10% 3.9%;
--card: 0 0% 100%;
--card-foreground: 240 10% 3.9%;
--popover: 0 0% 100%;
--popover-foreground: 240 10% 3.9%;
--primary: 240 5.9% 10%;
--primary-foreground: 0 0% 98%;
--secondary: 240 4.8% 95.9%;
--secondary-foreground: 240 5.9% 10%;
--muted: 240 4.8% 95.9%;
--muted-foreground: 240 3.8% 46.1%;
--accent: 240 4.8% 95.9%;
--accent-foreground: 240 5.9% 10%;
--destructive: 0 84.2% 60.2%;
--destructive-foreground: 0 0% 98%;
--border: 240 5.9% 90%;
--input: 240 5.9% 90%;
--ring: 240 10% 3.9%;
--radius: 0.5rem;
}
.dark {
--background: 240 10% 3.9%;
--foreground: 0 0% 98%;
--card: 240 10% 3.9%;
--card-foreground: 0 0% 98%;
--popover: 240 10% 3.9%;
--popover-foreground: 0 0% 98%;
--primary: 0 0% 98%;
--primary-foreground: 240 5.9% 10%;
--secondary: 240 3.7% 15.9%;
--secondary-foreground: 0 0% 98%;
--muted: 240 3.7% 15.9%;
--muted-foreground: 240 5% 64.9%;
--accent: 240 3.7% 15.9%;
--accent-foreground: 0 0% 98%;
--destructive: 0 72.2% 50.6%;
--destructive-foreground: 0 0% 98%;
--border: 240 3.7% 15.9%;
--input: 240 3.7% 15.9%;
--ring: 240 4.9% 83.9%;
--sidebar-background: 240 5.9% 10%;
--sidebar-foreground: 240 4.8% 95.9%;
--sidebar-primary: 224.3 76.3% 48%;
--sidebar-primary-foreground: 0 0% 100%;
--sidebar-accent: 240 3.7% 15.9%;
--sidebar-accent-foreground: 240 4.8% 95.9%;
--sidebar-border: 240 3.7% 15.9%;
--sidebar-ring: 217.2 91.2% 59.8%;
}
}
@layer base {
* {
@apply border-border;
}
body {
@apply bg-background text-foreground;
}
}
@layer base {
:root {
--brand1: 221.2 83.2% 53.3%;
--brand1-foreground: 210 40% 98%;
}
.dark {
--brand1: 217.2 91.2% 59.8%;
--brand1-foreground: 222.2 47.4% 11.2%;
}
}
@layer base {
:root {
--warning: 37.7 92.1% 50.2%;
--warning-foreground: 0 0% 98%;
--info: 198.6 88.7% 48.4%;
--info-foreground: 0 0% 98%;
--success: 142.1 70.6% 45.3%;
--success-foreground: 0 0% 98%;
}
.dark {
--warning: 32.1 94.6% 43.7%;
--warning-foreground: 0 0% 98%;
--info: 200.4 98% 39.4%;
--info-foreground: 0 0% 98%;
--success: 142.1 76.2% 36.3%;
--success-foreground: 0 0% 98%;
}
}
.grecaptcha-badge {
visibility: hidden;
}
================================================
FILE: frontend/app/src/app/layout.tsx
================================================
import { getPublicSiteSettings } from '@/api/site-settings';
import { getBootstrapStatus } from '@/api/system';
import { RootProviders } from '@/app/RootProviders';
import { experimentalFeatures } from '@/experimental/experimental-features';
import { auth } from '@/lib/auth';
import { GoogleAnalytics } from '@next/third-parties/google';
import type { Metadata } from 'next';
import { Inter } from 'next/font/google';
import Script from 'next/script';
import { cache, type ReactNode } from 'react';
import './globals.css';
import './chart-theme.css';
const inter = Inter({ subsets: ['latin'] });
const cachedGetSettings = cache(getPublicSiteSettings);
export async function generateMetadata (): Promise {
const { title, description } = await cachedGetSettings();
return {
title,
description,
icons: '/favicon.svg',
};
};
export default async function RootLayout ({
children,
}: Readonly<{
children: ReactNode;
}>) {
const [
me,
settings,
bootstrapStatus,
] = await Promise.all([
auth(),
cachedGetSettings(),
getBootstrapStatus(),
]);
const _experimentalFeatures = experimentalFeatures();
if (!settings.enable_post_verifications) {
_experimentalFeatures.enable_message_post_verification = false;
}
return (
{children}
{settings.ga_id && }
);
}
export const dynamic = 'force-dynamic';
================================================
FILE: frontend/app/src/components/admin-page-heading.tsx
================================================
'use client';
import { Breadcrumb, BreadcrumbItem, BreadcrumbLink, BreadcrumbList, BreadcrumbPage, BreadcrumbSeparator } from '@/components/ui/breadcrumb';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
import { useSettingContext } from '@/components/website-setting-provider';
import { AlertCircleIcon, AlertTriangleIcon, CheckCircleIcon, HelpCircleIcon } from 'lucide-react';
import Link from 'next/link';
import { Fragment, type ReactNode } from 'react';
export interface BreadcrumbItem {
title: ReactNode;
url?: string;
docsUrl?: string;
alert?: {
variant: 'success' | 'warning' | 'destructive';
content: ReactNode;
};
}
export interface TableHeadingProps {
breadcrumbs?: BreadcrumbItem[];
}
export function AdminPageHeading ({ breadcrumbs }: TableHeadingProps) {
const { title: siteTitle } = useSettingContext();
return (
{breadcrumbs && (
{siteTitle}
{breadcrumbs.map((item, index) => (
{item.alert &&
{item.alert.variant === 'success'
?
: item.alert.variant === 'warning'
?
: item.alert.variant === 'destructive'
?
: }
{item.alert.content}
}
{item.url
? {item.title}
: index === breadcrumbs.length - 1
? {item.title}
: {item.title} }
{item.docsUrl
?
: undefined}
))}
)}
);
}
================================================
FILE: frontend/app/src/components/admin-page-layout.tsx
================================================
import type { ReactNode } from 'react';
export function AdminPageLayout ({ children }: { children: ReactNode }) {
return (
{children}
);
}
================================================
FILE: frontend/app/src/components/api-keys/CreateApiKeyForm.tsx
================================================
import { createApiKey, type CreateApiKeyResponse } from '@/api/api-keys';
import { FormInput } from '@/components/form/control-widget';
import { withCreateEntityForm } from '@/components/form/create-entity-form';
import { z } from 'zod';
const schema = z.object({
description: z.string(),
});
export interface CreateApiKeyFormProps {
onCreated?: (data: CreateApiKeyResponse) => void;
}
const FormImpl = withCreateEntityForm(schema, createApiKey, {
submitTitle: 'Create API Key',
submittingTitle: 'Creating API Key...',
});
export function CreateApiKeyForm ({ onCreated }: CreateApiKeyFormProps) {
return (
);
}
================================================
FILE: frontend/app/src/components/auth/AuthProvider.tsx
================================================
import type { MeInfo } from '@/api/users';
import { createContext, type ReactNode, useContext } from 'react';
export interface AuthContextValues {
me: MeInfo | undefined;
isLoading: boolean;
isValidating: boolean;
reload: () => void;
}
const AuthContext = createContext({ me: undefined, isLoading: false, isValidating: false, reload: () => {}, });
export function AuthProvider ({ children, ...context }: AuthContextValues & { children: ReactNode }) {
return (
{children}
);
}
export function useAuth () {
return useContext(AuthContext);
}
================================================
FILE: frontend/app/src/components/auto-scroll/auto-scroll.stories.tsx
================================================
import type { Meta, StoryObj } from '@storybook/react';
import { type RefObject, useEffect, useRef, useState } from 'react';
import { AutoScroll } from './auto-scroll';
import { ManualScrollVoter } from './manual-scroll-voter';
import { useRequestScroll } from './use-request-scroll';
const meta = {
title: 'Components/AutoScroll',
subcomponents: {},
parameters: {
layout: 'centered',
},
tags: ['autodocs'],
argTypes: {},
args: {},
} satisfies Meta;
export default meta;
type Story = StoryObj
export const Container: Story = {
args: {},
render: () => {
const [count, setCount] = useState(0);
const [target, setTarget] = useState(null);
useEffect(() => {
const interval = setInterval(() => {
setCount(count => count + 1);
}, 1000);
return () => {
clearInterval(interval);
};
}, []);
return (
{Array(count).fill(0).map((_, i) => (
))}
);
},
};
export const Document: Story = {
args: {},
render: () => {
const [count, setCount] = useState(0);
const ref = useRef(null);
useEffect(() => {
const interval = setInterval(() => {
setCount(count => count + 1);
}, 1000);
return () => {
clearInterval(interval);
};
}, []);
return (
{Array(count).fill(0).map((_, i) => (
))}
);
},
};
function AutoScrollByObservingChildren ({ target }: { target: RefObject }) {
const requestScroll = useRequestScroll();
useEffect(() => {
const el = target.current;
if (el) {
const mo = new MutationObserver(() => {
requestScroll('bottom');
});
mo.observe(el, { childList: true });
return () => {
mo.disconnect();
};
}
}, [target.current]);
return null;
}
================================================
FILE: frontend/app/src/components/auto-scroll/auto-scroll.tsx
================================================
'use client';
import { type MutableRefObject, type ReactNode, useCallback, useEffect, useRef, useState } from 'react';
import { AutoScrollContext, ScrollEdge, ScrollHandler } from './context';
export interface AutoScrollProps {
/**
* undefined - Window
* null - no scroll target
*/
target?: HTMLElement | null;
children: ReactNode;
edgePixels?: number;
}
export function AutoScroll ({ target, edgePixels = 0, children }: AutoScrollProps) {
const [map, setMap] = useState>(() => new Map());
const [handlers, setHandlers] = useState[]>([]);
const lastScrollPositionInfo = useRef<{ left: number, top: number }>({ left: 0, top: 0 });
const targetRef = useRef(target);
const currentHandlers = useRef(handlers);
useEffect(() => {
currentHandlers.current = handlers;
targetRef.current = target;
});
const registerVoter = useCallback((id: string, defaultScroll: boolean) => {
setMap(map => new Map(map).set(id, defaultScroll));
}, []);
const unregisterVoter = useCallback((id: string) => {
setMap(map => {
map = new Map(map);
map.delete(id);
return map;
});
}, []);
const voteAutoScroll = useCallback((id: string) => {
setMap(map => {
if (map.get(id)) {
return map;
}
return new Map(map).set(id, true);
});
}, []);
const cancelVoteAutoScroll = useCallback((id: string) => {
setMap(map => {
if (map.get(id) === false) {
return map;
}
return new Map(map).set(id, false);
});
}, []);
const useScroll = useCallback(function useScroll (handler: ScrollHandler) {
const currentRef = useRef(handler);
useEffect(() => {
currentRef.current = handler;
});
useEffect(() => {
setHandlers(handlers => [...handlers, currentRef]);
return () => {
setHandlers(handlers => handlers.filter(handler => handler !== currentRef));
};
}, []);
return;
}, []);
useEffect(() => {
if (target === null) {
return;
}
const scrollTarget = target ?? window;
if (scrollTarget) {
const measureTarget = scrollTarget === window ? document.documentElement : scrollTarget;
const eventTarget = scrollTarget;
lastScrollPositionInfo.current = {
left: measureTarget.scrollLeft,
top: measureTarget.scrollTop,
};
const getSize = () => {
if (scrollTarget === window) {
return {
width: window.innerWidth,
height: window.innerHeight,
};
} else {
return measureTarget.getBoundingClientRect();
}
};
const handleScroll = (event: Event) => {
const { scrollLeft, scrollTop, scrollHeight, scrollWidth } = measureTarget;
const { width, height } = getSize();
const x = scrollLeft - lastScrollPositionInfo.current.left;
const y = scrollTop - lastScrollPositionInfo.current.top;
const reachLeft = scrollLeft <= edgePixels;
const reachRight = scrollLeft + width >= scrollWidth - edgePixels;
const reachTop = scrollTop <= edgePixels;
const reachBottom = scrollTop + height >= scrollHeight - edgePixels;
lastScrollPositionInfo.current = {
left: scrollLeft,
top: scrollTop,
};
currentHandlers.current.forEach(ref => {
ref.current({
x, y, reachLeft, reachTop, reachRight, reachBottom,
});
});
};
eventTarget.addEventListener('scroll', handleScroll, { passive: true });
return () => {
eventTarget.removeEventListener('scroll', handleScroll);
};
}
}, [target, edgePixels]);
const votes = Array.from(map.values());
const shouldAutoScroll = votes.length > 0 && votes.indexOf(false) === -1;
const shouldAutoScrollRef = useRef(shouldAutoScroll);
useEffect(() => {
shouldAutoScrollRef.current = shouldAutoScroll;
});
const requestScroll = useCallback((edge: ScrollEdge) => {
const target = targetRef.current;
if (target === null) {
return;
}
const scrollTarget = target ?? document.documentElement;
if (shouldAutoScrollRef.current && scrollTarget) {
switch (edge) {
case 'top':
scrollTarget.scrollTo({ top: 0, behavior: 'smooth' });
break;
case 'left':
scrollTarget.scrollTo({ left: 0, behavior: 'smooth' });
break;
case 'bottom':
scrollTarget.scrollTo({ top: scrollTarget.scrollHeight, behavior: 'smooth' });
break;
case 'right':
scrollTarget.scrollTo({ left: scrollTarget.scrollWidth, behavior: 'smooth' });
break;
}
}
}, []);
return (
{children}
);
}
================================================
FILE: frontend/app/src/components/auto-scroll/context.ts
================================================
'use client';
import { createContext } from 'react';
export type ScrollInfo = {
x: number
y: number
reachLeft: boolean
reachTop: boolean
reachRight: boolean
reachBottom: boolean
}
export type ScrollEdge = 'top' | 'left' | 'right' | 'bottom';
export type ScrollHandler = (info: ScrollInfo) => void;
export interface AutoScrollContextValues {
registerVoter (id: string, defaultScroll: boolean): void;
unregisterVoter (id: string): void;
voteAutoScroll (id: string): void;
cancelVoteAutoScroll (id: string): void;
useScroll (handler: ScrollHandler): void;
requestScroll (edge: ScrollEdge): void;
}
export const AutoScrollContext = createContext({
registerVoter (id: string, defaultScroll: boolean) {},
unregisterVoter (id: string) {},
voteAutoScroll (id: string) {},
cancelVoteAutoScroll (id: string) {},
useScroll () {},
requestScroll () {},
});
================================================
FILE: frontend/app/src/components/auto-scroll/index.ts
================================================
export * from './manual-scroll-voter';
export * from './auto-scroll';
export * from './use-auto-scroll-voter';
export * from './use-request-scroll';
================================================
FILE: frontend/app/src/components/auto-scroll/manual-scroll-voter.tsx
================================================
'use client';
import { useAutoScrollVoter } from './use-auto-scroll-voter';
/**
* When scrolling up, stop auto scroll.
*
* When scrolled down to bottom, start auto scroll.
*/
export function ManualScrollVoter () {
const { useScroll, voteAutoScroll, cancelVoteAutoScroll } = useAutoScrollVoter(true);
useScroll((info) => {
const { y, reachBottom } = info;
if (y > 0 && reachBottom) {
voteAutoScroll();
} else if (y < 0) {
cancelVoteAutoScroll();
}
});
return null;
}
================================================
FILE: frontend/app/src/components/auto-scroll/use-auto-scroll-voter.ts
================================================
'use client';
import { AutoScrollContext } from './context';
import { useContext, useEffect, useId, useMemo } from 'react';
export function useAutoScrollVoter (defaultScroll: boolean) {
const id = useId();
const { useScroll, registerVoter, unregisterVoter, voteAutoScroll, cancelVoteAutoScroll } = useContext(AutoScrollContext);
useEffect(() => {
registerVoter(id, defaultScroll);
return () => {
unregisterVoter(id);
};
}, [id]);
const _voteAutoScroll = useMemo(() => {
return () => voteAutoScroll(id);
}, [id]);
const _cancelVoteAutoScroll = useMemo(() => {
return () => cancelVoteAutoScroll(id);
}, [id]);
return {
useScroll,
voteAutoScroll: _voteAutoScroll,
cancelVoteAutoScroll: _cancelVoteAutoScroll,
};
}
================================================
FILE: frontend/app/src/components/auto-scroll/use-request-scroll.ts
================================================
'use client';
import { useContext } from 'react';
import { AutoScrollContext } from './context';
export function useRequestScroll () {
const { requestScroll } = useContext(AutoScrollContext);
return requestScroll;
}
================================================
FILE: frontend/app/src/components/branding.tsx
================================================
import type { PublicWebsiteSettings } from '@/api/site-settings';
import Link from 'next/link';
export function Branding ({ setting }: { setting: PublicWebsiteSettings }) {
return (
);
}
================================================
FILE: frontend/app/src/components/cells/actions.tsx
================================================
import { DangerousActionButton, type DangerousActionButtonProps } from '@/components/dangerous-action-button';
import { buttonVariants } from '@/components/ui/button';
import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator, DropdownMenuTrigger } from '@/components/ui/dropdown-menu';
import { useDataTable } from '@/components/use-data-table';
import { cn } from '@/lib/utils';
import type { CellContext } from '@tanstack/react-table';
import { EllipsisIcon, Loader2Icon } from 'lucide-react';
import type { AppRouterInstance } from 'next/dist/shared/lib/app-router-context.shared-runtime';
import { useRouter } from 'next/navigation';
import { type Dispatch, type ReactNode, type SetStateAction, type TransitionStartFunction, useState, useTransition } from 'react';
export interface CellAction {
type?: 'button' | 'label' | 'separator';
key?: string | number;
icon?: ReactNode;
title?: ReactNode;
disabled?: boolean;
dangerous?: Pick;
action?: (context: ActionUIContext) => Promise | void;
}
export interface ActionUIContext {
table: ReturnType;
startTransition: TransitionStartFunction;
router: AppRouterInstance;
dropdownOpen: boolean;
setDropdownOpen: Dispatch>;
}
export function actions (items: (row: Row) => CellAction[]) {
// eslint-disable-next-line react/display-name
return (props: CellContext) => {
const [open, setOpen] = useState(false);
const actionItems = items(props.row.original);
return (
{actionItems.map((item, index) => {
if (item.type === 'label') {
return {item.title} ;
} else if (item.type === 'separator') {
return ;
} else {
return ;
}
})}
);
};
}
function Action ({ item, open, setOpen }: { item: CellAction, open: boolean, setOpen: Dispatch> }) {
const table = useDataTable();
const [busy, setBusy] = useState(false);
const [transitioning, startTransition] = useTransition();
const router = useRouter();
const onAction = async () => {
try {
setBusy(true);
await item?.action?.({ startTransition, router, table, dropdownOpen: open, setDropdownOpen: setOpen });
} finally {
setBusy(false);
}
};
let el = (
{
event.preventDefault();
}
: (event) => {
event.preventDefault();
void onAction();
}}
>
{item.icon ? (busy || transitioning) ? : item.icon : null}
{item.title}
);
if (item.dangerous) {
el = (
{el}
);
}
return el;
}
================================================
FILE: frontend/app/src/components/cells/boolean.tsx
================================================
import type { CellContext } from '@tanstack/react-table';
import { CheckIcon, XIcon } from 'lucide-react';
export function boolean (props: CellContext) {
const bool = props.getValue();
if (bool == null) {
return - ;
}
if (bool) {
return (
Yes
);
} else {
return (
No
);
}
}
================================================
FILE: frontend/app/src/components/cells/datetime.tsx
================================================
import type { CellContext } from '@tanstack/react-table';
import { format } from 'date-fns';
export function datetime (props: CellContext) {
const date = props.getValue();
if (!date) {
return '-';
}
if (Number.isNaN(date.getTime())) {
return 'Invalid Date';
}
return format(date, 'yyyy-MM-dd HH:mm:ss');
}
================================================
FILE: frontend/app/src/components/cells/error-message.tsx
================================================
import { PythonViewer } from '@/components/py-viewer';
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from '@/components/ui/dialog';
import type { CellContext } from '@tanstack/react-table';
export function errorMessageCell (trimLength = 25) {
return function ErrorMessageCell (context: CellContext) {
return {context.getValue() ?? '-'} ;
};
}
export function AutoErrorMessagePopper ({ trimLength = 25, children }: { trimLength?: number, children: string | null }) {
if (!children || children.length <= trimLength) {
return children;
}
const shortcut = children.slice(0, trimLength);
return (
{shortcut}{'... '}
({children.length + ' characters'})
Error Message
);
}
================================================
FILE: frontend/app/src/components/cells/link.tsx
================================================
import type { CellContext } from '@tanstack/react-table';
import Link from 'next/link';
export interface LinkCellProps {
icon?: React.ReactNode;
url?: (row: Row) => string;
text?: (row: Row) => string;
truncate?: boolean;
truncate_length?: number;
}
const format_link = (url: string, maxLength: number = 30): string => {
if (!url || url.length <= maxLength) return url;
const start = url.substring(0, maxLength / 2);
const end = url.substring(url.length - maxLength / 2);
return `${start}...${end}`;
};
export function link ({ icon, url, text, truncate, truncate_length }: LinkCellProps) {
// eslint-disable-next-line react/display-name
return (context: CellContext) => {
const href_value = url ? url(context.row.original) : String(context.getValue());
const text_value = text ? text(context.row.original) : String(context.getValue());
const display_text = truncate ? format_link(text_value, truncate_length) : text_value;
return
{icon} {display_text}
};
}
================================================
FILE: frontend/app/src/components/cells/metadata.tsx
================================================
import { ThemedStyle } from '@/components/themed-style';
import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/popover';
import type { CellContext } from '@tanstack/react-table';
import JsonView from '@uiw/react-json-view';
import { darkTheme } from '@uiw/react-json-view/dark';
import { lightTheme } from '@uiw/react-json-view/light';
import { AlignLeftIcon, BracesIcon, BracketsIcon } from 'lucide-react';
import type { ReactElement, ReactNode } from 'react';
export const metadataCell = (props: CellContext) => {
const metadata = props.getValue();
if (metadata == null) {
return (nul) ;
}
let icon: ReactElement | null;
let text: ReactNode;
if (typeof metadata === 'object') {
if (metadata instanceof Array) {
icon = ;
text = {`${metadata.length} items`} ;
} else {
icon = ;
text = {`${Object.keys(metadata).length} keys`} ;
}
} else {
const stringValue = String(metadata);
if (stringValue.length < 25) {
return stringValue;
}
icon = ;
text = {typeof metadata} ;
}
return (
{icon} {text}
);
};
================================================
FILE: frontend/app/src/components/cells/mono.tsx
================================================
import type { CellContext } from '@tanstack/react-table';
export const mono = (cell: CellContext) => {String(cell.getValue())} ;
================================================
FILE: frontend/app/src/components/cells/percent.tsx
================================================
import { cn } from '@/lib/utils';
import * as ProgressPrimitive from '@radix-ui/react-progress';
import type { CellContext } from '@tanstack/react-table';
import * as React from 'react';
import { type CSSProperties } from 'react';
interface PercentCellConfig {
multiply?: number;
colorStops?: {
checkpoint: number
color: string
}[];
}
export const percent = (
cell: CellContext,
{ multiply = 100, colorStops }: PercentCellConfig = {},
) => (
{cell.getValue()?.toFixed(2) ?? '-'}
);
function getStyle (value: number, stops: {
checkpoint: number
color: string
}[] | undefined): CSSProperties {
if (!stops || stops.length === 0) {
return {};
}
let from = stops.findLast(stop => stop.checkpoint <= value)!;
let to = stops.find(stop => stop.checkpoint >= value)!;
if (!from || !to) {
return {};
}
if (from.checkpoint === to.checkpoint) {
return {
backgroundColor: from.color,
};
}
const p1 = (value - from.checkpoint) / (to.checkpoint - from.checkpoint) * 100;
return {
backgroundColor: `color-mix(in srgb, ${from.color} ${100 - p1}%, ${to.color} ${p1}%)`,
};
}
================================================
FILE: frontend/app/src/components/cells/reference.tsx
================================================
'use client';
import Link from 'next/link';
export function DatasourceCell ({ id, name }: { id: number, name: string }) {
return {name} ;
}
export function KnowledgeBaseCell ({ id, name }: { id?: number, name?: string }) {
if (id == null) {
return - ;
}
return {name ?? 'Unnamed'};
}
================================================
FILE: frontend/app/src/components/charts/IndexProgressChart.stories.tsx
================================================
import { IndexProgressChart, IndexProgressChartPlaceholder } from '@/components/charts/IndexProgressChart';
import type { Meta, StoryObj } from '@storybook/react';
import type { FC } from 'react';
const valueType = {
name: 'number',
required: false,
} as const;
const meta = {
title: 'Components/Charts/IndexProgressChart',
component: IndexProgressChart,
subcomponents: {
IndexProgressChartPlaceholder: IndexProgressChartPlaceholder as FC,
},
parameters: {
layout: 'centered',
},
tags: ['autodocs'],
decorators: [
(Story) => (
),
],
argTypes: {},
args: {},
} satisfies Meta;
export default meta;
type Story = StoryObj
export const Default: Story = {
args: {
title: 'Title',
description: 'Description',
label: 'Count',
data: {
failed: 23,
completed: 120,
not_started: 2,
},
},
};
export const Placeholder: StoryObj = {
args: {
title: 'Title',
description: 'Description',
label: 'Count',
},
render ({ ...args }) {
return ;
},
};
================================================
FILE: frontend/app/src/components/charts/IndexProgressChart.tsx
================================================
'use client';
import type { IndexProgress } from '@/api/rag';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
import { Skeleton } from '@/components/ui/skeleton';
import * as React from 'react';
import { type ReactNode, useMemo } from 'react';
import { Label, Pie, PieChart } from 'recharts';
const color_error = '#ef4444';
const color_succeed = '#22c55e';
const color_in_progress = '#3b82f6';
const color_pending = '#71717a';
const color_blank = '#71717a80';
const color_placeholder = '#71717a40';
const chartConfig = {
total: {
label: 'Total',
},
completed: {
label: 'Completed',
color: color_succeed,
},
pending: {
label: 'Pending',
color: color_pending,
},
running: {
label: 'Running',
color: color_in_progress,
},
failed: {
label: 'Failed',
color: color_error,
},
not_started: {
label: 'Not Started',
color: color_blank,
},
} satisfies ChartConfig;
export function IndexProgressChart ({ title, description, label, data }: { title: string, description?: string, label: ReactNode, data: IndexProgress }) {
const total = React.useMemo(() => {
return Object.values(data).reduce((a, b) => a + b, 0);
}, [data]);
const chartData = useMemo(() => {
return [
{ count: data.completed, state: 'Completed', fill: color_succeed },
{ count: data.failed, state: 'Failed', fill: color_error },
{ count: data.pending, state: 'Pending', fill: color_pending },
{ count: data.running, state: 'Running', fill: color_in_progress },
{ count: data.not_started, state: 'Not started', fill: color_blank },
];
}, []);
return (
{title}
{description && {description} }
}
/>
{
if (viewBox && 'cx' in viewBox && 'cy' in viewBox) {
return (
{total.toLocaleString()}
{label}
);
}
}}
/>
);
}
export interface IndexProgressChartPlaceholderProps {
title: string,
label: ReactNode,
description?: string
}
export function IndexProgressChartPlaceholder ({ title, label, description }: IndexProgressChartPlaceholderProps) {
return (
{title}
{description && {description} }
{
if (viewBox && 'cx' in viewBox && 'cy' in viewBox) {
return (
--
{label}
);
}
}}
/>
);
}
================================================
FILE: frontend/app/src/components/charts/TotalCard.stories.tsx
================================================
import { TotalCard } from '@/components/charts/TotalCard';
import type { Meta, StoryObj } from '@storybook/react';
import { Aperture } from 'lucide-react';
const meta = {
title: 'Components/Charts/TotalCard',
component: TotalCard,
parameters: {
layout: 'centered',
},
tags: ['autodocs'],
decorators: [
(Story) => (
),
],
argTypes: {
total: {
type: 'number',
control: 'select',
options: [42, null, undefined],
},
},
args: {},
} satisfies Meta;
export default meta;
type Story = StoryObj
export const Default: Story = {
args: {
title: 'Title',
isLoading: false,
icon: ,
total: 42,
children: 'Hahaha',
},
};
================================================
FILE: frontend/app/src/components/charts/TotalCard.tsx
================================================
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Skeleton } from '@/components/ui/skeleton';
import { cn } from '@/lib/utils';
import type { ReactNode } from 'react';
const nf = new Intl.NumberFormat('en-US', {});
export interface TotalCardProps {
isLoading: boolean;
title: string,
icon: ReactNode,
total?: number | null | undefined,
children?: ReactNode
}
export function TotalCard ({ isLoading = false, title, icon, total, children }: TotalCardProps) {
return (
{title}
{icon}
{isLoading ? : nf.format(total || 0)}
{children}
);
}
================================================
FILE: frontend/app/src/components/charts/TrendsChart.tsx
================================================
'use client';
import type { TrendResponse } from '@/api/stats';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { ChartContainer, ChartLegend, ChartLegendContent, ChartTooltip, ChartTooltipContent } from '@/components/ui/chart';
import { cn } from '@/lib/utils';
import { addDays, format, isAfter } from 'date-fns';
import { ReactNode, useMemo } from 'react';
import { Bar, BarChart, CartesianGrid, XAxis } from 'recharts';
export function TrendsChart ({
className,
title,
description,
data,
dimensions,
config,
}: {
className?: string,
title: ReactNode,
description: ReactNode,
data: TrendResponse,
dimensions: Exclude[],
config: { [P in Exclude]: { label: string, color: string } }
}) {
const chartData = useMemo(() => {
const dates: any[] = [];
for (let i = data.start_date; !isAfter(i, data.end_date); i = addDays(i, 1)) {
dates.push(i);
}
const map = new Map(data.values.map(item => [String(item.date), item]));
return dates.map(date => ({
...map.get(String(date)),
date: format(date, 'yyyy-MM-dd'),
}));
}, [data]);
return (
{title}
{description}
}
/>
} />
{dimensions.map((dimension) => (
))}
{/**/}
{/* */}
{/*
*/}
{/*
*/}
{/* Trending up by 5.2% this month */}
{/*
*/}
{/*
*/}
{/* January - June 2024*/}
{/*
*/}
{/*
*/}
{/*
*/}
{/* */}
);
}
================================================
FILE: frontend/app/src/components/chat/ask.tsx
================================================
import { MessageInput } from '@/components/chat/message-input';
import { type UseAskReturns } from '@/components/chat/use-ask';
import { SecuritySettingContext, withReCaptcha } from '@/components/security-setting-provider';
import { useContext, useRef } from 'react';
export function Ask ({ className, loading, disabled, ask, engine, setEngine }: { className?: string } & UseAskReturns) {
const ref = useRef(null);
const security = useContext(SecuritySettingContext);
return (
);
}
================================================
FILE: frontend/app/src/components/chat/chat-controller.test.ts
================================================
import { ChatMessageRole } from '@/api/chats';
import type { ChatController } from '@/components/chat/chat-controller';
import type { ChatMessageController } from '@/components/chat/chat-message-controller';
import type { ChatInitialData } from '@/components/chat/chat-stream-state';
import { jest } from '@jest/globals';
import { z } from 'zod';
jest.unstable_mockModule('../../api/chats', () => ({
chatMessageSchema: z.any(),
chatSchema: z.any(),
chat: (...args: any) => currentChat(...args),
}));
let currentChat: any;
afterAll(() => {
// using pnpm patch https://github.com/jestjs/jest/pull/15080/files#diff-c0d5b59e96fdc7ffc98405e8afb46d525505bc7b1c24916b5c8482de5a186c00
jest.unstable_unmockModule('../../api/chats');
});
const exampleData = {
chat: {
id: 'mock',
created_at: new Date,
updated_at: new Date,
browser_id: null,
deleted_at: null,
origin: "None",
engine_options: {
llm: {
condense_question_prompt: '',
text_qa_prompt: '',
refine_prompt: '',
intent_graph_knowledge: '',
normal_graph_knowledge: '',
},
knowledge_graph: {
depth: 0,
enabled: false,
include_meta: false,
with_degree: false,
using_intent_search: false,
},
},
user_id: '0',
engine_id: 1,
title: 'Demo',
},
assistant_message: {
chat_id: 'mock',
content: '',
created_at: new Date(),
error: null,
finished_at: new Date(),
id: 2,
ordinal: 1,
role: ChatMessageRole.assistant,
sources: [],
trace_url: '',
updated_at: new Date(),
user_id: '0',
post_verification_result_url: null,
},
user_message: {
chat_id: 'mock',
content: 'ping',
created_at: new Date(),
error: null,
finished_at: new Date(),
id: 1,
ordinal: 0,
role: ChatMessageRole.user,
sources: [],
trace_url: '',
updated_at: new Date(),
user_id: '0',
post_verification_result_url: null,
},
} satisfies ChatInitialData;
describe('stream protocol', () => {
const onPost = jest.fn();
const onPostInitialized = jest.fn();
const onMessageLoaded = jest.fn();
const onPostError = jest.fn();
const onPostFinished = jest.fn();
const postRejection = jest.fn();
const addListeners = (controller: ChatController) => {
controller.on('post', onPost)
.on('post-initialized', onPostInitialized)
.on('message-loaded', onMessageLoaded)
.on('post-error', onPostError)
.on('post-finished', onPostFinished);
};
const newChatController = async () => {
// for using `jest.unstable_mockModule` mocked module
const { ChatController } = await import('./chat-controller');
const controller = new ChatController();
addListeners(controller);
return controller;
};
test('terminate before server responses', async () => {
const error = new Error('terminate before server response');
currentChat = () => {
throw error;
};
const controller = await newChatController();
await controller.post({ content: 'hi' }).catch(postRejection);
expect(postRejection).toHaveBeenCalledTimes(0);
expect(onPost).toHaveBeenCalledTimes(1);
expect(onPost).toHaveBeenCalledWith({ content: 'hi' });
expect(onPostInitialized).toHaveBeenCalledTimes(0);
expect(onPostError).toHaveBeenCalledTimes(1);
expect(onPostError).toHaveBeenCalledWith(error);
});
test('terminated by stream protocol', async () => {
currentChat = async function* () {
yield {
type: 'data',
value: [exampleData],
};
yield { type: 'text', value: 'pong' };
yield { type: 'error', value: 'terminated' };
};
const controller = await newChatController();
await controller.post({ content: 'hi' }).catch(postRejection);
expect(postRejection).toHaveBeenCalledTimes(0);
expect(onPost).toHaveBeenCalledTimes(1);
expect(onPost).toHaveBeenCalledWith({ content: 'hi' });
expect(onPostInitialized).toHaveBeenCalledTimes(1);
expect(onPostError).toHaveBeenCalledTimes(0);
expect(onMessageLoaded).toHaveBeenCalledTimes(2);
const assistantMessage: ChatMessageController = controller.messages.find(msg => msg.role === 'assistant')!;
expect(assistantMessage.content).toBe('pong');
expect(assistantMessage.message.error).toBe('terminated');
});
test('normal', async () => {
currentChat = async function* () {
yield {
type: 'data',
value: [exampleData],
};
yield { type: 'text', value: 'pong' };
};
const controller = await newChatController();
await controller.post({ content: 'ping' }).catch(postRejection);
expect(postRejection).toHaveBeenCalledTimes(0);
expect(onPostInitialized).toHaveBeenCalledTimes(1);
expect(onMessageLoaded).toHaveBeenCalledTimes(2);
expect(onPostFinished).toHaveBeenCalledTimes(1);
const assistantMessage: ChatMessageController = controller.messages.find(msg => msg.role === 'assistant')!;
expect(assistantMessage.content).toBe('pong');
});
});
================================================
FILE: frontend/app/src/components/chat/chat-controller.ts
================================================
import { chat, type Chat, type ChatMessage, type PostChatParams } from '@/api/chats';
import { BaseChatMessageController, ChatMessageController, LegacyChatMessageController, type OngoingState, StackVMChatMessageController } from '@/components/chat/chat-message-controller';
import { AppChatStreamState, type BaseAnnotation, chatDataPartSchema, fixChatInitialData, type StackVMState } from '@/components/chat/chat-stream-state';
import type { GtagFn } from '@/components/gtag-provider';
import { getErrorMessage } from '@/lib/errors';
import { trigger } from '@/lib/react';
import { type JSONValue, type StreamPart } from 'ai';
import EventEmitter from 'eventemitter3';
export interface ChatControllerEventsMap = BaseAnnotation> {
'created': [Chat];
'updated': [Chat];
'message-loaded': [messageController: BaseChatMessageController];
/**
* Emit instantly when {@link ChatController#post} is called
*/
'post': [params: Omit];
/**
* Emit when server returned chat and chat_message info
*/
'post-initialized': [];
'post-finished': [];
'post-error': [error: unknown];
/**
* Experimental
*/
'ui:input-mount': [HTMLTextAreaElement | HTMLInputElement];
'ui:input-unmount': [HTMLTextAreaElement | HTMLInputElement];
}
export class ChatController = BaseAnnotation> extends EventEmitter> {
public chat: Chat | undefined;
private _messages: Map = new Map();
private _postParams: Omit | undefined = undefined;
private _postError: unknown = undefined;
private _postInitialized: boolean = false;
private _inputElement: HTMLTextAreaElement | HTMLInputElement | null = null;
get postState () {
return {
params: this._postParams,
error: this._postError,
initialized: this._postInitialized,
};
}
constructor (
chat: Chat | undefined = undefined,
messages: ChatMessage[] | undefined = [],
initialPost: Omit | undefined = undefined,
inputElement: HTMLInputElement | HTMLTextAreaElement | null = null,
private readonly _gtagFn: GtagFn = () => {},
) {
super();
if (chat) {
this.chat = chat;
}
for (const message of messages) {
this.upsertMessage(message);
}
if (initialPost) {
this.post(initialPost);
}
this._inputElement = inputElement;
if (inputElement) {
this.emit('ui:input-mount', inputElement);
}
}
get inputElement () {
return this._inputElement;
}
set inputElement (value: HTMLInputElement | HTMLTextAreaElement | null) {
if (this._inputElement) {
if (value) {
if (value !== this._inputElement) {
const old = this._inputElement;
this._inputElement = null;
this.emit('ui:input-unmount', old);
this._inputElement = value;
this.emit('ui:input-mount', value);
}
} else {
const old = this._inputElement;
this._inputElement = null;
this.emit('ui:input-unmount', old);
}
} else {
if (value) {
this._inputElement = value;
this.emit('ui:input-mount', value);
}
}
}
private get _enabledInputElement () {
if (!this._inputElement) {
console.warn('Input element is not exists.');
return;
}
if (this._inputElement.disabled) {
console.warn('Input element is disabled currently.');
return;
}
return this._inputElement;
}
get inputEnabled () {
if (!this._inputElement) {
return false;
}
return !this._inputElement.disabled;
}
get input (): string {
return this._inputElement?.value ?? '';
}
set input (value: string) {
const inputElement = this._enabledInputElement;
if (inputElement) {
trigger(inputElement as HTMLTextAreaElement, HTMLTextAreaElement, value);
}
}
focusInput () {
const inputElement = this._enabledInputElement;
if (inputElement) {
inputElement.focus();
}
}
get messages (): (ChatMessageController | StackVMChatMessageController)[] {
return Array.from(this._messages.values()).sort((a, b) => a.message.ordinal - b.message.ordinal);
}
async post (params: Omit) {
if (this._postParams) {
throw new Error('previous not finished.');
}
if (!params.content.trim()) {
throw new Error('Empty message');
}
this._gtagFn('event', 'tidbai.events.message-start', {
'tidbai_appending_message': !!this.chat?.id,
});
// Initialize post states
this._postParams = params;
this._postError = undefined;
this._postInitialized = false;
this.emit('post', params);
let ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined = undefined;
try {
const streamParts = chat({
chat_id: this.chat?.id,
...params,
});
// Process stream parts and dispatch to ongoingMessageController
for await (let part of streamParts) {
ongoingMessageController = this._processPart(ongoingMessageController, part);
}
// Cleanup post states
if (ongoingMessageController) {
this.upsertMessage(ongoingMessageController.finish());
} else {
console.warn('Empty ongoing message');
}
this._postParams = undefined;
this._postInitialized = false;
this.emit('post-finished');
this._gtagFn('event', 'tidbai.events.message-finish', {});
} catch (error) {
ongoingMessageController?.applyError(getErrorMessage(error));
this._postError = error;
this.emit('post-error', error);
this._gtagFn('event', 'tidbai.events.message-error', {});
}
}
// TODO: wait server implementation
async regenerate (messageId: number) {
throw new Error('not supported.');
}
updateChat (chat: Chat): void {
const newCreated = !this.chat;
this.chat = { ...this.chat, ...chat };
this.emit(newCreated ? 'created' : 'updated', this.chat);
if (process.env.NEXT_PUBLIC_IS_WIDGET) {
if (newCreated) {
window.dispatchEvent(new CustomEvent('tidbainewchat', {
detail: { id: chat.id },
}));
}
}
}
upsertMessage (message: ChatMessage): void {
let controller = this._messages.get(message.id);
if (controller) {
controller.update(message);
} else {
this.createMessage(message);
}
}
_processPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) {
switch (part.type) {
case 'data':
// Data part contains chat and chat_message info from server. will be sent twice (beginning and finished).
// We will update frontend cached and computed info like message content which is computed from stream text deltas.
ongoingMessageController = this._processDataPart(ongoingMessageController, part);
break;
case 'message_annotations':
// Message annotations part containing current generating state.
this._processMessageAnnotationPart(ongoingMessageController, part);
break;
case 'text':
this._processTextPart(ongoingMessageController, part);
break;
case 'error':
this._processErrorPart(ongoingMessageController, part);
break;
case 'tool_call':
this._processToolCallPart(ongoingMessageController, part);
break;
case 'tool_result':
this._processToolResultPart(ongoingMessageController, part);
break;
case 'finish_step':
case 'finish_message':
// TODO
break;
default:
console.warn('unsupported stream part', part);
}
return ongoingMessageController;
}
private _processDataPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>): ChatMessageController | StackVMChatMessageController {
const { chat, user_message, assistant_message } = chatDataPartSchema.parse(fixChatInitialData(part.value[0]));
this.updateChat(chat);
this.upsertMessage(user_message);
if (!ongoingMessageController) {
ongoingMessageController = this.createMessage(assistant_message, true);
this._postInitialized = true;
this.emit('post-initialized');
} else {
ongoingMessageController.update(assistant_message);
}
return ongoingMessageController;
}
private _processMessageAnnotationPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) {
assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part);
const annotation = ongoingMessageController.parseAnnotation(part.value[0]);
ongoingMessageController.applyStreamAnnotation(annotation as never);
}
private _processTextPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) {
if (part.value) { // ignore leading empty chunks.
assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part);
ongoingMessageController.applyDelta(part.value);
}
}
private _processErrorPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) {
assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part);
ongoingMessageController.applyError(part.value);
}
private _processToolCallPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) {
assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part);
ongoingMessageController.applyToolCall(part.value);
}
private _processToolResultPart (ongoingMessageController: ChatMessageController | StackVMChatMessageController | undefined, part: ReturnType['parse']>) {
assertNonNull(ongoingMessageController, 'Cannot handle chat stream part: no ongoingMessageController', part);
ongoingMessageController.applyToolResult(part.value);
}
private createMessage (message: ChatMessage, initialOngoingState?: true) {
if (!this.chat?.engine_options) {
throw new Error('Unable to decide which chat engine used.');
}
if (this.chat.engine_options.external_engine_config?.stream_chat_api_url) {
return this.createStackVMMessage(message, initialOngoingState);
} else {
return this.createLegacyMessage(message, initialOngoingState);
}
}
private createLegacyMessage (message: ChatMessage, initialOngoingState?: true | OngoingState) {
const controller = new LegacyChatMessageController(message, initialOngoingState);
this._messages.set(message.id, controller);
this.emit('message-loaded', controller as any);
return controller;
}
private createStackVMMessage (message: ChatMessage, initialOngoingState?: true | OngoingState) {
const controller = new StackVMChatMessageController(message, initialOngoingState);
this._messages.set(message.id, controller);
this.emit('message-loaded', controller as any);
return controller;
}
}
function assertNonNull (value: T, message: string, ...args: any): asserts value is NonNullable {
if (value == null) {
console.warn(message, args);
throw new Error('bad stream');
}
}
================================================
FILE: frontend/app/src/components/chat/chat-hooks.tsx
================================================
import { type Chat, type ChatMessage, ChatMessageRole } from '@/api/chats';
import { isBootstrapStatusPassed } from '@/api/system';
import { ChatController } from '@/components/chat/chat-controller';
import { BaseChatMessageController, ChatMessageController, type ChatMessageControllerAnnotationState, type OngoingState, type OngoingStateHistoryItem } from '@/components/chat/chat-message-controller';
import type { AppChatStreamState, StackVMState } from '@/components/chat/chat-stream-state';
import { useGtagFn } from '@/components/gtag-provider';
import { useBootstrapStatus } from '@/components/system/BootstrapStatusProvider';
import { useLatestRef } from '@/components/use-latest-ref';
import { createContext, type ReactNode, useContext, useEffect, useState } from 'react';
export interface ChatsProviderValues {
chats: Map;
disabled: boolean;
newChat (...args: ConstructorParameters): ChatController;
destroyChat (id: string): void;
}
const ChatsContext = createContext({
chats: new Map(),
disabled: true,
newChat (): ChatController {
throw new Error('not in a chat context');
},
destroyChat () {
throw new Error('not in a chat context');
},
});
const ChatControllerContext = createContext(null);
export function ChatsProvider ({ onChatCreated, children }: { children: ReactNode, /** @deprecated */onChatCreated?: (id: string, chat: Chat, controller: ChatController) => void }) {
const bootstrapStatusRef = useLatestRef(useBootstrapStatus());
const [chats, setChats] = useState(() => new Map);
const newChat: ChatsProviderValues['newChat'] = (...args) => {
if (!isBootstrapStatusPassed(bootstrapStatusRef.current)) {
throw new Error('System check not passed.');
}
const controller = new ChatController(...args);
controller.once('created', (chat) => {
setChats(chats => new Map(chats).set(chat.id, controller));
onChatCreated?.(chat.id, chat, controller);
});
return controller;
};
const destroyChat: ChatsProviderValues['destroyChat'] = (id: string) => {
setChats(chats => {
chats = new Map(chats);
chats.delete(id);
return chats;
});
};
return (
{children}
);
}
export function ChatControllerProvider ({ controller, children }: { controller: ChatController | null, children: ReactNode }) {
return (
{children}
);
}
export function useChats () {
return useContext(ChatsContext);
}
export interface ChatMessageGroup {
user: ChatMessageController;
assistant: ChatMessageController | undefined;
hasFirstAssistantMessage: boolean;
hasLastAssistantMessage: boolean;
hasLastUserMessage: boolean;
}
export function useChatController (
id: string | undefined,
initialChat: Chat | undefined,
initialMessages: ChatMessage[] | undefined,
inputElement: HTMLInputElement | HTMLTextAreaElement | null = null,
) {
const gtagFn = useGtagFn();
const { chats } = useChats();
// Create essential chat controller
const [controller] = useState(() => {
if (id) {
let controller = chats.get(id);
if (!controller) {
controller = new ChatController(initialChat, initialMessages, undefined, inputElement, gtagFn);
chats.set(id, controller);
}
return controller;
} else {
return new ChatController(undefined, undefined, undefined, inputElement, gtagFn);
}
});
useEffect(() => {
controller.inputElement = inputElement;
}, [controller, inputElement]);
return controller;
}
export function useChatInfo (controller: ChatController) {
const [chat, setChat] = useState(controller.chat);
useEffect(() => {
if (controller) {
setChat(controller.chat);
const handleChatUpdated = (chat: Chat) => setChat(chat);
controller
.on('updated', handleChatUpdated)
.on('created', handleChatUpdated);
return () => {
controller
.off('updated', handleChatUpdated)
.off('created', handleChatUpdated);
};
}
}, [controller]);
return chat;
}
export function useChatPostState (controller: ChatController | undefined) {
const [state, setState] = useState(controller?.postState ?? { initialized: false, params: undefined, error: undefined });
useEffect(() => {
if (controller) {
setState(controller.postState);
const handleStateChange = () => {
setState(controller.postState);
};
controller
.on('post', handleStateChange)
.on('post-initialized', handleStateChange)
.on('post-finished', handleStateChange)
.on('post-error', handleStateChange);
return () => {
controller
.off('post', handleStateChange)
.off('post-initialized', handleStateChange)
.off('post-finished', handleStateChange)
.off('post-error', handleStateChange);
};
}
}, [controller]);
return state;
}
export function useChatMessageControllers (controller: ChatController) {
const [messageControllers, setMessageControllers] = useState(() => controller?.messages);
useEffect(() => {
if (controller) {
setMessageControllers(controller.messages);
const handleMessageLoaded = () => {
setMessageControllers(controller.messages);
};
controller.on('message-loaded', handleMessageLoaded);
return () => {
controller.off('message-loaded', handleMessageLoaded);
};
}
}, [controller]);
return messageControllers;
}
export function useChatMessageGroups (controllers: ChatMessageController[]) {
const [chatMessageGroups, setChatMessageGroups] = useState(() => collectMessageGroups(controllers));
useEffect(() => {
setChatMessageGroups(collectMessageGroups(controllers));
}, [controllers]);
return chatMessageGroups;
}
function collectMessageGroups (messageControllers: ChatMessageController[]) {
const groups: ChatMessageGroup[] = [];
let user: ChatMessageController | undefined;
for (let messageController of messageControllers) {
switch (messageController.role) {
case ChatMessageRole.user:
user = messageController;
break;
case ChatMessageRole.assistant:
if (user) {
groups.push({
user,
assistant: messageController,
hasFirstAssistantMessage: false,
hasLastAssistantMessage: false,
hasLastUserMessage: false,
});
} else {
console.warn('No matched user message, drop assistant message', messageController.message.id);
}
break;
}
}
let group = groups.findLast(group => !!group.assistant);
if (group) {
group.hasLastAssistantMessage = true;
}
group = groups.find(group => !!group.assistant);
if (group) {
group.hasFirstAssistantMessage = true;
}
group = groups.findLast(group => !!group.user);
if (group) {
group.hasLastUserMessage = true;
}
return groups;
}
export function useCurrentChatController () {
const controller = useContext(ChatControllerContext);
if (!controller) {
throw new Error('Not in a chat controller provider');
}
return controller;
}
export function useChatMessageField (controller: ChatMessageController, key: K): ChatMessage[K];
export function useChatMessageField (controller: ChatMessageController | undefined, key: K): ChatMessage[K] | undefined;
export function useChatMessageField (controller: ChatMessageController | undefined, key: keyof ChatMessage): any {
const [value, setValue] = useState(controller?.message[key]);
useEffect(() => {
if (controller) {
setValue(controller.message[key]);
const handleUpdate = (message: ChatMessage) => {
setValue(message[key]);
};
controller
.on('update', handleUpdate)
.on('stream-update', handleUpdate)
.on('stream-error', handleUpdate)
.on('stream-finished', handleUpdate);
return () => {
controller
.off('update', handleUpdate)
.off('stream-update', handleUpdate)
.off('stream-error', handleUpdate)
.off('stream-finished', handleUpdate);
};
} else {
setValue(undefined);
}
}, [controller, key]);
return value;
}
export function useChatMessageStreamState (controller: C | undefined): OngoingState> | undefined {
const [state, setState] = useState(controller?.ongoing);
useEffect(() => {
if (controller) {
setState(controller.ongoing);
const handleUpdate = (_: any, state?: OngoingState) => setState(state);
controller
.on('stream-update', handleUpdate)
.on('stream-error', handleUpdate)
.on('stream-finished', handleUpdate);
return () => {
controller
.off('stream-update', handleUpdate)
.off('stream-error', handleUpdate)
.off('stream-finished', handleUpdate);
};
} else {
setState(undefined);
}
}, [controller]);
return state as OngoingState>;
}
export function useChatMessageStreamHistoryStates> (controller: C | undefined): C['ongoingHistory'] {
const [state, setState] = useState(controller?.ongoingHistory);
useEffect(() => {
if (controller) {
setState(controller.ongoingHistory);
const handleUpdate = (_: any, state?: OngoingStateHistoryItem[]) => {
if (state) {
setState(state);
}
};
controller
.on('stream-history-update', handleUpdate)
.on('stream-finished', handleUpdate);
return () => {
controller
.off('stream-history-update', handleUpdate)
.off('stream-finished', handleUpdate);
};
} else {
setState(undefined);
}
}, [controller]);
return state;
}
export function useChatMessageStreamContainsState (controller: ChatMessageController | undefined, state: AppChatStreamState) {
const history = useChatMessageStreamHistoryStates(controller);
const current = useChatMessageStreamState(controller);
// FIXME: what if state not triggered?
if (!current || current.finished) {
return true;
}
return history?.some(item => item.state.state === state) || current?.state === state;
}
================================================
FILE: frontend/app/src/components/chat/chat-message-controller.test.ts
================================================
import { ChatMessageController, LegacyChatMessageController } from '@/components/chat/chat-message-controller';
import { AppChatStreamState } from '@/components/chat/chat-stream-state';
import { createExampleInitialChatMessage } from '@/components/chat/testutils';
import { jest } from '@jest/globals';
describe('stream', () => {
const onUpdate = jest.fn();
const onStreamUpdate = jest.fn();
const onStreamError = jest.fn();
const onStreamFinished = jest.fn();
test('success', () => {
const controller = new LegacyChatMessageController(createExampleInitialChatMessage(), true);
controller.on('update', onUpdate)
.on('stream-update', onStreamUpdate)
.on('stream-error', onStreamError)
.on('stream-finished', onStreamFinished);
// TRACE event should provide langfuse_url
expect(controller.message.trace_url).toBe('');
controller.applyStreamAnnotation({
state: AppChatStreamState.TRACE,
display: 'trace',
context: { langfuse_url: 'fake_url' },
});
expect(controller.message.trace_url).toBe('fake_url');
// SOURCE_NODES event should provide sources
expect(controller.message.sources.length).toBe(0);
controller.applyStreamAnnotation({
state: AppChatStreamState.SOURCE_NODES,
display: 'source_nodes',
context: [{ source_uri: 'uri', id: 0, name: 'foo' }],
});
expect(controller.message.sources).toStrictEqual([{ source_uri: 'uri', id: 0, name: 'foo' }]);
// test applyDelta
expect(controller.message.content).toEqual('');
controller.applyDelta('Hello');
expect(controller.message.content).toEqual('Hello');
controller.applyDelta(' world!');
expect(controller.message.content).toEqual('Hello world!');
expect(onStreamUpdate).toHaveBeenCalledTimes(4);
controller.finish();
expect(controller.ongoing).toBeUndefined();
expect(onStreamFinished).toHaveBeenCalledTimes(1);
expect(onStreamError).toHaveBeenCalledTimes(0);
});
test('error', () => {
const controller = new LegacyChatMessageController(createExampleInitialChatMessage(), true);
controller.on('update', onUpdate)
.on('stream-update', onStreamUpdate)
.on('stream-error', onStreamError)
.on('stream-finished', onStreamFinished);
controller.applyError('error');
expect(controller.ongoing!.finished).toBe(true);
expect(controller.message.error).toBe('error');
expect(onStreamFinished).toHaveBeenCalledTimes(0);
expect(onStreamError).toHaveBeenCalledTimes(1);
});
});
================================================
FILE: frontend/app/src/components/chat/chat-message-controller.ts
================================================
import { type ChatMessage, ChatMessageRole } from '@/api/chats';
import { AppChatStreamState, type BaseAnnotation, type ChatMessageAnnotation, type StackVMState, type StackVMStateAnnotation } from '@/components/chat/chat-stream-state';
import { StackVM } from '@/lib/stackvm';
import EventEmitter from 'eventemitter3';
export interface OngoingState {
finished: boolean;
state: State;
display: string;
message?: string;
}
export interface OngoingStateHistoryItem {
state: OngoingState;
time: Date;
}
export interface ChatMessageControllerEventsMap {
'update': [assistant_message: ChatMessage];
'stream-update': [ongoing_message: ChatMessage, ongoing: OngoingState, delta: string];
'stream-history-update': [ongoing_message: ChatMessage, history: { state: OngoingState, time: Date }[]];
'stream-finished': [ongoing_message: ChatMessage];
'stream-error': [ongoing_message: ChatMessage, ongoing: OngoingState];
'stream-tool-call': [id: string, name: string, args: any];
'stream-tool-result': [id: string, result: any];
}
export abstract class BaseChatMessageController<
State,
Annotation extends BaseAnnotation
> extends EventEmitter> {
protected _message: ChatMessage;
protected _ongoing: OngoingState | undefined;
protected _ongoingHistory: OngoingStateHistoryItem[] | undefined;
public readonly role: ChatMessageRole;
public readonly id: number;
constructor (message: ChatMessage, ongoing: OngoingState | true | undefined) {
super();
this._message = message;
this._ongoing = ongoing === true ? this.createInitialOngoingState() : ongoing;
this._ongoingHistory = ongoing ? [] : undefined;
this.role = message.role;
this.id = message.id;
if (this._message.finished_at == null && !ongoing) {
this._ongoing = this.createUnknownOngoingState();
}
}
// dynamic, usage in react component needs subscription.
get content () {
return this.message?.content ?? '';
}
update (message: ChatMessage) {
this._message = { ...this._message, ...message };
this.emit('update', this._message);
}
applyStreamAnnotation (annotation: Annotation) {
if (!this._ongoing || this._ongoing.finished) {
console.warn('message already finished');
return;
}
const stateChanged = annotation.state !== this._ongoing.state;
let message = this._message;
const ongoing: OngoingState = { ...this._ongoing };
ongoing.state = annotation.state;
ongoing.display = annotation.display || (stateChanged ? '' : ongoing.display);
ongoing.message = stateChanged ? undefined : ongoing.message;
message = this._polishMessage(message, ongoing, annotation);
const lastOngoing = this._ongoing;
this._ongoing = ongoing;
this._message = message;
if (annotation.state === AppChatStreamState.FINISHED) {
this._ongoing.finished = true;
}
this.emit('stream-update', this._message, this._ongoing, '');
if (stateChanged && this._ongoingHistory != null) {
const lastState = this._ongoingHistory[this._ongoingHistory.length - 1];
if (lastOngoing && lastOngoing.display && lastOngoing.state !== lastState?.state.state) {
// Insert new state
this._ongoingHistory = [
...this._ongoingHistory,
{
state: lastOngoing,
time: new Date(),
},
];
this.emit('stream-history-update', this._message, this._ongoingHistory);
}
}
}
applyDelta (delta: string) {
if (!this._ongoing || this._ongoing.finished) {
console.warn('message already finished');
return;
}
this._message = {
...this._message,
content: this._message.content + delta,
};
this.emit('stream-update', this._message, this._ongoing, delta);
}
applyError (error: string) {
if (!this._ongoing || this._ongoing.finished) {
console.warn('message already finished');
console.error('Error in ChatMessageController (on finished message):', error);
return;
}
this._ongoing = {
...this._ongoing,
finished: true,
};
this._message = {
...this._message,
error,
};
this.emit('stream-error', this._message, this._ongoing);
}
applyToolCall ({ toolCallId, toolName, args }: { toolCallId: string, toolName: string, args: any }) {
this.emit('stream-tool-call', toolCallId, toolName, args);
}
applyToolResult ({ toolCallId, result }: { toolCallId: string, result: any }) {
this.emit('stream-tool-result', toolCallId, result);
}
finish () {
this._ongoing = undefined;
this.emit('stream-finished', this._message);
return this._message;
}
get message (): ChatMessage {
return this._message;
}
get ongoing () {
return this._ongoing;
}
get ongoingHistory () {
return this._ongoingHistory;
}
abstract parseAnnotation (raw: unknown): Annotation;
abstract createInitialOngoingState (): OngoingState;
abstract createUnknownOngoingState (): OngoingState;
protected abstract _polishMessage (message: ChatMessage, ongoing: OngoingState, annotation: Annotation): ChatMessage
}
export type ChatMessageController = LegacyChatMessageController | StackVMChatMessageController;
export type ChatMessageControllerAnnotationState = C extends BaseChatMessageController ? State : never;
export class LegacyChatMessageController extends BaseChatMessageController {
readonly version = 'Legacy';
parseAnnotation (raw: unknown): ChatMessageAnnotation {
return raw as ChatMessageAnnotation;
}
createInitialOngoingState (): OngoingState {
return {
state: AppChatStreamState.CONNECTING,
display: 'Connecting to server...',
finished: false,
};
}
createUnknownOngoingState (): OngoingState {
return {
state: AppChatStreamState.UNKNOWN,
display: 'Unknown',
finished: false,
};
}
_polishMessage (message: ChatMessage, ongoing: OngoingState, annotation: ChatMessageAnnotation) {
switch (annotation.state) {
case AppChatStreamState.TRACE:
message = { ...message };
message.trace_url = annotation.context.langfuse_url;
break;
case AppChatStreamState.SOURCE_NODES:
message = { ...message };
message.sources = annotation.context;
break;
case AppChatStreamState.REFINE_QUESTION:
ongoing.message = annotation.message || ongoing.message;
break;
}
return message;
}
}
export class StackVMChatMessageController extends BaseChatMessageController {
readonly version = 'StackVM';
applyToolCall (payload: { toolCallId: string; toolName: string; args: any }) {
super.applyToolCall(payload);
if (this._ongoing) {
this._ongoing = {
...this._ongoing,
state: {
...this._ongoing.state,
toolCalls: [...this._ongoing.state.toolCalls, payload],
},
};
this.emit('stream-update', this._message, this._ongoing, '');
}
}
applyToolResult (payload: { toolCallId: string; result: any }) {
super.applyToolResult(payload);
if (this._ongoing) {
const idx = this._ongoing.state.toolCalls.findIndex(toolCall => toolCall.toolCallId === payload.toolCallId);
if (idx >= 0) {
this._ongoing.state.toolCalls[idx] = {
...this._ongoing.state.toolCalls[idx],
result: payload.result,
};
this._ongoing.state = { ...this._ongoing.state };
this._ongoing = { ...this._ongoing };
this.emit('stream-update', this._message, this._ongoing, '');
}
}
}
parseAnnotation (raw: unknown): StackVMStateAnnotation {
const { state: rawState, task_id, branch, seq_no } = raw as { state: StackVM.State, task_id: string, branch: string, seq_no: number };
const state = StackVM.model.parseState(rawState);
return {
state: { task_id, branch, state, toolCalls: [], seq_no },
display: '[deprecated]',
};
}
createInitialOngoingState (): OngoingState {
return {
state: {
task_id: '',
branch: '',
seq_no: -1,
state: {
variables_refs: {},
variables: {},
errors: [],
current_plan: [],
program_counter: -1,
goal_completed: false,
goal: '',
msgs: [],
plan: {
steps: [],
vars: [],
},
},
toolCalls: [],
},
display: 'Thinking...',
finished: false,
};
}
createUnknownOngoingState (): OngoingState {
return {
state: {
task_id: '',
branch: '',
seq_no: -1,
state: {
variables_refs: {},
variables: {},
errors: ['Unknown state'],
current_plan: [],
program_counter: -1,
goal_completed: false,
goal: '',
msgs: [],
plan: {
steps: [],
vars: [],
},
},
toolCalls: [],
},
display: 'Unknown',
finished: false,
};
}
_polishMessage (message: ChatMessage): ChatMessage {
return message;
}
}
================================================
FILE: frontend/app/src/components/chat/chat-new-dialog.tsx
================================================
import { Ask } from '@/components/chat/ask';
import { useAsk } from '@/components/chat/use-ask';
import { Button } from '@/components/ui/button';
import { Dialog, DialogDescription, DialogHeader, DialogOverlay, DialogPortal, DialogTitle, DialogTrigger } from '@/components/ui/dialog';
import { cn } from '@/lib/utils';
import * as DialogPrimitive from '@radix-ui/react-dialog';
import { CommandIcon, PlusIcon } from 'lucide-react';
import { useEffect, useState } from 'react';
export function ChatNewDialog () {
const [open, setOpen] = useState(false);
const ask = useAsk(() => {
setOpen(false);
});
useEffect(() => {
const handle = (e: KeyboardEvent) => {
if (e.key === 'k' && (e.metaKey || e.ctrlKey) && !(e.shiftKey)) {
setOpen(true);
e.preventDefault();
e.stopPropagation();
}
};
window.addEventListener('keydown', handle);
return () => {
window.removeEventListener('keydown', handle);
};
}, []);
return (
Ask Question Dialog
New Thread
K
);
}
================================================
FILE: frontend/app/src/components/chat/chat-stream-state.ts
================================================
/**
* TRACE = 0
* SOURCE_NODES = 1
* KG_RETRIEVAL = 2
* REFINE_QUESTION = 3
* SEARCH_RELATED_DOCUMENTS = 4
* GENERATE_ANSWER = 5
* FINISHED = 9
*/
import { Chat, ChatMessage, chatMessageSchema, ChatMessageSource, chatSchema } from '@/api/chats';
import { StackVM } from '@/lib/stackvm';
import { z, type ZodType } from 'zod';
export const enum BaseState {
CONNECTING = 'CONNECTING', // only client side
UNKNOWN = 'UNKNOWN',
}
export const enum AppChatStreamState {
CONNECTING = 'CONNECTING', // only client side
TRACE = 'TRACE',
SOURCE_NODES = 'SOURCE_NODES',
KG_RETRIEVAL = 'KG_RETRIEVAL',
REFINE_QUESTION = 'REFINE_QUESTION',
SEARCH_RELATED_DOCUMENTS = 'SEARCH_RELATED_DOCUMENTS',
RERANKING = 'RERANKING',
GENERATE_ANSWER = 'GENERATE_ANSWER',
FINISHED = 'FINISHED',
FAILED = 'FAILED',
UNKNOWN = 'UNKNOWN',
}
export type StackVMState = {
task_id: string;
branch: string;
state: StackVM.model.ParsedState;
seq_no: number;
toolCalls: StackVMToolCall[];
};
export type StackVMToolCall = { toolCallId: string, toolName: string, args: any, result?: any }
export interface BaseAnnotation {
state: S;
display?: string;
}
export interface TraceAnnotation extends BaseAnnotation {
context: { langfuse_url: string };
}
export interface SourceNodesAnnotation extends BaseAnnotation {
context: ChatMessageSource[];
}
export interface RefineQuestionAnnotation extends BaseAnnotation {
message?: string;
}
export type ChatMessageAnnotation =
BaseAnnotation>
| TraceAnnotation
| SourceNodesAnnotation
| RefineQuestionAnnotation;
export interface StackVMStateAnnotation extends BaseAnnotation {
}
export type ChatInitialData = {
chat: Chat;
user_message: ChatMessage;
assistant_message: ChatMessage;
}
// FIXME: Server will not return updated_at
export function fixChatInitialData (data: any) {
if (data.assistant_message) {
if (!data.assistant_message.updated_at) {
data.assistant_message.updated_at = data.assistant_message.created_at;
}
}
return data as any;
}
export const chatDataPartSchema = z.object({
chat: chatSchema,
user_message: chatMessageSchema,
assistant_message: chatMessageSchema,
}) satisfies ZodType;
================================================
FILE: frontend/app/src/components/chat/chat-stream.state.test.ts
================================================
import { chatDataPartSchema, fixChatInitialData } from '@/components/chat/chat-stream-state';
it('pass', () => {
const result = chatDataPartSchema.safeParse(fixChatInitialData({
'chat': {
'updated_at': '2024-07-08T10:28:39',
'id': '019091e3-5cfc-74a3-b5e0-653a73f52af2',
'title': 'What is TiDB?',
'engine_id': 1,
'engine_options': {"llm":{"provider":"openai","openai_chat_model":"gpt-4o","gemini_chat_model":"models/gemini-2.0-flash","reranker_provider":"jinaai","reranker_top_k":10,"intent_graph_knowledge":"Given a list of relationships of a knowledge graph as follows. When there is a conflict in meaning between knowledge relationships, the relationship with the higher `weight` and newer `last_modified_at` value takes precedence.\\n\\nKnowledge sub-queries:\\n\\n{% for sub_query, data in sub_queries.items() %}\\n\\nSub-query: {{ sub_query }}\\n\\n - Entities:\\n\\n{% for entity in data[\'entities\'] %}\\n\\n - Name: {{ entity.name }}\\n - Description: {{ entity.description }}\\n\\n{% endfor %}\\n\\n - Relationships:\\n\\n{% for relationship in data[\'relationships\'] %}\\n\\n - Description: {{ relationship.rag_description }}\\n - Last Modified At: {{ relationship.last_modified_at }}\\n - Meta: {{ relationship.meta | tojson(indent=2) }}\\n\\n{% endfor %}\\n\\n{% endfor %}\\n","normal_graph_knowledge":"Given a list of relationships of a knowledge graph as follows. When there is a conflict in meaning between knowledge relationships, the relationship with the higher `weight` and newer `last_modified_at` value takes precedence.\\n\\n---------------------\\nEntities:\\n\\n{% for entity in entities %}\\n\\n- Name: {{ entity.name }}\\n- Description: {{ entity.description }}\\n\\n{% endfor %}\\n\\n---------------------\\n\\nKnowledge relationships:\\n\\n{% for relationship in relationships %}\\n\\n- Description: {{ relationship.rag_description }}\\n- Weight: {{ relationship.weight }}\\n- Last Modified At: {{ relationship.last_modified_at }}\\n- Meta: {{ relationship.meta | tojson(indent=2) }}\\n\\n{% endfor %}\\n"},"knowledge_graph":{"enabled":true,"depth":2,"include_meta":true,"with_degree":false,"using_intent_search":true}},
'user_id': null,
'browser_id': null,
'created_at': '2024-07-08T10:28:39',
'deleted_at': null,
'origin': null,
},
'user_message': {
'id': 60033,
'created_at': '2024-07-08T10:28:40',
'role': 'user',
'trace_url': null,
'finished_at': null,
'user_id': null,
'updated_at': '2024-07-08T10:28:40',
'ordinal': 1,
'content': 'What is TiDB?',
'error': null,
'sources': [],
'chat_id': '019091e3-5cfc-74a3-b5e0-653a73f52af2',
"post_verification_result_url": null,
},
'assistant_message': {
'id': 60034,
'created_at': '2024-07-08T10:28:40',
'role': 'assistant',
'trace_url': 'https://us.cloud.langfuse.com/trace/fd18e8c4-94b7-4b6a-a6e7-3877f07a3d2d',
'finished_at': '2024-07-08T10:30:12.129128Z',
'user_id': null,
'ordinal': 2,
'content': '### Comprehensive Overview of TiDB\n\n#### What is TiDB?\nTiDB is an open-source distributed SQL database designed to support Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL-compatible and offers horizontal scalability, strong consistency, and high availability[^1][^2]. TiDB aims to provide a one-stop solution for OLTP (Online Transactional Processing), OLAP (Online Analytical Processing), and HTAP services, making it suitable for various use cases that require high availability and strong consistency with large-scale data[^3].\n\n#### Architecture of TiDB\nThe TiDB cluster consists of three main components[^4][^5]:\n1. **TiDB Server**: Handles SQL parsing, query planning, and execution.\n2. **TiKV Server**: Acts as the distributed key-value storage engine, storing the actual data.\n3. **PD (Placement Driver) Server**: Manages cluster metadata, allocates timestamps, and handles data placement and load balancing.\n\nAdditionally, TiDB includes:\n- **TiFlash**: A columnar storage engine for analytical workloads, providing high concurrency for `INSERT` and `UPDATE` operations without impacting OLTP performance[^6].\n- **TiSpark**: A connector that enables Spark to access data stored in TiDB[^7].\n- **TiDB Binlog**: A tool for capturing and replicating data changes[^8].\n- **TiDB Lightning**: A high-performance tool for importing data into TiDB[^9].\n\n#### Key Features of TiDB\n1. **Horizontal Scalability**: TiDB allows for easy horizontal scaling of both computing and storage resources, making it adaptable to changing workloads[^10]. The architecture separates computing from storage, enabling independent scaling[^11].\n2. **High Availability**: TiDB ensures high availability through data replication and the Multi-Raft protocol, guaranteeing data integrity even in the event of failures[^12]. It supports automatic failover when a minority of replicas fail, making it transparent to applications[^13].\n3. **HTAP Capabilities**: TiDB supports both row-based (TiKV) and columnar (TiFlash) storage engines, enabling real-time processing of both transactional and analytical workloads[^14].\n4. **Cloud-Native Design**: TiDB is built for cloud environments, offering flexible scalability, reliability, and security on various cloud platforms[^15]. It integrates seamlessly with Kubernetes and offers a fully-managed service (TiDB Cloud)[^16].\n5. **MySQL Compatibility**: TiDB is compatible with the MySQL 5.7 protocol and ecosystem, allowing for easy migration of applications with minimal code changes[^17]. However, it does not support triggers, stored procedures, and user-defined functions[^18].\n\n#### Ensuring High Availability, Scalability, and Performance\n- **High Availability**: TiDB achieves high availability through its multi-replica architecture and the Multi-Raft protocol, which ensures that data is consistently replicated across multiple nodes[^19]. Transactions can only be committed when data has been successfully written into the majority of replicas[^20].\n- **Scalability**: TiDB\'s architecture allows for flexible and elastic scaling by separating computing from storage. This design enables users to scale out or scale in the computing or storage capacity online as needed[^21].\n- **Performance**: TiDB provides high performance through various optimizations, including the use of TiFlash for analytical workloads and the DeltaTree structure for efficient data modification[^22]. The system also supports distributed transactions using a two-phase commit protocol with optimizations inspired by Google\'s Percolator[^23].\n\n#### Compatibility with MySQL\nTiDB supports most MySQL 5.7 syntax and features, making it highly compatible with MySQL applications[^24]. This compatibility allows users to migrate applications to TiDB without changing a single line of code in many cases[^25]. However, certain features like triggers, stored procedures, and user-defined functions are not supported[^26].\n\n### Conclusion\nTiDB is a robust, scalable, and high-performance distributed SQL database designed for modern data workloads. Its architecture, key features, and compatibility with MySQL make it a versatile solution for various use cases, ensuring high availability, scalability, and performance.\n\n[^1]: [TiDB Overview | PingCAP Docs](https://docs.pingcap.com/tidb/stable/overview)\n[^2]: [TiDB Introduction | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)\n[^3]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/tidb-faq)\n[^4]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-architecture)\n[^5]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/tidb-faq)\n[^6]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-faq)\n[^7]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidbcloud/tidb-architecture)\n[^8]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-architecture)\n[^9]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-architecture)\n[^10]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/overview)\n[^11]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)\n[^12]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/overview)\n[^13]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-architecture)\n[^14]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/overview)\n[^15]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/overview)\n[^16]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/overview)\n[^17]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/overview)\n[^18]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-faq)\n[^19]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v5.4/overview)\n[^20]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/overview)\n[^21]: [TiDB Key Features | PingCAP Docs](https://docs.pingcap.com/tidb/v7.1/overview)\n[^22]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-faq)\n[^23]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v6.5/tidb-faq)\n[^24]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-faq)\n[^25]: [TiDB Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-architecture)\n[^26]: [TiDB Introduction and Architecture | PingCAP Docs](https://docs.pingcap.com/tidb/v7.5/tidb-faq)',
'error': null,
"post_verification_result_url": null,
'sources': [
{
'id': 8247,
'name': 'Overview',
'source_uri': 'https://docs.pingcap.com/tidb/v7.5/tidb-architecture',
}, {
'id': 8600,
'name': 'TiDB FAQs',
'source_uri': 'https://docs.pingcap.com/tidb/v7.5/tidb-faq',
}, {
'id': 9057,
'name': 'Overview',
'source_uri': 'https://docs.pingcap.com/tidbcloud/tidb-architecture',
}, {
'id': 9373,
'name': 'TiDB Introduction',
'source_uri': 'https://docs.pingcap.com/tidb/v7.1/overview',
}, {
'id': 9865,
'name': 'Overview',
'source_uri': 'https://docs.pingcap.com/tidb/v7.1/tidb-architecture',
}, {
'id': 10191,
'name': 'TiDB FAQs',
'source_uri': 'https://docs.pingcap.com/tidb/v7.1/tidb-faq',
}, {
'id': 10578,
'name': 'TiDB Introduction',
'source_uri': 'https://docs.pingcap.com/tidb/v6.5/overview',
}, {
'id': 11370,
'name': 'TiDB FAQs',
'source_uri': 'https://docs.pingcap.com/tidb/v6.5/tidb-faq',
}, {
'id': 12985,
'name': 'TiDB Introduction',
'source_uri': 'https://docs.pingcap.com/tidb/v5.4/overview',
},
],
'chat_id': '019091e3-5cfc-74a3-b5e0-653a73f52af2',
},
} as any));
if (!result.success) {
throw result.error;
}
});
================================================
FILE: frontend/app/src/components/chat/chats-history.tsx
================================================
import { type Chat, deleteChat, listChats } from '@/api/chats';
import { useAuth } from '@/components/auth/AuthProvider';
import { DangerousActionButton } from '@/components/dangerous-action-button';
import { NextLink } from '@/components/nextjs/NextLink';
import { Button } from '@/components/ui/button';
import { SidebarMenuSkeleton } from '@/components/ui/sidebar';
import { cn } from '@/lib/utils';
import { TrashIcon } from 'lucide-react';
import { usePathname } from 'next/navigation';
import { useEffect } from 'react';
import useSWR from 'swr';
export function ChatsHistory () {
const pathname = usePathname();
const auth = useAuth();
const user = auth.me;
const { data: history, mutate, isLoading, isValidating } = useSWR('api.chats.list?size=8', () => listChats({ size: 8 }), {
revalidateOnMount: false,
keepPreviousData: true,
});
useEffect(() => {
void mutate(() => undefined, { revalidate: true });
}, [user?.id]);
const isActive = (chat: Chat) => pathname === `/c/${chat.id}`;
return (
{isLoading && (
<>
>
)}
{history?.items.map(chat => (
{chat.title}
{
await deleteChat(chat.id).finally(() => mutate(history => history, { revalidate: true }));
}}
dialogTitle={`Are you sure to delete ${chat.title}?`}
dialogDescription="This action cannot be undone."
>
))}
);
}
================================================
FILE: frontend/app/src/components/chat/chats-table.tsx
================================================
'use client';
import { type Chat, deleteChat, listChats } from '@/api/chats';
import { actions } from '@/components/cells/actions';
import { datetime } from '@/components/cells/datetime';
import { link } from '@/components/cells/link';
import { metadataCell } from '@/components/cells/metadata';
import { DataTableRemote } from '@/components/data-table-remote';
import { createColumnHelper } from '@tanstack/table-core';
import { Trash2Icon } from 'lucide-react';
export function ChatsTable () {
return (
);
}
const helper = createColumnHelper();
const columns = [
helper.accessor('title', {
cell: link({ url: chat => `/c/${chat.id}` }),
}),
helper.accessor('origin', {}),
helper.accessor('created_at', { cell: datetime }),
helper.accessor('engine_id', {}),
helper.accessor('engine_options', { cell: metadataCell }),
helper.display({
header: 'Operations',
cell: actions(chat => [
{
key: 'delete',
title: 'Delete',
icon: ,
dangerous: {
dialogTitle: 'Are you sure to delete this chat?',
dialogDescription: 'This action cannot be undone.',
},
action: async ({ table }) => {
await deleteChat(chat.id);
table.reload?.();
},
},
]),
}),
];
================================================
FILE: frontend/app/src/components/chat/conversation-message-groups.scss
================================================
@keyframes highlight {
from {
@apply bg-amber-500/10;
}
to {
@apply bg-transparent;
}
}
.animate-highlight {
animation-name: highlight;
animation-delay: 2s;
animation-duration: 1s;
animation-timing-function: ease;
animation-fill-mode: both;
}
================================================
FILE: frontend/app/src/components/chat/conversation-message-groups.tsx
================================================
import { type ChatMessageGroup, useChatInfo, useChatPostState, useCurrentChatController } from '@/components/chat/chat-hooks';
import { DebugInfo } from '@/components/chat/debug-info';
import { MessageAnnotationHistory } from '@/components/chat/message-annotation-history';
import { StackVMMessageAnnotationHistory } from '@/components/chat/message-annotation-history-stackvm';
import { MessageAnswer } from '@/components/chat/message-answer';
import { MessageAutoScroll } from '@/components/chat/message-auto-scroll';
import { MessageContextSources } from '@/components/chat/message-content-sources';
import { MessageError } from '@/components/chat/message-error';
import { MessageOperations } from '@/components/chat/message-operations';
import { MessageRecommendQuestions } from '@/components/chat/message-recommend-questions';
import { MessageSection } from '@/components/chat/message-section';
import { Button } from '@/components/ui/button';
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '@/components/ui/collapsible';
import { MessageVerify } from '@/experimental/chat-verify-service/message-verify';
import { cn } from '@/lib/utils';
import { InfoIcon } from 'lucide-react';
import { useEffect, useState } from 'react';
import './conversation-message-groups.scss';
const isWidgetEnv = !!process.env.NEXT_PUBLIC_IS_WIDGET;
export function ConversationMessageGroups ({ groups }: { groups: ChatMessageGroup[] }) {
const controller = useCurrentChatController();
const { params, initialized } = useChatPostState(useCurrentChatController());
useEffect(() => {
if (!isWidgetEnv) {
const scroll = () => {
setTimeout(() => {
window.scrollTo({
left: 0,
top: document.body.scrollHeight,
behavior: 'smooth',
});
}, 100);
};
controller
.on('post', scroll)
.on('post-initialized', scroll);
return () => {
controller
.off('post', scroll)
.off('post-initialized', scroll);
};
}
}, [controller]);
return (
{groups.map((group, index) => (
))}
{!!params && !initialized && (
)}
);
}
function ConversationMessageGroup ({ group, isLastGroup }: { group: ChatMessageGroup, isLastGroup: boolean }) {
const enableDebug = /* !!me && */ !process.env.NEXT_PUBLIC_DISABLE_DEBUG_PANEL;
const { engine_options } = useChatInfo(useCurrentChatController()) ?? {};
const { params } = useChatPostState(useCurrentChatController());
const [debugInfoOpen, setDebugInfoOpen] = useState(false);
const [highlight, setHighlight] = useState(false);
useEffect(() => {
if (group.assistant && location.hash.slice(1) === String(group.assistant.id)) {
setHighlight(true);
document.getElementById(String(group.assistant.id))?.scrollIntoView({ behavior: 'instant', block: 'start' });
}
}, []);
return (
setHighlight(false)}
>
{group.user.content}
{enableDebug &&
Toggle
}
{group.assistant?.version === 'Legacy' && }
{group.assistant?.version === 'StackVM' && }
{!engine_options?.hide_sources &&
}
{group.assistant && }
{group.assistant && }
{group.assistant && }
{/* Only show recommend questions if enabled in engine_options */}
{!params && isLastGroup && group.hasLastAssistantMessage && engine_options?.further_questions !== false && (
)}
);
}
================================================
FILE: frontend/app/src/components/chat/conversation.test.tsx
================================================
describe('Conversation component', () => {
test('button should be disabled when input is empty', () => {
// Create a simple test scenario without the full component complexity
const textarea = document.createElement('textarea');
const button = document.createElement('button');
// Simulate the input validation logic
const validateInput = (value: string) => {
return !value.trim();
};
// Test empty input
textarea.value = '';
button.disabled = validateInput(textarea.value);
expect(button.disabled).toBe(true);
// Test whitespace only
textarea.value = ' ';
button.disabled = validateInput(textarea.value);
expect(button.disabled).toBe(true);
// Test whitespace with tab
textarea.value = ' \t';
button.disabled = validateInput(textarea.value);
expect(button.disabled).toBe(true);
// Test with actual content
textarea.value = 'foo';
button.disabled = validateInput(textarea.value);
expect(button.disabled).toBe(false);
});
});
================================================
FILE: frontend/app/src/components/chat/conversation.tsx
================================================
'use client';
import type { Chat, ChatMessage } from '@/api/chats';
import type { ChatController } from '@/components/chat/chat-controller';
import { ChatControllerProvider, useChatController, useChatMessageControllers, useChatMessageGroups, useChatPostState } from '@/components/chat/chat-hooks';
import { ConversationMessageGroups } from '@/components/chat/conversation-message-groups';
import { MessageInput } from '@/components/chat/message-input';
import { SecuritySettingContext, withReCaptcha } from '@/components/security-setting-provider';
import { useSize } from '@/components/use-size';
import { cn } from '@/lib/utils';
import { type ChangeEvent, type FormEvent, type ReactNode, type Ref, useContext, useImperativeHandle, useState } from 'react';
export interface ConversationProps {
chatId?: string;
className?: string;
open: boolean;
chat: Chat | undefined;
history: ChatMessage[];
/* Only for widgets */
placeholder?: (controller: ChatController, postState: ReturnType) => ReactNode;
preventMutateBrowserHistory?: boolean;
preventShiftMessageInput?: boolean;
newChatRef?: Ref;
}
export function Conversation ({ open, chat, chatId, history, placeholder, preventMutateBrowserHistory = false, preventShiftMessageInput = false, newChatRef, className }: ConversationProps) {
const [inputElement, setInputElement] = useState(null);
const controller = useChatController(chatId, chat, history, inputElement);
const postState = useChatPostState(controller);
const groups = useChatMessageGroups(useChatMessageControllers(controller));
const [input, setInput] = useState('');
const handleInputChange = (e: ChangeEvent) => {
setInput(e.target.value);
};
const { ref, size } = useSize();
const security = useContext(SecuritySettingContext);
const submitWithReCaptcha = async (e: FormEvent) => {
e.preventDefault();
withReCaptcha({
action: 'chat',
siteKey: security?.google_recaptcha_site_key || '',
mode: security?.google_recaptcha,
}, ({ token, action }) => {
controller.post({
content: input,
headers: {
'X-Recaptcha-Token': token,
'X-Recaptcha-Action': action,
},
});
setInput('');
});
};
const disabled = !!postState.params;
const actionDisabled = disabled || !input.trim();
useImperativeHandle(newChatRef, () => {
return controller.post.bind(controller);
}, [controller]);
return (
{!postState.params && !groups.length && placeholder?.(controller, postState)}
{size && open && }
);
}
================================================
FILE: frontend/app/src/components/chat/debug-info.tsx
================================================
import { type ChatMessageGroup, useChatInfo, useChatMessageField, useCurrentChatController } from '@/components/chat/chat-hooks';
import { KnowledgeGraphDebugInfo } from '@/components/chat/knowledge-graph-debug-info';
import { DateFormat } from '@/components/date-format';
import { OptionDetail } from '@/components/option-detail';
// import { MessageLangfuse } from '@/components/chat/message-langfuse';
import { differenceInSeconds } from 'date-fns';
import { WorkflowIcon } from 'lucide-react';
import { useMemo } from 'react';
export interface DebugInfoProps {
group: ChatMessageGroup;
}
export function DebugInfo ({ group }: DebugInfoProps) {
const chat = useChatInfo(useCurrentChatController());
const traceURL = useChatMessageField(group.assistant, 'trace_url');
const createdAt = useChatMessageField(group.assistant, 'created_at');
const finishedAt = useChatMessageField(group.assistant, 'finished_at');
const stackVMUrl = useMemo(() => {
if (traceURL) {
try {
const url = new URL(traceURL);
if (url.host === 'stackvm.tidb.ai') {
const id = url.searchParams.get('task_id');
return `https://stackvm-ui.vercel.app/tasks/${id}`;
}
} catch {
return undefined;
}
}
}, [traceURL]);
return (
{traceURL &&
}
{/*
*/}
{chat && (
)}
);
}
================================================
FILE: frontend/app/src/components/chat/knowledge-graph-debug-info.tsx
================================================
import { getChatMessageSubgraph } from '@/api/chats';
import { useAuth } from '@/components/auth/AuthProvider';
import { type ChatMessageGroup, useChatInfo, useChatMessageStreamState, useCurrentChatController } from '@/components/chat/chat-hooks';
import type { OngoingState } from '@/components/chat/chat-message-controller';
import { AppChatStreamState, type StackVMState } from '@/components/chat/chat-stream-state';
import { NetworkViewer } from '@/components/graph/components/NetworkViewer';
import { useNetwork } from '@/components/graph/useNetwork';
import { PencilIcon } from 'lucide-react';
import Link from 'next/link';
import { useEffect } from 'react';
import useSWR from 'swr';
export function KnowledgeGraphDebugInfo ({ group }: { group: ChatMessageGroup }) {
const { engine_options } = useChatInfo(useCurrentChatController()) ?? {};
const auth = useAuth();
const ongoing = useChatMessageStreamState(group.assistant);
const kbLinked = engine_options?.knowledge_base?.linked_knowledge_bases;
const canEdit = !!auth.me?.is_superuser && kbLinked;
const shouldFetch = (!ongoing || ongoing.finished || couldFetchKnowledgeGraphDebugInfo(ongoing));
const { data: span, isLoading, mutate, error } = useSWR(
shouldFetch && `api.chats.get-message-subgraph?id=${group.user.id}`,
() => getChatMessageSubgraph(group.user.id),
{
revalidateOnReconnect: false,
revalidateOnFocus: false,
revalidateOnMount: false,
},
);
useEffect(() => {
if (shouldFetch && !error && !isLoading && !span) {
mutate(undefined, true);
}
}, [span, isLoading, error, shouldFetch]);
const network = useNetwork(span);
return (
{
if (!canEdit) return null;
if (!kbLinked) return null;
if (kbLinked.length === 1) {
return (
Edit graph
);
}
const placeholder =
Edit graph
;
if (!target) return placeholder;
if (target.type === 'node') {
const node = network.node(target.id);
if (!node?.knowledge_base_id) return placeholder;
return (
Edit graph
);
} else if (target.type === 'link') {
const link = network.node(target.id);
if (!link?.knowledge_base_id) return placeholder;
return (
Edit graph
);
}
return placeholder;
}
}
/>
);
}
function couldFetchKnowledgeGraphDebugInfo (state: OngoingState) {
switch (state.state) {
case AppChatStreamState.GENERATE_ANSWER:
case AppChatStreamState.FINISHED:
case AppChatStreamState.RERANKING:
case AppChatStreamState.SOURCE_NODES:
return true;
default:
return false;
}
}
================================================
FILE: frontend/app/src/components/chat/message-annotation-history-stackvm.tsx
================================================
import { useChatMessageField, useChatMessageStreamHistoryStates, useChatMessageStreamState } from '@/components/chat/chat-hooks';
import { type OngoingState, type OngoingStateHistoryItem, StackVMChatMessageController } from '@/components/chat/chat-message-controller';
import type { StackVMState, StackVMToolCall } from '@/components/chat/chat-stream-state';
import { isNotFinished } from '@/components/chat/utils';
import { DiffSeconds } from '@/components/diff-seconds';
import { RemarkContent } from '@/components/remark-content';
import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/popover';
import { StackVM } from '@/lib/stackvm';
import { cn } from '@/lib/utils';
import { motion, type Target } from 'framer-motion';
import { CheckCircleIcon, ChevronUpIcon, ClockIcon, InfoIcon, Loader2Icon, SearchIcon } from 'lucide-react';
import { useEffect, useMemo, useState } from 'react';
export function StackVMMessageAnnotationHistory ({ message }: { message: StackVMChatMessageController | undefined }) {
const [show, setShow] = useState(true);
const history = useChatMessageStreamHistoryStates(message);
const current = useChatMessageStreamState(message);
const error = useChatMessageField(message, 'error');
const traceUrl = useChatMessageField(message, 'trace_url');
const finished = !isNotFinished(current) || !!error;
const stackVMTaskUrl = useMemo(() => {
if (traceUrl) {
return traceUrl;
}
let id: string | undefined;
if (current) {
id = current.state.task_id;
} else {
id = history?.[0]?.state.state.task_id;
}
if (id) {
return `https://stackvm.tidb.ai/?task_id=${id}`;
}
return undefined;
}, [traceUrl, history, current]);
const stackVmUI = useMemo(() => {
if (!stackVMTaskUrl) {
return undefined;
}
try {
const url = new URL(stackVMTaskUrl);
const taskId = url.searchParams.get('task_id');
if (!taskId) {
return stackVMTaskUrl;
}
return `https://stackvm-ui.vercel.app/tasks/${taskId}`;
} catch {
return stackVMTaskUrl;
}
}, [stackVMTaskUrl]);
useEffect(() => {
if (finished) {
const handler = setTimeout(() => {
setShow(false);
}, 2000);
return () => {
clearTimeout(handler);
};
}
}, [finished]);
if (!history) {
return
;
}
return (
{history?.map((item, index, history) => (
))}
{error && }
{current && !current.finished && }
{stackVmUI && }
setShow(false)}>
Collapse
setShow(true)}
className="flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground transition-colors"
animate={show ? { height: 0, opacity: 0, overflow: 'visible', pointerEvents: 'none', scale: 0.5 } : { height: 'auto', opacity: 1, scale: 1, pointerEvents: 'auto' }}
>
{error ? 'Not finished' : }
);
}
const CheckedCircle = motion(CheckCircleIcon);
const InformationCircle = motion(InfoIcon);
const itemInitial: Target = { opacity: 0.5 };
const itemAnimate: Target = { opacity: 1 };
const itemIconInitial: Target = { color: 'rgb(113 113 122 / 50)' };
const itemSuccessIconAnimate: Target = { color: 'rgb(34 197 94)' };
const itemErrorIconAnimate: Target = { color: 'rgb(239 68 68)' };
function StackVMCheckpoint ({ state, pc }: { state: StackVMState, pc: boolean }) {
const step = useMemo(() => {
return state.state.plan.steps.find(step => step.id === `step:${pc ? state.state.program_counter : state.seq_no}`);
}, [state.state, state.seq_no, pc]);
if (!step) {
return null;
}
switch (step.type) {
case 'reasoning':
return 'Thoughts';
case 'assign':
return `Assign Variables`;
case 'calling':
return `Tool Call`;
case 'jmp':
return `Jump`;
}
}
function StackVMDetails ({ pc, state }: { state: StackVMState, pc: boolean }) {
const step = useMemo(() => {
return state.state.plan.steps.find(step => step.id === `step:${pc ? state.state.program_counter : state.seq_no}`);
}, [state.state, state.seq_no, pc]);
if (!step) {
return null;
}
switch (step.type) {
case 'reasoning':
return {(step as StackVM.model.StepModel<'reasoning'>).parameters.chain_of_thoughts} ;
case 'calling':
return (
} vars={state.state.variables??{}} pc={state.state.program_counter} toolCalls={state.toolCalls} />
);
case 'assign':
return (
{Object.entries(step.parameters).map(([key, value]) => (
{key}:
))}
);
default:
return null;
}
}
function ToolCallInfo ({ vars, step, pc, toolCalls }: { step: StackVM.model.StepModel<'calling'>, vars: Record, pc: number | undefined, toolCalls: StackVMToolCall[] }) {
const result = toolCalls.find(tc => tc.toolCallId === `${pc}`)?.result;
return (
{step.parameters.tool_name}
{' '}
(
{Object.entries(step.parameters.tool_params).map(([key, value]) => (
{key}:
))}
)
{result != null && <>
Result:
{'{'}
{step.output_vars.map((binding) => (
{binding.parameter}:
))}
{'}'}
>}
);
}
function JsonValueViewer ({ value }: { value: unknown }) {
if (value == null) {
return String(value);
}
if (typeof value === 'object') {
let label: string;
if (value instanceof Array) {
label = `array<${value.length} items>`;
} else {
label = `object<${Object.keys(value).length} entries>`;
}
return (
{label}
{JSON.stringify(value, undefined, 2)}
);
} else {
const isText = typeof value === 'string';
const string = String(value);
if (string.length > 25) {
return (
{isText && {'"'} }{string.slice(0, 25) + '...'}{isText && {'"'} }
{string}
);
} else {
return (
{isText && {'"'} }{string}{isText && {'"'} }
);
}
}
}
function MessageAnnotationHistoryItem ({ history, item: { state, time }, index }: { history: OngoingStateHistoryItem[], index: number, item: OngoingStateHistoryItem }) {
return (
{index > 1 && }
{state.display === '[deprecated]' ?