Repository: zhongyu09/openchatbi
Branch: main
Commit: 428f5d88bb12
Files: 133
Total size: 701.7 KB
Directory structure:
gitextract_wkfrsja_/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ └── feature_request.md
│ └── workflows/
│ ├── docs.yml
│ ├── publish.yml
│ └── runledger.yml
├── .gitignore
├── CONTRIBUTING.md
├── Dockerfile.python-executor
├── LICENSE
├── README.md
├── baselines/
│ └── runledger-openchatbi.json
├── docs/
│ ├── Makefile
│ ├── make.bat
│ └── source/
│ ├── _templates/
│ │ └── layout.html
│ ├── catalog.rst
│ ├── code.rst
│ ├── conf.py
│ ├── config.rst
│ ├── core.rst
│ ├── index.rst
│ ├── llm.rst
│ ├── text2sql.rst
│ ├── timeseries.rst
│ └── tools.rst
├── evals/
│ ├── __init__.py
│ └── runledger/
│ ├── README.md
│ ├── __init__.py
│ ├── agent/
│ │ └── agent.py
│ ├── cases/
│ │ └── t1.yaml
│ ├── cassettes/
│ │ └── t1.jsonl
│ ├── schema.json
│ ├── suite.yaml
│ └── tools.py
├── example/
│ ├── bi.yaml
│ ├── common_columns.csv
│ ├── config.yaml
│ ├── sql_example.yaml
│ ├── table_columns.csv
│ ├── table_info.yaml
│ └── table_selection_example.csv
├── openchatbi/
│ ├── __init__.py
│ ├── agent_graph.py
│ ├── catalog/
│ │ ├── __init__.py
│ │ ├── catalog_loader.py
│ │ ├── catalog_store.py
│ │ ├── factory.py
│ │ ├── helper.py
│ │ ├── retrival_helper.py
│ │ ├── schema_retrival.py
│ │ ├── store/
│ │ │ ├── __init__.py
│ │ │ └── file_system.py
│ │ └── token_service.py
│ ├── code/
│ │ ├── docker_executor.py
│ │ ├── executor_base.py
│ │ ├── local_executor.py
│ │ └── restricted_local_executor.py
│ ├── config.yaml.template
│ ├── config_loader.py
│ ├── constants.py
│ ├── context_config.py
│ ├── context_manager.py
│ ├── graph_state.py
│ ├── llm/
│ │ └── llm.py
│ ├── prompts/
│ │ ├── agent_prompt.md
│ │ ├── extraction_prompt.md
│ │ ├── schema_linking_prompt.md
│ │ ├── sql_dialect/
│ │ │ └── presto.md
│ │ ├── summary_prompt.md
│ │ ├── system_prompt.py
│ │ ├── text2sql_prompt.md
│ │ └── visualization_prompt.md
│ ├── text2sql/
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── extraction.py
│ │ ├── generate_sql.py
│ │ ├── schema_linking.py
│ │ ├── sql_graph.py
│ │ ├── text2sql_utils.py
│ │ └── visualization.py
│ ├── text_segmenter.py
│ ├── tool/
│ │ ├── ask_human.py
│ │ ├── mcp_tools.py
│ │ ├── memory.py
│ │ ├── run_python_code.py
│ │ ├── save_report.py
│ │ ├── search_knowledge.py
│ │ └── timeseries_forecast.py
│ └── utils.py
├── pyproject.toml
├── run_streamlit_ui.py
├── run_tests.py
├── sample_api/
│ └── async_api.py
├── sample_ui/
│ ├── async_graph_manager.py
│ ├── memory_ui.py
│ ├── plotly_utils.py
│ ├── simple_ui.py
│ ├── streaming_ui.py
│ ├── streamlit_ui.py
│ └── style.py
├── tests/
│ ├── README.md
│ ├── __init__.py
│ ├── conftest.py
│ ├── context_management/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_agent_graph_integration.py
│ │ ├── test_context_config.py
│ │ ├── test_context_manager.py
│ │ ├── test_edge_cases.py
│ │ ├── test_runner.py
│ │ └── test_state_operations.py
│ ├── test_catalog_loader.py
│ ├── test_catalog_store.py
│ ├── test_config_loader.py
│ ├── test_graph_state.py
│ ├── test_incomplete_tool_calls.py
│ ├── test_memory.py
│ ├── test_plotly_utils.py
│ ├── test_simple_store.py
│ ├── test_text2sql_extraction.py
│ ├── test_text2sql_generate_sql.py
│ ├── test_text2sql_schema_linking.py
│ ├── test_text2sql_visualization.py
│ ├── test_tools_ask_human.py
│ ├── test_tools_run_python_code.py
│ ├── test_tools_search_knowledge.py
│ └── test_utils.py
└── timeseries_forecasting/
├── Dockerfile
├── README.md
├── app.py
├── build_and_run.sh
├── model_handler.py
└── test_forecasting.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .github/workflows/docs.yml
================================================
name: Build and Deploy Documentation
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: "pages"
cancel-in-progress: false
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install -e ".[docs]"
- name: Build documentation
run: |
cd docs
make html
- name: Setup Pages
uses: actions/configure-pages@v4
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: './docs/build/html'
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
needs: build
if: github.ref == 'refs/heads/main'
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish to PyPI
on:
release:
types: [published] # Trigger when a release is published
workflow_dispatch: # Allow manual triggering
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11', '3.12']
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv sync --all-extras
- name: Run linting
run: |
uv run black --check .
- name: Run tests
run: |
uv run pytest -v --cov=openchatbi --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
build:
needs: test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Set up Python
run: uv python install 3.11
- name: Build package
run: |
uv build
- name: Check build artifacts
run: |
ls -la dist/
uv run twine check dist/*
- name: Upload build artifacts
uses: actions/upload-artifact@v4
with:
name: dist
path: dist/
publish:
needs: build
runs-on: ubuntu-latest
if: github.event_name == 'release'
environment:
name: pypi
url: https://pypi.org/p/openchatbi
permissions:
id-token: write # Required for PyPI trusted publishing
contents: read
steps:
- name: Download build artifacts
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
# Uses PyPI trusted publishing, no API token needed
verbose: true
print-hash: true
publish-test:
needs: build
runs-on: ubuntu-latest
if: github.event_name == 'workflow_dispatch' # Only publish to TestPyPI when manually triggered
environment:
name: testpypi
url: https://test.pypi.org/p/openchatbi
permissions:
id-token: write
contents: read
steps:
- name: Download build artifacts
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: Publish to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
verbose: true
print-hash: true
================================================
FILE: .github/workflows/runledger.yml
================================================
name: runledger
on:
workflow_dispatch:
pull_request:
paths:
- "openchatbi/**"
jobs:
runledger:
if: github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'runledger')
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install runledger
python -m pip install .
- name: Run deterministic evals (replay)
run: |
runledger run evals/runledger --mode replay --baseline baselines/runledger-openchatbi.json
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: runledger-artifacts
path: runledger_out/**
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# project spec
openchatbi/config.yaml
memory.db
checkpoints.db
data
hf_model
timeseries_forecasting/hf_model
runledger_out/
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to OpenChatBI
Hi there! Thank you for your interest in contributing to OpenChatBI.
OpenChatBI started as a personal project, with the hope of making it easier for businesses to build their own ChatBI applications with less effort. To achieve this goal, I made it open source, and I greatly appreciate contributions of all kinds.
Whether you’d like to propose a new feature, refactor the code, enhance documentation, or fix bugs, your contributions are always welcome.
================================================
FILE: Dockerfile.python-executor
================================================
FROM python:3.11-slim
# Set working directory
WORKDIR /app
# Install basic packages that might be needed for data analysis
RUN pip install --no-cache-dir \
pandas \
numpy \
matplotlib \
seaborn \
requests \
json5
# Create a directory for code execution
RUN mkdir -p /app/code
# Set up a non-root user for security
RUN useradd -m -u 1000 executor
USER executor
# Set the default command
CMD ["python3"]
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025 Yu Zhong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# OpenChatBI
OpenChatBI is an open source, chat-based intelligent BI tool powered by large language models, designed to help users
query, analyze, and visualize data through natural language conversations. Built on LangGraph and LangChain ecosystem,
it provides chat agents and workflows that support natural language to SQL conversion and streamlined data analysis.
Join the Slack channel to discuss: https://join.slack.com/t/openchatbicommunity/shared_invite/zt-3jpzpx9mv-Sk88RxpO4Up0L~YTZYf4GQ
## Core Features
1. **Natural Language Interaction**: Get data analysis results by asking questions in natural language
2. **Automatic SQL Generation**: Convert natural language queries into SQL statements using advanced text2sql workflows
with schema linking and well organized prompt engineering
3. **Data Visualization**: Generate intuitive data visualizations (via plotly)
4. **Data Catalog Management**: Automatically discovers and indexes database table structures, supports flexible catalog
storage backends with vector-based or BM25-based retrieval, and easily maintains business explanations for tables
and columns as well as optimizes Prompts.
5. **Time Series Forecasting**: Forecasting models deployed in-house that can be called as tools
6. **Code Execution**: Execute Python code for data analysis and visualization
7. **Interactive Problem-Solving**: Proactively ask users for more context when information is incomplete
8. **Persistent Memory**: Conversation management and user characteristic memory based on LangGraph checkpointing
9. **MCP Support**: Integration with MCP tools by configuration
10. **Knowledge Base Integration**: Answer complex questions by combining catalog based knowledge retrival and external
knowledge base retrival (via MCP tools)
11. **Web UI Interface**: Provide 2 sample UI: simple and streaming web interfaces using Gradio and Streamlit, easy to
integrate with other web applications
## Roadmap
1. **Anomaly Detection Algorithm**: Time series anomaly detection
2. **Root Cause Analysis Algorithm**: Multi-dimensional drill-down capabilities for anomaly investigation
# Getting started
## Installation & Setup
### Prerequisites
- Python 3.11 or higher
- Access to a supported LLM provider (OpenAI, Anthropic, etc.)
- Data Warehouse (Database) credentials (like Presto, PostgreSQL, MySQL, etc.)
- (Optional) Embedding model for vector-based retrieval - if not available, BM25-based retrieval will be used
- (Optional) Docker - required only for `docker` executor mode
**Note on Chinese Text Segmentation**: For better Chinese text retrieval, `jieba` is used for word segmentation. However, `jieba` is not compatible with Python 3.12+. On Python 3.12 and higher, the system automatically falls back to simple punctuation-based segmentation for Chinese text.
### Installation
1. **Using uv (recommended):**
```bash
git clone git@github.com:zhongyu09/openchatbi
uv sync
```
2. **Using pip:**
```bash
pip install openchatbi
```
3. **For development:**
```bash
git clone git@github.com:zhongyu09/openchatbi
uv sync --group dev
```
Optional: If you want to use `pysqlite3` (newer SQLite builds), you can install it manually. If build fails, install SQLite first:
On macOS, try to install sqlite using Homebrew:
```bash
brew install sqlite
brew info sqlite
export LDFLAGS="-L/opt/homebrew/opt/sqlite/lib"
export CPPFLAGS="-I/opt/homebrew/opt/sqlite/include"
```
On Amazon Linux / RHEL / CentOS:
```bash
sudo yum install sqlite-devel
```
On Ubuntu / Debian:
```bash
sudo apt-get update
sudo apt-get install libsqlite3-dev
```
### Run Demo
Run demo using **example dataset** from spider dataset. You need to provide "YOUR OPENAI API KEY" or change config to use other LLM providers.
**Note**: The demo example includes embedding model configuration. If you want to run without an embedding model, you can remove the `embedding_model` section in the config - BM25 retrieval will be used automatically.
```bash
cp example/config.yaml openchatbi/config.yaml
sed -i 's/YOUR_API_KEY_HERE/[YOUR OPENAI API KEY]/g' openchatbi/config.yaml
python run_streamlit_ui.py
```
### Configuration
1. **Create configuration file**
Copy the configuration template:
```bash
cp openchatbi/config.yaml.template openchatbi/config.yaml
```
Or create an empty YAML file.
2. **Configure your LLMs:**
```yaml
# Select which provider to use
default_llm: openai
# Define one or more providers
llm_providers:
openai:
default_llm:
class: langchain_openai.ChatOpenAI
params:
api_key: YOUR_API_KEY_HERE
model: gpt-4.1
temperature: 0.02
max_tokens: 8192
# Optional: Embedding model for vector-based retrieval and memory tools
# If not configured, BM25-based retrieval will be used, and the memory tools will not work
embedding_model:
class: langchain_openai.OpenAIEmbeddings
params:
api_key: YOUR_API_KEY_HERE
model: text-embedding-3-large
chunk_size: 1024
```
3. **Configure your data warehouse:**
```yaml
organization: Your Company
dialect: presto
data_warehouse_config:
uri: "presto://user@host:8080/catalog/schema"
include_tables:
- your_table_name
database_name: "catalog.schema"
```
### Running the Application
1. **Invoking LangGraph:**
```bash
export CONFIG_FILE=YOUR_CONFIG_FILE_PATH
```
```python
from openchatbi import get_default_graph
graph = get_default_graph()
graph.invoke({"messages": [{"role": "user", "content": "Show me ctr trends for the past 7 days"}]},
config={"configurable": {"thread_id": "1"}})
```
```
# System-generated SQL
SELECT date, SUM(clicks)/SUM(impression) AS ctr
FROM ad_performance
WHERE date >= CURRENT_DATE - 7 DAYS
GROUP BY date
ORDER BY date;
```
2. **Sample Web UI:**
Streamlit based UI:
```bash
streamlit run sample_ui streamlit_ui.py
```
Run Gradio based UI:
```bash
python sample_ui/streaming_ui.py
```
## Configuration Instructions
The configuration template is provided at `config.yaml.template`. Key configuration sections include:
### Basic Settings
- `organization`: Organization name (e.g., "Your Company")
- `dialect`: Database dialect (e.g., "presto")
- `bi_config_file`: Path to BI configuration file (e.g., "example/bi.yaml")
### Catalog Store Configuration
- `catalog_store`: Configuration for data catalog storage
- `store_type`: Storage type (e.g., "file_system")
- `data_path`: Path to catalog data stored by file system (e.g., "./example")
### Data Warehouse Configuration
- `data_warehouse_config`: Database connection settings
- `uri`: Connection string for your database
- `include_tables`: List of tables to include in catalog, leave empty to include all tables
- `database_name`: Database name for catalog
- `token_service`: Token service URL (for data warehouse that need token authentication like Presto)
- `user_name` / `password`: Token service credentials
### LLM Configuration
Various LLMs are supported based on LangChain, see LangChain API
Document(https://python.langchain.com/api_reference/reference.html#integrations) for full list that support
`chat_models`. You can configure different LLMs for different tasks:
- `default_llm`: Primary language model for general tasks
- `embedding_model`: (Optional) Model for embedding generation. If not configured, BM25-based text retrieval will be used as fallback, and the memory tools will not work
- `text2sql_llm`: (Optional) Specialized model for SQL generation. If not configured, uses `default_llm`
Multiple providers (optional):
- Configure multiple providers under `llm_providers` and select with `default_llm: `.
- In `sample_ui/streamlit_ui.py`, a provider dropdown appears when `llm_providers` is configured.
- In `sample_api/async_api.py`, pass `provider` in the `/chat/stream` request body.
Commonly used LLM providers and their corresponding classes and installation commands:
- **Anthropic**: `langchain_anthropic.ChatAnthropic`, `pip install langchain-anthropic`
- **OpenAI**: `langchain_openai.ChatOpenAI`, `pip install langchain-openai`
- **Azure OpenAI**: `langchain_openai.AzureChatOpenAI`, `pip install langchain-openai`
- **Google Vertex AI**: `langchain_google_vertexai.ChatVertexAI`, `pip install langchain-google-vertexai`
- **Bedrock**: `langchain_aws.ChatBedrock`, `pip install langchain-aws`
- **Huggingface**: `langchain_huggingface.ChatHuggingFace`, `pip install langchain-huggingface`
- **Deepseek**: `langchain_deepseek.ChatDeepSeek`, `pip install langchain-deepseek`
- **Ollama**: `langchain_ollama.ChatOllama`, `pip install langchain-ollama`
### Advanced Configuration
OpenChatBI supports sophisticated customization through prompt engineering and catalog management features:
- **Prompt Engineering Configuration**: Customize system prompts, business glossaries, and data warehouse introductions
- **Data Catalog Management**: Configure table metadata, column descriptions, and SQL generation rules
- **Business Rules**: Define table selection criteria and domain-specific SQL constraints
- **Forecasting Service**: Configure the forecasting service url and prompt based on your own deployment
For detailed configuration options and examples, see the [Advanced Features](#advanced-features) section.
## Architecture Overview
OpenChatBI is built using a modular architecture with clear separation of concerns:
1. **LangGraph Workflows**: Core orchestration using state machines for complex multi-step processes
2. **Catalog Management**: Flexible data catalog system with intelligent retrieval (vector-based or BM25 fallback)
3. **Text2SQL Pipeline**: Advanced natural language to SQL conversion with schema linking
4. **Code Execution**: Sandboxed Python execution environment for data analysis
5. **Tool Integration**: Extensible tool system for human interaction and knowledge search
6. **Persistent Memory**: SQLite-based conversation state management
## Technology Stack
- **Frameworks**: LangGraph, LangChain, FastAPI, Gradio/Streamlit
- **Large Language Models**: Azure OpenAI (GPT-4), Anthropic Claude, OpenAI GPT models
- **Text Retrieval**: Vector-based (with embedding models) or BM25-based (fallback without embeddings)
- **Databases**: Presto, Trino, MySQL with SQLAlchemy support
- **Code Execution**: Local Python, RestrictedPython, Docker containerization
- **Development**: Python 3.11+, with modern tooling (Black, Ruff, MyPy, Pytest)
- **Storage**: SQLite for conversation checkpointing, file system catalog storage
### Agent Graph
### Text2SQL Graph
## Project Structure
```
openchatbi/
├── README.md # Project documentation
├── pyproject.toml # Modern Python project configuration
├── Dockerfile.python-executor # Docker image for isolated code execution
├── run_tests.py # Test runner script
├── run_streamlit_ui.py # Streamlit UI launcher
├── openchatbi/ # Core application code
│ ├── __init__.py # Package initialization
│ ├── config.yaml.template # Configuration template
│ ├── config_loader.py # Configuration management
│ ├── constants.py # Application constants
│ ├── agent_graph.py # Main LangGraph workflow
│ ├── graph_state.py # State definition for workflows
│ ├── context_config.py # Context management configuration
│ ├── context_manager.py # Context window and token management
│ ├── text_segmenter.py # Text segmentation with jieba support
│ ├── utils.py # Utility functions and SimpleStore (BM25-based retrieval)
│ ├── catalog/ # Data catalog management
│ │ ├── __init__.py # Package initialization
│ │ ├── catalog_loader.py # Catalog loading logic
│ │ ├── catalog_store.py # Catalog storage interface
│ │ ├── factory.py # Catalog factory patterns
│ │ ├── helper.py # Catalog helper functions
│ │ ├── retrival_helper.py # Retrieval helper utilities
│ │ ├── schema_retrival.py # Schema retrieval logic
│ │ ├── token_service.py # Token service integration
│ │ └── store/ # Catalog storage implementations
│ │ └── file_system.py # File system-based catalog storage
│ ├── code/ # Code execution framework
│ │ ├── __init__.py # Package initialization
│ │ ├── executor_base.py # Base executor interface
│ │ ├── local_executor.py # Local Python execution
│ │ ├── restricted_local_executor.py # RestrictedPython execution
│ │ └── docker_executor.py # Docker-based isolated execution
│ ├── llm/ # LLM integration layer
│ │ ├── __init__.py # Package initialization
│ │ └── llm.py # LLM management and retry logic
│ ├── prompts/ # Prompt templates and engineering
│ │ ├── __init__.py # Package initialization
│ │ ├── agent_prompt.md # Main agent prompts
│ │ ├── extraction_prompt.md # Information extraction prompts
│ │ ├── system_prompt.py # System prompt management
│ │ ├── summary_prompt.md # Summary conversation prompts
│ │ ├── table_selection_prompt.md # Table selection prompts
│ │ ├── text2sql_prompt.md # Text-to-SQL prompts
│ │ └── sql_dialect/ # SQL dialect-specific prompts
│ ├── text2sql/ # Text-to-SQL conversion pipeline
│ │ ├── __init__.py # Package initialization
│ │ ├── data.py # Data and retriever for Text-to-SQL
│ │ ├── extraction.py # Information extraction
│ │ ├── generate_sql.py # SQL generation and execution logic
│ │ ├── schema_linking.py # Schema linking process
│ │ ├── sql_graph.py # SQL generation LangGraph workflow
│ │ ├── text2sql_utils.py # Text2SQL utilities
│ │ └── visualization.py # Data visualization functions
│ └── tool/ # LangGraph tools and functions
│ ├── ask_human.py # Human-in-the-loop interactions
│ ├── memory.py # Memory management tool
│ ├── mcp_tools.py # MCP (Model Context Protocol) integration
│ ├── run_python_code.py # Configurable Python code execution
│ ├── save_report.py # Report saving functionality
│ ├── search_knowledge.py # Knowledge base search
│ └── timeseries_forecast.py # Time series forecasting tool
├── sample_api/ # API implementations
│ └── async_api.py # Asynchronous FastAPI example
├── sample_ui/ # Web interface implementations
│ ├── memory_ui.py # Memory-enhanced UI interface
│ ├── plotly_utils.py # Plotly utilities and helpers
│ ├── simple_ui.py # Simple non-streaming Gradio UI
│ ├── streaming_ui.py # Streaming Gradio UI with real-time updates
│ ├── streamlit_ui.py # Streaming Streamlit UI with enhanced features
│ └── style.py # UI styling and CSS
├── example/ # Example configurations and data
│ ├── bi.yaml # BI configuration example
│ ├── config.yaml # Application config example
│ ├── table_info.yaml # Table information
│ ├── table_columns.csv # Table column registry
│ ├── common_columns.csv # Common column definitions
│ ├── sql_example.yaml # SQL examples for retrieval
│ ├── table_selection_example.csv # Table selection examples
│ └── tracking_orders.sqlite # Sample SQLite database
├── timeseries_forecasting/ # Time series forecasting service
│ ├── README.md # Forecasting service documentation
│ └── ... # Forecasting service implementation
├── tests/ # Test suite
│ ├── __init__.py # Package initialization
│ ├── conftest.py # Test configuration
│ ├── test_*.py # Test modules for various components
│ └── README.md # Testing documentation
├── docs/ # Documentation
│ ├── source/ # Sphinx documentation source
│ ├── build/ # Built documentation
│ ├── Makefile # Documentation build scripts
│ └── make.bat # Windows build script
└── .github/ # GitHub workflows and templates
└── workflows/ # CI/CD workflows
```
## Advanced Features
### Visualization configuration
You can choose rule-based or llm-based visualization or disable visualization.
```yaml
# Options: "rule" (rule-based), "llm" (LLM-based), or null (skip visualization)
visualization_mode: llm
```
### Prompt Engineering
#### Basic Knowledge & Glossary
You can define basic knowledge and glossary in `example/bi.yaml`, for example:
```yaml
basic_knowledge_glossary: |
# Basic Knowledge Introduction
The basic knowledge about your company and its business, including key concepts, metrics, and processes.
# Glossary
Common terms and their definitions used in your business context.
```
#### Data Warehouse Introduction
You can provide a brief introduction of your data warehouse in `example/bi.yaml`, for example:
```yaml
data_warehouse_introduction: |
# Data Warehouse Introduction
This data warehouse is built on Presto and contains various tables related to XXXXX.
The main fact tables include XXXX metrics, while dimension tables include XXXXX.
The data is updated hourly and is used for reporting and analysis purposes.
```
#### Table Selection Rules
You can configure table selection rules in `example/bi.yaml`, for example:
```yaml
table_selection_extra_rule: |
- All tables with is_valid can support both valid and invalid traffics
```
#### Custom SQL Rules
You can define your additional SQL Generation rules for tables in `example/table_info.yaml`, for example:
```yaml
sql_rule: |
### SQL Rules
- All event_date in the table are stored in **UTC**. If the user specifies a timezone (e.g., CET, PST), convert between timezones accordingly.
```
### Catalog Management
#### Introduction
High-quality catalog data is essential for accurate Text2SQL generation and data analysis. OpenChatBI automatically
discovers and indexes data warehouse table structures while providing flexible management for business metadata, column
descriptions, and query optimization rules.
#### Catalog Structure
The catalog system organizes metadata in a hierarchical structure:
**Database Level**
- Top-level container for all tables and schemas
**Table Level**
- `description`: Business functionality and purpose of the table
- `selection_rule`: Guidelines for when and how to use this table in queries
- `sql_rule`: Specific SQL generation rules and constraints for this table
**Column Level**
- **Required Fields**: Essential metadata for each column to enable effective Text2SQL generation
- `column_name`: Technical database column name
- `display_name`: Human-readable name for business users
- `alias`: Alternative names or abbreviations
- `type`: Data type (string, integer, date, etc.)
- `category`: Business category, dimension or metric
- `tag`: Additional labels for filtering and organization
- `description`: Detailed explanation of column purpose and usage
- **Two Types** of Columns
- **Common Columns**: Columns with standardized business meanings shared across tables
- **Table-Specific Columns**: Columns with context-dependent meanings that vary between tables
- **Derived Metrics**: Virtual metrics calculated from existing columns using SQL formulas
- Computed dynamically during query execution rather than stored as physical columns
- Examples: CTR (clicks/impressions), conversion rates, profit margins
- Enable complex business calculations without pre-computing values
#### Loading Catalog from Database
OpenChatBI can automatically discover and load table structures from your data warehouse:
1. **Automatic Discovery**: Connects to your configured data warehouse and scans table schemas
2. **Metadata Extraction**: Extracts column names, data types, and basic structural information
3. **Incremental Updates**: Supports updating catalog data as your database schema evolves
Configure automatic catalog loading in your `config.yaml`:
```yaml
catalog_store:
store_type: file_system
data_path: ./catalog_data
data_warehouse_config:
include_tables:
- your_table_pattern
# Leave empty to include all accessible tables
```
#### File System Catalog Store
The file system catalog store organizes metadata across multiple files for maintainability and version control:
**Core Table Information**
- `table_info.yaml`: Comprehensive table metadata organized hierarchically (database → table → information)
- `type`: Table classification (e.g., "fact" for Fact Tables, "dimension" for Dimension Tables)
- `description`: Business functionality and purpose
- `selection_rule`: Usage guidelines in markdown list format (each line starts with `-`)
- `sql_rule`: SQL generation rules in markdown header format (each rule starts with `####`)
- `derived_metric`: Virtual metrics with calculation formulas, organized by groups:
```md
#### Derived Ratio Metrics
Click-through Rate (alias CTR): SUM(clicks) / SUM(impression)
Conversion Rate (alias CVR): SUM(conversions) / SUM(clicks)
```
**Column Management**
- `table_columns.csv`: Basic column registry with schema `db_name,table_name,column_name`
- `table_spec_columns.csv`: Table-specific column metadata with full schema:
`db_name,table_name,column_name,display_name,alias,type,category,tag,description`
- `common_columns.csv`: Shared column definitions across tables with schema:
`column_name,display_name,alias,type,category,tag,description`
**Query Examples and Training Data**
- `table_selection_example.csv`: Table selection training examples with schema `question,selected_tables`
- `sql_example.yaml`: Query examples organized by database and table structure:
```yaml
your_database:
ad_performance: |
Q: Show me CTR trends for the past 7 days
A: SELECT date, SUM(clicks)/SUM(impressions) AS ctr
FROM ad_performance
WHERE date >= CURRENT_DATE - INTERVAL 7 DAY
GROUP BY date
ORDER BY date;
```
### Time Series Forecasting Service Setup
OpenChatBI can integrate with a time series forecasting service for advanced predictive analytics. Follow these steps to set up the service:
#### 1. Build and Run the Forecasting Service
See detailed instructions in [timeseries_forecasting/README.md](timeseries_forecasting/README.md)
Quick start:
```bash
cd timeseries_forecasting
./build_and_run.sh
```
#### 2. Configure Tool Usage Rules
In your `bi.yaml`, add constraints for the timeseries_forecast tool, e.g. if you are using `timer-base-84m` model:
```yaml
extra_tool_use_rule: |
- timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros.
```
#### 3. Configure Service URL
In your `config.yaml`:
```yaml
# Time Series Forecasting Service Configuration
timeseries_forecasting_service_url: "http://localhost:8765"
```
**Important**: Adjust the URL based on your deployment scenario:
- **Local development** (OpenChatBI on host, Forecasting service in Docker): `http://localhost:8765`
- **Remote service**: `http://your-service-host:8765`
#### 4. Verify Service Health
Test the service is accessible:
```bash
curl http://localhost:8765/health
```
Expected response:
```json
{
"status": "healthy",
"model_initialized": true,
"uptime_seconds": 123.45
}
```
### Python Code Execution Configuration
OpenChatBI supports multiple execution environments for running Python code with different security and performance characteristics:
```yaml
# Python Code Execution Configuration
python_executor: local # Options: "local", "restricted_local", "docker"
```
#### Executor Types
- **`local`** (Default)
- **Performance**: Fastest execution
- **Security**: Least secure (code runs in current Python process)
- **Capabilities**: Full Python capabilities and library access
- **Use Case**: Development environments, trusted code execution
- **`restricted_local`**
- **Performance**: Moderate execution speed
- **Security**: Moderate security with RestrictedPython sandboxing
- **Capabilities**: Limited Python features (no imports, file access, etc.)
- **Use Case**: Semi-trusted environments with controlled execution
- **`docker`**
- **Performance**: Slower due to container overhead
- **Security**: Highest security with complete process isolation
- **Capabilities**: Full Python capabilities within isolated container
- **Use Case**: Production environments, untrusted code execution
- **Requirements**: Docker must be installed and running
#### Docker Executor Setup
For production deployments or when running untrusted code, the Docker executor provides complete isolation:
1. **Install Docker**: Download and install Docker Desktop or Docker Engine
2. **Configure executor**: Set `python_executor: docker` in your config
3. **Automatic setup**: OpenChatBI will automatically build the required Docker image
4. **Fallback behavior**: If Docker is unavailable, automatically falls back to local executor
**Docker Executor Features**:
- Pre-installed data science libraries (pandas, numpy, matplotlib, seaborn)
- Network isolation for security
- Automatic container cleanup
- Resource isolation from host system
## Development & Testing
### Code Quality Tools
The project uses modern Python tooling for code quality:
```bash
# Format code
uv run black .
# Lint code
uv run ruff check .
# Type checking
uv run mypy openchatbi/
# Security scanning
uv run bandit -r openchatbi/
```
### Testing
Run the test suite:
```bash
# Run all tests
uv run pytest
# Run with coverage
uv run pytest --cov=openchatbi --cov-report=html
# Run specific test files
uv run pytest test/test_generate_sql.py
uv run pytest test/test_agent_graph.py
```
### Pre-commit Hooks
Install pre-commit hooks for automatic code quality checks:
```bash
uv run pre-commit install
```
## Contribution Guidelines
1. Fork the repository
2. Create a feature branch (`git checkout -b feature/fooBar`)
3. Commit your changes (`git commit -am 'Add some fooBar'`)
4. Push to the branch (`git push origin feature/fooBar`)
5. Create a new Pull Request
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details
## Contact & Support
- **Author**: Yu Zhong ([zhongyu8@gmail.com](mailto:zhongyu8@gmail.com))
- **Repository**: [github.com/zhongyu09/openchatbi](https://github.com/zhongyu09/openchatbi)
- **Issues**: [Report bugs and feature requests](https://github.com/zhongyu09/openchatbi/issues)
================================================
FILE: baselines/runledger-openchatbi.json
================================================
{
"aggregates": {
"cases_error": 0,
"cases_fail": 0,
"cases_pass": 1,
"cases_total": 1,
"metrics": {
"cost_usd": {
"max": null,
"mean": null,
"min": null,
"p50": null,
"p95": null
},
"steps": {
"max": null,
"mean": null,
"min": null,
"p50": null,
"p95": null
},
"tokens_in": {
"max": null,
"mean": null,
"min": null,
"p50": null,
"p95": null
},
"tokens_out": {
"max": null,
"mean": null,
"min": null,
"p50": null,
"p95": null
},
"tool_calls": {
"max": 1.0,
"mean": 1.0,
"min": 1.0,
"p50": 1.0,
"p95": 1.0
},
"tool_errors": {
"max": 0.0,
"mean": 0.0,
"min": 0.0,
"p50": 0.0,
"p95": 0.0
},
"wall_ms": {
"max": 1.0,
"mean": 1.0,
"min": 1.0,
"p50": 1.0,
"p95": 1.0
}
},
"pass_rate": 1.0
},
"cases": [
{
"assertions": {
"failed": 0,
"total": 1
},
"cost_usd": null,
"failed_assertions": null,
"failure_reason": null,
"id": "t1",
"replay": {
"cassette_path": "evals/runledger/cassettes/t1.jsonl",
"cassette_sha256": "7e9830609490d140bf09178106dfa647bba4c9ec15859072b5aa2c3ae1659289"
},
"status": "pass",
"steps": null,
"tokens_in": null,
"tokens_out": null,
"tool_calls": 1,
"tool_calls_by_name": {
"search_knowledge": 1
},
"tool_errors": 0,
"tool_errors_by_name": {},
"wall_ms": 1
}
],
"generated_at": "2026-01-03T19:10:00Z",
"run": {
"ci": null,
"exit_status": "success",
"git_sha": null,
"mode": "replay",
"run_id": "baseline"
},
"runledger_version": "0.1.1",
"schema_version": 1,
"suite": {
"agent_command": [
"python",
"evals/runledger/agent/agent.py"
],
"cases_total": 1,
"name": "runledger-openchatbi",
"suite_config_hash": null,
"suite_path": "evals/runledger/suite.yaml",
"tool_mode": "replay"
}
}
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
================================================
FILE: docs/source/_templates/layout.html
================================================
{% extends "!layout.html" %}
{% block extrahead %}
{{ super() }}
{% endblock %}
================================================
FILE: docs/source/catalog.rst
================================================
Catalog System
==============
Overview
--------
The catalog system manages metadata for database tables, columns, and business rules.
Catalog Store
-------------
.. automodule:: openchatbi.catalog.catalog_store
:members:
:undoc-members:
:show-inheritance:
Filesystem Implementation
^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: openchatbi.catalog.store.file_system
:members:
:show-inheritance:
Catalog Loader
--------------
.. automodule:: openchatbi.catalog.catalog_loader
:members:
:undoc-members:
:show-inheritance:
Schema Retrieval
----------------
.. automodule:: openchatbi.catalog.schema_retrival
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/code.rst
================================================
Code Execution
==============
Code Module
-----------
.. automodule:: openchatbi.code
:members:
:undoc-members:
:show-inheritance:
Executor Base
-------------
.. automodule:: openchatbi.code.executor_base
:members:
:undoc-members:
:show-inheritance:
Local Executor
--------------
.. automodule:: openchatbi.code.local_executor
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
project = "OpenChatBI"
copyright = "2025, Yu Zhong"
author = "Yu Zhong"
release = "0.2.2"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
# Mock dependencies for documentation build
autodoc_mock_imports = []
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.githubpages",
"myst_parser",
]
# Set an environment variable to indicate we're building docs
import os
os.environ["SPHINX_BUILD"] = "1"
# MyST parser configuration
myst_enable_extensions = [
"colon_fence",
"deflist",
"html_admonition",
"html_image",
]
myst_heading_anchors = 3
templates_path = ["_templates"]
exclude_patterns = []
# Autodoc configuration
autodoc_default_options = {
"members": True,
"member-order": "bysource",
"special-members": "__init__",
"undoc-members": True,
"exclude-members": "__weakref__",
}
# Napoleon configuration for Google/NumPy style docstrings
napoleon_google_docstring = True
napoleon_numpy_docstring = True
napoleon_include_init_with_doc = False
napoleon_include_private_with_doc = False
napoleon_include_special_with_doc = True
napoleon_use_admonition_for_examples = False
napoleon_use_admonition_for_notes = False
napoleon_use_admonition_for_references = False
napoleon_use_ivar = False
napoleon_use_param = True
napoleon_use_rtype = True
napoleon_preprocess_types = False
napoleon_type_aliases = None
napoleon_attr_annotations = True
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]
# GitHub Pages configuration
html_baseurl = "https://zhongyu09.github.io/openchatbi/"
# Theme options for RTD theme
html_theme_options = {
"navigation_depth": 4,
"collapse_navigation": False,
"sticky_navigation": True,
"includehidden": True,
"titles_only": False,
}
================================================
FILE: docs/source/config.rst
================================================
Configuration
=============
The configuration system consists of two main classes:
- **Config**: Defines the configuration model.
- **ConfigLoader**: Manages loading and accessing configuration.
Config
------
.. autoclass:: openchatbi.config_loader.Config
:exclude-members: organization, dialect, default_llm, embedding_model, text2sql_llm, bi_config, data_warehouse_config, catalog_store, mcp_servers, report_directory, python_executor
ConfigLoader
------------
.. autoclass:: openchatbi.config_loader.ConfigLoader
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/core.rst
================================================
Core Module
===========
Main Module
-----------
.. automodule:: openchatbi
:members:
:undoc-members:
:show-inheritance:
Agent Graph
-----------
.. automodule:: openchatbi.agent_graph
:members:
:undoc-members:
:show-inheritance:
State Management
----------------
.. automodule:: openchatbi.graph_state
:members:
:undoc-members:
:show-inheritance:
Utilities
---------
.. automodule:: openchatbi.utils
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/index.rst
================================================
OpenChatBI Documentation
========================
`GitHub Repository `_
.. include:: ../../README.md
:parser: myst_parser.sphinx_
.. toctree::
:maxdepth: 4
:caption: Documentation:
:titlesonly:
self
.. toctree::
:maxdepth: 2
:caption: API Reference:
Core Module
Configuration
Catalog System
Text2SQL System
Code Execution
LLM Integration
Tools and Utilities
Time Series Forecasting Service
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
================================================
FILE: docs/source/llm.rst
================================================
LLM Integration
===============
LLM Module
----------
.. automodule:: openchatbi.llm
:members:
:undoc-members:
:show-inheritance:
LLM Implementation
------------------
.. automodule:: openchatbi.llm.llm
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/text2sql.rst
================================================
Text2SQL System
===============
Overview
--------
Natural language to SQL conversion pipeline with schema linking and prompt engineering.
SQL Graph
---------
.. automodule:: openchatbi.text2sql.sql_graph
:members:
:undoc-members:
:show-inheritance:
SQL Generation
--------------
.. automodule:: openchatbi.text2sql.generate_sql
:members:
:undoc-members:
:show-inheritance:
Schema Linking
--------------
.. automodule:: openchatbi.text2sql.schema_linking
:members:
:undoc-members:
:show-inheritance:
Information Extraction
----------------------
.. automodule:: openchatbi.text2sql.extraction
:members:
:undoc-members:
:show-inheritance:
Text2SQL Utilities
-------------------
.. automodule:: openchatbi.text2sql.text2sql_utils
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/timeseries.rst
================================================
Time Series Forecasting Service
========================
`GitHub Repository `_
.. include:: ../../timeseries_forecasting/README.md
:parser: myst_parser.sphinx_
================================================
FILE: docs/source/tools.rst
================================================
Tools and Utilities
===================
Overview
--------
LangGraph tools for human interaction, code execution, and knowledge search.
Python Code Execution
----------------------
.. automodule:: openchatbi.tool.run_python_code
:members:
:undoc-members:
:show-inheritance:
Human Interaction
-----------------
.. automodule:: openchatbi.tool.ask_human
:members:
:undoc-members:
:show-inheritance:
Memory Management
-----------------
.. automodule:: openchatbi.tool.memory
:members:
:undoc-members:
:show-inheritance:
Knowledge Search
----------------
.. automodule:: openchatbi.tool.search_knowledge
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: evals/__init__.py
================================================
"""Evaluation suites for RunLedger."""
================================================
FILE: evals/runledger/README.md
================================================
# RunLedger eval (OpenChatBI)
This suite is **replay-only** by default. It runs a deterministic CI check using a JSONL adapter that proxies tool calls through RunLedger and replays results from a cassette.
## Run (replay)
```bash
runledger run evals/runledger --mode replay --baseline baselines/runledger-openchatbi.json
```
## Record / update cassette (optional)
If you want to re-record the cassette with real tool outputs, run in record mode in a fully configured OpenChatBI environment (valid `openchatbi/config.yaml`, data warehouse/catalog, LLM keys).
```bash
runledger run evals/runledger --mode record \
--baseline baselines/runledger-openchatbi.json \
--tool-module evals.runledger.tools
```
Notes:
- Tool args are passed as JSON objects; see `evals/runledger/cassettes/t1.jsonl` for the exact shape.
- After recording, promote the new baseline:
```bash
runledger baseline promote \
--from runledger_out/runledger-openchatbi/ \
--to baselines/runledger-openchatbi.json
```
================================================
FILE: evals/runledger/__init__.py
================================================
"""RunLedger eval suite for OpenChatBI."""
================================================
FILE: evals/runledger/agent/agent.py
================================================
import json
import sys
from itertools import count
from typing import Any
from unittest.mock import MagicMock
import builtins
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import StructuredTool
from langgraph.checkpoint.memory import MemorySaver
from pydantic import BaseModel, Field
from openchatbi import config
import openchatbi.agent_graph as agent_graph
_CALL_COUNTER = count(1)
_ORIG_PRINT = builtins.print
def _safe_print(*args: Any, **kwargs: Any) -> None:
"""Suppress stdout prints so JSONL stays clean; allow stderr."""
target = kwargs.get("file")
if target is None or target is sys.stdout:
return
_ORIG_PRINT(*args, **kwargs)
builtins.print = _safe_print
class JsonlChannel:
def __init__(self, stream: Any) -> None:
self._stream = stream
def read(self) -> dict[str, Any] | None:
while True:
line = self._stream.readline()
if not line:
return None
line = line.strip()
if not line:
continue
try:
return json.loads(line)
except json.JSONDecodeError:
continue
@staticmethod
def send(payload: dict[str, Any]) -> None:
sys.stdout.write(json.dumps(payload) + "\n")
sys.stdout.flush()
def _last_user_text(messages: list[Any]) -> str:
for message in reversed(messages):
if isinstance(message, HumanMessage):
return str(message.content).strip()
return "OpenChatBI"
def _runledger_tool_call(channel: JsonlChannel, name: str, args: dict[str, Any]) -> Any:
call_id = f"c{next(_CALL_COUNTER)}"
channel.send({"type": "tool_call", "name": name, "call_id": call_id, "args": args})
while True:
message = channel.read()
if message is None:
raise RuntimeError("Tool result missing")
if message.get("type") != "tool_result":
continue
if message.get("call_id") != call_id:
continue
if message.get("ok"):
return message.get("result")
raise RuntimeError(message.get("error") or "Tool error")
class SearchKnowledgeInput(BaseModel):
reasoning: str = Field(description="Reason for searching knowledge")
query_list: list[str] = Field(description="Query terms")
knowledge_bases: list[str] = Field(description="Knowledge bases to search")
with_table_list: bool = Field(default=False, description="Include table list")
class ShowSchemaInput(BaseModel):
reasoning: str = Field(description="Reason for showing schema")
tables: list[str] = Field(description="Table names")
class Text2SQLInput(BaseModel):
reasoning: str = Field(description="Reason for calling text2sql")
context: str = Field(description="Full context for the SQL graph")
class RunPythonInput(BaseModel):
reasoning: str = Field(description="Reason for running python code")
code: str = Field(description="Python code to execute")
class SaveReportInput(BaseModel):
content: str = Field(description="Report content")
title: str = Field(description="Report title")
file_format: str = Field(description="File extension")
def _build_tool_proxies(channel: JsonlChannel) -> dict[str, StructuredTool]:
def search_knowledge(
reasoning: str,
query_list: list[str],
knowledge_bases: list[str],
with_table_list: bool = False,
) -> Any:
return _runledger_tool_call(
channel,
"search_knowledge",
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": with_table_list,
},
)
def show_schema(reasoning: str, tables: list[str]) -> Any:
return _runledger_tool_call(
channel,
"show_schema",
{"reasoning": reasoning, "tables": tables},
)
def text2sql(reasoning: str, context: str) -> Any:
return _runledger_tool_call(
channel,
"text2sql",
{"reasoning": reasoning, "context": context},
)
def run_python_code(reasoning: str, code: str) -> Any:
return _runledger_tool_call(
channel,
"run_python_code",
{"reasoning": reasoning, "code": code},
)
def save_report(content: str, title: str, file_format: str = "md") -> Any:
return _runledger_tool_call(
channel,
"save_report",
{"content": content, "title": title, "file_format": file_format},
)
return {
"search_knowledge": StructuredTool.from_function(
func=search_knowledge,
name="search_knowledge",
description="RunLedger proxy for search_knowledge",
args_schema=SearchKnowledgeInput,
),
"show_schema": StructuredTool.from_function(
func=show_schema,
name="show_schema",
description="RunLedger proxy for show_schema",
args_schema=ShowSchemaInput,
),
"text2sql": StructuredTool.from_function(
func=text2sql,
name="text2sql",
description="RunLedger proxy for text2sql",
args_schema=Text2SQLInput,
),
"run_python_code": StructuredTool.from_function(
func=run_python_code,
name="run_python_code",
description="RunLedger proxy for run_python_code",
args_schema=RunPythonInput,
),
"save_report": StructuredTool.from_function(
func=save_report,
name="save_report",
description="RunLedger proxy for save_report",
args_schema=SaveReportInput,
),
}
def _stub_llm_call(chat_model: Any, messages: list[Any], **_kwargs: Any) -> AIMessage:
tool_seen = any(isinstance(msg, ToolMessage) or getattr(msg, "type", None) == "tool" for msg in messages)
if tool_seen:
return AIMessage(content="Here is a deterministic summary based on the tool result.", tool_calls=[])
user_text = _last_user_text(messages)
tool_args = {
"reasoning": "Look up relevant knowledge",
"query_list": [user_text],
"knowledge_bases": ["columns"],
"with_table_list": False,
}
return AIMessage(
content="Searching knowledge base.",
tool_calls=[{"name": "search_knowledge", "args": tool_args, "id": "call_1"}],
)
def _configure_agent_graph(channel: JsonlChannel) -> None:
tool_proxies = _build_tool_proxies(channel)
agent_graph.search_knowledge = tool_proxies["search_knowledge"]
agent_graph.show_schema = tool_proxies["show_schema"]
agent_graph.run_python_code = tool_proxies["run_python_code"]
agent_graph.save_report = tool_proxies["save_report"]
agent_graph.get_sql_tools = lambda *_args, **_kwargs: tool_proxies["text2sql"]
agent_graph.build_sql_graph = lambda *_args, **_kwargs: object()
agent_graph.get_memory_tools = lambda *_args, **_kwargs: []
agent_graph.create_mcp_tools_sync = lambda *_args, **_kwargs: []
agent_graph.check_forecast_service_health = lambda: False
agent_graph.call_llm_chat_model_with_retry = _stub_llm_call
def _bootstrap_config() -> None:
config.set(
{
"default_llm": MagicMock(),
"data_warehouse_config": {},
"catalog_store": {"store_type": "file_system", "auto_load": False},
}
)
def main() -> int:
channel = JsonlChannel(sys.stdin)
message = channel.read()
if not message or message.get("type") != "task_start":
return 1
_bootstrap_config()
_configure_agent_graph(channel)
prompt = ""
payload = message.get("input", {})
if isinstance(payload, dict):
prompt = payload.get("prompt") or payload.get("question") or payload.get("query") or ""
if not prompt:
prompt = "OpenChatBI"
graph = agent_graph.build_agent_graph_sync(
catalog=config.get().catalog_store,
checkpointer=MemorySaver(),
memory_store=None,
enable_context_management=False,
)
result = graph.invoke({"messages": [{"role": "user", "content": prompt}]})
output = ""
if isinstance(result, dict) and result.get("messages"):
output = str(result["messages"][-1].content)
channel.send(
{
"type": "final_output",
"output": {"category": "bi", "reply": output or "Completed request."},
}
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
================================================
FILE: evals/runledger/cases/t1.yaml
================================================
id: t1
description: "basic BI flow with a single search_knowledge tool call"
input:
prompt: "OpenChatBI"
cassette: cassettes/t1.jsonl
================================================
FILE: evals/runledger/cassettes/t1.jsonl
================================================
{"tool":"search_knowledge","args":{"knowledge_bases":["columns"],"query_list":["OpenChatBI"],"reasoning":"Look up relevant knowledge","with_table_list":false},"ok":true,"result":{"columns":"# Relevant Columns and Description:\n## openchatbi\n- Column Category: metric\n- Display Name: OpenChatBI\n- Description \"Project overview\""}}
================================================
FILE: evals/runledger/schema.json
================================================
{
"type": "object",
"properties": {
"category": {
"type": "string"
},
"reply": {
"type": "string"
}
},
"required": [
"category",
"reply"
]
}
================================================
FILE: evals/runledger/suite.yaml
================================================
suite_name: runledger-openchatbi
agent_command: ["python", "evals/runledger/agent/agent.py"]
mode: replay
cases_path: cases
tool_registry:
- search_knowledge
tool_module: evals.runledger.tools
assertions:
- type: json_schema
schema_path: schema.json
budgets:
max_wall_ms: 20000
max_tool_calls: 5
max_tool_errors: 0
baseline_path: ../../baselines/runledger-openchatbi.json
================================================
FILE: evals/runledger/tools.py
================================================
from __future__ import annotations
from typing import Any
from openchatbi.tool.search_knowledge import search_knowledge
def _invoke_tool(tool, args: dict[str, Any]) -> Any:
return tool.invoke(args)
def _search_knowledge(args: dict[str, Any]) -> Any:
return _invoke_tool(search_knowledge, args)
TOOLS = {
"search_knowledge": _search_knowledge,
}
================================================
FILE: example/bi.yaml
================================================
extra_tool_use_rule: |
- Try your best to give appropriate parameters when calling tools.
- timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros.
table_selection_extra_rule: |
- When users ask about orders, consider if they need customer information (join with Customers table)
- For product-related queries, check if order information is needed (join with Order_Items)
- Shipment queries often require order and product details (join with multiple tables)
- Invoice questions may need shipment information for complete tracking
text2sql_extra_rule: |
- Use proper JOIN syntax when connecting related tables
- Use LIKE operator for partial string matches in product names or customer names
- Handle NULL values properly in optional fields like details columns
basic_knowledge_glossary: |
# Sales Business System Glossary
## Overview
You're answering questions related to a sales order tracking business system that manages the complete customer order lifecycle from placement to delivery.
## Key Business Concepts
**Customer Management:**
- Customer: Individual or entity who places orders
- Customer Details: Additional information like contact info, preferences, or notes
**Order Processing:**
- Order: A request from a customer to purchase products
- Order Status: Current state - Valid values: "Shipped", "Packing", "On Road"
- Order Item: Individual product within an order (orders can contain multiple items)
- Order Item Status: Status of specific items - Valid values: "Finish", "Payed", "Cancel"
**Product Catalog:**
- Product: Items available for purchase
- Product Details: Specifications, descriptions, or additional product information
**Fulfillment & Shipping:**
- Shipment: Physical delivery package sent to customer
- Shipment Items: Specific order items included in a shipment
- Tracking Number: Unique identifier for package tracking
- Shipment Date: When package was dispatched
**Financial Processing:**
- Invoice: Bill generated for completed orders
- Invoice Number: Unique identifier for billing purposes
- Invoice Date: When billing document was created
## Business Rules
- One order can have multiple items (products)
- One order can be fulfilled through multiple shipments
- Each shipment links to one invoice for billing
- Order items can have different statuses within the same order
- Customers can have multiple orders over time
================================================
FILE: example/common_columns.csv
================================================
column_name,display_name,alias,type,category,tag,description,dimension_table,default
customer_id,Customer ID,cust_id,INTEGER,identifier,customer,Unique identifier for customers,Customers,
customer_name,Customer Name,cust_name,VARCHAR(80),attribute,customer,Name of the customer,Customers,
customer_details,Customer Details,cust_details,VARCHAR(255),attribute,customer,Additional customer information,Customers,
invoice_number,Invoice Number,inv_num,INTEGER,identifier,financial,Unique invoice identifier,Invoices,
invoice_date,Invoice Date,inv_date,DATETIME,temporal,financial,Date the invoice was created,Invoices,
invoice_details,Invoice Details,inv_details,VARCHAR(255),attribute,financial,Additional invoice information,Invoices,
order_item_id,Order Item ID,oi_id,INTEGER,identifier,order,Unique identifier for order items,Order_Items,
product_id,Product ID,prod_id,INTEGER,identifier,product,Unique identifier for products,Products,
order_id,Order ID,ord_id,INTEGER,identifier,order,Unique identifier for orders,Orders,
order_item_status,Order Item Status,oi_status,VARCHAR(10),status,order,Current status of the order item (Finish|Payed|Cancel),Order_Items,
order_item_details,Order Item Details,oi_details,VARCHAR(255),attribute,order,Additional order item information,Order_Items,
order_status,Order Status,ord_status,VARCHAR(10),status,order,Current status of the order (Shipped|Packing|On Road),Orders,
date_order_placed,Order Placed Date,ord_date,DATETIME,temporal,order,Date when the order was placed,Orders,
order_details,Order Details,ord_details,VARCHAR(255),attribute,order,Additional order information,Orders,
product_name,Product Name,prod_name,VARCHAR(80),attribute,product,Name of the product,Products,
product_details,Product Details,prod_details,VARCHAR(255),attribute,product,Additional product information,Products,
shipment_id,Shipment ID,ship_id,INTEGER,identifier,shipment,Unique identifier for shipments,Shipments,
shipment_tracking_number,Tracking Number,track_num,VARCHAR(80),identifier,shipment,Tracking number for shipment,Shipments,
shipment_date,Shipment Date,ship_date,DATETIME,temporal,shipment,Date when the shipment was sent,Shipments,
other_shipment_details,Shipment Details,ship_details,VARCHAR(255),attribute,shipment,Additional shipment information,Shipments,
================================================
FILE: example/config.yaml
================================================
organization: MyCompany
dialect: sqlite
bi_config_file: example/bi.yaml
python_executor: docker
# Visualization configuration
visualization_mode: llm
# Catalog store configuration
catalog_store:
store_type: file_system
data_path: ./example
# Data warehouse configuration
data_warehouse_config:
# sqlite from spider->tracking_orders dataset
uri: "sqlite:///example/tracking_orders.sqlite"
database_name: ""
# LLM configurations
# Use OpenAI LLM, replace YOUR_API_KEY_HERE with your actual API key
default_llm:
class: langchain_openai.ChatOpenAI
params:
api_key: YOUR_API_KEY_HERE
model: gpt-4.1
temperature: 0.01
max_tokens: 8192
embedding_model:
class: langchain_openai.OpenAIEmbeddings
params:
api_key: YOUR_API_KEY_HERE
model: text-embedding-3-large
chunk_size: 1024
# If you cannot access to OpenAI or other cloud LLM provider,
# uncomment the following lines instead to use Ollama local LLM
#default_llm:
# class: langchain_ollama.ChatOllama
# params:
# model: gpt-oss:20b
# temperature: 0.01
# num_predict: 8192
================================================
FILE: example/sql_example.yaml
================================================
'':
Customers: |
Q: Show me all customers with their names and details
A: SELECT customer_id, customer_name, customer_details
FROM Customers
ORDER BY customer_name
Invoices: |
Q: List all invoices from the last 30 days
A: SELECT invoice_number, invoice_date, invoice_details
FROM Invoices
WHERE invoice_date >= DATE(''now'', ''-30 days'')
ORDER BY invoice_date DESC
Order_Items: |
Q: Show me all items in order 123
A: SELECT oi.order_item_id, p.product_name, oi.order_item_status, oi.order_item_details
FROM Order_Items oi
JOIN Products p ON oi.product_id = p.product_id
WHERE oi.order_id = 123
Orders: |
Q: Find all pending orders with customer information
A: SELECT o.order_id, c.customer_name, o.order_status, o.date_order_placed
FROM Orders o
JOIN Customers c ON o.customer_id = c.customer_id
WHERE o.order_status = ''pending''
ORDER BY o.date_order_placed
Products: |
Q: Search for products containing ''laptop'' in the name
A: SELECT product_id, product_name, product_details
FROM Products
WHERE product_name LIKE ''%laptop%''
ORDER BY product_name'
Shipment_Items: |
Q: Show which order items are in shipment 456
A: SELECT si.shipment_id, si.order_item_id, p.product_name
FROM Shipment_Items si
JOIN Order_Items oi ON si.order_item_id = oi.order_item_id
JOIN Products p ON oi.product_id = p.product_id
WHERE si.shipment_id = 456
Shipments: |
Q: Track all shipments for order 789
A: SELECT shipment_id, shipment_tracking_number, shipment_date, other_shipment_details
FROM Shipments
WHERE order_id = 789
ORDER BY shipment_date
================================================
FILE: example/table_columns.csv
================================================
db_name,table_name,column_name
,Customers,customer_id
,Customers,customer_name
,Customers,customer_details
,Invoices,invoice_number
,Invoices,invoice_date
,Invoices,invoice_details
,Order_Items,order_item_id
,Order_Items,product_id
,Order_Items,order_id
,Order_Items,order_item_status
,Order_Items,order_item_details
,Orders,order_id
,Orders,customer_id
,Orders,order_status
,Orders,date_order_placed
,Orders,order_details
,Products,product_id
,Products,product_name
,Products,product_details
,Shipment_Items,shipment_id
,Shipment_Items,order_item_id
,Shipments,shipment_id
,Shipments,order_id
,Shipments,invoice_number
,Shipments,shipment_tracking_number
,Shipments,shipment_date
,Shipments,other_shipment_details
================================================
FILE: example/table_info.yaml
================================================
? ''
: Customers:
description: 'Contains customer information including unique ID, name, and additional details'
selection_rule: 'Select when queries involve customer information, customer names, or need to join orders with customer data'
sql_rule: 'Use customer_id as primary key for joins. Always include customer_name when displaying customer information'
Invoices:
description: 'Stores invoice information with unique invoice numbers, dates, and details'
selection_rule: 'Select when queries involve billing, invoice tracking, or financial reporting'
sql_rule: 'Use invoice_number as primary key. Filter by invoice_date for temporal queries'
Order_Items:
description: 'Links products to orders with individual item status and details'
selection_rule: 'Select when queries need product details within orders or item-level status tracking'
sql_rule: 'Always join with Products table via product_id and Orders table via order_id for complete information'
Orders:
description: 'Main order table containing order status, placement date, and customer relationships'
selection_rule: 'Select when queries involve order status, order history, or customer order relationships'
sql_rule: 'Use order_id as primary key. Join with Customers via customer_id for customer information'
Products:
description: 'Product catalog containing product names, IDs, and detailed product information'
selection_rule: 'Select when queries involve product information, product searches, or inventory-related questions'
sql_rule: 'Use product_id as primary key. Use LIKE operator for product_name searches'
Shipment_Items:
description: 'Junction table linking shipments to specific order items'
selection_rule: 'Select when queries need to track which specific items are in which shipments'
sql_rule: 'Always join with both Shipments and Order_Items tables. No primary key - composite key of shipment_id and order_item_id'
Shipments:
description: 'Shipment tracking information including tracking numbers, dates, and shipment details'
selection_rule: 'Select when queries involve shipping, delivery tracking, or fulfillment information'
sql_rule: 'Use shipment_id as primary key. Join with Orders via order_id and Invoices via invoice_number for complete shipping context'
================================================
FILE: example/table_selection_example.csv
================================================
question,selected_tables
"Show me all customers","[""Customers""]"
"What orders were placed today?","[""Orders""]"
"List all products and their details","[""Products""]"
"Show me customer orders with their details","[""Customers"", ""Orders""]"
"What products are in each order?","[""Orders"", ""Order_Items"", ""Products""]"
"Show shipment tracking information","[""Shipments""]"
"Which items are in each shipment?","[""Shipments"", ""Shipment_Items"", ""Order_Items""]"
"Show order status and customer information","[""Orders"", ""Customers""]"
"What invoices were created this month?","[""Invoices""]"
"Show complete order fulfillment chain","[""Orders"", ""Order_Items"", ""Products"", ""Shipments"", ""Invoices""]"
================================================
FILE: openchatbi/__init__.py
================================================
"""OpenChatBI core module initialization."""
import os
from langgraph.graph.state import CompiledStateGraph
from openchatbi.config_loader import ConfigLoader
# Global configuration instance
config = ConfigLoader()
# Skip config loading during documentation build
if not os.environ.get("SPHINX_BUILD"):
config.load()
else:
config.set({})
def get_default_graph():
"""
Build the synchronous mode of the agent graph using default catalog in config.
Returns:
CompiledStateGraph: Compiled agent graph ready for execution.
"""
if os.environ.get("SPHINX_BUILD"):
return None
from langgraph.checkpoint.memory import MemorySaver
from openchatbi.agent_graph import build_agent_graph_sync
from openchatbi.tool.memory import get_sync_memory_store
checkpointer = MemorySaver()
return build_agent_graph_sync(
config.get().catalog_store, checkpointer=checkpointer, memory_store=get_sync_memory_store()
)
================================================
FILE: openchatbi/agent_graph.py
================================================
"""Main agent graph construction and execution logic."""
import datetime
import logging
import traceback
from collections.abc import Callable
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.tools import StructuredTool
from langchain_openai.chat_models.base import BaseChatOpenAI
from langgraph.constants import START
from langgraph.errors import GraphInterrupt
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer, Send, interrupt
from pydantic import BaseModel, Field
from openchatbi import config
from openchatbi.catalog import CatalogStore
from openchatbi.constants import datetime_format
from openchatbi.context_config import get_context_config
from openchatbi.context_manager import ContextManager
from openchatbi.graph_state import AgentState, InputState, OutputState
from openchatbi.llm.llm import call_llm_chat_model_with_retry, get_llm
from openchatbi.prompts.system_prompt import get_agent_prompt_template
from openchatbi.text2sql.sql_graph import build_sql_graph
from openchatbi.tool.ask_human import AskHuman
from openchatbi.tool.mcp_tools import create_mcp_tools_sync, get_mcp_tools_async
from openchatbi.tool.memory import get_memory_tools
from openchatbi.tool.run_python_code import run_python_code
from openchatbi.tool.save_report import save_report
from openchatbi.tool.search_knowledge import search_knowledge, show_schema
from openchatbi.tool.timeseries_forecast import check_forecast_service_health, timeseries_forecast
from openchatbi.utils import log, recover_incomplete_tool_calls
logger = logging.getLogger(__name__)
def get_mcp_servers():
"""Get MCP servers from config with fallback for tests."""
try:
return config.get().mcp_servers
except ValueError:
return []
def ask_human(state: AgentState) -> dict[str, Any]:
"""Node function to ask human for additional information or clarification.
Args:
state (AgentState): The current graph state containing messages and context.
Returns:
dict: Updated state with human feedback as a tool message and user input.
"""
tool_call = state["messages"][-1].tool_calls[0]
tool_call_id = tool_call["id"]
args = tool_call["args"]
user_feedback = interrupt({"text": args["question"], "buttons": args.get("options", None)})
tool_message = [{"tool_call_id": tool_call_id, "type": "tool", "content": user_feedback}]
return {
"messages": tool_message,
"history_messages": [AIMessage(args["question"]), HumanMessage(user_feedback)],
"user_input": user_feedback,
}
class CallSQLGraphInput(BaseModel):
reasoning: str = Field(
description="Explanation of why Text2SQL tool is needed",
)
context: str = Field(
description="""The full context pass to Text2SQL tool, make sure do not miss any potential information that related to user's question.
Following the format: History Conversation: (user and assistant history dialog)
Information: (the knowledge you retrival that is relevant, like metrics and dimensions)
User's latest question:""",
)
# Description for SQL tools
TEXT2SQL_TOOL_DESCRIPTION = """Text2SQL tool to generate and execute SQL query and build visualization DSL for UI
based on user's question and context.
Returns:
str: A formatted response containing SQL, data, and visualization status.
Important notes:
- If user want to change the visualization chart type or style, add the requirement in the question
- Make sure to provide question in English
"""
def _format_sql_response(sql_graph_response: dict) -> str:
"""Format SQL graph response into a standardized string format.
Args:
sql_graph_response: The response dictionary from the SQL graph
Returns:
str: Formatted response string
"""
sql = sql_graph_response.get("sql", "")
data = sql_graph_response.get("data", "")
visualization_dsl = sql_graph_response.get("visualization_dsl", {})
response_parts = []
if sql:
response_parts.append(f"SQL Query:\n```sql\n{sql}\n```")
if data:
response_parts.append(f"\nQuery Results (CSV format):\n```csv\n{data}\n```")
# Include visualization status
if visualization_dsl and "error" not in visualization_dsl:
chart_type = visualization_dsl.get("chart_type", "unknown")
response_parts.append(
f"\nVisualization Created: {chart_type} chart has been automatically generated and will be displayed in the UI."
)
elif visualization_dsl and "error" in visualization_dsl:
response_parts.append(f"\nVisualization Error: {visualization_dsl['error']}")
return "\n\n".join(response_parts) if response_parts else "No results returned."
def get_sql_tools(sql_graph: CompiledStateGraph, sync_mode: bool = False) -> Callable:
"""Create SQL generation tool from compiled SQL graph.
Args:
sql_graph (CompiledStateGraph): The compiled SQL generation subgraph.
sync_mode (bool): Whether to create synchronous or asynchronous tools
Returns:
function: Tool function for SQL generation.
"""
def call_sql_graph_sync(reasoning: str, context: str) -> str:
"""Sync node function for Text2SQL tool"""
log(f"Call SQL graph (sync) with reasoning: {reasoning}, context: {context}")
try:
sql_graph_response = sql_graph.invoke({"messages": context})
return _format_sql_response(sql_graph_response)
except GraphInterrupt as e:
log(f"Sql graph interrupted:\n{repr(e)}")
raise e
except Exception as e:
log(f"Run sql graph error:\n{repr(e)}")
traceback.print_exc()
return "Error occurred when calling Text2SQL tool."
async def call_sql_graph_async(reasoning: str, context: str) -> str:
"""Async node function for Text2SQL tool"""
log(f"Call SQL graph (async) with reasoning: {reasoning}, context: {context}")
try:
sql_graph_response = await sql_graph.ainvoke({"messages": context})
return _format_sql_response(sql_graph_response)
except GraphInterrupt as e:
log(f"Sql graph interrupted:\n{repr(e)}")
raise e
except Exception as e:
log(f"Run sql graph error:\n{repr(e)}")
traceback.print_exc()
return "Error occurred when calling Text2SQL tool."
if sync_mode:
return StructuredTool.from_function(
func=call_sql_graph_sync,
name="text2sql",
description=TEXT2SQL_TOOL_DESCRIPTION,
args_schema=CallSQLGraphInput,
return_direct=False,
)
else:
return StructuredTool.from_function(
coroutine=call_sql_graph_async,
name="text2sql",
description=TEXT2SQL_TOOL_DESCRIPTION,
args_schema=CallSQLGraphInput,
return_direct=False,
)
def agent_llm_call(llm: BaseChatModel, tools: list, context_manager: ContextManager = None) -> Callable:
"""Create llm call function to generate reasoning and determine next node based on tool calls in LLM response.
Args:
llm (BaseChatModel): The LLM for agent decision-making.
tools: List of tools.
context_manager: Optional context manager for handling long conversations.
Returns:
function: function that processes state and determines next node.
"""
# OpenAI models support strict tool calling
if isinstance(llm, BaseChatOpenAI):
llm_with_tools = llm.bind_tools(tools, strict=True)
else:
llm_with_tools = llm.bind_tools(tools)
def _call_model(state: AgentState):
# First, check and recover any incomplete tool calls
recovery_ops = recover_incomplete_tool_calls(state)
if recovery_ops:
return {"messages": recovery_ops, "agent_next_node": "llm_node"}
messages = state["messages"]
final_messages = []
if isinstance(messages[-1], HumanMessage):
final_messages.append(messages[-1])
# Apply context management if available (before processing)
if context_manager:
original_count = len(messages)
context_manager.manage_context_messages(messages)
if len(messages) != original_count:
logger.info(f"Context management: modified messages from {original_count} to {len(messages)}")
system_prompt = get_agent_prompt_template().replace(
"[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format)
)
response = call_llm_chat_model_with_retry(
llm_with_tools,
([SystemMessage(system_prompt)] + messages),
streaming_tokens=True,
bound_tools=tools,
parallel_tool_call=True,
)
if isinstance(response, AIMessage):
tool_calls = response.tool_calls
print("Tool Call:", ", ".join(tool["name"] for tool in tool_calls))
if tool_calls:
# Group tool calls by type for parallel routing
ask_human_calls = [call for call in tool_calls if call["name"] == "AskHuman"]
normal_tool_calls = [call for call in tool_calls if call["name"] != "AskHuman"]
# Create Send objects for parallel routing
sends = []
if ask_human_calls:
# Create message with only AskHuman calls
ask_human_msg = AIMessage(content=response.content, tool_calls=ask_human_calls)
sends.append(Send("ask_human", {"messages": [ask_human_msg]}))
if normal_tool_calls:
# Create message with only normal tool calls
tool_msg = AIMessage(content=response.content, tool_calls=normal_tool_calls)
sends.append(Send("use_tool", {"messages": [tool_msg]}))
return {"messages": [response], "history_messages": final_messages, "sends": sends}
else:
final_messages.append(AIMessage(response.content))
return {
"messages": [response],
"final_answer": response.content,
"history_messages": final_messages,
"agent_next_node": END,
}
elif response is None:
return {
"messages": [AIMessage("Sorry, the LLM service is currently unavailable.")],
"history_messages": final_messages,
"agent_next_node": END,
}
else:
return {"messages": [response], "history_messages": final_messages, "agent_next_node": END}
return _call_model
def _build_graph_core(
catalog: CatalogStore,
sync_mode: bool,
checkpointer: Checkpointer,
memory_store: BaseStore,
memory_tools: list[Callable] | None,
mcp_tools: list,
enable_context_management: bool = True,
llm_provider: str | None = None,
) -> CompiledStateGraph:
"""Core graph building logic shared by both sync and async versions.
Args:
catalog: Catalog store containing schema information
sync_mode: Whether to use synchronous mode for tools and operations
checkpointer: The Checkpointer for state persistence
memory_store: The BaseStore to use for long-term memory
memory_tools: List of memory tools (manage_memory_tool, search_memory_tool)
mcp_tools: Pre-initialized MCP tools
enable_context_management: Whether to enable context management
Returns:
CompiledStateGraph: Compiled agent graph ready for execution
"""
sql_graph = build_sql_graph(catalog, checkpointer, memory_store, llm_provider=llm_provider)
call_sql_graph_tool = get_sql_tools(sql_graph=sql_graph, sync_mode=sync_mode)
# Use provided memory tools or create them
if not memory_tools:
memory_tools = get_memory_tools(get_llm(llm_provider), sync_mode=sync_mode, store=memory_store)
log(str(mcp_tools))
normal_tools = [
search_knowledge,
show_schema,
call_sql_graph_tool,
run_python_code,
save_report,
]
if memory_tools:
normal_tools.extend(memory_tools)
if check_forecast_service_health():
normal_tools.append(timeseries_forecast)
else:
logger.warning("Time series forecasting service is not healthy. Skipping timeseries_forecast tool.")
normal_tools.extend(mcp_tools)
# Initialize context manager if enabled
context_manager = None
if enable_context_management:
context_manager = ContextManager(llm=get_llm(llm_provider), config=get_context_config())
tool_node = ToolNode(normal_tools)
# Define the agent graph
graph = StateGraph(AgentState, input_schema=InputState, output_schema=OutputState)
# Add nodes to the graph
graph.add_node("llm_node", agent_llm_call(get_llm(llm_provider), normal_tools + [AskHuman], context_manager))
graph.add_node("ask_human", ask_human)
graph.add_node("use_tool", tool_node)
# Add edges between nodes
graph.add_edge(START, "llm_node")
graph.add_edge("ask_human", "llm_node")
graph.add_edge("use_tool", "llm_node")
# Add conditional routing from llm node
def route_tools(state: AgentState):
# Only use sends if the last message came from the llm node (has tool_calls)
last_message = state["messages"][-1] if state["messages"] else None
if (
last_message
and isinstance(last_message, AIMessage)
and last_message.tool_calls
and "sends" in state
and state["sends"]
):
return state["sends"] # Return Send objects for parallel execution
elif "agent_next_node" in state:
return state["agent_next_node"] # Return single node name
else:
return END
graph.add_conditional_edges(
"llm_node",
route_tools,
# mapping of paths to node names (for single routing)
{
"llm_node": "llm_node",
"ask_human": "ask_human",
"use_tool": "use_tool",
END: END,
},
)
graph = graph.compile(name="agent_graph", checkpointer=checkpointer, store=memory_store)
return graph
def build_agent_graph_sync(
catalog: CatalogStore,
checkpointer: Checkpointer = None,
memory_store: BaseStore = None,
enable_context_management: bool = True,
llm_provider: str | None = None,
) -> CompiledStateGraph:
"""Build the main agent graph with all nodes and edges (sync version).
Args:
catalog: Catalog store containing schema information.
checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.
memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.
enable_context_management: Whether to enable context management for long conversations.
Returns:
CompiledStateGraph: Compiled agent graph ready for execution.
"""
# Get MCP tools for sync context
mcp_tools = create_mcp_tools_sync(get_mcp_servers())
return _build_graph_core(
catalog=catalog,
sync_mode=True,
checkpointer=checkpointer,
memory_store=memory_store,
memory_tools=None, # Always None for sync version - creates its own
mcp_tools=mcp_tools,
enable_context_management=enable_context_management,
llm_provider=llm_provider,
)
async def build_agent_graph_async(
catalog: CatalogStore,
checkpointer: Checkpointer = None,
memory_store: BaseStore = None,
memory_tools: list[Callable] = None,
enable_context_management: bool = True,
llm_provider: str | None = None,
) -> CompiledStateGraph:
"""Build the main agent graph with all nodes and edges (async version).
This function is identical to build_agent_graph_sync but properly handles
async MCP tool initialization when called from async contexts.
Args:
catalog: Catalog store containing schema information.
checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.
memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.
memory_tools: List of memory tools (manage_memory_tool, search_memory_tool). If None, creates async tools.
enable_context_management: Whether to enable context management for long conversations.
Returns:
CompiledStateGraph: Compiled agent graph ready for execution.
"""
# Get MCP tools for async context
mcp_tools = await get_mcp_tools_async(get_mcp_servers())
return _build_graph_core(
catalog=catalog,
sync_mode=False,
checkpointer=checkpointer,
memory_store=memory_store,
memory_tools=memory_tools,
mcp_tools=mcp_tools,
enable_context_management=enable_context_management,
llm_provider=llm_provider,
)
================================================
FILE: openchatbi/catalog/__init__.py
================================================
"""Data catalog management module for OpenChatBI."""
from openchatbi.catalog.catalog_loader import (
DataCatalogLoader,
load_catalog_from_data_warehouse,
)
from openchatbi.catalog.catalog_store import CatalogStore
from openchatbi.catalog.factory import create_catalog_store
__all__ = [
"CatalogStore",
"DataCatalogLoader",
"load_catalog_from_data_warehouse",
]
================================================
FILE: openchatbi/catalog/catalog_loader.py
================================================
import logging
from typing import Any
from sqlalchemy import MetaData, inspect
from sqlalchemy.engine import Engine
from .catalog_store import CatalogStore
logger = logging.getLogger(__name__)
class DataCatalogLoader:
"""
The loader to load data catalog from data warehouse metadata and save to catalog store.
"""
def __init__(self, engine: Engine, include_tables: list[str] | None = None):
"""
Initialize catalog loader.
Args:
engine (Engine): SQLAlchemy engine instance
include_tables (Optional[List[str]]): List of table names to include, None for all
"""
self.engine = engine
self.include_tables = include_tables
self.metadata = MetaData()
self.inspector = inspect(engine)
def get_tables_and_columns(self) -> dict[str, list[dict[str, Any]]]:
"""
Extract table and column metadata including comments using SQLAlchemy inspector.
Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary mapping table names to list of column information
"""
try:
tables_columns = {}
# Get all table names
table_names = self.inspector.get_table_names()
# Filter to specific tables if configured
if self.include_tables:
table_names = [name for name in table_names if name in self.include_tables]
logger.info(f"Found {len(table_names)} tables to process")
for table_name in table_names:
try:
# Get column information for the table
columns = self.inspector.get_columns(table_name)
column_list = []
for column in columns:
is_common_column = column not in ("id", "name", "type", "status")
column_info = {
"column_name": column["name"],
"display_name": "",
"alias": "",
"type": str(column["type"]),
"category": "",
"tag": "",
"description": column.get("comment", "") or "",
"dimension_table": "",
"default": str(column.get("default", "")) if column.get("default") is not None else "",
"is_common": is_common_column,
}
column_list.append(column_info)
tables_columns[table_name] = column_list
logger.debug(f"Processed table {table_name} with {len(column_list)} columns")
except Exception as e:
logger.error(f"Failed to process table {table_name}: {e}")
continue
logger.info(f"Successfully processed {len(tables_columns)} tables")
return tables_columns
except Exception as e:
logger.error(f"Failed to get tables and columns from data warehouse: {e}")
return {}
def get_table_indexes(self, table_name: str) -> list[dict[str, Any]]:
"""
Get index information for a specific table.
Args:
table_name (str): Name of the table
Returns:
List[Dict[str, Any]]: List of index information
"""
try:
indexes = self.inspector.get_indexes(table_name)
return indexes
except Exception as e:
logger.warning(f"Failed to get indexes for table {table_name}: {e}")
return []
def get_foreign_keys(self, table_name: str) -> list[dict[str, Any]]:
"""
Get foreign key information for a specific table.
Args:
table_name (str): Name of the table
Returns:
List[Dict[str, Any]]: List of foreign key information
"""
try:
foreign_keys = self.inspector.get_foreign_keys(table_name)
return foreign_keys
except Exception as e:
logger.warning(f"Failed to get foreign keys for table {table_name}: {e}")
return []
def save_to_catalog_store(
self, catalog_store: CatalogStore, database_name: str | None = None, update: bool = False
) -> bool:
"""
Extract warehouse metadata and save to catalog store.
Args:
catalog_store (CatalogStore): Target catalog store to load data to
database_name (Optional[str]): Database name in catalog, defaults to 'default'
update (bool): Update existing catalog store to sync with data warehouse
Returns:
bool: True if load was successful, False otherwise
"""
try:
if database_name is None:
database_name = "default"
# Get tables and columns from data warehouse
tables_columns = self.get_tables_and_columns()
if not tables_columns:
logger.warning("No tables found in data warehouse")
return True
# Import each table
success_count = 0
total_count = len(tables_columns)
for table_name, columns in tables_columns.items():
try:
# Get table comment if available
table_comment = ""
try:
table_info = self.inspector.get_table_comment(table_name)
table_comment = table_info.get("text", "") if table_info else ""
except Exception:
# Some databases don't support table comments
pass
table_info = {"description": table_comment, "selection_rule": "", "sql_rule": ""}
if catalog_store.save_table_information(table_name, table_info, columns, database_name):
success_count += 1
logger.info(f"Successfully loaded table: {database_name}.{table_name}")
else:
logger.error(f"Failed to load table: {database_name}.{table_name}")
# init null SQL examples
catalog_store.save_table_sql_examples(
table_name, [{"question": "null", "answer": "null"}], database_name
)
except Exception as e:
logger.error(f"Error loading table {table_name}: {e}")
# init empty table selection examples
catalog_store.save_table_selection_examples([("", [])])
logger.info(f"Load completed: {success_count}/{total_count} tables loaded successfully")
return success_count == total_count
except Exception as e:
logger.error(f"Failed to load data warehouse to catalog store: {e}")
return False
def load_catalog_from_data_warehouse(catalog_store: CatalogStore) -> bool:
"""
Load catalog data from data warehouse using SQLAlchemy based on data warehouse config (URI)
Main entry point for catalog loading.
Args:
catalog_store (CatalogStore): Target catalog store
Returns:
bool: True if load was successful, False otherwise
"""
try:
data_warehouse_config = catalog_store.get_data_warehouse_config()
database_uri = data_warehouse_config.get("uri")
include_tables = data_warehouse_config.get("include_tables")
database_name = data_warehouse_config.get("database_name", "default")
engine = catalog_store.get_sql_engine()
loader = DataCatalogLoader(engine, include_tables)
return loader.save_to_catalog_store(catalog_store, database_name)
except Exception as e:
logger.error(f"Failed to import catalog from data warehouse URI {database_uri}: {e}")
return False
================================================
FILE: openchatbi/catalog/catalog_store.py
================================================
from abc import ABC, abstractmethod
from typing import Any
from sqlalchemy import Engine
class CatalogStore(ABC):
"""
Abstract base class defining the storage interface for data catalog (database, table, column definitions, descriptions, and additional prompts).
Common columns which have same meanings across tables will be store centralized to avoid data duplication.
Column attribute:
- column_name: the name of the column
- display_name: the display name of the column
- type: the data type of the column
- category: dimension or metric
- description: the description of the column
- is_common: is common column or not
"""
@abstractmethod
def get_data_warehouse_config(self) -> dict:
"""
Get the data warehouse configuration
Returns:
dict: Data warehouse configuration
"""
pass
@abstractmethod
def get_sql_engine(self) -> Engine:
"""
Get the SQLAlchemy engine for the catalog
Returns:
Engine: SQLAlchemy engine
"""
pass
@abstractmethod
def get_database_list(self) -> list[str]:
"""
Get a list of all databases
Returns:
List[str]: List of database names
"""
pass
@abstractmethod
def get_table_list(self, database: str | None = None) -> list[str]:
"""
Get a list of all tables in the specified database, if database is None, return all tables
Args:
database (Optional[str]): Database name
Returns:
List[str]: List of table names
"""
pass
@abstractmethod
def get_column_list(self, table: str | None = None, database: str | None = None) -> list[dict[str, Any]]:
"""
Get all column information for the specified table, if table is None, return all common columns in the catalog
Args:
table (Optional[str]): Table name
database (Optional[str]): Database name
Returns:
List[Dict[str, Any]]: List of column information, each column contains name, type, description, etc.
"""
pass
@abstractmethod
def get_table_information(self, table: str, database: str | None = None) -> dict[str, Any]:
"""
Get the information for the specified table
Args:
table (str): Table name
database (Optional[str]): Database name
Returns:
Dict[str, Any]: Table information, including description text, selection rules, etc.
"""
pass
@abstractmethod
def get_sql_examples(
self, table: str | None = None, database: str | None = None
) -> list[tuple[str, str, list[str]]]:
"""
Get SQL examples
Args:
table (Optional[str]): Table name
database (Optional[str]): Database name
Returns:
List[Tuple[str, str, List[str]]]: List of SQL examples, each example is a Tuple-3 as (question, SQL, full_table_names)
"""
pass
@abstractmethod
def get_table_selection_examples(self) -> list[tuple[str, list[str]]]:
"""
Get table selection examples
Returns:
List[Tuple[str, List[str]]]: List of table selection examples, each example is a Tuple-2 as (question, selected tables)
"""
pass
@abstractmethod
def save_table_information(
self,
table: str,
information: dict[str, Any],
columns: list[dict[str, Any]],
database: str | None = None,
update_existing: bool = False,
) -> bool:
"""
Save the information and columns for a table
Args:
table (str): Table name
information (Dict[str, Any]): Table information
columns (List[Dict[str, Any]]): List of column information, each column dict contains at lease
column_name, type, category, description
database (Optional[str]): Database name
update_existing (bool): Update existing table and column information
Returns:
bool: Whether the save was successful
"""
pass
@abstractmethod
def save_table_sql_examples(self, table: str, examples: list[dict[str, str]], database: str | None = None) -> bool:
"""
Save SQL examples for a table
Args:
table (str): Table name
examples (List[Dict[str, str]]): List of SQL examples
database (Optional[str]): Database name
Returns:
bool: Whether the save was successful
"""
pass
@abstractmethod
def save_table_selection_examples(self, examples: list[tuple[str, list[str]]]) -> bool:
"""
Save table selection examples
Args:
examples (List[Tuple[str, List[str]]]): List of table selection examples
Returns:
bool: Whether the save was successful
"""
pass
@abstractmethod
def check_exists(self) -> bool:
"""
Check if the catalog store has existing data/content
Returns:
bool: True if catalog store has existing data, False if empty or missing essential files
"""
pass
def split_db_table_name(table: str, database: str | None = None) -> tuple[str, str, str]:
"""
Split full table name into db name and table name
Args:
table (str): if database is None, should be full table name like `db.table`, otherwise should be only table name
database (Optional[str]): Database name
Returns:
Tuple[str, str, str]: full_table_name, db_name, table_name
"""
full_table_name = table
if database is not None and "." not in table:
full_table_name = f"{database}.{table}"
if "." in full_table_name:
db_name, table_name = full_table_name.rsplit(".", 1)
else:
db_name = ""
table_name = full_table_name
return full_table_name, db_name, table_name
================================================
FILE: openchatbi/catalog/factory.py
================================================
import logging
import os
from openchatbi.catalog.catalog_loader import load_catalog_from_data_warehouse
from openchatbi.catalog.catalog_store import CatalogStore
from openchatbi.catalog.store.file_system import FileSystemCatalogStore
logger = logging.getLogger(__name__)
# Factory function for creating CatalogStore instances
def create_catalog_store(
store_type: str, auto_load: bool = True, data_warehouse_config: dict = None, **kwargs
) -> CatalogStore:
"""
Create a CatalogStore instance
Args:
store_type (str): Storage type, supports 'file_system'
auto_load (bool): Whether to autoload from database if catalog files don't exist
data_warehouse_config (dict): Data warehouse configuration dictionary
**kwargs: Other parameters
Returns:
CatalogStore: CatalogStore instance
Raises:
ValueError: If the storage type is not supported
"""
if store_type == "file_system":
data_path = kwargs.get("data_path", "data")
# convert relative path to absolute path
if not data_path.startswith("/"):
data_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), data_path)
catalog_store = FileSystemCatalogStore(data_path, data_warehouse_config)
# Check if autoload is enabled and if catalog files are missing
if auto_load:
_auto_load_catalog_if_needed(catalog_store)
return catalog_store
else:
raise ValueError(f"Unsupported storage type: {store_type}")
def _auto_load_catalog_if_needed(catalog_store: CatalogStore) -> None:
"""
Autoload catalog from data warehouse if catalog files are missing or empty
Args:
catalog_store (CatalogStore): The catalog store instance
"""
# Check if catalog store has existing data using the store's own check_exists method
if not catalog_store.check_exists():
logger.info("Catalog files missing or empty, attempting to load from data warehouse...")
try:
# Get data warehouse config from loaded configuration
data_warehouse_config = catalog_store.get_data_warehouse_config()
if not data_warehouse_config:
logger.warning("No data warehouse configuration found, skipping autoload")
return
warehouse_uri = data_warehouse_config.get("uri")
if not warehouse_uri:
logger.warning("No data warehouse URI found in configuration, skipping autoload")
return
# load catalog from data warehouse
success = load_catalog_from_data_warehouse(catalog_store)
if success:
logger.info("Successfully loaded catalog from data warehouse")
else:
logger.error("Failed to load catalog from data warehouse")
raise Exception("Failed to load catalog from data warehouse")
except Exception as e:
logger.warning(f"Autoload from data warehouse failed: {e}")
raise Exception("Failed to load catalog from data warehouse") from e
================================================
FILE: openchatbi/catalog/helper.py
================================================
from typing import Any
import requests
from sqlalchemy import Engine, create_engine
from openchatbi.catalog.token_service import apply_token_for_user
from openchatbi.utils import log
def get_requests_session(token: str, header_extra_params: dict) -> requests.Session:
"""Create HTTP session with bearer token authentication."""
session = requests.Session()
session.headers.update({"Authorization": f"Bearer {token}"})
if header_extra_params:
session.headers.update(header_extra_params)
return session
def create_sqlalchemy_engine_instance(data_warehouse_config: dict[str, Any]) -> Engine:
"""
Create SQLAlchemy engine instance from data warehouse config
Args:
data_warehouse_config: Config dict with 'uri' and optional 'token_service'
Returns:
Configured SQLAlchemy engine
"""
database_uri = data_warehouse_config.get("uri")
engine_args = {"echo": True}
# Handle Presto authentication
if "presto" in database_uri and "token_service" in data_warehouse_config:
token_service = data_warehouse_config.get("token_service")
user_name = data_warehouse_config.get("user_name")
password = data_warehouse_config.get("password")
header_extra_params = data_warehouse_config.get("header_extra_params", {})
token = apply_token_for_user(token_service, user_name, password)
log(f"Applied presto token: {token} for user: {user_name}")
engine_args["connect_args"] = {
"protocol": "https",
"requests_session": get_requests_session(token, header_extra_params),
}
database_uri = database_uri.format(user_name=user_name)
engine = create_engine(database_uri, **engine_args)
return engine
================================================
FILE: openchatbi/catalog/retrival_helper.py
================================================
"""Helper functions for building column retrieval systems."""
from rank_bm25 import BM25Okapi
from openchatbi.llm.llm import get_embedding_model
from openchatbi.text_segmenter import _segmenter
from openchatbi.utils import create_vector_db, log
def get_columns_metadata(catalog):
"""Extract column metadata for indexing.
Args:
catalog: Catalog store instance.
Returns:
tuple: (columns, col_dict, column_tokens, embedding_keys)
"""
columns = catalog.get_column_list()
col_dict = {}
column_tokens = []
embedding_keys = []
for column in columns:
col_dict[column["column_name"]] = column
text_parts = [
column.get("column_name", ""),
column.get("display_name", ""),
column.get("alias", ""),
column.get("tag", ""),
column.get("description", ""),
]
text = " ".join(text_parts)
tokens = [token for token in _segmenter.cut(text) if token not in ("_", " ")]
column_tokens.append(tokens)
embedding_key = f"{column['column_name']}: {column['display_name']}"
embedding_keys.append(embedding_key)
return columns, col_dict, column_tokens, embedding_keys
def build_column_tables_mapping(catalog):
"""Build a mapping of column names to their corresponding table names."""
column_tables_mapping = {}
for table_name in catalog.get_table_list():
for column in catalog.get_column_list(table_name):
column_name = column["column_name"]
if column_name not in column_tables_mapping:
column_tables_mapping[column_name] = []
column_tables_mapping[column_name].append(table_name)
return column_tables_mapping
def build_columns_retriever(catalog, vector_db_path: str = None):
"""Build BM25 and vector retrievers for columns.
Args:
catalog: Catalog store instance.
vector_db_path: Path to the vector database file.
Returns:
tuple: (bm25, vector_db, columns, col_dict)
"""
columns, col_dict, column_tokens, embedding_keys = get_columns_metadata(catalog)
bm25 = BM25Okapi(column_tokens)
log("Building vector database for columns...")
vector_db = create_vector_db(
embedding_keys,
get_embedding_model(),
metadatas=columns,
collection_name="columns",
collection_metadata={"hnsw:space": "cosine"},
chroma_db_path=vector_db_path,
)
return bm25, vector_db, columns, col_dict
================================================
FILE: openchatbi/catalog/schema_retrival.py
================================================
"""Schema and column retrieval functionality for finding relevant database structures."""
import os
import re
import Levenshtein
from openchatbi import config
from openchatbi.catalog.retrival_helper import build_column_tables_mapping, build_columns_retriever
from openchatbi.text_segmenter import _segmenter
from openchatbi.utils import log
# Skip build during documentation build
if not os.environ.get("SPHINX_BUILD"):
try:
_catalog_store = config.get().catalog_store
except ValueError:
_catalog_store = None
else:
_catalog_store = None
if _catalog_store:
bm25, vector_db, columns, col_dict = build_columns_retriever(_catalog_store, config.get().vector_db_path)
column_tables_mapping = build_column_tables_mapping(_catalog_store)
else:
bm25, vector_db, columns, col_dict = None, None, [], {}
column_tables_mapping = {}
def column_retrieval(query, db, k=10, threshold=0.5, filter=None):
"""Retrieves relevant columns based on a similarity search.
Args:
query (str): The query string to search for.
db: The vector database to search in.
k (int, optional): The number of top results to return. Defaults to 10.
threshold (float, optional): The similarity threshold for filtering results. Defaults to 0.5.
filter (dict, optional): A filter to apply to the search. Defaults to None.
Returns:
list: List of relevant column names.
"""
log(f"Get the top relevant columns for query: {query}")
similar_column_key_scores = db.similarity_search_with_score(query, k=k, filter=filter)
# log(f"similar_column_key_scores: {similar_column_key_scores}")
column_names = [key.metadata["column_name"] for (key, score) in similar_column_key_scores if score < threshold]
log(f"Filtered relevant columns: {column_names}")
return column_names
def merge_list(list1, list2):
return list(set(list1 + list2))
def edit_distance_score(key1, key2):
"""Calculate normalized edit distance score between two strings.
Returns:
float: Score between 0 (identical) and 1 (completely different).
"""
dist = Levenshtein.distance(key1, key2)
max_len = max(len(key1), len(key2))
return dist / max_len if max_len > 0 else 1
def edit_distance_search(keywords_list, top_k=10, threshold=0.5):
"""Searches for columns using edit distance similarity.
Args:
keywords_list (list): List of keywords to search for.
top_k (int, optional): The number of top results to return per keyword. Defaults to 10.
threshold (float, optional): The maximum edit distance score to consider. Defaults to 0.5.
Returns:
list: List of relevant column names.
"""
keys = set([re.sub(r"(_id|_name| id| name)$", "", key.lower()) for key in keywords_list])
column_similarity_score = set()
for key in keys:
key_column_similarity_score = {}
for column_name, row in col_dict.items():
column_name_score = edit_distance_score(
key, re.sub(r"(_id|_name| id| name)$", "", row.get("column_name", ""))
)
display_score = edit_distance_score(
key, re.sub(r"(_id|_name| id| name)$", "", row.get("display_name", "").lower())
)
if column_name_score < threshold or display_score < threshold:
key_column_similarity_score[column_name] = min(column_name_score, display_score)
key_top_column = [
key for key, _ in sorted(key_column_similarity_score.items(), key=lambda x: x[1], reverse=True)[:top_k]
]
column_similarity_score.update(key_top_column)
return list(column_similarity_score)
def bm25_search(query_list, top_k=5, score_threshold=0.5):
"""Performs a BM25 search on columns based on the query.
Args:
query_list (list): List of query terms.
top_k (int, optional): The number of top results to return. Defaults to 5.
score_threshold (float, optional): The minimum BM25 score to consider. Defaults to 0.5.
Returns:
list: List of relevant column names.
"""
query_tokens = [token for token in _segmenter.cut(" ".join(query_list)) if token not in ("_", " ")]
scores = bm25.get_scores(query_tokens)
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
results = []
for idx, score in ranked[:top_k]:
if score_threshold and score < score_threshold:
continue
results.append(columns[idx]["column_name"])
return results
def get_relevant_columns(keywords_list, dimensions, metrics):
"""Get the most relevant columns for given keywords, dimensions, and metrics.
Uses multiple retrieval methods (BM25, edit distance, vector similarity)
to find the best matching columns.
Args:
keywords_list (list): General keywords to search for.
dimensions (list): Dimension-specific keywords.
metrics (list): Metric-specific keywords.
Returns:
list: Relevant column names.
"""
# 1. BM25 search for general keywords
total_results = bm25_search(keywords_list, top_k=len(keywords_list) * 4)
# 2. Edit distance search for exact matches
keyword_len = len(keywords_list + dimensions + metrics)
ed_results = edit_distance_search(keywords_list + dimensions + metrics, top_k=keyword_len, threshold=0.3)
total_results = merge_list(total_results, ed_results)
# 3. Vector similarity search for dimensions
if dimensions:
d_results = column_retrieval(" ".join(dimensions), vector_db, k=10, filter={"category": "dimension"})
total_results = merge_list(total_results, d_results)
# 4. Vector similarity search for metrics
if metrics:
m_results = column_retrieval(" ".join(metrics), vector_db, k=10, threshold=0.55, filter={"category": "metric"})
total_results = merge_list(total_results, m_results)
log(f"Relevant columns: {total_results}")
return total_results
================================================
FILE: openchatbi/catalog/store/__init__.py
================================================
"""Catalog store implementations."""
from .file_system import FileSystemCatalogStore
================================================
FILE: openchatbi/catalog/store/file_system.py
================================================
"""File system-based catalog store implementation."""
import csv
import logging
import os
import re
import traceback
from typing import Any
import yaml
from sqlalchemy import Engine
from ..catalog_store import CatalogStore, split_db_table_name
from ..helper import create_sqlalchemy_engine_instance
logger = logging.getLogger(__name__)
class FileSystemCatalogStore(CatalogStore):
"""File system-based data catalog storage implementation.
Stores catalog data in CSV and YAML files on the local filesystem.
"""
data_path: str
table_info_file: str
sql_example_file: str
table_selection_example_file: str
table_columns_file: str
common_columns_file: str
table_spec_columns_file: str
_table_info_cache: dict | None
_table_columns_cache: dict | None
_common_columns_cache: dict | None
_table_spec_columns_cache: dict | None
_sql_example_cache: dict | None
_table_selection_example_cache: dict | None
_data_warehouse_config: dict
_sql_engine: Engine
def __init__(self, data_path: str, data_warehouse_config: dict):
"""Initialize filesystem catalog store.
Args:
data_path (str): Directory absolute path for storing catalog files.
data_warehouse_config (dict): Data warehouse configuration dictionary with keys:
- uri (str): Database connection URI
- include_tables (Optional[List[str]]): List of tables to include, if None include all
- database_name (Optional[str]): Database name to use in catalog
"""
if not isinstance(data_path, str) or not data_path.strip():
raise ValueError("data_path must be a non-empty string")
if data_warehouse_config is None:
data_warehouse_config = {}
elif not isinstance(data_warehouse_config, dict):
raise ValueError("data_warehouse_config must be a dictionary")
self.data_path = data_path.strip()
self.table_info_file = os.path.join(data_path, "table_info.yaml")
self.sql_example_file = os.path.join(data_path, "sql_example.yaml")
self.table_selection_example_file = os.path.join(data_path, "table_selection_example.csv")
self.table_columns_file = os.path.join(data_path, "table_columns.csv")
self.common_columns_file = os.path.join(data_path, "common_columns.csv")
self.table_spec_columns_file = os.path.join(data_path, "table_spec_columns.csv")
# Ensure directory exists with proper error handling
try:
os.makedirs(self.data_path, exist_ok=True)
except (OSError, PermissionError) as e:
raise RuntimeError(f"Failed to create data directory '{self.data_path}': {e}") from e
# Initialize cache
self._table_info_cache = None
self._table_columns_cache = None
self._common_columns_cache = None
self._table_spec_columns_cache = None
self._sql_example_cache = None
self._table_selection_example_cache = None
self._data_warehouse_config = data_warehouse_config
try:
self._sql_engine = create_sqlalchemy_engine_instance(data_warehouse_config)
except Exception as e:
logger.warning(f"Failed to create SQL engine: {e}. Some catalog operations may not work.")
self._sql_engine = None
def _clear_cache(self) -> None:
"""
Clear all cached data to ensure consistency after data modifications
"""
self._table_info_cache = None
self._table_columns_cache = None
self._common_columns_cache = None
self._table_spec_columns_cache = None
self._sql_example_cache = None
self._table_selection_example_cache = None
logger.debug("Cleared all caches")
def get_data_warehouse_config(self) -> dict:
return self._data_warehouse_config
def get_sql_engine(self) -> Engine:
if self._sql_engine is None:
raise RuntimeError("SQL engine is not available. Check data warehouse configuration.")
return self._sql_engine
def _validate_table_name(self, table: str) -> bool:
"""
Validate table name
Args:
table (str): Table name
Returns:
bool: Whether the table name is valid
Raises:
ValueError: If table name is invalid
"""
if not table or not isinstance(table, str):
raise ValueError("Table name must be a non-empty string")
# Check for invalid characters (allow dots for db.table format)
invalid_chars = ["/", "\\", "*", "?", "<", ">", "|", '"', "'"]
if any(char in table for char in invalid_chars):
raise ValueError(f"Table name contains invalid characters: {table}")
return True
def _validate_column_data(self, columns: list[dict[str, Any]]) -> bool:
"""
Validate column data format
Args:
columns (List[Dict[str, Any]]): List of column information
Returns:
bool: Whether the column data is valid
Raises:
ValueError: If column data is invalid
"""
if not isinstance(columns, list):
raise ValueError("Columns must be a list")
required_fields = {"column_name", "type"}
for i, column in enumerate(columns):
if not isinstance(column, dict):
raise ValueError(f"Column {i} must be a dictionary")
# Check required fields
missing_fields = required_fields - set(column.keys())
if missing_fields:
raise ValueError(f"Column {i} missing required fields: {missing_fields}")
# Validate column_name
column_name = column.get("column_name")
if not isinstance(column_name, str) or not column_name.strip():
raise ValueError(f"Column {i}: column_name must be a non-empty string")
# Validate type
column_type = column.get("type")
if not isinstance(column_type, str) or not column_type.strip():
raise ValueError(f"Column {i}: type must be a non-empty string")
return True
def _validate_table_information(self, information: dict[str, Any]) -> bool:
"""
Validate table information format
Args:
information (Dict[str, Any]): Table information
Returns:
bool: Whether the table information is valid
Raises:
ValueError: If table information is invalid
"""
if not isinstance(information, dict):
raise ValueError("Table information must be a dictionary")
# Validate optional string fields
string_fields = ["description", "selection_rule"]
for field in string_fields:
if field in information:
value = information[field]
if value is not None and not isinstance(value, str):
raise ValueError(f"Table information field '{field}' must be a string or None")
return True
def _validate_sql_examples(self, examples: list[dict[str, str]]) -> bool:
"""
Validate SQL examples format
Args:
examples (List[Dict[str, str]]): List of SQL examples
Returns:
bool: Whether the SQL examples are valid
Raises:
ValueError: If SQL examples are invalid
"""
if not isinstance(examples, list):
raise ValueError("Examples must be a list")
required_fields = {"question", "answer"}
for i, example in enumerate(examples):
if not isinstance(example, dict):
raise ValueError(f"Example {i} must be a dictionary")
# Check required fields
missing_fields = required_fields - set(example.keys())
if missing_fields:
raise ValueError(f"Example {i} missing required fields: {missing_fields}")
# Validate fields are non-empty strings
for field in required_fields:
value = example.get(field)
if not isinstance(value, str) or not value.strip():
raise ValueError(f"Example {i}: {field} must be a non-empty string")
return True
@staticmethod
def _load_yaml_file(file_path: str) -> dict:
"""
Load YAML file
Args:
file_path (str): File path
Returns:
Dict: YAML content
"""
if not os.path.exists(file_path):
logger.debug(f"YAML file does not exist: {file_path}")
return {}
try:
with open(file_path, encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
logger.debug(f"Successfully loaded YAML file: {file_path}")
return data
except Exception as e:
logger.error(f"Failed to load YAML file {file_path}: {e}")
logger.error(traceback.format_stack())
return {}
@staticmethod
def _load_csv_file(file_path: str) -> list[dict[str, str]]:
"""
Load CSV file
Args:
file_path (str): File path
Returns:
List[Dict[str, str]]: List of rows as dictionaries
"""
if not os.path.exists(file_path):
logger.debug(f"CSV file does not exist: {file_path}")
return []
try:
result = []
with open(file_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
result.append(row)
logger.debug(f"Successfully loaded CSV file: {file_path} with {len(result)} rows")
return result
except Exception as e:
logger.error(f"Failed to load CSV file {file_path}: {e}")
logger.error(traceback.format_stack())
return []
@staticmethod
def _save_yaml_file(file_path: str, data: dict) -> bool:
"""
Save YAML file
Args:
file_path (str): File path
data (Dict): Data to save
Returns:
bool: Whether the save was successful
"""
try:
with open(file_path, "w", encoding="utf-8") as f:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
return True
except Exception as e:
logger.error(f"Failed to save YAML file {file_path}: {e}")
logger.error(traceback.format_stack())
return False
@staticmethod
def _save_csv_file(file_path: str, data: list[dict[str, str]], headers: list[str] = None) -> bool:
"""
Save CSV file
Args:
file_path (str): File path
data (List[Dict[str, str]]): List of rows as dictionaries
headers (List[str]): List of header names in sequence
Returns:
bool: Whether the save was successful
"""
try:
if not data:
return True
# Get all possible headers from all rows
all_headers = set()
for row in data:
all_headers.update(row.keys())
# If specify field_names, make sure all keys are in field_names
if headers is not None:
for key in all_headers:
if key not in headers:
headers.append(key)
with open(file_path, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=headers)
writer.writeheader()
for row in data:
writer.writerow(row)
return True
except Exception as e:
logger.error(f"Failed to save CSV file {file_path}: {e}")
logger.error(traceback.format_stack())
return False
def _load_tables(self) -> dict[str, list[str]]:
# Load table_columns.csv
table_columns_csv = self._load_csv_file(self.table_columns_file)
# Get unique db_name.table_name combinations
table_dict = {}
for row in table_columns_csv:
if "db_name" in row and "table_name" in row and "column_name" in row:
db_name = row["db_name"]
table_name = row["table_name"]
column_name = row["column_name"]
full_table_name = f"{db_name}.{table_name}"
if full_table_name not in table_dict:
table_dict[full_table_name] = []
table_dict[full_table_name].append(column_name)
return table_dict
def _load_common_columns(self) -> dict[str, dict[str, Any]]:
# Load common_columns.csv to get column details
columns_csv = self._load_csv_file(self.common_columns_file)
# Filter and return column details
column_dict = {}
for row in columns_csv:
if row.get("column_name") and row.get("type"):
# Convert row to Dict[str, Any]
column_info = {}
for key, value in row.items():
if key != "":
column_info[key] = value
column_dict[row["column_name"]] = column_info
return column_dict
def _load_table_spec_columns(self) -> dict[str, dict[str, Any]]:
"""
Load info of table spec columns
Returns:
Dict[str, Dict[str, Any]]: Dictionary of table specific columns information, keyed by "full_table_name:column_name"
"""
# Load table_spec_columns.csv to get table specific column details
columns_csv = self._load_csv_file(self.table_spec_columns_file)
# Filter and return column details
column_dict = {}
for row in columns_csv:
if "db_name" in row and "table_name" in row and "column_name" in row and row["column_name"]:
# Convert row to Dict[(str, str), Any]
full_table_name = f"{row['db_name']}.{row['table_name']}"
column_info = {}
for key, value in row.items():
if key != "":
column_info[key] = value
column_dict[f"{full_table_name}:{row['column_name']}"] = column_info
return column_dict
def _parse_example_text(self, example_text: str) -> list[tuple[str, str]]:
"""
Parse example text, format is Q: ... A: ...
Args:
example_text (str): Example text
Returns:
List[Tuple[str, str]]: List of parsed question-answer pairs
"""
examples = []
lines = example_text.strip().split("\n")
question = ""
answer = ""
current_type = None
for line in lines:
if line.startswith("Q:"):
# If there is already a complete question-answer pair, add it to the results
if question and answer:
examples.append((question.strip(), answer.strip()))
question = ""
answer = ""
question = line[2:]
current_type = "Q"
elif line.startswith("A:"):
answer = line[2:]
current_type = "A"
else:
# Continue adding to the current type
if current_type == "Q":
question += "\n" + line
elif current_type == "A":
answer += "\n" + line
# Add the last question-answer pair
if question and answer:
examples.append((question.strip(), answer.strip()))
return examples
def get_database_list(self) -> list[str]:
# Extract unique database names
databases = set()
for table in self._get_all_table_schema().keys():
full_table_name, db_name, table_name = split_db_table_name(table)
databases.add(db_name)
return list(databases)
def _get_all_table_schema(self) -> dict[str, list[str]]:
"""
Get all tables schema (columns of table)
Returns:
Dict[str, List[str]]: Tables schema (columns) dict, keyed by table name
"""
if self._table_columns_cache is None:
self._table_columns_cache = self._load_tables()
# Return a deep copy to prevent external modifications
return {k: v.copy() for k, v in self._table_columns_cache.items()}
def get_table_list(self, database: str | None = None) -> list[str]:
tables = self._get_all_table_schema()
if database is None:
return list(tables.keys())
# Filter by database
filtered_tables = []
for full_table_name in tables.keys():
_, db_name, table_name = split_db_table_name(full_table_name)
if db_name == database:
filtered_tables.append(full_table_name)
return filtered_tables
def _get_common_columns(self) -> dict[str, dict[str, Any]]:
"""
Get information of all common columns
Returns:
Dict[str, Dict[str, Any]]: Dictionary of columns information, keyed by column name
"""
if self._common_columns_cache is None:
self._common_columns_cache = self._load_common_columns()
# Return a deep copy to prevent external modifications
return {k: v.copy() for k, v in self._common_columns_cache.items()}
def _get_table_spec_columns(self) -> dict[str, dict[str, Any]]:
"""
Get information of all table specific columns
Returns:
Dict[str, Dict[str, Any]]: Dictionary of table specific columns information, keyed by "full_table_name:column_name"
"""
if self._table_spec_columns_cache is None:
self._table_spec_columns_cache = self._load_table_spec_columns()
# Return a deep copy to prevent external modifications
return {k: v.copy() for k, v in self._table_spec_columns_cache.items()}
def get_column_list(self, table: str | None = None, database: str | None = None) -> list[dict[str, Any]]:
_common_columns = self._get_common_columns()
if table is None:
return list(_common_columns.values())
# Get the full table name
full_table_name, db_name, table_name = split_db_table_name(table, database)
# Filter table columns
tables_dict = self._get_all_table_schema()
if full_table_name not in tables_dict:
return []
table_columns = tables_dict[full_table_name]
# If no columns found, return empty list
if not table_columns:
return []
# Filter and return column details
result = []
_table_spec_columns = self._get_table_spec_columns()
for column in table_columns:
# check if the column is table specific
key = f"{full_table_name}:{column}"
if key in _table_spec_columns:
column_info = _table_spec_columns[key]
column_info["is_common"] = False
result.append(column_info)
else:
column_info = _common_columns.get(column)
if column_info:
column_info["is_common"] = True
result.append(column_info)
return result
def get_table_information(self, table: str, database: str | None = None) -> dict[str, Any]:
full_table_name, db_name, table_name = split_db_table_name(table, database)
if self._table_info_cache is None:
self._table_info_cache = self._load_yaml_file(self.table_info_file)
if db_name in self._table_info_cache and table_name in self._table_info_cache[db_name]:
# Return a copy to prevent external modifications
return self._table_info_cache[db_name][table_name].copy()
return {}
def get_sql_examples(
self, table: str | None = None, database: str | None = None
) -> list[tuple[str, str, list[str]]]:
if self._sql_example_cache is None:
self._sql_example_cache = self._load_yaml_file(self.sql_example_file)
if table is None:
# If no table specified, return all examples
examples = []
for db_name, tables in self._sql_example_cache.items():
for table_name, example_text in tables.items():
qa_pairs = self._parse_example_text(example_text)
examples.extend([(q, a, [f"{db_name}.{table_name}"]) for (q, a) in qa_pairs])
return examples
full_table_name, db_name, table_name = split_db_table_name(table, database)
# Find examples that include this table
examples = []
# Check the fact section
if db_name in self._sql_example_cache:
if table_name in self._sql_example_cache[db_name]:
# Parse example text, format is Q: ... A: ...
qa_pairs = self._parse_example_text(self._sql_example_cache[db_name][table_name])
examples.extend([(q, a, [full_table_name]) for (q, a) in qa_pairs])
return examples
@staticmethod
def _load_table_selection_examples_from_csv(file_path: str) -> list[tuple[str, list[str]]]:
examples = []
try:
with open(file_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
question = row.get("question", "").strip()
selected_tables = row.get("selected_tables", "").strip()
if question and selected_tables:
table_list = [p.strip() for p in re.split(r"[ ,\n]", selected_tables) if p.strip()]
examples.append((question, table_list))
except (FileNotFoundError, PermissionError, UnicodeDecodeError) as e:
logger.warning(f"Failed to load table selection examples from {file_path}: {e}")
return examples
def get_table_selection_examples(self) -> list[tuple[str, list[str]]]:
if self._table_selection_example_cache is None:
self._table_selection_example_cache = self._load_table_selection_examples_from_csv(
self.table_selection_example_file
)
return self._table_selection_example_cache
def save_table_information(
self,
table: str,
information: dict[str, Any],
columns: list[dict[str, Any]],
database: str | None = None,
update_existing: bool = False,
) -> bool:
# Validate input data (let validation errors propagate)
self._validate_table_name(table)
self._validate_table_information(information)
self._validate_column_data(columns)
try:
full_table_name, db_name, table_name = split_db_table_name(table, database)
table_info = self._load_yaml_file(self.table_info_file)
# Save columns first
if not self._save_columns(table_name, columns, db_name, update_existing):
logger.error(f"Failed to save columns for table {full_table_name}")
return False
# Save table information (ensure proper structure)
if db_name not in table_info:
table_info[db_name] = {}
if update_existing or table_name not in table_info[db_name]:
table_info[db_name][table_name] = information
success = self._save_yaml_file(self.table_info_file, table_info)
if success:
logger.info(f"Successfully saved table information for {full_table_name}")
# Clear cache to ensure consistency
self._clear_cache()
return success
except Exception as e:
logger.error(f"Unexpected error when saving table information: {e}")
logger.error(traceback.format_stack())
return False
def _save_columns(
self, table_name: str, columns: list[dict[str, Any]], db_name: str = "", update_existing: bool = False
) -> bool:
"""
Save columns information to common_columns.csv and columns of tables to table_columns.csv
Args:
table_name (str): Table name
columns (List[Dict[str, Any]]): List of column information
db_name (str): Database name
update_existing (bool): Update existing column information
Returns:
bool: Whether the save was successful
"""
full_table_name, db_name, table_name = split_db_table_name(table_name, db_name)
# Load existing data
tables_data = self._load_csv_file(self.table_columns_file)
common_columns_dict = self._load_common_columns()
table_spec_columns_dict = self._load_table_spec_columns()
# Create a set of existing table-column combinations
existing_table_columns = set()
for row in tables_data:
if "db_name" in row and "table_name" in row and "column_name" in row:
key = f"{row['db_name']}.{row['table_name']}:{row['column_name']}"
existing_table_columns.add(key)
# Update table_columns.csv and track new columns to add
for column in columns:
if "column_name" not in column:
continue
column_name = column["column_name"]
is_common_column = column.get("is_common", False)
key = f"{full_table_name}:{column_name}"
column_info = {k: str(v) for k, v in column.items() if k != "is_common"}
if not is_common_column:
column_info["db_name"] = db_name
column_info["table_name"] = table_name
# New column of the table -> add to table_columns.csv
if key not in existing_table_columns:
tables_data.append({"db_name": db_name, "table_name": table_name, "column_name": column_name})
existing_table_columns.add(key)
if is_common_column:
# Handle common_columns.csv - avoid duplicates
if column_name not in common_columns_dict:
# Add new columns to columns_data
logger.info(f"Add new column column {column_name}")
common_columns_dict[column_name] = column_info
else:
table_spec_columns_dict[key] = column_info
# Apply updates to existing columns in columns_data
elif update_existing:
if is_common_column:
common_columns_dict[column_name] = column_info
else:
table_spec_columns_dict[key] = column_info
# Save updated data
tables_success = self._save_csv_file(
self.table_columns_file, tables_data, ["db_name", "table_name", "column_name"]
)
common_columns_success = self._save_csv_file(
self.common_columns_file,
list(common_columns_dict.values()),
["column_name", "display_name", "alias", "type", "category", "tag", "description"],
)
table_spec_columns_success = self._save_csv_file(
self.table_spec_columns_file,
list(table_spec_columns_dict.values()),
["db_name", "table_name", "column_name", "display_name", "alias", "type", "category", "tag", "description"],
)
success = tables_success and common_columns_success and table_spec_columns_success
if success:
# Clear cache to ensure consistency
self._clear_cache()
logger.debug(f"Successfully saved columns for table {table_name}")
return success
def save_table_sql_examples(self, table: str, examples: list[dict[str, str]], database: str | None = None) -> bool:
# Validate input data (let validation errors propagate)
self._validate_table_name(table)
self._validate_sql_examples(examples)
try:
full_table_name, db_name, table_name = split_db_table_name(table, database)
sql_examples = self._load_yaml_file(self.sql_example_file)
# Ensure database exists in structure
if db_name not in sql_examples:
sql_examples[db_name] = {}
# example text
example_text = ""
for example in examples:
example_text += f"Q: {example['question']}\nA: {example['answer']}\n\n"
sql_examples[db_name][table_name] = example_text.strip()
success = self._save_yaml_file(self.sql_example_file, sql_examples)
if success:
logger.info(f"Successfully saved {len(examples)} examples for table {full_table_name}")
# Update cache
self._sql_example_cache = sql_examples
return success
except Exception as e:
logger.error(f"Unexpected error when saving table examples: {e}")
logger.error(traceback.format_stack())
return False
def save_table_selection_examples(self, examples: list[tuple[str, list[str]]]) -> bool:
example_data = []
for example in examples:
example_data.append({"question": example[0], "selected_tables": example[1]})
save_success = self._save_csv_file(
self.table_selection_example_file, example_data, ["question", "selected_tables"]
)
if save_success:
logger.info(f"Successfully saved {len(examples)} table selection examples.")
return save_success
def check_exists(self) -> bool:
try:
# Check if essential catalog files exist and have content
files_missing = (
not os.path.exists(self.table_columns_file)
or not os.path.exists(self.common_columns_file)
or os.path.getsize(self.table_columns_file) <= 1 # Empty or just header
or os.path.getsize(self.common_columns_file) <= 1
)
return not files_missing
except Exception as e:
logger.warning(f"Error checking catalog existence: {e}")
logger.error(traceback.format_stack())
return False
================================================
FILE: openchatbi/catalog/token_service.py
================================================
"""Token service for authentication with external services."""
import json
import requests
class TokenService:
"""Service for managing authentication tokens.
Handles token application, validation, and authentication
with external services.
"""
base_url = None
token = None
user_name = None
password = None
def __init__(self, user_name: str, password: str):
"""Initialize token service."""
self.user_name = user_name
self.password = password
def apply_token(self):
"""Apply for authentication token using credentials."""
response = requests.post(
self.base_url + "/apply_token", data=json.dumps({"user_name": self.user_name, "password": self.password})
)
resp_json = response.json()
self.token = resp_json.get("token")
def apply_token_for_user(token_url: str, user_name: str, password: str):
"""Apply for token and return token with username.
Args:
token_url (str): Base URL for token service.
user_name (str): The user name.
password (str): The password.
Returns:
token
"""
token_service = TokenService(user_name, password)
token_service.base_url = token_url
token_service.apply_token()
return token_service.token
================================================
FILE: openchatbi/code/docker_executor.py
================================================
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
import docker
from docker.errors import ContainerError
from openchatbi.code.executor_base import ExecutorBase
def check_docker_status() -> tuple[bool, str]:
"""
Check Docker installation and status without initializing DockerExecutor.
Returns:
Tuple[bool, str]: (is_available, status_message)
"""
try:
# Check if Docker CLI is installed
if not shutil.which("docker"):
return False, "Docker is not installed. Please install Docker."
# Check if Docker daemon is running
result = subprocess.run(["docker", "info"], capture_output=True, text=True, timeout=10)
if result.returncode == 0:
return True, "Docker is installed and running"
else:
if "Cannot connect to the Docker daemon" in result.stderr:
return False, "Docker is installed but not running. Please start the Docker daemon."
else:
return False, f"Docker is not available: {result.stderr.strip()}"
except subprocess.TimeoutExpired:
return False, "Docker command timed out. Docker may not be running properly."
except FileNotFoundError:
return False, "Docker command not found. Please install Docker."
except Exception as e:
return False, f"Error checking Docker status: {str(e)}"
class DockerExecutor(ExecutorBase):
"""Docker-based Python code executor for isolated execution."""
def __init__(self, variable: dict = None):
super().__init__(variable)
self.image_name = "python-executor"
self.dockerfile_path = Path(__file__).parent.parent.parent / "Dockerfile.python-executor"
# Check Docker installation and status
self._check_docker_availability()
try:
self.client = docker.from_env()
# Build Docker image if it doesn't exist
self._ensure_image_exists()
except Exception as e:
self._handle_docker_error(e)
@staticmethod
def _check_docker_availability():
"""Check if Docker is installed and available."""
# Check if Docker CLI is installed
if not shutil.which("docker"):
raise RuntimeError("Docker is not installed. Please install Docker and ensure it's in your system PATH.")
# Check if Docker daemon is running
try:
result = subprocess.run(["docker", "info"], capture_output=True, text=True, timeout=10)
if result.returncode != 0:
if "Cannot connect to the Docker daemon" in result.stderr:
raise RuntimeError(
"Docker is installed but not running. Please start the Docker daemon and try again."
)
else:
raise RuntimeError(
f"Docker is not available. Please check Docker installation and status. "
f"Error: {result.stderr.strip()}"
)
except subprocess.TimeoutExpired:
raise RuntimeError("Docker command timed out. Please check if Docker is running properly.")
except FileNotFoundError:
raise RuntimeError("Docker command not found. Please install Docker and ensure it's in your system PATH.")
@staticmethod
def _handle_docker_error(error: Exception):
"""Handle Docker-related errors with specific error messages."""
error_str = str(error).lower()
if "connection aborted" in error_str and "no such file or directory" in error_str:
raise RuntimeError("Docker is not running. Please start the Docker daemon and try again.")
elif "permission denied" in error_str:
raise RuntimeError(
"Permission denied accessing Docker. Please ensure your user has Docker permissions "
"or try running with appropriate privileges."
)
elif "docker daemon" in error_str or "connection refused" in error_str:
raise RuntimeError("Cannot connect to Docker daemon. Please start the Docker daemon and try again.")
else:
raise RuntimeError(
f"Failed to initialize Docker client. Please ensure Docker is installed and running. "
f"Error: {str(error)}"
)
def _ensure_image_exists(self):
"""Build Docker image if it doesn't exist."""
try:
self.client.images.get(self.image_name)
except docker.errors.ImageNotFound:
print(f"Building Docker image '{self.image_name}'...")
self.client.images.build(
path=str(self.dockerfile_path.parent),
dockerfile=self.dockerfile_path.name,
tag=self.image_name,
rm=True,
)
print(f"Docker image '{self.image_name}' built successfully.")
def run_code(self, code: str) -> tuple[bool, str]:
"""Execute Python code in a Docker container."""
try:
# Create a temporary file with the code
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
# Add variable definitions to the code
variable_code = ""
for key, value in self._variable.items():
if isinstance(value, str):
variable_code += f'{key} = "{value}"\n'
else:
variable_code += f"{key} = {repr(value)}\n"
full_code = variable_code + "\n" + code
f.write(full_code)
temp_file_path = f.name
try:
# Run the code in a Docker container
container = self.client.containers.run(
self.image_name,
command=["python3", f"/app/{os.path.basename(temp_file_path)}"],
volumes={temp_file_path: {"bind": f"/app/{os.path.basename(temp_file_path)}", "mode": "ro"}},
remove=True,
detach=False,
stdout=True,
stderr=True,
network_mode="none", # Disable network access for security
)
# Get the output
output = container.decode("utf-8")
return True, output
except ContainerError as e:
# Container exited with non-zero code
error_output = e.stderr if e.stderr else str(e)
return False, f"Container execution failed: {error_output}"
except Exception as e:
return False, f"Docker execution error: {str(e)}"
finally:
# Clean up temporary file
if "temp_file_path" in locals() and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except (OSError, PermissionError) as e:
# Log but don't fail the operation for cleanup issues
print(f"Warning: Failed to clean up temporary file {temp_file_path}: {e}")
def __del__(self):
"""Clean up Docker client on deletion."""
try:
if hasattr(self, "client") and self.client is not None:
self.client.close()
except Exception:
# Ignore cleanup errors during object destruction
pass
================================================
FILE: openchatbi/code/executor_base.py
================================================
from typing import Any
class ExecutorBase:
"""Base class for executing python code."""
_variable: dict
def __init__(self, variable: dict = None):
if variable is None:
self._variable = {}
else:
self._variable = variable
def run_code(self, code: str) -> (bool, str):
"""Execute python code."""
raise NotImplementedError()
def set_variable(self, key: str, value: Any) -> None:
"""Set variable."""
self._variable[key] = value
================================================
FILE: openchatbi/code/local_executor.py
================================================
import sys
from io import StringIO
from openchatbi.code.executor_base import ExecutorBase
class LocalExecutor(ExecutorBase):
def run_code(self, code: str) -> str:
safe_globals = {"__builtins__": __builtins__}
original_stdout = sys.stdout
output_buffer = StringIO()
sys.stdout = output_buffer
try:
exec(code, safe_globals, safe_globals)
output = output_buffer.getvalue()
return True, output
except Exception as e:
return False, str(e)
finally:
sys.stdout = original_stdout
================================================
FILE: openchatbi/code/restricted_local_executor.py
================================================
import sys
from io import StringIO
from RestrictedPython import compile_restricted, safe_globals, utility_builtins
from RestrictedPython.Guards import safe_builtins, safer_getattr
from openchatbi.code.executor_base import ExecutorBase
class RestrictedLocalExecutor(ExecutorBase):
def run_code(self, code: str) -> (bool, str):
try:
# compile restricted code
byte_code = compile_restricted(code, "", "exec")
if byte_code is None:
return False, "Failed to compile restricted code"
restricted_locals = {}
restricted_globals = safe_globals.copy()
# Set up restricted environment with necessary functions
restricted_globals.update(safe_builtins)
restricted_globals["_getattr_"] = safer_getattr
restricted_globals["__builtins__"] = utility_builtins
# Add variable definitions to the restricted locals
for key, value in self._variable.items():
restricted_locals[key] = value
# Capture print output
original_stdout = sys.stdout
output_buffer = StringIO()
sys.stdout = output_buffer
# Use the standard print function for RestrictedPython
restricted_globals["_print_"] = lambda *args, **kwargs: print(*args, **kwargs)
exec(byte_code, restricted_globals, restricted_locals)
output = output_buffer.getvalue()
return True, output
except Exception as e:
return False, str(e)
finally:
if "original_stdout" in locals():
sys.stdout = original_stdout
================================================
FILE: openchatbi/config.yaml.template
================================================
organization: The Company
dialect: presto
bi_config_file: example/bi.yaml
# Python Code Execution Configuration
# Options: "local", "restricted_local", "docker"
# - local: Run code in the current Python process (fastest, least secure)
# - restricted_local: Run code with RestrictedPython (moderate security, some limitations)
# - docker: Run code in isolated Docker containers (slowest, most secure, requires Docker to be installed)
python_executor: local
# Visualization configuration
# Options: "rule" (rule-based), "llm" (LLM-based), or null (skip visualization)
# visualization_mode: llm
# Context management configuration
# Controls how conversation context is managed and compressed when it becomes too long
context_config:
# Enable/disable context management entirely
enabled: true
# Token limit that triggers context management (when conversation exceeds this, compression starts)
summary_trigger_tokens: 12000
# Number of recent messages to always preserve in full (never compress these)
keep_recent_messages: 20
# Historical tool output compression limits
max_tool_output_length: 2000 # Max length for historical tool outputs
max_sql_result_rows: 50 # Max rows to keep in CSV results
max_code_output_lines: 50 # Max lines for code execution output
# Conversation summarization settings
enable_summarization: true # Enable conversation summarization
enable_conversation_summary: true # Enable detailed conversation summary
summary_max_messages: 50 # Max messages to include in summary context
# Content preservation settings
preserve_tool_errors: true # Always preserve error messages in full
preserve_recent_sql: true # Preserve SQL content (less aggressive compression)
# Time Series Forecasting Service Configuration
# URL for the time series forecasting service endpoint, adjust based on your deployment scenario:
# - Local development (OpenChatBI on host, Forecasting service in Docker): "http://localhost:8765"
# - Remote service: "http://your-service-host:8765"
timeseries_forecasting_service_url: "http://localhost:8765"
# Catalog store configuration
catalog_store:
store_type: file_system
data_path: ./example
# Data warehouse configuration
data_warehouse_config:
uri: "presto://{user_name}@domain:8080/db/default"
include_tables:
- null # null means include all tables, or specify yaml list
database_name: "db.default" # database name to use in catalog
token_service: "https://tokens-domain:8080/v1"
user_name: TOKEN_SERVICE_USER_NAME
password: TOKEN_SERVICE_PASSWORD
# Vector database (chroma) path
# vector_db_path: ./.chroma_db
# LLM configurations (multiple providers)
#
# 1) Define providers under `llm_providers`
# 2) Select which one to use by setting `default_llm: `
default_llm: openai
llm_providers:
openai:
default_llm:
class: langchain_openai.ChatOpenAI
params:
api_key: YOUR_API_KEY_HERE
model: gpt-4.1
temperature: 0.01
max_tokens: 8192
embedding_model:
class: langchain_openai.OpenAIEmbeddings
params:
api_key: YOUR_API_KEY_HERE
model: text-embedding-3-large
chunk_size: 1024
# Optional
text2sql_llm:
class: langchain_openai.ChatOpenAI
params:
api_key: YOUR_API_KEY_HERE
model: gpt-4.1
temperature: 0.0
max_tokens: 8192
# anthropic:
# default_llm:
# class: langchain_anthropic.ChatAnthropic
# params:
# api_key: YOUR_API_KEY_HERE
# model: claude-3-5-sonnet-latest
# MCP (Model Context Protocol) server configurations
mcp_servers:
# File system MCP server (stdio transport)
- name: filesystem
transport: stdio
command: ["npx", "-y", "@modelcontextprotocol/server-filesystem"]
args: ["--path", "/tmp"]
enabled: false
timeout: 30
# Example HTTP-based MCP server (streamable_http transport)
- name: weather
transport: streamable_http
url: "http://localhost:8000/mcp/"
headers:
Authorization: "Bearer YOUR_TOKEN"
enabled: false
timeout: 30
================================================
FILE: openchatbi/config_loader.py
================================================
import importlib
import os
from importlib.util import find_spec
from typing import Any
from unittest.mock import MagicMock
from langchain_core.language_models import BaseChatModel
from pydantic import BaseModel
from openchatbi.catalog.factory import create_catalog_store
from openchatbi.utils import log
class LLMProviderConfig(BaseModel):
"""Resolved LLM objects for a single provider."""
model_config = {"arbitrary_types_allowed": True}
default_llm: BaseChatModel | MagicMock
embedding_model: BaseModel | MagicMock | None = None
text2sql_llm: BaseChatModel | MagicMock | None = None
class Config(BaseModel):
"""Configuration model for the OpenChatBI application.
Attributes:
organization (str): Organization name. Defaults to "The Company".
dialect (str): SQL dialect to use. Defaults to "presto".
default_llm (BaseChatModel): Default language model for general tasks.
embedding_model (BaseModel): Language model for embedding generation.
text2sql_llm (Optional[BaseChatModel]): Language model specifically for text-to-SQL tasks.
bi_config (Dict[str, Any]): BI configuration loaded from YAML file. Defaults to empty dict.
data_warehouse_config (Dict[str, Any]): Data warehouse configuration. Defaults to empty dict.
"""
model_config = {"arbitrary_types_allowed": True}
# General Configurations
organization: str = "The Company"
dialect: str = "presto"
# LLM Configurations
default_llm: BaseChatModel | MagicMock
embedding_model: BaseModel | MagicMock | None = None
text2sql_llm: BaseChatModel | MagicMock | None = None
# Multiple LLM providers (optional)
llm_provider: str | None = None
llm_providers: dict[str, LLMProviderConfig] = {}
# BI Configuration
bi_config: dict[str, Any] = {}
# Data Warehouse Configuration
data_warehouse_config: dict[str, Any] = {}
# Catalog Store
catalog_store: Any = None
# Path to the vector database file
vector_db_path: str = None
# MCP Servers Configuration
mcp_servers: list[dict[str, Any]] = []
# Report Configuration
report_directory: str = "./data"
# Code Execution Configuration
python_executor: str = "local" # Options: "local", "restricted_local", "docker"
# Visualization Configuration
visualization_mode: str | None = "rule" # Options: "rule", "llm", None (skip visualization)
# Context Management Configuration
context_config: dict[str, Any] = {}
# Time Series Service Configuration
timeseries_forecasting_service_url: str = "http://localhost:8765"
@classmethod
def from_dict(cls, config: dict[str, Any]) -> "Config":
"""Creates a Config instance from a dictionary.
Args:
config (Dict[str, Any]): Dictionary containing configuration values.
Returns:
Config: A new Config instance with the provided values.
"""
return cls(**config)
class ConfigLoader:
"""Singleton class to load and manage configuration settings for OpenChatBI.
This class provides methods to load, get, and set configuration parameters
for the application, including LLM models, SQL dialect, and other settings.
"""
_instance = None
_config: Config = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
llm_configs = ["default_llm", "embedding_model", "text2sql_llm"]
def get(self) -> Config:
"""Get the current configuration.
Returns:
Config: The current configuration instance.
Raises:
ValueError: If the configuration has not been loaded.
"""
if self._config is None:
raise ValueError("Configuration has not been loaded. Please call load() or set() first.")
return self._config
def load(self, config_file: str = None) -> None:
"""Load configuration from a YAML file.
Args:
config_file (str, optional): Path to configuration file. Uses CONFIG_FILE
environment variable or 'openchatbi/config.yaml' if not provided.
Raises:
ImportError: If pyyaml is not installed.
FileNotFoundError: If the configuration file cannot be found.
"""
if config_file is None:
config_file = os.getenv("CONFIG_FILE", "openchatbi/config.yaml")
if not find_spec("yaml"):
raise ImportError("Please install pyyaml to use this feature.")
import yaml
try:
with open(config_file, encoding="utf-8") as file:
config_data = yaml.safe_load(file)
if config_data is None:
config_data = {}
except FileNotFoundError:
log(f"Configuration file not found: {config_file}, leave config un-loaded.")
return
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in configuration file {config_file}: {e}")
except Exception as e:
raise RuntimeError(f"Failed to read configuration file {config_file}: {e}")
self._process_config_dict(config_data)
self._config = Config.from_dict(config_data)
def _process_config_dict(self, config_data: dict[str, Any]) -> None:
"""
Processes a configuration dictionary.
"""
self._process_llm_providers(config_data)
providers = config_data.get("llm_providers", {})
selected_provider = None
default_llm_value = config_data.get("default_llm")
if isinstance(default_llm_value, str):
# Simplified multi-provider config: default_llm:
if not providers:
raise ValueError("default_llm is a provider name but llm_providers is missing.")
selected_provider = default_llm_value
elif providers:
# Backwards-compat: allow selecting provider via llm_provider
legacy_provider = config_data.get("llm_provider")
if isinstance(legacy_provider, str):
selected_provider = legacy_provider
elif "default_llm" not in config_data:
# Pick the first provider in config order for backwards-compatible YAML behavior
selected_provider = next(iter(providers.keys()), None)
elif isinstance(default_llm_value, dict):
raise ValueError(
"When using llm_providers, set default_llm to a provider name (e.g. default_llm: openai), "
"not a class config."
)
if providers:
if not selected_provider or selected_provider not in providers:
raise ValueError(f"Unknown LLM provider '{selected_provider}'. Available: {sorted(providers.keys())}")
# Store selected provider for runtime lookups (UI/API can still override per-request)
config_data["llm_provider"] = selected_provider
# Populate top-level LLM objects for legacy call sites
config_data["default_llm"] = providers[selected_provider].default_llm
config_data.setdefault("embedding_model", providers[selected_provider].embedding_model)
config_data.setdefault("text2sql_llm", providers[selected_provider].text2sql_llm)
elif "default_llm" not in config_data:
raise ValueError("Missing LLM config key: default_llm")
if not config_data.get("embedding_model"):
log("WARN: Missing LLM config key: embedding_model, will use BM25 based retrival only")
if "data_warehouse_config" not in config_data:
raise ValueError("Missing Data Warehouse config key: data_warehouse_config")
# Load BI configuration
if "bi_config_file" in config_data:
bi_config = self.load_bi_config(config_data["bi_config_file"])
bi_config.update(config_data.get("bi_config", {}))
config_data["bi_config"] = bi_config
if "catalog_store" in config_data:
if "store_type" not in config_data["catalog_store"]:
raise ValueError("catalog_store must have a store_type field.")
catalog_store = create_catalog_store(
**config_data["catalog_store"],
auto_load=config_data["catalog_store"].get("auto_load", True),
data_warehouse_config=config_data.get("data_warehouse_config"),
)
else:
log("Catalog store config key `catalog_store` not found. Using default file system store.")
catalog_store = create_catalog_store(
store_type="file_system",
auto_load=True,
data_warehouse_config=config_data.get("data_warehouse_config"),
)
config_data["catalog_store"] = catalog_store
for config_key in self.llm_configs:
config_item = config_data.get(config_key)
if not isinstance(config_item, dict) or "class" not in config_item:
continue
config_data[config_key] = self._instantiate_from_config_dict(config_item, config_key=config_key)
def _instantiate_from_config_dict(self, config_item: dict[str, Any], *, config_key: str) -> Any:
try:
class_path = config_item["class"]
if "." not in class_path:
raise ValueError(f"Invalid class path format: {class_path}")
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
llm_cls = getattr(module, class_name)
params = config_item.get("params", {})
return llm_cls(**params)
except (ImportError, AttributeError, ValueError, TypeError) as e:
raise RuntimeError(f"Failed to load {config_key} class '{config_item.get('class', '')}': {e}") from e
def _process_llm_providers(self, config_data: dict[str, Any]) -> None:
"""Resolve llm_providers into instantiated provider configs (if present)."""
raw_providers = config_data.get("llm_providers")
if not raw_providers:
return
if not isinstance(raw_providers, dict):
raise ValueError("llm_providers must be a mapping of provider_name -> config")
providers: dict[str, LLMProviderConfig] = {}
for provider_name, provider_cfg in raw_providers.items():
if isinstance(provider_cfg, LLMProviderConfig):
providers[str(provider_name)] = provider_cfg
continue
if not isinstance(provider_cfg, dict):
raise ValueError(f"llm_providers.{provider_name} must be a mapping")
resolved_cfg: dict[str, Any] = dict(provider_cfg)
for config_key in self.llm_configs:
config_item = resolved_cfg.get(config_key)
if not isinstance(config_item, dict) or "class" not in config_item:
continue
resolved_cfg[config_key] = self._instantiate_from_config_dict(
config_item, config_key=f"llm_providers.{provider_name}.{config_key}"
)
if "default_llm" not in resolved_cfg or resolved_cfg["default_llm"] is None:
raise ValueError(f"llm_providers.{provider_name} missing default_llm")
providers[str(provider_name)] = LLMProviderConfig(**resolved_cfg)
config_data["llm_providers"] = providers
def load_bi_config(self, bi_config_file: str) -> dict[str, Any]:
"""Load BI configuration from a YAML file.
Args:
bi_config_file (str): Path to the BI configuration file.
Defaults to 'example/bi.yaml'.
Returns:
Dict[str, Any]: The loaded BI configuration as a dictionary.
Raises:
ImportError: If pyyaml is not installed.
FileNotFoundError: If the BI configuration file cannot be found.
"""
if not find_spec("yaml"):
raise ImportError("Please install pyyaml to use this feature.")
import yaml
bi_config_data = {}
try:
with open(bi_config_file, encoding="utf-8") as file:
bi_config_data = yaml.safe_load(file) or {}
except FileNotFoundError:
log(f"Warning: BI config file '{bi_config_file}' not found. Ignore load BI config from yaml file.")
except yaml.YAMLError as e:
log(f"Warning: Invalid YAML in BI config file '{bi_config_file}': {e}. Using empty config.")
except Exception as e:
log(f"Warning: Failed to read BI config file '{bi_config_file}': {e}. Using empty config.")
return bi_config_data
def set(self, config: dict[str, Any]) -> None:
"""Set the configuration from a dictionary.
Args:
config (Dict[str, Any]): Dictionary containing configuration values.
"""
self._process_config_dict(config)
self._config = Config.from_dict(config)
================================================
FILE: openchatbi/constants.py
================================================
"""Constants used throughout the OpenChatBI application."""
# Date/time format strings
datetime_format = "%Y-%m-%d %H:%M:%S"
date_format = "%Y-%m-%d"
datetime_format_ms = "%Y-%m-%d %H:%M:%S.%f"
datetime_format_ms_T = "%Y-%m-%dT%H:%M:%S.%fZ"
# SQL execution status codes
SQL_NA = "SQL_NA"
SQL_SUCCESS = "SQL_SUCCESS"
SQL_EXECUTE_TIMEOUT = "SQL_CHECK_TIMEOUT"
SQL_SYNTAX_ERROR = "SQL_SYNTAX_ERROR"
SQL_UNKNOWN_ERROR = "SQL_UNKNOWN_ERROR"
MCP_TOOL_DEFAULT_TIMEOUT_SECONDS = 60
================================================
FILE: openchatbi/context_config.py
================================================
"""Configuration for context management settings."""
from dataclasses import dataclass
from openchatbi import config
@dataclass
class ContextConfig:
"""Configuration class for context management settings."""
# Enable/disable context management
enabled: bool = True
# Token limits for triggering context management
summary_trigger_tokens: int = 12000
# Message retention (how many recent messages to always preserve)
keep_recent_messages: int = 20
# Historical tool output compression limits
max_tool_output_length: int = 2000 # Max length for historical tool outputs
max_sql_result_rows: int = 50 # Max rows to keep in CSV results
max_code_output_lines: int = 50 # Max lines for code execution output
# Conversation summarization
enable_summarization: bool = True
enable_conversation_summary: bool = True
summary_max_messages: int = 50 # Max messages to include in summary context
# Content preservation settings
preserve_tool_errors: bool = True # Always preserve error messages in full
preserve_recent_sql: bool = True # Preserve SQL content (less aggressive compression)
def get_context_config() -> ContextConfig:
"""Get the current context configuration.
This function loads context configuration from the main config system.
Falls back to default configuration if not available.
Returns:
ContextConfig: The current context configuration
"""
try:
main_config = config.get()
# Check if context_config exists in the main config
if hasattr(main_config, "context_config") and main_config.context_config:
context_config_dict = main_config.context_config
# Create ContextConfig from the loaded configuration
context_config = ContextConfig()
for key, value in context_config_dict.items():
if hasattr(context_config, key):
setattr(context_config, key, value)
return context_config
except (ImportError, ValueError, AttributeError):
# Fall back to default if config system is not available or configured
pass
return ContextConfig()
def update_context_config(**kwargs) -> ContextConfig:
"""Update context configuration with new values.
Args:
**kwargs: Configuration parameters to update
Returns:
ContextConfig: Updated configuration
"""
config = get_context_config()
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
return config
================================================
FILE: openchatbi/context_manager.py
================================================
"""Context management utilities for handling long conversations."""
import json
import re
import uuid
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from openchatbi.context_config import ContextConfig, get_context_config
from openchatbi.llm.llm import call_llm_chat_model_with_retry
from openchatbi.prompts.system_prompt import get_summary_prompt_template
from openchatbi.utils import log
class ContextManager:
"""Manages conversation context to prevent token limit issues."""
def __init__(self, llm: BaseChatModel, config: ContextConfig = None):
"""Initialize context manager.
Args:
llm: Language model for summarization
config: Context configuration. If None, uses default config.
"""
self.llm = llm
self.config = config or get_context_config()
# ============================================================================
# PUBLIC API METHODS
# ============================================================================
def manage_context_messages(self, messages: list) -> None:
"""Main context management function that directly modifies messages list.
Args:
messages: The list of messages to manage (modified in place)
"""
if not self.config.enabled:
return
if not messages:
return
# Check if we need to manage context
estimated_tokens = self.estimate_message_tokens(messages)
if estimated_tokens <= self.config.summary_trigger_tokens:
return # No action needed
log(f"Context management triggered: {estimated_tokens} tokens > {self.config.summary_trigger_tokens}")
# Apply historical tool message compression directly
self._compress_historical_tool_messages(messages)
# Check if we still need summarization after compression
remaining_tokens = self.estimate_message_tokens(messages)
if remaining_tokens > self.config.summary_trigger_tokens and self.config.enable_summarization:
self._apply_conversation_summarization(messages)
log("Context management completed")
# ============================================================================
# TOKEN ESTIMATION METHODS
# ============================================================================
@staticmethod
def estimate_tokens(text: str) -> int:
"""Rough token estimation (1 token ≈ 4 characters for most languages)."""
return len(text) // 4
def estimate_message_tokens(self, messages: list[BaseMessage]) -> int:
"""Estimate total tokens in a list of messages."""
total = 0
for msg in messages:
total += self.estimate_tokens(str(msg.content))
# Add tokens for metadata and structure
total += 50
return total
# ============================================================================
# TOOL OUTPUT TRIMMING METHODS
# ============================================================================
def trim_tool_output(self, content: str, tool_name: str = "") -> str:
"""Trim tool output to manageable size while preserving key information."""
if len(content) <= self.config.max_tool_output_length:
return content
# Preserve full error messages if configured
if self.config.preserve_tool_errors and ("Error:" in content or "Traceback" in content):
return content
# For SQL results, preserve structure
if "```sql" in content or "```csv" in content:
return self._trim_structured_output(content)
# For code execution results
if "```python" in content or "Traceback" in content:
return self._trim_code_output(content)
# Generic trimming
max_len = self.config.max_tool_output_length
trimmed = content[: max_len // 2] + "\n\n... [Output truncated] ...\n\n" + content[-max_len // 2 :]
return trimmed
def _trim_structured_output(self, content: str) -> str:
"""Trim SQL/CSV output while preserving structure."""
parts = []
# Extract SQL query (always keep)
sql_match = re.search(r"```sql\n(.*?)\n```", content, re.DOTALL)
if sql_match:
parts.append(f"```sql\n{sql_match.group(1)}\n```")
# Extract and trim CSV data
csv_match = re.search(r"```csv\n(.*?)\n```", content, re.DOTALL)
if csv_match:
csv_data = csv_match.group(1)
lines = csv_data.split("\n")
max_rows = self.config.max_sql_result_rows
if len(lines) > max_rows: # Keep header + first half + last quarter
keep_start = max_rows // 2
keep_end = max_rows // 4
trimmed_csv = "\n".join(
lines[: keep_start + 1]
+ [f"... [{len(lines) - keep_start - keep_end - 1} rows omitted] ..."]
+ lines[-keep_end:]
)
parts.append(f"```csv\n{trimmed_csv}\n```")
else:
parts.append(f"```csv\n{csv_data}\n```")
# Keep visualization info
viz_match = re.search(r"Visualization Created:.*", content)
if viz_match:
parts.append(viz_match.group(0))
return "\n\n".join(parts)
def _trim_code_output(self, content: str) -> str:
"""Trim Python code execution output."""
# Keep error messages (full) if configured
if self.config.preserve_tool_errors and ("Traceback" in content or "Error:" in content):
return content
lines = content.split("\n")
max_lines = self.config.max_code_output_lines
if len(lines) <= max_lines:
return content
# Keep first half and last quarter
keep_start = max_lines // 2
keep_end = max_lines // 4
return "\n".join(lines[:keep_start] + ["... [Output truncated] ..."] + lines[-keep_end:])
# ============================================================================
# CONVERSATION SUMMARIZATION METHODS
# ============================================================================
def summarize_conversation(self, messages: list[BaseMessage]) -> str:
"""Create a summary of conversation history."""
if not self.config.enable_conversation_summary:
return ""
# Filter out system messages for summarization
# Note: The messages passed in are already historical messages (split point already calculated)
messages_to_summarize = []
for msg in messages:
if not isinstance(msg, SystemMessage):
messages_to_summarize.append(msg)
if not messages_to_summarize:
return ""
# Create summarization prompt
conversation_text = self._format_messages_for_summary(messages_to_summarize)
# Get the summary prompt template from the file and replace placeholder
summary_prompt = get_summary_prompt_template().replace("[conversation_text]", conversation_text)
try:
response = call_llm_chat_model_with_retry(
self.llm, [HumanMessage(content=summary_prompt)], parallel_tool_call=False
)
if isinstance(response, AIMessage):
return f"[Conversation Summary]: {response.content}"
return "[Summary generation failed]"
except Exception as e:
log(f"Failed to generate conversation summary: {e}")
return "[Summary generation failed]"
def _truncate_text(self, text: str, truncate_len: int = 500) -> str:
# do not truncate Conversation Summary
if text.startswith("[Conversation Summary]"):
return text
if len(text) > truncate_len:
return text[:truncate_len] + "... [truncated]"
return text
def _truncate_text_or_list(self, content):
results = []
if isinstance(content, str):
results.append(self._truncate_text(content))
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
results.append(self._truncate_text(item))
elif isinstance(item, dict):
if item["type"] == "text":
results.append(self._truncate_text(item["text"]))
elif item["type"] == "tool_use":
results.append(json.dumps(item))
return results
def _format_messages_for_summary(self, messages: list[BaseMessage]) -> str:
"""Format messages for summary generation."""
formatted = []
max_messages = self.config.summary_max_messages
# Limit messages for summary context
for msg in messages[-max_messages:]:
if isinstance(msg, HumanMessage):
formatted.append(f" {msg.content} ")
elif isinstance(msg, AIMessage):
content = msg.content or ""
formatted.append("")
formatted.extend(self._truncate_text_or_list(content))
formatted.append("")
elif isinstance(msg, ToolMessage):
formatted.append(
f" tool_call_id: {msg.tool_call_id}, "
f"tool: {msg.name}, "
f"status: {msg.status}, "
f"result: {self._truncate_text_or_list(msg.content)} "
)
return "\n".join(formatted)
# ============================================================================
# CONTEXT MANAGEMENT IMPLEMENTATION METHODS
# ============================================================================
def _compress_historical_tool_messages(self, messages: list[BaseMessage]) -> None:
"""Compress historical (not recent) tool messages in place."""
# Find a safe split point
recent_start_index = self._find_safe_split_point(messages)
# Find tool messages in historical part (before recent_start_index) that need compression
for i in range(recent_start_index):
msg = messages[i]
if isinstance(msg, ToolMessage):
original_content = str(msg.content)
# Apply intelligent filtering for tool message compression
if self._should_compress_historical_tool_message(msg, original_content):
trimmed_content = self.trim_tool_output(original_content)
if len(trimmed_content) < len(original_content):
# Update message content directly
messages[i] = ToolMessage(
content=trimmed_content,
tool_call_id=msg.tool_call_id,
id=msg.id, # Keep original ID to preserve position
)
log(
f"Compressed historical tool message: {len(original_content)} -> {len(trimmed_content)} chars"
)
def _apply_conversation_summarization(self, messages: list[BaseMessage]) -> None:
"""Apply conversation summarization by modifying messages list in place."""
if not self.config.enable_conversation_summary:
return
# Find a safe split point that doesn't separate AI messages with tool calls from their ToolMessages
recent_start_index = self._find_safe_split_point(messages)
if recent_start_index == 0:
return # No historical messages to summarize
historical_messages = messages[:recent_start_index]
recent_messages = messages[recent_start_index:]
if len(historical_messages) == 1:
msg = historical_messages[0]
if isinstance(msg, AIMessage) and msg.content.startswith("[Conversation Summary]"):
return
# Generate summary
summary_text = self.summarize_conversation(historical_messages)
if summary_text:
# Rebuild messages list in place: summary + recent
new_messages = [AIMessage(content=summary_text, id=str(uuid.uuid4()))] + recent_messages
# Clear and repopulate the list in place
messages.clear()
messages.extend(new_messages)
log(f"Applied conversation summary, removed {len(historical_messages)} historical messages")
def _find_safe_split_point(self, messages: list[BaseMessage]) -> int:
"""Find a safe split point that start at HumanMessage
Returns the index where recent messages should start (everything before this index is historical).
"""
if len(messages) <= self.config.keep_recent_messages:
return 0 # Keep all messages as recent
# If keep_recent_messages is 0, return all messages as historical
if self.config.keep_recent_messages <= 0:
return len(messages)
# Start from the naive split point
naive_split = len(messages) - self.config.keep_recent_messages
# Find the nearest HumanMessage
for i in range(naive_split, -1, -1):
msg = messages[i]
if isinstance(msg, HumanMessage) or isinstance(msg, dict) and msg["role"] == "user":
return i # Split before this HumanMessage
return naive_split
# ============================================================================
# CONTENT ANALYSIS HELPER METHODS
# ============================================================================
def _should_compress_historical_tool_message(self, tool_msg: ToolMessage, content: str) -> bool:
"""Determine if a historical tool message should be compressed.
Args:
tool_msg: The tool message to evaluate
content: The content of the tool message
Returns:
bool: True if the message should be compressed
"""
# Don't compress if content is already short
if len(content) <= self.config.max_tool_output_length:
return False
# Always preserve error messages if configured
if self.config.preserve_tool_errors and self._is_error_content(content):
return False
# Don't compress recent SQL results if configured
if self.config.preserve_recent_sql and self._is_sql_content(content):
return False
# Compress large outputs from specific tools more aggressively
if self._is_data_query_result(content):
return True
# Compress Python execution results but preserve errors
if self._is_python_execution_result(content):
return not self._is_error_content(content)
# Default: compress if content is long
return True
def _is_error_content(self, content: str) -> bool:
"""Check if content contains error information."""
error_indicators = [
"error:",
"Error:",
"ERROR:",
"exception:",
"Exception:",
"EXCEPTION:",
"traceback",
"Traceback",
"TRACEBACK",
"failed",
"Failed",
"FAILED",
"KeyError",
"ValueError",
"TypeError",
"AttributeError",
"FileNotFoundError",
"ConnectionError",
]
return any(indicator in content for indicator in error_indicators)
def _is_sql_content(self, content: str) -> bool:
"""Check if content contains SQL query results."""
sql_indicators = [
"```sql",
"query results",
"sql query:",
"select ",
"insert ",
"update ",
"delete ",
"create table",
"alter table",
]
content_lower = content.lower()
return any(indicator in content_lower for indicator in sql_indicators)
def _is_data_query_result(self, content: str) -> bool:
"""Check if content is a data query result that can be safely compressed."""
indicators = [
"```csv",
"query results",
"rows returned",
"records found",
"records in the database",
"found records",
"csv format",
]
content_lower = content.lower()
return any(indicator in content_lower for indicator in indicators)
def _is_python_execution_result(self, content: str) -> bool:
"""Check if content is Python code execution result."""
indicators = [
"```python",
"execution completed",
"output:",
"result:",
"print(",
]
return any(indicator.lower() in content.lower() for indicator in indicators)
================================================
FILE: openchatbi/graph_state.py
================================================
"""State classes for OpenChatBI graph execution."""
from typing import Annotated, Any
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.graph import MessagesState
from langgraph.types import Send
def add_history_messages(left: list, right: list):
if left:
total_messages = left + right
else:
total_messages = right
return total_messages
class AgentState(MessagesState):
"""State for the main agent graph execution.
Extends MessagesState with additional fields for routing and responses.
"""
history_messages: Annotated[list[HumanMessage | AIMessage], add_history_messages]
agent_next_node: str
sends: list[Send]
sql: str
final_answer: str
class SQLGraphState(MessagesState):
"""State for SQL generation subgraph.
Contains rewritten question, table selection, extracted entities, and generated SQL.
"""
rewrite_question: str
tables: list[dict[str, Any]]
info_entities: dict[str, Any]
sql: str
sql_retry_count: int
sql_execution_result: str
schema_info: dict[str, Any] # Data schema analysis results
data: str # CSV data for display
previous_sql_errors: list[dict[str, Any]]
visualization_dsl: dict[str, Any]
class InputState(MessagesState):
"""Input state schema for the main graph."""
pass
class OutputState(MessagesState):
"""Output state schema for the main graph."""
pass
class SQLOutputState(MessagesState):
"""Output state schema for the SQL generation subgraph."""
rewrite_question: str
tables: list[dict[str, Any]]
sql: str
schema_info: dict[str, Any] # Data schema analysis results
data: str # CSV data for display
visualization_dsl: dict[str, Any]
================================================
FILE: openchatbi/llm/llm.py
================================================
import time
import traceback
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables.base import RunnableBinding
from langchain_core.tools import StructuredTool
from openchatbi import config
from openchatbi.tool.ask_human import AskHuman
from openchatbi.utils import log
def list_llm_providers() -> list[str]:
"""List configured LLM provider names (if any)."""
try:
providers = getattr(config.get(), "llm_providers", None) or {}
except ValueError:
return []
return sorted(providers.keys())
def _get_provider_config(provider: str | None):
cfg = config.get()
providers = getattr(cfg, "llm_providers", None) or {}
if not provider:
provider = getattr(cfg, "llm_provider", None)
if not provider:
return None
if provider not in providers:
raise ValueError(f"Unknown llm_provider '{provider}'. Available: {sorted(providers.keys())}")
return providers[provider]
def get_embedding_model(provider: str | None = None):
"""Get embedding model from config (optionally scoped to a provider)."""
provider_cfg = _get_provider_config(provider)
if provider_cfg and getattr(provider_cfg, "embedding_model", None) is not None:
return provider_cfg.embedding_model
return config.get().embedding_model
def get_default_llm(provider: str | None = None):
"""Get default LLM from config (optionally scoped to a provider)."""
provider_cfg = _get_provider_config(provider)
if provider_cfg:
return provider_cfg.default_llm
return config.get().default_llm
def get_llm(provider: str | None = None):
"""Get the chat model to use (alias for `get_default_llm`)."""
return get_default_llm(provider)
def get_text2sql_llm(provider: str | None = None):
"""Get text2sql LLM from config (optionally scoped to a provider)."""
provider_cfg = _get_provider_config(provider)
if provider_cfg:
return provider_cfg.text2sql_llm or provider_cfg.default_llm
return config.get().text2sql_llm or get_default_llm()
def _invalid_tool_names(valid_tools, tool_calls) -> str:
invalid_tools = []
for tool in tool_calls:
if tool["name"] not in valid_tools:
invalid_tools.append(tool["name"])
return ",".join(invalid_tools)
def call_llm_chat_model_with_retry(
chat_model: BaseChatModel, messages, streaming_tokens=False, bound_tools=None, parallel_tool_call=False
):
"""Calls a language model chat endpoint with retry logic.
Retries up to 3 times if there are errors or invalid tool calls.
Args:
chat_model: The chat model to invoke.
messages (list): List of messages to send to the model.
streaming_tokens (bool, optional): flag to indicate whether or not to show streaming tokens in UI.
bound_tools (list, optional): List of valid tool names that can be called.
parallel_tool_call (bool, optional): whether or not to call multiple tools in parallel.
Returns:
AIMessage or None: The model response or None if all retries failed.
"""
new_messages = list(messages)
valid_tools = []
if bound_tools:
for tool in bound_tools:
if isinstance(tool, str):
valid_tools.append(tool)
elif isinstance(tool, StructuredTool):
valid_tools.append(tool.name)
elif tool == AskHuman:
valid_tools.append("AskHuman")
elif isinstance(chat_model, RunnableBinding) and "tools" in chat_model.kwargs:
valid_tools += [tool["name"] for tool in chat_model.kwargs["tools"] if "name" in tool]
extra_prompt = (
" Please select the `AskHuman` tool if you need to confirm with user." if "AskHuman" in valid_tools else ""
)
response = None
retry = 0
# retry 3 times
while retry < 3:
start_time = time.time()
try:
log(f"Call LLM chat model with retry {retry} times.")
response = chat_model.invoke(new_messages, config={"metadata": {"streaming_tokens": streaming_tokens}})
run_time = int(time.time() - start_time)
log(f"LLM response after {run_time} seconds.")
except Exception:
run_time = int(time.time() - start_time)
retry += 1
log(f"LLM response error after {run_time} seconds, retry {retry} times.")
log("===== Messages:")
log(str(messages))
traceback.print_exc()
continue
if response.tool_calls:
if len(response.tool_calls) > 1 and not parallel_tool_call:
retry += 1
log(f"More than one tool {response.tool_calls}, retry {retry} times.")
new_messages += [{"role": "user", "content": "You should only response with one tool call."}]
response = None
continue
invalid_tools = _invalid_tool_names(valid_tools, response.tool_calls)
if invalid_tools:
retry += 1
log(f"Invalid tool {invalid_tools}, retry {retry} times.")
new_messages += [
{
"role": "user",
"content": f"You should not use tool that does not exist:`{invalid_tools}`."
f"Available tools are: {valid_tools}. Please choose a valid tool and try again."
f"{extra_prompt}",
}
]
response = None
continue
break
return response
================================================
FILE: openchatbi/prompts/agent_prompt.md
================================================
You are a helpful BI assistant that can answer user's question.
Use the instructions below and the tools available to you to assist the user.
# Capabilities:
1. Answer general question.
2. Answer question based on knowledge base.
3. Answer question regarding data query by call the SQL graph to write SQL to answer the question.
4. Answer question that need to analyze the data by write and execute python code
# Guidelines:
- You should be concise, direct, and to the point.
- No fabricate information, if you don't know, just say you don't know.
- Summarize the information you found to answer the question.
- When data analysis results include "Visualization Created" message, acknowledge that an interactive chart has been automatically generated and focus on interpreting the data insights rather than creating additional charts.
# Tool usage policy
- If you cannot answer the question, call tools that are available.
- For `run_python_code` tool, you can use these libs when writing python code: pandas numpy matplotlib seaborn requests json5
- IMPORTANT: DO NOT create charts/visualizations with Python code if the text2sql tool response already indicates "Visualization Created". The interactive chart is automatically generated and displayed in the UI. Simply summarize the results without duplicating the visualization.
- If user provide personalized information that need to remember or want to forget or correct something mentioned before, use `manage_memory` tool to save, delete or update the long term memory
- If the question is related to user information, characteristic or preference, proactively use `search_memory` tool to get the long term memory
- If the question is not clear, or some information is missing, ask the user to clarify by calling AskHuman tool.
- When generating reports, analysis results, or data summaries that users might want to save or share, use the `save_report` tool to save the content to a file and provide a download link.
- **When text2sql tool returns empty SQL**: This indicates the current data capabilities cannot support the requested query. Explain to the user that the requested data or analysis is not available in the current system, and suggest alternative queries that might be supported based on available data sources.
## Knowledge Search Optimization
- **AVOID excessive knowledge searches** for data queries that contain standard business terms already covered in your basic knowledge
- **ONLY search knowledge** when:
- User asks about unfamiliar business terms, metrics, or dimensions not in basic knowledge
- Question contains ambiguous terminology that needs clarification
- Need to understand complex business relationships or derived metrics
- User explicitly asks "what is [term]" or requests definitions
- **SKIP knowledge search** for straightforward data queries since `text2sql` tool will handle it
- **Prioritize direct SQL execution** over knowledge lookup for routine data analysis requests
[extra_tool_use_rule]
# Basic Business Knowledge:
[basic_knowledge_glossary]
# Realtime Environment
Current time is [time_field_placeholder] (format 'yyyy-MM-dd HH:mm:ss')
Review current state and decide what to do next.
If the information is sufficient to answer the question, generate the well summarized final answer.
================================================
FILE: openchatbi/prompts/extraction_prompt.md
================================================
You are a specialized language expert responsible for analyzing user questions and extracting structured information for business intelligence queries.
Your task is to process natural language questions and convert them into structured data that can be used for SQL generation and data analysis.
# Context
You will be provided with:
- Business knowledge glossary of [organization]
- User question
- Chat history (if available)
[basic_knowledge_glossary]
# Core Processing Steps
## Step 1: Information Extraction
Extract and categorize the following information from the user's question and context:
### 1.1 Keywords (Required Array)
Extract all relevant business terms, including:
- Dimension names and aliases
- Metric names and aliases
- Entity types (exclude specific IDs/values)
**Example**: "Show revenue for order 10001" → Extract: ["revenue", "order"] (exclude "10001")
### 1.2 Dimensions (Required Array)
Identify categorical data fields that can be used for grouping or filtering:
- Database column names (e.g., "order_id", "country", "site_id")
- Distinguish between ID fields (numeric identifiers) and name fields (text labels)
### 1.3 Metrics (Optional Array)
Identify measurable quantities that can be aggregated:
- Numeric values that can be summed, averaged, counted, etc.
- For derived metrics (defined in glossary), extract all component parts
- Example: For "click-through rate", extract ["click-through rate", "clicks", "impressions"]
### 1.4 Time Range (Optional)
**start_time** and **end_time**: Convert relative time expressions to absolute timestamps if the question is related to date/time like trends, aggregated metric, etc.
- Format: `'%Y-%m-%d %H:%M:%S'`
- Handle expressions like "yesterday", "last 7 days", "from X to Y"
- Default to "last 7 days" if no time range and granularity specified
- Specific default if user mentioned granularity:
- Weekly -> "last 12 weeks"
- Monthly -> "last 12 months"
- Yearly -> "Full data"
**Example**:
```
Question: "show top 10 ads by CTR yesterday" (today = 2025-05-11)
start_time: "2025-05-10 00:00:00"
end_time: "2025-05-10 23:59:59"
```
### 1.5 Timezone (Optional)
Extract timezone information using this priority:
1. Explicit mention in current question (e.g., "in CET", "EST time")
2. Previously mentioned timezone in conversation history
3. Reset timezone requests → "UTC"
**Common formats**: "America/New_York", "CET", "UTC", "Europe/London"
## Step 2: Filter Conditions
Generate SQL-compatible filter expressions:
**Rules**:
- **Text matching**: Use `LIKE '%text%'` for partial name matches
- **Exact IDs**: Use `=` for numeric identifiers
- **Missing context**: Generate `AskHuman` tool call for clarification
**Examples**:
- "profile 1234" → `["profile_id=1234"]`
- "exam sites" → `["site_name LIKE '%exam%'"]`
- "the site" (no context) → Ask for clarification
## Step 3: Question Rewriting
Transform the original question into a clear, comprehensive query specification.
**Process**:
1. **Analysis**: Break down each component of the user's request
2. **Verification**: Confirm all elements are understood and unambiguous
3. **Rewrite**: Create detailed, explicit version with no ambiguity
**Enhancement Rules**:
- Add metric definitions in brackets: "CTR" → "click-through rate (clicks/impressions)"
- Include default time range if none specified
- Include visualization preference if provided by user
- Preserve user intent while adding necessary context
- Use conversation history to fill gaps
# Knowledge Search Decision
Before extracting information, determine if knowledge search is needed:
## When to Search Knowledge (use `search_knowledge` tool):
- **Unfamiliar terms**: Business-specific jargon, custom metrics, or domain acronyms not in basic knowledge
- **Ambiguous terminology**: Terms that could have multiple meanings in business context
- **Complex derived metrics**: Multi-component calculations requiring formula understanding
- **Explicit requests**: User asks "what is [term]" or requests definitions
## When to Skip Knowledge Search (proceed with JSON extraction):
- **Standard business terms**: Common metrics (revenue, orders, users, clicks, CTR, conversion rate)
- **Basic dimensions**: Standard fields (date, time, location, category, status, id)
- **Clear data requests**: Simple queries with well-understood terminology
- **Routine analytics**: Top N, totals, averages, trends with common business terms
**Decision rule**: Only search knowledge if you encounter terms that are NOT covered in your basic business knowledge or if terminology is genuinely ambiguous in the business context.
# Output Format
Return a JSON object with the following structure:
```json
{
"reasoning": "Step-by-step analysis of user input and decision-making process",
"keywords": ["array", "of", "extracted", "keywords"],
"dimensions": ["array", "of", "dimension", "names"],
"metrics": ["array", "of", "metric", "names"],
"filter": ["array", "of", "sql", "expressions"],
"start_time": "YYYY-MM-DD HH:MM:SS",
"end_time": "YYYY-MM-DD HH:MM:SS",
"timezone": "timezone_identifier",
"rewrite_question": "Complete and detailed question rewrite"
}
```
# Quality Guidelines
## Data Consistency
- If a dimension appears in filters, include it in the dimensions array
- Extract all aliases for derived metrics as defined in the glossary
## Accuracy Rules
- **No fabrication**: Only use information present in context or glossary
- **Prioritization**: Current question takes precedence over chat history
- **Completeness**: Use chat history to fill gaps when current question lacks detail
## Output Formatting
- **Standard response**: JSON wrapped in ```json code blocks
- **Clarification needed**: Generate `AskHuman` tool call instead of JSON
- **Required fields**: Always include `reasoning`, `keywords`, `dimensions`, `filter`, `rewrite_question`
# Comprehensive Example
**Input Question**: "Show me site 1001's CTR trend from 2024-04-01 to 2024-04-10"
**Expected Output**:
```json
{
"reasoning": "User wants to analyze click-through rate trends for a specific site. Breaking down the request: 1) Site identifier: 1001 (numeric ID), 2) Metric: CTR (click-through rate, calculated as clicks/impressions), 3) Analysis type: trend (time-based progression), 4) Time range: 2024-04-01 to 2024-04-10 (9-day period). Since it's a short time range, hourly granularity is most appropriate for trend analysis. All components are clear and complete.",
"keywords": ["site", "click-through rate", "CTR", "clicks", "impressions", "trend"],
"dimensions": ["site_id"],
"metrics": ["click-through rate", "clicks", "impressions"],
"filter": ["site_id=1001"],
"start_time": "2024-04-01 00:00:00",
"end_time": "2024-04-11 00:00:00",
"rewrite_question": "Show me the hourly click-through rate (calculated as clicks/impressions) trend for site_id = 1001 from 2024-04-01 to 2024-04-10"
}
```
# Special Cases
## Case 1: Insufficient Information
**Input**: "Show me revenue trends for the site"
**Action**: Generate `AskHuman` tool call requesting site identification
## Case 2: Conversation Context Usage
**Previous**: "Let's analyze site ABC performance"
**Current**: "Show me CTR for last week"
**Result**: Inherit site "ABC" context
## Case 3: Timezone Handling
**Input**: "Yesterday's metrics in EST"
**Result**: Extract timezone="America/New_York", calculate yesterday in EST
# Environment Variables
- Current date: `[time_field_placeholder]`\
================================================
FILE: openchatbi/prompts/schema_linking_prompt.md
================================================
You are a language expert and professional SQL engineer tasked with analyzing questions from [organization] users and selecting the appropriate table to write SQL.
- You need to analyze the user's question, find the possible dimensions and metrics, and then select the tables and all required columns related to the query.
- I will give you the business knowledge introduction and the glossaries of [organization] for reference.
- I will give you the data warehouse introduction about how these tables are generated and organized.
- I will give you the candidate tables and their schema, read the table description and rule carefully to understand the purpose and capability of the table, and select the appropriate tables and columns.
[basic_knowledge_glossary]
[data_warehouse_introduction]
# Candidate Tables
I found the following tables and their relevant columns and descriptions that might contain the data the user is looking for.
[tables]
# Examples
Here are some examples of questions and selected tables related to the user's question
[examples]
# General Rules
- Must follow the table description and rule to select the table first
- If it is not clear which table to select, you can check the columns in the table to find the columns most related to the question
- The "Candidate Tables" contain all the tables and columns you can use, NEVER make up columns or tables.
- VERY IMPORTANT: the columns you outputted **MUST** be contained in the table you selected, as described in the "# Candidate Tables" section.
- If the question is asking about the metadata of an entity only, you should find a suitable dimension table
- If the question needs to join the fact table with the dimension table, you should also output the dimension table
- If there are very similar questions in examples, you can refer to the selected tables in examples.
- If there are multiple tables that both need requirements, you should select the most relevant one.
- Select and output multiple tables when single table do not contain all fields and need join from multiple tables.
# Output Format
You should output a JSON object, it should include:
- tables: JSON array of selected tables and columns
- table: The selected table
- columns: The columns in the table that are related to the question
- reasoning: The reasoning behind the table selection
Strictly only output the format of JSON below, and do not output any extra description content.
## Example
```json
{
"reasoning": "the reason you select the two tables and columns",
"tables": [
{
"table": "table_name1",
"columns": ["column1", "column2", "column3"]
},
{
"table": "table_name2",
"columns": ["column4", "column5"]
}]
}
```
================================================
FILE: openchatbi/prompts/sql_dialect/presto.md
================================================
# Rules for Presto SQL
- Use 'LIKE' instead of 'ILIKE' in the Presto SQL.
- If there is a 'GROUP BY' clause in the user query, you can use serial number(1,2,3..) instead of the column names in the 'GROUP BY' clause.
- When filter Array type dimension, use ARRAYS_OVERLAP, e.g. ARRAYS_OVERLAP(states, ARRAY['CA'])
- If you have to write two SQL statements, ensure to separate them with a semicolon `;`.
- If there is no 'limit' or 'top' count mentioned in user question, default use "LIMIT 10000" in the SQL query.
- If the SQL you provide includes fuzzy matching filters (e.g., 'name LIKE ...'), you should apply a GROUP BY clause for this dimension to handle cases where multiple rows have similar names.
- If a table name is referenced multiple times in the SQL query (e.g., table_name.column), assign an alias to the table to simplify the query, such as (a.column).
- If you need to write an SQL statement that involves division between two columns, use CAST to convert the numerator to a double type and ensure the denominator is not zero. For example: CASE WHEN SUM(column2)=0 THEN 0 ELSE CAST(SUM(column1) AS DOUBLE) / SUM(column2) END
## Datetime filter related rules
- Please use "INTERVAL '7' DAY" instead of "INTERVAL '1' WEEK" in the presto sql you given.
- Do not use "DATE_SUB" or "DATE_ADD", You can use only datetime calculation like "NOW() - INTERVAL '1' DAY".
- Use `NOW()` instead of `CURRENT_DATE` when you need to get the current date.
- If user ask date range measured in days or months, you need to make sure the end date is included, example: "from 2025-03-12 to 2025-03-16", the condition should be `WHERE event_date >= timestamp '2025-03-12 00:00:00' and event_date < timestamp '2025-03-17 00:00:00'`
- If user ask for time range measured in hours, the end time is not included, example: "from 2025-03-12 00:00:00 to 2025-03-16 00:00:00", the condition should be `WHERE event_date >= timestamp '2025-03-12 00:00:00' and event_date < timestamp '2025-03-16 00:00:00'`
- If user ask for daily breakdown, make sure the whole day are correctly filtered, e.g. "last 7 days per day" should be `WHERE event_date >= DATE_TRUNC('day', (NOW() - INTERVAL '7' DAY)) AND event_date < DATE_TRUNC('day', NOW())`
## Rules for Timezone
### 1. Default Timezone
All event_date in the table are stored in **UTC**. If the user specifies a timezone (e.g., CET, PST), convert between timezones accordingly.
### 2. Timezone Conversion Syntax
- Use `AT TIME ZONE` to convert event_date to other timezone. Example, to convert to CET: `event_date_expr AT TIME ZONE 'CET'`
- User `with_timezone` function to define a constant timestamp with timezone. Example, 2025-05-06 00:00 at CET: `with_timezone(timestamp '2025-05-06 00:00:00', 'CET')`
### 3. WHERE Clause Conversion
- If the user query includes absolute(constant) filters with a specific timezone, convert the timestamp with user timezone to UTC. Keep relative time filters unchanged.
- Example: `WHERE event_date >= timestamp '2025-01-01 00:00:00' and event_date < NOW() - INTERVAL '1' DAY` ->
`WHERE event_date >= with_timezone(timestamp '2025-01-01 00:00:00', CET') AT TIME ZONE 'UTC' AND event_date < NOW() - INTERVAL '1' DAY`
- Instruction when user ask for daily breakdown with timezone
- If ask for relative date, the filter condition should use the date as "trunc date at that timezone first, then convert to UTC"
- Example: "last 7 days per day in NY time" should be `AND event_date >= DATE_TRUNC('day', (NOW() - INTERVAL '7' DAY) AT TIME ZONE 'America/New_York') AT TIME ZONE 'UTC' AND event_date < DATE_TRUNC('day', NOW() AT TIME ZONE 'America/New_York') AT TIME ZONE 'UTC'`
- If ask for absolute(constant) date, the filter condition should convert the 00:00 timestamp with user timezone to UTC
- Example: "during 2025-05-04 and 2025-05-14 per day in NY time" should be `AND event_date >= with_timezone(timestamp '2025-05-04 00:00:00', 'America/New_York') AT TIME ZONE 'UTC' AND event_date < with_timezone(timestamp '2025-05-15 00:00:00', 'America/New_York') AT TIME ZONE 'UTC'`
### 4. SELECT Clause Conversion
- If applying a function f to a datetime column (e.g.date_trunc, format_datetime), convert the event_date from UTC to the user’s timezone before applying f, then cast the result as TIMESTAMP.
- Example: `SELECT f(event_date) AS event_date`-> `SELECT CAST(f(event_date AT TIME ZONE 'CET') AS TIMESTAMP) AS event_date`
### 5. Full Example
- User Question: "Show me hourly pv using table fact_table from 2025-01-01 to yesterday in CET"
- Generated SQL:
```
SELECT
CAST(date_trunc('hour', event_date AT TIME ZONE 'CET') AS TIMESTAMP) AS event_date,
SUM(pv) AS "PV"
FROM fact_table
WHERE
event_date >= with_timezone(timestamp '2025-01-01 00:00:00', 'CET') AT TIME ZONE 'UTC'
AND event_date < (NOW() - INTERVAL '1' DAY)
```
## Rules for Array Dimension
- Filtering: Use ARRAYS_OVERLAP
- When filter value in Array type dimension , use ARRAYS_OVERLAP, e.g. ARRAYS_OVERLAP(states, ARRAY['CA'])
- Flattening Arrays: Use CROSS JOIN UNNEST
- When filter Array type dimension items from a row to multi rows, use `CROSS JOIN UNNEST(COALESCE(NULLIF(id_array, ARRAY[]), ARRAY[-1])) AS t(id)`; MAKE SURE the untested alias has format like `t(id)`
- Additionally, when the subquery uses CROSS JOIN UNNEST, do not sum the metrics for the total array items without group by the unnested id.
- Avoid UNNEST if the user didn’t request it
- If the user refers to an array-type dimension without specifying any particular item, or does not ask to expand it into individual elements, you should not use CROSS JOIN UNNEST.
================================================
FILE: openchatbi/prompts/summary_prompt.md
================================================
Create a concise summary of this conversation for continuing the data analysis work. Focus on:
1. **User's Main Questions and Objectives**: What data insights or analysis the user is seeking, business questions they want answered
2. **Key Data Analysis Results**: Important findings from SQL queries, relevant tables/columns discovered, key metrics or patterns identified
3. **Tools and Data Sources Overview**:
- Key databases/tables accessed
- Main analysis tools used (SQL, Python, etc.)
- Data export formats generated
4. **Business Context**: Important business concepts, domain-specific terms, data definitions that were clarified
5. **Conversation Flow**: List ALL user messages with corresponding response summaries. These are critical for understanding user feedback and changing intent:
- User message -> Response summary (what was done/analyzed)
- User feedback -> How the analysis was adjusted
- Follow-up questions -> Additional insights provided
6. **Current Progress**: What analysis was completed, any ongoing tasks, user feedback on results or requested modifications
Here's an example of how your output should be structured:
1. **User's Main Questions and Objectives**:
[User's data analysis goals and business questions]
2. **Key Data Analysis Results**:
- [Important SQL query results]
- [Key tables and relationships discovered]
- [Metrics and insights found]
3. **Tools and Data Sources Overview**:
- [Databases: customer_db, sales_warehouse]
- [Analysis tools: SQL, Python pandas]
- [Exports: CSV reports, dashboard charts]
4. **Business Context**:
- [Business concepts and terminology]
- [Data field definitions]
- [Domain knowledge gained]
5. **Conversation Flow**:
- [User message 1: original request] -> [Response: SQL query executed, found X insights]
- [User message 2: clarification about metric Y] -> [Response: adjusted analysis, discovered Z pattern]
- [User message 3: follow-up question] -> [Response: additional exploration, provided interpretation]
6. **Current Progress**:
[What was completed, ongoing work, and user feedback]
Conversation to summarize:
[conversation_text]
Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response.
================================================
FILE: openchatbi/prompts/system_prompt.py
================================================
"""System prompt templates and business configuration."""
import importlib.resources
from openchatbi import config
# Global cache variables for lazy loading (only for file I/O operations)
_dialect_rules_cache = None
_agent_prompt_template_cache = None
_extraction_prompt_template_cache = None
_table_selection_prompt_template_cache = None
_text2sql_prompt_template_cache = None
_visualization_prompt_template_cache = None
_summary_prompt_template_cache = None
def get_basic_knowledge():
"""Get basic knowledge from config."""
try:
return config.get().bi_config.get("basic_knowledge_glossary", "")
except ValueError:
return ""
def get_data_warehouse_introduction():
"""Get data warehouse introduction from config."""
try:
return config.get().bi_config.get("data_warehouse_introduction", "")
except ValueError:
return ""
def get_agent_extra_tool_use_rule():
"""Get agent extra tool use rule from config."""
try:
return config.get().bi_config.get("extra_tool_use_rule", "")
except ValueError:
return ""
def get_organization():
"""Get organization from config."""
try:
return config.get().organization
except ValueError:
return "The Company"
def get_dialect_rules():
"""Get SQL dialect rules with lazy loading and caching."""
global _dialect_rules_cache
if _dialect_rules_cache is None:
dialect_dir = importlib.resources.files("openchatbi.prompts.sql_dialect")
_dialect_rules_cache = {}
for item in dialect_dir.iterdir():
if item.is_file() and item.name.endswith(".md"):
dialect_name = item.name[:-3]
with item.open() as f:
prompt = f.read()
_dialect_rules_cache[dialect_name] = prompt
return _dialect_rules_cache
def get_agent_prompt_template() -> str:
"""Get agent prompt template with caching."""
global _agent_prompt_template_cache
if _agent_prompt_template_cache is None:
with importlib.resources.files("openchatbi.prompts").joinpath("agent_prompt.md").open("r") as f:
prompt = f.read()
_agent_prompt_template_cache = (
prompt.replace("[organization]", get_organization())
.replace("[basic_knowledge_glossary]", get_basic_knowledge())
.replace("[extra_tool_use_rule]", get_agent_extra_tool_use_rule())
)
return _agent_prompt_template_cache
def get_extraction_prompt_template() -> str:
"""Get extraction prompt template with caching."""
global _extraction_prompt_template_cache
if _extraction_prompt_template_cache is None:
with importlib.resources.files("openchatbi.prompts").joinpath("extraction_prompt.md").open("r") as f:
prompt = f.read()
_extraction_prompt_template_cache = prompt.replace("[organization]", get_organization()).replace(
"[basic_knowledge_glossary]", get_basic_knowledge()
)
return _extraction_prompt_template_cache
def get_table_selection_prompt_template() -> str:
"""Get table selection prompt template with caching."""
global _table_selection_prompt_template_cache
if _table_selection_prompt_template_cache is None:
with importlib.resources.files("openchatbi.prompts").joinpath("schema_linking_prompt.md").open("r") as f:
prompt = f.read()
_table_selection_prompt_template_cache = prompt.replace("[organization]", get_organization()).replace(
"[basic_knowledge_glossary]", get_basic_knowledge()
)
return _table_selection_prompt_template_cache
def get_text2sql_prompt_template() -> str:
"""Get text2sql prompt template with caching."""
global _text2sql_prompt_template_cache
if _text2sql_prompt_template_cache is None:
with importlib.resources.files("openchatbi.prompts").joinpath("text2sql_prompt.md").open("r") as f:
prompt = f.read()
_text2sql_prompt_template_cache = (
prompt.replace("[organization]", get_organization())
.replace("[basic_knowledge_glossary]", get_basic_knowledge())
.replace("[data_warehouse_introduction]", get_data_warehouse_introduction())
)
return _text2sql_prompt_template_cache
def get_visualization_prompt_template() -> str:
"""Get visualization prompt template with caching."""
global _visualization_prompt_template_cache
if _visualization_prompt_template_cache is None:
with importlib.resources.files("openchatbi.prompts").joinpath("visualization_prompt.md").open("r") as f:
_visualization_prompt_template_cache = f.read()
return _visualization_prompt_template_cache
def get_summary_prompt_template() -> str:
"""Get summary prompt template with caching."""
global _summary_prompt_template_cache
if _summary_prompt_template_cache is None:
with importlib.resources.files("openchatbi.prompts").joinpath("summary_prompt.md").open("r") as f:
_summary_prompt_template_cache = f.read()
return _summary_prompt_template_cache
def get_text2sql_dialect_prompt_template(dialect: str) -> str:
"""Get text2sql prompt template for specific SQL dialect."""
prompt = get_text2sql_prompt_template()
if not prompt:
prompt = "Generate SQL query for the given question in [dialect] dialect."
dialect_rules = get_dialect_rules()
prompt = prompt.replace("[dialect]", dialect).replace("[sql_dialect_rules]", dialect_rules.get(dialect, ""))
return prompt
def reset_cache():
"""Reset all cached values. Useful for testing."""
global _dialect_rules_cache, _agent_prompt_template_cache
global _extraction_prompt_template_cache, _table_selection_prompt_template_cache
global _text2sql_prompt_template_cache, _visualization_prompt_template_cache
global _summary_prompt_template_cache
_dialect_rules_cache = None
_agent_prompt_template_cache = None
_extraction_prompt_template_cache = None
_table_selection_prompt_template_cache = None
_text2sql_prompt_template_cache = None
_visualization_prompt_template_cache = None
_summary_prompt_template_cache = None
================================================
FILE: openchatbi/prompts/text2sql_prompt.md
================================================
You are a professional SQL engineer, your task is to transform user query into [dialect] SQL.
- I will give you the business knowledge introduction and the glossaries of [organization] for reference.
- I will give you the selected tables, you need to analyze the user query, read the table description, schema, constrains and examples carefully to write [dialect] SQL to answer user's question.
- You are a read-only analytics assistant. NEVER generate DELETE, DROP, UPDATE, or INSERT statements.
[basic_knowledge_glossary]
# Tables
[table_schema]
# Examples
[examples]
# Rules for [dialect] SQL
[sql_dialect_rules]
# Rules for Task
- I will provide you with data schema definition and the explanation and usage scenario of each field.
- You can only use the tables listed in "# Tables".
- You can only use the metrics, dimension, columns from the schema I provided.
- You should only use the display name as alias in query if provided in schema.
- Never create or assume additional tables or columns, even if they were mentioned in history message.
- Do not use any id or date in example SQL.
- Do not output any explanations or comment.
- If the query asks for a metric or field not explicitly defined in the table schema, do not generate a SQL query with an invented field, instead, you should output "NULL".
- You can only answer when you are very confident, otherwise, please output "NULL"
# Output format(case sensitive)
```sql
```
# Realtime Environment
Current time is [time_field_placeholder] (format 'yyyy-MM-dd HH:mm:ss')
Based on the Tables, Columns, take your time to think user query carefully, transform it into [dialect] SQL and reply following Output format.
================================================
FILE: openchatbi/prompts/visualization_prompt.md
================================================
You are a data visualization expert. Analyze the user's question and data to recommend the most appropriate chart type.
## User Question
[question]
## Data Schema
- Columns: [columns]
- Numeric columns: [numeric_columns]
- Categorical columns: [categorical_columns]
- DateTime columns: [datetime_columns]
- Row count: [row_count]
## Data Sample
[data_sample]
## Available Chart Types
1. **line** - For trends over time, time series data
2. **bar** - For comparing categories, discrete comparisons
3. **pie** - For showing proportions, parts of a whole (best for <= 6 categories)
4. **scatter** - For showing relationships between two numeric variables
5. **histogram** - For showing distribution of a single numeric variable
6. **box** - For showing statistical distribution, outliers, quartiles
7. **heatmap** - For showing correlation or intensity across two dimensions
8. **table** - For detailed data examination, small datasets, or when charts aren't suitable
## Analysis Guidelines
Consider:
- The user's intent and question keywords
- Data types and structure
- Number of data points and categories
- What insights the user is likely seeking
## Response Format
Respond with ONLY the chart type name (line, bar, pie, scatter, histogram, box, heatmap, or table). No explanation needed.
================================================
FILE: openchatbi/text2sql/__init__.py
================================================
"""Text-to-SQL conversion module for OpenChatBI."""
================================================
FILE: openchatbi/text2sql/data.py
================================================
import os
from openchatbi import config
from openchatbi.text2sql.text2sql_utils import init_sql_example_retriever, init_table_selection_example_dict
# Skip init during documentation build
if not os.environ.get("SPHINX_BUILD"):
try:
_catalog_store = config.get().catalog_store
except ValueError:
_catalog_store = None
else:
_catalog_store = None
if _catalog_store:
sql_example_retriever, sql_example_dicts = init_sql_example_retriever(_catalog_store, config.get().vector_db_path)
table_selection_retriever, table_selection_example_dict = init_table_selection_example_dict(
_catalog_store, config.get().vector_db_path
)
else:
sql_example_retriever, sql_example_dicts = None, {}
table_selection_retriever, table_selection_example_dict = None, {}
================================================
FILE: openchatbi/text2sql/extraction.py
================================================
"""Information extraction module for text2sql processing."""
import traceback
from collections.abc import Callable
from datetime import date
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from openchatbi.graph_state import SQLGraphState
from openchatbi.llm.llm import call_llm_chat_model_with_retry
from openchatbi.prompts.system_prompt import get_basic_knowledge, get_extraction_prompt_template
from openchatbi.utils import extract_json_from_answer, get_text_from_content, log
def generate_extraction_prompt() -> str:
"""Generate extraction prompt.
Returns:
str: Generated prompt with placeholders replaced.
"""
prompt = get_extraction_prompt_template()
date_str = date.today().strftime("%Y-%m-%d")
prompt = prompt.replace("[time_field_placeholder]", date_str)
prompt = prompt.replace("[basic_knowledge_glossary]", get_basic_knowledge())
return prompt
def parse_extracted_info_json(llm_answer_content: Any) -> dict[str, Any]:
"""Extract and parse JSON from LLM response.
Args:
llm_answer_content: LLM response containing JSON.
Returns:
dict: Parsed JSON or empty dict if parsing fails.
"""
try:
text = get_text_from_content(llm_answer_content)
result = extract_json_from_answer(text)
except Exception:
log(traceback.format_exc())
result = {}
return result
def information_extraction(llm: BaseChatModel) -> Callable:
"""Create function to extract information from questions.
Args:
llm (BaseChatModel): Language model for information extraction.
Returns:
function: Node function that extracts information from questions.
"""
def _extract(state: SQLGraphState):
"""Extract information from question in state.
Args:
state (SQLGraphState): Current SQL graph state with question.
Returns:
dict: Updated state with extracted information.
"""
messages = state["messages"]
last_message = messages[-1]
user_input = last_message.content
log(f"information_extraction: {user_input}")
system_prompt = generate_extraction_prompt()
prompt = "Please extract the information according to the context."
response = call_llm_chat_model_with_retry(
llm, ([SystemMessage(system_prompt)] + messages + [HumanMessage(prompt)]), ["search_knowledge", "AskHuman"]
)
if response:
log(response)
if response.tool_calls:
return {"messages": [response]}
else:
llm_answer_content = response.content
parsed_result = parse_extracted_info_json(llm_answer_content)
return {
"messages": [response],
"rewrite_question": parsed_result.get("rewrite_question"),
"info_entities": parsed_result,
}
else:
return {"messages": [AIMessage(role="system", content="{}")]}
return _extract
def information_extraction_conditional_edges(state: SQLGraphState):
"""Determine next node after information extraction.
Args:
state (SQLGraphState): Current SQL graph state.
Returns:
str: Next node ('ask_human', 'search_knowledge', 'next', or 'end').
"""
messages = state["messages"]
last_message = messages[-1]
tool_calls = None
if isinstance(last_message, AIMessage):
tool_calls = last_message.tool_calls
log(f"tool_calls: {tool_calls}")
if tool_calls:
if tool_calls[0]["name"] == "AskHuman":
return "ask_human"
elif tool_calls[0]["name"] == "search_knowledge":
return "search_knowledge"
else:
print(f"Unknown tool call: {tool_calls[0]['name']}")
return "end"
else:
if "rewrite_question" in state:
return "next"
else:
return "end"
================================================
FILE: openchatbi/text2sql/generate_sql.py
================================================
import datetime
from collections.abc import Callable
from typing import Any
import pandas as pd
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from sqlalchemy import text
from sqlalchemy.exc import DatabaseError, OperationalError, ProgrammingError, TimeoutError
from openchatbi.catalog import CatalogStore
from openchatbi.constants import (
SQL_EXECUTE_TIMEOUT,
SQL_NA,
SQL_SUCCESS,
SQL_SYNTAX_ERROR,
SQL_UNKNOWN_ERROR,
datetime_format,
)
from openchatbi.graph_state import SQLGraphState
from openchatbi.prompts.system_prompt import get_text2sql_dialect_prompt_template
from openchatbi.text2sql.data import sql_example_dicts, sql_example_retriever
from openchatbi.text2sql.visualization import VisualizationService
from openchatbi.utils import get_text_from_content, log
COLUMN_PROMPT_TEMPLATE = """### Columns
Column(Name, Type, Display Name, Description):
[
{}
]
"""
def create_sql_nodes(
llm: BaseChatModel, catalog: CatalogStore, dialect: str, visualization_mode: str | None = "rule"
) -> tuple[Callable, Callable, Callable, Callable]:
"""Creates the four SQL processing nodes for LangGraph.
Args:
llm (BaseChatModel): The language model to use for SQL generation.
catalog (CatalogStore): The catalog store containing schema information.
dialect (str): The SQL dialect to use (e.g., 'presto', 'mysql').
visualization_mode (str | None): Visualization analysis mode ("rule", "llm", or None to skip).
Returns:
tuple: Four node functions (generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node)
"""
# Initialize visualization service based on configuration
visualization_service = VisualizationService(llm if visualization_mode == "llm" else None)
def _get_column_prompt(column: dict[str, Any]) -> str:
alias_prompt = f"alias({column['alias']})" if "alias" in column and column["alias"] else ""
return (
f""" Column("{column['column_name']}", {column['type']}, {column['display_name']},"""
f""" "{alias_prompt}{column['description']}"),"""
)
def _get_table_schema_prompt(tables_columns: list[dict[str, Any]]) -> str:
"""Generates a prompt string for table schemas, including table description,
columns, derived metrics and rules when writting SQL
Args:
tables_columns (List[Dict[str, Any]]): List of tables with selected columns.
Returns:
str: Formatted table schema prompt string.
"""
schema_prompt = []
for table_dict in tables_columns:
table_name = table_dict["table"]
# TODO maybe use columns in prompt
columns = table_dict["columns"]
table_info = catalog.get_table_information(table_name)
single_table_schema_prompt = f"## Table {table_name}\n{table_info['description']}\n"
columns = catalog.get_column_list(table_name)
single_table_schema_prompt += COLUMN_PROMPT_TEMPLATE.format(
"\n".join([_get_column_prompt(column) for column in columns])
)
single_table_schema_prompt += table_info.get("derived_metric", "")
single_table_schema_prompt += table_info["sql_rule"]
schema_prompt.append(single_table_schema_prompt)
return "\n".join(schema_prompt)
def _get_relevant_sql_examples_prompt(question, tables_columns: list[dict[str, Any]]) -> str:
"""Retrieves relevant SQL examples based on the question and selected tables.
Args:
question (str): The natural language question.
tables_columns (List[str]): List of selected tables with selected columns.
Returns:
str: Formatted string of relevant SQL examples.
"""
tables = [d["table"] for d in tables_columns]
relevant_questions = sql_example_retriever.invoke(question)
# log(f"Retrieved examples for question: {question} \n Relevant questions: {relevant_questions}")
# filter examples that only use the selected tables
examples = []
for relevant_document in relevant_questions:
question = relevant_document.page_content
example_sql, used_tables = sql_example_dicts[question]
if all(table in tables for table in used_tables):
examples.append(f"\nQ: {question}\nA: {example_sql}\n\n")
log(f"Examples using selected tables: {examples}")
return "\n".join(examples)
def _analyze_dataframe_schema(df: pd.DataFrame) -> dict[str, Any]:
"""Analyze DataFrame to understand column types and characteristics."""
try:
schema_info = {
"columns": list(df.columns),
"column_types": {},
"row_count": len(df),
"numeric_columns": [],
"categorical_columns": [],
"datetime_columns": [],
}
for col in df.columns:
dtype = str(df[col].dtype)
schema_info["column_types"][col] = dtype
# Classify column types
if df[col].dtype in ["int64", "float64", "int32", "float32"]:
schema_info["numeric_columns"].append(col)
elif df[col].dtype == "object":
# Check if it could be datetime
try:
pd.to_datetime(df[col].head(10))
schema_info["datetime_columns"].append(col)
except:
schema_info["categorical_columns"].append(col)
# Calculate unique value counts for categorical columns
schema_info["unique_counts"] = {}
for col in schema_info["categorical_columns"]:
schema_info["unique_counts"][col] = df[col].nunique()
return schema_info
except Exception as e:
return {"error": f"Failed to analyze data schema: {str(e)}"}
def _execute_sql(sql: str) -> tuple[dict, str]:
"""Executes the generated SQL query and returns the result with schema analysis.
Args:
sql (str): The SQL query to execute.
Returns:
Tuple[dict, str]: A tuple containing (schema_info, CSV string).
"""
with catalog.get_sql_engine().connect() as connection:
result = connection.execute(text(sql))
# Fetch all rows from the result
rows = result.fetchall()
# Get column names
columns = list(result.keys())
# Create DataFrame for analysis
df = pd.DataFrame(rows, columns=columns)
# Analyze data schema
schema_info = _analyze_dataframe_schema(df)
# Format as CSV
csv_data = df.to_csv(index=False)
connection.commit()
return schema_info, csv_data
def generate_sql_node(state: SQLGraphState) -> dict:
"""First node: Generates initial SQL query based on the state.
Args:
state (SQLGraphState): The current SQL graph state containing the question and tables.
Returns:
dict: Updated state with generated SQL query.
"""
if "rewrite_question" not in state:
log("Missing rewrite question, skipping SQL generation.")
return {}
if "tables" not in state or len(state["tables"]) == 0:
log("Missing tables, skipping SQL generation.")
return {}
question = state["rewrite_question"]
tables_columns = state["tables"]
system_prompt = (
get_text2sql_dialect_prompt_template(dialect)
.replace("[table_schema]", _get_table_schema_prompt(tables_columns))
.replace("[examples]", _get_relevant_sql_examples_prompt(question, tables_columns))
.replace("[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format))
)
user_prompt = f"""Generate a SQL query for the question: {question}"""
messages = [SystemMessage(system_prompt)] + list(state["messages"]) + [HumanMessage(user_prompt)]
response = llm.invoke(messages)
response_content = get_text_from_content(response.content)
sql_query = response_content.replace("```sql", "").replace("```", "").strip()
if not sql_query or sql_query.lower() == "null":
log(f"Generated SQL query is empty. LLM output: {response.content}")
return {
"messages": [AIMessage(response_content)],
"sql": sql_query,
"sql_retry_count": 0,
"sql_execution_result": "",
"previous_sql_errors": [],
}
return {"sql": sql_query, "sql_retry_count": 0, "sql_execution_result": "", "previous_sql_errors": []}
def execute_sql_node(state: SQLGraphState) -> dict:
"""Second node: Executes the SQL query and returns result or error.
Args:
state (SQLGraphState): The current SQL graph state containing the SQL query.
Returns:
dict: Updated state with execution result or error information.
"""
sql_query = state.get("sql", "").strip()
if not sql_query:
return {"sql_execution_result": SQL_NA, "messages": [AIMessage("No SQL query to execute")]}
try:
schema_info, csv_result = _execute_sql(sql_query)
result = f"```sql\n{sql_query}\n```\nSQL Result:\n```csv\n{csv_result}\n```"
return {
"sql_execution_result": SQL_SUCCESS,
"schema_info": schema_info,
"data": csv_result,
"messages": [AIMessage(result)],
}
except (OperationalError, TimeoutError) as e:
log(f"Database connection/timeout error: {str(e)}")
error_result = (
f"```sql\n{sql_query}\n```\nDatabase Connection Timeout: {str(e)}\nPlease check database connectivity."
)
return {"sql_execution_result": SQL_EXECUTE_TIMEOUT, "messages": [AIMessage(error_result)]}
except Exception as e:
error_type = "Unexpected error"
if isinstance(e, ProgrammingError):
error_type = "SQL syntax error"
elif isinstance(e, DatabaseError):
error_type = "Database error"
log(f"{error_type}: {str(e)}")
# Add error to previous errors list
previous_errors = list(state.get("previous_sql_errors", []))
previous_errors.append({"sql": sql_query, "error": f"{error_type}: {str(e)}", "error_type": error_type})
return {
"sql_execution_result": SQL_UNKNOWN_ERROR if error_type == "Unexpected error" else SQL_SYNTAX_ERROR,
"previous_sql_errors": previous_errors,
}
def regenerate_sql_node(state: SQLGraphState) -> dict:
"""Third node: Regenerates SQL based on previous errors.
Args:
state (SQLGraphState): The current SQL graph state containing error information.
Returns:
dict: Updated state with regenerated SQL query.
"""
question = state["rewrite_question"]
tables = state["tables"]
previous_errors = state.get("previous_sql_errors", [])
retry_count = state.get("sql_retry_count", 0) + 1
system_prompt = (
get_text2sql_dialect_prompt_template(dialect)
.replace("[table_schema]", _get_table_schema_prompt(tables))
.replace("[examples]", _get_relevant_sql_examples_prompt(question, tables))
.replace("[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format))
)
user_prompt = f"""Generate a SQL query for the question: {question}"""
if previous_errors:
user_prompt += "\n\nPrevious attempts failed with errors:"
for i, error_info in enumerate(previous_errors, 1):
user_prompt += f"\n\nAttempt {i}:\nSQL: {error_info['sql']}\nError: {error_info['error']}"
user_prompt += "\n\nPlease analyze the errors above and generate a corrected SQL query."
messages = [SystemMessage(system_prompt)] + list(state["messages"]) + [HumanMessage(user_prompt)]
response = llm.invoke(messages)
response_content = get_text_from_content(response.content)
sql_query = response_content.replace("```sql", "").replace("```", "").strip()
if not sql_query:
log(f"Generated SQL query is empty. LLM output: {response.content}")
error_result = f"Failed to regenerate valid SQL after {retry_count} attempts."
return {
"messages": [AIMessage(error_result)],
"sql": "",
"sql_retry_count": retry_count,
"sql_execution_result": SQL_NA,
}
return {"sql": sql_query, "sql_retry_count": retry_count, "sql_execution_result": ""}
def generate_visualization_node(state: SQLGraphState) -> dict:
"""Fourth node: Generates visualization DSL based on successful SQL execution result.
Args:
state (SQLGraphState): The current SQL graph state containing query data and results.
Returns:
dict: Updated state with visualization DSL.
"""
execution_result = state.get("sql_execution_result", "")
if execution_result != SQL_SUCCESS:
# No visualization for failed queries
return {"visualization_dsl": {}}
question = state.get("rewrite_question", "")
schema_info = state.get("schema_info", {})
data = state.get("data", "")
if not question or not schema_info or not data or not visualization_mode:
return {"visualization_dsl": {}}
try:
# Generate visualization DSL using configured service
viz_dsl = visualization_service.generate_visualization(question, schema_info, data)
# Handle case where visualization is skipped
if viz_dsl is None:
return {"visualization_dsl": {}}
# Update the AI message to include visualization information
messages = list(state.get("messages", []))
if messages and hasattr(messages[-1], "content"):
current_content = messages[-1].content
viz_info = f"\n\n**Visualization Generated**: {viz_dsl.chart_type.title()} chart with {len(viz_dsl.data_columns)} column(s)"
messages[-1] = AIMessage(current_content + viz_info)
return {"visualization_dsl": viz_dsl.to_dict(), "messages": messages}
except Exception as e:
log(f"Visualization generation error: {str(e)}")
return {"visualization_dsl": {"error": f"Failed to generate visualization: {str(e)}"}}
return generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node
def should_retry_sql(state: SQLGraphState) -> str:
"""Conditional edge function to determine if SQL should be retried.
Args:
state (SQLGraphState): Current state
Returns:
str: Next node name - "regenerate_sql" if retry needed, "end" if done
"""
execution_result = state.get("sql_execution_result", "")
retry_count = state.get("sql_retry_count", 0)
max_retries = 3
if execution_result in (SQL_SUCCESS, SQL_EXECUTE_TIMEOUT):
return "end"
elif retry_count < max_retries:
return "regenerate_sql"
else:
# Max retries reached or other terminal state
if retry_count >= max_retries:
previous_errors = state.get("previous_sql_errors", [])
if previous_errors:
last_error = previous_errors[-1]
error_result = f"```sql\n{last_error['sql']}\n```\n{last_error['error']}\nFailed to generate valid SQL after {max_retries} attempts."
else:
error_result = f"Failed to generate valid SQL after {max_retries} attempts."
# Update state with final error message
state["messages"] = [AIMessage(error_result)]
state["sql_execution_result"] = SQL_NA
return "end"
def should_execute_sql(state: SQLGraphState) -> str:
"""Conditional edge function to determine if SQL should be executed.
Args:
state (SQLGraphState): Current state
Returns:
str: Next node name - "execute_sql" if SQL is generated, "end" if done
"""
sql = state.get("sql", "")
if not sql:
return "end"
else:
return "execute_sql"
================================================
FILE: openchatbi/text2sql/schema_linking.py
================================================
"""Schema linking module for table and column selection in text2sql."""
from datetime import datetime
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from openchatbi.catalog import CatalogStore
from openchatbi.catalog.schema_retrival import col_dict, column_tables_mapping, get_relevant_columns
from openchatbi.constants import datetime_format
from openchatbi.graph_state import SQLGraphState
from openchatbi.prompts.system_prompt import get_table_selection_prompt_template
from openchatbi.text2sql.data import table_selection_example_dict, table_selection_retriever
from openchatbi.utils import extract_json_from_answer, log
def schema_linking(llm: BaseChatModel, catalog: CatalogStore):
"""Create function for schema linking: select appropriate tables and columns for a question.
Args:
llm (BaseChatModel): Language model for table selection.
catalog (CatalogStore): Catalog store with schema information.
Returns:
function: Node function for schema linking based on question.
"""
def _get_related_tables_and_columns(keywords_list, dimensions, metrics, start_time=None, invalid_table=None):
"""Retrieves tables and columns related to the given keywords, dimensions, and metrics.
Args:
keywords_list (list): List of keywords extracted from the question.
dimensions (list): List of dimensions mentioned in the question.
metrics (list): List of metrics mentioned in the question.
start_time (str, optional): Start time for filtering tables.
invalid_table (list, optional): List of tables to exclude.
Returns:
dict: Dictionary mapping table names to their information and related columns.
"""
# 1. Get the top similar columns
relevant_columns = get_relevant_columns(keywords_list, dimensions, metrics)
# 2. Get all the related tables
candidate_tables = set()
for column in relevant_columns:
table_list = column_tables_mapping.get(column, [])
candidate_tables.update(table_list)
if start_time:
try:
start_time = datetime.strptime(start_time, datetime_format)
except ValueError:
start_time = None
# 3. Get all the table's related column
related_table_column_dict = {}
for table_name in candidate_tables:
if table_name in invalid_table:
continue
table_info = catalog.get_table_information(table_name)
if not table_info:
continue
if start_time and "start_time" in table_info:
if datetime.strptime(table_info.get("start_time"), datetime_format) > start_time:
continue
columns = []
for column_name in relevant_columns:
column_dict = col_dict[column_name].copy()
if table_name not in column_tables_mapping.get(column_name, []):
continue
columns.append(column_dict)
related_table_column_dict[table_name] = (table_info, columns)
return related_table_column_dict
def _example_retrieval(query, candidate_tables):
"""Retrieves example questions and their selected tables that match the candidate tables.
Args:
query (str): The natural language question.
candidate_tables (list): List of candidate table names.
Returns:
dict: Dictionary mapping example questions to their selected tables.
"""
similar_questions = table_selection_retriever.invoke(query)
valid_examples = {}
for question_doc in similar_questions:
question = question_doc.page_content
if not question:
continue
expected_tables = table_selection_example_dict[question]
expected_tables = [table for table in expected_tables if table in candidate_tables]
if expected_tables:
valid_examples[question] = expected_tables
return valid_examples
def _build_table_selection_prompt(related_table_column_dict, similar_examples):
"""Builds a prompt for table selection based on related tables and examples.
Args:
related_table_column_dict (dict): Dictionary of tables with their information and columns.
similar_examples (dict): Dictionary of example questions and their selected tables.
Returns:
str: Formatted prompt for table selection.
"""
similar_examples = [
f"- Question: {example} Selected Tables: [{','.join(selected_tables)}]"
for example, selected_tables in similar_examples.items()
]
table_column_descs = []
for table_name, (table_info, columns) in related_table_column_dict.items():
columns_desc = "\n".join(
[
f"- {column['category']}({column['column_name']}, {column['display_name']}, \"{column['description']}\")"
for column in columns
]
)
desc_part = f"\n### Table Description: \n{table_info['description']}"
rule_part = f"\n### Rule: \n{table_info.get('selection_rule')}" if table_info.get("selection_rule") else ""
table_desc = (
f"\n## Table: {table_name} {desc_part} {rule_part}"
"\n### Columns: \nCategory(Name, Display Name, Description): "
f"\n{columns_desc}"
""
)
table_column_descs.append(table_desc)
# Build the LLM prompt
prompt = (
get_table_selection_prompt_template()
.replace("[tables]", "\n\n".join(table_column_descs))
.replace("[examples]", "\n".join(similar_examples))
)
return prompt
def _verify_table(selected_tables, candidate_tables):
"""Verifies that selected tables are valid candidates.
Args:
selected_tables (list): List of tables selected by the model.
candidate_tables (list): List of candidate tables.
Returns:
bool: True if all selected tables are valid candidates.
"""
if not selected_tables:
return False
for table in selected_tables:
if table.get("table") not in candidate_tables:
return False
return True
def _call_llm_select(llm: BaseChatModel, system_prompt, messages, question, candidate_tables):
"""Calls the language model to select appropriate tables for the question.
Retries up to 3 times if the LLM's answer is invalid.
Args:
llm (BaseChatModel): The language model to use.
system_prompt (str): The system prompt for table selection.
messages (list): List of previous messages.
question (str): The natural language question.
candidate_tables (list): List of candidate tables.
Returns:
dict: Dictionary containing selected tables.
"""
log("Selecting appropriate tables...")
# print(f"candidate_tables: {candidate_tables}")
prompt = f"""Please select the appropriate tables for the question: {question}"""
messages.append(HumanMessage(prompt))
retry_flag = True
retry_cnt = 1
while retry_flag:
try:
log("Ask LLM to select the table...")
# print("_call_llm_select")
# print(messages)
response = llm.invoke([SystemMessage(system_prompt)] + messages)
result = extract_json_from_answer(response.content)
selected_tables = result.get("tables")
log(result)
if _verify_table(selected_tables, candidate_tables):
return {"tables": selected_tables}
else:
messages.append(
HumanMessage(
f'The selected table {",".join([table.get("table") for table in result.get("tables")])} is not valid. '
f"Do not select this table, please try again."
)
)
retry_cnt += 1
if retry_cnt > 3:
retry_flag = False
if retry_flag:
log(
f"The selected table {','.join([table.get('table') for table in result.get('tables')])} is not in the candidate tables."
)
log("Retry Table Selection...")
except Exception as e:
log(str(e))
retry_cnt += 1
if retry_cnt > 3:
retry_flag = False
return {}
def _select(state: SQLGraphState) -> dict:
if not state.get("rewrite_question"):
log("Missing rewrite question, skipping schema linking.")
return {}
messages = state["messages"]
question = state["rewrite_question"]
info_entities = state["info_entities"]
keywords_list = info_entities.get("keywords", [])
dimensions = info_entities.get("dimensions", [])
metrics = info_entities.get("metrics", [])
start_time = info_entities.get("start_time")
invalid_table = []
log("Retrieving related table schema...")
# 1. Get related tables and columns
related_table_column_dict = _get_related_tables_and_columns(
keywords_list, dimensions, metrics, start_time, invalid_table
)
candidate_tables = related_table_column_dict.keys()
# 2. Get the similar examples
similar_examples = _example_retrieval(" ".join(keywords_list), related_table_column_dict.keys())
# 3. Build tables prompt
system_prompt = _build_table_selection_prompt(related_table_column_dict, similar_examples)
# 4. Call LLM to select the table
return _call_llm_select(llm, system_prompt, messages, question, candidate_tables)
return _select
================================================
FILE: openchatbi/text2sql/sql_graph.py
================================================
"""SQL generation graph construction and execution."""
from langchain_openai.chat_models.base import BaseChatOpenAI
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer, interrupt
from openchatbi import config
from openchatbi.catalog import CatalogStore
from openchatbi.constants import SQL_SUCCESS
from openchatbi.graph_state import InputState, SQLGraphState, SQLOutputState
from openchatbi.llm.llm import get_llm, get_text2sql_llm
from openchatbi.text2sql.extraction import information_extraction, information_extraction_conditional_edges
from openchatbi.text2sql.generate_sql import create_sql_nodes, should_execute_sql
from openchatbi.text2sql.schema_linking import schema_linking
from openchatbi.tool.ask_human import AskHuman
from openchatbi.tool.search_knowledge import search_knowledge
def ask_human(state):
"""Node function to ask human for additional information or clarification.
Args:
state (SQLGraphState): The current SQL graph state containing messages and context.
Returns:
dict: Updated state with human feedback as a tool message and user input.
"""
tool_call = state["messages"][-1].tool_calls[0]
tool_call_id = tool_call["id"]
args = tool_call["args"]
user_feedback = interrupt({"text": args["question"], "buttons": args.get("options", None)})
tool_message = [{"tool_call_id": tool_call_id, "type": "tool", "content": user_feedback}]
return {"messages": tool_message, "user_input": user_feedback}
def should_generate_visualization_or_retry(state: SQLGraphState) -> str:
"""Conditional edge function to determine next action after execute_sql.
Args:
state (SQLGraphState): Current state
Returns:
str: Next node name - "generate_visualization" if SQL succeeded, "regenerate_sql" if retry needed, "end" if done
"""
execution_result = state.get("sql_execution_result", "")
retry_count = state.get("sql_retry_count", 0)
max_retries = 3
if execution_result == SQL_SUCCESS:
return "generate_visualization"
elif retry_count < max_retries and execution_result not in ("SQL_EXECUTE_TIMEOUT",):
return "regenerate_sql"
else:
return "end"
def build_sql_graph(
catalog: CatalogStore, checkpointer: Checkpointer, memory_store: BaseStore, llm_provider: str | None = None
) -> CompiledStateGraph:
"""Build SQL generation graph with all nodes and edges.
Args:
catalog: Catalog store containing schema information.
checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.
memory_store: The BaseStore to use for long-term memory. If None, no long-term memory.
Returns:
CompiledStateGraph: Compiled SQL graph ready for execution.
"""
tools = [search_knowledge, AskHuman]
search_tool_node = ToolNode([search_knowledge])
default_llm = get_llm(llm_provider)
if isinstance(default_llm, BaseChatOpenAI):
llm_with_tools = default_llm.bind_tools(tools, strict=True).bind(response_format={"type": "json_object"})
else:
llm_with_tools = default_llm.bind_tools(tools)
# Create SQL processing nodes with visualization configuration
generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node = create_sql_nodes(
get_text2sql_llm(llm_provider),
catalog,
dialect=config.get().dialect,
visualization_mode=config.get().visualization_mode,
)
# Define the SQL generation graph
graph = StateGraph(SQLGraphState, input_schema=InputState, output_schema=SQLOutputState)
# Add nodes to the graph
graph.add_node("search_knowledge", search_tool_node)
graph.add_node("ask_human", ask_human)
graph.add_node("information_extraction", information_extraction(llm_with_tools))
graph.add_node("table_selection", schema_linking(default_llm, catalog))
graph.add_node("generate_sql", generate_sql_node)
graph.add_node("execute_sql", execute_sql_node)
graph.add_node("regenerate_sql", regenerate_sql_node)
graph.add_node("generate_visualization", generate_visualization_node)
# Add basic edges
graph.add_edge(START, "information_extraction")
graph.add_edge("ask_human", "information_extraction")
graph.add_edge("search_knowledge", "information_extraction")
graph.add_edge("table_selection", "generate_sql")
# Add conditional routing from information extraction
graph.add_conditional_edges(
"information_extraction",
information_extraction_conditional_edges,
# mapping of paths to node names
{
"ask_human": "ask_human",
"search_knowledge": "search_knowledge",
"next": "table_selection",
"end": END,
},
)
# Add conditional edges for generate_sql
graph.add_conditional_edges(
"generate_sql",
should_execute_sql,
{
"execute_sql": "execute_sql",
"end": END,
},
)
# Add conditional edges for regenerate_sql
graph.add_conditional_edges(
"regenerate_sql",
should_execute_sql,
{
"execute_sql": "execute_sql",
"end": END,
},
)
# Add conditional edges for execute_sql - either retry, generate visualization, or end
graph.add_conditional_edges(
"execute_sql",
should_generate_visualization_or_retry,
{
"generate_visualization": "generate_visualization",
"regenerate_sql": "regenerate_sql",
"end": END,
},
)
# Add edge from visualization to end
graph.add_edge("generate_visualization", END)
graph = graph.compile(name="text2sql_graph", checkpointer=checkpointer, store=memory_store)
return graph
================================================
FILE: openchatbi/text2sql/text2sql_utils.py
================================================
"""Utility functions for text2sql retrieval systems."""
from openchatbi.llm.llm import get_embedding_model
from openchatbi.utils import create_vector_db
def init_sql_example_retriever(catalog, vector_db_path: str = None):
"""Initialize SQL example retriever from catalog.
Args:
catalog: Catalog store containing SQL examples.
vector_db_path: Path to the vector database file.
Returns:
tuple: (retriever, sql_example_dict)
"""
sql_examples = catalog.get_sql_examples()
sql_example_dict = {q: (sql, table) for q, sql, table in sql_examples}
texts = list(sql_example_dict.keys())
vector_db = create_vector_db(
texts,
get_embedding_model(),
collection_name="text2sql",
collection_metadata={"hnsw:space": "cosine"},
chroma_db_path=vector_db_path,
)
retriever = vector_db.as_retriever(
search_type="mmr", search_kwargs={"distance_metric": "cosine", "fetch_k": 30, "k": 10}
)
return retriever, sql_example_dict
def init_table_selection_example_dict(catalog, vector_db_path: str = None):
"""Initialize table selection example retriever from catalog.
Args:
catalog: Catalog store containing table selection examples.
vector_db_path: Path to the vector database file.
Returns:
tuple: (retriever, table_selection_example_dict)
"""
sql_examples = catalog.get_table_selection_examples()
table_selection_example_dict = dict((q, tables) for q, tables in sql_examples)
texts = list(table_selection_example_dict.keys())
if not texts:
texts = [""] # Empty text as fallback
vector_db = create_vector_db(
texts,
get_embedding_model(),
collection_name="table_selection_example",
collection_metadata={"hnsw:space": "cosine"},
chroma_db_path=vector_db_path,
)
retriever = vector_db.as_retriever(
search_type="mmr", search_kwargs={"distance_metric": "cosine", "fetch_k": 30, "k": 10}
)
return retriever, table_selection_example_dict
================================================
FILE: openchatbi/text2sql/visualization.py
================================================
"""Visualization generation for SQL query results using Plotly."""
from dataclasses import dataclass
from enum import Enum
from io import StringIO
from typing import Any
import pandas as pd
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from openchatbi.prompts.system_prompt import get_visualization_prompt_template
class ChartType(Enum):
"""Supported chart types for data visualization."""
LINE = "line"
BAR = "bar"
PIE = "pie"
SCATTER = "scatter"
HISTOGRAM = "histogram"
BOX = "box"
HEATMAP = "heatmap"
TABLE = "table"
@dataclass
class VisualizationConfig:
"""Configuration for generating visualization DSL."""
chart_type: ChartType
x_column: str | None = None
y_column: str | None = None
color_column: str | None = None
size_column: str | None = None
title: str | None = None
x_title: str | None = None
y_title: str | None = None
show_legend: bool = True
width: int | None = None
height: int | None = None
@dataclass
class VisualizationDSL:
"""Plotly-friendly DSL for data visualization."""
chart_type: str
data_columns: list[str]
config: dict[str, Any]
layout: dict[str, Any]
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
"chart_type": self.chart_type,
"data_columns": self.data_columns,
"config": self.config,
"layout": self.layout,
}
class VisualizationService:
"""Service class to handle visualization generation with configurable analysis method."""
# Chart type mapping for LLM responses
CHART_TYPE_MAPPING = {
"line": ChartType.LINE,
"bar": ChartType.BAR,
"pie": ChartType.PIE,
"scatter": ChartType.SCATTER,
"histogram": ChartType.HISTOGRAM,
"box": ChartType.BOX,
"heatmap": ChartType.HEATMAP,
"table": ChartType.TABLE,
}
def __init__(self, llm: BaseChatModel | None = None):
"""Initialize visualization service.
Args:
llm: BaseChatModel LLM instance, will skip using LLM if None
"""
self.llm = llm
def _get_chart_type_by_rule(self, question: str, schema_info: dict[str, Any]) -> ChartType:
"""Recommend chart type based on user question and data schema using rules."""
question_lower = question.lower()
# Get data characteristics
numeric_cols = schema_info.get("numeric_columns", [])
categorical_cols = schema_info.get("categorical_columns", [])
datetime_cols = schema_info.get("datetime_columns", [])
row_count = schema_info.get("row_count", 0)
# Question-based heuristics
if any(keyword in question_lower for keyword in ["trend", "over time", "timeline", "time series"]):
return ChartType.LINE
elif any(keyword in question_lower for keyword in ["distribution", "frequency", "histogram"]):
return ChartType.HISTOGRAM
elif any(keyword in question_lower for keyword in ["correlation", "relationship", "scatter"]):
return ChartType.SCATTER
elif any(keyword in question_lower for keyword in ["proportion", "percentage", "share", "pie"]):
return ChartType.PIE
elif any(keyword in question_lower for keyword in ["compare", "comparison", "vs", "versus", "bar"]):
return ChartType.BAR
elif any(keyword in question_lower for keyword in ["summary", "range", "quartile", "box"]):
return ChartType.BOX
# Data-based heuristics
if len(datetime_cols) > 0 and len(numeric_cols) > 0:
return ChartType.LINE
elif len(categorical_cols) == 1 and len(numeric_cols) == 1:
unique_count = schema_info.get("unique_counts", {}).get(categorical_cols[0], 0)
if unique_count <= 10:
return ChartType.PIE if unique_count <= 6 else ChartType.BAR
else:
return ChartType.BAR
elif len(numeric_cols) == 2:
return ChartType.SCATTER
elif len(numeric_cols) == 1 and len(categorical_cols) == 0:
return ChartType.HISTOGRAM
elif row_count <= 20: # Changed from 50 to 20
return ChartType.TABLE
else:
return ChartType.BAR
def generate_visualization_dsl(
self, question: str, schema_info: dict[str, Any], chart_type: ChartType | None = None
) -> VisualizationDSL:
"""Generate visualization DSL based on question and schema info."""
if "error" in schema_info:
# Return table view for error cases
return VisualizationDSL(
chart_type="table",
data_columns=["error"],
config={"error": schema_info["error"]},
layout={"title": "Data Analysis Error"},
)
# Determine chart type
if chart_type is None:
chart_type = self._get_chart_type_by_rule(question, schema_info)
columns = schema_info["columns"]
numeric_cols = schema_info["numeric_columns"]
categorical_cols = schema_info["categorical_columns"]
datetime_cols = schema_info["datetime_columns"]
# Generate DSL based on chart type
if chart_type == ChartType.LINE:
x_col = datetime_cols[0] if datetime_cols else (categorical_cols[0] if categorical_cols else columns[0])
# For line charts, include all numeric columns for multiple metrics
y_cols = numeric_cols if numeric_cols else [columns[-1]]
data_columns = [x_col] + y_cols
# Support multiple y-axis columns
config = {"x": x_col, "mode": "lines+markers"}
if len(y_cols) == 1:
config["y"] = y_cols[0]
title = f"Line Chart: {y_cols[0]} over {x_col}"
else:
config["y"] = y_cols # Multiple metrics
title = f"Line Chart: {', '.join(y_cols)} over {x_col}"
return VisualizationDSL(
chart_type="line",
data_columns=data_columns,
config=config,
layout={"title": title, "xaxis_title": x_col, "yaxis_title": "Value"},
)
elif chart_type == ChartType.BAR:
x_col = categorical_cols[0] if categorical_cols else columns[0]
# For bar charts, include all numeric columns for multiple metrics
y_cols = numeric_cols if numeric_cols else [columns[-1]]
data_columns = [x_col] + y_cols
config = {"x": x_col}
if len(y_cols) == 1:
config["y"] = y_cols[0]
title = f"Bar Chart: {y_cols[0]} by {x_col}"
else:
config["y"] = y_cols # Multiple metrics
title = f"Bar Chart: {', '.join(y_cols)} by {x_col}"
return VisualizationDSL(
chart_type="bar",
data_columns=data_columns,
config=config,
layout={"title": title, "xaxis_title": x_col, "yaxis_title": "Value"},
)
elif chart_type == ChartType.PIE:
label_col = categorical_cols[0] if categorical_cols else columns[0]
value_col = numeric_cols[0] if numeric_cols else columns[-1]
return VisualizationDSL(
chart_type="pie",
data_columns=[label_col, value_col],
config={"labels": label_col, "values": value_col},
layout={"title": f"Pie Chart: {value_col} by {label_col}"},
)
elif chart_type == ChartType.SCATTER:
x_col = numeric_cols[0] if len(numeric_cols) > 0 else columns[0]
y_col = numeric_cols[1] if len(numeric_cols) > 1 else columns[-1]
return VisualizationDSL(
chart_type="scatter",
data_columns=[x_col, y_col],
config={"x": x_col, "y": y_col, "mode": "markers"},
layout={"title": f"Scatter Plot: {y_col} vs {x_col}", "xaxis_title": x_col, "yaxis_title": y_col},
)
elif chart_type == ChartType.HISTOGRAM:
col = numeric_cols[0] if numeric_cols else columns[0]
return VisualizationDSL(
chart_type="histogram",
data_columns=[col],
config={"x": col, "nbins": 20},
layout={"title": f"Histogram: Distribution of {col}", "xaxis_title": col, "yaxis_title": "Frequency"},
)
elif chart_type == ChartType.BOX:
y_col = numeric_cols[0] if numeric_cols else columns[0]
x_col = categorical_cols[0] if categorical_cols else None
config = {"y": y_col}
if x_col:
config["x"] = x_col
return VisualizationDSL(
chart_type="box",
data_columns=[col for col in [x_col, y_col] if col],
config=config,
layout={
"title": f"Box Plot: {y_col}" + (f" by {x_col}" if x_col else ""),
"xaxis_title": x_col if x_col else "",
"yaxis_title": y_col,
},
)
else: # TABLE or fallback
return VisualizationDSL(
chart_type="table", data_columns=columns, config={"columns": columns}, layout={"title": "Data Table"}
)
def _llm_recommend_chart_type(self, question: str, schema_info: dict[str, Any], data_sample: str) -> ChartType:
"""Use LLM to recommend chart type based on question and data analysis.
Args:
question: User's question or intent
schema_info: Data schema information
data_sample: Sample of the data
Returns:
ChartType: Recommended chart type
"""
try:
prompt = (
get_visualization_prompt_template()
.replace("[question]", question)
.replace("[columns]", str(schema_info.get("columns", [])))
.replace("[numeric_columns]", str(schema_info.get("numeric_columns", [])))
.replace("[categorical_columns]", str(schema_info.get("categorical_columns", [])))
.replace("[datetime_columns]", str(schema_info.get("datetime_columns", [])))
.replace("[row_count]", str(schema_info.get("row_count", 0)))
.replace("[data_sample]", data_sample)
)
# Call LLM with the formatted prompt
response = self.llm.invoke([HumanMessage(content=prompt)])
chart_type_str = response.content.strip().lower()
return self.CHART_TYPE_MAPPING.get(chart_type_str, ChartType.TABLE)
except Exception:
# Fallback to rule-based recommendation on other LLM errors
return self._get_chart_type_by_rule(question, schema_info)
def generate_visualization(
self, question: str, schema_info: dict[str, Any], csv_data: str, chart_type: ChartType | None = None
) -> VisualizationDSL | None:
"""Generate visualization using the configured analysis method.
Args:
question: User's question or intent
schema_info: Pre-analyzed schema information
csv_data: CSV data string for LLM analysis if needed
chart_type: Optional specific chart type to use
Returns:
VisualizationDSL or None: Generated visualization configuration, or None if skipped
"""
# Use existing DSL generation if chart type is already specified
if chart_type is not None:
return self.generate_visualization_dsl(question, schema_info, chart_type)
# Determine chart type based on configured method
if self.llm:
if "error" in schema_info:
return VisualizationDSL(
chart_type="table",
data_columns=["error"],
config={"error": schema_info["error"]},
layout={"title": "Data Analysis Error"},
)
# Prepare data sample for LLM analysis
try:
df = pd.read_csv(StringIO(csv_data))
data_sample = df.head(3).to_string() if len(df) > 0 else "No data available"
except Exception:
data_sample = "Unable to parse data"
chart_type = self._llm_recommend_chart_type(question, schema_info, data_sample)
# Generate DSL using determined or recommended chart type
return self.generate_visualization_dsl(question, schema_info, chart_type)
================================================
FILE: openchatbi/text_segmenter.py
================================================
"""Text segmentation utility with jieba support."""
import re
import string
import sys
# Try to import jieba, fallback to None if not available
# Note: jieba is not compatible with Python 3.12+
_jieba_available = False
if sys.version_info < (3, 12):
try:
import jieba
_jieba_available = True
except ImportError:
_jieba_available = False
class TextSegmenter:
"""A text segmenter that uses jieba for Chinese text and simple splitting for others.
This segmenter tries to use jieba for better Chinese word segmentation.
If jieba is not available or Python version is 3.12+, it falls back to simple
punctuation/whitespace splitting.
Note: jieba is not compatible with Python 3.12+, so simple segmentation will be
used on Python 3.12 and higher versions.
"""
def __init__(self, use_jieba: bool = True):
"""Initialize the text segmenter.
Args:
use_jieba: Whether to use jieba for Chinese text segmentation.
Defaults to True. Will automatically fall back to simple
segmentation if jieba is not available.
"""
self.use_jieba = use_jieba and _jieba_available
# Include both English and Chinese punctuation
chinese_punctuation = ",。!?;:" "''()【】《》〈〉「」『』〔〕"
all_separators = string.punctuation + chinese_punctuation + " \t\n\r"
# Create regex pattern to split on any separator
self.split_pattern = "[" + re.escape(all_separators) + "]+"
@staticmethod
def _contains_chinese(text: str) -> bool:
"""Check if text contains Chinese characters.
Args:
text: Input text to check
Returns:
True if text contains Chinese characters, False otherwise
"""
return any("\u4e00" <= char <= "\u9fff" for char in text)
def _simple_cut(self, text: str) -> list[str]:
"""Simple segmentation by splitting on punctuation and whitespace.
Args:
text: Input text to be segmented
Returns:
List of tokens
"""
if not text:
return []
# Split by separators and filter empty strings
tokens = re.split(self.split_pattern, text)
return [token for token in tokens if token.strip()]
def cut(self, text: str) -> list[str]:
"""Segment text into tokens.
For Chinese text with jieba available, uses jieba for word segmentation.
Otherwise, splits by punctuation and whitespace.
Args:
text: Input text to be segmented
Returns:
List of tokens
"""
if not text:
return []
# Use jieba for Chinese text if available
if self.use_jieba and self._contains_chinese(text):
return list(jieba.cut(text))
# Fall back to simple segmentation
return self._simple_cut(text)
class SimpleSegmenter:
"""A simple text segmenter that splits text by punctuation and whitespace.
This is a lightweight text segmentation tool that provides basic
functionality without external dependencies.
Note: This class is kept for backward compatibility. Consider using
TextSegmenter instead for better Chinese text support.
"""
def __init__(self):
# Include both English and Chinese punctuation
chinese_punctuation = ",。!?;:" "''()【】《》〈〉「」『』〔〕"
all_separators = string.punctuation + chinese_punctuation + " \t\n\r"
# Create regex pattern to split on any separator
self.split_pattern = "[" + re.escape(all_separators) + "]+"
def cut(self, text: str) -> list[str]:
"""Segment text into tokens by splitting on punctuation and whitespace.
Args:
text: Input text to be segmented
Returns:
List of tokens
"""
if not text:
return []
# Split by separators and filter empty strings
tokens = re.split(self.split_pattern, text)
return [token for token in tokens if token.strip()]
# Global instance - use TextSegmenter with jieba support
_segmenter = TextSegmenter()
================================================
FILE: openchatbi/tool/ask_human.py
================================================
"""Tool for asking human clarification when information is ambiguous."""
from pydantic import BaseModel, Field
class AskHuman(BaseModel):
"""Ask user for clarification when data is missing or ambiguous.
Use this tool ONLY when you are STRONGLY certain that information is
ambiguous or missing. First try to solve the question with available
user input before calling this tool.
"""
question: str = Field(description="Question to ask the user for clarification")
options: list[str] = Field(description="Options for user to choose (max 3). Empty if not a choice question.")
================================================
FILE: openchatbi/tool/mcp_tools.py
================================================
"""MCP (Model Context Protocol) tools integration for OpenChatBI.
This module provides integration with MCP servers using langchain-mcp-adapters,
allowing the agent to use external tools through the Model Context Protocol.
"""
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from langchain_core.tools import StructuredTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from pydantic import BaseModel, Field
from openchatbi.constants import MCP_TOOL_DEFAULT_TIMEOUT_SECONDS
logger = logging.getLogger(__name__)
def make_tool_sync_compatible(tool: StructuredTool, timeout: int) -> StructuredTool:
"""Make an async-only StructuredTool compatible with sync invocation.
This wraps the async coroutine with a sync function that runs it in an event loop.
Args:
tool: The StructuredTool to make sync-compatible
timeout: Timeout in seconds for tool execution
Returns:
StructuredTool with sync compatibility
"""
if tool.func is not None:
# Tool already has sync support
return tool
if tool.coroutine is None:
# Tool has no async function either, can't help
return tool
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Synchronous wrapper for async tool function."""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, can't use run_until_complete
# Create a new thread with its own event loop
with ThreadPoolExecutor(max_workers=1) as executor:
def run_in_new_loop() -> Any:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(tool.coroutine(*args, **kwargs)) # type: ignore
finally:
new_loop.close()
future = executor.submit(run_in_new_loop)
return future.result(timeout=timeout)
else:
# No running loop, we can use run_until_complete
return loop.run_until_complete(tool.coroutine(*args, **kwargs)) # type: ignore
except RuntimeError:
# No event loop exists, create one
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(tool.coroutine(*args, **kwargs)) # type: ignore
finally:
loop.close()
# Create a new StructuredTool with both sync and async functions
return StructuredTool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
func=sync_wrapper,
coroutine=tool.coroutine,
)
class MCPServerConfig(BaseModel):
"""Configuration for MCP server connection."""
name: str = Field(description="Name of the MCP server")
transport: str = Field(default="stdio", description="Transport type: stdio, sse, or streamable_http")
# For stdio transport
command: list[str] = Field(default_factory=list, description="Command to start the MCP server")
args: list[str] = Field(default_factory=list, description="Arguments for the MCP server")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables")
# For HTTP transports (sse, streamable_http)
url: str = Field(default="", description="URL for HTTP-based transports")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers")
# Common settings
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
timeout: int = Field(default=MCP_TOOL_DEFAULT_TIMEOUT_SECONDS, description="Connection timeout in seconds")
async def create_mcp_tools_async(server_configs: list[dict[str, Any]]) -> list[StructuredTool]:
"""Create MCP tools asynchronously from server configurations.
This function processes MCP server configurations, establishes connections to enabled
servers, retrieves available tools, and makes them sync-compatible with proper
timeout configuration.
Args:
server_configs: List of MCP server configuration dictionaries containing
server connection details, transport settings, and timeouts
Returns:
List of LangChain StructuredTool instances with mcp_ prefixes and sync compatibility
"""
if not server_configs:
return []
# Filter enabled servers and convert to MCPServerConfig
enabled_servers = {}
max_timeout = MCP_TOOL_DEFAULT_TIMEOUT_SECONDS # Default from constants
for config_dict in server_configs:
try:
config = MCPServerConfig(**config_dict)
if not config.enabled:
continue
server_name = config.name
# Track the maximum timeout across all servers
max_timeout = max(max_timeout, config.timeout)
# Build server configuration for MultiServerMCPClient
if config.transport == "stdio":
if not config.command:
logger.warning(f"MCP server {server_name}: command required for stdio transport")
continue
enabled_servers[server_name] = {
"transport": "stdio",
"command": config.command[0] if config.command else "",
"args": config.command[1:] + config.args if len(config.command) > 1 else config.args,
"env": config.env,
}
elif config.transport in ["sse", "streamable_http"]:
if not config.url:
logger.warning(f"MCP server {server_name}: url required for {config.transport} transport")
continue
server_config: dict[str, Any] = {
"transport": config.transport,
"url": config.url,
}
if config.headers:
server_config["headers"] = config.headers
enabled_servers[server_name] = server_config
else:
logger.warning(f"MCP server {server_name}: unsupported transport {config.transport}")
continue
except Exception as e:
logger.error(f"Invalid MCP server configuration: {e}")
continue
if not enabled_servers:
logger.info("No enabled MCP servers found")
return []
try:
# Create MultiServerMCPClient and get tools with timeout
client = MultiServerMCPClient(enabled_servers)
tools = await asyncio.wait_for(client.get_tools(), timeout=max_timeout)
logger.info(f"Successfully loaded {len(tools)} MCP tools from {len(enabled_servers)} servers")
# Add server prefix to tool names and make sync-compatible
prefixed_tools = []
for tool in tools:
# Get server name from tool metadata or guess from tool name
original_name = tool.name
if not original_name.startswith("mcp_"):
tool.name = f"mcp_{original_name}"
# Make tool sync-compatible with configured timeout
sync_compatible_tool = make_tool_sync_compatible(tool, timeout=max_timeout)
prefixed_tools.append(sync_compatible_tool)
return prefixed_tools
except Exception as e:
logger.error(f"Failed to initialize MCP client: {e}")
return []
def create_mcp_tools_sync(server_configs: list[dict[str, Any]]) -> list[StructuredTool]:
"""Create MCP tools from server configurations synchronously.
This function initializes MCP tools in a separate thread with its own event loop
to avoid conflicts with existing async contexts.
Args:
server_configs: List of MCP server configuration dictionaries
Returns:
List of LangChain StructuredTool instances with sync compatibility
"""
if not server_configs:
return []
# For sync mode, run async initialization in a thread
def sync_initialize() -> list[StructuredTool]:
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(create_mcp_tools_async(server_configs))
except Exception as e:
logger.error(f"Failed to create MCP tools in sync mode: {e}")
return []
finally:
loop.close()
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_initialize)
return future.result(timeout=MCP_TOOL_DEFAULT_TIMEOUT_SECONDS)
except Exception as e:
logger.error(f"MCP tools sync initialization failed: {e}")
return []
# Global variable to store async-initialized tools
_async_mcp_tools = None
async def get_mcp_tools_async(server_configs: list[dict[str, Any]]) -> list[StructuredTool]:
"""Get MCP tools asynchronously, using cached version if available.
Args:
server_configs: List of MCP server configuration dictionaries
Returns:
List of cached or newly created LangChain StructuredTool instances
"""
global _async_mcp_tools
if _async_mcp_tools is None:
_async_mcp_tools = await create_mcp_tools_async(server_configs)
return _async_mcp_tools
def reset_mcp_tools_cache() -> None:
"""Reset the async MCP tools cache."""
global _async_mcp_tools
_async_mcp_tools = None
================================================
FILE: openchatbi/tool/memory.py
================================================
import functools
import sys
from typing import Any
try:
import pysqlite3 as sqlite3
except ImportError: # pragma: no cover
import sqlite3
# Make sure langgraph sqlite connector uses the same sqlite module.
sys.modules["sqlite3"] = sqlite3
from langchain.tools import StructuredTool
from langchain_core.language_models import BaseChatModel
from langchain_openai.chat_models.base import BaseChatOpenAI
from langgraph.store.sqlite import SqliteStore
from langgraph.store.sqlite.aio import AsyncSqliteStore
from langmem import (
create_manage_memory_tool,
create_memory_store_manager,
create_search_memory_tool,
)
from openchatbi import config
try:
from pydantic import BaseModel, ConfigDict
except ImportError:
ConfigDict = None
# Use AsyncSqliteStore for async operations
async_memory_store = None
async_store_context_manager = None
sync_memory_store = None
memory_manager = None
# Define profile structure
class UserProfile(BaseModel):
"""Represents the full representation of a user."""
name: str | None = None
language: str | None = None
timezone: str | None = None
jargon: str | None = None
def get_sync_memory_store() -> SqliteStore | None:
global sync_memory_store
embedding_model = config.get().embedding_model
if not embedding_model:
return None
if sync_memory_store is None:
# For backwards compatibility and sync operations
conn = sqlite3.connect("memory.db", check_same_thread=False)
conn.isolation_level = None
sync_memory_store = SqliteStore(
conn,
index={
"dims": 1536,
"embed": embedding_model,
"fields": ["text"], # specify which fields to embed
},
)
try:
sync_memory_store.setup()
except Exception:
pass
return sync_memory_store
async def get_async_memory_store() -> AsyncSqliteStore | None:
"""Get or create the async memory store."""
global async_memory_store, async_store_context_manager
embedding_model = config.get().embedding_model
if not embedding_model:
return None
if async_memory_store is None:
# AsyncSqliteStore.from_conn_string returns an async context manager
async_store_context_manager = AsyncSqliteStore.from_conn_string(
"memory.db",
index={
"dims": 1536,
"embed": embedding_model,
"fields": ["text"], # specify which fields to embed
},
)
async_memory_store = await async_store_context_manager.__aenter__()
return async_memory_store
async def cleanup_async_memory_store() -> None:
"""Cleanup async memory store resources."""
global async_memory_store, async_store_context_manager
if async_memory_store is not None and async_store_context_manager is not None:
try:
await async_store_context_manager.__aexit__(None, None, None)
except Exception as e:
print(f"Error cleaning up async memory store: {e}")
finally:
async_memory_store = None
async_store_context_manager = None
async def setup_async_memory_store() -> Any:
"""Setup async memory store for langmem."""
await get_async_memory_store()
def fix_schema_for_openai(schema: dict) -> None:
props = schema.get("properties", {})
schema["required"] = list(props.keys())
# Since Pydantic 2.11, it will always add `additionalProperties: True` for arbitrary dictionary schemas
# If it is already set to True, we need override it to False
# Can remove this fix when the patch release: https://github.com/langchain-ai/langchain/pull/32879
def fix(obj):
if isinstance(obj, dict):
if obj.get("type") == "object" and "additionalProperties" in obj and obj["additionalProperties"]:
obj["additionalProperties"] = False
for v in obj.values():
fix(v)
elif isinstance(obj, list):
for item in obj:
fix(item)
fix(schema)
def get_memory_manager() -> Any:
global memory_manager
if memory_manager is None:
memory_manager = create_memory_store_manager(
config.get().default_llm,
schemas=[UserProfile],
instructions="Extract user profile information",
enable_inserts=False,
)
return memory_manager
class StructuredToolWithRequired(StructuredTool):
def __init__(self, orig_tool: StructuredTool):
name = getattr(orig_tool, "name", None)
super().__init__(
name=name,
description=orig_tool.description,
args_schema=orig_tool.args_schema,
func=orig_tool.func,
coroutine=orig_tool.coroutine,
)
@functools.cached_property
def tool_call_schema(self) -> "ArgsSchema":
tcs = super().tool_call_schema
try:
if tcs.model_config:
tcs.model_config["json_schema_extra"] = fix_schema_for_openai
elif ConfigDict is not None:
tcs.model_config = ConfigDict(json_schema_extra=fix_schema_for_openai)
except Exception:
pass
return tcs
def get_memory_tools(
llm: BaseChatModel, sync_mode: bool = False, store: Any | None = None
) -> list[StructuredTool] | None:
# Get the appropriate store based on mode
if not store:
if sync_mode:
store = get_sync_memory_store()
else:
store = None
if not store:
return None
# create langmem manage memory tool with {user_id} template
manage_memory_tool = create_manage_memory_tool(namespace=("memories", "{user_id}"), store=store)
search_memory_tool = create_search_memory_tool(namespace=("memories", "{user_id}"), store=store)
if isinstance(llm, BaseChatOpenAI):
manage_memory_tool = StructuredToolWithRequired(manage_memory_tool)
search_memory_tool = StructuredToolWithRequired(search_memory_tool)
return [manage_memory_tool, search_memory_tool]
async def get_async_memory_tools(llm: BaseChatModel) -> list[StructuredTool]:
"""Get memory tools configured with async store."""
async_store = await get_async_memory_store()
return get_memory_tools(llm, sync_mode=False, store=async_store)
================================================
FILE: openchatbi/tool/run_python_code.py
================================================
"""Tool for running python code."""
from langchain.tools import tool
from pydantic import BaseModel, Field
from openchatbi.code.docker_executor import DockerExecutor, check_docker_status
from openchatbi.code.local_executor import LocalExecutor
from openchatbi.code.restricted_local_executor import RestrictedLocalExecutor
from openchatbi.config_loader import ConfigLoader
from openchatbi.utils import log
class PythonCodeInput(BaseModel):
reasoning: str = Field(description="Reason for using this run python code tool")
code: str = Field(description="The python code to execute")
def _create_executor():
"""Create appropriate executor based on configuration."""
config_loader = ConfigLoader()
try:
config = config_loader.get()
executor_type = config.python_executor.lower()
except ValueError:
# Configuration not loaded, use default local executor
log("Configuration not loaded, using default LocalExecutor")
return LocalExecutor()
log(f"Creating executor of type: {executor_type}")
if executor_type == "docker":
# Check if Docker is available before creating DockerExecutor
is_available, status_message = check_docker_status()
if not is_available:
log(f"Docker is not available ({status_message}), falling back to LocalExecutor")
return LocalExecutor()
log("Docker is available, creating DockerExecutor")
return DockerExecutor()
elif executor_type == "restricted_local":
log("Creating RestrictedLocalExecutor")
return RestrictedLocalExecutor()
elif executor_type == "local":
log("Creating LocalExecutor")
return LocalExecutor()
else:
log(f"Unknown executor type '{executor_type}', using LocalExecutor as fallback")
return LocalExecutor()
@tool("run_python_code", args_schema=PythonCodeInput, return_direct=False, infer_schema=True)
def run_python_code(reasoning: str, code: str) -> str:
"""Run python code string. Note: Only print outputs are visible, function return values will be ignored. Use print statements to see results.
Returns:
str: The print outputs of the python code
"""
log(f"Run Python Code, Reasoning: {reasoning}")
try:
executor = _create_executor()
log(f"Using {executor.__class__.__name__} for code execution")
success, output = executor.run_code(code)
if success:
return output
else:
return f"Error: {output}"
except Exception as e:
log(f"Failed to create executor: {e}")
# Fallback to LocalExecutor if configuration fails
log("Falling back to LocalExecutor")
executor = LocalExecutor()
success, output = executor.run_code(code)
if success:
return output
else:
return f"Error: {output}"
================================================
FILE: openchatbi/tool/save_report.py
================================================
"""Tool for saving reports to files."""
import datetime
from pathlib import Path
from langchain.tools import tool
from pydantic import BaseModel, Field
from openchatbi import config
from openchatbi.utils import log
class SaveReportInput(BaseModel):
content: str = Field(description="The content of the report to save")
title: str = Field(description="The title of the report (will be used in filename)")
file_format: str = Field(
description="The file format/extension, only support 'md', 'csv', 'txt', 'json', 'html', 'xml'"
)
@tool("save_report", args_schema=SaveReportInput, return_direct=False, infer_schema=True)
def save_report(content: str, title: str, file_format: str = "md") -> str:
"""Save a report to a file with timestamp and title in filename.
Args:
content: The content of the report to save
title: The title of the report (will be used in filename)
file_format: The file format/extension, only support 'md', 'csv', 'txt', 'json', 'html', 'xml'
Returns:
str: Success message with download link or error message
"""
allowed_formats = {"md", "csv", "txt", "json", "html", "xml"}
if file_format not in allowed_formats:
raise ValueError(f"Unsupported file format: {file_format}")
try:
# Get report directory from config
report_dir = config.get().report_directory
# Create directory if it doesn't exist
Path(report_dir).mkdir(parents=True, exist_ok=True)
# Generate timestamp for filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# Clean title for filename (remove invalid characters)
clean_title = "".join(c for c in title if c.isalnum() or c in (" ", "-")).rstrip()
clean_title = clean_title.replace(" ", "_")
# Create filename
filename = f"{timestamp}_{clean_title}.{file_format}"
file_path = Path(report_dir) / filename
# Write content to file
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
log(f"Report saved: {file_path}")
# Return success message with download link
download_url = f"/api/download/report/{filename}"
return f"Report saved successfully! Download link: {download_url}"
except Exception as e:
error_msg = f"Failed to save report: {str(e)}"
log(error_msg)
return error_msg
================================================
FILE: openchatbi/tool/search_knowledge.py
================================================
"""Tools for searching knowledge bases and schema information."""
from langchain.tools import tool
from pydantic import BaseModel, Field
from openchatbi import config
from openchatbi.catalog.schema_retrival import col_dict, column_tables_mapping, get_relevant_columns
from openchatbi.utils import log
class SearchInput(BaseModel):
"""Input schema for knowledge search tool."""
reasoning: str = Field(description="Reason for using this search tool")
query_list: list[str] = Field(description="Query terms to search (max 5, avoid duplicates)")
knowledge_bases: list[str] = Field(
description="""Knowledge bases to search, options are:
- `"columns"`: The description, alias of columns, including dimensions and metrics.
- `"business"`: The business knowledge."""
)
with_table_list: bool = Field(
description="Include table list for columns (only set to True when user asks about table-column relationships)"
)
@tool("search_knowledge", args_schema=SearchInput, return_direct=False, infer_schema=True)
def search_knowledge(
reasoning: str, query_list: list[str], knowledge_bases: list[str], with_table_list: bool = False
) -> dict[str, str]:
"""Search relevant knowledge from knowledge bases.
Returns:
Dict[str, str]: Search results for each knowledge base.
"""
log(f"Search knowledge, query_list={query_list}, knowledge_bases={knowledge_bases}, reasoning={reasoning}")
final_results = {}
if "columns" in knowledge_bases:
column_results = search_column_from_catalog(query_list, with_table_list)
final_results["columns"] = f"# Relevant Columns and Description:\n{column_results}"
return final_results
class ShowSchemaInput(BaseModel):
"""Input schema for show schema tool."""
reasoning: str = Field(description="Reason for showing schema")
tables: list[str] = Field(description="Full table names to show (max 5)")
@tool("show_schema", args_schema=ShowSchemaInput, return_direct=False, infer_schema=True)
def show_schema(reasoning: str, tables: list[str]) -> list[str]:
"""Show table schemas including description, columns, and derived metrics.
Returns:
list[str]: Schema information for each table.
"""
log(f"Show schema, tables={tables}, reasoning={reasoning}")
result = list_table_from_catalog(tables)
return result
def search_column_from_catalog(query_list: list[str], with_table_list: bool) -> str:
"""Search columns from catalog based on query list."""
relevant_column_set = set()
for keywords in query_list:
relevant_columns = get_relevant_columns(keywords.split(" "), keywords.split(" "), keywords.split(" "))
relevant_column_set.update(relevant_columns)
column_results = render_column_result(relevant_column_set, with_table_list)
return "\n".join(column_results)
def list_table_from_catalog(tables: list[str]) -> list[str]:
"""Get table information from catalog."""
result = []
catalog_store = config.get().catalog_store
for table_name in tables:
table_info = catalog_store.get_table_information(table_name)
if not table_info:
continue
table_desc = f"Table: `{table_name}` \n# Description: {table_info['description']}\n"
columns = catalog_store.get_column_list(table_name)
column_names = [info["column_name"] for info in columns]
column_results = render_column_result(column_names)
table_desc += "# Columns:\n"
table_desc += "\n".join(column_results)
if table_info.get("derived_metric"):
table_desc += "## Derived metrics:\n"
table_desc += table_info["derived_metric"]
result.append(table_desc)
return result
def render_column_result(column_list: list[str], with_table_list: bool = False) -> list[str]:
"""Render column information as formatted strings."""
column_results = []
for column_name in column_list:
if column_name not in col_dict:
continue
table_list = column_tables_mapping.get(column_name, [])
column = col_dict[column_name]
column_desc = (
f"## {column['column_name']}"
f"\n- Column Category: {column['category']}"
f"\n- Display Name: {column['display_name']} "
f"\n- Description \"{column['description']}\""
)
if with_table_list:
column_desc += f"\n- Related Tables: {table_list}"
column_results.append(column_desc)
return column_results
================================================
FILE: openchatbi/tool/timeseries_forecast.py
================================================
"""Tool for time series forecasting."""
import logging
from typing import Any
import requests
from langchain.tools import tool
from pydantic import BaseModel, Field
from openchatbi import config
from openchatbi.utils import log
logger = logging.getLogger(__name__)
class TimeseriesForecastInput(BaseModel):
"""Input schema for time series forecasting tool."""
reasoning: str = Field(description="Reason for using time series forecasting and what insights you expect to gain")
input_data: list[float | int | dict[str, Any]] = Field(
description="Time series data as list of numbers or structured data with timestamps and values"
)
forecast_window: int = Field(
default=24, description="Number of future time points to predict (1-200)", ge=1, le=200
)
frequency: str = Field(default="hourly", description="Time series frequency: hourly, daily, weekly, monthly, etc.")
input_length: int | None = Field(
default=None, description="Optional limit on input data length to use for prediction"
)
target_column: str = Field(
default="value", description="Column name to forecast for structured data (default: 'value')"
)
def _check_service_health(service_url: str) -> bool:
"""Check if time series forecasting service is available."""
try:
response = requests.get(f"{service_url}/health", timeout=5)
if response.status_code == 200:
health_data = response.json()
return health_data.get("model_initialized", False)
return False
except requests.exceptions.RequestException:
return False
def check_forecast_service_health() -> bool:
try:
service_url = config.get().timeseries_forecasting_service_url
return _check_service_health(service_url)
except ValueError:
# Configuration not loaded yet (e.g., in tests)
return False
def _call_timeseries_service(
service_url: str,
input_data: list[float | int | dict[str, Any]],
forecast_window: int,
frequency: str,
input_length: int | None = None,
target_column: str = "value",
) -> dict[str, Any]:
"""Call time series forecasting service."""
try:
# Prepare request payload
payload = {"input": input_data, "forecast_window": forecast_window, "frequency": frequency}
if input_length is not None:
payload["input_len"] = input_length
if target_column != "value":
payload["target_column"] = target_column
# Make request to time series forecasting service
response = requests.post(f"{service_url}/predict", json=payload, timeout=30)
if response.status_code == 200:
return response.json()
else:
return {
"error": f"Service returned status {response.status_code}: {response.text}",
"status": "http_error",
"status_code": response.status_code,
}
except requests.exceptions.Timeout:
return {"error": "Request timeout - forecasting service took too long to respond", "status": "error"}
except requests.exceptions.RequestException as e:
return {"error": f"Failed to connect to forecasting service: {str(e)}", "status": "error"}
except Exception as e:
return {"error": f"Unexpected error: {str(e)}", "status": "error"}
def _format_forecast_result(result: dict[str, Any], reasoning: str, input_data_length: int) -> str:
"""Format the forecasting result for the agent."""
if result.get("status") == "error":
return f"""Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')}
Please check:
1. Time series forecasting service is running (docker run -p 8765:8765 timeseries-forecasting)
2. Model load successfully
3. Try again if timeout"""
elif result.get("status") == "http_error":
if result.get("status_code") == 400:
return f"""Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')}
Please check:
1. Input data format is correct
2. input_len is set to larger when the input data length is not enough
3. Forecast window is reasonable (1-200)"""
else:
return f"""Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')}"""
predictions = result.get("predictions", [])
forecast_window = result.get("forecast_window", len(predictions))
frequency = result.get("frequency", "unknown")
if not predictions:
return "No predictions were generated. Please check your input data."
# Calculate basic statistics
sum_predictions = sum(predictions)
avg_prediction = sum_predictions / len(predictions) if predictions else 0
min_prediction = min(predictions) if predictions else 0
max_prediction = max(predictions) if predictions else 0
# Create formatted response
response_parts = [
"✅ Time Series Forecasting Completed",
"",
"Forecast Summary:",
f" • Input data points: {input_data_length}",
f" • Forecast window: {forecast_window} {frequency.lower()} periods",
"",
"Predictions:",
f" • Average forecast: {avg_prediction:.2f}",
f" • Sum: {sum_predictions:.2f}",
f" • Range: {min_prediction:.2f} to {max_prediction:.2f}",
f" • Total periods forecasted: {len(predictions)}",
"",
"Detailed Forecast Values:",
]
for i, pred in enumerate(predictions):
period_label = f"Period {i + 1}"
response_parts.append(f" • {period_label}: {pred:.2f}")
return "\n".join(response_parts)
@tool("timeseries_forecast", args_schema=TimeseriesForecastInput, return_direct=False, infer_schema=True)
def timeseries_forecast(
reasoning: str,
input_data: list[float | int | dict[str, Any]],
forecast_window: int = 24,
frequency: str = "hourly",
input_length: int | None = None,
target_column: str = "value",
) -> str:
"""Forecast future values for time series data using advanced deep learning models.
This tool uses state-of-the-art deep learning models (currently transformer based) to predict future values based on historical time series data.
Perfect for sales forecasting, demand planning, trend analysis, and business intelligence.
Args:
reasoning: Explanation of why forecasting is needed and what insights are expected
input_data: Historical time series data as list of numbers or structured data with timestamps
forecast_window: Number of future time points to predict (1-200, default: 24)
frequency: Time series frequency - hourly, daily, weekly, monthly, etc.
input_length: Optional limit on how much historical data to use for prediction
target_column: Column name to forecast for structured data (default: 'value')
Returns:
str: Formatted forecast results with predictions, statistics, and interpretation guidance
Examples:
- Sales forecasting: Predict next month's daily sales based on historical data
- Demand planning: Forecast product demand for inventory management
- Financial planning: Predict revenue, costs, or other financial metrics
- Operational planning: Forecast website traffic, resource usage, etc.
"""
# Get service URL from config
service_url = config.get().timeseries_forecasting_service_url
log(f"Time Series Forecast: {reasoning}")
log(f"Input data points: {len(input_data)}, Forecast window: {forecast_window}, Frequency: {frequency}")
# Validate input data
if not input_data:
return "Error: Input data cannot be empty. Please provide historical time series data."
if len(input_data) < 3:
return "Error: Need at least 3 data points for reliable forecasting. Please provide more historical data."
# Check service availability
if not _check_service_health(service_url):
return """Time Series Forecasting Service Unavailable. The time series forecasting service is not running or not in service. """
# Call the forecasting service
result = _call_timeseries_service(
service_url=service_url,
input_data=input_data,
forecast_window=forecast_window,
frequency=frequency,
input_length=input_length,
target_column=target_column,
)
# Format and return the result
return _format_forecast_result(result, reasoning, len(input_data))
================================================
FILE: openchatbi/utils.py
================================================
"""Utility functions for OpenChatBI."""
import json
import sys
import uuid
from pathlib import Path
from typing import Any
from fastapi import HTTPException
from fastapi.responses import FileResponse
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, RemoveMessage, ToolMessage
from langchain_core.vectorstores import VectorStore
from rank_bm25 import BM25Okapi
from regex import regex
from openchatbi.graph_state import AgentState
from openchatbi.text_segmenter import _segmenter
def log(args) -> None:
"""Log messages to stderr for debugging."""
print(args, file=sys.stderr, flush=True)
def get_text_from_content(content: str | list[str | dict]) -> str:
"""Extract text from various content formats.
Args:
content: String, list of strings, or list of dicts with 'text' key.
Returns:
str: Extracted text content.
"""
if isinstance(content, str):
return content
elif isinstance(content, list):
if isinstance(content[0], str):
return "".join(content)
elif isinstance(content[0], dict):
return "".join([item.get("text", "") for item in content])
return ""
def get_text_from_message_chunk(chunk: AIMessageChunk) -> str:
"""Extract content from an AIMessageChunk.
Args:
chunk (AIMessageChunk): The message chunk to extract text from.
Returns:
str: Extracted text content or empty string.
"""
if not isinstance(chunk, AIMessageChunk) or not hasattr(chunk, "content") or not chunk.content:
return ""
return get_text_from_content(chunk.content)
def extract_json_from_answer(answer: str) -> dict:
"""Extract the first JSON object from a string answer.
Args:
answer (str): String that may contain JSON objects.
Returns:
dict: Parsed JSON object or empty dict if none found.
"""
pattern = regex.compile(r"\{(?:[^{}]+|(?R))*\}")
matches = pattern.findall(answer)
json_result = matches[0] if matches else "{}"
return json.loads(json_result)
def get_report_download_response(filename: str) -> FileResponse:
"""Get FileResponse for downloading a report file.
Args:
filename: The filename of the report to download
Returns:
FileResponse: Response object for file download
Raises:
HTTPException: Various HTTP errors for invalid requests
"""
try:
# Import config here to avoid circular imports
from openchatbi import config
# Get report directory from config
report_dir = config.get().report_directory
file_path = Path(report_dir) / filename
# Check if file exists and is within the report directory
if not file_path.exists():
raise HTTPException(status_code=404, detail="Report file not found")
if not file_path.is_file():
raise HTTPException(status_code=400, detail="Invalid file path")
# Ensure the file is within the report directory (security check)
try:
file_path.resolve().relative_to(Path(report_dir).resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied") from None
# Determine media type based on file extension
media_type_map = {
".md": "text/markdown",
".csv": "text/csv",
".txt": "text/plain",
".json": "application/json",
".html": "text/html",
".xml": "application/xml",
}
file_extension = file_path.suffix.lower()
media_type = media_type_map.get(file_extension, "application/octet-stream")
return FileResponse(path=str(file_path), media_type=media_type, filename=filename)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to download report: {str(e)}") from e
def _create_chroma_from_texts(
texts: list[str],
embedding,
collection_name: str,
metadatas,
collection_metadata: dict,
chroma_dir: str,
):
"""Helper function to create Chroma client from texts."""
return Chroma.from_texts(
texts,
embedding,
metadatas=metadatas,
collection_name=collection_name,
collection_metadata=collection_metadata,
persist_directory=chroma_dir,
)
def create_vector_db(
texts: list[str],
embedding=None,
collection_name: str = "langchain",
metadatas=None,
collection_metadata: dict = None,
chroma_db_path: str = None,
) -> VectorStore:
"""Create or reuse a Chroma vector database.
Args:
texts (List[str]): Text documents to index.
embedding: Embedding function to use.
collection_name (str): Name of the collection.
metadatas: Metadata for each document.
collection_metadata (dict): Collection-level metadata.
chroma_db_path (str): Path to chroma database file.
Returns:
Chroma: Vector database instance.
"""
# fallback to Simple vector store using BM25 if no embedding model configured
if not embedding:
return SimpleStore(texts, metadatas)
chroma_dir = chroma_db_path or "./.chroma_db"
client = Chroma(
collection_name,
persist_directory=chroma_dir,
embedding_function=embedding,
collection_metadata=collection_metadata,
)
use_cache = False
existing_docs = None
try:
# Try to get documents to check if collection exists and has content
existing_docs = client.get()
if not existing_docs["documents"]:
print(f"Init new client from text for {collection_name}...")
else:
# Check if cached texts match the input texts
cached_texts = existing_docs["documents"]
# Compare texts: check count first, then content
if len(cached_texts) != len(texts):
print(
f"Texts count mismatch for {collection_name} "
f"(cached: {len(cached_texts)}, input: {len(texts)}). Recreating collection..."
)
else:
# Compare content by sorting both lists to handle order differences
sorted_cached = sorted(cached_texts)
sorted_input = sorted(texts)
if sorted_cached != sorted_input:
print(f"Cache content mismatch for {collection_name}. Recreating collection...")
else:
print(f"Re-use collection for {collection_name}")
use_cache = True
except Exception:
# If collection doesn't exist or any error, create new one
print(f"Init new client from text for {collection_name}...")
if not use_cache:
# Clear existing collection before recreating to avoid data duplication
if existing_docs and existing_docs["documents"]:
try:
client.reset_collection()
print(f"Cleared existing collection {collection_name} before recreating...")
except Exception as e:
# If reset fails, log and continue with recreation
print(f"Warning: Failed to clear collection {collection_name}: {e}")
client = _create_chroma_from_texts(
texts, embedding, collection_name, metadatas, collection_metadata, chroma_dir
)
return client
def recover_incomplete_tool_calls(state: AgentState) -> list:
"""Recover from incomplete tool calls by creating message operations to insert ToolMessages correctly.
When the graph execution is interrupted (e.g., by kill or app restart)
during tool execution, the state can end up with AIMessage containing
tool_calls but no corresponding ToolMessage responses. This function
detects such cases and creates the necessary message operations to insert
failure ToolMessages in the correct position (right after the AIMessage).
Args:
state (AgentState): The current graph state containing messages.
Returns:
list: Message operations to insert recovery ToolMessages, or empty list if no recovery needed.
"""
messages = state.get("messages", [])
if not messages:
return []
# Find the last AIMessage with tool_calls
last_ai_message = None
last_ai_index = -1
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], AIMessage) and messages[i].tool_calls:
last_ai_message = messages[i]
last_ai_index = i
break
if not last_ai_message:
return []
# Check if there are any ToolMessages after this AIMessage
tool_call_ids = {call["id"] for call in last_ai_message.tool_calls}
handled_tool_call_ids = set()
# Look for ToolMessages that respond to these tool calls
for msg in messages[last_ai_index + 1 :]:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
handled_tool_call_ids.add(msg.tool_call_id)
# Find unhandled tool calls
unhandled_tool_call_ids = tool_call_ids - handled_tool_call_ids
if not unhandled_tool_call_ids:
return [] # All tool calls have responses
# Create failure ToolMessages for unhandled tool calls
recovery_messages = []
for tool_call in last_ai_message.tool_calls:
if tool_call["id"] in unhandled_tool_call_ids:
failure_msg = ToolMessage(
content=f"Tool `{tool_call['name']}` execution was interrupted due to system restart or process termination. Please retry the operation.",
tool_call_id=tool_call["id"],
)
recovery_messages.append(failure_msg)
# Build operations to insert recovery messages in correct position
operations = []
messages_after_ai = messages[last_ai_index + 1 :]
# Collect IDs that will be removed
removed_ids = set()
# If there are messages after the AIMessage, we need to remove them first
if messages_after_ai:
for msg in messages_after_ai:
operations.append(RemoveMessage(id=msg.id))
removed_ids.add(msg.id)
# Add recovery messages (they will be inserted right after the AIMessage)
operations.extend(recovery_messages)
# Re-add the messages that were after the AIMessage (if any)
# CRITICAL: Must regenerate Message ids if matches a RemoveMessage to prevent RemoveMessage from being cancelled
if messages_after_ai:
for msg in messages_after_ai:
# Only regenerate ID if this message's ID was removed
if msg.id in removed_ids:
# Create a copy with new ID to prevent the RemoveMessage from being discarded
new_msg = msg.model_copy(update={"id": str(uuid.uuid4())})
operations.append(new_msg)
else:
# Keep original message as-is if ID wasn't removed
operations.append(msg)
log(f"Recovered {len(recovery_messages)} incomplete tool calls")
return operations
class SimpleStore(VectorStore):
"""Simple vector store using BM25 for text retrieval without embeddings."""
def __init__(
self,
texts: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
):
"""Initialize SimpleStore with texts.
Args:
texts: List of text documents to store.
metadatas: Optional list of metadata dicts for each document.
ids: Optional list of IDs for each document.
"""
self.texts = texts
self.metadatas = metadatas or [{} for _ in texts]
self.ids = ids or [str(uuid.uuid4()) for _ in texts]
# Create Document objects
self.documents = [
Document(id=doc_id, page_content=text, metadata=meta)
for doc_id, text, meta in zip(self.ids, self.texts, self.metadatas)
]
# Tokenize texts and create BM25 index
self.tokenized_corpus = [self._tokenize(text) for text in texts]
# BM25Okapi doesn't support empty corpus, so set to None if empty
self.bm25 = BM25Okapi(self.tokenized_corpus) if texts else None
def _tokenize(self, text: str) -> list[str]:
"""Tokenize text for BM25 indexing using TextSegmenter.
Args:
text: Text to tokenize.
Returns:
List of tokens.
"""
return _segmenter.cut(text)
def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> list[Document]:
"""Search for documents similar to the query using BM25.
Args:
query: Query text.
k: Number of documents to return.
**kwargs: Additional arguments (unused).
Returns:
List of most similar Document objects.
"""
if not self.texts:
return []
# Tokenize query
tokenized_query = self._tokenize(query)
# Get BM25 scores
scores = self.bm25.get_scores(tokenized_query)
# Get top-k indices
top_k = min(k, len(scores))
top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
# Return corresponding documents
return [self.documents[i] for i in top_k_indices]
def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> list[tuple[Document, float]]:
"""Search for documents similar to the query with BM25 scores.
Args:
query: Query text.
k: Number of documents to return.
**kwargs: Additional arguments (unused).
Returns:
List of (Document, score) tuples.
"""
if not self.texts:
return []
# Tokenize query
tokenized_query = self._tokenize(query)
# Get BM25 scores
scores = self.bm25.get_scores(tokenized_query)
# Get top-k items
top_k = min(k, len(scores))
top_k_items = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k]
# Return (Document, score) tuples
return [(self.documents[i], score) for i, score in top_k_items]
def _select_relevance_score_fn(self):
"""Return relevance score function for BM25.
BM25 scores are already relevance scores, so return identity function.
"""
return lambda score: score
def add_texts(
self,
texts: list[str],
metadatas: list[dict] | None = None,
*,
ids: list[str] | None = None,
**kwargs: Any,
) -> list[str]:
"""Add texts to the store.
Args:
texts: Texts to add.
metadatas: Optional metadata for each text.
ids: Optional IDs for each text.
**kwargs: Additional arguments (unused).
Returns:
List of IDs of added texts.
"""
if metadatas is None:
metadatas = [{} for _ in texts]
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
# Add to existing data
self.texts.extend(texts)
self.metadatas.extend(metadatas)
self.ids.extend(ids)
# Create new Document objects
new_documents = [
Document(id=doc_id, page_content=text, metadata=meta) for doc_id, text, meta in zip(ids, texts, metadatas)
]
self.documents.extend(new_documents)
# Update BM25 index
new_tokenized = [self._tokenize(text) for text in texts]
self.tokenized_corpus.extend(new_tokenized)
self.bm25 = BM25Okapi(self.tokenized_corpus)
return ids
def delete(self, ids: list[str] | None = None, **kwargs: Any) -> bool | None:
"""Delete documents by IDs.
Args:
ids: List of document IDs to delete.
**kwargs: Additional arguments (unused).
Returns:
True if deletion successful, False otherwise.
"""
if ids is None:
return False
# Find indices to delete
indices_to_delete = [i for i, doc_id in enumerate(self.ids) if doc_id in ids]
if not indices_to_delete:
return False
# Remove items in reverse order to maintain indices
for idx in sorted(indices_to_delete, reverse=True):
del self.texts[idx]
del self.metadatas[idx]
del self.ids[idx]
del self.documents[idx]
del self.tokenized_corpus[idx]
# Rebuild BM25 index
if self.tokenized_corpus:
self.bm25 = BM25Okapi(self.tokenized_corpus)
else:
self.bm25 = None
return True
def get_by_ids(self, ids: list[str], /) -> list[Document]:
"""Get documents by their IDs.
Args:
ids: List of document IDs to retrieve.
Returns:
List of Document objects.
"""
id_to_doc = {doc.id: doc for doc in self.documents}
return [id_to_doc[doc_id] for doc_id in ids if doc_id in id_to_doc]
@classmethod
def from_texts(
cls,
texts: list[str],
embedding: Any = None, # Unused but required by interface
metadatas: list[dict] | None = None,
*,
ids: list[str] | None = None,
**kwargs: Any,
) -> "SimpleStore":
"""Create SimpleStore from texts.
Args:
texts: List of texts.
embedding: Unused (SimpleStore doesn't use embeddings).
metadatas: Optional metadata for each text.
ids: Optional IDs for each text.
**kwargs: Additional arguments (unused).
Returns:
SimpleStore instance.
"""
return cls(texts, metadatas, ids)
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of `Document` objects to return.
fetch_k: Number of `Document` objects to fetch to pass to MMR algorithm.
lambda_mult: Number between `0` and `1` that determines the degree
of diversity among the results with `0` corresponding
to maximum diversity and `1` to minimum diversity.
**kwargs: Arguments to pass to the search method.
Returns:
List of `Document` objects selected by maximal marginal relevance.
"""
if not self.texts:
return []
# Get initial candidates using BM25 similarity search
candidates = self.similarity_search_with_score(query, k=fetch_k, **kwargs)
if not candidates:
return []
if len(candidates) <= k:
return [doc for doc, _ in candidates]
# Normalize BM25 scores to [0, 1] for proper MMR calculation
scores = [score for _, score in candidates]
min_score = min(scores) if scores else 0
max_score = max(scores) if scores else 1
score_range = max_score - min_score if max_score > min_score else 1
normalized_candidates = [(doc, (score - min_score) / score_range) for doc, score in candidates]
# MMR implementation following standard algorithm
selected = []
remaining = list(range(len(normalized_candidates)))
# Select documents iteratively using MMR formula
while len(selected) < k and remaining:
best_mmr_score = float("-inf")
best_idx = -1
best_remaining_idx = -1
for i, doc_idx in enumerate(remaining):
candidate_doc, relevance_score = normalized_candidates[doc_idx]
# Calculate maximum similarity to already selected documents
max_similarity = 0.0
if selected:
max_similarity = max(
self._calculate_similarity(candidate_doc, normalized_candidates[sel_idx][0])
for sel_idx in selected
)
# Standard MMR formula: λ * Sim(q, d) - (1-λ) * max(Sim(d, s)) for s in selected
mmr_score = lambda_mult * relevance_score - (1 - lambda_mult) * max_similarity
if mmr_score > best_mmr_score:
best_mmr_score = mmr_score
best_idx = doc_idx
best_remaining_idx = i
if best_idx != -1:
selected.append(best_idx)
remaining.pop(best_remaining_idx)
return [normalized_candidates[idx][0] for idx in selected]
def _calculate_similarity(self, doc1: Document, doc2: Document) -> float:
"""Calculate similarity between two documents using Jaccard similarity.
Args:
doc1: First document.
doc2: Second document.
Returns:
Similarity score between 0 and 1 (higher means more similar).
"""
tokens1 = set(self._tokenize(doc1.page_content))
tokens2 = set(self._tokenize(doc2.page_content))
# Calculate Jaccard similarity
intersection = len(tokens1 & tokens2)
union = len(tokens1 | tokens2)
return intersection / union if union > 0 else 0.0
================================================
FILE: pyproject.toml
================================================
[project]
name = "openchatbi"
version = "0.2.2"
description = "OpenChatBI - Natural language business intelligence powered by LLMs for intuitive data analysis and SQL generation"
authors = [
{ name = "Yu Zhong", email = "zhongyu8@gmail.com" },
]
license = { text = "MIT" }
readme = "README.md"
keywords = [
"business intelligence",
"bi",
"analytics",
"llm",
"gpt",
"ai",
"machine learning",
"nlp",
"text2sql",
"agent",
"query data",
"talk to data",
"analyze data",
"data agent",
"database",
"langchain",
"langgraph",
"natural language",
"conversational ai",
"timeseries",
"forecasting",
"prediction"
]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: End Users/Desktop",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Database",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Office/Business",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Operating System :: OS Independent",
"License :: OSI Approved :: MIT License",
]
requires-python = ">=3.11,<4.0"
dependencies = [
"requests>=2.31.0,<3.0.0",
"langgraph>=0.4.7,<1.0.0",
"langchain-openai>=0.3.18,<1.0.0",
"langchain-anthropic>=0.3.13,<1.0.0",
"langchain-community>=0.3.27,<1.0.0",
"langgraph-checkpoint-sqlite>=2.0.11",
"langchain-chroma>=0.2.5",
"langchain-mcp-adapters>=0.1.9,<0.2.0",
"langmem>=0.0.29",
"sqlalchemy>=2.0.41,<3.0.0",
"sqlalchemy-trino>=0.5.0",
"aiosqlite>=0.21.0",
"pyhive[presto]>=0.7.0",
"rank-bm25>=0.2.2,<1.0.0",
"python-levenshtein>=0.27.1",
"gradio>=5.43.1,<6.0.0",
"streamlit>=1.49.1,<2.0.0",
"RestrictedPython>=8.0,<9.0",
"docker>=7.0.0,<8.0.0",
"pandas>=2.2.0,<3.0.0",
"numpy>=2.3.0,<3.0.0",
"matplotlib>=3.10.6,<4.0.0",
"seaborn>=0.13.0,<1.0.0",
"plotly>=5.17.0,<6.0.0",
"json5>=0.10.0,<1.0.0",
"jieba>=0.42.1", # Note: jieba is not compatible with Python 3.12+
]
[project.urls]
Homepage = "https://github.com/zhongyu09/openchatbi"
Repository = "https://github.com/zhongyu09/openchatbi"
Documentation = "https://github.com/zhongyu09/openchatbi/tree/main"
"Bug Tracker" = "https://github.com/zhongyu09/openchatbi/issues"
[project.optional-dependencies]
docs = [
"sphinx>=8.2.3,<9.0.0",
"sphinx-rtd-theme>=3.0.0,<4.0.0",
"sphinx-autodoc-typehints>=2.5.0,<3.0.0",
"myst_parser",
"autodoc-pydantic",
]
test = [
"pytest>=7.4.0,<9.0.0",
"pytest-mock>=3.14.0,<4.0.0",
"pytest-asyncio>=0.23.8,<1.0.0",
"pytest-sugar>=1.0.0,<2.0.0",
"pytest-cov>=6.0.0,<7.0.0",
"aioresponses>=0.7.7,<1.0.0",
"responses>=0.25.3,<1.0.0",
"langsmith[pytest]>=0.4.8,<1.0.0",
"openevals>=0.1.0,<1.0.0",
]
dev = [
"openchatbi[test,docs]",
"black>=24.10.0,<25.0.0",
"mypy>=1.13.0,<2.0.0",
"ruff>=0.8.0,<1.0.0",
"pre-commit>=4.0.1,<5.0.0",
"bandit>=1.8.6,<2.0.0",
"types-setuptools>=75.6.0.20241126",
"twine>=6.0.0,<7.0.0",
]
[tool.uv]
managed = true
dev-dependencies = [
"openchatbi[dev]",
]
[build-system]
requires = ["hatchling>=1.26.0"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["openchatbi"]
[tool.hatch.build.targets.sdist]
include = [
"/openchatbi",
"/tests",
"/README.md",
"/LICENSE",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.black]
line-length = 120
target-version = ["py311"]
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''
skip-string-normalization = false
skip-magic-trailing-comma = false
preview = false
[tool.ruff]
line-length = 120
target-version = "py311"
exclude = [
".git",
".mypy_cache",
".tox",
".venv",
"_build",
"buck-out",
"build",
"dist",
]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
"UP", # pyupgrade
]
ignore = [
"E501", # line too long, handled by black
"B008", # do not perform function calls in argument defaults
"C901", # too complex
]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tests/**/*" = ["B011"]
[tool.mypy]
python_version = "3.11"
strict = true
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true
ignore_missing_imports = true
show_error_codes = true
[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
"--strict-markers",
"--strict-config",
"--cov=openchatbi",
"--cov-report=term-missing",
"--cov-report=html",
"--cov-report=xml",
]
testpaths = ["tests"]
markers = [
"unit: Unit tests",
"integration: Integration tests",
"slow: Slow tests that may take several seconds",
"requires_db: Tests that require database connection",
"requires_llm: Tests that require LLM service",
"asyncio: Asynchronous tests"
]
filterwarnings = [
"error",
"ignore::UserWarning",
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["openchatbi"]
omit = [
"*/tests/*",
"*/test_*.py",
"setup.py",
]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
]
================================================
FILE: run_streamlit_ui.py
================================================
#!/usr/bin/env python3
"""
Launch script for the Streamlit-based OpenChatBI interface.
Usage:
python run_streamlit_ui.py
This will start the Streamlit server on http://localhost:8501
"""
import os
import subprocess
import sys
def main():
"""Launch the Streamlit UI"""
# Change to the project directory
project_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(project_dir)
print("🚀 Starting OpenChatBI Streamlit UI...")
print("📍 URL: http://localhost:8501")
print("⏹️ Press Ctrl+C to stop the server")
print("-" * 50)
try:
# Run streamlit with the new UI file
subprocess.run(
[
sys.executable,
"-m",
"streamlit",
"run",
"sample_ui/streamlit_ui.py",
"--server.port=8501",
"--server.address=localhost",
],
check=True,
)
except KeyboardInterrupt:
print("\n👋 Stopping Streamlit server...")
except subprocess.CalledProcessError as e:
print(f"❌ Error starting Streamlit: {e}")
print("\n💡 Make sure Streamlit is installed:")
print(" pip install streamlit")
except FileNotFoundError:
print("❌ Python or Streamlit not found")
print("\n💡 Make sure Python and Streamlit are installed:")
print(" pip install streamlit")
if __name__ == "__main__":
main()
================================================
FILE: run_tests.py
================================================
#!/usr/bin/env python3
"""Test runner script for OpenChatBI."""
import argparse
import subprocess
import sys
def run_command(cmd, description):
"""Run a command and return the result."""
print(f"\\n{'=' * 60}")
print(f"Running: {description}")
print(f"Command: {' '.join(cmd)}")
print(f"{'=' * 60}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.stdout:
print("STDOUT:")
print(result.stdout)
if result.stderr:
print("STDERR:")
print(result.stderr)
if result.returncode != 0:
print(f"❌ {description} failed with return code {result.returncode}")
return False
else:
print(f"✅ {description} passed")
return True
def main():
"""Main test runner function."""
parser = argparse.ArgumentParser(description="Run OpenChatBI tests")
parser.add_argument("--unit", action="store_true", help="Run only unit tests")
parser.add_argument("--integration", action="store_true", help="Run only integration tests")
parser.add_argument("--coverage", action="store_true", help="Run with coverage report")
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
parser.add_argument("--fast", action="store_true", help="Skip slow tests")
parser.add_argument("--lint", action="store_true", help="Run linting checks")
parser.add_argument("--type-check", action="store_true", help="Run type checking")
parser.add_argument("--all", action="store_true", help="Run all checks (tests, lint, type-check)")
parser.add_argument("--file", help="Run specific test file")
args = parser.parse_args()
# Determine test command
base_cmd = ["uv", "run", "pytest"]
if args.verbose:
base_cmd.append("-v")
if args.coverage:
base_cmd.extend(["--cov=openchatbi", "--cov-report=html", "--cov-report=term-missing"])
if args.unit:
base_cmd.extend(["-m", "unit"])
elif args.integration:
base_cmd.extend(["-m", "integration"])
elif args.fast:
base_cmd.extend(["-m", "not slow"])
if args.file:
base_cmd.append(f"tests/{args.file}")
success = True
# Run tests
if not args.lint and not args.type_check:
success &= run_command(base_cmd, "Unit Tests")
# Run linting if requested
if args.lint or args.all:
lint_commands = [
(["uv", "run", "black", "--check", "."], "Black formatting check"),
(["uv", "run", "isort", "--check-only", "."], "Import sorting check"),
(["uv", "run", "ruff", "check", "."], "Ruff linting"),
(["uv", "run", "bandit", "-r", "openchatbi/"], "Security scanning"),
]
for cmd, desc in lint_commands:
success &= run_command(cmd, desc)
# Run type checking if requested
if args.type_check or args.all:
success &= run_command(["uv", "run", "mypy", "openchatbi/"], "Type checking")
# Run all tests if --all is specified
if args.all:
test_commands = [
(["uv", "run", "pytest", "-m", "unit", "-v"], "Unit Tests"),
(["uv", "run", "pytest", "-m", "integration", "-v"], "Integration Tests"),
(["uv", "run", "pytest", "--cov=openchatbi", "--cov-report=html"], "Coverage Report"),
]
for cmd, desc in test_commands:
success &= run_command(cmd, desc)
# Print summary
print(f"\\n{'=' * 60}")
if success:
print("🎉 All checks passed!")
sys.exit(0)
else:
print("❌ Some checks failed!")
sys.exit(1)
if __name__ == "__main__":
main()
================================================
FILE: sample_api/async_api.py
================================================
"""Async API for streaming chat responses from OpenChatBI."""
import asyncio
from typing import Any
from collections import defaultdict
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from openchatbi import config
from openchatbi.agent_graph import build_agent_graph_async
from openchatbi.utils import get_report_download_response
# Session state storage: session_id -> state
sessions = defaultdict(dict)
# Graphs keyed by provider name
graphs: dict[str, Any] = {}
graphs_lock = asyncio.Lock()
async def get_or_build_graph(provider: str | None):
"""Get (or lazily build) a graph for the requested provider."""
key = provider or "__default__"
if key in graphs:
return graphs[key]
async with graphs_lock:
if key in graphs:
return graphs[key]
graphs[key] = await build_agent_graph_async(config.get().catalog_store, llm_provider=provider)
return graphs[key]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan events."""
# Startup: Initialize the async graph
graphs["__default__"] = await build_agent_graph_async(config.get().catalog_store)
yield
# Shutdown: cleanup if needed
graphs.clear()
app = FastAPI(lifespan=lifespan)
class UserRequest(BaseModel):
"""Request model for streaming chat."""
input: str
user_id: str | None = "default"
session_id: str | None = "default"
provider: str | None = None
@app.post("/chat/stream")
async def chat_stream(req: UserRequest):
"""Stream chat responses from the agent graph."""
user_id = req.user_id or "default"
session_id = req.session_id or "default"
provider = req.provider
# Create user-session ID just like in UI
user_session_id = f"{user_id}-{session_id}"
stream_input = {"messages": [("user", req.input)]}
config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}}
try:
graph = await get_or_build_graph(provider)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
async def event_generator():
"""Generate streaming events from the graph."""
async for _namespace, event_type, event_value in graph.astream(
stream_input, config=config, stream_mode=["updates", "messages"], subgraphs=True
):
text = ""
if event_type == "messages":
message_chunk = event_value[0]
if isinstance(message_chunk, AIMessageChunk):
text = message_chunk.content
elif event_value.get("llm_node") and event_value["llm_node"].get("final_answer"):
text = event_value["llm_node"]["final_answer"]
if text:
yield text
return StreamingResponse(event_generator(), media_type="text/plain")
@app.get("/user/{user_id}/memories")
async def get_user_memories(user_id: str):
"""Get all memories for a specific user."""
try:
# Import required modules for memory access
import json
from openchatbi.tool.memory import get_async_memory_store
# Get the async memory store
memory_store = await get_async_memory_store()
memories = []
namespace = ("memories", user_id)
try:
# Search for all memories for this user
search_results = memory_store.search(namespace)
for item in search_results:
# Parse the memory data
try:
content = json.loads(item.value.decode("utf-8")) if isinstance(item.value, bytes) else item.value
except (json.JSONDecodeError, AttributeError):
content = str(item.value)
memory_data = {
"key": item.key,
"content": content,
"namespace": str(namespace),
"created_at": getattr(item, "created_at", "Unknown"),
"updated_at": getattr(item, "updated_at", "Unknown"),
}
memories.append(memory_data)
return {"user_id": user_id, "total_memories": len(memories), "memories": memories}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving memories: {str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to access memory store: {str(e)}") from e
@app.get("/api/download/report/{filename}")
async def download_report(filename: str):
"""Download a saved report file."""
return get_report_download_response(filename)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
================================================
FILE: sample_ui/async_graph_manager.py
================================================
"""Common AsyncGraphManager for UIs."""
from typing import Any
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from openchatbi import config
from openchatbi.agent_graph import build_agent_graph_async
from openchatbi.tool.memory import cleanup_async_memory_store, get_async_memory_store, setup_async_memory_store
from openchatbi.utils import log
class AsyncGraphManager:
"""Manages the async graph and checkpointer lifecycle"""
def __init__(self):
self.checkpointer = None
self.graph = None # Default graph (backwards compatible)
self.graphs: dict[str, Any] = {}
self._context_manager = None
self._memory_store = None
self._initialized = False
async def initialize(self):
"""Initialize the graph and checkpointer"""
if self._initialized:
return
try:
# Setup async memory store
await setup_async_memory_store()
# Initialize checkpointer
self._context_manager = AsyncSqliteSaver.from_conn_string("checkpoints.db")
self.checkpointer = await self._context_manager.__aenter__()
# Cache store for graph builds
self._memory_store = await get_async_memory_store()
self._initialized = True
# Build default graph for backwards compatibility
self.graph = await self.get_graph()
log("Graph initialized successfully")
except Exception as e:
self._initialized = False
log(f"Failed to initialize graph: {e}")
raise
async def get_graph(self, llm_provider: str | None = None):
"""Get or build a graph for the requested LLM provider."""
if not self._initialized:
await self.initialize()
key = llm_provider or "__default__"
if key in self.graphs:
return self.graphs[key]
graph = await build_agent_graph_async(
config.get().catalog_store,
checkpointer=self.checkpointer,
memory_store=self._memory_store,
memory_tools=None, # Let graph builder create provider-appropriate tools
llm_provider=llm_provider,
)
self.graphs[key] = graph
return graph
async def cleanup(self):
"""Cleanup resources"""
if self.checkpointer is not None and self._context_manager is not None:
try:
await self._context_manager.__aexit__(None, None, None)
await cleanup_async_memory_store()
log("Graph cleaned up successfully")
except Exception as e:
log(f"Error during cleanup: {e}")
finally:
self.checkpointer = None
self.graph = None
self.graphs = {}
self._context_manager = None
self._memory_store = None
self._initialized = False
================================================
FILE: sample_ui/memory_ui.py
================================================
"""Memory listing UI for OpenChatBI using FastAPI and Gradio."""
import json
from typing import Any
import gradio as gr
import uvicorn
from fastapi import FastAPI
from sample_ui.style import custom_css
def get_thread_memory_store() -> Any:
"""Create a thread-safe memory store connection."""
try:
import pysqlite3 as sqlite3
except ImportError:
import sqlite3
from langgraph.store.sqlite import SqliteStore
from openchatbi import config
conn = sqlite3.connect("memory.db", check_same_thread=False)
conn.isolation_level = None # Use autocommit mode to avoid transaction conflicts
store = SqliteStore(conn, index={"dims": 1536, "embed": config.get().embedding_model, "fields": ["text"]})
try:
store.setup()
except Exception:
pass # Store might already be set up
return store, conn
def list_all_memories() -> list[dict[str, Any]]:
"""
Retrieve all memories from the memory store.
Returns:
List of memory items with their metadata
"""
try:
memory_store, conn = get_thread_memory_store()
memories = []
try:
# Use search with partial namespace to find all memory items
items = memory_store.search(("memories",), limit=1000)
for item in items:
memory_data = {
"namespace": item.namespace,
"key": item.key,
"value": item.value,
"created_at": getattr(item, "created_at", "Unknown"),
"updated_at": getattr(item, "updated_at", "Unknown"),
}
memories.append(memory_data)
except Exception as e:
return [{"error": f"Failed to retrieve memories: {str(e)}"}]
finally:
conn.close()
return memories
except Exception as e:
return [{"error": f"Failed to access memory store: {str(e)}"}]
def format_memories_for_display(memories: list[dict[str, Any]]) -> str:
"""
Format memories for display in the Gradio interface.
Args:
memories: List of memory items
Returns:
Formatted string for display
"""
if not memories:
return "No memories found."
if len(memories) == 1 and "error" in memories[0]:
return f"Error: {memories[0]['error']}"
formatted = []
for i, memory in enumerate(memories, 1):
if "error" in memory:
formatted.append(f"**Error:** {memory['error']}")
continue
formatted.append(f"## Memory {i}")
formatted.append(f"**Namespace:** {memory['namespace']}")
formatted.append(f"**Key:** {memory['key']}")
# Format the value nicely
value = memory["value"]
if isinstance(value, dict):
try:
value_str = json.dumps(value, indent=2)
formatted.append(f"**Content:**\n```json\n{value_str}\n```")
except:
formatted.append(f"**Content:** {str(value)}")
else:
formatted.append(f"**Content:** {str(value)}")
formatted.append(f"**Created:** {memory['created_at']}")
formatted.append(f"**Updated:** {memory['updated_at']}")
formatted.append("---")
return "\n".join(formatted)
def refresh_memories() -> list[list[str]]:
"""Refresh and return formatted memories."""
memories = list_all_memories()
return format_memories_for_display(memories)
def delete_memory_by_key(namespace_str: str, key: str) -> str:
"""
Delete a memory by namespace and key.
Args:
namespace_str: String representation of namespace (e.g., "('memories', 'user1')")
key: Memory key to delete
Returns:
Status message
"""
try:
import ast
memory_store, conn = get_thread_memory_store()
try:
# Parse namespace string back to tuple
namespace = ast.literal_eval(namespace_str)
# Delete the item
memory_store.delete(namespace, key)
return f"Successfully deleted memory: {key} from namespace {namespace}"
finally:
conn.close()
except Exception as e:
return f"Failed to delete memory: {str(e)}"
# ---------- FastAPI ----------
app = FastAPI()
# ---------- Gradio UI ----------
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🧠 Memory Store Viewer")
gr.Markdown("View and manage long-term memories stored in the OpenChatBI system.")
with gr.Row():
with gr.Column(scale=3):
memories_display = gr.Markdown(value=refresh_memories(), elem_id="memories-display")
with gr.Column(scale=1):
gr.Markdown("### Actions")
refresh_btn = gr.Button("🔄 Refresh Memories", variant="primary")
gr.Markdown("### Delete Memory")
namespace_input = gr.Textbox(
label="Namespace",
placeholder="('memories', 'user_id')",
info="Copy the exact namespace from the memory list",
)
key_input = gr.Textbox(
label="Key", placeholder="memory_key", info="Copy the exact key from the memory list"
)
delete_btn = gr.Button("🗑️ Delete Memory", variant="stop")
delete_status = gr.Textbox(label="Status", interactive=False)
# Event handlers
refresh_btn.click(fn=refresh_memories, outputs=[memories_display])
delete_btn.click(fn=delete_memory_by_key, inputs=[namespace_input, key_input], outputs=[delete_status]).then(
fn=refresh_memories, outputs=[memories_display]
)
# ---------- Application Startup ----------
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/memory")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8001)
================================================
FILE: sample_ui/plotly_utils.py
================================================
"""Plotly utilities for generating charts from visualization DSL."""
from io import StringIO
from typing import Any
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
def create_plotly_chart(data_csv: str, visualization_dsl: dict[str, Any]) -> go.Figure:
"""Create a plotly chart from CSV data and visualization DSL.
Args:
data_csv: CSV string containing the data
visualization_dsl: Dictionary containing chart configuration
Returns:
Plotly Figure object
"""
if not data_csv or not visualization_dsl:
return create_empty_chart("No data available")
if "error" in visualization_dsl:
return create_empty_chart(f"Visualization error: {visualization_dsl['error']}")
try:
# Parse CSV data
df = pd.read_csv(StringIO(data_csv))
if df.empty:
return create_empty_chart("No data to visualize")
chart_type = visualization_dsl.get("chart_type", "table")
config = visualization_dsl.get("config", {})
layout = visualization_dsl.get("layout", {})
# Create chart based on type
if chart_type == "line":
return create_line_chart(df, config, layout)
elif chart_type == "bar":
return create_bar_chart(df, config, layout)
elif chart_type == "pie":
return create_pie_chart(df, config, layout)
elif chart_type == "scatter":
return create_scatter_chart(df, config, layout)
elif chart_type == "histogram":
return create_histogram_chart(df, config, layout)
elif chart_type == "box":
return create_box_chart(df, config, layout)
elif chart_type == "table":
return create_table_chart(df, config, layout)
else:
return create_empty_chart(f"Unsupported chart type: {chart_type}")
except Exception as e:
return create_empty_chart(f"Chart generation error: {str(e)}")
def create_line_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:
"""Create a line chart."""
x_col = config.get("x")
y_col = config.get("y")
color_col = config.get("color")
if not x_col or x_col not in df.columns:
return create_empty_chart("Missing required x column for line chart")
# Handle multiple y columns case
if isinstance(y_col, list):
# Multiple metrics - need to melt the data
if not all(col in df.columns for col in y_col):
return create_empty_chart("Some y columns missing from data")
# Melt the dataframe to long format for multiple series
melted_df = df.melt(id_vars=[x_col], value_vars=y_col, var_name="metric", value_name="value")
fig = px.line(melted_df, x=x_col, y="value", color="metric")
else:
# Single y column
if not y_col or y_col not in df.columns:
return create_empty_chart("Missing required y column for line chart")
# Check if color column exists and is valid
if color_col and color_col in df.columns:
fig = px.line(df, x=x_col, y=y_col, color=color_col)
else:
fig = px.line(df, x=x_col, y=y_col)
fig.update_layout(**layout)
return fig
def create_bar_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:
"""Create a bar chart."""
x_col = config.get("x")
y_col = config.get("y")
if not x_col or x_col not in df.columns:
return create_empty_chart("Missing required x column for bar chart")
# Handle multiple y columns case
if isinstance(y_col, list):
# Multiple metrics - need to melt the data
if not all(col in df.columns for col in y_col):
return create_empty_chart("Some y columns missing from data")
# Melt the dataframe to long format for multiple series
melted_df = df.melt(id_vars=[x_col], value_vars=y_col, var_name="metric", value_name="value")
fig = px.bar(melted_df, x=x_col, y="value", color="metric")
else:
# Single y column
if not y_col or y_col not in df.columns:
return create_empty_chart("Missing required y column for bar chart")
fig = px.bar(df, x=x_col, y=y_col)
fig.update_layout(**layout)
return fig
def create_pie_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:
"""Create a pie chart."""
labels_col = config.get("labels")
values_col = config.get("values")
if not labels_col or not values_col or labels_col not in df.columns or values_col not in df.columns:
return create_empty_chart("Missing required columns for pie chart")
fig = px.pie(df, names=labels_col, values=values_col)
fig.update_layout(**layout)
return fig
def create_scatter_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:
"""Create a scatter plot."""
x_col = config.get("x")
y_col = config.get("y")
if not x_col or not y_col or x_col not in df.columns or y_col not in df.columns:
return create_empty_chart("Missing required columns for scatter plot")
fig = px.scatter(df, x=x_col, y=y_col)
fig.update_layout(**layout)
return fig
def create_histogram_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:
"""Create a histogram."""
x_col = config.get("x")
nbins = config.get("nbins", 20)
if not x_col or x_col not in df.columns:
return create_empty_chart("Missing required column for histogram")
fig = px.histogram(df, x=x_col, nbins=nbins)
fig.update_layout(**layout)
return fig
def create_box_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:
"""Create a box plot."""
y_col = config.get("y")
x_col = config.get("x")
if not y_col or y_col not in df.columns:
return create_empty_chart("Missing required column for box plot")
if x_col and x_col in df.columns:
fig = px.box(df, x=x_col, y=y_col)
else:
fig = px.box(df, y=y_col)
fig.update_layout(**layout)
return fig
def create_table_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:
"""Create a table display."""
columns = config.get("columns", list(df.columns))
# Limit to first 100 rows for display
display_df = df.head(100)
fig = go.Figure(
data=[
go.Table(
header=dict(values=columns, fill_color="lightblue", align="left"),
cells=dict(
values=[display_df[col] for col in columns if col in display_df.columns],
fill_color="white",
align="left",
),
)
]
)
fig.update_layout(**layout)
return fig
def create_empty_chart(message: str) -> go.Figure:
"""Create an empty chart with a message."""
fig = go.Figure()
fig.add_annotation(
text=message,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
font=dict(size=16),
)
fig.update_layout(
title="Chart Generation Issue",
xaxis=dict(showgrid=False, showticklabels=False, zeroline=False),
yaxis=dict(showgrid=False, showticklabels=False, zeroline=False),
)
return fig
def visualization_dsl_to_gradio_plot(data_csv: str, visualization_dsl: dict[str, Any]) -> tuple[go.Figure, str]:
"""Convert visualization DSL to Gradio-compatible plotly figure.
Args:
data_csv: CSV string containing the data
visualization_dsl: Dictionary containing chart configuration
Returns:
Tuple of (plotly figure, description string)
"""
fig = create_plotly_chart(data_csv, visualization_dsl)
if visualization_dsl:
chart_type = visualization_dsl.get("chart_type", "unknown")
layout_title = visualization_dsl.get("layout", {}).get("title", f"{chart_type.title()} Chart")
description = f"Generated {chart_type} visualization: {layout_title}"
else:
description = "Data table view"
return fig, description
def create_inline_chart_markdown(data_csv: str, visualization_dsl: dict[str, Any]) -> str:
"""Create a simplified markdown representation of the chart for inline display.
This creates a text-based summary with a clickable link to show the interactive chart.
"""
if not data_csv or not visualization_dsl:
return "📊 *No visualization data available*"
if "error" in visualization_dsl:
return f"⚠️ *Visualization error: {visualization_dsl['error']}*"
try:
from io import StringIO
import pandas as pd
df = pd.read_csv(StringIO(data_csv))
chart_type = visualization_dsl.get("chart_type", "table")
layout = visualization_dsl.get("layout", {})
title = layout.get("title", f"{chart_type.title()} Chart")
# Create a text summary with key data points and view instruction
summary_lines = [
f"📊 **{title}**",
"",
f"*Chart Type: {chart_type.title()}* | *Data Points: {len(df)} rows, {len(df.columns)} columns*",
"",
"✨ **Interactive chart will appear automatically in the chart panel →**",
"",
]
# Add a small data sample
if len(df) > 0:
summary_lines.append("**Sample Data:**")
summary_lines.append("```")
# Show first few rows in a clean format
sample_df = df.head(3)
summary_lines.append(sample_df.to_string(index=False))
if len(df) > 3:
summary_lines.append(f"... and {len(df) - 3} more rows")
summary_lines.append("```\n")
return "\n".join(summary_lines)
except Exception as e:
return f"⚠️ *Chart generation error: {str(e)}*"
================================================
FILE: sample_ui/simple_ui.py
================================================
"""Simple web UI for OpenChatBI using FastAPI and Gradio."""
from collections import defaultdict
import gradio as gr
import uvicorn
from fastapi import FastAPI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.types import Command
from openchatbi import config
from openchatbi.agent_graph import build_agent_graph_sync
from openchatbi.tool.memory import get_sync_memory_store
from openchatbi.utils import get_report_download_response, log
from sample_ui.style import custom_css
# Session state storage: session_id -> state
session_interrupt = defaultdict(bool)
# Use SqliteSaver for persistence
sqlite_checkpointer_cm = SqliteSaver.from_conn_string("checkpoints.db")
sqlite_checkpointer = sqlite_checkpointer_cm.__enter__()
graph = build_agent_graph_sync(
config.get().catalog_store, checkpointer=sqlite_checkpointer, memory_store=get_sync_memory_store()
)
# ---------- FastAPI ----------
app = FastAPI()
# ---------- Gradio UI ----------
def chat_fn(message: str, history: list[tuple[str, str]], user_id: str = "default", session_id: str = "default") -> str:
"""Chat function for Gradio interface."""
user_session_id = f"{user_id}-{session_id}"
config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}}
if session_interrupt[user_session_id]:
inputs = Command(resume=message)
else:
inputs = {"messages": [{"role": "user", "content": message}]}
# Use synchronous call
result = graph.invoke(inputs, config=config)
state = graph.get_state(config)
if state.interrupts:
log(f"state.interrupts: {state.interrupts}")
output_content = state.interrupts[0].value.get("text")
session_interrupt[user_session_id] = True
else:
session_interrupt[user_session_id] = False
output_content = result["messages"][-1].content
return output_content
# Create Gradio interface with custom CSS and theme
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("## 💬 OpenChatBI Agent Chatbot")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(elem_id="chatbot", label="", bubble_full_width=False, height=600)
msg = gr.Textbox(placeholder="Type a message and press Enter", label="Input", elem_id="msg")
with gr.Column(scale=1):
user_box = gr.Textbox(value="default", label="User ID", interactive=True)
session_box = gr.Textbox(value="default", label="Session ID", interactive=True)
gr.Markdown(
"""
**Instructions**
- Type a message and press Enter to send
- User ID is used for memory isolation between users
- Session ID can be used to differentiate between conversations
""",
elem_id="description",
)
def respond(
message: str, chat_history: list[tuple[str, str]], user_id: str, session_id: str
) -> tuple[str, list[tuple[str, str]]]:
"""Handle response in Gradio chat interface."""
response = chat_fn(message, chat_history, user_id, session_id)
chat_history.append((message, response))
return "", chat_history
msg.submit(respond, [msg, chatbot, user_box, session_box], [msg, chatbot])
# ---------- API Endpoints ----------
@app.get("/api/download/report/{filename}")
def download_report(filename: str):
"""Download a saved report file."""
return get_report_download_response(filename)
# ---------- Application Startup ----------
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/ui")
if __name__ == "__main__":
try:
uvicorn.run(app, host="0.0.0.0", port=8000)
finally:
# Cleanup checkpointer
sqlite_checkpointer_cm.__exit__(None, None, None)
================================================
FILE: sample_ui/streaming_ui.py
================================================
"""Gradio-based Streaming UI for OpenChatBI with real-time chat interface."""
import asyncio
import sys
from collections import defaultdict
from contextlib import asynccontextmanager
import gradio as gr
try:
import pysqlite3 as sqlite3
except ImportError: # pragma: no cover
import sqlite3
from fastapi import FastAPI
from langchain_core.messages import AIMessage
sys.modules["sqlite3"] = sqlite3
from langgraph.types import Command
from openchatbi.utils import get_report_download_response, get_text_from_message_chunk, log
from sample_ui.async_graph_manager import AsyncGraphManager
from sample_ui.plotly_utils import create_inline_chart_markdown, visualization_dsl_to_gradio_plot
from sample_ui.style import custom_css
# Session state storage: user_session_id -> state
session_interrupt = defaultdict(bool)
# Global event loop for async operations (similar to Streamlit approach)
global_event_loop = None
# Global graph manager (similar to Streamlit approach)
graph_manager = AsyncGraphManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Async context manager for FastAPI lifespan"""
# Startup
await graph_manager.initialize()
yield
# Shutdown
await graph_manager.cleanup()
# ---------- FastAPI ----------
app = FastAPI(lifespan=lifespan)
# ---------- Gradio UI functions ----------
def get_or_create_event_loop():
"""Get or create an independent event loop"""
global global_event_loop
if global_event_loop is None or global_event_loop.is_closed():
global_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(global_event_loop)
return global_event_loop
async def _async_respond_helper(message, chat_history, user_id, session_id):
"""
Helper async function that contains the actual async logic.
This will be run in an independent event loop.
Collects all responses and returns them as a list.
"""
responses = [] # Collect all yield values
user_session_id = f"{user_id}-{session_id}"
full_response = ""
plot_figure = None
chart_panel_update = gr.update()
if session_interrupt[user_session_id]:
stream_input = Command(resume=message)
else:
stream_input = {"messages": [{"role": "user", "content": message}]}
config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}}
# Ensure graph is available
if not graph_manager._initialized:
try:
await graph_manager.initialize()
except Exception as e:
log(f"Failed to initialize graph: {e}")
chat_history[-1] = (chat_history[-1][0], f"Error: Failed to initialize system - {str(e)}")
responses.append(("", chat_history, plot_figure, chart_panel_update))
return responses
data_csv = None
# Asynchronously iterate through LangGraph stream
async for _namespace, event_type, event_value in graph_manager.graph.astream(
stream_input, config=config, stream_mode=["updates", "messages"], subgraphs=True, debug=True
):
token = ""
if event_type == "messages":
chunk = event_value[0]
metadata = event_value[1]
# Keep llm node messages only to avoid duplicates
if metadata["langgraph_node"] != "llm_node" or not metadata.get("streaming_tokens", False):
continue
token = get_text_from_message_chunk(chunk)
else:
# Process intermediate graph node updates
if event_value.get("llm_node"):
message_obj = event_value["llm_node"].get("messages")[0]
if message_obj and isinstance(message_obj, AIMessage) and message_obj.tool_calls:
token = f"\nUse tool: {', '.join(tool['name'] for tool in message_obj.tool_calls)}\n"
else:
token = "\n"
elif event_value.get("information_extraction"):
message_obj = event_value["information_extraction"].get("messages")[0]
if message_obj.tool_calls:
token = f"Use tool: {message_obj.tool_calls[0]['name']}\n"
else:
token = f"Rewrite question: {event_value['information_extraction'].get('rewrite_question')}\n"
elif event_value.get("table_selection"):
token = f"Selected tables: {event_value['table_selection'].get('tables')}\n"
elif event_value.get("generate_sql"):
token = f"SQL: \n ```sql \n{event_value['generate_sql'].get('sql')}\n```\n"
elif event_value.get("execute_sql"):
token = "Running SQL...\n"
data_csv = event_value["execute_sql"].get("data")
elif event_value.get("regenerate_sql"):
token = f"SQL: \n ```sql \n{event_value['regenerate_sql'].get('sql')}\n```\n"
elif event_value.get("generate_visualization"):
visualization_dsl = event_value["generate_visualization"].get("visualization_dsl")
# Check for visualization data in the final state and embed in response
if visualization_dsl and "error" not in visualization_dsl and data_csv:
try:
plot_figure, plot_description = visualization_dsl_to_gradio_plot(data_csv, visualization_dsl)
# Add markdown representation to the chat
chart_markdown = create_inline_chart_markdown(data_csv, visualization_dsl)
full_response += f"\n\n{chart_markdown}"
chat_history[-1] = (chat_history[-1][0], full_response)
# Auto-show chart panel when plot is generated
chart_panel_update = gr.update(visible=True)
responses.append(("", chat_history, plot_figure, chart_panel_update))
except Exception as e:
log(f"Visualization generation error: {str(e)}")
full_response += f"\n\n⚠️ Visualization error: {str(e)}"
chat_history[-1] = (chat_history[-1][0], full_response)
responses.append(("", chat_history, plot_figure, chart_panel_update))
# Update chat history with new tokens and collect response
if token:
full_response += token
chat_history[-1] = (chat_history[-1][0], full_response)
responses.append(("", chat_history, plot_figure, chart_panel_update))
# Get final state and check for visualization data
state = await graph_manager.graph.aget_state(config)
final_state_values = state.values
if state.interrupts:
log(f"state.interrupts: {state.interrupts}")
output_content = state.interrupts[0].value.get("text")
if "buttons" in state.interrupts[0].value:
output_content += str(state.interrupts[0].value.get("buttons"))
full_response += output_content
chat_history[-1] = (chat_history[-1][0], full_response)
session_interrupt[user_session_id] = True
responses.append(("", chat_history, plot_figure, chart_panel_update))
else:
session_interrupt[user_session_id] = False
return responses
def respond(message, chat_history, user_id, session_id="default"):
"""
Synchronous callback for Gradio Chatbot with streaming updates.
This function processes user input and streams responses from the LangGraph agent.
Returns: message_input, chat_history, plot_figure, chart_panel_visibility
"""
# Add a placeholder in chat history
chat_history.append((message, ""))
plot_figure = None
chart_panel_update = gr.update()
yield "", chat_history, plot_figure, chart_panel_update # Stream updates to UI
# Get or create independent event loop
loop = get_or_create_event_loop()
# Run the async helper in the independent event loop
try:
responses = loop.run_until_complete(_async_respond_helper(message, chat_history, user_id, session_id))
# Yield all collected responses
for response in responses:
yield response
except Exception as e:
log(f"Error in respond: {e}")
import traceback
traceback.print_exc()
chat_history[-1] = (chat_history[-1][0], f"Error: {str(e)}")
yield "", chat_history, plot_figure, chart_panel_update
# ---------- Memory Management Functions ----------
def list_user_memories(user_id: str) -> str:
"""List all memories for a specific user."""
try:
import json
try:
import pysqlite3 as sqlite3
except ImportError:
import sqlite3
from langgraph.store.sqlite import SqliteStore
from openchatbi import config
# Create a new connection in this thread to avoid SQLite threading issues
conn = sqlite3.connect("memory.db", check_same_thread=False)
conn.isolation_level = None # Use autocommit mode to avoid transaction conflicts
thread_memory_store = SqliteStore(
conn, index={"dims": 1536, "embed": config.get().embedding_model, "fields": ["text"]}
)
try:
thread_memory_store.setup()
except Exception:
pass # Store might already be set up
memories = []
namespace = ("memories", user_id)
try:
# Use search with namespace to find all items for this user
items = thread_memory_store.search(namespace, limit=1000)
for item in items:
memory_data = {
"key": item.key,
"value": item.value,
"created_at": getattr(item, "created_at", "Unknown"),
"updated_at": getattr(item, "updated_at", "Unknown"),
}
memories.append(memory_data)
except Exception as e:
return f"No memories found for user {user_id} or error: {str(e)}"
finally:
conn.close()
if not memories:
return f"No memories found for user {user_id}"
formatted = [f"## Memories for User: {user_id}\n"]
for i, memory in enumerate(memories, 1):
formatted.append(f"### Memory {i}")
formatted.append(f"**Key:** {memory['key']}")
value = memory["value"]
if isinstance(value, dict):
try:
value_str = json.dumps(value, indent=2)
formatted.append(f"**Content:**\n```json\n{value_str}\n```")
except ValueError:
formatted.append(f"**Content:** {str(value)}")
else:
formatted.append(f"**Content:** {str(value)}")
formatted.append(f"**Created:** {memory['created_at']}")
formatted.append(f"**Updated:** {memory['updated_at']}")
formatted.append("---")
return "\n".join(formatted)
except Exception as e:
return f"Error accessing memories: {str(e)}"
# ---------- Gradio UI Blocks ----------
# Create Gradio interface with custom CSS and theme
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("## 💬 OpenChatBI Agent Chatbot with Streaming & On-Demand Visualization")
with gr.Tabs():
with gr.TabItem("💬 Chat"):
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Chat",
bubble_full_width=False,
height=500,
show_label=False,
sanitize_html=False,
render_markdown=True,
)
msg = gr.Textbox(placeholder="Type a message and press Enter", label="Input", elem_id="msg")
with gr.Column(scale=2, visible=False) as chart_panel:
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("### 📊 Interactive Chart")
with gr.Column(scale=1):
hide_chart_btn = gr.Button("✖️ Hide", elem_id="hide-chart-btn", size="sm")
plot = gr.Plot(label="", visible=True, show_label=False)
with gr.Column(scale=1):
user_box = gr.Textbox(value="default", label="User ID", interactive=True)
session_box = gr.Textbox(value="default", label="Session ID", interactive=True)
show_chart_btn = gr.Button("📊 Show Chart Panel", variant="secondary")
gr.Markdown(
"""
**Instructions**
- Type a data question and press Enter
- Supports streaming output (real-time display)
- Click chart links in chat to view interactive charts
- Use 'Show Chart Panel' to make panel visible
- Session ID can be used to differentiate between conversations
""",
elem_id="description",
)
def show_chart_panel():
"""Show the chart panel."""
return gr.update(visible=True)
def hide_chart_panel():
"""Hide the chart panel."""
return gr.update(visible=False)
# Register async submit handler for message input with plot output
msg.submit(respond, [msg, chatbot, user_box, session_box], [msg, chatbot, plot, chart_panel])
show_chart_btn.click(show_chart_panel, outputs=[chart_panel])
hide_chart_btn.click(hide_chart_panel, outputs=[chart_panel])
with gr.TabItem("🧠 Memory Store"):
gr.Markdown("### Long-term Memory Viewer")
gr.Markdown("View memories stored for each user in the system.")
with gr.Row():
with gr.Column(scale=3):
memory_display = gr.Markdown(
value="Enter a User ID and click 'Load Memories' to view stored memories.",
elem_id="memory-display",
)
with gr.Column(scale=1):
memory_user_input = gr.Textbox(label="User ID", placeholder="default", value="default")
load_memories_btn = gr.Button("🔍 Load Memories", variant="primary")
# Event handler for loading memories
load_memories_btn.click(fn=list_user_memories, inputs=[memory_user_input], outputs=[memory_display])
# ---------- API Endpoints ----------
@app.get("/api/download/report/{filename}")
async def download_report(filename: str):
"""Download a saved report file."""
return get_report_download_response(filename)
# ---------- Application Startup ----------
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/ui")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
================================================
FILE: sample_ui/streamlit_ui.py
================================================
"""Streamlit-based Streaming UI for OpenChatBI with collapsible thinking sections."""
import asyncio
import sys
import traceback
import uuid
from pathlib import Path
import plotly.graph_objects as go
try:
import pysqlite3 as sqlite3
except ImportError: # pragma: no cover
import sqlite3
import streamlit as st
from langchain_core.messages import AIMessage
sys.modules["sqlite3"] = sqlite3
from langgraph.types import Command
from openchatbi import config as openchatbi_config
from openchatbi.llm.llm import list_llm_providers
from openchatbi.utils import get_text_from_message_chunk, log
from sample_ui.plotly_utils import visualization_dsl_to_gradio_plot
from sample_ui.async_graph_manager import AsyncGraphManager
# Configuration
st.set_page_config(page_title="OpenChatBI - Streamlit Interface", page_icon="💬", layout="wide")
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "graph_manager" not in st.session_state:
st.session_state.graph_manager = AsyncGraphManager()
if "session_interrupts" not in st.session_state:
st.session_state.session_interrupts = {}
if "event_loop" not in st.session_state:
st.session_state.event_loop = None
async def process_user_message_stream(
message: str, user_id: str, session_id: str, llm_provider: str | None, thinking_container, response_container
):
"""
Process user message through the OpenChatBI graph with real-time updates
Updates the thinking_container and response_container as processing happens
"""
thinking_steps = []
final_response = ""
plot_figure = None
# Initialize graph if needed
if not st.session_state.graph_manager._initialized:
await st.session_state.graph_manager.initialize()
graph = await st.session_state.graph_manager.get_graph(llm_provider)
user_session_id = f"{user_id}-{session_id}"
# Check for interrupts
if st.session_state.session_interrupts.get(user_session_id, False):
stream_input = Command(resume=message)
else:
stream_input = {"messages": [{"role": "user", "content": message}]}
config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}}
data_csv = None
# Use empty container for real-time updates
thinking_placeholder = thinking_container.empty()
# Build content chronologically - all events in time order
base_content = "🔄 **Processing...**\n\n"
chronological_content = "" # All content in time order
def update_display():
full_content = base_content + chronological_content
thinking_placeholder.markdown(full_content)
# Initial display
update_display()
# Stream through the graph
async for _namespace, event_type, event_value in graph.astream(
stream_input, config=config, stream_mode=["updates", "messages"], subgraphs=True, debug=True
):
if event_type == "messages":
chunk = event_value[0]
metadata = event_value[1]
# Keep llm node messages only to avoid duplicates
if metadata["langgraph_node"] != "llm_node" or not metadata.get("streaming_tokens", False):
continue
token = get_text_from_message_chunk(chunk)
if token:
final_response += token
# Add to thinking content during processing
if len(final_response) == len(token):
chronological_content += "\n**🤖 AI Response:** "
chronological_content += token
update_display()
else:
# Process tool calls and intermediate steps
step_description = ""
if event_value.get("llm_node"):
message_obj = event_value["llm_node"].get("messages")[0]
if message_obj and isinstance(message_obj, AIMessage) and message_obj.tool_calls:
step_description = f"🛠️ Using tools: {', '.join(tool['name'] for tool in message_obj.tool_calls)}"
elif event_value.get("information_extraction"):
message_obj = event_value["information_extraction"].get("messages")[0]
if message_obj and message_obj.tool_calls:
step_description = f"🛠️ Using tool: {message_obj.tool_calls[0]['name']}"
else:
rewrite_q = event_value["information_extraction"].get("rewrite_question")
if rewrite_q:
step_description = f"📝 Rewriting question: {rewrite_q}"
elif event_value.get("table_selection"):
tables = event_value["table_selection"].get("tables")
if tables:
step_description = f"🗂️ Selected tables: {tables}"
elif event_value.get("generate_sql"):
sql = event_value["generate_sql"].get("sql")
if sql:
step_description = f"💾 Generated SQL:\n```sql\n{sql}\n```"
elif event_value.get("execute_sql"):
step_description = "⚡ Executing SQL query..."
data_csv = event_value["execute_sql"].get("data")
elif event_value.get("regenerate_sql"):
sql = event_value["regenerate_sql"].get("sql")
if sql:
step_description = f"🔄 Regenerated SQL:\n```sql\n{sql}\n```"
elif event_value.get("generate_visualization"):
visualization_dsl = event_value["generate_visualization"].get("visualization_dsl")
if visualization_dsl and "error" not in visualization_dsl and data_csv:
try:
plot_figure, plot_description = visualization_dsl_to_gradio_plot(data_csv, visualization_dsl)
step_description = f"📊 Generated visualization: {plot_description}"
except Exception as e:
step_description = f"⚠️ Visualization error: {str(e)}"
if step_description:
thinking_steps.append(step_description)
# Append new step to chronological content in time order
step_number = len(thinking_steps)
# Ensure proper spacing before new step
if chronological_content and not chronological_content.endswith("\n\n"):
chronological_content += "\n\n"
chronological_content += f"**Step {step_number}:** {step_description}\n\n"
update_display()
# Check for interrupts in final state
state = await graph.aget_state(config)
if state.interrupts:
log(f"State interrupts: {state.interrupts}")
output_content = state.interrupts[0].value.get("text", "")
if "buttons" in state.interrupts[0].value:
output_content += str(state.interrupts[0].value.get("buttons"))
final_response += output_content
# Append interrupt content to chronological content
chronological_content += output_content
update_display()
st.session_state.session_interrupts[user_session_id] = True
else:
st.session_state.session_interrupts[user_session_id] = False
# Final update - add completion message to chronological content
# Add some spacing if the last content didn't end with newlines
if not chronological_content.endswith("\n\n"):
chronological_content += "\n\n"
chronological_content += "✅ **Analysis complete!**"
update_display()
# Extract final answer (last part without tool calls) and display outside thinking
if final_response:
# Find the last occurrence of tool usage to separate final answer
lines = final_response.split("\n")
final_answer_lines = []
collecting_final = False
for line in reversed(lines):
if "Use tool:" in line or "Using tools:" in line or "Using tool:" in line:
break
final_answer_lines.append(line)
collecting_final = True
if collecting_final and final_answer_lines:
# Reverse back to correct order
final_answer_lines.reverse()
final_answer_text = "\n".join(final_answer_lines).strip()
if final_answer_text:
with response_container:
processed_final_answer_text = process_download_links(final_answer_text)
render_content_with_downloads(processed_final_answer_text)
# Final update to response container - only show plot if available (text response is in thinking container)
with response_container:
if plot_figure:
st.plotly_chart(plot_figure, use_container_width=True, key=str(uuid.uuid4()))
# Extract final answer for separate storage
final_answer_text = ""
if final_response:
lines = final_response.split("\n")
final_answer_lines = []
collecting_final = False
for line in reversed(lines):
if "Use tool:" in line or "Using tools:" in line or "Using tool:" in line:
break
final_answer_lines.append(line)
collecting_final = True
if collecting_final and final_answer_lines:
final_answer_lines.reverse()
final_answer_text = "\n".join(final_answer_lines).strip()
return final_response, plot_figure, thinking_steps, chronological_content, final_answer_text
def get_available_reports() -> list[str]:
"""Get list of available report files for download."""
try:
# Import config here to avoid circular imports
from openchatbi import config
report_dir = Path(config.get().report_directory)
if not report_dir.exists():
return []
# Get all files in the report directory
report_files = []
for file_path in report_dir.iterdir():
if file_path.is_file():
report_files.append(file_path.name)
return sorted(report_files)
except Exception as e:
st.error(f"Error accessing reports: {str(e)}")
return []
def get_report_file_content(filename: str) -> tuple[bytes | None, str | None]:
"""Get report file content for download.
Returns:
tuple: (file_content_bytes, mime_type) or (None, None) if error
"""
try:
# Import config here to avoid circular imports
from openchatbi import config
report_dir = Path(config.get().report_directory)
file_path = report_dir / filename
# Security check - ensure file is within report directory
if not file_path.exists() or not file_path.is_file():
st.error(f"Report file not found: {filename}")
return None, None
try:
file_path.resolve().relative_to(report_dir.resolve())
except ValueError:
st.error("Access denied to file")
return None, None
# Determine MIME type
mime_type_map = {
".md": "text/markdown",
".csv": "text/csv",
".txt": "text/plain",
".json": "application/json",
".html": "text/html",
".xml": "application/xml",
}
file_extension = file_path.suffix.lower()
mime_type = mime_type_map.get(file_extension, "application/octet-stream")
# Read file content
with open(file_path, "rb") as f:
content = f.read()
return content, mime_type
except Exception as e:
st.error(f"Error reading report file: {str(e)}")
return None, None
def process_download_links(content: str) -> str:
"""Process download links in content and replace them with Streamlit-compatible ones.
Args:
content: Message content that may contain download links
Returns:
str: Content with download links replaced
"""
import re
if not content:
return content
# Pattern to match both full URLs and path-only download links
# Matches: http://localhost:8501/api/download/report/filename.ext or /api/download/report/filename.ext
download_pattern = r"(?:https?://[^/\s]+)?/api/download/report/([^)\s\]<>]+)"
def replace_link(match):
filename = match.group(1)
# Return a placeholder that we'll replace with actual download button
return f"[DOWNLOAD_LINK:{filename}]"
processed_content = re.sub(download_pattern, replace_link, content)
# Debug log to see if processing worked
if processed_content != content:
st.write(f"🔍 Debug: Processed download links - found {content.count('/api/download/report/')} links")
return processed_content
def render_content_with_downloads(content: str) -> None:
"""Render content and replace download placeholders with actual download buttons."""
import re
# Split content by download placeholders
download_pattern = r"\[DOWNLOAD_LINK:([^)]+)\]"
parts = re.split(download_pattern, content)
for i, part in enumerate(parts):
if i % 2 == 0:
# Regular content
if part.strip():
st.markdown(part)
else:
# Download link filename
filename = part
file_content, mime_type = get_report_file_content(filename)
if file_content is not None:
st.download_button(
label=f"📥 Download {filename}",
data=file_content,
file_name=filename,
mime=mime_type,
key=f"inline_download_{filename}_{hash(content)}",
)
else:
st.error(f"❌ Could not load report: {filename}")
def display_message_with_thinking(
role: str, content: str, thinking_steps: list[str] = None, plot_figure: go.Figure = None
):
"""Display a message with collapsible thinking section"""
with st.chat_message(role):
if thinking_steps and role == "assistant":
# Create thinking section with all content inside
with st.expander("💭 AI Thinking Process", expanded=False):
for i, step in enumerate(thinking_steps, 1):
st.markdown(f"**Step {i}:** {step}")
if content:
st.markdown("**🤖 AI Response:**")
render_content_with_downloads(content)
st.success("✅ Analysis complete")
# For non-assistant messages, display content normally
elif content and role != "assistant":
render_content_with_downloads(content)
# Display plot if available (outside thinking container)
if plot_figure:
st.plotly_chart(plot_figure, use_container_width=True, key=str(uuid.uuid4()))
# Main UI
st.title("💬 OpenChatBI - Streamlit UI")
st.markdown("*AI-powered Business Intelligence Chat with Thinking*")
# Sidebar for configuration
with st.sidebar:
st.header("⚙️ Configuration")
user_id = st.text_input("User ID", value="default", help="Unique identifier for the user session")
session_id = st.text_input("Session ID", value="default", help="Session identifier for conversation continuity")
# Optional multi-provider support
llm_provider = None
provider_options = list_llm_providers()
if provider_options:
try:
default_provider = getattr(openchatbi_config.get(), "llm_provider", None)
except Exception:
default_provider = None
default_index = provider_options.index(default_provider) if default_provider in provider_options else 0
llm_provider = st.selectbox(
"LLM Provider",
options=provider_options,
index=default_index,
help="Select which configured LLM provider to use for this session",
)
st.markdown("---")
st.markdown(
"""
**💡 How to use:**
- Type your business questions
- Watch the AI thinking process in collapsible sections
- View generated charts and analyses
- Use different session IDs for separate conversations
"""
)
if st.button("🗑️ Clear Chat History"):
st.session_state.messages = []
st.rerun()
st.markdown("---")
st.markdown("### 📁 Report Downloads")
# Get available reports
available_reports = get_available_reports()
if available_reports:
selected_report = st.selectbox(
"Select a report to download:", options=[""] + available_reports, help="Choose a report file to download"
)
if selected_report and st.button("📥 Download Report"):
file_content, mime_type = get_report_file_content(selected_report)
if file_content is not None:
st.download_button(
label=f"💾 Save {selected_report}",
data=file_content,
file_name=selected_report,
mime=mime_type,
key=f"download_{selected_report}",
)
st.success(f"✅ {selected_report} is ready for download!")
else:
st.info("No reports available for download.")
# Display chat history
for msg in st.session_state.messages:
if msg["type"] == "chronological_message":
# Display chronological content in expander - all collapsed after completion
with st.chat_message(msg["role"]):
with st.expander("💭 AI Thinking Process", expanded=False):
st.markdown(msg["chronological_content"])
# Extract and display final answer text outside thinking
if msg.get("final_answer"):
render_content_with_downloads(msg["final_answer"])
# Display plot if available (outside thinking container)
if msg.get("plot_figure"):
st.plotly_chart(msg["plot_figure"], use_container_width=True, key=str(uuid.uuid4()))
elif msg["type"] == "thinking_message":
display_message_with_thinking(
msg["role"], msg["content"], msg.get("thinking_steps", []), msg.get("plot_figure")
)
else:
with st.chat_message(msg["role"]):
if msg["type"] == "text":
render_content_with_downloads(msg["content"])
elif msg["type"] == "plot" and msg.get("plot_figure"):
st.plotly_chart(msg["plot_figure"], use_container_width=True, key=str(uuid.uuid4()))
# Chat input
if prompt := st.chat_input("Ask me anything about your data..."):
# Add user message
st.session_state.messages.append({"role": "user", "type": "text", "content": prompt})
# Display user message immediately
with st.chat_message("user"):
st.markdown(prompt)
# Process assistant response with real-time streaming
with st.chat_message("assistant"):
# Create thinking and response containers
thinking_expander = st.expander("💭 AI Thinking Process...", expanded=True)
thinking_container = thinking_expander.container()
response_container = st.container()
# Process the message asynchronously with real-time updates
try:
# Reuse the same event loop to avoid binding issues
if st.session_state.event_loop is None or st.session_state.event_loop.is_closed():
st.session_state.event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(st.session_state.event_loop)
loop = st.session_state.event_loop
final_response, plot_figure, thinking_steps, full_chronological_content, final_answer = (
loop.run_until_complete(
process_user_message_stream(
prompt, user_id, session_id, llm_provider, thinking_container, response_container
)
)
)
# No need to create another expander - content is already shown in real-time
# Process download links in the content before storing
processed_chronological_content = process_download_links(full_chronological_content)
processed_final_answer = process_download_links(final_answer) if final_answer else final_answer
# Store the complete message with the processed content
st.session_state.messages.append(
{
"role": "assistant",
"type": "chronological_message",
"chronological_content": processed_chronological_content,
"final_answer": processed_final_answer,
"plot_figure": plot_figure,
}
)
# Trigger rerun to collapse the thinking section
st.rerun()
except Exception as e:
traceback.print_exc()
st.error(f"❌ Error processing request: {str(e)}")
error_content = f"❌ Error: {str(e)}"
processed_error_content = process_download_links(error_content)
st.session_state.messages.append({"role": "assistant", "type": "text", "content": processed_error_content})
# Cleanup on session end
def cleanup_session():
"""Cleanup resources when session ends"""
if "graph_manager" in st.session_state:
try:
# Use the same event loop for cleanup
if st.session_state.event_loop and not st.session_state.event_loop.is_closed():
loop = st.session_state.event_loop
loop.run_until_complete(st.session_state.graph_manager.cleanup())
loop.close()
st.session_state.event_loop = None
except Exception as e:
log(f"Error during session cleanup: {e}")
# Register cleanup (this is a simplified approach - in production you might want more robust cleanup)
import atexit
atexit.register(cleanup_session)
================================================
FILE: sample_ui/style.py
================================================
# Custom CSS for styling the chat interface
custom_css = """
#chatbot {
height: 600px !important;
font-family: "Inter", "Helvetica Neue", sans-serif;
}
#chatbot .wrap.svelte-1cl84sx {
background: #f5f7fa;
border-radius: 12px;
padding: 8px;
}
.message.user {
background-color: #d1e9ff !important;
border-radius: 12px 12px 0 12px;
margin: 4px 0;
padding: 10px 14px;
font-size: 15px;
}
.message.bot {
background-color: #ffffff !important;
border-radius: 12px 12px 12px 0;
margin: 4px 0;
padding: 10px 14px;
font-size: 15px;
box-shadow: 0px 1px 3px rgba(0,0,0,0.08);
}
#msg {
font-family: "Inter", "Helvetica Neue", sans-serif;
font-size: 15px;
}
#description {
font-family: "Inter", "Helvetica Neue", sans-serif;
font-size: 14px;
color: #374151;
line-height: 1.6;
}
"""
================================================
FILE: tests/README.md
================================================
# OpenChatBI Test Suite
This directory contains comprehensive unit tests for the OpenChatBI project. The test suite is built using pytest and follows modern Python testing best practices.
## Test Structure
```
tests/
├── __init__.py # Test package initialization
├── conftest.py # Shared fixtures and configuration
├── README.md # This file
│
├── Core Module Tests
├── test_config_loader.py # Configuration loading tests
├── test_graph_state.py # State management tests
├── test_utils.py # Utility function tests
│
├── Catalog System Tests
├── test_catalog_store.py # Catalog store interface tests
├── test_catalog_loader.py # Database catalog loading tests
│
├── Text2SQL Pipeline Tests
├── test_text2sql_extraction.py # Information extraction tests
├── test_text2sql_generate_sql.py # SQL generation tests
├── test_text2sql_schema_linking.py # Schema linking tests
├── test_text2sql_visualization.py # Data visualization tests
│
├── Tool Tests
├── test_tools_ask_human.py # Human interaction tool tests
├── test_tools_run_python_code.py # Python code execution tests
├── test_tools_search_knowledge.py # Knowledge search tests
│
├── Additional Module Tests
├── test_memory.py # Memory management tests
├── test_plotly_utils.py # Plotly utilities tests
├── test_incomplete_tool_calls.py # Incomplete tool call handling tests
│
└── Context Management Tests
└── context_management/ # Context management test suite (see context_management/README.md)
```
## Running Tests
### Prerequisites
Ensure you have the development dependencies installed:
```bash
# Using uv (recommended)
uv sync --group dev
# Or using pip
pip install -e ".[dev]"
```
### Basic Test Execution
```bash
# Run all tests
uv run pytest
# Run tests with verbose output
uv run pytest -v
# Run specific test file
uv run pytest tests/test_config_loader.py
# Run specific test class
uv run pytest tests/test_config_loader.py::TestConfigLoader
# Run specific test method
uv run pytest tests/test_config_loader.py::TestConfigLoader::test_load_config_from_file
```
### Test Coverage
```bash
# Run tests with coverage report
uv run pytest --cov=openchatbi --cov-report=html --cov-report=term-missing
# View HTML coverage report
open htmlcov/index.html
```
### Test Categories
```bash
# Run only fast unit tests (exclude slow integration tests)
uv run pytest -m "not slow"
# Run tests for specific components
uv run pytest tests/test_catalog* -k "catalog"
uv run pytest tests/test_text2sql* -k "text2sql"
uv run pytest tests/test_tools* -k "tools"
# Run context management tests
uv run pytest tests/context_management/
# Run memory and utility tests
uv run pytest tests/test_memory.py tests/test_plotly_utils.py
# Run incomplete tool call tests
uv run pytest tests/test_incomplete_tool_calls.py
```
## Test Configuration
### Environment Variables
The test suite uses several environment variables that can be set to customize test behavior:
- `OPENCHATBI_TEST_MODE=true` - Enables test mode
- `OPENCHATBI_CONFIG_PATH` - Path to test configuration file
- `PYTEST_TIMEOUT=300` - Test timeout in seconds
### Fixtures
The `conftest.py` file provides shared fixtures used across tests:
#### Core Fixtures
- `test_config` - Test configuration dictionary
- `temp_dir` - Temporary directory for test files
- `mock_llm` - Mocked language model for testing
- `sample_agent_state` - Sample AgentState for testing
#### Catalog Fixtures
- `mock_catalog_store` - Mocked catalog store with sample data
- `mock_database_engine` - Mocked database engine
- `sample_table_info` - Sample table metadata
#### Database Fixtures
- `mock_presto_connection` - Mocked Presto database connection
- `mock_token_service` - Mocked authentication token service
## Writing Tests
### Test Naming Conventions
Follow these naming conventions for consistency:
```python
# Test files
test_.py
# Test classes
class TestModuleName:
# Test methods
def test_specific_functionality(self):
def test_error_condition_handling(self):
def test_edge_case_scenario(self):
```
### Test Categories
Use pytest marks to categorize tests:
```python
import pytest
@pytest.mark.unit
def test_basic_functionality():
"""Unit test for basic functionality."""
pass
@pytest.mark.integration
def test_database_integration():
"""Integration test with database."""
pass
@pytest.mark.slow
def test_performance_benchmark():
"""Slow performance test."""
pass
@pytest.mark.parametrize("input,expected", [
("test1", "result1"),
("test2", "result2"),
])
def test_multiple_scenarios(input, expected):
"""Test multiple input/output scenarios."""
pass
```
### Mocking Best Practices
Use proper mocking for external dependencies:
```python
from unittest.mock import Mock, patch, MagicMock
# Mock external services
@patch('openchatbi.module.external_service')
def test_with_external_service(mock_service):
mock_service.return_value = "expected_result"
# Test implementation
# Mock LLM responses
def test_llm_integration(mock_llm):
mock_llm.invoke.return_value = AIMessage(content="Mock response")
# Test implementation
```
### Async Test Support
For testing async functionality:
```python
import pytest
import asyncio
@pytest.mark.asyncio
async def test_async_functionality():
"""Test asynchronous operations."""
result = await some_async_function()
assert result is not None
```
## Test Data
### Sample Data Files
Test data is managed through fixtures and temporary files:
```python
def test_with_sample_data(temp_dir):
"""Test using temporary sample data."""
# Create test data file
data_file = temp_dir / "test_data.csv"
data_file.write_text("col1,col2\\nval1,val2")
# Test with the data
assert data_file.exists()
```
### Mock Responses
Common mock responses are defined in fixtures:
```python
# SQL generation mock response
mock_llm.invoke.return_value = AIMessage(
content="SELECT COUNT(*) FROM test_table;"
)
# Catalog search mock response
mock_catalog.search_tables.return_value = [
{"table_name": "users", "description": "User data"}
]
```
## Continuous Integration
### GitHub Actions
Tests run automatically on:
- Pull requests
- Pushes to main branch
- Scheduled runs (daily)
### Test Matrix
Tests run against multiple configurations:
- Python versions: 3.11+
- Dependencies: Minimum and latest versions
## Debugging Tests
### Common Issues
1. **Import Errors**
```bash
# Ensure package is installed in development mode
pip install -e .
```
2. **Missing Dependencies**
```bash
# Install test dependencies
pip install -e ".[test]"
```
3. **Configuration Issues**
```bash
# Set test environment variables
export OPENCHATBI_TEST_MODE=true
```
### Debug Output
Enable debug output for failing tests:
```bash
# Run with debug output
uv run pytest -v -s --tb=long
# Run with pdb on failures
uv run pytest --pdb
# Run with coverage debug
uv run pytest --cov-report=term-missing -v
```
## Performance Testing
### Benchmarks
Performance tests are marked with `@pytest.mark.slow`:
```bash
# Run performance tests
uv run pytest -m slow
# Skip performance tests
uv run pytest -m "not slow"
```
### Memory Profiling
For memory usage testing:
```bash
# Install memory profiler
pip install memory-profiler
# Run with memory profiling
uv run pytest --profile-mem
```
## Contributing
### Adding New Tests
1. Create test file following naming conventions
2. Import required fixtures from `conftest.py`
3. Write comprehensive test cases covering:
- Happy path scenarios
- Error conditions
- Edge cases
- Performance considerations
4. Use appropriate mocking for external dependencies
5. Add docstrings explaining test purpose
6. Run tests locally before submitting PR
### Test Review Guidelines
When reviewing test PRs:
- Ensure adequate test coverage
- Verify mock usage is appropriate
- Check for test independence
- Validate error case handling
- Confirm performance test categorization
## Resources
- [Pytest Documentation](https://docs.pytest.org/)
- [Python unittest.mock](https://docs.python.org/3/library/unittest.mock.html)
- [Coverage.py Documentation](https://coverage.readthedocs.io/)
- [pytest-asyncio](https://pytest-asyncio.readthedocs.io/)
================================================
FILE: tests/__init__.py
================================================
"""Test package for OpenChatBI."""
================================================
FILE: tests/conftest.py
================================================
"""Pytest configuration and shared fixtures."""
import tempfile
from collections.abc import Generator
from pathlib import Path
from typing import Any
from unittest.mock import Mock
import pytest
from langchain_core.language_models import FakeListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from sqlalchemy import create_engine
from openchatbi.catalog.store.file_system import FileSystemCatalogStore
from openchatbi.config_loader import ConfigLoader
from openchatbi.graph_state import AgentState
@pytest.fixture(scope="session")
def test_config() -> dict[str, Any]:
"""Test configuration fixture."""
return {
"organization": "TestOrg",
"dialect": "presto",
"bi_config_file": "test_bi.yaml",
"catalog_store": {"store_type": "file_system", "data_path": "./test_data"},
"default_llm": {
"class": "langchain_core.language_models.FakeListChatModel",
"params": {"responses": ["Test response"]},
},
}
@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""Temporary directory fixture."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def mock_llm() -> FakeListChatModel:
"""Mock LLM fixture for testing."""
return FakeListChatModel(
responses=["SELECT COUNT(*) FROM test_table;", "This is a test SQL query.", "Test analysis result."]
)
@pytest.fixture
def sample_agent_state() -> AgentState:
"""Sample agent state for testing."""
return AgentState(
messages=[HumanMessage(content="Test query")],
sql="SELECT * FROM test_table;",
agent_next_node="sql_generation",
final_answer="Test data results",
)
@pytest.fixture
def mock_catalog_store(temp_dir: Path) -> FileSystemCatalogStore:
"""Mock catalog store fixture."""
# Create test data files
test_data_dir = temp_dir / "test_data"
test_data_dir.mkdir(exist_ok=True)
# Create sample table_columns.csv
tables_info_file = test_data_dir / "table_info.yaml"
tables_info_file.write_text(
"""test:
test_table:
type: fact
description: A test table for unit tests
user_data:
type: fact
description: User information table"""
)
# Create sample table_columns.csv
tables_file = test_data_dir / "table_columns.csv"
tables_file.write_text(
"""db_name,table_name,column_name
test,test_table,id
test,test_table,name
test,user_data,user_id"""
)
# Create sample table_spec_columns.csv
columns_file = test_data_dir / "table_spec_columns.csv"
columns_file.write_text(
"""db_name,table_name,column_name,type,display_name,description
test,test_table,id,bigint,Id,Primary key
test,test_table,name,varchar,Name,User name
test,user_data,user_id,bigint,User Id,User identifier"""
)
# Create sample common_columns.csv
common_columns_file = test_data_dir / "common_columns.csv"
common_columns_file.write_text(
"""column_name,type,display_name,description
status,varchar,Status,Record status
created_at,timestamp,Created At,Creation timestamp
updated_at,timestamp,Updated At,Last update timestamp"""
)
# Mock data warehouse config
data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}
return FileSystemCatalogStore(data_path=str(test_data_dir), data_warehouse_config=data_warehouse_config)
@pytest.fixture
def mock_database_engine():
"""Mock database engine fixture."""
engine = create_engine("sqlite:///:memory:")
# Create test tables
with engine.connect() as conn:
conn.execute("CREATE TABLE test_table (id INTEGER, name TEXT)")
conn.execute("INSERT INTO test_table VALUES (1, 'Test User')")
conn.commit()
return engine
@pytest.fixture
def sample_table_info() -> dict[str, Any]:
"""Sample table information fixture."""
return {
"test_table": {
"columns": [
{"name": "id", "type": "bigint", "description": "Primary key"},
{"name": "name", "type": "varchar", "description": "User name"},
],
"description": "A test table for unit tests",
"sql_rule": "Always filter by active status",
}
}
@pytest.fixture
def sample_messages() -> list:
"""Sample message history fixture."""
return [
HumanMessage(content="What's the user count?"),
AIMessage(content="I'll help you get the user count from the database."),
HumanMessage(content="Show me the SQL query"),
]
@pytest.fixture(autouse=True)
def reset_config_loader():
"""Reset ConfigLoader singleton state before each test."""
# Reset the singleton instance to ensure clean state
ConfigLoader._instance = None
ConfigLoader._config = None
yield
# Clean up after test
ConfigLoader._instance = None
ConfigLoader._config = None
@pytest.fixture
def mock_config():
"""Provide a mock configuration for tests that need it."""
from unittest.mock import MagicMock
config_dict = {
"organization": "Test Company",
"dialect": "presto",
"default_llm": MagicMock(),
"embedding_model": MagicMock(),
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
"report_directory": "./data",
"python_executor": "local",
"visualization_mode": "rule",
"context_config": {},
}
loader = ConfigLoader()
loader.set(config_dict)
return loader.get()
@pytest.fixture(autouse=True)
def setup_test_env(monkeypatch, temp_dir):
"""Setup test environment variables."""
monkeypatch.setenv("OPENCHATBI_CONFIG_PATH", str(temp_dir / "config.yaml"))
monkeypatch.setenv("OPENCHATBI_TEST_MODE", "true")
class MockTokenService:
"""Mock token service for testing."""
def __init__(self):
self.token = "mock_token_12345"
def get_token(self) -> str:
return self.token
@pytest.fixture
def mock_token_service() -> MockTokenService:
"""Mock token service fixture."""
return MockTokenService()
@pytest.fixture
def sample_sql_examples() -> list:
"""Sample SQL examples fixture."""
return [
{"question": "How many users are there?", "sql": "SELECT COUNT(*) FROM users;", "tables": ["users"]},
{
"question": "What's the average age?",
"sql": "SELECT AVG(age) FROM users WHERE age IS NOT NULL;",
"tables": ["users"],
},
]
@pytest.fixture
def mock_presto_connection():
"""Mock Presto connection fixture."""
mock_conn = Mock()
mock_cursor = Mock()
# Setup cursor behavior
mock_cursor.fetchall.return_value = [("table1", "Test table 1"), ("table2", "Test table 2")]
mock_cursor.description = [("table_name",), ("description",)]
mock_conn.cursor.return_value = mock_cursor
mock_conn.execute.return_value = mock_cursor
return mock_conn
================================================
FILE: tests/context_management/README.md
================================================
# Context Management Test Suite
This directory contains comprehensive tests for the context management functionality in OpenChatBI.
## Test Structure
### 📁 Test Files
- **`test_context_manager.py`** - Unit tests for the `ContextManager` class
- **`test_context_config.py`** - Tests for context configuration management
- **`test_agent_graph_integration.py`** - Integration tests for agent graph with context management
- **`test_edge_cases.py`** - Edge case handling
- **`test_state_operations.py`** - Tests for state operations and message processing
- **`conftest.py`** - Shared pytest fixtures and configuration
- **`test_runner.py`** - Custom test runner script
## 🧪 Test Categories
### Unit Tests (`test_context_manager.py`)
Tests core functionality of the `ContextManager` class:
- ✅ Token estimation and message token calculation
- ✅ Tool output trimming (generic, SQL, Python code)
- ✅ Conversation summarization
- ✅ Context management with sliding window
- ✅ Tool wrapper functionality
- ✅ Configuration-based behavior
**Key test cases:**
- `test_trim_sql_output()` - Tests intelligent SQL result trimming
- `test_conversation_summary_success()` - Tests LLM-based summarization
- `test_manage_context_with_summarization()` - Tests full context management flow
### Configuration Tests (`test_context_config.py`)
Tests configuration management and validation:
- ✅ Default configuration values
- ✅ Custom configuration creation
- ✅ Configuration updates
- ✅ Edge cases (zero/negative values)
- ✅ Different configuration presets
**Key test cases:**
- `test_update_context_config_multiple_values()` - Tests configuration updates
- `test_production_optimized_config()` - Tests realistic production settings
### Integration Tests (`test_agent_graph_integration.py`)
Tests integration with the agent graph system:
- ✅ Agent router with context management
- ✅ Graph building with/without context management
- ✅ Tool wrapping in graph context
- ✅ Full conversation flow testing
- ✅ System message preservation
**Key test cases:**
- `test_agent_router_with_context_manager()` - Tests router integration
- `test_full_context_management_flow()` - Tests end-to-end functionality
### State Operations Tests (`test_state_operations.py`)
Tests state manipulation and message processing operations:
- ✅ Message trimming and truncation logic
- ✅ State updates and modifications
- ✅ Message type handling and conversion
- ✅ Context state preservation during operations
- ✅ Error handling in state operations
**Key test cases:**
- `test_trim_messages_by_token_count()` - Tests message trimming logic
- `test_state_message_processing()` - Tests state message operations
- `test_context_state_updates()` - Tests context state modifications
### Edge Cases (`test_edge_cases.py`)
Tests system behavior under stress and edge conditions:
- ✅ Unicode and encoding edge cases
- ✅ Malformed input handling
**Key test cases:**
- `test_sql_output_edge_cases()` - SQL edge cases
- `test_extremely_nested_or_complex_structures()` - Complex data structures
## 🚀 Running Tests
### Using the Test Runner
```bash
# Run all tests
python tests/context_management/test_runner.py
# Run only unit tests
python tests/context_management/test_runner.py --type unit
# Run with coverage reporting
python tests/context_management/test_runner.py --coverage
```
### Using Pytest Directly
```bash
# Run all context management tests
pytest tests/context_management/
# Run specific test file
pytest tests/context_management/test_context_manager.py
# Run with verbose output
pytest tests/context_management/ -v
# Run with coverage
pytest tests/context_management/ --cov=openchatbi.context_manager --cov-report=html
```
## 📊 Test Markers
Tests are organized using pytest markers:
- `@pytest.mark.integration` - Integration tests
- `@pytest.mark.slow` - Slow-running tests (can be excluded)
## 🎯 Test Coverage Areas
### Core Functionality
- [x] Token estimation and management
- [x] Message processing and trimming
- [x] Conversation summarization
- [x] Context compression strategies
### Tool Output Management
- [x] SQL output trimming with structure preservation
- [x] Python code output handling
- [x] Error message preservation
- [x] Generic output trimming
### Configuration Management
- [x] Default and custom configurations
- [x] Configuration validation
- [x] Runtime configuration updates
- [x] Edge case configurations
### Integration Points
- [x] Agent router integration
- [x] Graph building integration
- [x] Tool wrapper integration
- [x] LLM service integration
### Edge Cases
- [x] Unicode and encoding issues
- [x] Malformed input handling
## 🧩 Fixtures
### Common Fixtures (in `conftest.py`)
- `mock_llm` - Mock language model for testing
- `standard_config` - Standard test configuration
- `minimal_config` - Minimal configuration for edge testing
- `sample_conversation` - Sample conversation data
- `large_sql_output` - Large SQL output for trimming tests
- `error_output` - Sample error output for preservation tests
## 🔧 Extending Tests
### Adding New Test Cases
1. Choose the appropriate test file based on the functionality
2. Use existing fixtures from `conftest.py`
3. Follow the naming convention: `test_feature_description()`
4. Add appropriate markers for categorization
### Adding New Fixtures
Add shared fixtures to `conftest.py` if they'll be used across multiple test files.
## 🐛 Debugging Tests
### Common Issues
1. **Mock LLM failures**: Ensure proper mocking of LLM responses
2. **Configuration conflicts**: Use isolated config instances
3. **Memory leaks in large tests**: Force garbage collection with `gc.collect()`
### Debugging Tools
```bash
# Run with debugging output
pytest tests/context_management/ -v -s
# Run single test with full traceback
pytest tests/context_management/test_name.py::test_function -v --tb=long
# Profile test performance
pytest tests/context_management/ --profile
```
## 📋 Test Results
Expected test results:
- **Total tests**: ~100+ test cases across 6 test files
- **Coverage target**: >95% for context management modules
- **State operations tests**: All message processing should work correctly
- **Edge cases**: All should handle gracefully without crashes
================================================
FILE: tests/context_management/__init__.py
================================================
"""Context management test package."""
# Test package initialization
================================================
FILE: tests/context_management/conftest.py
================================================
"""Pytest configuration and fixtures for context management tests."""
from unittest.mock import Mock
import pytest
from langchain_core.messages import AIMessage
from openchatbi.context_config import ContextConfig
@pytest.fixture
def mock_llm():
"""Mock LLM for testing across all test modules."""
llm = Mock()
llm.bind_tools = Mock(return_value=llm)
return llm
@pytest.fixture
def mock_llm_with_summary_response():
"""Mock LLM that returns a summary response."""
llm = Mock()
llm.bind_tools = Mock(return_value=llm)
return llm
@pytest.fixture
def standard_config():
"""Standard test configuration."""
return ContextConfig(
enabled=True,
summary_trigger_tokens=12000,
keep_recent_messages=10,
max_tool_output_length=2000,
max_sql_result_rows=20,
max_code_output_lines=50,
enable_summarization=True,
enable_conversation_summary=True,
preserve_tool_errors=True,
)
@pytest.fixture
def minimal_config():
"""Minimal test configuration."""
return ContextConfig(
enabled=True,
summary_trigger_tokens=800,
keep_recent_messages=3,
max_tool_output_length=200,
max_sql_result_rows=5,
max_code_output_lines=10,
)
@pytest.fixture
def disabled_config():
"""Configuration with context management disabled."""
return ContextConfig(
enabled=False, summary_trigger_tokens=12000, keep_recent_messages=10, max_tool_output_length=2000
)
@pytest.fixture
def sample_conversation():
"""Sample conversation for testing."""
from langchain_core.messages import HumanMessage, ToolMessage
return [
HumanMessage(content="Can you analyze our sales data?"),
AIMessage(content="I'll help you analyze the sales data. Let me query the database."),
ToolMessage(content="Query executed successfully. Found 1000 records.", tool_call_id="query_1"),
HumanMessage(content="What are the top trends?"),
AIMessage(content="Based on the data, I can see several key trends..."),
HumanMessage(content="Can you create a visualization?"),
AIMessage(
content="I'll create a chart for you.",
tool_calls=[{"name": "create_chart", "args": {"type": "bar"}, "id": "chart_1"}],
),
ToolMessage(content="Chart created successfully.", tool_call_id="chart_1"),
]
@pytest.fixture
def large_sql_output():
"""Large SQL output for testing trimming."""
csv_data = "id,name,value,date\n"
csv_data += "\n".join([f"{i},Customer_{i},{i*100},2023-01-{i%30+1:02d}" for i in range(100)])
return f"""SQL Query:
```sql
SELECT id, name, value, date
FROM customers
ORDER BY value DESC
LIMIT 100;
```
Query Results (CSV format):
```csv
{csv_data}
```
Visualization Created: bar chart has been automatically generated and will be displayed in the UI."""
@pytest.fixture
def large_python_output():
"""Large Python code execution output."""
output_lines = []
output_lines.append("Processing data...")
for i in range(100):
output_lines.append(f"Step {i}: Processing record {i} - Status: OK")
output_lines.append("Processing complete!")
return "\n".join(output_lines)
@pytest.fixture
def error_output():
"""Sample error output."""
return """Traceback (most recent call last):
File "/app/analysis.py", line 42, in analyze_sales
df = pd.read_csv('nonexistent_file.csv')
File "/usr/local/lib/python3.9/site-packages/pandas/io/parsers/readers.py", line 912, in read_csv
return _read(filepath_or_buffer, kwds)
FileNotFoundError: [Errno 2] No such file or directory: 'nonexistent_file.csv'
Error: Could not load the sales data file. Please check that the file exists and is accessible."""
# Pytest configuration
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')")
config.addinivalue_line("markers", "integration: marks tests as integration tests")
def pytest_collection_modifyitems(config, items):
"""Modify test items to add markers."""
for item in items:
# Mark integration tests
if "integration" in item.nodeid.lower():
item.add_marker(pytest.mark.integration)
# Mark slow tests based on certain patterns
if any(pattern in item.nodeid.lower() for pattern in ["large", "stress", "concurrent"]):
item.add_marker(pytest.mark.slow)
================================================
FILE: tests/context_management/test_agent_graph_integration.py
================================================
"""Integration tests for agent graph with context management."""
from unittest.mock import Mock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import StructuredTool
from openchatbi.agent_graph import _build_graph_core, agent_llm_call, build_agent_graph_async, build_agent_graph_sync
from openchatbi.context_config import ContextConfig
from openchatbi.context_manager import ContextManager
from openchatbi.graph_state import AgentState
class TestAgentGraphIntegration:
"""Integration tests for agent graph with context management."""
@pytest.fixture
def mock_catalog(self):
"""Mock catalog store for testing."""
catalog = Mock()
catalog.get_schema = Mock(return_value={"tables": []})
return catalog
@pytest.fixture
def mock_llm(self):
"""Mock LLM for testing."""
llm = Mock()
llm.bind_tools = Mock(return_value=llm)
return llm
@pytest.fixture
def mock_tools(self):
"""Mock tools for testing."""
def mock_tool_func(query: str) -> str:
return "Mock tool result"
tool = StructuredTool.from_function(func=mock_tool_func, name="mock_tool", description="Mock tool for testing")
return [tool]
@pytest.fixture
def test_config(self):
"""Test configuration for context management."""
return ContextConfig(
enabled=True,
summary_trigger_tokens=800,
keep_recent_messages=3,
max_tool_output_length=100,
)
def test_agent_llm_node_with_context_manager(self, mock_llm, mock_tools, test_config):
"""Test agent llm_node with context manager integration."""
context_manager = ContextManager(llm=mock_llm, config=test_config)
# Mock LLM response
mock_response = AIMessage(content="Test response", tool_calls=[])
with patch("openchatbi.agent_graph.call_llm_chat_model_with_retry", return_value=mock_response):
llm_node_func = agent_llm_call(mock_llm, mock_tools, context_manager)
# Create test state with long messages to trigger context management
long_messages = [
HumanMessage(content="A" * 500), # Long message
AIMessage(content="B" * 500), # Long message
ToolMessage(content="C" * 200, tool_call_id="123"), # Long tool output
HumanMessage(content="Recent question"),
]
state = AgentState(messages=long_messages)
result = llm_node_func(state)
# Should have processed the state
assert "messages" in result
assert isinstance(result["messages"][0], AIMessage)
def test_agent_llm_node_without_context_manager(self, mock_llm, mock_tools):
"""Test agent llm_node without context manager."""
mock_response = AIMessage(content="Test response", tool_calls=[])
with patch("openchatbi.agent_graph.call_llm_chat_model_with_retry", return_value=mock_response):
llm_node_func = agent_llm_call(mock_llm, mock_tools, context_manager=None)
state = AgentState(messages=[HumanMessage(content="Test")])
result = llm_node_func(state)
assert "messages" in result
assert isinstance(result["messages"][0], AIMessage)
def test_build_graph_core_with_context_management(self, mock_catalog, mock_llm):
"""Test core graph building with context management enabled."""
def create_mock_tool(name):
def mock_func(input_str: str) -> str:
return f"Mock {name} result"
return StructuredTool.from_function(func=mock_func, name=name, description=f"Mock {name} tool")
# Mock all the tool imports directly
with (
patch("openchatbi.agent_graph.search_knowledge", create_mock_tool("search_knowledge")),
patch("openchatbi.agent_graph.show_schema", create_mock_tool("show_schema")),
patch("openchatbi.agent_graph.run_python_code", create_mock_tool("run_python_code")),
patch("openchatbi.agent_graph.save_report", create_mock_tool("save_report")),
patch("openchatbi.agent_graph.timeseries_forecast", create_mock_tool("timeseries_forecast")),
patch("openchatbi.agent_graph.get_sql_tools") as mock_get_sql_tools,
patch("openchatbi.agent_graph.build_sql_graph") as mock_sql_graph,
patch("openchatbi.agent_graph.get_memory_tools") as mock_memory_tools,
patch("openchatbi.agent_graph.create_mcp_tools_sync") as mock_mcp_tools,
patch("openchatbi.agent_graph.get_llm", return_value=mock_llm),
):
# Setup function-based mocks
mock_get_sql_tools.return_value = create_mock_tool("call_sql_graph_tool")
mock_sql_graph.return_value = Mock()
mock_memory_tools.return_value = (
create_mock_tool("manage_memory_tool"),
create_mock_tool("search_memory_tool"),
)
mock_mcp_tools.return_value = []
graph = _build_graph_core(
catalog=mock_catalog,
sync_mode=True,
checkpointer=None,
memory_store=None,
memory_tools=None,
mcp_tools=[],
enable_context_management=True,
)
# Should create a compiled graph
assert graph is not None
# Verify that SQL graph was initialized
mock_sql_graph.assert_called_once()
def test_build_graph_core_without_context_management(self, mock_catalog, mock_llm):
"""Test core graph building with context management disabled."""
def create_mock_tool(name):
def mock_func(input_str: str) -> str:
return f"Mock {name} result"
return StructuredTool.from_function(func=mock_func, name=name, description=f"Mock {name} tool")
# Mock all the tool imports directly - same pattern as with context management
with (
patch("openchatbi.agent_graph.search_knowledge", create_mock_tool("search_knowledge")),
patch("openchatbi.agent_graph.show_schema", create_mock_tool("show_schema")),
patch("openchatbi.agent_graph.run_python_code", create_mock_tool("run_python_code")),
patch("openchatbi.agent_graph.save_report", create_mock_tool("save_report")),
patch("openchatbi.agent_graph.timeseries_forecast", create_mock_tool("timeseries_forecast")),
patch("openchatbi.agent_graph.get_sql_tools") as mock_get_sql_tools,
patch("openchatbi.agent_graph.build_sql_graph") as mock_sql_graph,
patch("openchatbi.agent_graph.get_memory_tools") as mock_memory_tools,
patch("openchatbi.agent_graph.create_mcp_tools_sync") as mock_mcp_tools,
patch("openchatbi.agent_graph.get_llm", return_value=mock_llm),
):
# Setup function-based mocks
mock_get_sql_tools.return_value = create_mock_tool("call_sql_graph_tool")
mock_sql_graph.return_value = Mock()
mock_memory_tools.return_value = (
create_mock_tool("manage_memory_tool"),
create_mock_tool("search_memory_tool"),
)
mock_mcp_tools.return_value = []
graph = _build_graph_core(
catalog=mock_catalog,
sync_mode=True,
checkpointer=None,
memory_store=None,
memory_tools=None,
mcp_tools=[],
enable_context_management=False,
)
# Should still create a compiled graph
assert graph is not None
def test_build_agent_graph_sync_with_context_management(self, mock_catalog):
"""Test sync graph building with context management."""
with (
patch("openchatbi.agent_graph.create_mcp_tools_sync") as mock_mcp_tools,
patch("openchatbi.agent_graph._build_graph_core") as mock_build_core,
):
mock_build_core.return_value = Mock()
mock_mcp_tools.return_value = []
graph = build_agent_graph_sync(catalog=mock_catalog, enable_context_management=True)
# Verify _build_graph_core was called with correct parameters
mock_build_core.assert_called_once()
call_args = mock_build_core.call_args
assert call_args[1]["enable_context_management"] is True
# Should return the graph
assert graph is not None
@pytest.mark.asyncio
async def test_build_agent_graph_async_with_context_management(self, mock_catalog):
"""Test async graph building with context management."""
with (
patch("openchatbi.agent_graph.get_mcp_tools_async") as mock_mcp_tools,
patch("openchatbi.agent_graph._build_graph_core") as mock_build_core,
):
mock_build_core.return_value = Mock()
# Mock async function
mock_mcp_tools.return_value = []
graph = await build_agent_graph_async(catalog=mock_catalog, enable_context_management=True)
# Verify _build_graph_core was called with correct parameters
mock_build_core.assert_called_once()
call_args = mock_build_core.call_args
assert call_args[1]["enable_context_management"] is True
# Should return the graph
assert graph is not None
@patch("openchatbi.agent_graph.call_llm_chat_model_with_retry")
def test_full_context_management_flow(self, mock_llm_call, mock_catalog):
"""Test full context management flow in agent graph."""
# Mock LLM responses
mock_llm_call.side_effect = [
AIMessage(content="Response 1"),
AIMessage(content="Summary of conversation"), # For summarization
AIMessage(content="Final response"),
]
context_manager = ContextManager(
llm=Mock(),
config=ContextConfig(
enabled=True,
summary_trigger_tokens=80,
keep_recent_messages=2,
),
)
# Create many messages to trigger context management
messages = []
for i in range(10):
messages.extend(
[
HumanMessage(content=f"Question {i}" * 10), # Make messages longer
AIMessage(content=f"Response {i}" * 10),
ToolMessage(content=f"Tool result {i}" * 20, tool_call_id=f"tool_{i}"),
]
)
# Test context management
original_count = len(messages)
context_manager.manage_context_messages(messages)
managed_messages = messages
# Should have fewer messages than input
assert len(managed_messages) < original_count
# Should preserve recent messages
assert any("Question 9" in str(msg.content) for msg in managed_messages if hasattr(msg, "content"))
class TestContextManagementEdgeCases:
"""Test edge cases for context management in agent graph."""
def test_empty_message_handling(self):
"""Test handling of empty messages."""
config = ContextConfig(enabled=True)
context_manager = ContextManager(llm=Mock(), config=config)
messages = []
context_manager.manage_context_messages(messages)
result = messages
assert result == []
def test_state_message_type_validation(self):
"""Test that only valid state message types are maintained during context management."""
config = ContextConfig(enabled=True)
context_manager = ContextManager(llm=Mock(), config=config)
# State should only contain valid message types (no SystemMessage)
messages = [
HumanMessage(content="A" * 100), # Long message
AIMessage(content="B" * 100), # Long message
HumanMessage(content="Recent question"),
]
with patch(
"openchatbi.context_manager.call_llm_chat_model_with_retry", return_value=AIMessage(content="Summary")
):
context_manager.manage_context_messages(messages)
result = messages
# Should only contain valid state message types
valid_types = {HumanMessage, AIMessage, ToolMessage}
assert all(type(msg) in valid_types for msg in result), "Should only contain valid state message types"
def test_context_management_with_tool_calls(self):
"""Test context management when AI messages have tool calls."""
config = ContextConfig(enabled=True)
context_manager = ContextManager(llm=Mock(), config=config)
ai_message_with_tools = AIMessage(
content="I'll help you with that.",
tool_calls=[{"name": "search_tool", "args": {"query": "test"}, "id": "call_123"}],
)
messages = [ai_message_with_tools, HumanMessage(content="Follow up")]
context_manager.manage_context_messages(messages)
result = messages
# AI message with tool calls should be preserved
ai_msgs = [msg for msg in result if isinstance(msg, AIMessage)]
assert len(ai_msgs) > 0
assert any(hasattr(msg, "tool_calls") and msg.tool_calls for msg in ai_msgs)
@patch("openchatbi.context_manager.call_llm_chat_model_with_retry")
def test_summarization_failure_fallback(self, mock_llm_call):
"""Test fallback behavior when summarization fails."""
# Mock LLM failure
mock_llm_call.side_effect = Exception("LLM unavailable")
config = ContextConfig(enabled=True)
context_manager = ContextManager(llm=Mock(), config=config)
# Create messages that would trigger summarization (no SystemMessage in state)
messages = [
HumanMessage(content="A" * 100), # Long messages to trigger
AIMessage(content="B" * 100),
HumanMessage(content="C" * 100),
AIMessage(content="D" * 100),
HumanMessage(content="Recent"),
]
context_manager.manage_context_messages(messages)
result = messages
# Should fallback to sliding window
assert len(result) <= len(messages)
# Should preserve recent messages and only contain valid state message types
assert any("Recent" in str(msg.content) for msg in result if hasattr(msg, "content"))
valid_types = {HumanMessage, AIMessage, ToolMessage}
assert all(type(msg) in valid_types for msg in result), "Should only contain valid state message types"
================================================
FILE: tests/context_management/test_context_config.py
================================================
"""Unit tests for context configuration."""
from openchatbi.context_config import ContextConfig, get_context_config, update_context_config
class TestContextConfig:
"""Test cases for ContextConfig class."""
def test_default_config_values(self):
"""Test that default configuration has expected values."""
config = ContextConfig()
# Test default values
assert config.enabled is True
assert config.summary_trigger_tokens == 12000
assert config.keep_recent_messages == 20
assert config.max_tool_output_length == 2000
assert config.max_sql_result_rows == 50
assert config.max_code_output_lines == 50
# Test boolean flags
assert config.enable_summarization is True
assert config.enable_conversation_summary is True
assert config.preserve_tool_errors is True
assert config.preserve_recent_sql is True
def test_custom_config_values(self):
"""Test creating config with custom values."""
config = ContextConfig(
enabled=False,
summary_trigger_tokens=8000,
keep_recent_messages=5,
max_tool_output_length=1000,
enable_summarization=False,
)
assert config.enabled is False
assert config.summary_trigger_tokens == 8000
assert config.keep_recent_messages == 5
assert config.max_tool_output_length == 1000
assert config.enable_summarization is False
# Other values should use defaults
assert config.max_sql_result_rows == 50
assert config.preserve_tool_errors is True
def test_config_validation_logic(self):
"""Test logical relationships in configuration."""
config = ContextConfig()
# Keep recent messages should be reasonable
assert config.keep_recent_messages > 0
assert config.keep_recent_messages < 100 # Sanity check
# Output limits should be positive
assert config.max_tool_output_length > 0
assert config.max_sql_result_rows > 0
assert config.max_code_output_lines > 0
# Token limits should be reasonable
assert config.summary_trigger_tokens > 0
def test_get_context_config(self):
"""Test getting context configuration."""
config = get_context_config()
assert isinstance(config, ContextConfig)
def test_update_context_config_single_value(self):
"""Test updating a single configuration value."""
original_trigger_tokens = get_context_config().summary_trigger_tokens
updated_config = update_context_config(summary_trigger_tokens=15000)
assert updated_config.summary_trigger_tokens == 15000
# Other values should remain unchanged
assert updated_config.keep_recent_messages == get_context_config().keep_recent_messages
def test_update_context_config_multiple_values(self):
"""Test updating multiple configuration values."""
updated_config = update_context_config(
summary_trigger_tokens=20000,
keep_recent_messages=15,
enable_summarization=False,
max_tool_output_length=3000,
)
assert updated_config.summary_trigger_tokens == 20000
assert updated_config.keep_recent_messages == 15
assert updated_config.enable_summarization is False
assert updated_config.max_tool_output_length == 3000
def test_update_context_config_invalid_attribute(self):
"""Test updating config with invalid attribute name."""
# Should not raise error, just ignore invalid attributes
config = update_context_config(invalid_attribute=123)
assert not hasattr(config, "invalid_attribute")
def test_update_context_config_returns_copy(self):
"""Test that update_context_config returns a modified copy."""
original_config = get_context_config()
updated_config = update_context_config(summary_trigger_tokens=30000)
# Original should be unchanged (if it's designed that way)
# Updated should have new values
assert updated_config.summary_trigger_tokens == 30000
class TestContextConfigPresets:
"""Test different configuration presets for common scenarios."""
def test_minimal_context_config(self):
"""Test configuration for minimal context management."""
config = ContextConfig(
enabled=True,
enable_summarization=False,
enable_conversation_summary=False,
max_tool_output_length=500,
)
assert config.enabled is True
assert config.enable_summarization is False
assert config.enable_conversation_summary is False
def test_aggressive_compression_config(self):
"""Test configuration for aggressive context compression."""
config = ContextConfig(
enabled=True,
summary_trigger_tokens=6000,
keep_recent_messages=5,
max_tool_output_length=1000,
max_sql_result_rows=10,
max_code_output_lines=20,
enable_summarization=True,
)
assert config.summary_trigger_tokens == 6000
assert config.keep_recent_messages == 5
assert config.max_tool_output_length == 1000
assert config.max_sql_result_rows == 10
assert config.max_code_output_lines == 20
def test_development_debug_config(self):
"""Test configuration suitable for development/debugging."""
config = ContextConfig(
enabled=True,
summary_trigger_tokens=40000,
keep_recent_messages=20,
max_tool_output_length=10000, # Don't trim much
preserve_tool_errors=True, # Always preserve errors
)
assert config.preserve_tool_errors is True
def test_production_optimized_config(self):
"""Test configuration optimized for production use."""
config = ContextConfig(
enabled=True,
summary_trigger_tokens=10000,
keep_recent_messages=8,
max_tool_output_length=1500,
max_sql_result_rows=15,
enable_summarization=True,
preserve_tool_errors=True,
)
assert config.summary_trigger_tokens == 10000
assert config.enable_summarization is True
assert config.preserve_tool_errors is True
class TestContextConfigEdgeCases:
"""Test edge cases and boundary conditions for context configuration."""
def test_zero_values(self):
"""Test configuration with zero values."""
config = ContextConfig(
summary_trigger_tokens=0,
keep_recent_messages=0,
max_tool_output_length=0,
)
# Should accept zero values (might cause issues in practice)
assert config.keep_recent_messages == 0
assert config.summary_trigger_tokens == 0
assert config.max_tool_output_length == 0
def test_very_large_values(self):
"""Test configuration with very large values."""
config = ContextConfig(
summary_trigger_tokens=900000,
keep_recent_messages=1000,
max_tool_output_length=100000,
)
assert config.keep_recent_messages == 1000
def test_inconsistent_token_limits(self):
"""Test configuration where summary trigger > max tokens."""
# Should accept but might cause logical issues
def test_all_features_disabled(self):
"""Test configuration with all features disabled."""
config = ContextConfig(
enabled=False,
enable_summarization=False,
enable_conversation_summary=False,
)
assert config.enabled is False
assert config.enable_summarization is False
assert config.enable_conversation_summary is False
def test_config_serialization(self):
"""Test that config can be converted to/from dict (if needed)."""
config = ContextConfig(summary_trigger_tokens=15000, enable_summarization=True)
# Test converting to dict-like representation
config_dict = {
"summary_trigger_tokens": config.summary_trigger_tokens,
"enable_summarization": config.enable_summarization,
"enabled": config.enabled,
}
assert config_dict["summary_trigger_tokens"] == 15000
assert config_dict["enable_summarization"] is True
assert config_dict["enabled"] is True
def test_config_immutability_simulation(self):
"""Test that config behaves consistently across operations."""
config1 = ContextConfig(summary_trigger_tokens=10000)
config2 = ContextConfig(summary_trigger_tokens=10000)
# Same values should be equal
assert config1.summary_trigger_tokens == config2.summary_trigger_tokens
assert config1.enabled == config2.enabled
def test_realistic_configuration_scenarios(self):
"""Test realistic configuration scenarios."""
# Small dataset scenario
small_config = ContextConfig()
# Large dataset scenario
large_config = ContextConfig()
# Interactive analysis scenario
interactive_config = ContextConfig(
keep_recent_messages=50, # Keep more context for back-and-forth
preserve_tool_errors=True,
max_tool_output_length=5000, # Don't trim too aggressively
)
assert interactive_config.keep_recent_messages > small_config.keep_recent_messages
assert interactive_config.preserve_tool_errors is True
================================================
FILE: tests/context_management/test_context_manager.py
================================================
"""Unit tests for ContextManager class."""
from unittest.mock import Mock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from openchatbi.context_config import ContextConfig
from openchatbi.context_manager import ContextManager
class TestContextManager:
"""Test cases for ContextManager class."""
@pytest.fixture
def mock_llm(self):
"""Mock LLM for testing."""
llm = Mock()
# Mock response for summarization
llm_response = AIMessage(content="This is a test summary of the conversation.")
with patch("openchatbi.context_manager.call_llm_chat_model_with_retry", return_value=llm_response):
yield llm
@pytest.fixture
def default_config(self):
"""Default context configuration for testing."""
return ContextConfig(
enabled=True,
summary_trigger_tokens=900,
keep_recent_messages=3,
max_tool_output_length=200,
max_sql_result_rows=5,
max_code_output_lines=10,
enable_conversation_summary=True,
enable_summarization=True,
)
@pytest.fixture
def context_manager(self, mock_llm, default_config):
"""Context manager instance for testing."""
return ContextManager(llm=mock_llm, config=default_config)
def test_token_estimation(self, context_manager):
"""Test token estimation functionality."""
# Test basic token estimation
short_text = "Hello world"
assert context_manager.estimate_tokens(short_text) == len(short_text) // 4
# Test longer text
long_text = "This is a longer text that should have more tokens estimated."
assert context_manager.estimate_tokens(long_text) > context_manager.estimate_tokens(short_text)
def test_message_token_estimation(self, context_manager):
"""Test token estimation for messages."""
messages = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
ToolMessage(content="Tool result", tool_call_id="123"),
]
total_tokens = context_manager.estimate_message_tokens(messages)
assert total_tokens > 0
# Should include content tokens plus metadata overhead
assert total_tokens > sum(len(str(msg.content)) // 4 for msg in messages)
def test_trim_short_tool_output(self, context_manager):
"""Test trimming tool output that's already short enough."""
short_output = "This is a short output."
result = context_manager.trim_tool_output(short_output)
assert result == short_output
def test_trim_long_generic_output(self, context_manager):
"""Test trimming long generic tool output."""
long_output = "A" * 500 # Much longer than max_tool_output_length (200)
result = context_manager.trim_tool_output(long_output)
assert len(result) < len(long_output)
assert "... [Output truncated] ..." in result
assert result.startswith("A")
assert result.endswith("A")
def test_trim_sql_output(self, context_manager):
"""Test trimming SQL output with structured data."""
sql_output = """SQL Query:
```sql
SELECT * FROM users WHERE age > 18;
```
Query Results (CSV format):
```csv
id,name,age,email
1,John,25,john@example.com
2,Jane,30,jane@example.com
3,Bob,22,bob@example.com
4,Alice,28,alice@example.com
5,Charlie,35,charlie@example.com
6,Diana,27,diana@example.com
7,Eve,31,eve@example.com
```
Visualization Created: bar chart has been automatically generated."""
result = context_manager.trim_tool_output(sql_output)
# Should preserve SQL query
assert "SELECT * FROM users WHERE age > 18;" in result
# Should preserve visualization info
assert "Visualization Created:" in result
# Should trim CSV data but keep structure
assert "```csv" in result
assert "rows omitted" in result
def test_trim_code_output(self, context_manager):
"""Test trimming Python code execution output."""
# Test long output without errors
long_code_output = "\n".join([f"Line {i}: Some output here" for i in range(50)])
result = context_manager.trim_tool_output(long_code_output)
assert len(result.split("\n")) < 50
assert "... [Output truncated] ..." in result
def test_preserve_error_output(self, context_manager):
"""Test that error outputs are preserved when configured."""
error_output = """Traceback (most recent call last):
File "test.py", line 1, in
print(undefined_variable)
NameError: name 'undefined_variable' is not defined"""
# With preserve_tool_errors=True (default in test config)
result = context_manager.trim_tool_output(error_output)
assert result == error_output # Should be preserved in full
# Test with preserve_tool_errors=False
context_manager.config.preserve_tool_errors = False
result = context_manager.trim_tool_output(error_output)
# Should still preserve because it's an error, but could be trimmed based on length
# Tool output trimming disable test removed - trimming is always enabled now
def test_conversation_summary_disabled(self, context_manager):
"""Test conversation summary when disabled."""
context_manager.config.enable_conversation_summary = False
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
summary = context_manager.summarize_conversation(messages)
assert summary == ""
@patch("openchatbi.context_manager.call_llm_chat_model_with_retry")
def test_conversation_summary_success(self, mock_llm_call, context_manager):
"""Test successful conversation summarization."""
# Mock successful LLM response
mock_response = AIMessage(content="Summary: User asked about data analysis.")
mock_llm_call.return_value = mock_response
messages = [
HumanMessage(content="Can you analyze our sales data?"),
AIMessage(content="I'll help you analyze the sales data."),
ToolMessage(content="Query results: 100 records", tool_call_id="123"),
HumanMessage(content="What are the trends?"),
AIMessage(content="The trends show increasing sales."),
HumanMessage(content="Recent question"), # This should be excluded from summary
]
summary = context_manager.summarize_conversation(messages)
assert summary.startswith("[Conversation Summary]:")
assert "Summary: User asked about data analysis." in summary
mock_llm_call.assert_called_once()
@patch("openchatbi.context_manager.call_llm_chat_model_with_retry")
def test_conversation_summary_failure(self, mock_llm_call, context_manager):
"""Test conversation summary when LLM call fails."""
# Mock LLM failure
mock_llm_call.side_effect = Exception("LLM service unavailable")
# Need more messages than keep_recent_messages (3) to trigger summarization
messages = [
HumanMessage(content="First message"),
AIMessage(content="First response"),
HumanMessage(content="Second message"),
AIMessage(content="Second response"),
HumanMessage(content="Third message"),
AIMessage(content="Third response"),
]
summary = context_manager.summarize_conversation(messages)
assert summary == "[Summary generation failed]"
def test_manage_context_disabled(self, context_manager):
"""Test context management when disabled."""
context_manager.config.enabled = False
messages = [HumanMessage(content="Test")]
context_manager.manage_context_messages(messages)
result = messages
assert result == messages # Should return unchanged
def test_manage_context_empty_messages(self, context_manager):
"""Test context management with empty message list."""
messages = []
context_manager.manage_context_messages(messages)
result = messages
assert result == []
def test_manage_context_tool_message_trimming(self, context_manager):
"""Test that tool messages are trimmed during context management."""
long_content = "A" * 500
# Add enough messages to trigger context management, with ToolMessage in historical part
# keep_recent_messages=3, so we need more than 3 messages after the ToolMessage
messages = [
HumanMessage(content="This is a long question that helps reach the token threshold " * 10),
ToolMessage(content=long_content, tool_call_id="123"), # This should be in historical part
AIMessage(content="This is a long response that helps reach the token threshold " * 10),
HumanMessage(content="Another long question to increase token count " * 10),
AIMessage(content="Response " * 20),
HumanMessage(content="Final question"),
AIMessage(content="Final response"),
]
context_manager.manage_context_messages(messages)
result = messages
# Find the tool message in results
tool_msg = next(msg for msg in result if isinstance(msg, ToolMessage))
assert len(str(tool_msg.content)) < len(long_content)
assert "... [Output truncated] ..." in str(tool_msg.content)
@patch("openchatbi.context_manager.call_llm_chat_model_with_retry")
def test_manage_context_with_summarization(self, mock_llm_call, context_manager):
"""Test context management triggering summarization."""
# Mock successful summarization
mock_response = AIMessage(content="Conversation summary here.")
mock_llm_call.return_value = mock_response
# Create many messages to trigger summarization
messages = []
for i in range(10):
messages.extend(
[
HumanMessage(content=f"Question {i}"),
AIMessage(content=f"Response {i}" * 100), # Long responses to increase token count
]
)
original_length = len(messages)
context_manager.manage_context_messages(messages)
result = messages
# Should have fewer messages than input
assert len(result) < original_length
# Should contain summary message
assert any("[Conversation Summary]:" in str(msg.content) for msg in result if hasattr(msg, "content"))
# Verify LLM was called for summarization
mock_llm_call.assert_called_once()
# Tool wrapper tests removed - we now handle context at state level
def test_format_messages_for_summary(self, context_manager):
"""Test message formatting for summary generation."""
messages = [
HumanMessage(content="User question"),
AIMessage(content="AI response"),
ToolMessage(content="Tool result with some data", tool_call_id="123"),
SystemMessage(content="System message"), # Should be excluded
]
formatted = context_manager._format_messages_for_summary(messages)
assert " User question " in formatted
assert "" in formatted and "AI response" in formatted
assert "tool_result" in formatted
assert "System message" not in formatted # System messages excluded
def test_format_long_ai_message_for_summary(self, context_manager):
"""Test that long AI messages are truncated in summary formatting."""
long_content = "A" * 1000
messages = [AIMessage(content=long_content)]
formatted = context_manager._format_messages_for_summary(messages)
assert len(formatted) < len(f"Assistant: {long_content}")
assert "... [truncated]" in formatted
# Tool wrapping tests removed - we now handle context at state level instead of wrapping tools
# Pytest fixtures and test data
@pytest.fixture
def sample_sql_output():
"""Sample SQL output for testing."""
return """SQL Query:
```sql
SELECT customer_id, SUM(amount) as total
FROM orders
WHERE order_date >= '2023-01-01'
GROUP BY customer_id
ORDER BY total DESC;
```
Query Results (CSV format):
```csv
customer_id,total
1001,15420.50
1002,12350.75
1003,11200.00
1004,9875.25
1005,8650.00
1006,7500.50
1007,6200.75
1008,5800.00
1009,4950.25
1010,4200.00
```
Visualization Created: bar chart has been automatically generated and will be displayed in the UI."""
@pytest.fixture
def sample_error_output():
"""Sample error output for testing."""
return """Traceback (most recent call last):
File "/app/code.py", line 15, in analyze_data
result = df.groupby('nonexistent_column').sum()
File "/usr/local/lib/python3.9/site-packages/pandas/core/groupby/groupby.py", line 1647, in sum
return self._cython_transform("sum", numeric_only=numeric_only, **kwargs)
KeyError: 'nonexistent_column'
Error: Column 'nonexistent_column' not found in DataFrame. Available columns: ['customer_id', 'order_date', 'amount', 'product_id']"""
================================================
FILE: tests/context_management/test_edge_cases.py
================================================
"""Edge cases for context management."""
from unittest.mock import Mock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from openchatbi.context_config import ContextConfig
from openchatbi.context_manager import ContextManager
class TestContextManagementEdgeCases:
"""Edge cases and boundary conditions for context management."""
@pytest.fixture
def edge_case_config(self):
"""Configuration for edge case testing."""
return ContextConfig(
enabled=True,
summary_trigger_tokens=800,
keep_recent_messages=2,
max_tool_output_length=50,
)
@pytest.fixture
def context_manager(self, edge_case_config):
"""Context manager for edge case testing."""
return ContextManager(llm=Mock(), config=edge_case_config)
def test_empty_and_none_inputs(self, context_manager):
"""Test handling of empty and None inputs."""
# Empty list
messages = []
context_manager.manage_context_messages(messages)
assert messages == []
# List with None elements (should be filtered out gracefully)
messages = [HumanMessage(content="Test"), None, AIMessage(content="Response")]
# Filter out None values before passing to context manager
filtered_messages = [msg for msg in messages if msg is not None]
context_manager.manage_context_messages(filtered_messages)
result = filtered_messages
assert len(result) == 2
def test_malformed_messages(self, context_manager):
"""Test handling of malformed messages."""
# Message with None content
try:
malformed_msg = HumanMessage(content=None)
messages = [malformed_msg]
context_manager.manage_context_messages(messages)
result = messages
# Should handle gracefully
assert isinstance(result, list)
except Exception as e:
# If it raises an exception, it should be a reasonable one
assert "content" in str(e).lower()
def test_extremely_long_single_message(self, context_manager):
"""Test handling of extremely long single messages."""
# Create a message longer than the entire context limit
very_long_content = "A" * 100000 # Much longer than context limit
long_message = HumanMessage(content=very_long_content)
messages = [long_message]
context_manager.manage_context_messages(messages)
result = messages
# Should still return the message (context management doesn't trim individual message content)
assert len(result) == 1
assert isinstance(result[0], HumanMessage)
def test_tool_message_without_tool_call_id(self, context_manager):
"""Test handling of tool messages without proper tool_call_id."""
try:
# This might raise an error depending on LangChain's validation
tool_msg = ToolMessage(content="Result", tool_call_id="")
messages = [tool_msg]
context_manager.manage_context_messages(messages)
result = messages
assert isinstance(result, list)
except Exception:
# If LangChain validates and raises, that's acceptable
pass
def test_circular_references_in_content(self, context_manager):
"""Test handling of complex content that might cause issues."""
# Content with special characters and formatting
special_content = (
"""
Content with:
- Unicode: 🚀 中文 العربية
- Code blocks: ```python\nprint("hello")\n```
- JSON: {"key": "value", "nested": {"array": [1,2,3]}}
- HTML: content
- URLs: https://example.com/path?param=value
- Very long line: """
+ "X" * 1000
)
message = HumanMessage(content=special_content)
messages = [message]
context_manager.manage_context_messages(messages)
result = messages
assert len(result) == 1
assert isinstance(result[0], HumanMessage)
def test_zero_configuration_values(self):
"""Test behavior with zero configuration values."""
zero_config = ContextConfig(
enabled=True,
summary_trigger_tokens=0,
keep_recent_messages=0,
max_tool_output_length=0,
)
context_manager = ContextManager(llm=Mock(), config=zero_config)
messages = [HumanMessage(content="Test")]
# Should handle zero values gracefully
context_manager.manage_context_messages(messages)
result = messages
assert isinstance(result, list)
def test_negative_configuration_values(self):
"""Test behavior with negative configuration values."""
negative_config = ContextConfig(
enabled=True,
summary_trigger_tokens=-50,
keep_recent_messages=-5,
max_tool_output_length=-10,
)
context_manager = ContextManager(llm=Mock(), config=negative_config)
messages = [HumanMessage(content="Test")]
# Should handle negative values gracefully (might treat as disabled)
context_manager.manage_context_messages(messages)
result = messages
assert isinstance(result, list)
def test_unicode_and_encoding_edge_cases(self, context_manager):
"""Test handling of various Unicode and encoding scenarios."""
unicode_messages = [
HumanMessage(content="English text"),
HumanMessage(content="中文内容测试"),
HumanMessage(content="العربية"),
HumanMessage(content="Русский текст"),
HumanMessage(content="🚀🎉💡🔥"), # Emojis
HumanMessage(content="Mixed: Hello 世界 🌍"),
ToolMessage(content="Unicode tool result: café naïve résumé", tool_call_id="unicode_1"),
]
context_manager.manage_context_messages(unicode_messages)
result = unicode_messages
# Should handle all Unicode content
assert len(result) > 0
assert all(isinstance(msg, (HumanMessage, AIMessage, ToolMessage)) for msg in result)
def test_extremely_nested_or_complex_structures(self, context_manager):
"""Test handling of complex nested data structures in tool outputs."""
# Simulate deeply nested JSON output
nested_data = {"level1": {"level2": {"level3": {"data": ["item"] * 1000}}}}
complex_output = str(nested_data) * 100 # Make it very large
# Create messages so the tool message is in historical part (not recent)
# keep_recent_messages=2, so add more than 2 messages after the tool message
messages = [
ToolMessage(content=complex_output, tool_call_id="complex_1"), # Historical part
HumanMessage(content="Question 1"),
AIMessage(content="Response 1"),
HumanMessage(content="Recent question"), # Recent part starts here
]
context_manager.manage_context_messages(messages)
result = messages
# Should trim the complex output since it's in historical part
tool_msg = next(msg for msg in result if isinstance(msg, ToolMessage))
assert len(str(tool_msg.content)) < len(complex_output)
def test_sql_output_edge_cases(self, context_manager):
"""Test SQL output trimming with edge cases."""
# SQL with no results
empty_sql_output = """SQL Query:
```sql
SELECT * FROM users WHERE id = -1;
```
Query Results (CSV format):
```csv
id,name
```"""
# SQL with single row
single_row_sql = """SQL Query:
```sql
SELECT COUNT(*) as total FROM users;
```
Query Results (CSV format):
```csv
total
42
```"""
# Malformed SQL output
malformed_sql = """Something that looks like SQL but isn't:
```sql
INVALID QUERY HERE
```
Random text after"""
test_cases = [empty_sql_output, single_row_sql, malformed_sql]
for sql_output in test_cases:
tool_msg = ToolMessage(content=sql_output, tool_call_id="sql_test")
messages = [tool_msg]
context_manager.manage_context_messages(messages)
result = messages
# Should handle all cases gracefully
assert len(result) == 1
assert isinstance(result[0], ToolMessage)
def test_conversation_state_consistency(self, context_manager):
"""Test that conversation state remains consistent through management."""
# Create a conversation with specific patterns (no SystemMessage in state)
messages = [
HumanMessage(content="Question 1"),
AIMessage(content="Response 1"),
ToolMessage(content="Tool result 1", tool_call_id="tool_1"),
HumanMessage(content="Question 2"),
AIMessage(
content="Response 2 with tool calls",
tool_calls=[{"name": "test_tool", "args": {"param": "value"}, "id": "call_1"}],
),
ToolMessage(content="Tool result 2", tool_call_id="call_1"),
HumanMessage(content="Final question"),
]
with patch(
"openchatbi.context_manager.call_llm_chat_model_with_retry", return_value=AIMessage(content="Summary")
):
context_manager.manage_context_messages(messages)
result = messages
# Should maintain message type consistency (only valid state message types)
message_types = [type(msg) for msg in result]
valid_types = {HumanMessage, AIMessage, ToolMessage}
assert all(
msg_type in valid_types for msg_type in message_types
), "Should only contain valid state message types"
# Should not have orphaned tool messages without corresponding AI messages
for i, msg in enumerate(result):
if isinstance(msg, ToolMessage):
# There should be an AI message with tool calls before this
previous_ai_msgs = [m for m in result[:i] if isinstance(m, AIMessage)]
assert len(previous_ai_msgs) > 0, "Tool message should have corresponding AI message"
================================================
FILE: tests/context_management/test_runner.py
================================================
"""Test runner script for context management tests."""
import argparse
import subprocess
import sys
from pathlib import Path
def run_tests(test_type="all", verbose=False, coverage=False):
"""Run context management tests.
Args:
test_type: Type of tests to run ('all', 'unit', 'integration', 'edge_cases')
verbose: Enable verbose output
coverage: Enable coverage reporting
"""
# Base pytest command
cmd = ["python", "-m", "pytest"]
# Test directory
test_dir = Path(__file__).parent
# Add specific test files based on type
if test_type == "all":
cmd.append(str(test_dir))
elif test_type == "unit":
cmd.extend([str(test_dir / "test_context_manager.py"), str(test_dir / "test_context_config.py")])
elif test_type == "integration":
cmd.append(str(test_dir / "test_agent_graph_integration.py"))
elif test_type == "edge_cases":
cmd.extend([str(test_dir / "test_edge_cases.py"), str(test_dir / "test_state_operations.py")])
else:
print(f"Unknown test type: {test_type}")
return False
# Add verbose flag
if verbose:
cmd.append("-v")
# Add coverage
if coverage:
cmd.extend(
[
"--cov=openchatbi.context_manager",
"--cov=openchatbi.context_config",
"--cov-report=html",
"--cov-report=term-missing",
]
)
# Add other useful flags
cmd.extend(
[
"--tb=short", # Shorter traceback format
"-x", # Stop on first failure
"--strict-markers", # Strict marker checking
]
)
print(f"Running command: {' '.join(cmd)}")
print("-" * 50)
# Run the tests
try:
result = subprocess.run(cmd, check=False)
return result.returncode == 0
except KeyboardInterrupt:
print("\nTests interrupted by user")
return False
def main():
"""Main function for test runner."""
parser = argparse.ArgumentParser(description="Run context management tests")
parser.add_argument(
"--type",
"-t",
choices=["all", "unit", "integration", "edge_cases"],
default="all",
help="Type of tests to run (default: all)",
)
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
parser.add_argument("--coverage", "-c", action="store_true", help="Enable coverage reporting")
args = parser.parse_args()
success = run_tests(test_type=args.type, verbose=args.verbose, coverage=args.coverage)
if success:
print("\n✅ All tests passed!")
sys.exit(0)
else:
print("\n❌ Some tests failed!")
sys.exit(1)
if __name__ == "__main__":
main()
================================================
FILE: tests/context_management/test_state_operations.py
================================================
"""Tests for message-based context management operations."""
from unittest.mock import Mock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from openchatbi.context_config import ContextConfig
from openchatbi.context_manager import ContextManager
class TestMessageBasedContextManagement:
"""Test message-based context management with direct modification."""
@pytest.fixture
def test_config(self):
"""Configuration for testing message operations."""
return ContextConfig(
enabled=True,
summary_trigger_tokens=300, # Lower threshold to trigger management
keep_recent_messages=3,
max_tool_output_length=200,
preserve_tool_errors=True,
preserve_recent_sql=True,
)
@pytest.fixture
def context_manager(self, test_config):
"""Context manager for testing."""
mock_llm = Mock()
return ContextManager(llm=mock_llm, config=test_config)
def test_no_operations_when_disabled(self, context_manager):
"""Test that no operations are performed when context management is disabled."""
context_manager.config.enabled = False
messages = [HumanMessage(content="Test", id="test_1")]
original_messages = messages.copy()
context_manager.manage_context_messages(messages)
assert messages == original_messages # Should be unchanged
def test_no_operations_when_under_limit(self, context_manager):
"""Test that no operations are performed when context is under token limit."""
# Short messages that won't trigger context management
messages = [HumanMessage(content="Hi", id="human_1"), AIMessage(content="Hello", id="ai_1")]
original_messages = messages.copy()
context_manager.manage_context_messages(messages)
assert messages == original_messages # Should be unchanged
def test_historical_tool_compression(self, context_manager):
"""Test compression of historical tool messages."""
# Disable conversation summarization to test only tool compression
context_manager.config.enable_conversation_summary = False
context_manager.config.enable_summarization = False
# Create messages with large historical tool outputs
messages = [
HumanMessage(content="Query data", id="human_1"),
AIMessage(content="Running query", id="ai_1"),
# Large historical tool message (should be compressed)
ToolMessage(content="A" * 1000, tool_call_id="query_1", id="tool_1_historical"), # Large content
HumanMessage(content="More analysis", id="human_2"),
AIMessage(content="Analyzing", id="ai_2"),
# Another large historical tool message
ToolMessage(content="B" * 800, tool_call_id="query_2", id="tool_2_historical"), # Large content
# Recent messages (should be preserved)
HumanMessage(content="Recent question", id="human_recent"),
AIMessage(content="Recent response", id="ai_recent"),
ToolMessage(content="Recent result", tool_call_id="recent_1", id="tool_recent"),
]
original_count = len(messages)
context_manager.manage_context_messages(messages)
# Should have same number of messages but some content should be compressed
assert len(messages) == original_count
# Check that historical tool messages are compressed
historical_tool_msgs = [
msg
for msg in messages
if isinstance(msg, ToolMessage) and msg.id in ["tool_1_historical", "tool_2_historical"]
]
for msg in historical_tool_msgs:
assert len(str(msg.content)) < 1000, "Historical tool messages should be compressed"
def test_error_message_preservation(self, context_manager):
"""Test that error messages are preserved even if they're historical."""
error_content = """Traceback (most recent call last):
File "test.py", line 1, in
raise ValueError("Test error")
ValueError: Test error"""
messages = [
HumanMessage(content="Run code", id="human_1"),
AIMessage(content="Executing", id="ai_1"),
# Historical error message (should be preserved)
ToolMessage(content=error_content, tool_call_id="code_1", id="error_tool_historical"),
# Recent messages
HumanMessage(content="What happened?", id="human_recent"),
AIMessage(content="There was an error", id="ai_recent"),
]
original_error_content = messages[2].content
context_manager.manage_context_messages(messages)
# Error message should be preserved
error_msg = next(msg for msg in messages if msg.id == "error_tool_historical")
assert error_msg.content == original_error_content, "Error messages should be preserved"
def test_sql_content_preservation(self, context_manager):
"""Test that SQL content is preserved when configured."""
sql_content = """SQL Query:
```sql
SELECT * FROM users WHERE active = 1;
```
Query Results (CSV format):
```csv
id,name,email
1,John,john@example.com
2,Jane,jane@example.com
```"""
messages = [
HumanMessage(content="Get user data", id="human_1"),
AIMessage(content="Querying users", id="ai_1"),
# Historical SQL result (should be preserved if preserve_recent_sql=True)
ToolMessage(content=sql_content, tool_call_id="sql_1", id="sql_tool_historical"),
# Recent messages
HumanMessage(content="Analyze results", id="human_recent"),
AIMessage(content="Analyzing", id="ai_recent"),
]
# Test with SQL preservation enabled
context_manager.config.preserve_recent_sql = True
original_sql_content = messages[2].content
context_manager.manage_context_messages(messages)
# SQL should be preserved when preserve_recent_sql=True
sql_msg = next(msg for msg in messages if msg.id == "sql_tool_historical")
assert sql_msg.content == original_sql_content, "SQL content should be preserved when configured"
@patch("openchatbi.context_manager.call_llm_chat_model_with_retry")
def test_conversation_summarization(self, mock_llm_call, context_manager):
"""Test conversation summarization with message modification."""
# Mock LLM response for summarization
mock_llm_call.return_value = AIMessage(content="Summary of the conversation")
# Create a long conversation that will trigger summarization
messages = []
# Add many historical messages
for i in range(20):
messages.extend(
[
HumanMessage(content=f"Question {i}" * 10, id=f"human_{i}"),
AIMessage(content=f"Response {i}" * 10, id=f"ai_{i}"),
]
)
# Add recent messages
messages.extend(
[
HumanMessage(content="Recent question", id="human_recent"),
AIMessage(content="Recent response", id="ai_recent"),
ToolMessage(content="Recent result", tool_call_id="recent_1", id="tool_recent"),
]
)
original_count = len(messages)
context_manager.manage_context_messages(messages)
# Should have fewer messages due to summarization
assert len(messages) < original_count
# Should have a summary message
summary_msgs = [msg for msg in messages if isinstance(msg, AIMessage) and "Summary" in str(msg.content)]
assert len(summary_msgs) > 0, "Should create a summary message"
def test_content_type_detection(self, context_manager):
"""Test content type detection methods."""
# Test error content detection
error_contents = [
"Error: Something went wrong",
"Traceback (most recent call last):\n File test.py",
"ValueError: Invalid input",
"Connection failed with status 500",
]
for content in error_contents:
assert context_manager._is_error_content(content), f"Should detect error in: {content[:50]}"
# Test SQL content detection
sql_contents = [
"```sql\nSELECT * FROM users;\n```",
"Query results: 100 rows returned",
"SQL Query:\nSELECT id FROM table",
]
for content in sql_contents:
assert context_manager._is_sql_content(content), f"Should detect SQL in: {content[:50]}"
# Test data query result detection
data_contents = [
"```csv\nid,name\n1,test\n```",
"Query Results (CSV format):",
"Found 500 records in the database",
]
for content in data_contents:
assert context_manager._is_data_query_result(content), f"Should detect data result in: {content[:50]}"
def test_should_compress_logic(self, context_manager):
"""Test the logic for determining whether to compress historical tool messages."""
# Short content should not be compressed
short_msg = ToolMessage(content="Short", tool_call_id="test", id="short")
assert not context_manager._should_compress_historical_tool_message(short_msg, "Short")
# Long non-error content should be compressed
long_content = "A" * 1000
long_msg = ToolMessage(content=long_content, tool_call_id="test", id="long")
assert context_manager._should_compress_historical_tool_message(long_msg, long_content)
# Long error content should not be compressed (if preserve_tool_errors=True)
error_content = "Error: " + "A" * 1000
error_msg = ToolMessage(content=error_content, tool_call_id="test", id="error")
context_manager.config.preserve_tool_errors = True
assert not context_manager._should_compress_historical_tool_message(error_msg, error_content)
# But should be compressed if preserve_tool_errors=False
context_manager.config.preserve_tool_errors = False
assert context_manager._should_compress_historical_tool_message(error_msg, error_content)
def test_recent_messages_always_preserved(self, context_manager):
"""Test that recent messages are always preserved regardless of content."""
# Create messages where recent ones are large but should still be preserved
messages = []
# Historical messages
for i in range(10):
messages.extend(
[
HumanMessage(content=f"Historical {i}", id=f"hist_human_{i}"),
ToolMessage(content="A" * 500, tool_call_id=f"hist_{i}", id=f"hist_tool_{i}"),
]
)
# Recent messages (including large tool output)
messages.extend(
[
HumanMessage(content="Recent question", id="recent_human"),
AIMessage(content="Recent response", id="recent_ai"),
ToolMessage(content="B" * 1000, tool_call_id="recent", id="recent_tool"), # Large but recent
]
)
original_count = len(messages)
context_manager.manage_context_messages(messages)
# Recent messages should be preserved (even if content gets compressed due to summarization)
recent_ids = ["recent_human", "recent_ai", "recent_tool"]
remaining_recent = [msg for msg in messages if hasattr(msg, "id") and msg.id in recent_ids]
# All recent message IDs should still be present (even if summarization occurred)
assert len(remaining_recent) >= 2, "Most recent messages should be preserved"
def test_message_order_preservation(self, context_manager):
"""Test that message ordering is preserved during context management."""
# Disable conversation summarization to test only tool compression
context_manager.config.enable_conversation_summary = False
context_manager.config.enable_summarization = False
# Create messages with specific order
messages = [
HumanMessage(content="Question 1", id="human_1"),
AIMessage(content="Response 1", id="ai_1"),
ToolMessage(content="A" * 1000, tool_call_id="tool_1", id="tool_1"), # Will be compressed
HumanMessage(content="Question 2", id="human_2"),
AIMessage(content="Response 2", id="ai_2"),
ToolMessage(content="B" * 1000, tool_call_id="tool_2", id="tool_2"), # Will be compressed
HumanMessage(content="Recent question", id="human_recent"), # Recent, should not be compressed
AIMessage(content="Recent response", id="ai_recent"), # Recent
ToolMessage(
content="C" * 1000, tool_call_id="tool_recent", id="tool_recent"
), # Recent, should not be compressed
]
original_order = [msg.id for msg in messages if hasattr(msg, "id")]
context_manager.manage_context_messages(messages)
# Extract the IDs in the new order
result_order = [msg.id for msg in messages if hasattr(msg, "id")]
# The order should be preserved
assert result_order == original_order, "Message order should be preserved"
# Verify that historical tool messages were actually compressed
historical_tools = [msg for msg in messages if isinstance(msg, ToolMessage) and msg.id in ["tool_1", "tool_2"]]
for msg in historical_tools:
assert len(str(msg.content)) < 1000, "Historical tool messages should be compressed"
================================================
FILE: tests/test_catalog_loader.py
================================================
"""Tests for catalog loader functionality."""
from unittest.mock import Mock, patch
import pytest
from openchatbi.catalog.catalog_loader import DataCatalogLoader, load_catalog_from_data_warehouse
class TestDataCatalogLoader:
"""Test DataCatalogLoader functionality."""
@pytest.fixture
def mock_engine(self):
"""Mock SQLAlchemy engine."""
engine = Mock()
return engine
def test_catalog_loader_initialization(self, mock_engine):
"""Test DataCatalogLoader initialization."""
with patch("openchatbi.catalog.catalog_loader.inspect") as mock_inspect:
mock_inspect.return_value = Mock()
loader = DataCatalogLoader(engine=mock_engine, include_tables=["table1", "table2"])
assert loader.engine == mock_engine
assert loader.include_tables == ["table1", "table2"]
def test_catalog_loader_without_include_tables(self, mock_engine):
"""Test DataCatalogLoader without include tables."""
with patch("openchatbi.catalog.catalog_loader.inspect") as mock_inspect:
mock_inspect.return_value = Mock()
loader = DataCatalogLoader(engine=mock_engine, include_tables=None)
assert loader.include_tables is None
def test_get_tables_and_columns(self, mock_engine):
"""Test getting tables and columns metadata."""
# Mock inspector
mock_inspector = Mock()
mock_inspector.get_table_names.return_value = ["table1", "table2"]
mock_inspector.get_columns.return_value = [
{"name": "col1", "type": "VARCHAR(50)", "comment": "Test column", "default": None, "primary_key": False}
]
with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector):
loader = DataCatalogLoader(engine=mock_engine, include_tables=["table1"])
result = loader.get_tables_and_columns()
assert "table1" in result
assert len(result["table1"]) == 1
assert result["table1"][0]["column_name"] == "col1"
def test_get_table_indexes(self, mock_engine):
"""Test getting table indexes."""
mock_inspector = Mock()
mock_inspector.get_indexes.return_value = [{"name": "idx_test", "column_names": ["col1"]}]
with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector):
loader = DataCatalogLoader(engine=mock_engine)
result = loader.get_table_indexes("table1")
assert len(result) == 1
assert result[0]["name"] == "idx_test"
def test_get_foreign_keys(self, mock_engine):
"""Test getting foreign keys."""
mock_inspector = Mock()
mock_inspector.get_foreign_keys.return_value = [
{"name": "fk_test", "constrained_columns": ["col1"], "referred_table": "ref_table"}
]
with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector):
loader = DataCatalogLoader(engine=mock_engine)
result = loader.get_foreign_keys("table1")
assert len(result) == 1
assert result[0]["name"] == "fk_test"
def test_save_to_catalog_store_success(self, mock_engine):
"""Test saving to catalog store successfully."""
mock_catalog_store = Mock()
mock_catalog_store.save_table_information.return_value = True
mock_catalog_store.save_table_sql_examples.return_value = True
mock_catalog_store.save_table_selection_examples.return_value = True
mock_inspector = Mock()
mock_inspector.get_table_names.return_value = ["table1"]
mock_inspector.get_columns.return_value = [
{"name": "col1", "type": "VARCHAR(50)", "comment": "Test column", "default": None, "primary_key": False}
]
mock_inspector.get_table_comment.return_value = {"text": "Test table"}
with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector):
loader = DataCatalogLoader(engine=mock_engine, include_tables=["table1"])
result = loader.save_to_catalog_store(mock_catalog_store, "test_db")
assert result == True
mock_catalog_store.save_table_information.assert_called()
mock_catalog_store.save_table_sql_examples.assert_called()
mock_catalog_store.save_table_selection_examples.assert_called()
def test_save_to_catalog_store_failure(self, mock_engine):
"""Test handling catalog store save failures."""
mock_catalog_store = Mock()
mock_catalog_store.save_table_information.return_value = False
mock_inspector = Mock()
mock_inspector.get_table_names.return_value = ["table1"]
mock_inspector.get_columns.return_value = []
with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector):
loader = DataCatalogLoader(engine=mock_engine)
result = loader.save_to_catalog_store(mock_catalog_store)
assert result == False
def test_load_catalog_from_data_warehouse(self):
"""Test main entry point for catalog loading."""
mock_catalog_store = Mock()
mock_catalog_store.get_data_warehouse_config.return_value = {
"uri": "test://user@host/db",
"include_tables": ["table1"],
"database_name": "test_db",
}
mock_catalog_store.get_sql_engine.return_value = Mock()
with patch("openchatbi.catalog.catalog_loader.DataCatalogLoader") as mock_loader_class:
mock_loader = Mock()
mock_loader.save_to_catalog_store.return_value = True
mock_loader_class.return_value = mock_loader
result = load_catalog_from_data_warehouse(mock_catalog_store)
assert result == True
mock_loader.save_to_catalog_store.assert_called_once()
def test_error_handling_in_get_tables_and_columns(self, mock_engine):
"""Test error handling in get_tables_and_columns method."""
mock_inspector = Mock()
mock_inspector.get_table_names.side_effect = Exception("Database error")
with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector):
loader = DataCatalogLoader(engine=mock_engine)
result = loader.get_tables_and_columns()
assert result == {}
================================================
FILE: tests/test_catalog_store.py
================================================
"""Tests for catalog store functionality."""
import pytest
from openchatbi.catalog.catalog_store import CatalogStore
from openchatbi.catalog.store.file_system import FileSystemCatalogStore
class TestCatalogStore:
"""Test base CatalogStore functionality."""
def test_catalog_store_is_abstract(self):
"""Test that CatalogStore cannot be instantiated directly."""
with pytest.raises(TypeError):
CatalogStore()
def test_catalog_store_interface_methods(self):
"""Test that CatalogStore defines required interface methods."""
# Check that abstract methods exist
assert hasattr(CatalogStore, "get_table_list")
assert hasattr(CatalogStore, "get_column_list")
assert hasattr(CatalogStore, "get_table_information")
assert hasattr(CatalogStore, "get_data_warehouse_config")
assert hasattr(CatalogStore, "get_sql_engine")
assert hasattr(CatalogStore, "save_table_information")
class TestFileSystemCatalogStore:
"""Test FileSystemCatalogStore functionality."""
def test_filesystem_store_initialization(self, temp_dir):
"""Test FileSystemCatalogStore initialization."""
data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}
data_path = str(temp_dir)
store = FileSystemCatalogStore(data_path=data_path, data_warehouse_config=data_warehouse_config)
assert store.data_path == data_path
assert isinstance(store, CatalogStore)
def test_get_tables_from_csv(self, mock_catalog_store):
"""Test getting tables from CSV file."""
tables = mock_catalog_store.get_table_list()
assert isinstance(tables, list)
assert len(tables) >= 1
def test_get_columns_from_csv(self, mock_catalog_store):
"""Test getting columns from CSV file."""
columns = mock_catalog_store.get_column_list("test_table", "test")
assert isinstance(columns, list)
if columns:
column = columns[0]
assert "column_name" in column or "name" in column
assert "data_type" in column or "type" in column
def test_get_table_info(self, mock_catalog_store):
"""Test getting table information."""
table_info = mock_catalog_store.get_table_information("test.test_table")
assert isinstance(table_info, dict)
def test_get_tables_file_not_found(self, temp_dir):
"""Test handling when tables file doesn't exist."""
empty_dir = temp_dir / "empty"
empty_dir.mkdir()
data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}
store = FileSystemCatalogStore(data_path=str(empty_dir), data_warehouse_config=data_warehouse_config)
# Should handle missing file gracefully
tables = store.get_table_list()
assert isinstance(tables, list)
def test_get_columns_file_not_found(self, temp_dir):
"""Test handling when columns file doesn't exist."""
empty_dir = temp_dir / "empty"
empty_dir.mkdir()
data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}
store = FileSystemCatalogStore(data_path=str(empty_dir), data_warehouse_config=data_warehouse_config)
# Should handle missing file gracefully
columns = store.get_column_list("nonexistent_table")
assert isinstance(columns, list)
def test_get_tables_malformed_csv(self, temp_dir):
"""Test handling malformed CSV files."""
# Create malformed CSV
malformed_csv = temp_dir / "table_columns.csv"
malformed_csv.write_text("invalid,csv,format\\nno,proper\\nheaders")
data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}
store = FileSystemCatalogStore(data_path=str(temp_dir), data_warehouse_config=data_warehouse_config)
# Should handle malformed CSV gracefully
tables = store.get_table_list()
assert isinstance(tables, list)
def test_get_tables_pandas_error(self, temp_dir):
"""Test handling pandas errors."""
data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}
store = FileSystemCatalogStore(data_path=str(temp_dir), data_warehouse_config=data_warehouse_config)
# Should handle pandas errors gracefully
tables = store.get_table_list()
assert isinstance(tables, list)
def test_get_table_schema(self, mock_catalog_store):
"""Test getting complete table schema."""
# Use get_table_information instead of get_table_schema
schema = mock_catalog_store.get_table_information("test.test_table")
assert isinstance(schema, dict)
def test_search_tables(self, mock_catalog_store):
"""Test searching for tables by keyword."""
# This method might not exist in current implementation
# but it's a common catalog feature
if hasattr(mock_catalog_store, "search_tables"):
results = mock_catalog_store.search_tables("test")
assert isinstance(results, list)
def test_get_all_table_names(self, mock_catalog_store):
"""Test getting all table names."""
tables = mock_catalog_store.get_table_list()
# get_table_list() returns list of strings (table names), not dictionaries
assert isinstance(tables, list)
# Verify all items are strings
for table_name in tables:
assert isinstance(table_name, str)
def test_case_insensitive_table_lookup(self, mock_catalog_store):
"""Test case-insensitive table lookups."""
# Test with different cases
test_cases = ["test_table", "TEST_TABLE", "Test_Table"]
for table_name in test_cases:
columns = mock_catalog_store.get_column_list(table_name)
assert isinstance(columns, list)
def test_data_path_validation(self):
"""Test data path validation."""
data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}
# Test with None path
with pytest.raises((ValueError, TypeError)):
FileSystemCatalogStore(data_path=None, data_warehouse_config=data_warehouse_config)
# Test with empty string
with pytest.raises((ValueError, FileNotFoundError)):
FileSystemCatalogStore(data_path="", data_warehouse_config=data_warehouse_config)
def test_concurrent_access(self, mock_catalog_store):
"""Test concurrent access to catalog store."""
import threading
import time
results = []
errors = []
def worker():
try:
tables = mock_catalog_store.get_table_list()
results.append(len(tables))
time.sleep(0.01)
columns = mock_catalog_store.get_column_list("test_table", "test")
results.append(len(columns))
except Exception as e:
errors.append(e)
# Create multiple threads
threads = []
for _ in range(5):
thread = threading.Thread(target=worker)
threads.append(thread)
# Start all threads
for thread in threads:
thread.start()
# Wait for completion
for thread in threads:
thread.join()
# Should not have errors from concurrent access
assert len(errors) == 0
assert len(results) > 0
================================================
FILE: tests/test_config_loader.py
================================================
"""Tests for configuration loading functionality."""
from unittest.mock import MagicMock, patch
import pytest
import yaml
from openchatbi.config_loader import Config, ConfigLoader
class TestConfigLoader:
"""Test configuration loading functionality."""
def test_config_initialization(self):
"""Test Config model initialization."""
from unittest.mock import MagicMock
mock_llm = MagicMock()
mock_embedding = MagicMock()
config = Config(organization="TestOrg", dialect="presto", default_llm=mock_llm, embedding_model=mock_embedding)
assert config.organization == "TestOrg"
assert config.dialect == "presto"
assert config.default_llm == mock_llm
assert config.embedding_model == mock_embedding
def test_config_from_dict(self):
"""Test creating Config from dictionary."""
from unittest.mock import MagicMock
mock_llm = MagicMock()
mock_embedding = MagicMock()
config_dict = {
"organization": "TestOrg",
"dialect": "mysql",
"default_llm": mock_llm,
"embedding_model": mock_embedding,
}
config = Config.from_dict(config_dict)
assert config.organization == "TestOrg"
assert config.dialect == "mysql"
assert config.default_llm == mock_llm
assert config.embedding_model == mock_embedding
def test_config_loader_initialization(self):
"""Test ConfigLoader initialization."""
loader = ConfigLoader()
# Initially, config should be None until loaded
# Don't assert _config state since it depends on previous tests
def test_load_config_from_file(self, temp_dir):
"""Test loading configuration from YAML file."""
config_data = {
"organization": "TestOrg",
"dialect": "presto",
"default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}},
"embedding_model": {
"class": "langchain_openai.OpenAIEmbeddings",
"params": {"model": "text-embedding-ada-002"},
},
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
}
config_file = temp_dir / "test_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loader = ConfigLoader()
with (
patch("langchain_openai.ChatOpenAI") as mock_openai,
patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings,
):
# Create a proper mock that satisfies BaseChatModel interface
from langchain_core.language_models import BaseChatModel
mock_llm_instance = MagicMock(spec=BaseChatModel)
mock_embedding_instance = MagicMock()
mock_openai.return_value = mock_llm_instance
mock_embeddings.return_value = mock_embedding_instance
loader.load(str(config_file))
config = loader.get()
assert config.organization == "TestOrg"
assert config.dialect == "presto"
assert config.default_llm == mock_llm_instance
assert config.embedding_model == mock_embedding_instance
def test_load_config_missing_file(self):
"""Test handling of missing configuration file."""
loader = ConfigLoader()
# Reset the config to ensure clean state
loader._config = None
# The loader now logs and returns instead of raising FileNotFoundError
loader.load("/nonexistent/path.yaml")
# Verify that the config was not loaded (remains None)
with pytest.raises(ValueError, match="Configuration has not been loaded"):
loader.get()
def test_load_config_invalid_yaml(self, temp_dir):
"""Test handling of invalid YAML syntax."""
config_file = temp_dir / "invalid_config.yaml"
config_file.write_text("invalid: yaml: content: [")
loader = ConfigLoader()
with pytest.raises(ValueError, match="Invalid YAML in configuration file"):
loader.load(str(config_file))
def test_load_config_with_bi_config_file(self, temp_dir):
"""Test loading configuration with BI config file."""
bi_config_data = {"metrics": ["revenue", "users"], "dimensions": ["date", "region"]}
bi_config_file = temp_dir / "bi_config.yaml"
with open(bi_config_file, "w") as f:
yaml.dump(bi_config_data, f)
config_data = {
"organization": "TestOrg",
"bi_config_file": str(bi_config_file),
"default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}},
"embedding_model": {
"class": "langchain_openai.OpenAIEmbeddings",
"params": {"model": "text-embedding-ada-002"},
},
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
}
config_file = temp_dir / "test_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loader = ConfigLoader()
with (
patch("langchain_openai.ChatOpenAI") as mock_openai,
patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings,
):
mock_llm_instance = MagicMock()
mock_embedding_instance = MagicMock()
mock_openai.return_value = mock_llm_instance
mock_embeddings.return_value = mock_embedding_instance
loader.load(str(config_file))
config = loader.get()
assert config.bi_config["metrics"] == ["revenue", "users"]
assert config.bi_config["dimensions"] == ["date", "region"]
def test_load_config_with_catalog_store(self, temp_dir):
"""Test loading configuration with catalog store."""
config_data = {
"organization": "TestOrg",
"catalog_store": {"store_type": "file_system", "data_path": str(temp_dir / "catalog_data")},
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
"default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}},
"embedding_model": {
"class": "langchain_openai.OpenAIEmbeddings",
"params": {"model": "text-embedding-ada-002"},
},
}
config_file = temp_dir / "test_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loader = ConfigLoader()
with (
patch("langchain_openai.ChatOpenAI") as mock_openai,
patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings,
):
mock_llm_instance = MagicMock()
mock_embedding_instance = MagicMock()
mock_openai.return_value = mock_llm_instance
mock_embeddings.return_value = mock_embedding_instance
loader.load(str(config_file))
config = loader.get()
# Just verify that a catalog store was created
assert config.catalog_store is not None
assert hasattr(config.catalog_store, "get_table_list")
def test_load_config_with_llm_configs(self, temp_dir):
"""Test loading configuration with LLM configs."""
config_data = {
"organization": "TestOrg",
"default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4", "temperature": 0.1}},
"embedding_model": {
"class": "langchain_openai.OpenAIEmbeddings",
"params": {"model": "text-embedding-ada-002"},
},
"text2sql_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-3.5-turbo"}},
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
}
config_file = temp_dir / "test_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loader = ConfigLoader()
with (
patch("langchain_openai.ChatOpenAI") as mock_openai,
patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings,
):
# Create proper mocks that satisfy BaseChatModel interface
from langchain_core.language_models import BaseChatModel
mock_instance1 = MagicMock(spec=BaseChatModel)
mock_instance2 = MagicMock(spec=BaseChatModel)
mock_embedding_instance = MagicMock()
mock_openai.side_effect = [mock_instance1, mock_instance2]
mock_embeddings.return_value = mock_embedding_instance
loader.load(str(config_file))
config = loader.get()
assert config.default_llm == mock_instance1
assert config.embedding_model == mock_embedding_instance
assert config.text2sql_llm == mock_instance2
def test_load_config_with_llm_providers_selected_by_default_llm(self, temp_dir):
"""Test loading configuration using llm_providers with default_llm provider selector."""
config_data = {
"organization": "TestOrg",
"dialect": "presto",
"default_llm": "openai",
"llm_providers": {
"openai": {
"default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}},
"embedding_model": {
"class": "langchain_openai.OpenAIEmbeddings",
"params": {"model": "text-embedding-ada-002"},
},
},
"anthropic": {
"default_llm": {"class": "langchain_anthropic.ChatAnthropic", "params": {"model": "claude"}},
},
},
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
}
config_file = temp_dir / "test_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loader = ConfigLoader()
with (
patch("langchain_openai.ChatOpenAI") as mock_openai,
patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings,
patch("langchain_anthropic.ChatAnthropic") as mock_anthropic,
):
from langchain_core.language_models import BaseChatModel
mock_openai_instance = MagicMock(spec=BaseChatModel)
mock_anthropic_instance = MagicMock(spec=BaseChatModel)
mock_embedding_instance = MagicMock()
mock_openai.return_value = mock_openai_instance
mock_anthropic.return_value = mock_anthropic_instance
mock_embeddings.return_value = mock_embedding_instance
loader.load(str(config_file))
config = loader.get()
assert config.llm_provider == "openai"
assert config.default_llm == mock_openai_instance
assert config.embedding_model == mock_embedding_instance
assert set(config.llm_providers.keys()) == {"openai", "anthropic"}
def test_set_config(self):
"""Test setting configuration from dictionary."""
config_dict = {
"organization": "SetOrg",
"dialect": "postgresql",
"default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}},
"embedding_model": {
"class": "langchain_openai.OpenAIEmbeddings",
"params": {"model": "text-embedding-ada-002"},
},
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
}
loader = ConfigLoader()
with (
patch("langchain_openai.ChatOpenAI") as mock_openai,
patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings,
):
mock_llm_instance = MagicMock()
mock_embedding_instance = MagicMock()
mock_openai.return_value = mock_llm_instance
mock_embeddings.return_value = mock_embedding_instance
loader.set(config_dict)
config = loader.get()
assert config.organization == "SetOrg"
assert config.dialect == "postgresql"
def test_get_config_not_loaded(self):
"""Test getting configuration when not loaded."""
loader = ConfigLoader()
loader._config = None
with pytest.raises(ValueError, match="Configuration has not been loaded"):
loader.get()
def test_load_bi_config_missing_file(self, temp_dir):
"""Test loading missing BI config file."""
nonexistent_file = temp_dir / "nonexistent_bi.yaml"
loader = ConfigLoader()
# Should not raise exception, just return empty dict
result = loader.load_bi_config(str(nonexistent_file))
assert result == {}
def test_catalog_store_missing_store_type(self, temp_dir):
"""Test catalog store configuration without store_type."""
config_data = {
"organization": "TestOrg",
"catalog_store": {
"data_path": "/test/path"
# Missing store_type
},
"default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}},
"embedding_model": {
"class": "langchain_openai.OpenAIEmbeddings",
"params": {"model": "text-embedding-ada-002"},
},
"data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"},
}
config_file = temp_dir / "test_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)
loader = ConfigLoader()
with pytest.raises(ValueError, match="catalog_store must have a store_type field"):
loader.load(str(config_file))
================================================
FILE: tests/test_graph_state.py
================================================
"""Tests for graph state management."""
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from openchatbi.graph_state import AgentState, InputState, OutputState
class TestAgentState:
"""Test AgentState functionality."""
def test_agent_state_with_data(self):
"""Test creating AgentState with initial data."""
messages = [HumanMessage(content="Test message")]
sql = "SELECT * FROM test_table;"
agent_next_node = "sql_generation"
final_answer = "Here is your data"
state = AgentState(messages=messages, sql=sql, agent_next_node=agent_next_node, final_answer=final_answer)
assert state["messages"] == messages
assert state["sql"] == sql
assert state["agent_next_node"] == agent_next_node
assert state["final_answer"] == final_answer
def test_agent_state_message_types(self):
"""Test AgentState with different message types."""
messages = [
HumanMessage(content="User question"),
AIMessage(content="AI response"),
ToolMessage(content="Tool result", tool_call_id="test_id"),
]
state = AgentState(messages=messages)
assert len(state["messages"]) == 3
assert isinstance(state["messages"][0], HumanMessage)
assert isinstance(state["messages"][1], AIMessage)
assert isinstance(state["messages"][2], ToolMessage)
def test_agent_state_immutability(self):
"""Test that AgentState behaves correctly with updates."""
original_state = AgentState(
messages=[HumanMessage(content="Original")],
sql="SELECT 1;",
agent_next_node="original_node",
final_answer="Original answer",
)
# Create updated state
new_messages = original_state["messages"] + [AIMessage(content="Response")]
updated_state = AgentState(
messages=new_messages, sql="SELECT 2;", agent_next_node="updated_node", final_answer="Updated answer"
)
# Original state should remain unchanged
assert len(original_state["messages"]) == 1
assert original_state["sql"] == "SELECT 1;"
assert original_state["agent_next_node"] == "original_node"
assert original_state["final_answer"] == "Original answer"
# Updated state should have new values
assert len(updated_state["messages"]) == 2
assert updated_state["sql"] == "SELECT 2;"
assert updated_state["agent_next_node"] == "updated_node"
assert updated_state["final_answer"] == "Updated answer"
class TestInputState:
"""Test InputState functionality."""
def test_input_state_creation(self):
"""Test creating InputState."""
messages = [HumanMessage(content="Input message")]
state = InputState(messages=messages)
assert state["messages"] == messages
def test_input_state_empty_messages(self):
"""Test InputState with empty messages."""
state = InputState(messages=[])
assert state["messages"] == []
class TestOutputState:
"""Test OutputState functionality."""
def test_output_state_creation(self):
"""Test creating OutputState."""
messages = [AIMessage(content="Output message")]
state = OutputState(messages=messages)
assert state["messages"] == messages
def test_output_state_with_multiple_messages(self):
"""Test OutputState with conversation history."""
messages = [
HumanMessage(content="Question"),
AIMessage(content="Answer"),
HumanMessage(content="Follow-up"),
AIMessage(content="Final response"),
]
state = OutputState(messages=messages)
assert len(state["messages"]) == 4
assert state["messages"] == messages
class TestStateIntegration:
"""Test integration between different state types."""
def test_input_to_agent_state_conversion(self):
"""Test converting InputState to AgentState."""
input_messages = [HumanMessage(content="User input")]
input_state = InputState(messages=input_messages)
# Simulate conversion to AgentState
agent_state = AgentState(messages=input_state["messages"], sql="", agent_next_node="", final_answer="")
assert agent_state["messages"] == input_messages
assert agent_state["sql"] == ""
def test_agent_to_output_state_conversion(self):
"""Test converting AgentState to OutputState."""
agent_messages = [HumanMessage(content="Question"), AIMessage(content="Generated response")]
agent_state = AgentState(
messages=agent_messages,
sql="SELECT * FROM test_table;",
agent_next_node="output",
final_answer="Generated response",
)
# Simulate conversion to OutputState
output_state = OutputState(messages=agent_state["messages"])
assert output_state["messages"] == agent_messages
def test_state_serialization_compatibility(self):
"""Test that states can be serialized and deserialized."""
original_state = AgentState(
messages=[HumanMessage(content="Test"), AIMessage(content="Response")],
sql="SELECT COUNT(*) FROM table1;",
agent_next_node="final",
final_answer="Count results",
)
# Convert to dict (simulating serialization)
state_dict = {
"messages": original_state["messages"],
"sql": original_state["sql"],
"agent_next_node": original_state["agent_next_node"],
"final_answer": original_state["final_answer"],
}
# Recreate from dict (simulating deserialization)
recreated_state = AgentState(**state_dict)
assert recreated_state["messages"] == original_state["messages"]
assert recreated_state["sql"] == original_state["sql"]
assert recreated_state["agent_next_node"] == original_state["agent_next_node"]
assert recreated_state["final_answer"] == original_state["final_answer"]
================================================
FILE: tests/test_incomplete_tool_calls.py
================================================
"""Tests for incomplete tool call recovery functionality."""
from unittest.mock import Mock
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from openchatbi.agent_graph import agent_llm_call
from openchatbi.graph_state import AgentState
from openchatbi.utils import recover_incomplete_tool_calls
class TestIncompleteToolCallRecovery:
"""Test cases for recover_incomplete_tool_calls function."""
def test_no_messages(self):
"""Test recovery with empty message list."""
state = AgentState(messages=[])
result = recover_incomplete_tool_calls(state)
assert result == []
def test_no_tool_calls(self):
"""Test recovery with messages but no tool calls."""
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi there!")]
state = AgentState(messages=messages)
result = recover_incomplete_tool_calls(state)
assert result == []
def test_complete_tool_calls(self):
"""Test recovery when all tool calls have responses."""
messages = [
HumanMessage(content="Search for data"),
AIMessage(
content="I'll search for that data.",
tool_calls=[{"name": "search", "args": {"query": "data"}, "id": "call_1"}],
),
ToolMessage(content="Search completed", tool_call_id="call_1"),
]
state = AgentState(messages=messages)
result = recover_incomplete_tool_calls(state)
assert result == []
def test_incomplete_single_tool_call(self):
"""Test recovery when there's one incomplete tool call."""
messages = [
HumanMessage(content="Search for data"),
AIMessage(
content="I'll search for that data.",
tool_calls=[{"name": "search", "args": {"query": "data"}, "id": "call_1"}],
),
]
state = AgentState(messages=messages)
result = recover_incomplete_tool_calls(state)
assert isinstance(result, list)
assert len(result) == 1 # Just the recovery message
failure_msg = result[0]
assert failure_msg.tool_call_id == "call_1"
assert "interrupted" in failure_msg.content.lower()
def test_incomplete_multiple_tool_calls(self):
"""Test recovery when there are multiple incomplete tool calls."""
messages = [
HumanMessage(content="Search and analyze"),
AIMessage(
content="I'll search and analyze.",
tool_calls=[
{"name": "search", "args": {"query": "data"}, "id": "call_1"},
{"name": "analyze", "args": {"data": "result"}, "id": "call_2"},
],
),
]
state = AgentState(messages=messages)
result = recover_incomplete_tool_calls(state)
assert isinstance(result, list)
assert len(result) == 2 # Just the recovery messages
# Check that both tool calls get failure messages
recovery_messages = result
tool_call_ids = {msg.tool_call_id for msg in recovery_messages}
assert tool_call_ids == {"call_1", "call_2"}
for msg in recovery_messages:
assert isinstance(msg, ToolMessage)
assert "interrupted" in msg.content.lower()
def test_partial_incomplete_tool_calls(self):
"""Test recovery when some tool calls are complete, others are not."""
messages = [
HumanMessage(content="Search and analyze"),
AIMessage(
content="I'll search and analyze.",
tool_calls=[
{"name": "search", "args": {"query": "data"}, "id": "call_1"},
{"name": "analyze", "args": {"data": "result"}, "id": "call_2"},
],
),
ToolMessage(content="Search completed", tool_call_id="call_1"),
# Missing ToolMessage for call_2
]
state = AgentState(messages=messages)
result = recover_incomplete_tool_calls(state)
assert isinstance(result, list)
assert len(result) == 3 # RemoveMessage + recovery message + re-added message
# Should have: RemoveMessage, ToolMessage(recovery for call_2), ToolMessage(original for call_1)
operations = result
assert "RemoveMessage" in str(type(operations[0])) # Remove the existing ToolMessage
assert isinstance(operations[1], ToolMessage) # Recovery message for call_2
assert isinstance(operations[2], ToolMessage) # Re-added original message for call_1
# The recovery message should be for call_2
recovery_msg = operations[1]
assert recovery_msg.tool_call_id == "call_2"
assert "interrupted" in recovery_msg.content.lower()
# The re-added message should be the original for call_1
original_msg = operations[2]
assert original_msg.tool_call_id == "call_1"
assert original_msg.content == "Search completed"
def test_multiple_ai_messages_with_tool_calls(self):
"""Test recovery considers only the last AIMessage with tool calls."""
messages = [
HumanMessage(content="First task"),
AIMessage(content="Doing first task.", tool_calls=[{"name": "task1", "args": {}, "id": "old_call"}]),
ToolMessage(content="Task 1 done", tool_call_id="old_call"),
HumanMessage(content="Second task"),
AIMessage(content="Doing second task.", tool_calls=[{"name": "task2", "args": {}, "id": "new_call"}]),
# Missing ToolMessage for new_call
]
state = AgentState(messages=messages)
result = recover_incomplete_tool_calls(state)
assert isinstance(result, list)
assert len(result) == 1 # Just the recovery message
# The recovery message should be for new_call only
recovery_msg = result[0]
assert recovery_msg.tool_call_id == "new_call"
assert "interrupted" in recovery_msg.content.lower()
def test_llm_node_integration_with_recovery(self):
"""Test that the llm_node handles recovery correctly and continues processing."""
# Create a mock llm_node function for testing
mock_llm = Mock()
mock_tools = []
llm_node_func = agent_llm_call(mock_llm, mock_tools)
# State with incomplete tool calls
messages = [
HumanMessage(content="Search for data"),
AIMessage(
content="I'll search for that data.",
tool_calls=[{"name": "search", "args": {"query": "data"}, "id": "call_1"}],
),
]
state = AgentState(messages=messages)
# Call the llm node - it should detect incomplete tool calls and return recovery
result = llm_node_func(state)
# Should return message operations and continue to llm_node
assert "messages" in result
assert "agent_next_node" in result
assert result["agent_next_node"] == "llm_node"
# Should have recovery ToolMessage operation for the incomplete call
operations = result["messages"]
assert len(operations) == 1 # Only recovery message needed
assert isinstance(operations[0], ToolMessage)
assert operations[0].tool_call_id == "call_1"
================================================
FILE: tests/test_memory.py
================================================
"""Tests for memory tool functionality."""
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
import pytest
from langchain_core.language_models import FakeListChatModel
from langchain_openai import ChatOpenAI
# Check if pysqlite3 is available, if not skip these tests
pysqlite3 = pytest.importorskip("pysqlite3", reason="pysqlite3 not available")
from openchatbi.tool.memory import (
StructuredToolWithRequired,
UserProfile,
cleanup_async_memory_store,
fix_schema_for_openai,
get_async_memory_store,
get_async_memory_tools,
get_memory_manager,
get_memory_tools,
get_sync_memory_store,
setup_async_memory_store,
)
class TestUserProfile:
"""Test UserProfile model functionality."""
def test_user_profile_basic_initialization(self):
"""Test basic UserProfile model creation."""
profile = UserProfile(name="John Doe", language="English", timezone="UTC", jargon="Technical")
assert profile.name == "John Doe"
assert profile.language == "English"
assert profile.timezone == "UTC"
assert profile.jargon == "Technical"
def test_user_profile_optional_fields(self):
"""Test UserProfile with optional fields."""
profile = UserProfile()
assert profile.name is None
assert profile.language is None
assert profile.timezone is None
assert profile.jargon is None
def test_user_profile_partial_initialization(self):
"""Test UserProfile with partial field initialization."""
profile = UserProfile(name="Jane Smith", language="Spanish")
assert profile.name == "Jane Smith"
assert profile.language == "Spanish"
assert profile.timezone is None
assert profile.jargon is None
def test_user_profile_serialization(self):
"""Test UserProfile model serialization."""
profile = UserProfile(name="Test User", timezone="EST")
data = profile.model_dump()
assert data["name"] == "Test User"
assert data["timezone"] == "EST"
assert data["language"] is None
assert data["jargon"] is None
class TestMemoryStoreManagement:
"""Test memory store management functions."""
@pytest.fixture(autouse=True)
def setup_test_env(self, tmp_path: Path):
"""Setup test environment with temporary database."""
self.temp_db_path = tmp_path / "test_memory.db"
# Clean up any global state
import openchatbi.tool.memory as memory_module
memory_module.sync_memory_store = None
memory_module.async_memory_store = None
memory_module.async_store_context_manager = None
@patch("openchatbi.tool.memory.sqlite3.connect")
@patch("openchatbi.tool.memory.config.get")
def test_get_sync_memory_store(self, mock_config, mock_connect):
"""Test sync memory store creation."""
mock_config.return_value.embedding_model = Mock()
mock_conn = Mock()
mock_connect.return_value = mock_conn
# Mock SqliteStore
with patch("openchatbi.tool.memory.SqliteStore") as mock_store_class:
mock_store = Mock()
mock_store_class.return_value = mock_store
store = get_sync_memory_store()
assert store == mock_store
mock_store_class.assert_called_once()
mock_store.setup.assert_called_once()
@pytest.mark.asyncio
@patch("openchatbi.tool.memory.AsyncSqliteStore.from_conn_string")
@patch("openchatbi.tool.memory.config.get")
async def test_get_async_memory_store(self, mock_config, mock_from_conn_string):
"""Test async memory store creation."""
mock_config.return_value.embedding_model = Mock()
# Mock the async context manager
mock_context_manager = AsyncMock()
mock_store = Mock()
mock_context_manager.__aenter__.return_value = mock_store
mock_from_conn_string.return_value = mock_context_manager
store = await get_async_memory_store()
assert store == mock_store
mock_from_conn_string.assert_called_once()
mock_context_manager.__aenter__.assert_called_once()
@pytest.mark.asyncio
@patch("openchatbi.tool.memory.async_memory_store", new=Mock())
@patch("openchatbi.tool.memory.async_store_context_manager")
async def test_cleanup_async_memory_store(self, mock_context_manager):
"""Test async memory store cleanup."""
mock_context_manager.__aexit__ = AsyncMock()
await cleanup_async_memory_store()
mock_context_manager.__aexit__.assert_called_once_with(None, None, None)
@pytest.mark.asyncio
@patch("openchatbi.tool.memory.get_async_memory_store")
async def test_setup_async_memory_store(self, mock_get_store):
"""Test async memory store setup."""
mock_store = Mock()
mock_get_store.return_value = mock_store
result = await setup_async_memory_store()
mock_get_store.assert_called_once()
assert result is None
class TestMemoryTools:
"""Test memory tools creation and management."""
@patch("openchatbi.tool.memory.create_manage_memory_tool")
@patch("openchatbi.tool.memory.create_search_memory_tool")
@patch("openchatbi.tool.memory.get_sync_memory_store")
def test_get_memory_tools_sync_mode(self, mock_get_store, mock_search_tool, mock_manage_tool):
"""Test getting memory tools in sync mode."""
mock_llm = FakeListChatModel(responses=["test"])
mock_store = Mock()
mock_get_store.return_value = mock_store
mock_manage = Mock()
mock_search = Mock()
mock_manage_tool.return_value = mock_manage
mock_search_tool.return_value = mock_search
memory_tools = get_memory_tools(mock_llm, sync_mode=True)
manage_tool, search_tool = memory_tools[0], memory_tools[1]
assert manage_tool == mock_manage
assert search_tool == mock_search
mock_manage_tool.assert_called_once_with(namespace=("memories", "{user_id}"), store=mock_store)
mock_search_tool.assert_called_once_with(namespace=("memories", "{user_id}"), store=mock_store)
@patch("openchatbi.tool.memory.create_manage_memory_tool")
@patch("openchatbi.tool.memory.create_search_memory_tool")
@patch("openchatbi.tool.memory.config.get")
def test_get_memory_tools_with_openai_llm(self, mock_config, mock_search_tool, mock_manage_tool):
"""Test getting memory tools with OpenAI LLM (requires structured tool wrapper)."""
mock_llm = Mock(spec=ChatOpenAI)
mock_config.return_value.embedding_model = Mock()
mock_manage = Mock()
mock_search = Mock()
mock_manage_tool.return_value = mock_manage
mock_search_tool.return_value = mock_search
with patch("openchatbi.tool.memory.StructuredToolWithRequired") as mock_wrapper:
mock_wrapped_manage = Mock()
mock_wrapped_search = Mock()
mock_wrapper.side_effect = [mock_wrapped_manage, mock_wrapped_search]
memory_tools = get_memory_tools(mock_llm, sync_mode=True)
manage_tool, search_tool = memory_tools[0], memory_tools[1]
assert manage_tool == mock_wrapped_manage
assert search_tool == mock_wrapped_search
assert mock_wrapper.call_count == 2
@pytest.mark.asyncio
@patch("openchatbi.tool.memory.get_async_memory_store")
@patch("openchatbi.tool.memory.get_memory_tools")
@patch("openchatbi.tool.memory.config.get")
async def test_get_async_memory_tools(self, mock_config, mock_get_tools, mock_get_store):
"""Test getting async memory tools."""
mock_llm = FakeListChatModel(responses=["test"])
mock_store = Mock()
mock_get_store.return_value = mock_store
mock_config.return_value.embedding_model = Mock()
mock_manage = Mock()
mock_search = Mock()
mock_get_tools.return_value = (mock_manage, mock_search)
manage_tool, search_tool = await get_async_memory_tools(mock_llm)
assert manage_tool == mock_manage
assert search_tool == mock_search
mock_get_store.assert_called_once()
mock_get_tools.assert_called_once_with(mock_llm, sync_mode=False, store=mock_store)
class TestMemoryManager:
"""Test memory manager functionality."""
@patch("openchatbi.tool.memory.create_memory_store_manager")
@patch("openchatbi.tool.memory.config.get")
def test_get_memory_manager(self, mock_config, mock_create_manager):
"""Test memory manager creation."""
mock_llm = Mock()
mock_config.return_value.default_llm = mock_llm
mock_manager = Mock()
mock_create_manager.return_value = mock_manager
manager = get_memory_manager()
assert manager == mock_manager
mock_create_manager.assert_called_once_with(
mock_llm,
schemas=[UserProfile],
instructions="Extract user profile information",
enable_inserts=False,
)
@patch("openchatbi.tool.memory.memory_manager", new=Mock())
@patch("openchatbi.tool.memory.create_memory_store_manager")
@patch("openchatbi.tool.memory.config.get")
def test_get_memory_manager_singleton(self, mock_config, mock_create_manager):
"""Test memory manager singleton behavior."""
# Reset the global variable for this test
import openchatbi.tool.memory as memory_module
existing_manager = Mock()
memory_module.memory_manager = existing_manager
manager = get_memory_manager()
# Should return existing manager without creating new one
assert manager == existing_manager
mock_create_manager.assert_not_called()
class TestSchemaFixer:
"""Test schema fixing functionality for OpenAI compatibility."""
def test_fix_schema_for_openai_basic(self):
"""Test basic schema fixing."""
schema = {"properties": {"field1": {"type": "string"}, "field2": {"type": "number"}}}
fix_schema_for_openai(schema)
assert schema["required"] == ["field1", "field2"]
def test_fix_schema_for_openai_nested_object(self):
"""Test schema fixing with nested objects."""
schema = {
"properties": {
"nested": {"type": "object", "additionalProperties": True, "properties": {"inner": {"type": "string"}}}
}
}
fix_schema_for_openai(schema)
assert schema["required"] == ["nested"]
assert schema["properties"]["nested"]["additionalProperties"] is False
def test_fix_schema_for_openai_with_arrays(self):
"""Test schema fixing with array properties."""
schema = {"properties": {"items": {"type": "array", "items": {"type": "object", "additionalProperties": True}}}}
fix_schema_for_openai(schema)
assert schema["required"] == ["items"]
assert schema["properties"]["items"]["items"]["additionalProperties"] is False
class TestStructuredToolWithRequired:
"""Test StructuredToolWithRequired wrapper functionality."""
def test_structured_tool_with_required_initialization(self):
"""Test StructuredToolWithRequired initialization."""
mock_original_tool = Mock()
mock_original_tool.name = "test_tool"
mock_original_tool.description = "Test description"
mock_original_tool.args_schema = Mock()
mock_original_tool.func = Mock()
mock_original_tool.coroutine = None
with patch("openchatbi.tool.memory.StructuredTool.__init__", return_value=None) as mock_init:
wrapper = StructuredToolWithRequired(mock_original_tool)
# Verify the __init__ was called with correct parameters
mock_init.assert_called_once()
call_args = mock_init.call_args
assert call_args.kwargs["name"] == "test_tool"
assert call_args.kwargs["description"] == "Test description"
def test_tool_call_schema_property(self):
"""Test tool_call_schema cached property."""
mock_original_tool = Mock()
mock_original_tool.name = "test_tool"
mock_original_tool.description = "Test description"
mock_original_tool.args_schema = Mock()
mock_original_tool.func = Mock()
mock_original_tool.coroutine = None
with patch("openchatbi.tool.memory.StructuredTool.__init__", return_value=None):
wrapper = StructuredToolWithRequired(mock_original_tool)
# Mock the parent's tool_call_schema
mock_tcs = Mock()
mock_tcs.model_config = {}
with patch("openchatbi.tool.memory.StructuredTool.tool_call_schema", new_callable=lambda: mock_tcs):
result = wrapper.tool_call_schema
assert result == mock_tcs
assert "json_schema_extra" in mock_tcs.model_config
================================================
FILE: tests/test_plotly_utils.py
================================================
"""Tests for plotly utilities in the UI."""
import plotly.graph_objects as go
import pytest
from sample_ui.plotly_utils import (
create_empty_chart,
create_plotly_chart,
visualization_dsl_to_gradio_plot,
)
@pytest.fixture
def sample_csv_data():
"""Sample CSV data for testing."""
return """product,sales,region,month
Widget A,10000,North,Jan
Widget B,15000,South,Jan
Widget C,8000,East,Jan
Widget A,12000,North,Feb
Widget B,18000,South,Feb
Widget C,9000,East,Feb"""
@pytest.fixture
def sample_line_dsl():
"""Sample DSL for line chart."""
return {
"chart_type": "line",
"data_columns": ["month", "sales"],
"config": {"x": "month", "y": "sales", "mode": "lines+markers"},
"layout": {"title": "Sales Over Time", "xaxis_title": "Month", "yaxis_title": "Sales"},
}
@pytest.fixture
def sample_bar_dsl():
"""Sample DSL for bar chart."""
return {
"chart_type": "bar",
"data_columns": ["region", "sales"],
"config": {"x": "region", "y": "sales"},
"layout": {"title": "Sales by Region", "xaxis_title": "Region", "yaxis_title": "Sales"},
}
@pytest.fixture
def sample_pie_dsl():
"""Sample DSL for pie chart."""
return {
"chart_type": "pie",
"data_columns": ["product", "sales"],
"config": {"labels": "product", "values": "sales"},
"layout": {"title": "Sales Distribution by Product"},
}
class TestPlotlyChartCreation:
"""Tests for individual chart creation functions."""
def test_create_line_chart_success(self, sample_csv_data, sample_line_dsl):
"""Test successful line chart creation."""
fig = create_plotly_chart(sample_csv_data, sample_line_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
assert fig.layout.title.text == "Sales Over Time"
def test_create_line_chart_with_color(self, sample_csv_data):
"""Test line chart creation with color parameter for multiple series."""
multi_series_dsl = {
"chart_type": "line",
"data_columns": ["month", "sales", "product"],
"config": {"x": "month", "y": "sales", "color": "product"},
"layout": {"title": "Sales Over Time by Product", "xaxis_title": "Month", "yaxis_title": "Sales"},
}
fig = create_plotly_chart(sample_csv_data, multi_series_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0 # Should have multiple traces for different products
assert fig.layout.title.text == "Sales Over Time by Product"
def test_create_line_chart_with_multiple_y_columns(self):
"""Test line chart creation with multiple y columns."""
multi_metric_data = """date,revenue,profit,users
2023-01-01,50000,15000,1000
2023-02-01,55000,18000,1100
2023-03-01,60000,20000,1200"""
multi_y_dsl = {
"chart_type": "line",
"data_columns": ["date", "revenue", "profit"],
"config": {"x": "date", "y": ["revenue", "profit"]},
"layout": {"title": "Multiple Metrics Over Time", "xaxis_title": "Date", "yaxis_title": "Value"},
}
fig = create_plotly_chart(multi_metric_data, multi_y_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0 # Should have multiple traces for different metrics
assert fig.layout.title.text == "Multiple Metrics Over Time"
def test_create_bar_chart_success(self, sample_csv_data, sample_bar_dsl):
"""Test successful bar chart creation."""
fig = create_plotly_chart(sample_csv_data, sample_bar_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
assert fig.layout.title.text == "Sales by Region"
def test_create_pie_chart_success(self, sample_csv_data, sample_pie_dsl):
"""Test successful pie chart creation."""
fig = create_plotly_chart(sample_csv_data, sample_pie_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
assert fig.layout.title.text == "Sales Distribution by Product"
def test_create_scatter_chart(self, sample_csv_data):
"""Test scatter chart creation."""
scatter_dsl = {
"chart_type": "scatter",
"data_columns": ["sales", "region"],
"config": {"x": "sales", "y": "region", "mode": "markers"},
"layout": {"title": "Sales Scatter Plot"},
}
fig = create_plotly_chart(sample_csv_data, scatter_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
def test_create_histogram_chart(self, sample_csv_data):
"""Test histogram chart creation."""
histogram_dsl = {
"chart_type": "histogram",
"data_columns": ["sales"],
"config": {"x": "sales", "nbins": 10},
"layout": {"title": "Sales Distribution"},
}
fig = create_plotly_chart(sample_csv_data, histogram_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
def test_create_box_chart(self, sample_csv_data):
"""Test box chart creation."""
box_dsl = {
"chart_type": "box",
"data_columns": ["sales", "region"],
"config": {"y": "sales", "x": "region"},
"layout": {"title": "Sales Distribution by Region"},
}
fig = create_plotly_chart(sample_csv_data, box_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
def test_create_table_chart(self, sample_csv_data):
"""Test table chart creation."""
table_dsl = {
"chart_type": "table",
"data_columns": ["product", "sales", "region", "month"],
"config": {"columns": ["product", "sales", "region", "month"]},
"layout": {"title": "Data Table"},
}
fig = create_plotly_chart(sample_csv_data, table_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
assert fig.data[0].type == "table"
class TestErrorHandling:
"""Tests for error handling in chart creation."""
def test_empty_data(self):
"""Test handling of empty data."""
fig = create_plotly_chart("", {})
assert isinstance(fig, go.Figure)
# Should create an empty chart with error message
def test_invalid_csv_data(self, sample_bar_dsl):
"""Test handling of invalid CSV data."""
invalid_csv = "invalid,csv\ndata"
fig = create_plotly_chart(invalid_csv, sample_bar_dsl)
assert isinstance(fig, go.Figure)
# Should create an empty chart with error message
def test_missing_columns(self, sample_csv_data):
"""Test handling of missing columns in DSL."""
invalid_dsl = {
"chart_type": "line",
"data_columns": ["nonexistent_col"],
"config": {"x": "nonexistent_col", "y": "another_missing_col"},
"layout": {"title": "Invalid Chart"},
}
fig = create_plotly_chart(sample_csv_data, invalid_dsl)
assert isinstance(fig, go.Figure)
# Should create an empty chart with error message
def test_unsupported_chart_type(self, sample_csv_data):
"""Test handling of unsupported chart types."""
invalid_dsl = {"chart_type": "unsupported_type", "data_columns": ["sales"], "config": {}, "layout": {}}
fig = create_plotly_chart(sample_csv_data, invalid_dsl)
assert isinstance(fig, go.Figure)
# Should create an empty chart with error message
def test_visualization_dsl_error(self):
"""Test handling of DSL with error field."""
error_dsl = {"error": "Failed to generate visualization"}
fig = create_plotly_chart("some,data\n1,2", error_dsl)
assert isinstance(fig, go.Figure)
# Should create an empty chart with error message
class TestVisualizationDslToGradioPlot:
"""Tests for the main interface function."""
def test_successful_conversion(self, sample_csv_data, sample_line_dsl):
"""Test successful DSL to Gradio plot conversion."""
fig, description = visualization_dsl_to_gradio_plot(sample_csv_data, sample_line_dsl)
assert isinstance(fig, go.Figure)
assert isinstance(description, str)
assert "line" in description.lower()
assert "Sales Over Time" in description
def test_empty_dsl(self, sample_csv_data):
"""Test conversion with empty DSL."""
fig, description = visualization_dsl_to_gradio_plot(sample_csv_data, {})
assert isinstance(fig, go.Figure)
assert isinstance(description, str)
assert "table" in description.lower()
def test_no_data(self, sample_line_dsl):
"""Test conversion with no data."""
fig, description = visualization_dsl_to_gradio_plot("", sample_line_dsl)
assert isinstance(fig, go.Figure)
assert isinstance(description, str)
class TestCreateEmptyChart:
"""Tests for empty chart creation."""
def test_create_empty_chart(self):
"""Test empty chart creation with message."""
message = "Test error message"
fig = create_empty_chart(message)
assert isinstance(fig, go.Figure)
assert fig.layout.title.text == "Chart Generation Issue"
# Check if annotation contains the message
assert len(fig.layout.annotations) > 0
assert fig.layout.annotations[0].text == message
@pytest.fixture
def sample_time_series_data():
"""Sample time series data for testing."""
return """date,revenue,users
2023-01-01,50000,1000
2023-02-01,55000,1100
2023-03-01,60000,1200
2023-04-01,52000,1050
2023-05-01,58000,1150"""
class TestIntegrationScenarios:
"""Integration tests for complete visualization scenarios."""
def test_sales_dashboard_scenario(self, sample_csv_data):
"""Test a complete sales dashboard scenario."""
# Test multiple chart types with the same data
chart_configs = [
{"chart_type": "bar", "config": {"x": "product", "y": "sales"}, "layout": {"title": "Sales by Product"}},
{
"chart_type": "pie",
"config": {"labels": "region", "values": "sales"},
"layout": {"title": "Sales by Region"},
},
]
for config in chart_configs:
config["data_columns"] = list(config["config"].values())
fig = create_plotly_chart(sample_csv_data, config)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
def test_time_series_scenario(self, sample_time_series_data):
"""Test time series visualization scenario."""
line_dsl = {
"chart_type": "line",
"data_columns": ["date", "revenue"],
"config": {"x": "date", "y": "revenue", "mode": "lines+markers"},
"layout": {"title": "Revenue Trend Over Time", "xaxis_title": "Date", "yaxis_title": "Revenue"},
}
fig, description = visualization_dsl_to_gradio_plot(sample_time_series_data, line_dsl)
assert isinstance(fig, go.Figure)
assert "line" in description.lower()
assert "Revenue Trend Over Time" in description
def test_multiple_metrics_scenario(self, sample_time_series_data):
"""Test scenario with multiple metrics."""
scatter_dsl = {
"chart_type": "scatter",
"data_columns": ["revenue", "users"],
"config": {"x": "revenue", "y": "users", "mode": "markers"},
"layout": {"title": "Revenue vs Users Correlation", "xaxis_title": "Revenue", "yaxis_title": "Users"},
}
fig = create_plotly_chart(sample_time_series_data, scatter_dsl)
assert isinstance(fig, go.Figure)
assert len(fig.data) > 0
assert fig.layout.title.text == "Revenue vs Users Correlation"
================================================
FILE: tests/test_simple_store.py
================================================
"""Unit tests for SimpleStore."""
import pytest
from openchatbi.utils import SimpleStore
class TestSimpleStore:
"""Test suite for SimpleStore class."""
@pytest.fixture
def sample_texts(self):
"""Sample texts for testing."""
return [
"Python is a programming language",
"Machine learning is a subset of AI",
"Deep learning uses neural networks",
"Natural language processing works with text",
]
@pytest.fixture
def sample_metadatas(self):
"""Sample metadata for testing."""
return [
{"category": "programming"},
{"category": "ai"},
{"category": "ai"},
{"category": "nlp"},
]
@pytest.fixture
def simple_store(self, sample_texts):
"""Create a SimpleStore instance for testing."""
return SimpleStore(sample_texts)
def test_initialization_basic(self, sample_texts):
"""Test basic initialization."""
store = SimpleStore(sample_texts)
assert len(store.texts) == len(sample_texts)
assert store.texts == sample_texts
assert len(store.documents) == len(sample_texts)
assert store.bm25 is not None
def test_initialization_with_metadata_and_ids(self, sample_texts, sample_metadatas):
"""Test initialization with metadata and custom IDs."""
ids = ["id1", "id2", "id3", "id4"]
store = SimpleStore(sample_texts, sample_metadatas, ids)
assert store.texts == sample_texts
assert store.metadatas == sample_metadatas
assert store.ids == ids
# Check documents are created correctly
for doc, text, meta, doc_id in zip(store.documents, sample_texts, sample_metadatas, ids):
assert doc.page_content == text
assert doc.metadata == meta
assert doc.id == doc_id
def test_similarity_search(self, simple_store):
"""Test similarity search functionality."""
query = "programming"
results = simple_store.similarity_search(query, k=2)
assert len(results) == 2
assert "Python" in results[0].page_content
# Test k parameter
results = simple_store.similarity_search(query, k=10)
assert len(results) == 4 # Should return all documents
def test_similarity_search_with_score(self, simple_store):
"""Test similarity search with scores."""
query = "programming"
results = simple_store.similarity_search_with_score(query, k=2)
assert len(results) == 2
for doc, score in results:
assert hasattr(doc, "page_content")
assert isinstance(score, (int, float))
assert score >= 0
# Scores should be in descending order
scores = [score for _, score in results]
assert scores == sorted(scores, reverse=True)
def test_empty_store(self):
"""Test empty store operations."""
store = SimpleStore([])
assert store.bm25 is None
assert store.similarity_search("test", k=5) == []
assert store.similarity_search_with_score("test", k=5) == []
def test_add_texts(self, simple_store):
"""Test adding texts with and without metadata."""
initial_count = len(simple_store.texts)
new_texts = ["Data science is important", "Statistics is fundamental"]
new_metadatas = [{"type": "test"}, {"type": "example"}]
# Add with metadata and custom IDs
custom_ids = ["custom_1", "custom_2"]
returned_ids = simple_store.add_texts(new_texts, metadatas=new_metadatas, ids=custom_ids)
assert returned_ids == custom_ids
assert len(simple_store.texts) == initial_count + len(new_texts)
assert all(text in simple_store.texts for text in new_texts)
# Check metadata was added correctly
added_docs = [doc for doc in simple_store.documents if doc.id in custom_ids]
assert len(added_docs) == 2
assert added_docs[0].metadata == {"type": "test"}
# Verify BM25 index is updated
results = simple_store.similarity_search("data science", k=1)
assert "data" in results[0].page_content.lower() or "science" in results[0].page_content.lower()
def test_delete(self):
"""Test deleting documents."""
texts = ["Text A", "Text B", "Text C", "Text D"]
ids = ["id1", "id2", "id3", "id4"]
store = SimpleStore(texts, ids=ids)
# Delete specific IDs
result = store.delete(["id2", "id3"])
assert result is True
assert len(store.texts) == 2
assert store.texts == ["Text A", "Text D"]
assert store.ids == ["id1", "id4"]
# Delete non-existent IDs
result = store.delete(["nonexistent"])
assert result is False
# Delete with None
result = store.delete(None)
assert result is False
# Delete all remaining documents
result = store.delete(["id1", "id4"])
assert result is True
assert len(store.texts) == 0
assert store.bm25 is None
def test_get_by_ids(self, sample_texts):
"""Test retrieving documents by IDs."""
ids = ["id1", "id2", "id3", "id4"]
store = SimpleStore(sample_texts, ids=ids)
# Get existing IDs
docs = store.get_by_ids(["id1", "id3"])
assert len(docs) == 2
assert docs[0].id == "id1"
assert docs[0].page_content == sample_texts[0]
# Get non-existent IDs
docs = store.get_by_ids(["nonexistent"])
assert len(docs) == 0
# Mixed existent and non-existent
docs = store.get_by_ids(["id1", "nonexistent", "id3"])
assert len(docs) == 2
def test_from_texts(self, sample_texts, sample_metadatas):
"""Test creating store using from_texts class method."""
ids = ["id1", "id2", "id3", "id4"]
store = SimpleStore.from_texts(sample_texts, embedding=None, metadatas=sample_metadatas, ids=ids)
assert isinstance(store, SimpleStore)
assert store.texts == sample_texts
assert store.metadatas == sample_metadatas
assert store.ids == ids
def test_as_retriever(self, simple_store):
"""Test creating a retriever from the store."""
retriever = simple_store.as_retriever(search_kwargs={"k": 2})
results = retriever.invoke("programming")
assert len(results) <= 2
assert all(hasattr(doc, "page_content") for doc in results)
def test_chinese_and_mixed_language(self):
"""Test search with Chinese and mixed language texts."""
from openchatbi.text_segmenter import _jieba_available
mixed_texts = [
"Python programming language",
"机器学习很重要",
"Deep learning neural networks",
"数据科学分析",
]
store = SimpleStore(mixed_texts)
# Search in English
en_results = store.similarity_search("programming", k=1)
assert "Python" in en_results[0].page_content
# Search in Chinese - result depends on jieba availability
cn_results = store.similarity_search("机器学习", k=2)
assert len(cn_results) > 0
# If jieba is available, expect better Chinese matching
if _jieba_available:
assert "机器学习" in cn_results[0].page_content
else:
# Without jieba, just verify results are returned
# (Chinese text may not be perfectly tokenized)
assert any("机器学习" in doc.page_content for doc in cn_results) or any(
"数据科学" in doc.page_content for doc in cn_results
)
def test_max_marginal_relevance_search(self, simple_store):
"""Test max_marginal_relevance_search method."""
query = "programming language"
# Test basic MMR search
results = simple_store.max_marginal_relevance_search(query, k=2, fetch_k=4, lambda_mult=0.5)
assert len(results) == 2
assert all(hasattr(doc, "page_content") for doc in results)
# Test relevance-focused search (lambda_mult = 1.0)
results_relevant = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=1.0)
assert len(results_relevant) == 3
# Test diversity-focused search (lambda_mult = 0.0)
results_diverse = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=0.0)
assert len(results_diverse) == 3
# Verify different lambda values produce different results
# (unless there are ties in scoring)
assert len(results_relevant) == len(results_diverse)
# Test with k >= fetch_k
results = simple_store.max_marginal_relevance_search(query, k=5, fetch_k=3, lambda_mult=0.5)
assert len(results) == 3 # Should return fetch_k documents
# Test empty query
results = simple_store.max_marginal_relevance_search("", k=2)
assert len(results) <= 2
# Test empty store
empty_store = SimpleStore([])
results = empty_store.max_marginal_relevance_search(query, k=2)
assert results == []
def test_calculate_similarity(self, simple_store):
"""Test _calculate_similarity method."""
# Get two documents
doc1 = simple_store.documents[0] # "Python is a programming language"
doc2 = simple_store.documents[1] # "Machine learning is a subset of AI"
doc3 = simple_store.documents[0] # Same as doc1
# Test similarity between different documents
similarity_diff = simple_store._calculate_similarity(doc1, doc2)
assert 0.0 <= similarity_diff <= 1.0
# Test similarity between identical documents
similarity_same = simple_store._calculate_similarity(doc1, doc3)
assert similarity_same == 1.0
# Test with empty documents
from langchain_core.documents import Document
empty_doc1 = Document(page_content="", metadata={})
empty_doc2 = Document(page_content="", metadata={})
similarity_empty = simple_store._calculate_similarity(empty_doc1, empty_doc2)
assert similarity_empty == 0.0 # Empty sets have 0 Jaccard similarity
================================================
FILE: tests/test_text2sql_extraction.py
================================================
"""Tests for text2sql information extraction functionality."""
import json
from datetime import date
from unittest.mock import Mock, patch
from langchain_core.messages import AIMessage, HumanMessage
from openchatbi.graph_state import SQLGraphState
from openchatbi.text2sql.extraction import (
generate_extraction_prompt,
information_extraction,
information_extraction_conditional_edges,
parse_extracted_info_json,
)
class TestText2SQLExtraction:
"""Test text2sql information extraction functionality."""
def test_generate_extraction_prompt(self):
"""Test extraction prompt generation."""
prompt = generate_extraction_prompt()
# Should replace time placeholder with today's date
today_str = date.today().strftime("%Y-%m-%d")
assert today_str in prompt
# Should contain basic knowledge
assert "[basic_knowledge_glossary]" not in prompt
assert "[time_field_placeholder]" not in prompt
def test_parse_extracted_info_json_valid(self):
"""Test parsing valid JSON from LLM response."""
json_response = {
"keywords": ["revenue", "sales"],
"dimensions": ["date", "region"],
"metrics": ["total_revenue"],
"filters": [],
}
# Mock LLM response with JSON
llm_content = f"```json\n{json.dumps(json_response)}\n```"
with patch("openchatbi.text2sql.extraction.get_text_from_content", return_value=llm_content):
with patch("openchatbi.text2sql.extraction.extract_json_from_answer", return_value=json_response):
result = parse_extracted_info_json(llm_content)
assert result == json_response
assert "keywords" in result
assert "dimensions" in result
def test_parse_extracted_info_json_invalid(self):
"""Test parsing invalid JSON returns empty dict."""
invalid_content = "Not valid JSON content"
with patch("openchatbi.text2sql.extraction.get_text_from_content", return_value=invalid_content):
with patch("openchatbi.text2sql.extraction.extract_json_from_answer", side_effect=Exception("Parse error")):
result = parse_extracted_info_json(invalid_content)
assert result == {}
def test_information_extraction_function_creation(self):
"""Test creating information extraction function."""
mock_llm = Mock()
extraction_func = information_extraction(mock_llm)
# Should return a callable function
assert callable(extraction_func)
def test_information_extraction_successful(self):
"""Test successful information extraction."""
mock_llm = Mock()
# Mock LLM response
extracted_info = {
"rewrite_question": "What is the total revenue by region?",
"keywords": ["revenue", "total"],
"dimensions": ["region"],
"metrics": ["revenue"],
"filters": [],
}
mock_response = AIMessage(content=json.dumps(extracted_info))
with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response):
with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value=extracted_info):
extraction_func = information_extraction(mock_llm)
state = SQLGraphState(
messages=[HumanMessage(content="Show me revenue by region")], question="Show me revenue by region"
)
result = extraction_func(state)
assert "info_entities" in result
assert result["rewrite_question"] == "What is the total revenue by region?"
def test_information_extraction_empty_response(self):
"""Test handling empty extraction response."""
mock_llm = Mock()
mock_response = AIMessage(content="")
with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response):
with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value={}):
extraction_func = information_extraction(mock_llm)
state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question")
result = extraction_func(state)
# Should handle empty response gracefully
assert "info_entities" in result
assert result["info_entities"] == {}
def test_information_extraction_conditional_edges_success(self):
"""Test conditional edges with successful extraction."""
state = SQLGraphState(
messages=[HumanMessage(content="Test question")],
question="Test question",
rewrite_question="What is the total revenue by region?",
info_entities={"keywords": ["revenue"], "dimensions": ["date"]},
)
result = information_extraction_conditional_edges(state)
# Should proceed to next when rewrite_question exists
assert result == "next"
def test_information_extraction_conditional_edges_failure(self):
"""Test conditional edges with failed extraction."""
state = SQLGraphState(
messages=[HumanMessage(content="Test question")], question="Test question", info_entities={}
)
result = information_extraction_conditional_edges(state)
# Should end when no info extracted
assert result == "end"
def test_information_extraction_conditional_edges_missing(self):
"""Test conditional edges with missing info_entities."""
state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question")
result = information_extraction_conditional_edges(state)
# Should end when info_entities not present
assert result == "end"
def test_information_extraction_with_retry_on_failure(self):
"""Test information extraction with retry mechanism."""
mock_llm = Mock()
# First call fails, second succeeds
extracted_info = {
"rewrite_question": "Test question",
"keywords": ["test"],
"dimensions": [],
"metrics": [],
"filters": [],
}
mock_response = AIMessage(content=json.dumps(extracted_info))
with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response):
with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value=extracted_info):
extraction_func = information_extraction(mock_llm)
state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question")
result = extraction_func(state)
assert "info_entities" in result
assert result["info_entities"]["keywords"] == ["test"]
def test_information_extraction_time_period_detection(self):
"""Test time period detection in queries."""
mock_llm = Mock()
extracted_info = {
"rewrite_question": "Show data for the last 7 days",
"keywords": ["data"],
"dimensions": ["date"],
"metrics": [],
"filters": [],
"start_time": "2024-01-01",
}
mock_response = AIMessage(content=json.dumps(extracted_info))
with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response):
with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value=extracted_info):
extraction_func = information_extraction(mock_llm)
state = SQLGraphState(
messages=[HumanMessage(content="Test question")], question="Show data for last 7 days"
)
result = extraction_func(state)
assert "info_entities" in result
assert "start_time" in result["info_entities"]
def test_information_extraction_error_handling(self):
"""Test error handling in information extraction."""
mock_llm = Mock()
# Mock call to raise exception
with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", side_effect=Exception("LLM error")):
extraction_func = information_extraction(mock_llm)
state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question")
# Should raise exception as the function doesn't have try-catch
try:
result = extraction_func(state)
# Should not reach here
assert False, "Expected exception to be raised"
except Exception as e:
assert "LLM error" in str(e)
================================================
FILE: tests/test_text2sql_generate_sql.py
================================================
"""Tests for text2sql SQL generation functionality."""
from unittest.mock import Mock, patch
import pytest
from langchain_core.messages import AIMessage
from openchatbi.graph_state import SQLGraphState
from openchatbi.text2sql.generate_sql import create_sql_nodes, should_execute_sql, should_retry_sql
class TestText2SQLGenerateSQL:
"""Test text2sql SQL generation functionality."""
@pytest.fixture
def mock_llm(self):
"""Mock LLM for testing."""
llm = Mock()
llm.invoke.return_value = AIMessage(content="SELECT * FROM users")
return llm
@pytest.fixture
def mock_catalog(self):
"""Mock catalog store for testing."""
catalog = Mock()
catalog.get_table_information.return_value = {
"description": "User data table",
"sql_rule": "",
"derived_metric": "",
}
catalog.get_column_list.return_value = [
{
"column_name": "user_id",
"type": "bigint",
"display_name": "User ID",
"description": "Unique user identifier",
"alias": "",
}
]
# Mock SQL engine with proper context manager
mock_engine = Mock()
mock_connection = Mock()
mock_result = Mock()
mock_result.fetchall.return_value = [("1", "John"), ("2", "Jane")]
mock_result.keys.return_value = ["id", "name"]
mock_connection.execute.return_value = mock_result
# Create a proper context manager mock using MagicMock
from unittest.mock import MagicMock
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_connection
mock_context_manager.__exit__.return_value = None
mock_engine.connect.return_value = mock_context_manager
catalog.get_sql_engine.return_value = mock_engine
return catalog
def test_create_sql_nodes(self, mock_llm, mock_catalog):
"""Test creating SQL processing nodes."""
generate_node, execute_node, regenerate_node, visualization_node = create_sql_nodes(
mock_llm, mock_catalog, "presto"
)
assert callable(generate_node)
assert callable(execute_node)
assert callable(regenerate_node)
assert callable(visualization_node)
def test_generate_sql_node_success(self, mock_llm, mock_catalog):
"""Test successful SQL generation."""
generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(
messages=[],
question="Show all users",
rewrite_question="Show all users",
tables=[{"table": "users", "columns": []}],
)
with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
result = generate_node(state)
assert "sql" in result
assert result["sql"] == "SELECT * FROM users"
def test_generate_sql_node_missing_rewrite_question(self, mock_llm, mock_catalog):
"""Test SQL generation with missing rewrite question."""
generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(
messages=[],
question="Show all users",
# Missing rewrite_question
)
result = generate_node(state)
assert result == {}
def test_generate_sql_node_missing_tables(self, mock_llm, mock_catalog):
"""Test SQL generation with missing tables."""
generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(
messages=[], question="Show all users", rewrite_question="Show all users", tables=[] # Empty tables
)
result = generate_node(state)
assert result == {}
def test_execute_sql_node_success(self, mock_llm, mock_catalog):
"""Test successful SQL execution."""
_, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(messages=[], sql="SELECT * FROM users")
result = execute_node(state)
assert "sql_execution_result" in result
from openchatbi.constants import SQL_SUCCESS
assert result["sql_execution_result"] == SQL_SUCCESS
assert "data" in result
def test_execute_sql_node_empty_sql(self, mock_llm, mock_catalog):
"""Test SQL execution with empty SQL."""
_, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(messages=[], sql="") # Empty SQL
result = execute_node(state)
assert "sql_execution_result" in result
from openchatbi.constants import SQL_NA
assert result["sql_execution_result"] == SQL_NA
def test_execute_sql_node_syntax_error(self, mock_llm, mock_catalog):
"""Test SQL execution with syntax error."""
_, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
# Mock SQL execution to raise syntax error
mock_engine = mock_catalog.get_sql_engine.return_value
mock_connection = mock_engine.connect.return_value.__enter__.return_value
from sqlalchemy.exc import ProgrammingError
mock_connection.execute.side_effect = ProgrammingError("", "", "Syntax error")
state = SQLGraphState(messages=[], sql="SELECT * FRON users") # Intentional syntax error
result = execute_node(state)
assert "sql_execution_result" in result
from openchatbi.constants import SQL_SYNTAX_ERROR
assert result["sql_execution_result"] == SQL_SYNTAX_ERROR
assert "previous_sql_errors" in result
def test_regenerate_sql_node_success(self, mock_llm, mock_catalog):
"""Test successful SQL regeneration."""
_, _, regenerate_node, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(
messages=[],
question="Show all users",
rewrite_question="Show all users",
tables=[{"table": "users", "columns": []}],
previous_sql_errors=[
{"sql": "SELECT * FRON users", "error": "Syntax error: FRON", "error_type": "SQL syntax error"}
],
sql_retry_count=1,
)
with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
result = regenerate_node(state)
assert "sql" in result
assert "sql_retry_count" in result
assert result["sql_retry_count"] == 2
def test_should_retry_sql_success(self):
"""Test retry decision with successful execution."""
# Import the constant from the module
from openchatbi.constants import SQL_SUCCESS
state = SQLGraphState(sql_execution_result=SQL_SUCCESS, sql_retry_count=1)
result = should_retry_sql(state)
assert result == "end"
def test_should_retry_sql_timeout(self):
"""Test retry decision with timeout."""
# Import the constant from the module
from openchatbi.constants import SQL_EXECUTE_TIMEOUT
state = SQLGraphState(sql_execution_result=SQL_EXECUTE_TIMEOUT, sql_retry_count=1)
result = should_retry_sql(state)
assert result == "end"
def test_should_retry_sql_retry_needed(self):
"""Test retry decision when retry is needed."""
state = SQLGraphState(sql_execution_result="SYNTAX_ERROR", sql_retry_count=1)
result = should_retry_sql(state)
assert result == "regenerate_sql"
def test_should_retry_sql_max_retries_reached(self):
"""Test retry decision when max retries reached."""
state = SQLGraphState(sql_execution_result="SYNTAX_ERROR", sql_retry_count=3)
result = should_retry_sql(state)
assert result == "end"
def test_should_execute_sql_with_sql(self):
"""Test execute decision with SQL present."""
state = SQLGraphState(sql="SELECT * FROM users")
result = should_execute_sql(state)
assert result == "execute_sql"
def test_should_execute_sql_without_sql(self):
"""Test execute decision without SQL."""
state = SQLGraphState(sql="")
result = should_execute_sql(state)
assert result == "end"
def test_sql_generation_with_examples(self, mock_llm, mock_catalog):
"""Test SQL generation with relevant examples."""
generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(
messages=[],
question="Show user count",
rewrite_question="Show user count",
tables=[{"table": "users", "columns": []}],
)
# Mock example retrieval
mock_document = Mock()
mock_document.page_content = "How many users are there?"
with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever:
mock_retriever.invoke.return_value = [mock_document]
with patch(
"openchatbi.text2sql.generate_sql.sql_example_dicts",
{"How many users are there?": ("SELECT COUNT(*) FROM users", ["users"])},
):
result = generate_node(state)
assert "sql" in result
def test_sql_error_handling_database_error(self, mock_llm, mock_catalog):
"""Test handling of database connection errors."""
_, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
# Mock database connection error
mock_engine = mock_catalog.get_sql_engine.return_value
mock_connection = mock_engine.connect.return_value.__enter__.return_value
from sqlalchemy.exc import OperationalError
mock_connection.execute.side_effect = OperationalError("", "", "Connection failed")
state = SQLGraphState(messages=[], sql="SELECT * FROM users")
result = execute_node(state)
assert "sql_execution_result" in result
from openchatbi.constants import SQL_EXECUTE_TIMEOUT
assert result["sql_execution_result"] == SQL_EXECUTE_TIMEOUT
def test_regenerate_sql_empty_response(self, mock_llm, mock_catalog):
"""Test regeneration with empty LLM response."""
mock_llm.invoke.return_value = AIMessage(content="")
_, _, regenerate_node, _ = create_sql_nodes(mock_llm, mock_catalog, "presto")
state = SQLGraphState(
messages=[],
question="Show all users",
rewrite_question="Show all users",
tables=[{"table": "users", "columns": []}],
previous_sql_errors=[],
sql_retry_count=1,
)
with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
result = regenerate_node(state)
assert "sql" in result
assert result["sql"] == ""
from openchatbi.constants import SQL_NA
assert result["sql_execution_result"] == SQL_NA
================================================
FILE: tests/test_text2sql_schema_linking.py
================================================
"""Tests for text2sql schema linking functionality."""
from unittest.mock import Mock, patch
import pytest
from langchain_core.messages import AIMessage
from openchatbi.graph_state import SQLGraphState
from openchatbi.text2sql.schema_linking import schema_linking
class TestText2SQLSchemaLinking:
"""Test text2sql schema linking functionality."""
@pytest.fixture
def mock_llm(self):
"""Mock LLM for testing."""
llm = Mock()
llm.invoke.return_value = AIMessage(content='{"tables": [{"table": "users", "reason": "Contains user data"}]}')
return llm
@pytest.fixture
def mock_catalog(self):
"""Mock catalog store for testing."""
catalog = Mock()
catalog.get_table_information.return_value = {
"description": "User data table",
"selection_rule": "Use for user-related queries",
}
return catalog
def test_select_table_function_creation(self, mock_llm, mock_catalog):
"""Test creating table selection function."""
select_func = schema_linking(mock_llm, mock_catalog)
assert callable(select_func)
def test_select_table_success(self, mock_llm, mock_catalog):
"""Test successful table selection."""
with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns:
mock_get_columns.return_value = ["user_id", "name", "email"]
with patch(
"openchatbi.text2sql.schema_linking.column_tables_mapping",
{"user_id": ["users", "profiles"], "name": ["users"], "email": ["users", "contacts"]},
):
with patch(
"openchatbi.text2sql.schema_linking.col_dict",
{
"user_id": {
"column_name": "user_id",
"category": "dimension",
"display_name": "User ID",
"description": "Unique user identifier",
},
"name": {
"column_name": "name",
"category": "dimension",
"display_name": "Name",
"description": "User full name",
},
"email": {
"column_name": "email",
"category": "dimension",
"display_name": "Email",
"description": "User email address",
},
},
):
with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}):
with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract:
mock_extract.return_value = {
"tables": [{"table": "users", "reason": "Contains user data"}]
}
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show user information",
rewrite_question="Show user information",
info_entities={
"keywords": ["user", "information"],
"dimensions": ["name", "email"],
"metrics": [],
},
)
result = select_func(state)
assert "tables" in result
assert len(result["tables"]) == 1
assert result["tables"][0]["table"] == "users"
def test_select_table_missing_rewrite_question(self, mock_llm, mock_catalog):
"""Test table selection with missing rewrite question."""
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show user information",
# Missing rewrite_question
)
result = select_func(state)
assert result == {}
def test_select_table_with_examples(self, mock_llm, mock_catalog):
"""Test table selection with similar examples."""
with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns:
mock_get_columns.return_value = ["user_id", "revenue"]
with patch(
"openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"], "revenue": ["sales"]}
):
with patch(
"openchatbi.text2sql.schema_linking.col_dict",
{
"user_id": {
"column_name": "user_id",
"category": "dimension",
"display_name": "User ID",
"description": "Unique user identifier",
},
"revenue": {
"column_name": "revenue",
"category": "metric",
"display_name": "Revenue",
"description": "Total revenue amount",
},
},
):
# Mock similar examples
mock_document = Mock()
mock_document.page_content = "What is user revenue?"
with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever:
mock_retriever.invoke.return_value = [mock_document]
with patch(
"openchatbi.text2sql.schema_linking.table_selection_example_dict",
{"What is user revenue?": ["users", "sales"]},
):
with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract:
mock_extract.return_value = {"tables": [{"table": "users"}, {"table": "sales"}]}
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show user revenue",
rewrite_question="Show user revenue",
info_entities={
"keywords": ["user", "revenue"],
"dimensions": ["user_id"],
"metrics": ["revenue"],
},
)
result = select_func(state)
assert "tables" in result
assert len(result["tables"]) == 2
def test_select_table_invalid_table_selection(self, mock_llm, mock_catalog):
"""Test handling of invalid table selection."""
with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns:
mock_get_columns.return_value = ["user_id"]
with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}):
with patch(
"openchatbi.text2sql.schema_linking.col_dict",
{
"user_id": {
"column_name": "user_id",
"category": "dimension",
"display_name": "User ID",
"description": "Unique user identifier",
}
},
):
with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}):
with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract:
# Return invalid table not in candidate list
mock_extract.return_value = {"tables": [{"table": "invalid_table"}]}
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show user info",
rewrite_question="Show user info",
info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []},
)
result = select_func(state)
# Should return empty dict when invalid table selected
assert result == {}
def test_select_table_retry_mechanism(self, mock_llm, mock_catalog):
"""Test retry mechanism for table selection."""
with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns:
mock_get_columns.return_value = ["user_id"]
with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}):
with patch(
"openchatbi.text2sql.schema_linking.col_dict",
{
"user_id": {
"column_name": "user_id",
"category": "dimension",
"display_name": "User ID",
"description": "Unique user identifier",
}
},
):
with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}):
with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract:
# First returns invalid, then valid
mock_extract.side_effect = [
{"tables": [{"table": "invalid_table"}]},
{"tables": [{"table": "users"}]},
]
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show user info",
rewrite_question="Show user info",
info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []},
)
result = select_func(state)
assert "tables" in result
assert result["tables"][0]["table"] == "users"
def test_select_table_with_time_filter(self, mock_llm, mock_catalog):
"""Test table selection with time filtering."""
# Mock table with start_time
mock_catalog.get_table_information.return_value = {
"description": "User data table",
"selection_rule": "Use for user queries",
"start_time": "2024-01-01",
}
with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns:
mock_get_columns.return_value = ["user_id"]
with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}):
with patch(
"openchatbi.text2sql.schema_linking.col_dict",
{
"user_id": {
"column_name": "user_id",
"category": "dimension",
"display_name": "User ID",
"description": "Unique user identifier",
}
},
):
with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}):
with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract:
mock_extract.return_value = {"tables": [{"table": "users"}]}
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show recent user info",
rewrite_question="Show recent user info",
info_entities={
"keywords": ["user"],
"dimensions": ["user_id"],
"metrics": [],
"start_time": "2024-06-01", # Later than table start_time
},
)
result = select_func(state)
assert "tables" in result
assert result["tables"][0]["table"] == "users"
def test_select_table_llm_error_handling(self, mock_llm, mock_catalog):
"""Test handling of LLM errors during table selection."""
mock_llm.invoke.side_effect = Exception("LLM service error")
with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns:
mock_get_columns.return_value = ["user_id"]
with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}):
with patch(
"openchatbi.text2sql.schema_linking.col_dict",
{
"user_id": {
"column_name": "user_id",
"category": "dimension",
"display_name": "User ID",
"description": "Unique user identifier",
}
},
):
with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show user info",
rewrite_question="Show user info",
info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []},
)
result = select_func(state)
# Should handle error gracefully and return empty dict
assert result == {}
def test_select_table_max_retries_exceeded(self, mock_llm, mock_catalog):
"""Test behavior when max retries are exceeded."""
with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns:
mock_get_columns.return_value = ["user_id"]
with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}):
with patch(
"openchatbi.text2sql.schema_linking.col_dict",
{
"user_id": {
"column_name": "user_id",
"category": "dimension",
"display_name": "User ID",
"description": "Unique user identifier",
}
},
):
with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever:
mock_retriever.invoke.return_value = []
with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}):
with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract:
# Always return invalid table
mock_extract.return_value = {"tables": [{"table": "invalid_table"}]}
select_func = schema_linking(mock_llm, mock_catalog)
state = SQLGraphState(
messages=[],
question="Show user info",
rewrite_question="Show user info",
info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []},
)
result = select_func(state)
# Should return empty dict after max retries
assert result == {}
================================================
FILE: tests/test_text2sql_visualization.py
================================================
"""Tests for text2sql visualization functionality."""
import pytest
from openchatbi.text2sql.visualization import ChartType, VisualizationConfig, VisualizationDSL, VisualizationService
class TestVisualizationService:
"""Tests for the VisualizationService class."""
def test_generate_visualization_dsl_basic(self):
"""Test basic DSL generation with schema info."""
schema_info = {
"columns": ["name", "age", "salary", "department"],
"row_count": 4,
"numeric_columns": ["age", "salary"],
"categorical_columns": ["name", "department"],
"datetime_columns": [],
"unique_counts": {"name": 4, "department": 2},
}
service = VisualizationService()
question = "Compare salary by department"
dsl = service.generate_visualization_dsl(question, schema_info)
assert dsl.chart_type == "bar"
# Should use first categorical column which is "name"
assert "name" in dsl.data_columns
assert "age" in dsl.data_columns and "salary" in dsl.data_columns # Both numeric columns should be included
def test_get_chart_type_by_rule_with_datetime(self):
"""Test chart type recommendation with datetime columns."""
schema_info = {
"columns": ["date", "sales", "region"],
"numeric_columns": ["sales"],
"categorical_columns": ["region"],
"datetime_columns": ["date"],
"row_count": 3,
}
service = VisualizationService()
question = "Show sales trend over time"
chart_type = service._get_chart_type_by_rule(question, schema_info)
assert chart_type == ChartType.LINE
def test_generate_visualization_dsl_error_handling(self):
"""Test DSL generation with error in schema info."""
schema_info = {"error": "Failed to analyze data schema"}
service = VisualizationService()
question = "Show data"
dsl = service.generate_visualization_dsl(question, schema_info)
assert dsl.chart_type == "table"
assert "error" in dsl.config
def test_get_chart_type_by_rule_line_chart(self):
"""Test recommendation for line chart based on question keywords."""
question = "Show me the sales trend over time"
schema = {
"numeric_columns": ["sales"],
"categorical_columns": ["region"],
"datetime_columns": ["date"],
"row_count": 10,
}
service = VisualizationService()
chart_type = service._get_chart_type_by_rule(question, schema)
assert chart_type == ChartType.LINE
def test_get_chart_type_by_rule_pie_chart(self):
"""Test recommendation for pie chart based on question keywords."""
question = "What is the percentage breakdown by department?"
schema = {
"numeric_columns": ["count"],
"categorical_columns": ["department"],
"datetime_columns": [],
"row_count": 5,
"unique_counts": {"department": 4},
}
service = VisualizationService()
chart_type = service._get_chart_type_by_rule(question, schema)
assert chart_type == ChartType.PIE
def test_get_chart_type_by_rule_bar_chart(self):
"""Test recommendation for bar chart based on question keywords."""
question = "Compare sales by region"
schema = {
"numeric_columns": ["sales"],
"categorical_columns": ["region"],
"datetime_columns": [],
"row_count": 10,
"unique_counts": {"region": 4},
}
service = VisualizationService()
chart_type = service._get_chart_type_by_rule(question, schema)
assert chart_type == ChartType.BAR
def test_get_chart_type_by_rule_scatter_plot(self):
"""Test recommendation for scatter plot based on data characteristics."""
question = "Show relationship between age and salary"
schema = {
"numeric_columns": ["age", "salary"],
"categorical_columns": ["name"],
"datetime_columns": [],
"row_count": 10,
}
service = VisualizationService()
chart_type = service._get_chart_type_by_rule(question, schema)
assert chart_type == ChartType.SCATTER
def test_get_chart_type_by_rule_histogram(self):
"""Test recommendation for histogram based on keywords."""
question = "What is the distribution of ages?"
schema = {"numeric_columns": ["age"], "categorical_columns": [], "datetime_columns": [], "row_count": 100}
service = VisualizationService()
chart_type = service._get_chart_type_by_rule(question, schema)
assert chart_type == ChartType.HISTOGRAM
def test_get_chart_type_by_rule_data_based_priority(self):
"""Test that data characteristics take priority over row count."""
question = "Show all records"
schema = {
"numeric_columns": ["value"],
"categorical_columns": ["category"],
"datetime_columns": [],
"row_count": 15,
"unique_counts": {"category": 5}, # Small number of categories
}
service = VisualizationService()
chart_type = service._get_chart_type_by_rule(question, schema)
# Should choose PIE because of categorical + numeric columns, not TABLE due to row count
assert chart_type == ChartType.PIE
def test_generate_visualization_dsl_line_chart(self):
"""Test DSL generation for line chart."""
question = "Show sales trend over time"
schema_info = {
"columns": ["date", "sales"],
"numeric_columns": ["sales"],
"categorical_columns": [],
"datetime_columns": ["date"],
"row_count": 3,
}
service = VisualizationService()
dsl = service.generate_visualization_dsl(question, schema_info)
assert dsl.chart_type == "line"
assert "date" in dsl.data_columns
assert "sales" in dsl.data_columns
assert dsl.config["x"] == "date"
assert dsl.config["y"] == "sales"
def test_generate_visualization_dsl_bar_chart(self):
"""Test DSL generation for bar chart."""
question = "Compare sales by region"
schema_info = {
"columns": ["region", "sales"],
"numeric_columns": ["sales"],
"categorical_columns": ["region"],
"datetime_columns": [],
"row_count": 4,
"unique_counts": {"region": 4},
}
service = VisualizationService()
dsl = service.generate_visualization_dsl(question, schema_info)
assert dsl.chart_type == "bar"
assert "region" in dsl.data_columns
assert "sales" in dsl.data_columns
assert dsl.config["x"] == "region"
assert dsl.config["y"] == "sales"
def test_generate_visualization_dsl_pie_chart(self):
"""Test DSL generation for pie chart."""
question = "Show percentage breakdown by department"
schema_info = {
"columns": ["department", "count"],
"numeric_columns": ["count"],
"categorical_columns": ["department"],
"datetime_columns": [],
"row_count": 4,
"unique_counts": {"department": 4},
}
service = VisualizationService()
dsl = service.generate_visualization_dsl(question, schema_info, ChartType.PIE)
assert dsl.chart_type == "pie"
assert dsl.config["labels"] == "department"
assert dsl.config["values"] == "count"
def test_generate_visualization_dsl_empty_data(self):
"""Test DSL generation with empty data."""
question = "Show data"
schema_info = {
"columns": [],
"numeric_columns": [],
"categorical_columns": [],
"datetime_columns": [],
"row_count": 0,
}
service = VisualizationService()
dsl = service.generate_visualization_dsl(question, schema_info)
assert dsl.chart_type == "table"
assert "columns" in dsl.config
def test_visualization_config_dataclass(self):
"""Test VisualizationConfig dataclass."""
config = VisualizationConfig(
chart_type=ChartType.BAR, x_column="category", y_column="value", title="Test Chart"
)
assert config.chart_type == ChartType.BAR
assert config.x_column == "category"
assert config.y_column == "value"
assert config.title == "Test Chart"
assert config.show_legend is True # default value
def test_visualization_dsl_to_dict(self):
"""Test VisualizationDSL to_dict method."""
dsl = VisualizationDSL(
chart_type="bar",
data_columns=["x", "y"],
config={"x": "category", "y": "value"},
layout={"title": "Test Chart"},
)
result = dsl.to_dict()
assert result["chart_type"] == "bar"
assert result["data_columns"] == ["x", "y"]
assert result["config"]["x"] == "category"
assert result["layout"]["title"] == "Test Chart"
class TestChartType:
"""Tests for ChartType enum."""
def test_chart_type_values(self):
"""Test ChartType enum values."""
assert ChartType.LINE.value == "line"
assert ChartType.BAR.value == "bar"
assert ChartType.PIE.value == "pie"
assert ChartType.SCATTER.value == "scatter"
assert ChartType.HISTOGRAM.value == "histogram"
assert ChartType.BOX.value == "box"
assert ChartType.HEATMAP.value == "heatmap"
assert ChartType.TABLE.value == "table"
@pytest.fixture
def sample_csv_data():
"""Fixture providing sample CSV data for testing."""
return """product,sales,region,quarter
Widget A,10000,North,Q1
Widget B,15000,South,Q1
Widget C,8000,East,Q1
Widget A,12000,North,Q2
Widget B,18000,South,Q2
Widget C,9000,East,Q2"""
@pytest.fixture
def sample_time_series_data():
"""Fixture providing sample time series data for testing."""
return """date,revenue,users
2023-01-01,50000,1000
2023-02-01,55000,1100
2023-03-01,60000,1200
2023-04-01,52000,1050
2023-05-01,58000,1150"""
class TestVisualizationIntegration:
"""Integration tests for visualization functionality."""
def test_complete_workflow_line_chart(self, sample_time_series_data):
"""Test complete workflow for generating line chart."""
question = "Show revenue trend over time"
# Mock schema info for time series data
schema_info = {
"columns": ["date", "revenue", "users"],
"numeric_columns": ["revenue", "users"],
"categorical_columns": [],
"datetime_columns": ["date"],
"row_count": 5,
}
service = VisualizationService()
# Recommend chart type
chart_type = service._get_chart_type_by_rule(question, schema_info)
# Generate DSL
dsl = service.generate_visualization_dsl(question, schema_info, chart_type)
assert chart_type == ChartType.LINE
assert dsl.chart_type == "line"
assert "date" in dsl.data_columns
assert "revenue" in dsl.data_columns
def test_complete_workflow_bar_chart(self, sample_csv_data):
"""Test complete workflow for generating bar chart."""
question = "Compare sales by product"
# Mock schema info for sample CSV data
schema_info = {
"columns": ["product", "sales", "region", "quarter"],
"numeric_columns": ["sales"],
"categorical_columns": ["product", "region", "quarter"],
"datetime_columns": [],
"row_count": 6,
"unique_counts": {"product": 3, "region": 3, "quarter": 2},
}
service = VisualizationService()
# Generate DSL directly (will analyze schema internally)
dsl = service.generate_visualization_dsl(question, schema_info)
assert dsl.chart_type in ["bar", "line"] # Could be either based on heuristics
assert len(dsl.data_columns) >= 2
assert dsl.layout.get("title") is not None
================================================
FILE: tests/test_tools_ask_human.py
================================================
"""Tests for ask_human tool functionality."""
import pytest
from pydantic import ValidationError
from openchatbi.tool.ask_human import AskHuman
class TestAskHuman:
"""Test AskHuman model functionality."""
def test_ask_human_basic_initialization(self):
"""Test basic AskHuman model creation."""
question = "What time period should I analyze?"
options = ["Last 7 days", "Last 30 days", "Last year"]
ask_human = AskHuman(question=question, options=options)
assert ask_human.question == question
assert ask_human.options == options
def test_ask_human_empty_options(self):
"""Test AskHuman with empty options list."""
ask_human = AskHuman(question="Simple question?", options=[])
assert ask_human.question == "Simple question?"
assert ask_human.options == []
def test_ask_human_validation_error(self):
"""Test AskHuman model validation."""
with pytest.raises(ValidationError):
AskHuman() # Missing required fields
with pytest.raises(ValidationError):
AskHuman(question="Test") # Missing options field
def test_ask_human_serialization(self):
"""Test AskHuman model serialization."""
ask_human = AskHuman(question="Which analysis method?", options=["Statistical", "Machine Learning"])
data = ask_human.model_dump()
assert data["question"] == "Which analysis method?"
assert data["options"] == ["Statistical", "Machine Learning"]
================================================
FILE: tests/test_tools_run_python_code.py
================================================
"""Tests for run_python_code tool functionality."""
from unittest.mock import patch
from openchatbi.tool.run_python_code import run_python_code
class TestRunPythonCode:
"""Test run_python_code tool functionality."""
def test_run_python_code_basic(self):
"""Test basic Python code execution."""
reasoning = "Testing basic print functionality"
code = "print('Hello, World!')"
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Hello, World!\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Hello, World!" in result
def test_run_python_code_with_variables(self):
"""Test Python code execution with variables."""
reasoning = "Testing variable operations"
code = """
x = 10
y = 20
result = x + y
print(f"Result: {result}")
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Result: 30\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Result: 30" in result
def test_run_python_code_data_analysis(self):
"""Test Python code for data analysis operations."""
reasoning = "Performing data analysis calculations"
code = """
import math
data = [1, 2, 3, 4, 5]
mean = sum(data) / len(data)
print(f"Mean: {mean}")
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Mean: 3.0\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Mean: 3.0" in result
def test_run_python_code_matplotlib_plot(self):
"""Test Python code for creating plots."""
reasoning = "Creating a matplotlib visualization"
code = """
import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
plt.plot(x, y)
plt.title('Sample Plot')
print('Plot created successfully')
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Plot created successfully\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Plot created successfully" in result
def test_run_python_code_syntax_error(self):
"""Test Python code execution with syntax errors."""
reasoning = "Testing error handling for syntax errors"
code = "print('Hello World'" # Missing closing parenthesis
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (False, "SyntaxError: unexpected EOF while parsing")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Error:" in result
assert "SyntaxError" in result
def test_run_python_code_runtime_error(self):
"""Test Python code execution with runtime errors."""
reasoning = "Testing error handling for runtime errors"
code = """
x = 10
y = 0
result = x / y
print(result)
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (False, "ZeroDivisionError: division by zero")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Error:" in result
assert "ZeroDivisionError" in result
def test_run_python_code_import_error(self):
"""Test Python code execution with import errors."""
reasoning = "Testing error handling for import errors"
code = """
import nonexistent_module
print('This should not print')
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (False, "ModuleNotFoundError: No module named 'nonexistent_module'")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Error:" in result
assert "ModuleNotFoundError" in result
def test_run_python_code_multiline_output(self):
"""Test Python code with multiple print statements."""
reasoning = "Testing multiple output lines"
code = """
for i in range(3):
print(f"Line {i + 1}")
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Line 1\nLine 2\nLine 3\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Line 1" in result
assert "Line 2" in result
assert "Line 3" in result
def test_run_python_code_with_sql_data(self):
"""Test Python code working with SQL-like data."""
reasoning = "Processing SQL query results"
code = """
data = [
{'name': 'Alice', 'age': 30, 'salary': 50000},
{'name': 'Bob', 'age': 25, 'salary': 45000},
{'name': 'Charlie', 'age': 35, 'salary': 55000}
]
total_salary = sum(row['salary'] for row in data)
print(f"Total salary: ${total_salary}")
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Total salary: $150000\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Total salary: $150000" in result
def test_run_python_code_empty_code(self):
"""Test Python code execution with empty code."""
reasoning = "Testing empty code handling"
code = ""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert result == ""
def test_run_python_code_whitespace_only(self):
"""Test Python code execution with whitespace only."""
reasoning = "Testing whitespace-only code"
code = " \n \t \n "
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert result == ""
def test_run_python_code_with_comments(self):
"""Test Python code execution with comments."""
reasoning = "Testing code with comments"
code = """
# This is a comment
x = 5 # Another comment
print(f"Value: {x}") # Final comment
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Value: 5\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Value: 5" in result
def test_run_python_code_security_restrictions(self):
"""Test Python code with potentially restricted operations."""
reasoning = "Testing security restrictions"
code = """
# Attempting file operations
try:
with open('/etc/passwd', 'r') as f:
content = f.read()
print("File read successfully")
except Exception as e:
print(f"Security restriction: {e}")
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (
False,
"PermissionError: [Errno 13] Permission denied: '/etc/passwd'",
)
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Error:" in result
def test_run_python_code_timeout_handling(self):
"""Test Python code execution timeout scenarios."""
reasoning = "Testing timeout handling"
code = """
import time
time.sleep(10) # Long running operation
print("This might timeout")
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (False, "TimeoutError: Code execution timed out")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Error:" in result
assert "TimeoutError" in result
def test_run_python_code_memory_limit(self):
"""Test Python code execution with memory limitations."""
reasoning = "Testing memory limit handling"
code = """
# Creating a large list that might exceed memory limits
large_list = [0] * (10**8)
print(f"Created list with {len(large_list)} elements")
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (False, "MemoryError: Unable to allocate memory")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Error:" in result
assert "MemoryError" in result
def test_run_python_code_return_values(self):
"""Test that return values are not captured (only prints)."""
reasoning = "Testing return value handling"
code = """
def calculate():
return 42
result = calculate()
print(f"Function returned: {result}")
# The return value itself should not be captured
calculate()
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (True, "Function returned: 42\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Function returned: 42" in result
# Should not contain the raw return value
assert result.strip() == "Function returned: 42"
def test_run_python_code_exception_details(self):
"""Test detailed exception information."""
reasoning = "Testing detailed exception handling"
code = """
def faulty_function():
raise ValueError("This is a custom error message")
faulty_function()
"""
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor:
mock_instance = mock_executor.return_value
mock_instance.run_code.return_value = (False, "ValueError: This is a custom error message")
result = run_python_code.run({"reasoning": reasoning, "code": code})
assert isinstance(result, str)
assert "Error:" in result
assert "ValueError" in result
assert "custom error message" in result
def test_run_python_code_executor_selection(self):
"""Test that LocalExecutor is properly instantiated and used."""
reasoning = "Testing executor instantiation"
code = "print('Executor test')"
with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor_class:
mock_instance = mock_executor_class.return_value
mock_instance.run_code.return_value = (True, "Executor test\n")
result = run_python_code.run({"reasoning": reasoning, "code": code})
# Verify LocalExecutor was instantiated
mock_executor_class.assert_called_once()
# Verify run_code was called with the correct code
mock_instance.run_code.assert_called_once_with(code)
assert isinstance(result, str)
assert "Executor test" in result
================================================
FILE: tests/test_tools_search_knowledge.py
================================================
"""Tests for search_knowledge tool functionality."""
from unittest.mock import patch
import pytest
from openchatbi.tool.search_knowledge import search_knowledge, show_schema
class TestSearchKnowledge:
"""Test search_knowledge tool functionality."""
def test_search_knowledge_basic(self):
"""Test basic knowledge search functionality."""
reasoning = "Looking for user information"
query_list = ["user", "information"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_id: User identifier\nuser_name: User name"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "columns" in result
assert "User identifier" in result["columns"]
mock_search.assert_called_once_with(query_list, False)
def test_search_knowledge_table_matching(self):
"""Test knowledge search with table matching."""
reasoning = "Finding table relationships"
query_list = ["user", "metrics"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_id: Unique identifier\nmetrics_value: Metric value"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": True,
}
)
assert isinstance(result, dict)
assert "columns" in result
mock_search.assert_called_once_with(query_list, True)
def test_search_knowledge_empty_query(self):
"""Test knowledge search with empty query."""
reasoning = "Testing empty search"
query_list = []
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = ""
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "columns" in result
mock_search.assert_called_once_with(query_list, False)
def test_search_knowledge_no_matches(self):
"""Test knowledge search with no matches."""
reasoning = "Testing no matches"
query_list = ["nonexistent"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = ""
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "columns" in result
assert result["columns"] == "# Relevant Columns and Description:\n"
def test_search_knowledge_multiple_matches(self):
"""Test knowledge search with multiple matches."""
reasoning = "Finding multiple matches"
query_list = ["user", "data", "profile"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_id: User ID\nuser_name: Name\nprofile_data: Profile"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "columns" in result
assert "user_id" in result["columns"]
assert "profile_data" in result["columns"]
def test_search_knowledge_with_synonyms(self):
"""Test knowledge search with synonym matching."""
reasoning = "Testing synonym search"
query_list = ["customer", "client", "user"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "customer_id: Customer identifier\nclient_name: Client name"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "columns" in result
def test_search_knowledge_case_insensitive(self):
"""Test case insensitive knowledge search."""
reasoning = "Testing case sensitivity"
query_list = ["USER", "Data", "PROFILE"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_data: User information"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "columns" in result
def test_search_knowledge_partial_matches(self):
"""Test knowledge search with partial matches."""
reasoning = "Testing partial matching"
query_list = ["usr", "prof"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_profile: User profile data"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
def test_search_knowledge_error_handling(self):
"""Test knowledge search error handling."""
reasoning = "Testing error handling"
query_list = ["test"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.side_effect = Exception("Search error")
# Should handle exceptions gracefully
with pytest.raises(Exception):
search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
def test_show_schema_basic(self):
"""Test basic schema display functionality."""
reasoning = "Showing basic schema"
tables = ["user_data"]
with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list:
mock_list.return_value = ["Table: user_data\n# Description: User information\n# Columns:\nuser_id: User ID"]
result = show_schema.run({"reasoning": reasoning, "tables": tables})
assert isinstance(result, list)
assert len(result) == 1
assert "user_data" in result[0]
mock_list.assert_called_once_with(tables)
def test_show_schema_detailed_info(self):
"""Test detailed schema information."""
reasoning = "Showing detailed schema"
tables = ["user_data", "metrics"]
with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list:
mock_list.return_value = [
"Table: user_data\n# Columns: user_id, name, email",
"Table: metrics\n# Columns: metric_id, value, timestamp",
]
result = show_schema.run({"reasoning": reasoning, "tables": tables})
assert isinstance(result, list)
assert len(result) == 2
assert any("user_data" in schema for schema in result)
assert any("metrics" in schema for schema in result)
def test_show_schema_nonexistent_table(self):
"""Test schema display for nonexistent table."""
reasoning = "Testing nonexistent table"
tables = ["nonexistent_table"]
with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list:
mock_list.return_value = []
result = show_schema.run({"reasoning": reasoning, "tables": tables})
assert isinstance(result, list)
assert len(result) == 0
def test_show_schema_table_error(self):
"""Test schema display error handling."""
reasoning = "Testing schema errors"
tables = ["error_table"]
with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list:
mock_list.side_effect = Exception("Table access error")
with pytest.raises(Exception):
show_schema.run({"reasoning": reasoning, "tables": tables})
def test_show_schema_complex_table(self):
"""Test schema display for complex table structure."""
reasoning = "Showing complex schema"
tables = ["complex_table"]
with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list:
mock_list.return_value = [
"Table: complex_table\n# Description: Complex data structure\n# Columns:\nid: Primary key\ndata: JSON data\ncreated_at: Timestamp"
]
result = show_schema.run({"reasoning": reasoning, "tables": tables})
assert isinstance(result, list)
assert "complex_table" in result[0]
assert "Primary key" in result[0]
def test_search_knowledge_with_metrics(self):
"""Test knowledge search focusing on metrics."""
reasoning = "Finding metrics columns"
query_list = ["revenue", "clicks", "impressions"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "revenue: Revenue amount\nclicks: Click count\nimpressions: Impression count"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "revenue" in result["columns"]
assert "clicks" in result["columns"]
def test_search_knowledge_contextual_search(self):
"""Test contextual knowledge search."""
reasoning = "Contextual search for user behavior"
query_list = ["user", "behavior", "tracking"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_behavior: User activity tracking\ntracking_id: Tracking identifier"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "behavior" in result["columns"]
def test_search_knowledge_with_aggregations(self):
"""Test knowledge search for aggregation columns."""
reasoning = "Finding aggregation metrics"
query_list = ["sum", "count", "average"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "total_count: Count aggregation\naverage_value: Average calculation"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
def test_show_schema_with_examples(self):
"""Test schema display with usage examples."""
reasoning = "Showing schema with examples"
tables = ["example_table"]
with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list:
mock_list.return_value = [
"Table: example_table\n# Description: Example usage\n## Derived metrics:\nSELECT COUNT(*) FROM example_table"
]
result = show_schema.run({"reasoning": reasoning, "tables": tables})
assert isinstance(result, list)
assert "example_table" in result[0]
assert "Derived metrics" in result[0]
def test_search_knowledge_performance(self):
"""Test knowledge search performance characteristics."""
reasoning = "Testing search performance"
query_list = ["performance", "speed", "optimization"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "performance_metric: Performance measurement"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
# Just ensure it completes without performance issues
def test_search_knowledge_special_characters(self):
"""Test knowledge search with special characters."""
reasoning = "Testing special character handling"
query_list = ["user@domain", "data-point", "metric_value"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_email: User email address\ndata_point: Data measurement"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
def test_search_knowledge_unicode_support(self):
"""Test knowledge search with unicode characters."""
reasoning = "Testing unicode support"
query_list = ["utilización", "données", "用户"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "user_data: International user data"
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
def test_knowledge_integration_with_state(self):
"""Test knowledge search integration with agent state."""
reasoning = "Testing state integration"
query_list = ["state", "integration"]
knowledge_bases = ["columns"]
with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search:
mock_search.return_value = "state_data: Application state information"
# Test that the tool can be called in the context of agent state
result = search_knowledge.run(
{
"reasoning": reasoning,
"query_list": query_list,
"knowledge_bases": knowledge_bases,
"with_table_list": False,
}
)
assert isinstance(result, dict)
assert "columns" in result
================================================
FILE: tests/test_utils.py
================================================
"""Tests for utility functions."""
import io
from unittest.mock import patch
import pytest
from openchatbi.utils import log
class TestUtilityFunctions:
"""Test utility functions."""
def test_log_function_basic(self):
"""Test basic logging functionality."""
# Capture stdout
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log("Test message")
output = captured_output.getvalue()
assert "Test message" in output
def test_log_function_multiple_messages(self):
"""Test logging with multiple messages."""
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log("First message")
log("Second message")
output = captured_output.getvalue()
assert "First message" in output
assert "Second message" in output
def test_log_function_empty_message(self):
"""Test logging with empty message."""
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log("")
output = captured_output.getvalue()
# Should handle empty messages gracefully
assert output is not None
def test_log_function_none_message(self):
"""Test logging with None message."""
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log(None)
output = captured_output.getvalue()
# Should handle None messages gracefully
assert "None" in output or output == ""
def test_log_function_complex_objects(self):
"""Test logging with complex objects."""
test_dict = {"key": "value", "number": 42}
test_list = [1, 2, 3, "string"]
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log(test_dict)
log(test_list)
output = captured_output.getvalue()
assert "key" in output or str(test_dict) in output
assert "string" in output or str(test_list) in output
def test_log_function_with_exception(self):
"""Test logging exception information."""
try:
raise ValueError("Test exception")
except ValueError as e:
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log(f"Exception occurred: {e}")
output = captured_output.getvalue()
assert "Exception occurred" in output
assert "Test exception" in output
@patch("sys.stderr")
def test_log_function_stderr_error(self, mock_stderr):
"""Test logging when stderr has issues."""
mock_stderr.write.side_effect = OSError("stderr error")
# Current implementation raises exception when stderr fails - this is expected
with pytest.raises(OSError, match="stderr error"):
log("Test message")
def test_log_function_unicode_handling(self):
"""Test logging with unicode characters."""
unicode_message = "Test with émojis: 🚀 and spéciál characters: ñáéíóú"
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log(unicode_message)
output = captured_output.getvalue()
# Should handle unicode characters properly
assert len(output) > 0
def test_log_function_large_message(self):
"""Test logging with very large messages."""
large_message = "A" * 10000 # 10KB message
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log(large_message)
output = captured_output.getvalue()
assert len(output) > 0
assert "A" in output
def test_log_function_newline_handling(self):
"""Test logging with messages containing newlines."""
multiline_message = "Line 1\\nLine 2\\nLine 3"
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log(multiline_message)
output = captured_output.getvalue()
assert "Line 1" in output
assert "Line 2" in output
assert "Line 3" in output
def test_log_function_timestamp_format(self):
"""Test that log includes timestamp information."""
captured_output = io.StringIO()
with patch("sys.stderr", captured_output):
log("Timestamp test")
output = captured_output.getvalue()
# Check if output contains timestamp-like format (basic check)
# The actual implementation might vary
assert len(output) > len("Timestamp test")
def test_log_function_concurrent_calls(self):
"""Test logging with concurrent-like calls."""
import threading
import time
captured_output = io.StringIO()
def log_worker(message):
log(f"Worker: {message}")
time.sleep(0.01) # Small delay
# Patch stderr for all threads
with patch("sys.stderr", captured_output):
# Create multiple threads (simulating concurrency)
threads = []
for i in range(5):
thread = threading.Thread(target=log_worker, args=(f"message_{i}",))
threads.append(thread)
# Start all threads
for thread in threads:
thread.start()
# Wait for all threads
for thread in threads:
thread.join()
output = captured_output.getvalue()
# Should handle concurrent access gracefully
assert len(output) > 0
================================================
FILE: timeseries_forecasting/Dockerfile
================================================
FROM python:3.10-slim
# Install only essential build tools
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
wget \
build-essential \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
# Install Python dependencies for time series forecasting
RUN pip3 --no-cache-dir install \
fastapi==0.120.4 \
uvicorn==0.38.0 \
transformers==4.40.1 \
torch==2.9.0 \
numpy==2.2.6 \
pandas==2.3.3 \
pydantic==2.12.3
# Set working directory
WORKDIR /home/model-server
# Copy the model
COPY ../hf_model /home/model-server/hf_model
# Copy application files
COPY app.py model_handler.py /home/model-server/
# Set environment variables
ENV PYTHONPATH=/home/model-server
ENV PYTHONUNBUFFERED=1
# Expose port
EXPOSE 8765
# Define entrypoint and default command
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8765"]
================================================
FILE: timeseries_forecasting/README.md
================================================
# Transformer Time Series Forecasting Service
A Docker-based time series forecasting service using Transformer based models for accurate time series prediction. This service provides a FastAPI-based REST API for easy integration with various applications.
## Features
- **Transformer Model Integration**: Uses state-of-the-art Transformer models for time series forecasting
- **FastAPI Backend**: Modern, fast web framework with automatic API documentation
- **Docker Support**: Fully containerized service for easy deployment
- **Flexible Input**: Supports both simple numeric arrays and structured data with timestamps
- **Multiple Forecast Horizons**: Configure prediction length from 1 to 200 time steps
- **GPU Support**: Automatic GPU detection and utilization when available
## Prerequisites
- Docker installed and running
- Transformer model files (compatible with Hugging Face transformers library)
## Quick Start
### 1. Download Transformer Model
Download a pre-trained model from Hugging Face and place it in the `hf_model` directory. For example, use the recommended `timer-base-84m` model from https://huggingface.co/thuml/timer-base-84m:
> **Note**: The `timer-base-84m` model requires at least 96 time points in the input data. When integrating with OpenChatBI, add this restriction to your `extra_tool_use_rule` in bi.yaml:
> ```
> - timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros.
> ```
```bash
### 2. Build and Run
```bash
cd timeseries_forecasting
chmod +x build_and_run.sh
./build_and_run.sh
```
The service will be available at:
- **Predictions**: `http://localhost:8765/predict`
- **Health Check**: `http://localhost:8765/health`
- **API Documentation**: `http://localhost:8765/docs`
- **Model Info**: `http://localhost:8765/model/info`
### 2. Make a Prediction
```bash
curl -X POST http://localhost:8765/predict \
-H "Content-Type: application/json" \
-d '{
"input": [100, 102, 98, 105, 107, 103, 99, 101, 104, 106],
"input_len": 100,
"forecast_window": 5,
"frequency": "H"
}'
```
### 4. Test the Service
Run the comprehensive test suite:
```bash
python test_forecasting.py --url http://localhost:8765
```
## API Reference
### Prediction Endpoint
**POST** `/predict`
#### Request Format
```json
{
"input": [...], // Time series data (required)
"forecast_window": 24, // Number of future points to predict (default: 24, max: 200)
"frequency": "H", // Frequency: "H" (hourly), "D" (daily), etc. (default: "H")
"input_len": null, // Limit input length, if provided, will use it to truncate input or pad zero (optional)
"target_column": "value" // Column name for structured data (default: "value")
}
```
#### Input Data Formats
**Simple Numeric Array:**
```json
{
"input": [100, 102, 98, 105, 107, 103, 99, 101],
"input_len": 100,
"forecast_window": 12
}
```
**Structured Data with Timestamps:**
```json
{
"input": [
{"timestamp": "2024-01-01T00:00:00", "value": 100},
{"timestamp": "2024-01-01T01:00:00", "value": 102},
{"timestamp": "2024-01-01T02:00:00", "value": 98}
],
"input_len": 100,
"forecast_window": 24,
"target_column": "value"
}
```
#### Response Format
```json
{
"predictions": [101.5, 103.2, 99.8, ...],
"forecast_window": 24,
"frequency": "H",
"status": "success"
}
```
## Configuration
### Environment Variables
- `PYTHONPATH`: Python path for modules (default: /home/model-server)
- `PYTHONUNBUFFERED`: Disable Python output buffering (default: 1)
### Docker Run Options
```bash
# Basic run
docker run -p 8765:8765 timeseries-forecasting
# With volume mount for models
docker run -p 8765:8765 \
-v /path/to/model:/app/hf_model \
timeseries-forecasting
# With custom environment variables
docker run -p 8765:8765 \
-e PYTHONPATH=/home/model-server \
timeseries-forecasting
```
## Testing
### Service Tests
Run the test script to validate the service:
```bash
# Make test script executable
chmod +x test_forecasting.py
# Install test dependencies
pip install requests numpy
# Run tests
python test_forecasting.py --url http://localhost:8765
```
## Model Information
- **Recommended Models**: https://huggingface.co/thuml/timer-base-84m
- **Model Type**: Transformer-based Causal Language Model for Time Series
- **Framework**: Hugging Face Transformers
- **Architecture**: AutoModelForCausalLM
- **Device Support**: Automatic GPU/CPU detection
- **Capabilities**: Univariate time series forecasting with automatic normalization
## Troubleshooting
### Common Issues
1. **Service Not Starting**
- Check if port 8765 is available: `lsof -i :8765`
- Verify Docker has sufficient memory allocated (minimum 4GB recommended)
- Check logs: `docker logs time-series-forecasting-service`
2. **Model Loading Errors**
- Ensure model files are properly copied during build
- Check available disk space (models can be several GB)
- Verify Hugging Face transformers library compatibility
3. **Prediction Errors**
- Validate input data format
- Check forecast horizon is reasonable
- Ensure input data has sufficient length
### Debug Mode
Enable debug logging:
```bash
docker run -p 8765:8765 \
-e PYTHONPATH=/home/model-server \
-e LOGGING_LEVEL=DEBUG \
timeseries-forecasting
```
## Performance
- **Cold Start**: ~10 seconds (model loading)
- **Inference Time**: ~100-300ms per request (varies by input size and model)
- **Memory Usage**: ~2-4GB (depending on input size and model)
- **Concurrent Requests**: Supported (configure workers)
## Limitations
- Maximum forecast window: 200 time points
- Univariate forecasting (single time series)
- Requires minimum input data for reliable predictions, timer-base-84m needs at least 96 time points
- Model-specific context length limitations may apply
================================================
FILE: timeseries_forecasting/app.py
================================================
"""app.py: FastAPI application for Transformer time series forecasting."""
import logging
import time
from typing import Any
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from starlette.requests import Request
from model_handler import TransformerModelHandler, get_model_handler
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="Transformer Time Series Forecasting API",
description="A REST API for time series forecasting using Transformer model",
version="1.0.0",
)
# Request models
class ForecastRequest(BaseModel):
"""Request model for forecasting."""
input: list[float | int | dict[str, Any]] = Field(
...,
description="Time series data as list of numbers or structured data",
example=[100, 102, 98, 105, 107, 103, 99, 101],
)
forecast_window: int = Field(default=24, ge=1, le=200, description="Number of future points to predict")
input_len: int | None = Field(default=None, description="Optional input length limit")
frequency: str = Field(default="hourly", description="Frequency of the time series (hourly, daily, etc.)")
target_column: str = Field(default="value", description="Column name for structured data")
class ForecastResponse(BaseModel):
"""Response model for forecasting."""
predictions: list[float] = Field(description="Forecasted values")
forecast_window: int = Field(description="Number of predictions")
frequency: str = Field(description="Time series frequency")
status: str = Field(description="Response status")
class ErrorResponse(BaseModel):
"""Error response model."""
error: str = Field(description="Error message")
status: str = Field(description="Response status")
# Global variables
model_handler: TransformerModelHandler | None = None
startup_time: float | None = None
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup."""
global model_handler, startup_time
startup_time = time.time()
logger.info("Starting Transformer Forecasting API...")
try:
# Initialize model handler
model_handler = get_model_handler()
model_success = model_handler.initialize()
if model_success:
logger.info("Model initialized successfully")
else:
logger.error("Failed to initialize model")
except Exception as e:
logger.error(f"Startup failed: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
uptime = time.time() - startup_time if startup_time else 0
return {
"status": "healthy",
"model_initialized": model_handler.initialized if model_handler else False,
"uptime_seconds": round(uptime, 2),
}
@app.get("/ping")
async def ping():
"""Simple ping endpoint."""
return {"status": "ok"}
@app.post(
"/predict",
response_model=ForecastResponse | ErrorResponse,
responses={
400: {"model": ErrorResponse, "description": "Bad Request"},
422: {"model": ErrorResponse, "description": "Validation Error"},
500: {"model": ErrorResponse, "description": "Internal Error"},
},
)
async def predict(request: ForecastRequest):
"""
Main forecasting endpoint.
Args:
request: Forecast request containing time series data and parameters
Returns:
Forecast response with predictions or error
"""
try:
logger.info(f"Received prediction request: {len(request.input)} data points, horizon={request.forecast_window}")
# Check if model is initialized
if not model_handler or not model_handler.initialized:
raise HTTPException(status_code=500, detail="Model not initialized")
# Validate input
if len(request.input) == 0:
raise HTTPException(status_code=400, detail="Input data cannot be empty")
# Make prediction
result = model_handler.predict(
time_series_data=request.input,
forecast_window=request.forecast_window,
input_len=request.input_len,
frequency=request.frequency,
target_column=request.target_column,
)
# Check if prediction was successful
if result.get("status") == "error":
raise HTTPException(status_code=result.get("code", 500), detail=result.get("error", "Prediction failed"))
logger.info(f"Prediction successful: {len(result['predictions'])} predictions generated")
return ForecastResponse(**result)
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content=ErrorResponse(error=str(e), status="error").model_dump())
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
return JSONResponse(status_code=500, content=ErrorResponse(error=str(e), status="error").model_dump())
@app.get("/model/info")
async def model_info():
"""Get model information."""
if not model_handler or not model_handler.initialized:
return {"error": "Model not initialized", "status": "error"}
return {
"model_path": model_handler.model_path,
"device": str(model_handler.device),
"initialized": model_handler.initialized,
"config": str(model_handler.config) if model_handler.config else None,
}
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": "Transformer Time Series Forecasting API",
"version": "1.0.0",
"description": "REST API for time series forecasting using Transformer model",
"endpoints": {
"predict": "/predict",
"health": "/health",
"ping": "/ping",
"model_info": "/model/info",
"docs": "/docs",
},
}
# Error handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""Handle HTTP exceptions."""
return JSONResponse(
status_code=exc.status_code,
content={"status": "error", "message": exc.detail},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions."""
logger.error(f"Unhandled exception: {str(exc)}")
return JSONResponse(
status_code=500,
content={"status": "error", "message": "Internal server error"},
)
if __name__ == "__main__":
# For development
uvicorn.run("app:app", host="0.0.0.0", port=8765, reload=True, log_level="info")
================================================
FILE: timeseries_forecasting/build_and_run.sh
================================================
#!/bin/bash
# Build and run script for time series forecasting service
set -e
echo "=== Building Timeseries Forecasting Docker Container ==="
# Check if the hf_model model directory exists
MODEL_DIR="../hf_model"
if [ ! -d "$MODEL_DIR" ]; then
echo "Error: Hugging face model directory not found at $MODEL_DIR"
echo "Please ensure the model is downloaded and available at this location"
exit 1
fi
echo "✓ Found Hugging face model at: $MODEL_DIR"
rm -rf hf_model
cp -r $MODEL_DIR .
# Build the Docker image
echo "Building Docker image..."
docker build -t timeseries-forecasting .
if [ $? -eq 0 ]; then
echo "✓ Docker image built successfully"
else
echo "✗ Failed to build Docker image"
exit 1
fi
# Check if container is already running
CONTAINER_NAME="time-series-forecasting-service"
if [ "$(docker ps -q -f name=$CONTAINER_NAME)" ]; then
echo "Stopping existing container..."
docker stop $CONTAINER_NAME
docker rm $CONTAINER_NAME
fi
echo "=== Starting Timeseries Forecasting Service ==="
# Run the container
docker run -d \
--name $CONTAINER_NAME \
-p 8765:8765 \
timeseries-forecasting
if [ $? -eq 0 ]; then
echo "✓ Container started successfully"
echo ""
echo "Service endpoints:"
echo " - Predictions: http://localhost:8765/predict"
echo " - Health Check: http://localhost:8765/health"
echo " - API Docs: http://localhost:8765/docs"
echo ""
echo "Container logs:"
echo " docker logs -f $CONTAINER_NAME"
echo ""
echo "To test the service:"
echo " python test_forecasting.py"
echo ""
echo "To stop the service:"
echo " docker stop $CONTAINER_NAME"
else
echo "✗ Failed to start container"
exit 1
fi
# Wait a moment and check if container is still running
sleep 5
if [ "$(docker ps -q -f name=$CONTAINER_NAME)" ]; then
echo "✓ Service is running"
# Show few logs
echo ""
echo "=== Initial Service Logs ==="
docker logs "$CONTAINER_NAME" | head -n 50
else
echo "✗ Service failed to start"
echo "Checking logs..."
docker logs $CONTAINER_NAME
exit 1
fi
================================================
FILE: timeseries_forecasting/model_handler.py
================================================
"""model_handler.py: Transformer based model handler for time series forecasting."""
import logging
from typing import Any
import numpy as np
import pandas as pd
import torch
from transformers import AutoConfig, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TransformerModelHandler:
"""
Transformer based Model handler for time series forecasting.
"""
def __init__(self, model_path: str = "hf_model"):
"""Initialize the model handler."""
logger.info("Initializing Transformer Model Handler")
self.model_path = model_path
self.model = None
self.config = None
self.device = None
self.initialized = False
def initialize(self) -> bool:
"""
Initialize model.
Returns:
bool: True if initialization successful
"""
try:
logger.info("Starting model initialization")
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
logger.info(f"Loading model from: {self.model_path}")
# Load model configuration
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
# Load the pretrained model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, config=self.config, trust_remote_code=True
)
# Move model to device
self.model.to(self.device)
self.model.eval()
self.initialized = True
logger.info("Transformer model loaded successfully")
logger.info(f"Model config: {self.config}")
return True
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
self.initialized = False
return False
def preprocess(
self,
time_series_data: list,
forecast_window: int = 24,
input_len: int | None = None,
frequency: str = "hourly",
target_column: str = "value",
) -> tuple[torch.Tensor, dict[str, Any]]:
"""
Transform raw input into model input data.
Args:
time_series_data: Input time series data
forecast_window: Number of future points to predict
input_len: Optional input length limit
frequency: Frequency of the time series
target_column: Column name for structured data
Returns:
Tuple of (processed_tensor, metadata)
"""
try:
logger.info(f"Input data length: {len(time_series_data) if isinstance(time_series_data, list) else 'N/A'}")
logger.info(f"Forecast window: {forecast_window}")
# Convert input to numpy array
if isinstance(time_series_data, list):
if len(time_series_data) > 0 and isinstance(time_series_data[0], dict):
# Handle structured data (with timestamps)
df = pd.DataFrame(time_series_data)
if target_column in df.columns:
values = df[target_column].values
else:
# Use the first numeric column
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
values = df[numeric_cols[0]].values
else:
values = np.array([float(x) for x in time_series_data])
else:
# Handle simple numeric list
values = np.array([float(x) for x in time_series_data])
else:
values = np.array(time_series_data)
# Handle input length constraint
if input_len is not None:
if input_len > len(values):
# Pad with zeros if input is shorter than required
values = np.pad(values, (input_len - len(values), 0), mode="constant", constant_values=0)
elif input_len < len(values):
# Take the last input_len values
values = values[-input_len:]
# Normalize the data (simple z-score normalization)
mean_val = np.mean(values)
std_val = np.std(values)
if std_val > 0:
normalized_values = (values - mean_val) / std_val
else:
normalized_values = values - mean_val
# Convert to tensor
tensor = torch.tensor(normalized_values, dtype=torch.float32).unsqueeze(0)
tensor = tensor.to(self.device)
# Store metadata for post-processing
metadata = {
"mean": mean_val,
"std": std_val,
"forecast_window": forecast_window,
"frequency": frequency,
"original_length": len(values),
}
logger.info(f"Preprocessed tensor shape: {tensor.shape}")
return tensor, metadata
except Exception as e:
logger.error(f"Preprocessing failed: {str(e)}")
raise e
def inference(self, input_tensor: torch.Tensor, metadata: dict[str, Any]) -> torch.Tensor:
"""
Run inference on the model.
Args:
input_tensor: Preprocessed input tensor
metadata: Preprocessing metadata
Returns:
Model output tensor
"""
try:
if not self.initialized:
raise RuntimeError("Model not initialized")
with torch.no_grad():
forecast_window = metadata.get("forecast_window", 24)
# Use generate method
output = self.model.generate(input_tensor, max_new_tokens=forecast_window)
logger.info(f"Model output shape: {output.shape}")
return output
except ValueError as e:
logger.error(f"Inference failed due to ValueError: {str(e)}")
raise e
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise e
def postprocess(self, output_tensor: torch.Tensor, metadata: dict[str, Any]) -> list[float]:
"""
Transform model output to final prediction format.
Args:
output_tensor: Raw model output
metadata: Preprocessing metadata
Returns:
Final predictions as list
"""
try:
# Extract predictions from tensor
if output_tensor.dim() > 1:
predictions = output_tensor[0].cpu().numpy()
else:
predictions = output_tensor.cpu().numpy()
# Denormalize the predictions
mean_val = metadata.get("mean", 0)
std_val = metadata.get("std", 1)
if std_val > 0:
denormalized_predictions = predictions * std_val + mean_val
else:
denormalized_predictions = predictions + mean_val
# Convert to list and ensure it's the right length
forecast_window = metadata.get("forecast_window", 24)
result = denormalized_predictions[:forecast_window].tolist()
logger.info(f"Final predictions length: {len(result)}")
return result
except Exception as e:
logger.error(f"Postprocessing failed: {str(e)}")
raise e
def predict(
self,
time_series_data: list,
forecast_window: int = 24,
input_len: int | None = None,
frequency: str = "hourly",
target_column: str = "value",
) -> dict[str, Any]:
"""
Main prediction method.
Args:
time_series_data: Input time series data
forecast_window: Number of future points to predict
input_len: Optional input length limit
frequency: Frequency of the time series
target_column: Column name for structured data
Returns:
Dictionary containing predictions and metadata
"""
try:
# Ensure model is initialized
if not self.initialized:
if not self.initialize():
raise RuntimeError("Failed to initialize model")
# Preprocess input
input_tensor, metadata = self.preprocess(
time_series_data, forecast_window, input_len, frequency, target_column
)
# Run inference
output_tensor = self.inference(input_tensor, metadata)
# Postprocess output
predictions = self.postprocess(output_tensor, metadata)
# Format result
result = {
"predictions": predictions,
"forecast_window": metadata.get("forecast_window", 24),
"frequency": metadata.get("frequency", "hourly"),
"status": "success",
}
return result
except ValueError as e:
logger.error(f"Prediction failed due to ValueError: {str(e)}")
return {"error": str(e), "code": 400, "status": "error"}
except Exception as e:
logger.error(f"Prediction failed: {str(e)}")
return {"error": str(e), "status": "error"}
# Global model handler instance
_model_handler = None
def get_model_handler() -> TransformerModelHandler:
"""Get or create global model handler instance."""
global _model_handler
if _model_handler is None:
_model_handler = TransformerModelHandler()
return _model_handler
================================================
FILE: timeseries_forecasting/test_forecasting.py
================================================
#!/usr/bin/env python3
"""test_forecasting.py: Test script for Timer forecasting service."""
import time
from datetime import datetime, timedelta
import numpy as np
import requests
from requests.exceptions import RequestException
class TimeseriesForecastingTester:
"""Test class for Timer forecasting service."""
def __init__(self, base_url="http://localhost:8765"):
"""Initialize the tester."""
self.base_url = base_url
self.predictions_endpoint = f"{base_url}/predict"
self.health_endpoint = f"{base_url}/health"
def generate_sample_data(self, length=100, frequency="H"):
"""Generate sample time series data for testing."""
# Generate synthetic time series with trend and seasonality
t = np.arange(length)
# Add trend
trend = 0.1 * t
# Add seasonality (daily pattern for hourly data)
if frequency == "H":
seasonality = 5 * np.sin(2 * np.pi * t / 24)
else:
seasonality = 3 * np.sin(2 * np.pi * t / 7) # Weekly pattern for daily data
# Add noise
noise = np.random.normal(0, 1, length)
# Combine components
values = 100 + trend + seasonality + noise
return values.tolist()
def test_basic_forecasting(self):
"""Test basic time series forecasting."""
print("\n=== Testing Basic Forecasting ===")
# Generate sample data
sample_data = self.generate_sample_data(length=168, frequency="H") # 1 week of hourly data
# Prepare request payload
payload = {
"input": sample_data,
"forecast_window": 24, # Forecast next 24 hours
"frequency": "H",
"input_len": 168, # Use last week of data
}
# Send request
try:
response = requests.post(
self.predictions_endpoint, json=payload, headers={"Content-Type": "application/json"}, timeout=30
)
if response.status_code == 200:
result = response.json()
print("✓ Basic forecasting successful")
print(f" - Input length: {len(sample_data)}")
print(f" - Forecast Window: {payload['forecast_window']}")
print(f" - Predictions length: {len(result.get('predictions', []))}")
print(f" - Sample predictions: {result.get('predictions', [])[:5]}")
return True
else:
print(f"✗ Request failed with status: {response.status_code}")
print(f" Response: {response.text}")
return False
except requests.exceptions.RequestException as e:
print(f"✗ Request failed: {str(e)}")
return False
def test_structured_data(self):
"""Test forecasting with structured data (timestamps + values)."""
print("\n=== Testing Structured Data Forecasting ===")
# Generate structured data with timestamps
start_time = datetime.now() - timedelta(days=7)
structured_data = []
for i in range(168): # 1 week of hourly data
timestamp = start_time + timedelta(hours=i)
value = 100 + 0.1 * i + 5 * np.sin(2 * np.pi * i / 24) + np.random.normal(0, 1)
structured_data.append({"timestamp": timestamp.isoformat(), "value": value})
# Prepare request payload
payload = {
"input": structured_data,
"forecast_window": 48, # Forecast next 48 hours
"frequency": "H",
"target_column": "value",
}
# Send request
try:
response = requests.post(
self.predictions_endpoint, json=payload, headers={"Content-Type": "application/json"}, timeout=30
)
if response.status_code == 200:
result = response.json()
print("✓ Structured data forecasting successful")
print(f" - Input records: {len(structured_data)}")
print(f" - Forecast Window: {payload['forecast_window']}")
print(f" - Predictions length: {len(result.get('predictions', []))}")
return True
else:
print(f"✗ Request failed with status: {response.status_code}")
print(f" Response: {response.text}")
return False
except requests.exceptions.RequestException as e:
print(f"✗ Request failed: {str(e)}")
return False
def test_different_windows(self):
"""Test forecasting with different forecast windows."""
print("\n=== Testing Different Forecast Horizons ===")
sample_data = self.generate_sample_data(length=100)
windows = [1, 12, 24, 48, 72]
for window in windows:
payload = {"input": sample_data, "forecast_window": window, "frequency": "H"}
try:
response = requests.post(
self.predictions_endpoint, json=payload, headers={"Content-Type": "application/json"}, timeout=30
)
if response.status_code == 200:
result = response.json()
predictions_len = len(result.get("predictions", []))
print(f"✓ Window {window}: {predictions_len} predictions")
else:
print(f"✗ Window {window}: Failed with status {response.status_code}")
return False
except requests.exceptions.RequestException as e:
print(f"✗ Window {window}: Request failed - {str(e)}")
return False
return True
def test_error_handling(self):
"""Test error handling with invalid inputs."""
print("\n=== Testing Error Handling ===")
# Test empty input
try:
response = requests.post(
self.predictions_endpoint, json={"input": []}, headers={"Content-Type": "application/json"}, timeout=10
)
print(f"Empty input: Status {response.status_code}")
if response.status_code != 400:
print("✗ Empty input: Expected 400 status code")
return False
except RequestException:
print("Empty input: exception occurred not expected")
return False
# Test invalid JSON
try:
response = requests.post(
self.predictions_endpoint, data="invalid json", headers={"Content-Type": "application/json"}, timeout=10
)
print(f"Invalid JSON: Status {response.status_code}")
if response.status_code != 422:
print("✗ Empty input: Expected 422 status code")
return False
except RequestException:
print("Empty input: exception occurred not expected")
return False
return True
def test_health_check(self):
"""Test service health check."""
print("\n=== Testing Service Health ===")
try:
# Test health endpoint
response = requests.get(self.health_endpoint, timeout=5)
if response.status_code == 200:
result = response.json()
print("✓ Service health check passed")
print(f" - Model initialized: {result.get('model_initialized', 'Unknown')}")
print(f" - Uptime: {result.get('uptime_seconds', 'Unknown')} seconds")
return True
else:
print(f"✗ Health check failed: {response.status_code}")
return False
except RequestException as e:
print(f"Health check failed: {str(e)}")
return False
def run_all_tests(self):
"""Run all tests."""
print("=" * 50)
print("TIMER FORECASTING SERVICE TESTS")
print("=" * 50)
# Wait for service to be ready
print("Waiting for service to be ready...")
for _i in range(30): # Wait up to 30 seconds
try:
response = requests.get(self.health_endpoint, timeout=2)
if response.status_code == 200:
result = response.json()
if result.get("model_initialized", False):
print("✓ Service is ready")
break
except RequestException:
pass
time.sleep(1)
else:
print("✗ Service not ready after 30 seconds")
return False
# Run tests
tests = [
self.test_health_check,
self.test_basic_forecasting,
self.test_structured_data,
self.test_different_windows,
self.test_error_handling,
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
except Exception as e:
print(f"✗ Test {test.__name__} failed with exception: {str(e)}")
print("\n" + "=" * 50)
print(f"TESTS COMPLETED: {passed}/{total} passed")
print("=" * 50)
return passed == total
def main():
"""Main function."""
import argparse
parser = argparse.ArgumentParser(description="Test Timer forecasting service")
parser.add_argument(
"--url", default="http://localhost:8765", help="Base URL of the service (default: http://localhost:8080)"
)
args = parser.parse_args()
tester = TimeseriesForecastingTester(args.url)
success = tester.run_all_tests()
exit(0 if success else 1)
if __name__ == "__main__":
main()