Repository: zhongyu09/openchatbi Branch: main Commit: 428f5d88bb12 Files: 133 Total size: 701.7 KB Directory structure: gitextract_wkfrsja_/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── workflows/ │ ├── docs.yml │ ├── publish.yml │ └── runledger.yml ├── .gitignore ├── CONTRIBUTING.md ├── Dockerfile.python-executor ├── LICENSE ├── README.md ├── baselines/ │ └── runledger-openchatbi.json ├── docs/ │ ├── Makefile │ ├── make.bat │ └── source/ │ ├── _templates/ │ │ └── layout.html │ ├── catalog.rst │ ├── code.rst │ ├── conf.py │ ├── config.rst │ ├── core.rst │ ├── index.rst │ ├── llm.rst │ ├── text2sql.rst │ ├── timeseries.rst │ └── tools.rst ├── evals/ │ ├── __init__.py │ └── runledger/ │ ├── README.md │ ├── __init__.py │ ├── agent/ │ │ └── agent.py │ ├── cases/ │ │ └── t1.yaml │ ├── cassettes/ │ │ └── t1.jsonl │ ├── schema.json │ ├── suite.yaml │ └── tools.py ├── example/ │ ├── bi.yaml │ ├── common_columns.csv │ ├── config.yaml │ ├── sql_example.yaml │ ├── table_columns.csv │ ├── table_info.yaml │ └── table_selection_example.csv ├── openchatbi/ │ ├── __init__.py │ ├── agent_graph.py │ ├── catalog/ │ │ ├── __init__.py │ │ ├── catalog_loader.py │ │ ├── catalog_store.py │ │ ├── factory.py │ │ ├── helper.py │ │ ├── retrival_helper.py │ │ ├── schema_retrival.py │ │ ├── store/ │ │ │ ├── __init__.py │ │ │ └── file_system.py │ │ └── token_service.py │ ├── code/ │ │ ├── docker_executor.py │ │ ├── executor_base.py │ │ ├── local_executor.py │ │ └── restricted_local_executor.py │ ├── config.yaml.template │ ├── config_loader.py │ ├── constants.py │ ├── context_config.py │ ├── context_manager.py │ ├── graph_state.py │ ├── llm/ │ │ └── llm.py │ ├── prompts/ │ │ ├── agent_prompt.md │ │ ├── extraction_prompt.md │ │ ├── schema_linking_prompt.md │ │ ├── sql_dialect/ │ │ │ └── presto.md │ │ ├── summary_prompt.md │ │ ├── system_prompt.py │ │ ├── text2sql_prompt.md │ │ └── visualization_prompt.md │ ├── text2sql/ │ │ ├── __init__.py │ │ ├── data.py │ │ ├── extraction.py │ │ ├── generate_sql.py │ │ ├── schema_linking.py │ │ ├── sql_graph.py │ │ ├── text2sql_utils.py │ │ └── visualization.py │ ├── text_segmenter.py │ ├── tool/ │ │ ├── ask_human.py │ │ ├── mcp_tools.py │ │ ├── memory.py │ │ ├── run_python_code.py │ │ ├── save_report.py │ │ ├── search_knowledge.py │ │ └── timeseries_forecast.py │ └── utils.py ├── pyproject.toml ├── run_streamlit_ui.py ├── run_tests.py ├── sample_api/ │ └── async_api.py ├── sample_ui/ │ ├── async_graph_manager.py │ ├── memory_ui.py │ ├── plotly_utils.py │ ├── simple_ui.py │ ├── streaming_ui.py │ ├── streamlit_ui.py │ └── style.py ├── tests/ │ ├── README.md │ ├── __init__.py │ ├── conftest.py │ ├── context_management/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_agent_graph_integration.py │ │ ├── test_context_config.py │ │ ├── test_context_manager.py │ │ ├── test_edge_cases.py │ │ ├── test_runner.py │ │ └── test_state_operations.py │ ├── test_catalog_loader.py │ ├── test_catalog_store.py │ ├── test_config_loader.py │ ├── test_graph_state.py │ ├── test_incomplete_tool_calls.py │ ├── test_memory.py │ ├── test_plotly_utils.py │ ├── test_simple_store.py │ ├── test_text2sql_extraction.py │ ├── test_text2sql_generate_sql.py │ ├── test_text2sql_schema_linking.py │ ├── test_text2sql_visualization.py │ ├── test_tools_ask_human.py │ ├── test_tools_run_python_code.py │ ├── test_tools_search_knowledge.py │ └── test_utils.py └── timeseries_forecasting/ ├── Dockerfile ├── README.md ├── app.py ├── build_and_run.sh ├── model_handler.py └── test_forecasting.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/workflows/docs.yml ================================================ name: Build and Deploy Documentation on: push: branches: [ main ] pull_request: branches: [ main ] permissions: contents: read pages: write id-token: write concurrency: group: "pages" cancel-in-progress: false jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel pip install -e ".[docs]" - name: Build documentation run: | cd docs make html - name: Setup Pages uses: actions/configure-pages@v4 - name: Upload artifact uses: actions/upload-pages-artifact@v3 with: path: './docs/build/html' deploy: environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} runs-on: ubuntu-latest needs: build if: github.ref == 'refs/heads/main' steps: - name: Deploy to GitHub Pages id: deployment uses: actions/deploy-pages@v4 ================================================ FILE: .github/workflows/publish.yml ================================================ name: Publish to PyPI on: release: types: [published] # Trigger when a release is published workflow_dispatch: # Allow manual triggering jobs: test: runs-on: ubuntu-latest strategy: matrix: python-version: ['3.11', '3.12'] steps: - uses: actions/checkout@v4 - name: Install uv uses: astral-sh/setup-uv@v4 with: version: "latest" - name: Set up Python ${{ matrix.python-version }} run: uv python install ${{ matrix.python-version }} - name: Install dependencies run: | uv sync --all-extras - name: Run linting run: | uv run black --check . - name: Run tests run: | uv run pytest -v --cov=openchatbi --cov-report=xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: file: ./coverage.xml flags: unittests name: codecov-umbrella build: needs: test runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Install uv uses: astral-sh/setup-uv@v4 with: version: "latest" - name: Set up Python run: uv python install 3.11 - name: Build package run: | uv build - name: Check build artifacts run: | ls -la dist/ uv run twine check dist/* - name: Upload build artifacts uses: actions/upload-artifact@v4 with: name: dist path: dist/ publish: needs: build runs-on: ubuntu-latest if: github.event_name == 'release' environment: name: pypi url: https://pypi.org/p/openchatbi permissions: id-token: write # Required for PyPI trusted publishing contents: read steps: - name: Download build artifacts uses: actions/download-artifact@v4 with: name: dist path: dist/ - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: # Uses PyPI trusted publishing, no API token needed verbose: true print-hash: true publish-test: needs: build runs-on: ubuntu-latest if: github.event_name == 'workflow_dispatch' # Only publish to TestPyPI when manually triggered environment: name: testpypi url: https://test.pypi.org/p/openchatbi permissions: id-token: write contents: read steps: - name: Download build artifacts uses: actions/download-artifact@v4 with: name: dist path: dist/ - name: Publish to TestPyPI uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: https://test.pypi.org/legacy/ verbose: true print-hash: true ================================================ FILE: .github/workflows/runledger.yml ================================================ name: runledger on: workflow_dispatch: pull_request: paths: - "openchatbi/**" jobs: runledger: if: github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'runledger') runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install runledger python -m pip install . - name: Run deterministic evals (replay) run: | runledger run evals/runledger --mode replay --baseline baselines/runledger-openchatbi.json - name: Upload artifacts uses: actions/upload-artifact@v4 with: name: runledger-artifacts path: runledger_out/** ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py.cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock #poetry.toml # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. # https://pdm-project.org/en/latest/usage/project/#working-with-version-control #pdm.lock #pdm.toml .pdm-python .pdm-build/ # pixi # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. #pixi.lock # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one # in the .venv directory. It is recommended not to include this directory in version control. .pixi # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .envrc .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # Abstra # Abstra is an AI-powered process automation framework. # Ignore directories containing user credentials, local state, and settings. # Learn more at https://abstra.io/docs .abstra/ # Visual Studio Code # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore # and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ # Ruff stuff: .ruff_cache/ # PyPI configuration file .pypirc # Cursor # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files .cursorignore .cursorindexingignore # Marimo marimo/_static/ marimo/_lsp/ __marimo__/ # project spec openchatbi/config.yaml memory.db checkpoints.db data hf_model timeseries_forecasting/hf_model runledger_out/ ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to OpenChatBI Hi there! Thank you for your interest in contributing to OpenChatBI. OpenChatBI started as a personal project, with the hope of making it easier for businesses to build their own ChatBI applications with less effort. To achieve this goal, I made it open source, and I greatly appreciate contributions of all kinds. Whether you’d like to propose a new feature, refactor the code, enhance documentation, or fix bugs, your contributions are always welcome. ================================================ FILE: Dockerfile.python-executor ================================================ FROM python:3.11-slim # Set working directory WORKDIR /app # Install basic packages that might be needed for data analysis RUN pip install --no-cache-dir \ pandas \ numpy \ matplotlib \ seaborn \ requests \ json5 # Create a directory for code execution RUN mkdir -p /app/code # Set up a non-root user for security RUN useradd -m -u 1000 executor USER executor # Set the default command CMD ["python3"] ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 Yu Zhong Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # OpenChatBI OpenChatBI is an open source, chat-based intelligent BI tool powered by large language models, designed to help users query, analyze, and visualize data through natural language conversations. Built on LangGraph and LangChain ecosystem, it provides chat agents and workflows that support natural language to SQL conversion and streamlined data analysis. Join the Slack channel to discuss: https://join.slack.com/t/openchatbicommunity/shared_invite/zt-3jpzpx9mv-Sk88RxpO4Up0L~YTZYf4GQ Demo ## Core Features 1. **Natural Language Interaction**: Get data analysis results by asking questions in natural language 2. **Automatic SQL Generation**: Convert natural language queries into SQL statements using advanced text2sql workflows with schema linking and well organized prompt engineering 3. **Data Visualization**: Generate intuitive data visualizations (via plotly) 4. **Data Catalog Management**: Automatically discovers and indexes database table structures, supports flexible catalog storage backends with vector-based or BM25-based retrieval, and easily maintains business explanations for tables and columns as well as optimizes Prompts. 5. **Time Series Forecasting**: Forecasting models deployed in-house that can be called as tools 6. **Code Execution**: Execute Python code for data analysis and visualization 7. **Interactive Problem-Solving**: Proactively ask users for more context when information is incomplete 8. **Persistent Memory**: Conversation management and user characteristic memory based on LangGraph checkpointing 9. **MCP Support**: Integration with MCP tools by configuration 10. **Knowledge Base Integration**: Answer complex questions by combining catalog based knowledge retrival and external knowledge base retrival (via MCP tools) 11. **Web UI Interface**: Provide 2 sample UI: simple and streaming web interfaces using Gradio and Streamlit, easy to integrate with other web applications ## Roadmap 1. **Anomaly Detection Algorithm**: Time series anomaly detection 2. **Root Cause Analysis Algorithm**: Multi-dimensional drill-down capabilities for anomaly investigation # Getting started ## Installation & Setup ### Prerequisites - Python 3.11 or higher - Access to a supported LLM provider (OpenAI, Anthropic, etc.) - Data Warehouse (Database) credentials (like Presto, PostgreSQL, MySQL, etc.) - (Optional) Embedding model for vector-based retrieval - if not available, BM25-based retrieval will be used - (Optional) Docker - required only for `docker` executor mode **Note on Chinese Text Segmentation**: For better Chinese text retrieval, `jieba` is used for word segmentation. However, `jieba` is not compatible with Python 3.12+. On Python 3.12 and higher, the system automatically falls back to simple punctuation-based segmentation for Chinese text. ### Installation 1. **Using uv (recommended):** ```bash git clone git@github.com:zhongyu09/openchatbi uv sync ``` 2. **Using pip:** ```bash pip install openchatbi ``` 3. **For development:** ```bash git clone git@github.com:zhongyu09/openchatbi uv sync --group dev ``` Optional: If you want to use `pysqlite3` (newer SQLite builds), you can install it manually. If build fails, install SQLite first: On macOS, try to install sqlite using Homebrew: ```bash brew install sqlite brew info sqlite export LDFLAGS="-L/opt/homebrew/opt/sqlite/lib" export CPPFLAGS="-I/opt/homebrew/opt/sqlite/include" ``` On Amazon Linux / RHEL / CentOS: ```bash sudo yum install sqlite-devel ``` On Ubuntu / Debian: ```bash sudo apt-get update sudo apt-get install libsqlite3-dev ``` ### Run Demo Run demo using **example dataset** from spider dataset. You need to provide "YOUR OPENAI API KEY" or change config to use other LLM providers. **Note**: The demo example includes embedding model configuration. If you want to run without an embedding model, you can remove the `embedding_model` section in the config - BM25 retrieval will be used automatically. ```bash cp example/config.yaml openchatbi/config.yaml sed -i 's/YOUR_API_KEY_HERE/[YOUR OPENAI API KEY]/g' openchatbi/config.yaml python run_streamlit_ui.py ``` ### Configuration 1. **Create configuration file** Copy the configuration template: ```bash cp openchatbi/config.yaml.template openchatbi/config.yaml ``` Or create an empty YAML file. 2. **Configure your LLMs:** ```yaml # Select which provider to use default_llm: openai # Define one or more providers llm_providers: openai: default_llm: class: langchain_openai.ChatOpenAI params: api_key: YOUR_API_KEY_HERE model: gpt-4.1 temperature: 0.02 max_tokens: 8192 # Optional: Embedding model for vector-based retrieval and memory tools # If not configured, BM25-based retrieval will be used, and the memory tools will not work embedding_model: class: langchain_openai.OpenAIEmbeddings params: api_key: YOUR_API_KEY_HERE model: text-embedding-3-large chunk_size: 1024 ``` 3. **Configure your data warehouse:** ```yaml organization: Your Company dialect: presto data_warehouse_config: uri: "presto://user@host:8080/catalog/schema" include_tables: - your_table_name database_name: "catalog.schema" ``` ### Running the Application 1. **Invoking LangGraph:** ```bash export CONFIG_FILE=YOUR_CONFIG_FILE_PATH ``` ```python from openchatbi import get_default_graph graph = get_default_graph() graph.invoke({"messages": [{"role": "user", "content": "Show me ctr trends for the past 7 days"}]}, config={"configurable": {"thread_id": "1"}}) ``` ``` # System-generated SQL SELECT date, SUM(clicks)/SUM(impression) AS ctr FROM ad_performance WHERE date >= CURRENT_DATE - 7 DAYS GROUP BY date ORDER BY date; ``` 2. **Sample Web UI:** Streamlit based UI: ```bash streamlit run sample_ui streamlit_ui.py ``` Run Gradio based UI: ```bash python sample_ui/streaming_ui.py ``` ## Configuration Instructions The configuration template is provided at `config.yaml.template`. Key configuration sections include: ### Basic Settings - `organization`: Organization name (e.g., "Your Company") - `dialect`: Database dialect (e.g., "presto") - `bi_config_file`: Path to BI configuration file (e.g., "example/bi.yaml") ### Catalog Store Configuration - `catalog_store`: Configuration for data catalog storage - `store_type`: Storage type (e.g., "file_system") - `data_path`: Path to catalog data stored by file system (e.g., "./example") ### Data Warehouse Configuration - `data_warehouse_config`: Database connection settings - `uri`: Connection string for your database - `include_tables`: List of tables to include in catalog, leave empty to include all tables - `database_name`: Database name for catalog - `token_service`: Token service URL (for data warehouse that need token authentication like Presto) - `user_name` / `password`: Token service credentials ### LLM Configuration Various LLMs are supported based on LangChain, see LangChain API Document(https://python.langchain.com/api_reference/reference.html#integrations) for full list that support `chat_models`. You can configure different LLMs for different tasks: - `default_llm`: Primary language model for general tasks - `embedding_model`: (Optional) Model for embedding generation. If not configured, BM25-based text retrieval will be used as fallback, and the memory tools will not work - `text2sql_llm`: (Optional) Specialized model for SQL generation. If not configured, uses `default_llm` Multiple providers (optional): - Configure multiple providers under `llm_providers` and select with `default_llm: `. - In `sample_ui/streamlit_ui.py`, a provider dropdown appears when `llm_providers` is configured. - In `sample_api/async_api.py`, pass `provider` in the `/chat/stream` request body. Commonly used LLM providers and their corresponding classes and installation commands: - **Anthropic**: `langchain_anthropic.ChatAnthropic`, `pip install langchain-anthropic` - **OpenAI**: `langchain_openai.ChatOpenAI`, `pip install langchain-openai` - **Azure OpenAI**: `langchain_openai.AzureChatOpenAI`, `pip install langchain-openai` - **Google Vertex AI**: `langchain_google_vertexai.ChatVertexAI`, `pip install langchain-google-vertexai` - **Bedrock**: `langchain_aws.ChatBedrock`, `pip install langchain-aws` - **Huggingface**: `langchain_huggingface.ChatHuggingFace`, `pip install langchain-huggingface` - **Deepseek**: `langchain_deepseek.ChatDeepSeek`, `pip install langchain-deepseek` - **Ollama**: `langchain_ollama.ChatOllama`, `pip install langchain-ollama` ### Advanced Configuration OpenChatBI supports sophisticated customization through prompt engineering and catalog management features: - **Prompt Engineering Configuration**: Customize system prompts, business glossaries, and data warehouse introductions - **Data Catalog Management**: Configure table metadata, column descriptions, and SQL generation rules - **Business Rules**: Define table selection criteria and domain-specific SQL constraints - **Forecasting Service**: Configure the forecasting service url and prompt based on your own deployment For detailed configuration options and examples, see the [Advanced Features](#advanced-features) section. ## Architecture Overview OpenChatBI is built using a modular architecture with clear separation of concerns: 1. **LangGraph Workflows**: Core orchestration using state machines for complex multi-step processes 2. **Catalog Management**: Flexible data catalog system with intelligent retrieval (vector-based or BM25 fallback) 3. **Text2SQL Pipeline**: Advanced natural language to SQL conversion with schema linking 4. **Code Execution**: Sandboxed Python execution environment for data analysis 5. **Tool Integration**: Extensible tool system for human interaction and knowledge search 6. **Persistent Memory**: SQLite-based conversation state management ## Technology Stack - **Frameworks**: LangGraph, LangChain, FastAPI, Gradio/Streamlit - **Large Language Models**: Azure OpenAI (GPT-4), Anthropic Claude, OpenAI GPT models - **Text Retrieval**: Vector-based (with embedding models) or BM25-based (fallback without embeddings) - **Databases**: Presto, Trino, MySQL with SQLAlchemy support - **Code Execution**: Local Python, RestrictedPython, Docker containerization - **Development**: Python 3.11+, with modern tooling (Black, Ruff, MyPy, Pytest) - **Storage**: SQLite for conversation checkpointing, file system catalog storage ### Agent Graph Agent Graph ### Text2SQL Graph Text2SQL Graph ## Project Structure ``` openchatbi/ ├── README.md # Project documentation ├── pyproject.toml # Modern Python project configuration ├── Dockerfile.python-executor # Docker image for isolated code execution ├── run_tests.py # Test runner script ├── run_streamlit_ui.py # Streamlit UI launcher ├── openchatbi/ # Core application code │ ├── __init__.py # Package initialization │ ├── config.yaml.template # Configuration template │ ├── config_loader.py # Configuration management │ ├── constants.py # Application constants │ ├── agent_graph.py # Main LangGraph workflow │ ├── graph_state.py # State definition for workflows │ ├── context_config.py # Context management configuration │ ├── context_manager.py # Context window and token management │ ├── text_segmenter.py # Text segmentation with jieba support │ ├── utils.py # Utility functions and SimpleStore (BM25-based retrieval) │ ├── catalog/ # Data catalog management │ │ ├── __init__.py # Package initialization │ │ ├── catalog_loader.py # Catalog loading logic │ │ ├── catalog_store.py # Catalog storage interface │ │ ├── factory.py # Catalog factory patterns │ │ ├── helper.py # Catalog helper functions │ │ ├── retrival_helper.py # Retrieval helper utilities │ │ ├── schema_retrival.py # Schema retrieval logic │ │ ├── token_service.py # Token service integration │ │ └── store/ # Catalog storage implementations │ │ └── file_system.py # File system-based catalog storage │ ├── code/ # Code execution framework │ │ ├── __init__.py # Package initialization │ │ ├── executor_base.py # Base executor interface │ │ ├── local_executor.py # Local Python execution │ │ ├── restricted_local_executor.py # RestrictedPython execution │ │ └── docker_executor.py # Docker-based isolated execution │ ├── llm/ # LLM integration layer │ │ ├── __init__.py # Package initialization │ │ └── llm.py # LLM management and retry logic │ ├── prompts/ # Prompt templates and engineering │ │ ├── __init__.py # Package initialization │ │ ├── agent_prompt.md # Main agent prompts │ │ ├── extraction_prompt.md # Information extraction prompts │ │ ├── system_prompt.py # System prompt management │ │ ├── summary_prompt.md # Summary conversation prompts │ │ ├── table_selection_prompt.md # Table selection prompts │ │ ├── text2sql_prompt.md # Text-to-SQL prompts │ │ └── sql_dialect/ # SQL dialect-specific prompts │ ├── text2sql/ # Text-to-SQL conversion pipeline │ │ ├── __init__.py # Package initialization │ │ ├── data.py # Data and retriever for Text-to-SQL │ │ ├── extraction.py # Information extraction │ │ ├── generate_sql.py # SQL generation and execution logic │ │ ├── schema_linking.py # Schema linking process │ │ ├── sql_graph.py # SQL generation LangGraph workflow │ │ ├── text2sql_utils.py # Text2SQL utilities │ │ └── visualization.py # Data visualization functions │ └── tool/ # LangGraph tools and functions │ ├── ask_human.py # Human-in-the-loop interactions │ ├── memory.py # Memory management tool │ ├── mcp_tools.py # MCP (Model Context Protocol) integration │ ├── run_python_code.py # Configurable Python code execution │ ├── save_report.py # Report saving functionality │ ├── search_knowledge.py # Knowledge base search │ └── timeseries_forecast.py # Time series forecasting tool ├── sample_api/ # API implementations │ └── async_api.py # Asynchronous FastAPI example ├── sample_ui/ # Web interface implementations │ ├── memory_ui.py # Memory-enhanced UI interface │ ├── plotly_utils.py # Plotly utilities and helpers │ ├── simple_ui.py # Simple non-streaming Gradio UI │ ├── streaming_ui.py # Streaming Gradio UI with real-time updates │ ├── streamlit_ui.py # Streaming Streamlit UI with enhanced features │ └── style.py # UI styling and CSS ├── example/ # Example configurations and data │ ├── bi.yaml # BI configuration example │ ├── config.yaml # Application config example │ ├── table_info.yaml # Table information │ ├── table_columns.csv # Table column registry │ ├── common_columns.csv # Common column definitions │ ├── sql_example.yaml # SQL examples for retrieval │ ├── table_selection_example.csv # Table selection examples │ └── tracking_orders.sqlite # Sample SQLite database ├── timeseries_forecasting/ # Time series forecasting service │ ├── README.md # Forecasting service documentation │ └── ... # Forecasting service implementation ├── tests/ # Test suite │ ├── __init__.py # Package initialization │ ├── conftest.py # Test configuration │ ├── test_*.py # Test modules for various components │ └── README.md # Testing documentation ├── docs/ # Documentation │ ├── source/ # Sphinx documentation source │ ├── build/ # Built documentation │ ├── Makefile # Documentation build scripts │ └── make.bat # Windows build script └── .github/ # GitHub workflows and templates └── workflows/ # CI/CD workflows ``` ## Advanced Features ### Visualization configuration You can choose rule-based or llm-based visualization or disable visualization. ```yaml # Options: "rule" (rule-based), "llm" (LLM-based), or null (skip visualization) visualization_mode: llm ``` ### Prompt Engineering #### Basic Knowledge & Glossary You can define basic knowledge and glossary in `example/bi.yaml`, for example: ```yaml basic_knowledge_glossary: | # Basic Knowledge Introduction The basic knowledge about your company and its business, including key concepts, metrics, and processes. # Glossary Common terms and their definitions used in your business context. ``` #### Data Warehouse Introduction You can provide a brief introduction of your data warehouse in `example/bi.yaml`, for example: ```yaml data_warehouse_introduction: | # Data Warehouse Introduction This data warehouse is built on Presto and contains various tables related to XXXXX. The main fact tables include XXXX metrics, while dimension tables include XXXXX. The data is updated hourly and is used for reporting and analysis purposes. ``` #### Table Selection Rules You can configure table selection rules in `example/bi.yaml`, for example: ```yaml table_selection_extra_rule: | - All tables with is_valid can support both valid and invalid traffics ``` #### Custom SQL Rules You can define your additional SQL Generation rules for tables in `example/table_info.yaml`, for example: ```yaml sql_rule: | ### SQL Rules - All event_date in the table are stored in **UTC**. If the user specifies a timezone (e.g., CET, PST), convert between timezones accordingly. ``` ### Catalog Management #### Introduction High-quality catalog data is essential for accurate Text2SQL generation and data analysis. OpenChatBI automatically discovers and indexes data warehouse table structures while providing flexible management for business metadata, column descriptions, and query optimization rules. #### Catalog Structure The catalog system organizes metadata in a hierarchical structure: **Database Level** - Top-level container for all tables and schemas **Table Level** - `description`: Business functionality and purpose of the table - `selection_rule`: Guidelines for when and how to use this table in queries - `sql_rule`: Specific SQL generation rules and constraints for this table **Column Level** - **Required Fields**: Essential metadata for each column to enable effective Text2SQL generation - `column_name`: Technical database column name - `display_name`: Human-readable name for business users - `alias`: Alternative names or abbreviations - `type`: Data type (string, integer, date, etc.) - `category`: Business category, dimension or metric - `tag`: Additional labels for filtering and organization - `description`: Detailed explanation of column purpose and usage - **Two Types** of Columns - **Common Columns**: Columns with standardized business meanings shared across tables - **Table-Specific Columns**: Columns with context-dependent meanings that vary between tables - **Derived Metrics**: Virtual metrics calculated from existing columns using SQL formulas - Computed dynamically during query execution rather than stored as physical columns - Examples: CTR (clicks/impressions), conversion rates, profit margins - Enable complex business calculations without pre-computing values #### Loading Catalog from Database OpenChatBI can automatically discover and load table structures from your data warehouse: 1. **Automatic Discovery**: Connects to your configured data warehouse and scans table schemas 2. **Metadata Extraction**: Extracts column names, data types, and basic structural information 3. **Incremental Updates**: Supports updating catalog data as your database schema evolves Configure automatic catalog loading in your `config.yaml`: ```yaml catalog_store: store_type: file_system data_path: ./catalog_data data_warehouse_config: include_tables: - your_table_pattern # Leave empty to include all accessible tables ``` #### File System Catalog Store The file system catalog store organizes metadata across multiple files for maintainability and version control: **Core Table Information** - `table_info.yaml`: Comprehensive table metadata organized hierarchically (database → table → information) - `type`: Table classification (e.g., "fact" for Fact Tables, "dimension" for Dimension Tables) - `description`: Business functionality and purpose - `selection_rule`: Usage guidelines in markdown list format (each line starts with `-`) - `sql_rule`: SQL generation rules in markdown header format (each rule starts with `####`) - `derived_metric`: Virtual metrics with calculation formulas, organized by groups: ```md #### Derived Ratio Metrics Click-through Rate (alias CTR): SUM(clicks) / SUM(impression) Conversion Rate (alias CVR): SUM(conversions) / SUM(clicks) ``` **Column Management** - `table_columns.csv`: Basic column registry with schema `db_name,table_name,column_name` - `table_spec_columns.csv`: Table-specific column metadata with full schema: `db_name,table_name,column_name,display_name,alias,type,category,tag,description` - `common_columns.csv`: Shared column definitions across tables with schema: `column_name,display_name,alias,type,category,tag,description` **Query Examples and Training Data** - `table_selection_example.csv`: Table selection training examples with schema `question,selected_tables` - `sql_example.yaml`: Query examples organized by database and table structure: ```yaml your_database: ad_performance: | Q: Show me CTR trends for the past 7 days A: SELECT date, SUM(clicks)/SUM(impressions) AS ctr FROM ad_performance WHERE date >= CURRENT_DATE - INTERVAL 7 DAY GROUP BY date ORDER BY date; ``` ### Time Series Forecasting Service Setup OpenChatBI can integrate with a time series forecasting service for advanced predictive analytics. Follow these steps to set up the service: #### 1. Build and Run the Forecasting Service See detailed instructions in [timeseries_forecasting/README.md](timeseries_forecasting/README.md) Quick start: ```bash cd timeseries_forecasting ./build_and_run.sh ``` #### 2. Configure Tool Usage Rules In your `bi.yaml`, add constraints for the timeseries_forecast tool, e.g. if you are using `timer-base-84m` model: ```yaml extra_tool_use_rule: | - timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros. ``` #### 3. Configure Service URL In your `config.yaml`: ```yaml # Time Series Forecasting Service Configuration timeseries_forecasting_service_url: "http://localhost:8765" ``` **Important**: Adjust the URL based on your deployment scenario: - **Local development** (OpenChatBI on host, Forecasting service in Docker): `http://localhost:8765` - **Remote service**: `http://your-service-host:8765` #### 4. Verify Service Health Test the service is accessible: ```bash curl http://localhost:8765/health ``` Expected response: ```json { "status": "healthy", "model_initialized": true, "uptime_seconds": 123.45 } ``` ### Python Code Execution Configuration OpenChatBI supports multiple execution environments for running Python code with different security and performance characteristics: ```yaml # Python Code Execution Configuration python_executor: local # Options: "local", "restricted_local", "docker" ``` #### Executor Types - **`local`** (Default) - **Performance**: Fastest execution - **Security**: Least secure (code runs in current Python process) - **Capabilities**: Full Python capabilities and library access - **Use Case**: Development environments, trusted code execution - **`restricted_local`** - **Performance**: Moderate execution speed - **Security**: Moderate security with RestrictedPython sandboxing - **Capabilities**: Limited Python features (no imports, file access, etc.) - **Use Case**: Semi-trusted environments with controlled execution - **`docker`** - **Performance**: Slower due to container overhead - **Security**: Highest security with complete process isolation - **Capabilities**: Full Python capabilities within isolated container - **Use Case**: Production environments, untrusted code execution - **Requirements**: Docker must be installed and running #### Docker Executor Setup For production deployments or when running untrusted code, the Docker executor provides complete isolation: 1. **Install Docker**: Download and install Docker Desktop or Docker Engine 2. **Configure executor**: Set `python_executor: docker` in your config 3. **Automatic setup**: OpenChatBI will automatically build the required Docker image 4. **Fallback behavior**: If Docker is unavailable, automatically falls back to local executor **Docker Executor Features**: - Pre-installed data science libraries (pandas, numpy, matplotlib, seaborn) - Network isolation for security - Automatic container cleanup - Resource isolation from host system ## Development & Testing ### Code Quality Tools The project uses modern Python tooling for code quality: ```bash # Format code uv run black . # Lint code uv run ruff check . # Type checking uv run mypy openchatbi/ # Security scanning uv run bandit -r openchatbi/ ``` ### Testing Run the test suite: ```bash # Run all tests uv run pytest # Run with coverage uv run pytest --cov=openchatbi --cov-report=html # Run specific test files uv run pytest test/test_generate_sql.py uv run pytest test/test_agent_graph.py ``` ### Pre-commit Hooks Install pre-commit hooks for automatic code quality checks: ```bash uv run pre-commit install ``` ## Contribution Guidelines 1. Fork the repository 2. Create a feature branch (`git checkout -b feature/fooBar`) 3. Commit your changes (`git commit -am 'Add some fooBar'`) 4. Push to the branch (`git push origin feature/fooBar`) 5. Create a new Pull Request ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details ## Contact & Support - **Author**: Yu Zhong ([zhongyu8@gmail.com](mailto:zhongyu8@gmail.com)) - **Repository**: [github.com/zhongyu09/openchatbi](https://github.com/zhongyu09/openchatbi) - **Issues**: [Report bugs and feature requests](https://github.com/zhongyu09/openchatbi/issues) ================================================ FILE: baselines/runledger-openchatbi.json ================================================ { "aggregates": { "cases_error": 0, "cases_fail": 0, "cases_pass": 1, "cases_total": 1, "metrics": { "cost_usd": { "max": null, "mean": null, "min": null, "p50": null, "p95": null }, "steps": { "max": null, "mean": null, "min": null, "p50": null, "p95": null }, "tokens_in": { "max": null, "mean": null, "min": null, "p50": null, "p95": null }, "tokens_out": { "max": null, "mean": null, "min": null, "p50": null, "p95": null }, "tool_calls": { "max": 1.0, "mean": 1.0, "min": 1.0, "p50": 1.0, "p95": 1.0 }, "tool_errors": { "max": 0.0, "mean": 0.0, "min": 0.0, "p50": 0.0, "p95": 0.0 }, "wall_ms": { "max": 1.0, "mean": 1.0, "min": 1.0, "p50": 1.0, "p95": 1.0 } }, "pass_rate": 1.0 }, "cases": [ { "assertions": { "failed": 0, "total": 1 }, "cost_usd": null, "failed_assertions": null, "failure_reason": null, "id": "t1", "replay": { "cassette_path": "evals/runledger/cassettes/t1.jsonl", "cassette_sha256": "7e9830609490d140bf09178106dfa647bba4c9ec15859072b5aa2c3ae1659289" }, "status": "pass", "steps": null, "tokens_in": null, "tokens_out": null, "tool_calls": 1, "tool_calls_by_name": { "search_knowledge": 1 }, "tool_errors": 0, "tool_errors_by_name": {}, "wall_ms": 1 } ], "generated_at": "2026-01-03T19:10:00Z", "run": { "ci": null, "exit_status": "success", "git_sha": null, "mode": "replay", "run_id": "baseline" }, "runledger_version": "0.1.1", "schema_version": 1, "suite": { "agent_command": [ "python", "evals/runledger/agent/agent.py" ], "cases_total": 1, "name": "runledger-openchatbi", "suite_config_hash": null, "suite_path": "evals/runledger/suite.yaml", "tool_mode": "replay" } } ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=source set BUILDDIR=build %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) if "%1" == "" goto help %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/source/_templates/layout.html ================================================ {% extends "!layout.html" %} {% block extrahead %} {{ super() }} {% endblock %} ================================================ FILE: docs/source/catalog.rst ================================================ Catalog System ============== Overview -------- The catalog system manages metadata for database tables, columns, and business rules. Catalog Store ------------- .. automodule:: openchatbi.catalog.catalog_store :members: :undoc-members: :show-inheritance: Filesystem Implementation ^^^^^^^^^^^^^^^^^^^^^^^^^ .. automodule:: openchatbi.catalog.store.file_system :members: :show-inheritance: Catalog Loader -------------- .. automodule:: openchatbi.catalog.catalog_loader :members: :undoc-members: :show-inheritance: Schema Retrieval ---------------- .. automodule:: openchatbi.catalog.schema_retrival :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/source/code.rst ================================================ Code Execution ============== Code Module ----------- .. automodule:: openchatbi.code :members: :undoc-members: :show-inheritance: Executor Base ------------- .. automodule:: openchatbi.code.executor_base :members: :undoc-members: :show-inheritance: Local Executor -------------- .. automodule:: openchatbi.code.local_executor :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/source/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os import sys sys.path.insert(0, os.path.abspath("../..")) project = "OpenChatBI" copyright = "2025, Yu Zhong" author = "Yu Zhong" release = "0.2.2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # Mock dependencies for documentation build autodoc_mock_imports = [] extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.viewcode", "sphinx.ext.githubpages", "myst_parser", ] # Set an environment variable to indicate we're building docs import os os.environ["SPHINX_BUILD"] = "1" # MyST parser configuration myst_enable_extensions = [ "colon_fence", "deflist", "html_admonition", "html_image", ] myst_heading_anchors = 3 templates_path = ["_templates"] exclude_patterns = [] # Autodoc configuration autodoc_default_options = { "members": True, "member-order": "bysource", "special-members": "__init__", "undoc-members": True, "exclude-members": "__weakref__", } # Napoleon configuration for Google/NumPy style docstrings napoleon_google_docstring = True napoleon_numpy_docstring = True napoleon_include_init_with_doc = False napoleon_include_private_with_doc = False napoleon_include_special_with_doc = True napoleon_use_admonition_for_examples = False napoleon_use_admonition_for_notes = False napoleon_use_admonition_for_references = False napoleon_use_ivar = False napoleon_use_param = True napoleon_use_rtype = True napoleon_preprocess_types = False napoleon_type_aliases = None napoleon_attr_annotations = True # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "sphinx_rtd_theme" html_static_path = ["_static"] # GitHub Pages configuration html_baseurl = "https://zhongyu09.github.io/openchatbi/" # Theme options for RTD theme html_theme_options = { "navigation_depth": 4, "collapse_navigation": False, "sticky_navigation": True, "includehidden": True, "titles_only": False, } ================================================ FILE: docs/source/config.rst ================================================ Configuration ============= The configuration system consists of two main classes: - **Config**: Defines the configuration model. - **ConfigLoader**: Manages loading and accessing configuration. Config ------ .. autoclass:: openchatbi.config_loader.Config :exclude-members: organization, dialect, default_llm, embedding_model, text2sql_llm, bi_config, data_warehouse_config, catalog_store, mcp_servers, report_directory, python_executor ConfigLoader ------------ .. autoclass:: openchatbi.config_loader.ConfigLoader :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/source/core.rst ================================================ Core Module =========== Main Module ----------- .. automodule:: openchatbi :members: :undoc-members: :show-inheritance: Agent Graph ----------- .. automodule:: openchatbi.agent_graph :members: :undoc-members: :show-inheritance: State Management ---------------- .. automodule:: openchatbi.graph_state :members: :undoc-members: :show-inheritance: Utilities --------- .. automodule:: openchatbi.utils :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/source/index.rst ================================================ OpenChatBI Documentation ======================== `GitHub Repository `_ .. include:: ../../README.md :parser: myst_parser.sphinx_ .. toctree:: :maxdepth: 4 :caption: Documentation: :titlesonly: self .. toctree:: :maxdepth: 2 :caption: API Reference: Core Module Configuration Catalog System Text2SQL System Code Execution LLM Integration Tools and Utilities Time Series Forecasting Service Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/source/llm.rst ================================================ LLM Integration =============== LLM Module ---------- .. automodule:: openchatbi.llm :members: :undoc-members: :show-inheritance: LLM Implementation ------------------ .. automodule:: openchatbi.llm.llm :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/source/text2sql.rst ================================================ Text2SQL System =============== Overview -------- Natural language to SQL conversion pipeline with schema linking and prompt engineering. SQL Graph --------- .. automodule:: openchatbi.text2sql.sql_graph :members: :undoc-members: :show-inheritance: SQL Generation -------------- .. automodule:: openchatbi.text2sql.generate_sql :members: :undoc-members: :show-inheritance: Schema Linking -------------- .. automodule:: openchatbi.text2sql.schema_linking :members: :undoc-members: :show-inheritance: Information Extraction ---------------------- .. automodule:: openchatbi.text2sql.extraction :members: :undoc-members: :show-inheritance: Text2SQL Utilities ------------------- .. automodule:: openchatbi.text2sql.text2sql_utils :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/source/timeseries.rst ================================================ Time Series Forecasting Service ======================== `GitHub Repository `_ .. include:: ../../timeseries_forecasting/README.md :parser: myst_parser.sphinx_ ================================================ FILE: docs/source/tools.rst ================================================ Tools and Utilities =================== Overview -------- LangGraph tools for human interaction, code execution, and knowledge search. Python Code Execution ---------------------- .. automodule:: openchatbi.tool.run_python_code :members: :undoc-members: :show-inheritance: Human Interaction ----------------- .. automodule:: openchatbi.tool.ask_human :members: :undoc-members: :show-inheritance: Memory Management ----------------- .. automodule:: openchatbi.tool.memory :members: :undoc-members: :show-inheritance: Knowledge Search ---------------- .. automodule:: openchatbi.tool.search_knowledge :members: :undoc-members: :show-inheritance: ================================================ FILE: evals/__init__.py ================================================ """Evaluation suites for RunLedger.""" ================================================ FILE: evals/runledger/README.md ================================================ # RunLedger eval (OpenChatBI) This suite is **replay-only** by default. It runs a deterministic CI check using a JSONL adapter that proxies tool calls through RunLedger and replays results from a cassette. ## Run (replay) ```bash runledger run evals/runledger --mode replay --baseline baselines/runledger-openchatbi.json ``` ## Record / update cassette (optional) If you want to re-record the cassette with real tool outputs, run in record mode in a fully configured OpenChatBI environment (valid `openchatbi/config.yaml`, data warehouse/catalog, LLM keys). ```bash runledger run evals/runledger --mode record \ --baseline baselines/runledger-openchatbi.json \ --tool-module evals.runledger.tools ``` Notes: - Tool args are passed as JSON objects; see `evals/runledger/cassettes/t1.jsonl` for the exact shape. - After recording, promote the new baseline: ```bash runledger baseline promote \ --from runledger_out/runledger-openchatbi/ \ --to baselines/runledger-openchatbi.json ``` ================================================ FILE: evals/runledger/__init__.py ================================================ """RunLedger eval suite for OpenChatBI.""" ================================================ FILE: evals/runledger/agent/agent.py ================================================ import json import sys from itertools import count from typing import Any from unittest.mock import MagicMock import builtins from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.tools import StructuredTool from langgraph.checkpoint.memory import MemorySaver from pydantic import BaseModel, Field from openchatbi import config import openchatbi.agent_graph as agent_graph _CALL_COUNTER = count(1) _ORIG_PRINT = builtins.print def _safe_print(*args: Any, **kwargs: Any) -> None: """Suppress stdout prints so JSONL stays clean; allow stderr.""" target = kwargs.get("file") if target is None or target is sys.stdout: return _ORIG_PRINT(*args, **kwargs) builtins.print = _safe_print class JsonlChannel: def __init__(self, stream: Any) -> None: self._stream = stream def read(self) -> dict[str, Any] | None: while True: line = self._stream.readline() if not line: return None line = line.strip() if not line: continue try: return json.loads(line) except json.JSONDecodeError: continue @staticmethod def send(payload: dict[str, Any]) -> None: sys.stdout.write(json.dumps(payload) + "\n") sys.stdout.flush() def _last_user_text(messages: list[Any]) -> str: for message in reversed(messages): if isinstance(message, HumanMessage): return str(message.content).strip() return "OpenChatBI" def _runledger_tool_call(channel: JsonlChannel, name: str, args: dict[str, Any]) -> Any: call_id = f"c{next(_CALL_COUNTER)}" channel.send({"type": "tool_call", "name": name, "call_id": call_id, "args": args}) while True: message = channel.read() if message is None: raise RuntimeError("Tool result missing") if message.get("type") != "tool_result": continue if message.get("call_id") != call_id: continue if message.get("ok"): return message.get("result") raise RuntimeError(message.get("error") or "Tool error") class SearchKnowledgeInput(BaseModel): reasoning: str = Field(description="Reason for searching knowledge") query_list: list[str] = Field(description="Query terms") knowledge_bases: list[str] = Field(description="Knowledge bases to search") with_table_list: bool = Field(default=False, description="Include table list") class ShowSchemaInput(BaseModel): reasoning: str = Field(description="Reason for showing schema") tables: list[str] = Field(description="Table names") class Text2SQLInput(BaseModel): reasoning: str = Field(description="Reason for calling text2sql") context: str = Field(description="Full context for the SQL graph") class RunPythonInput(BaseModel): reasoning: str = Field(description="Reason for running python code") code: str = Field(description="Python code to execute") class SaveReportInput(BaseModel): content: str = Field(description="Report content") title: str = Field(description="Report title") file_format: str = Field(description="File extension") def _build_tool_proxies(channel: JsonlChannel) -> dict[str, StructuredTool]: def search_knowledge( reasoning: str, query_list: list[str], knowledge_bases: list[str], with_table_list: bool = False, ) -> Any: return _runledger_tool_call( channel, "search_knowledge", { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": with_table_list, }, ) def show_schema(reasoning: str, tables: list[str]) -> Any: return _runledger_tool_call( channel, "show_schema", {"reasoning": reasoning, "tables": tables}, ) def text2sql(reasoning: str, context: str) -> Any: return _runledger_tool_call( channel, "text2sql", {"reasoning": reasoning, "context": context}, ) def run_python_code(reasoning: str, code: str) -> Any: return _runledger_tool_call( channel, "run_python_code", {"reasoning": reasoning, "code": code}, ) def save_report(content: str, title: str, file_format: str = "md") -> Any: return _runledger_tool_call( channel, "save_report", {"content": content, "title": title, "file_format": file_format}, ) return { "search_knowledge": StructuredTool.from_function( func=search_knowledge, name="search_knowledge", description="RunLedger proxy for search_knowledge", args_schema=SearchKnowledgeInput, ), "show_schema": StructuredTool.from_function( func=show_schema, name="show_schema", description="RunLedger proxy for show_schema", args_schema=ShowSchemaInput, ), "text2sql": StructuredTool.from_function( func=text2sql, name="text2sql", description="RunLedger proxy for text2sql", args_schema=Text2SQLInput, ), "run_python_code": StructuredTool.from_function( func=run_python_code, name="run_python_code", description="RunLedger proxy for run_python_code", args_schema=RunPythonInput, ), "save_report": StructuredTool.from_function( func=save_report, name="save_report", description="RunLedger proxy for save_report", args_schema=SaveReportInput, ), } def _stub_llm_call(chat_model: Any, messages: list[Any], **_kwargs: Any) -> AIMessage: tool_seen = any(isinstance(msg, ToolMessage) or getattr(msg, "type", None) == "tool" for msg in messages) if tool_seen: return AIMessage(content="Here is a deterministic summary based on the tool result.", tool_calls=[]) user_text = _last_user_text(messages) tool_args = { "reasoning": "Look up relevant knowledge", "query_list": [user_text], "knowledge_bases": ["columns"], "with_table_list": False, } return AIMessage( content="Searching knowledge base.", tool_calls=[{"name": "search_knowledge", "args": tool_args, "id": "call_1"}], ) def _configure_agent_graph(channel: JsonlChannel) -> None: tool_proxies = _build_tool_proxies(channel) agent_graph.search_knowledge = tool_proxies["search_knowledge"] agent_graph.show_schema = tool_proxies["show_schema"] agent_graph.run_python_code = tool_proxies["run_python_code"] agent_graph.save_report = tool_proxies["save_report"] agent_graph.get_sql_tools = lambda *_args, **_kwargs: tool_proxies["text2sql"] agent_graph.build_sql_graph = lambda *_args, **_kwargs: object() agent_graph.get_memory_tools = lambda *_args, **_kwargs: [] agent_graph.create_mcp_tools_sync = lambda *_args, **_kwargs: [] agent_graph.check_forecast_service_health = lambda: False agent_graph.call_llm_chat_model_with_retry = _stub_llm_call def _bootstrap_config() -> None: config.set( { "default_llm": MagicMock(), "data_warehouse_config": {}, "catalog_store": {"store_type": "file_system", "auto_load": False}, } ) def main() -> int: channel = JsonlChannel(sys.stdin) message = channel.read() if not message or message.get("type") != "task_start": return 1 _bootstrap_config() _configure_agent_graph(channel) prompt = "" payload = message.get("input", {}) if isinstance(payload, dict): prompt = payload.get("prompt") or payload.get("question") or payload.get("query") or "" if not prompt: prompt = "OpenChatBI" graph = agent_graph.build_agent_graph_sync( catalog=config.get().catalog_store, checkpointer=MemorySaver(), memory_store=None, enable_context_management=False, ) result = graph.invoke({"messages": [{"role": "user", "content": prompt}]}) output = "" if isinstance(result, dict) and result.get("messages"): output = str(result["messages"][-1].content) channel.send( { "type": "final_output", "output": {"category": "bi", "reply": output or "Completed request."}, } ) return 0 if __name__ == "__main__": raise SystemExit(main()) ================================================ FILE: evals/runledger/cases/t1.yaml ================================================ id: t1 description: "basic BI flow with a single search_knowledge tool call" input: prompt: "OpenChatBI" cassette: cassettes/t1.jsonl ================================================ FILE: evals/runledger/cassettes/t1.jsonl ================================================ {"tool":"search_knowledge","args":{"knowledge_bases":["columns"],"query_list":["OpenChatBI"],"reasoning":"Look up relevant knowledge","with_table_list":false},"ok":true,"result":{"columns":"# Relevant Columns and Description:\n## openchatbi\n- Column Category: metric\n- Display Name: OpenChatBI\n- Description \"Project overview\""}} ================================================ FILE: evals/runledger/schema.json ================================================ { "type": "object", "properties": { "category": { "type": "string" }, "reply": { "type": "string" } }, "required": [ "category", "reply" ] } ================================================ FILE: evals/runledger/suite.yaml ================================================ suite_name: runledger-openchatbi agent_command: ["python", "evals/runledger/agent/agent.py"] mode: replay cases_path: cases tool_registry: - search_knowledge tool_module: evals.runledger.tools assertions: - type: json_schema schema_path: schema.json budgets: max_wall_ms: 20000 max_tool_calls: 5 max_tool_errors: 0 baseline_path: ../../baselines/runledger-openchatbi.json ================================================ FILE: evals/runledger/tools.py ================================================ from __future__ import annotations from typing import Any from openchatbi.tool.search_knowledge import search_knowledge def _invoke_tool(tool, args: dict[str, Any]) -> Any: return tool.invoke(args) def _search_knowledge(args: dict[str, Any]) -> Any: return _invoke_tool(search_knowledge, args) TOOLS = { "search_knowledge": _search_knowledge, } ================================================ FILE: example/bi.yaml ================================================ extra_tool_use_rule: | - Try your best to give appropriate parameters when calling tools. - timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros. table_selection_extra_rule: | - When users ask about orders, consider if they need customer information (join with Customers table) - For product-related queries, check if order information is needed (join with Order_Items) - Shipment queries often require order and product details (join with multiple tables) - Invoice questions may need shipment information for complete tracking text2sql_extra_rule: | - Use proper JOIN syntax when connecting related tables - Use LIKE operator for partial string matches in product names or customer names - Handle NULL values properly in optional fields like details columns basic_knowledge_glossary: | # Sales Business System Glossary ## Overview You're answering questions related to a sales order tracking business system that manages the complete customer order lifecycle from placement to delivery. ## Key Business Concepts **Customer Management:** - Customer: Individual or entity who places orders - Customer Details: Additional information like contact info, preferences, or notes **Order Processing:** - Order: A request from a customer to purchase products - Order Status: Current state - Valid values: "Shipped", "Packing", "On Road" - Order Item: Individual product within an order (orders can contain multiple items) - Order Item Status: Status of specific items - Valid values: "Finish", "Payed", "Cancel" **Product Catalog:** - Product: Items available for purchase - Product Details: Specifications, descriptions, or additional product information **Fulfillment & Shipping:** - Shipment: Physical delivery package sent to customer - Shipment Items: Specific order items included in a shipment - Tracking Number: Unique identifier for package tracking - Shipment Date: When package was dispatched **Financial Processing:** - Invoice: Bill generated for completed orders - Invoice Number: Unique identifier for billing purposes - Invoice Date: When billing document was created ## Business Rules - One order can have multiple items (products) - One order can be fulfilled through multiple shipments - Each shipment links to one invoice for billing - Order items can have different statuses within the same order - Customers can have multiple orders over time ================================================ FILE: example/common_columns.csv ================================================ column_name,display_name,alias,type,category,tag,description,dimension_table,default customer_id,Customer ID,cust_id,INTEGER,identifier,customer,Unique identifier for customers,Customers, customer_name,Customer Name,cust_name,VARCHAR(80),attribute,customer,Name of the customer,Customers, customer_details,Customer Details,cust_details,VARCHAR(255),attribute,customer,Additional customer information,Customers, invoice_number,Invoice Number,inv_num,INTEGER,identifier,financial,Unique invoice identifier,Invoices, invoice_date,Invoice Date,inv_date,DATETIME,temporal,financial,Date the invoice was created,Invoices, invoice_details,Invoice Details,inv_details,VARCHAR(255),attribute,financial,Additional invoice information,Invoices, order_item_id,Order Item ID,oi_id,INTEGER,identifier,order,Unique identifier for order items,Order_Items, product_id,Product ID,prod_id,INTEGER,identifier,product,Unique identifier for products,Products, order_id,Order ID,ord_id,INTEGER,identifier,order,Unique identifier for orders,Orders, order_item_status,Order Item Status,oi_status,VARCHAR(10),status,order,Current status of the order item (Finish|Payed|Cancel),Order_Items, order_item_details,Order Item Details,oi_details,VARCHAR(255),attribute,order,Additional order item information,Order_Items, order_status,Order Status,ord_status,VARCHAR(10),status,order,Current status of the order (Shipped|Packing|On Road),Orders, date_order_placed,Order Placed Date,ord_date,DATETIME,temporal,order,Date when the order was placed,Orders, order_details,Order Details,ord_details,VARCHAR(255),attribute,order,Additional order information,Orders, product_name,Product Name,prod_name,VARCHAR(80),attribute,product,Name of the product,Products, product_details,Product Details,prod_details,VARCHAR(255),attribute,product,Additional product information,Products, shipment_id,Shipment ID,ship_id,INTEGER,identifier,shipment,Unique identifier for shipments,Shipments, shipment_tracking_number,Tracking Number,track_num,VARCHAR(80),identifier,shipment,Tracking number for shipment,Shipments, shipment_date,Shipment Date,ship_date,DATETIME,temporal,shipment,Date when the shipment was sent,Shipments, other_shipment_details,Shipment Details,ship_details,VARCHAR(255),attribute,shipment,Additional shipment information,Shipments, ================================================ FILE: example/config.yaml ================================================ organization: MyCompany dialect: sqlite bi_config_file: example/bi.yaml python_executor: docker # Visualization configuration visualization_mode: llm # Catalog store configuration catalog_store: store_type: file_system data_path: ./example # Data warehouse configuration data_warehouse_config: # sqlite from spider->tracking_orders dataset uri: "sqlite:///example/tracking_orders.sqlite" database_name: "" # LLM configurations # Use OpenAI LLM, replace YOUR_API_KEY_HERE with your actual API key default_llm: class: langchain_openai.ChatOpenAI params: api_key: YOUR_API_KEY_HERE model: gpt-4.1 temperature: 0.01 max_tokens: 8192 embedding_model: class: langchain_openai.OpenAIEmbeddings params: api_key: YOUR_API_KEY_HERE model: text-embedding-3-large chunk_size: 1024 # If you cannot access to OpenAI or other cloud LLM provider, # uncomment the following lines instead to use Ollama local LLM #default_llm: # class: langchain_ollama.ChatOllama # params: # model: gpt-oss:20b # temperature: 0.01 # num_predict: 8192 ================================================ FILE: example/sql_example.yaml ================================================ '': Customers: | Q: Show me all customers with their names and details A: SELECT customer_id, customer_name, customer_details FROM Customers ORDER BY customer_name Invoices: | Q: List all invoices from the last 30 days A: SELECT invoice_number, invoice_date, invoice_details FROM Invoices WHERE invoice_date >= DATE(''now'', ''-30 days'') ORDER BY invoice_date DESC Order_Items: | Q: Show me all items in order 123 A: SELECT oi.order_item_id, p.product_name, oi.order_item_status, oi.order_item_details FROM Order_Items oi JOIN Products p ON oi.product_id = p.product_id WHERE oi.order_id = 123 Orders: | Q: Find all pending orders with customer information A: SELECT o.order_id, c.customer_name, o.order_status, o.date_order_placed FROM Orders o JOIN Customers c ON o.customer_id = c.customer_id WHERE o.order_status = ''pending'' ORDER BY o.date_order_placed Products: | Q: Search for products containing ''laptop'' in the name A: SELECT product_id, product_name, product_details FROM Products WHERE product_name LIKE ''%laptop%'' ORDER BY product_name' Shipment_Items: | Q: Show which order items are in shipment 456 A: SELECT si.shipment_id, si.order_item_id, p.product_name FROM Shipment_Items si JOIN Order_Items oi ON si.order_item_id = oi.order_item_id JOIN Products p ON oi.product_id = p.product_id WHERE si.shipment_id = 456 Shipments: | Q: Track all shipments for order 789 A: SELECT shipment_id, shipment_tracking_number, shipment_date, other_shipment_details FROM Shipments WHERE order_id = 789 ORDER BY shipment_date ================================================ FILE: example/table_columns.csv ================================================ db_name,table_name,column_name ,Customers,customer_id ,Customers,customer_name ,Customers,customer_details ,Invoices,invoice_number ,Invoices,invoice_date ,Invoices,invoice_details ,Order_Items,order_item_id ,Order_Items,product_id ,Order_Items,order_id ,Order_Items,order_item_status ,Order_Items,order_item_details ,Orders,order_id ,Orders,customer_id ,Orders,order_status ,Orders,date_order_placed ,Orders,order_details ,Products,product_id ,Products,product_name ,Products,product_details ,Shipment_Items,shipment_id ,Shipment_Items,order_item_id ,Shipments,shipment_id ,Shipments,order_id ,Shipments,invoice_number ,Shipments,shipment_tracking_number ,Shipments,shipment_date ,Shipments,other_shipment_details ================================================ FILE: example/table_info.yaml ================================================ ? '' : Customers: description: 'Contains customer information including unique ID, name, and additional details' selection_rule: 'Select when queries involve customer information, customer names, or need to join orders with customer data' sql_rule: 'Use customer_id as primary key for joins. Always include customer_name when displaying customer information' Invoices: description: 'Stores invoice information with unique invoice numbers, dates, and details' selection_rule: 'Select when queries involve billing, invoice tracking, or financial reporting' sql_rule: 'Use invoice_number as primary key. Filter by invoice_date for temporal queries' Order_Items: description: 'Links products to orders with individual item status and details' selection_rule: 'Select when queries need product details within orders or item-level status tracking' sql_rule: 'Always join with Products table via product_id and Orders table via order_id for complete information' Orders: description: 'Main order table containing order status, placement date, and customer relationships' selection_rule: 'Select when queries involve order status, order history, or customer order relationships' sql_rule: 'Use order_id as primary key. Join with Customers via customer_id for customer information' Products: description: 'Product catalog containing product names, IDs, and detailed product information' selection_rule: 'Select when queries involve product information, product searches, or inventory-related questions' sql_rule: 'Use product_id as primary key. Use LIKE operator for product_name searches' Shipment_Items: description: 'Junction table linking shipments to specific order items' selection_rule: 'Select when queries need to track which specific items are in which shipments' sql_rule: 'Always join with both Shipments and Order_Items tables. No primary key - composite key of shipment_id and order_item_id' Shipments: description: 'Shipment tracking information including tracking numbers, dates, and shipment details' selection_rule: 'Select when queries involve shipping, delivery tracking, or fulfillment information' sql_rule: 'Use shipment_id as primary key. Join with Orders via order_id and Invoices via invoice_number for complete shipping context' ================================================ FILE: example/table_selection_example.csv ================================================ question,selected_tables "Show me all customers","[""Customers""]" "What orders were placed today?","[""Orders""]" "List all products and their details","[""Products""]" "Show me customer orders with their details","[""Customers"", ""Orders""]" "What products are in each order?","[""Orders"", ""Order_Items"", ""Products""]" "Show shipment tracking information","[""Shipments""]" "Which items are in each shipment?","[""Shipments"", ""Shipment_Items"", ""Order_Items""]" "Show order status and customer information","[""Orders"", ""Customers""]" "What invoices were created this month?","[""Invoices""]" "Show complete order fulfillment chain","[""Orders"", ""Order_Items"", ""Products"", ""Shipments"", ""Invoices""]" ================================================ FILE: openchatbi/__init__.py ================================================ """OpenChatBI core module initialization.""" import os from langgraph.graph.state import CompiledStateGraph from openchatbi.config_loader import ConfigLoader # Global configuration instance config = ConfigLoader() # Skip config loading during documentation build if not os.environ.get("SPHINX_BUILD"): config.load() else: config.set({}) def get_default_graph(): """ Build the synchronous mode of the agent graph using default catalog in config. Returns: CompiledStateGraph: Compiled agent graph ready for execution. """ if os.environ.get("SPHINX_BUILD"): return None from langgraph.checkpoint.memory import MemorySaver from openchatbi.agent_graph import build_agent_graph_sync from openchatbi.tool.memory import get_sync_memory_store checkpointer = MemorySaver() return build_agent_graph_sync( config.get().catalog_store, checkpointer=checkpointer, memory_store=get_sync_memory_store() ) ================================================ FILE: openchatbi/agent_graph.py ================================================ """Main agent graph construction and execution logic.""" import datetime import logging import traceback from collections.abc import Callable from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.tools import StructuredTool from langchain_openai.chat_models.base import BaseChatOpenAI from langgraph.constants import START from langgraph.errors import GraphInterrupt from langgraph.graph import END, StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode from langgraph.store.base import BaseStore from langgraph.types import Checkpointer, Send, interrupt from pydantic import BaseModel, Field from openchatbi import config from openchatbi.catalog import CatalogStore from openchatbi.constants import datetime_format from openchatbi.context_config import get_context_config from openchatbi.context_manager import ContextManager from openchatbi.graph_state import AgentState, InputState, OutputState from openchatbi.llm.llm import call_llm_chat_model_with_retry, get_llm from openchatbi.prompts.system_prompt import get_agent_prompt_template from openchatbi.text2sql.sql_graph import build_sql_graph from openchatbi.tool.ask_human import AskHuman from openchatbi.tool.mcp_tools import create_mcp_tools_sync, get_mcp_tools_async from openchatbi.tool.memory import get_memory_tools from openchatbi.tool.run_python_code import run_python_code from openchatbi.tool.save_report import save_report from openchatbi.tool.search_knowledge import search_knowledge, show_schema from openchatbi.tool.timeseries_forecast import check_forecast_service_health, timeseries_forecast from openchatbi.utils import log, recover_incomplete_tool_calls logger = logging.getLogger(__name__) def get_mcp_servers(): """Get MCP servers from config with fallback for tests.""" try: return config.get().mcp_servers except ValueError: return [] def ask_human(state: AgentState) -> dict[str, Any]: """Node function to ask human for additional information or clarification. Args: state (AgentState): The current graph state containing messages and context. Returns: dict: Updated state with human feedback as a tool message and user input. """ tool_call = state["messages"][-1].tool_calls[0] tool_call_id = tool_call["id"] args = tool_call["args"] user_feedback = interrupt({"text": args["question"], "buttons": args.get("options", None)}) tool_message = [{"tool_call_id": tool_call_id, "type": "tool", "content": user_feedback}] return { "messages": tool_message, "history_messages": [AIMessage(args["question"]), HumanMessage(user_feedback)], "user_input": user_feedback, } class CallSQLGraphInput(BaseModel): reasoning: str = Field( description="Explanation of why Text2SQL tool is needed", ) context: str = Field( description="""The full context pass to Text2SQL tool, make sure do not miss any potential information that related to user's question. Following the format: History Conversation: (user and assistant history dialog) Information: (the knowledge you retrival that is relevant, like metrics and dimensions) User's latest question:""", ) # Description for SQL tools TEXT2SQL_TOOL_DESCRIPTION = """Text2SQL tool to generate and execute SQL query and build visualization DSL for UI based on user's question and context. Returns: str: A formatted response containing SQL, data, and visualization status. Important notes: - If user want to change the visualization chart type or style, add the requirement in the question - Make sure to provide question in English """ def _format_sql_response(sql_graph_response: dict) -> str: """Format SQL graph response into a standardized string format. Args: sql_graph_response: The response dictionary from the SQL graph Returns: str: Formatted response string """ sql = sql_graph_response.get("sql", "") data = sql_graph_response.get("data", "") visualization_dsl = sql_graph_response.get("visualization_dsl", {}) response_parts = [] if sql: response_parts.append(f"SQL Query:\n```sql\n{sql}\n```") if data: response_parts.append(f"\nQuery Results (CSV format):\n```csv\n{data}\n```") # Include visualization status if visualization_dsl and "error" not in visualization_dsl: chart_type = visualization_dsl.get("chart_type", "unknown") response_parts.append( f"\nVisualization Created: {chart_type} chart has been automatically generated and will be displayed in the UI." ) elif visualization_dsl and "error" in visualization_dsl: response_parts.append(f"\nVisualization Error: {visualization_dsl['error']}") return "\n\n".join(response_parts) if response_parts else "No results returned." def get_sql_tools(sql_graph: CompiledStateGraph, sync_mode: bool = False) -> Callable: """Create SQL generation tool from compiled SQL graph. Args: sql_graph (CompiledStateGraph): The compiled SQL generation subgraph. sync_mode (bool): Whether to create synchronous or asynchronous tools Returns: function: Tool function for SQL generation. """ def call_sql_graph_sync(reasoning: str, context: str) -> str: """Sync node function for Text2SQL tool""" log(f"Call SQL graph (sync) with reasoning: {reasoning}, context: {context}") try: sql_graph_response = sql_graph.invoke({"messages": context}) return _format_sql_response(sql_graph_response) except GraphInterrupt as e: log(f"Sql graph interrupted:\n{repr(e)}") raise e except Exception as e: log(f"Run sql graph error:\n{repr(e)}") traceback.print_exc() return "Error occurred when calling Text2SQL tool." async def call_sql_graph_async(reasoning: str, context: str) -> str: """Async node function for Text2SQL tool""" log(f"Call SQL graph (async) with reasoning: {reasoning}, context: {context}") try: sql_graph_response = await sql_graph.ainvoke({"messages": context}) return _format_sql_response(sql_graph_response) except GraphInterrupt as e: log(f"Sql graph interrupted:\n{repr(e)}") raise e except Exception as e: log(f"Run sql graph error:\n{repr(e)}") traceback.print_exc() return "Error occurred when calling Text2SQL tool." if sync_mode: return StructuredTool.from_function( func=call_sql_graph_sync, name="text2sql", description=TEXT2SQL_TOOL_DESCRIPTION, args_schema=CallSQLGraphInput, return_direct=False, ) else: return StructuredTool.from_function( coroutine=call_sql_graph_async, name="text2sql", description=TEXT2SQL_TOOL_DESCRIPTION, args_schema=CallSQLGraphInput, return_direct=False, ) def agent_llm_call(llm: BaseChatModel, tools: list, context_manager: ContextManager = None) -> Callable: """Create llm call function to generate reasoning and determine next node based on tool calls in LLM response. Args: llm (BaseChatModel): The LLM for agent decision-making. tools: List of tools. context_manager: Optional context manager for handling long conversations. Returns: function: function that processes state and determines next node. """ # OpenAI models support strict tool calling if isinstance(llm, BaseChatOpenAI): llm_with_tools = llm.bind_tools(tools, strict=True) else: llm_with_tools = llm.bind_tools(tools) def _call_model(state: AgentState): # First, check and recover any incomplete tool calls recovery_ops = recover_incomplete_tool_calls(state) if recovery_ops: return {"messages": recovery_ops, "agent_next_node": "llm_node"} messages = state["messages"] final_messages = [] if isinstance(messages[-1], HumanMessage): final_messages.append(messages[-1]) # Apply context management if available (before processing) if context_manager: original_count = len(messages) context_manager.manage_context_messages(messages) if len(messages) != original_count: logger.info(f"Context management: modified messages from {original_count} to {len(messages)}") system_prompt = get_agent_prompt_template().replace( "[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format) ) response = call_llm_chat_model_with_retry( llm_with_tools, ([SystemMessage(system_prompt)] + messages), streaming_tokens=True, bound_tools=tools, parallel_tool_call=True, ) if isinstance(response, AIMessage): tool_calls = response.tool_calls print("Tool Call:", ", ".join(tool["name"] for tool in tool_calls)) if tool_calls: # Group tool calls by type for parallel routing ask_human_calls = [call for call in tool_calls if call["name"] == "AskHuman"] normal_tool_calls = [call for call in tool_calls if call["name"] != "AskHuman"] # Create Send objects for parallel routing sends = [] if ask_human_calls: # Create message with only AskHuman calls ask_human_msg = AIMessage(content=response.content, tool_calls=ask_human_calls) sends.append(Send("ask_human", {"messages": [ask_human_msg]})) if normal_tool_calls: # Create message with only normal tool calls tool_msg = AIMessage(content=response.content, tool_calls=normal_tool_calls) sends.append(Send("use_tool", {"messages": [tool_msg]})) return {"messages": [response], "history_messages": final_messages, "sends": sends} else: final_messages.append(AIMessage(response.content)) return { "messages": [response], "final_answer": response.content, "history_messages": final_messages, "agent_next_node": END, } elif response is None: return { "messages": [AIMessage("Sorry, the LLM service is currently unavailable.")], "history_messages": final_messages, "agent_next_node": END, } else: return {"messages": [response], "history_messages": final_messages, "agent_next_node": END} return _call_model def _build_graph_core( catalog: CatalogStore, sync_mode: bool, checkpointer: Checkpointer, memory_store: BaseStore, memory_tools: list[Callable] | None, mcp_tools: list, enable_context_management: bool = True, llm_provider: str | None = None, ) -> CompiledStateGraph: """Core graph building logic shared by both sync and async versions. Args: catalog: Catalog store containing schema information sync_mode: Whether to use synchronous mode for tools and operations checkpointer: The Checkpointer for state persistence memory_store: The BaseStore to use for long-term memory memory_tools: List of memory tools (manage_memory_tool, search_memory_tool) mcp_tools: Pre-initialized MCP tools enable_context_management: Whether to enable context management Returns: CompiledStateGraph: Compiled agent graph ready for execution """ sql_graph = build_sql_graph(catalog, checkpointer, memory_store, llm_provider=llm_provider) call_sql_graph_tool = get_sql_tools(sql_graph=sql_graph, sync_mode=sync_mode) # Use provided memory tools or create them if not memory_tools: memory_tools = get_memory_tools(get_llm(llm_provider), sync_mode=sync_mode, store=memory_store) log(str(mcp_tools)) normal_tools = [ search_knowledge, show_schema, call_sql_graph_tool, run_python_code, save_report, ] if memory_tools: normal_tools.extend(memory_tools) if check_forecast_service_health(): normal_tools.append(timeseries_forecast) else: logger.warning("Time series forecasting service is not healthy. Skipping timeseries_forecast tool.") normal_tools.extend(mcp_tools) # Initialize context manager if enabled context_manager = None if enable_context_management: context_manager = ContextManager(llm=get_llm(llm_provider), config=get_context_config()) tool_node = ToolNode(normal_tools) # Define the agent graph graph = StateGraph(AgentState, input_schema=InputState, output_schema=OutputState) # Add nodes to the graph graph.add_node("llm_node", agent_llm_call(get_llm(llm_provider), normal_tools + [AskHuman], context_manager)) graph.add_node("ask_human", ask_human) graph.add_node("use_tool", tool_node) # Add edges between nodes graph.add_edge(START, "llm_node") graph.add_edge("ask_human", "llm_node") graph.add_edge("use_tool", "llm_node") # Add conditional routing from llm node def route_tools(state: AgentState): # Only use sends if the last message came from the llm node (has tool_calls) last_message = state["messages"][-1] if state["messages"] else None if ( last_message and isinstance(last_message, AIMessage) and last_message.tool_calls and "sends" in state and state["sends"] ): return state["sends"] # Return Send objects for parallel execution elif "agent_next_node" in state: return state["agent_next_node"] # Return single node name else: return END graph.add_conditional_edges( "llm_node", route_tools, # mapping of paths to node names (for single routing) { "llm_node": "llm_node", "ask_human": "ask_human", "use_tool": "use_tool", END: END, }, ) graph = graph.compile(name="agent_graph", checkpointer=checkpointer, store=memory_store) return graph def build_agent_graph_sync( catalog: CatalogStore, checkpointer: Checkpointer = None, memory_store: BaseStore = None, enable_context_management: bool = True, llm_provider: str | None = None, ) -> CompiledStateGraph: """Build the main agent graph with all nodes and edges (sync version). Args: catalog: Catalog store containing schema information. checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory. memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode. enable_context_management: Whether to enable context management for long conversations. Returns: CompiledStateGraph: Compiled agent graph ready for execution. """ # Get MCP tools for sync context mcp_tools = create_mcp_tools_sync(get_mcp_servers()) return _build_graph_core( catalog=catalog, sync_mode=True, checkpointer=checkpointer, memory_store=memory_store, memory_tools=None, # Always None for sync version - creates its own mcp_tools=mcp_tools, enable_context_management=enable_context_management, llm_provider=llm_provider, ) async def build_agent_graph_async( catalog: CatalogStore, checkpointer: Checkpointer = None, memory_store: BaseStore = None, memory_tools: list[Callable] = None, enable_context_management: bool = True, llm_provider: str | None = None, ) -> CompiledStateGraph: """Build the main agent graph with all nodes and edges (async version). This function is identical to build_agent_graph_sync but properly handles async MCP tool initialization when called from async contexts. Args: catalog: Catalog store containing schema information. checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory. memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode. memory_tools: List of memory tools (manage_memory_tool, search_memory_tool). If None, creates async tools. enable_context_management: Whether to enable context management for long conversations. Returns: CompiledStateGraph: Compiled agent graph ready for execution. """ # Get MCP tools for async context mcp_tools = await get_mcp_tools_async(get_mcp_servers()) return _build_graph_core( catalog=catalog, sync_mode=False, checkpointer=checkpointer, memory_store=memory_store, memory_tools=memory_tools, mcp_tools=mcp_tools, enable_context_management=enable_context_management, llm_provider=llm_provider, ) ================================================ FILE: openchatbi/catalog/__init__.py ================================================ """Data catalog management module for OpenChatBI.""" from openchatbi.catalog.catalog_loader import ( DataCatalogLoader, load_catalog_from_data_warehouse, ) from openchatbi.catalog.catalog_store import CatalogStore from openchatbi.catalog.factory import create_catalog_store __all__ = [ "CatalogStore", "DataCatalogLoader", "load_catalog_from_data_warehouse", ] ================================================ FILE: openchatbi/catalog/catalog_loader.py ================================================ import logging from typing import Any from sqlalchemy import MetaData, inspect from sqlalchemy.engine import Engine from .catalog_store import CatalogStore logger = logging.getLogger(__name__) class DataCatalogLoader: """ The loader to load data catalog from data warehouse metadata and save to catalog store. """ def __init__(self, engine: Engine, include_tables: list[str] | None = None): """ Initialize catalog loader. Args: engine (Engine): SQLAlchemy engine instance include_tables (Optional[List[str]]): List of table names to include, None for all """ self.engine = engine self.include_tables = include_tables self.metadata = MetaData() self.inspector = inspect(engine) def get_tables_and_columns(self) -> dict[str, list[dict[str, Any]]]: """ Extract table and column metadata including comments using SQLAlchemy inspector. Returns: Dict[str, List[Dict[str, Any]]]: Dictionary mapping table names to list of column information """ try: tables_columns = {} # Get all table names table_names = self.inspector.get_table_names() # Filter to specific tables if configured if self.include_tables: table_names = [name for name in table_names if name in self.include_tables] logger.info(f"Found {len(table_names)} tables to process") for table_name in table_names: try: # Get column information for the table columns = self.inspector.get_columns(table_name) column_list = [] for column in columns: is_common_column = column not in ("id", "name", "type", "status") column_info = { "column_name": column["name"], "display_name": "", "alias": "", "type": str(column["type"]), "category": "", "tag": "", "description": column.get("comment", "") or "", "dimension_table": "", "default": str(column.get("default", "")) if column.get("default") is not None else "", "is_common": is_common_column, } column_list.append(column_info) tables_columns[table_name] = column_list logger.debug(f"Processed table {table_name} with {len(column_list)} columns") except Exception as e: logger.error(f"Failed to process table {table_name}: {e}") continue logger.info(f"Successfully processed {len(tables_columns)} tables") return tables_columns except Exception as e: logger.error(f"Failed to get tables and columns from data warehouse: {e}") return {} def get_table_indexes(self, table_name: str) -> list[dict[str, Any]]: """ Get index information for a specific table. Args: table_name (str): Name of the table Returns: List[Dict[str, Any]]: List of index information """ try: indexes = self.inspector.get_indexes(table_name) return indexes except Exception as e: logger.warning(f"Failed to get indexes for table {table_name}: {e}") return [] def get_foreign_keys(self, table_name: str) -> list[dict[str, Any]]: """ Get foreign key information for a specific table. Args: table_name (str): Name of the table Returns: List[Dict[str, Any]]: List of foreign key information """ try: foreign_keys = self.inspector.get_foreign_keys(table_name) return foreign_keys except Exception as e: logger.warning(f"Failed to get foreign keys for table {table_name}: {e}") return [] def save_to_catalog_store( self, catalog_store: CatalogStore, database_name: str | None = None, update: bool = False ) -> bool: """ Extract warehouse metadata and save to catalog store. Args: catalog_store (CatalogStore): Target catalog store to load data to database_name (Optional[str]): Database name in catalog, defaults to 'default' update (bool): Update existing catalog store to sync with data warehouse Returns: bool: True if load was successful, False otherwise """ try: if database_name is None: database_name = "default" # Get tables and columns from data warehouse tables_columns = self.get_tables_and_columns() if not tables_columns: logger.warning("No tables found in data warehouse") return True # Import each table success_count = 0 total_count = len(tables_columns) for table_name, columns in tables_columns.items(): try: # Get table comment if available table_comment = "" try: table_info = self.inspector.get_table_comment(table_name) table_comment = table_info.get("text", "") if table_info else "" except Exception: # Some databases don't support table comments pass table_info = {"description": table_comment, "selection_rule": "", "sql_rule": ""} if catalog_store.save_table_information(table_name, table_info, columns, database_name): success_count += 1 logger.info(f"Successfully loaded table: {database_name}.{table_name}") else: logger.error(f"Failed to load table: {database_name}.{table_name}") # init null SQL examples catalog_store.save_table_sql_examples( table_name, [{"question": "null", "answer": "null"}], database_name ) except Exception as e: logger.error(f"Error loading table {table_name}: {e}") # init empty table selection examples catalog_store.save_table_selection_examples([("", [])]) logger.info(f"Load completed: {success_count}/{total_count} tables loaded successfully") return success_count == total_count except Exception as e: logger.error(f"Failed to load data warehouse to catalog store: {e}") return False def load_catalog_from_data_warehouse(catalog_store: CatalogStore) -> bool: """ Load catalog data from data warehouse using SQLAlchemy based on data warehouse config (URI) Main entry point for catalog loading. Args: catalog_store (CatalogStore): Target catalog store Returns: bool: True if load was successful, False otherwise """ try: data_warehouse_config = catalog_store.get_data_warehouse_config() database_uri = data_warehouse_config.get("uri") include_tables = data_warehouse_config.get("include_tables") database_name = data_warehouse_config.get("database_name", "default") engine = catalog_store.get_sql_engine() loader = DataCatalogLoader(engine, include_tables) return loader.save_to_catalog_store(catalog_store, database_name) except Exception as e: logger.error(f"Failed to import catalog from data warehouse URI {database_uri}: {e}") return False ================================================ FILE: openchatbi/catalog/catalog_store.py ================================================ from abc import ABC, abstractmethod from typing import Any from sqlalchemy import Engine class CatalogStore(ABC): """ Abstract base class defining the storage interface for data catalog (database, table, column definitions, descriptions, and additional prompts). Common columns which have same meanings across tables will be store centralized to avoid data duplication. Column attribute: - column_name: the name of the column - display_name: the display name of the column - type: the data type of the column - category: dimension or metric - description: the description of the column - is_common: is common column or not """ @abstractmethod def get_data_warehouse_config(self) -> dict: """ Get the data warehouse configuration Returns: dict: Data warehouse configuration """ pass @abstractmethod def get_sql_engine(self) -> Engine: """ Get the SQLAlchemy engine for the catalog Returns: Engine: SQLAlchemy engine """ pass @abstractmethod def get_database_list(self) -> list[str]: """ Get a list of all databases Returns: List[str]: List of database names """ pass @abstractmethod def get_table_list(self, database: str | None = None) -> list[str]: """ Get a list of all tables in the specified database, if database is None, return all tables Args: database (Optional[str]): Database name Returns: List[str]: List of table names """ pass @abstractmethod def get_column_list(self, table: str | None = None, database: str | None = None) -> list[dict[str, Any]]: """ Get all column information for the specified table, if table is None, return all common columns in the catalog Args: table (Optional[str]): Table name database (Optional[str]): Database name Returns: List[Dict[str, Any]]: List of column information, each column contains name, type, description, etc. """ pass @abstractmethod def get_table_information(self, table: str, database: str | None = None) -> dict[str, Any]: """ Get the information for the specified table Args: table (str): Table name database (Optional[str]): Database name Returns: Dict[str, Any]: Table information, including description text, selection rules, etc. """ pass @abstractmethod def get_sql_examples( self, table: str | None = None, database: str | None = None ) -> list[tuple[str, str, list[str]]]: """ Get SQL examples Args: table (Optional[str]): Table name database (Optional[str]): Database name Returns: List[Tuple[str, str, List[str]]]: List of SQL examples, each example is a Tuple-3 as (question, SQL, full_table_names) """ pass @abstractmethod def get_table_selection_examples(self) -> list[tuple[str, list[str]]]: """ Get table selection examples Returns: List[Tuple[str, List[str]]]: List of table selection examples, each example is a Tuple-2 as (question, selected tables) """ pass @abstractmethod def save_table_information( self, table: str, information: dict[str, Any], columns: list[dict[str, Any]], database: str | None = None, update_existing: bool = False, ) -> bool: """ Save the information and columns for a table Args: table (str): Table name information (Dict[str, Any]): Table information columns (List[Dict[str, Any]]): List of column information, each column dict contains at lease column_name, type, category, description database (Optional[str]): Database name update_existing (bool): Update existing table and column information Returns: bool: Whether the save was successful """ pass @abstractmethod def save_table_sql_examples(self, table: str, examples: list[dict[str, str]], database: str | None = None) -> bool: """ Save SQL examples for a table Args: table (str): Table name examples (List[Dict[str, str]]): List of SQL examples database (Optional[str]): Database name Returns: bool: Whether the save was successful """ pass @abstractmethod def save_table_selection_examples(self, examples: list[tuple[str, list[str]]]) -> bool: """ Save table selection examples Args: examples (List[Tuple[str, List[str]]]): List of table selection examples Returns: bool: Whether the save was successful """ pass @abstractmethod def check_exists(self) -> bool: """ Check if the catalog store has existing data/content Returns: bool: True if catalog store has existing data, False if empty or missing essential files """ pass def split_db_table_name(table: str, database: str | None = None) -> tuple[str, str, str]: """ Split full table name into db name and table name Args: table (str): if database is None, should be full table name like `db.table`, otherwise should be only table name database (Optional[str]): Database name Returns: Tuple[str, str, str]: full_table_name, db_name, table_name """ full_table_name = table if database is not None and "." not in table: full_table_name = f"{database}.{table}" if "." in full_table_name: db_name, table_name = full_table_name.rsplit(".", 1) else: db_name = "" table_name = full_table_name return full_table_name, db_name, table_name ================================================ FILE: openchatbi/catalog/factory.py ================================================ import logging import os from openchatbi.catalog.catalog_loader import load_catalog_from_data_warehouse from openchatbi.catalog.catalog_store import CatalogStore from openchatbi.catalog.store.file_system import FileSystemCatalogStore logger = logging.getLogger(__name__) # Factory function for creating CatalogStore instances def create_catalog_store( store_type: str, auto_load: bool = True, data_warehouse_config: dict = None, **kwargs ) -> CatalogStore: """ Create a CatalogStore instance Args: store_type (str): Storage type, supports 'file_system' auto_load (bool): Whether to autoload from database if catalog files don't exist data_warehouse_config (dict): Data warehouse configuration dictionary **kwargs: Other parameters Returns: CatalogStore: CatalogStore instance Raises: ValueError: If the storage type is not supported """ if store_type == "file_system": data_path = kwargs.get("data_path", "data") # convert relative path to absolute path if not data_path.startswith("/"): data_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), data_path) catalog_store = FileSystemCatalogStore(data_path, data_warehouse_config) # Check if autoload is enabled and if catalog files are missing if auto_load: _auto_load_catalog_if_needed(catalog_store) return catalog_store else: raise ValueError(f"Unsupported storage type: {store_type}") def _auto_load_catalog_if_needed(catalog_store: CatalogStore) -> None: """ Autoload catalog from data warehouse if catalog files are missing or empty Args: catalog_store (CatalogStore): The catalog store instance """ # Check if catalog store has existing data using the store's own check_exists method if not catalog_store.check_exists(): logger.info("Catalog files missing or empty, attempting to load from data warehouse...") try: # Get data warehouse config from loaded configuration data_warehouse_config = catalog_store.get_data_warehouse_config() if not data_warehouse_config: logger.warning("No data warehouse configuration found, skipping autoload") return warehouse_uri = data_warehouse_config.get("uri") if not warehouse_uri: logger.warning("No data warehouse URI found in configuration, skipping autoload") return # load catalog from data warehouse success = load_catalog_from_data_warehouse(catalog_store) if success: logger.info("Successfully loaded catalog from data warehouse") else: logger.error("Failed to load catalog from data warehouse") raise Exception("Failed to load catalog from data warehouse") except Exception as e: logger.warning(f"Autoload from data warehouse failed: {e}") raise Exception("Failed to load catalog from data warehouse") from e ================================================ FILE: openchatbi/catalog/helper.py ================================================ from typing import Any import requests from sqlalchemy import Engine, create_engine from openchatbi.catalog.token_service import apply_token_for_user from openchatbi.utils import log def get_requests_session(token: str, header_extra_params: dict) -> requests.Session: """Create HTTP session with bearer token authentication.""" session = requests.Session() session.headers.update({"Authorization": f"Bearer {token}"}) if header_extra_params: session.headers.update(header_extra_params) return session def create_sqlalchemy_engine_instance(data_warehouse_config: dict[str, Any]) -> Engine: """ Create SQLAlchemy engine instance from data warehouse config Args: data_warehouse_config: Config dict with 'uri' and optional 'token_service' Returns: Configured SQLAlchemy engine """ database_uri = data_warehouse_config.get("uri") engine_args = {"echo": True} # Handle Presto authentication if "presto" in database_uri and "token_service" in data_warehouse_config: token_service = data_warehouse_config.get("token_service") user_name = data_warehouse_config.get("user_name") password = data_warehouse_config.get("password") header_extra_params = data_warehouse_config.get("header_extra_params", {}) token = apply_token_for_user(token_service, user_name, password) log(f"Applied presto token: {token} for user: {user_name}") engine_args["connect_args"] = { "protocol": "https", "requests_session": get_requests_session(token, header_extra_params), } database_uri = database_uri.format(user_name=user_name) engine = create_engine(database_uri, **engine_args) return engine ================================================ FILE: openchatbi/catalog/retrival_helper.py ================================================ """Helper functions for building column retrieval systems.""" from rank_bm25 import BM25Okapi from openchatbi.llm.llm import get_embedding_model from openchatbi.text_segmenter import _segmenter from openchatbi.utils import create_vector_db, log def get_columns_metadata(catalog): """Extract column metadata for indexing. Args: catalog: Catalog store instance. Returns: tuple: (columns, col_dict, column_tokens, embedding_keys) """ columns = catalog.get_column_list() col_dict = {} column_tokens = [] embedding_keys = [] for column in columns: col_dict[column["column_name"]] = column text_parts = [ column.get("column_name", ""), column.get("display_name", ""), column.get("alias", ""), column.get("tag", ""), column.get("description", ""), ] text = " ".join(text_parts) tokens = [token for token in _segmenter.cut(text) if token not in ("_", " ")] column_tokens.append(tokens) embedding_key = f"{column['column_name']}: {column['display_name']}" embedding_keys.append(embedding_key) return columns, col_dict, column_tokens, embedding_keys def build_column_tables_mapping(catalog): """Build a mapping of column names to their corresponding table names.""" column_tables_mapping = {} for table_name in catalog.get_table_list(): for column in catalog.get_column_list(table_name): column_name = column["column_name"] if column_name not in column_tables_mapping: column_tables_mapping[column_name] = [] column_tables_mapping[column_name].append(table_name) return column_tables_mapping def build_columns_retriever(catalog, vector_db_path: str = None): """Build BM25 and vector retrievers for columns. Args: catalog: Catalog store instance. vector_db_path: Path to the vector database file. Returns: tuple: (bm25, vector_db, columns, col_dict) """ columns, col_dict, column_tokens, embedding_keys = get_columns_metadata(catalog) bm25 = BM25Okapi(column_tokens) log("Building vector database for columns...") vector_db = create_vector_db( embedding_keys, get_embedding_model(), metadatas=columns, collection_name="columns", collection_metadata={"hnsw:space": "cosine"}, chroma_db_path=vector_db_path, ) return bm25, vector_db, columns, col_dict ================================================ FILE: openchatbi/catalog/schema_retrival.py ================================================ """Schema and column retrieval functionality for finding relevant database structures.""" import os import re import Levenshtein from openchatbi import config from openchatbi.catalog.retrival_helper import build_column_tables_mapping, build_columns_retriever from openchatbi.text_segmenter import _segmenter from openchatbi.utils import log # Skip build during documentation build if not os.environ.get("SPHINX_BUILD"): try: _catalog_store = config.get().catalog_store except ValueError: _catalog_store = None else: _catalog_store = None if _catalog_store: bm25, vector_db, columns, col_dict = build_columns_retriever(_catalog_store, config.get().vector_db_path) column_tables_mapping = build_column_tables_mapping(_catalog_store) else: bm25, vector_db, columns, col_dict = None, None, [], {} column_tables_mapping = {} def column_retrieval(query, db, k=10, threshold=0.5, filter=None): """Retrieves relevant columns based on a similarity search. Args: query (str): The query string to search for. db: The vector database to search in. k (int, optional): The number of top results to return. Defaults to 10. threshold (float, optional): The similarity threshold for filtering results. Defaults to 0.5. filter (dict, optional): A filter to apply to the search. Defaults to None. Returns: list: List of relevant column names. """ log(f"Get the top relevant columns for query: {query}") similar_column_key_scores = db.similarity_search_with_score(query, k=k, filter=filter) # log(f"similar_column_key_scores: {similar_column_key_scores}") column_names = [key.metadata["column_name"] for (key, score) in similar_column_key_scores if score < threshold] log(f"Filtered relevant columns: {column_names}") return column_names def merge_list(list1, list2): return list(set(list1 + list2)) def edit_distance_score(key1, key2): """Calculate normalized edit distance score between two strings. Returns: float: Score between 0 (identical) and 1 (completely different). """ dist = Levenshtein.distance(key1, key2) max_len = max(len(key1), len(key2)) return dist / max_len if max_len > 0 else 1 def edit_distance_search(keywords_list, top_k=10, threshold=0.5): """Searches for columns using edit distance similarity. Args: keywords_list (list): List of keywords to search for. top_k (int, optional): The number of top results to return per keyword. Defaults to 10. threshold (float, optional): The maximum edit distance score to consider. Defaults to 0.5. Returns: list: List of relevant column names. """ keys = set([re.sub(r"(_id|_name| id| name)$", "", key.lower()) for key in keywords_list]) column_similarity_score = set() for key in keys: key_column_similarity_score = {} for column_name, row in col_dict.items(): column_name_score = edit_distance_score( key, re.sub(r"(_id|_name| id| name)$", "", row.get("column_name", "")) ) display_score = edit_distance_score( key, re.sub(r"(_id|_name| id| name)$", "", row.get("display_name", "").lower()) ) if column_name_score < threshold or display_score < threshold: key_column_similarity_score[column_name] = min(column_name_score, display_score) key_top_column = [ key for key, _ in sorted(key_column_similarity_score.items(), key=lambda x: x[1], reverse=True)[:top_k] ] column_similarity_score.update(key_top_column) return list(column_similarity_score) def bm25_search(query_list, top_k=5, score_threshold=0.5): """Performs a BM25 search on columns based on the query. Args: query_list (list): List of query terms. top_k (int, optional): The number of top results to return. Defaults to 5. score_threshold (float, optional): The minimum BM25 score to consider. Defaults to 0.5. Returns: list: List of relevant column names. """ query_tokens = [token for token in _segmenter.cut(" ".join(query_list)) if token not in ("_", " ")] scores = bm25.get_scores(query_tokens) ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) results = [] for idx, score in ranked[:top_k]: if score_threshold and score < score_threshold: continue results.append(columns[idx]["column_name"]) return results def get_relevant_columns(keywords_list, dimensions, metrics): """Get the most relevant columns for given keywords, dimensions, and metrics. Uses multiple retrieval methods (BM25, edit distance, vector similarity) to find the best matching columns. Args: keywords_list (list): General keywords to search for. dimensions (list): Dimension-specific keywords. metrics (list): Metric-specific keywords. Returns: list: Relevant column names. """ # 1. BM25 search for general keywords total_results = bm25_search(keywords_list, top_k=len(keywords_list) * 4) # 2. Edit distance search for exact matches keyword_len = len(keywords_list + dimensions + metrics) ed_results = edit_distance_search(keywords_list + dimensions + metrics, top_k=keyword_len, threshold=0.3) total_results = merge_list(total_results, ed_results) # 3. Vector similarity search for dimensions if dimensions: d_results = column_retrieval(" ".join(dimensions), vector_db, k=10, filter={"category": "dimension"}) total_results = merge_list(total_results, d_results) # 4. Vector similarity search for metrics if metrics: m_results = column_retrieval(" ".join(metrics), vector_db, k=10, threshold=0.55, filter={"category": "metric"}) total_results = merge_list(total_results, m_results) log(f"Relevant columns: {total_results}") return total_results ================================================ FILE: openchatbi/catalog/store/__init__.py ================================================ """Catalog store implementations.""" from .file_system import FileSystemCatalogStore ================================================ FILE: openchatbi/catalog/store/file_system.py ================================================ """File system-based catalog store implementation.""" import csv import logging import os import re import traceback from typing import Any import yaml from sqlalchemy import Engine from ..catalog_store import CatalogStore, split_db_table_name from ..helper import create_sqlalchemy_engine_instance logger = logging.getLogger(__name__) class FileSystemCatalogStore(CatalogStore): """File system-based data catalog storage implementation. Stores catalog data in CSV and YAML files on the local filesystem. """ data_path: str table_info_file: str sql_example_file: str table_selection_example_file: str table_columns_file: str common_columns_file: str table_spec_columns_file: str _table_info_cache: dict | None _table_columns_cache: dict | None _common_columns_cache: dict | None _table_spec_columns_cache: dict | None _sql_example_cache: dict | None _table_selection_example_cache: dict | None _data_warehouse_config: dict _sql_engine: Engine def __init__(self, data_path: str, data_warehouse_config: dict): """Initialize filesystem catalog store. Args: data_path (str): Directory absolute path for storing catalog files. data_warehouse_config (dict): Data warehouse configuration dictionary with keys: - uri (str): Database connection URI - include_tables (Optional[List[str]]): List of tables to include, if None include all - database_name (Optional[str]): Database name to use in catalog """ if not isinstance(data_path, str) or not data_path.strip(): raise ValueError("data_path must be a non-empty string") if data_warehouse_config is None: data_warehouse_config = {} elif not isinstance(data_warehouse_config, dict): raise ValueError("data_warehouse_config must be a dictionary") self.data_path = data_path.strip() self.table_info_file = os.path.join(data_path, "table_info.yaml") self.sql_example_file = os.path.join(data_path, "sql_example.yaml") self.table_selection_example_file = os.path.join(data_path, "table_selection_example.csv") self.table_columns_file = os.path.join(data_path, "table_columns.csv") self.common_columns_file = os.path.join(data_path, "common_columns.csv") self.table_spec_columns_file = os.path.join(data_path, "table_spec_columns.csv") # Ensure directory exists with proper error handling try: os.makedirs(self.data_path, exist_ok=True) except (OSError, PermissionError) as e: raise RuntimeError(f"Failed to create data directory '{self.data_path}': {e}") from e # Initialize cache self._table_info_cache = None self._table_columns_cache = None self._common_columns_cache = None self._table_spec_columns_cache = None self._sql_example_cache = None self._table_selection_example_cache = None self._data_warehouse_config = data_warehouse_config try: self._sql_engine = create_sqlalchemy_engine_instance(data_warehouse_config) except Exception as e: logger.warning(f"Failed to create SQL engine: {e}. Some catalog operations may not work.") self._sql_engine = None def _clear_cache(self) -> None: """ Clear all cached data to ensure consistency after data modifications """ self._table_info_cache = None self._table_columns_cache = None self._common_columns_cache = None self._table_spec_columns_cache = None self._sql_example_cache = None self._table_selection_example_cache = None logger.debug("Cleared all caches") def get_data_warehouse_config(self) -> dict: return self._data_warehouse_config def get_sql_engine(self) -> Engine: if self._sql_engine is None: raise RuntimeError("SQL engine is not available. Check data warehouse configuration.") return self._sql_engine def _validate_table_name(self, table: str) -> bool: """ Validate table name Args: table (str): Table name Returns: bool: Whether the table name is valid Raises: ValueError: If table name is invalid """ if not table or not isinstance(table, str): raise ValueError("Table name must be a non-empty string") # Check for invalid characters (allow dots for db.table format) invalid_chars = ["/", "\\", "*", "?", "<", ">", "|", '"', "'"] if any(char in table for char in invalid_chars): raise ValueError(f"Table name contains invalid characters: {table}") return True def _validate_column_data(self, columns: list[dict[str, Any]]) -> bool: """ Validate column data format Args: columns (List[Dict[str, Any]]): List of column information Returns: bool: Whether the column data is valid Raises: ValueError: If column data is invalid """ if not isinstance(columns, list): raise ValueError("Columns must be a list") required_fields = {"column_name", "type"} for i, column in enumerate(columns): if not isinstance(column, dict): raise ValueError(f"Column {i} must be a dictionary") # Check required fields missing_fields = required_fields - set(column.keys()) if missing_fields: raise ValueError(f"Column {i} missing required fields: {missing_fields}") # Validate column_name column_name = column.get("column_name") if not isinstance(column_name, str) or not column_name.strip(): raise ValueError(f"Column {i}: column_name must be a non-empty string") # Validate type column_type = column.get("type") if not isinstance(column_type, str) or not column_type.strip(): raise ValueError(f"Column {i}: type must be a non-empty string") return True def _validate_table_information(self, information: dict[str, Any]) -> bool: """ Validate table information format Args: information (Dict[str, Any]): Table information Returns: bool: Whether the table information is valid Raises: ValueError: If table information is invalid """ if not isinstance(information, dict): raise ValueError("Table information must be a dictionary") # Validate optional string fields string_fields = ["description", "selection_rule"] for field in string_fields: if field in information: value = information[field] if value is not None and not isinstance(value, str): raise ValueError(f"Table information field '{field}' must be a string or None") return True def _validate_sql_examples(self, examples: list[dict[str, str]]) -> bool: """ Validate SQL examples format Args: examples (List[Dict[str, str]]): List of SQL examples Returns: bool: Whether the SQL examples are valid Raises: ValueError: If SQL examples are invalid """ if not isinstance(examples, list): raise ValueError("Examples must be a list") required_fields = {"question", "answer"} for i, example in enumerate(examples): if not isinstance(example, dict): raise ValueError(f"Example {i} must be a dictionary") # Check required fields missing_fields = required_fields - set(example.keys()) if missing_fields: raise ValueError(f"Example {i} missing required fields: {missing_fields}") # Validate fields are non-empty strings for field in required_fields: value = example.get(field) if not isinstance(value, str) or not value.strip(): raise ValueError(f"Example {i}: {field} must be a non-empty string") return True @staticmethod def _load_yaml_file(file_path: str) -> dict: """ Load YAML file Args: file_path (str): File path Returns: Dict: YAML content """ if not os.path.exists(file_path): logger.debug(f"YAML file does not exist: {file_path}") return {} try: with open(file_path, encoding="utf-8") as f: data = yaml.safe_load(f) or {} logger.debug(f"Successfully loaded YAML file: {file_path}") return data except Exception as e: logger.error(f"Failed to load YAML file {file_path}: {e}") logger.error(traceback.format_stack()) return {} @staticmethod def _load_csv_file(file_path: str) -> list[dict[str, str]]: """ Load CSV file Args: file_path (str): File path Returns: List[Dict[str, str]]: List of rows as dictionaries """ if not os.path.exists(file_path): logger.debug(f"CSV file does not exist: {file_path}") return [] try: result = [] with open(file_path, encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: result.append(row) logger.debug(f"Successfully loaded CSV file: {file_path} with {len(result)} rows") return result except Exception as e: logger.error(f"Failed to load CSV file {file_path}: {e}") logger.error(traceback.format_stack()) return [] @staticmethod def _save_yaml_file(file_path: str, data: dict) -> bool: """ Save YAML file Args: file_path (str): File path data (Dict): Data to save Returns: bool: Whether the save was successful """ try: with open(file_path, "w", encoding="utf-8") as f: yaml.dump(data, f, default_flow_style=False, allow_unicode=True) return True except Exception as e: logger.error(f"Failed to save YAML file {file_path}: {e}") logger.error(traceback.format_stack()) return False @staticmethod def _save_csv_file(file_path: str, data: list[dict[str, str]], headers: list[str] = None) -> bool: """ Save CSV file Args: file_path (str): File path data (List[Dict[str, str]]): List of rows as dictionaries headers (List[str]): List of header names in sequence Returns: bool: Whether the save was successful """ try: if not data: return True # Get all possible headers from all rows all_headers = set() for row in data: all_headers.update(row.keys()) # If specify field_names, make sure all keys are in field_names if headers is not None: for key in all_headers: if key not in headers: headers.append(key) with open(file_path, "w", encoding="utf-8", newline="") as f: writer = csv.DictWriter(f, fieldnames=headers) writer.writeheader() for row in data: writer.writerow(row) return True except Exception as e: logger.error(f"Failed to save CSV file {file_path}: {e}") logger.error(traceback.format_stack()) return False def _load_tables(self) -> dict[str, list[str]]: # Load table_columns.csv table_columns_csv = self._load_csv_file(self.table_columns_file) # Get unique db_name.table_name combinations table_dict = {} for row in table_columns_csv: if "db_name" in row and "table_name" in row and "column_name" in row: db_name = row["db_name"] table_name = row["table_name"] column_name = row["column_name"] full_table_name = f"{db_name}.{table_name}" if full_table_name not in table_dict: table_dict[full_table_name] = [] table_dict[full_table_name].append(column_name) return table_dict def _load_common_columns(self) -> dict[str, dict[str, Any]]: # Load common_columns.csv to get column details columns_csv = self._load_csv_file(self.common_columns_file) # Filter and return column details column_dict = {} for row in columns_csv: if row.get("column_name") and row.get("type"): # Convert row to Dict[str, Any] column_info = {} for key, value in row.items(): if key != "": column_info[key] = value column_dict[row["column_name"]] = column_info return column_dict def _load_table_spec_columns(self) -> dict[str, dict[str, Any]]: """ Load info of table spec columns Returns: Dict[str, Dict[str, Any]]: Dictionary of table specific columns information, keyed by "full_table_name:column_name" """ # Load table_spec_columns.csv to get table specific column details columns_csv = self._load_csv_file(self.table_spec_columns_file) # Filter and return column details column_dict = {} for row in columns_csv: if "db_name" in row and "table_name" in row and "column_name" in row and row["column_name"]: # Convert row to Dict[(str, str), Any] full_table_name = f"{row['db_name']}.{row['table_name']}" column_info = {} for key, value in row.items(): if key != "": column_info[key] = value column_dict[f"{full_table_name}:{row['column_name']}"] = column_info return column_dict def _parse_example_text(self, example_text: str) -> list[tuple[str, str]]: """ Parse example text, format is Q: ... A: ... Args: example_text (str): Example text Returns: List[Tuple[str, str]]: List of parsed question-answer pairs """ examples = [] lines = example_text.strip().split("\n") question = "" answer = "" current_type = None for line in lines: if line.startswith("Q:"): # If there is already a complete question-answer pair, add it to the results if question and answer: examples.append((question.strip(), answer.strip())) question = "" answer = "" question = line[2:] current_type = "Q" elif line.startswith("A:"): answer = line[2:] current_type = "A" else: # Continue adding to the current type if current_type == "Q": question += "\n" + line elif current_type == "A": answer += "\n" + line # Add the last question-answer pair if question and answer: examples.append((question.strip(), answer.strip())) return examples def get_database_list(self) -> list[str]: # Extract unique database names databases = set() for table in self._get_all_table_schema().keys(): full_table_name, db_name, table_name = split_db_table_name(table) databases.add(db_name) return list(databases) def _get_all_table_schema(self) -> dict[str, list[str]]: """ Get all tables schema (columns of table) Returns: Dict[str, List[str]]: Tables schema (columns) dict, keyed by table name """ if self._table_columns_cache is None: self._table_columns_cache = self._load_tables() # Return a deep copy to prevent external modifications return {k: v.copy() for k, v in self._table_columns_cache.items()} def get_table_list(self, database: str | None = None) -> list[str]: tables = self._get_all_table_schema() if database is None: return list(tables.keys()) # Filter by database filtered_tables = [] for full_table_name in tables.keys(): _, db_name, table_name = split_db_table_name(full_table_name) if db_name == database: filtered_tables.append(full_table_name) return filtered_tables def _get_common_columns(self) -> dict[str, dict[str, Any]]: """ Get information of all common columns Returns: Dict[str, Dict[str, Any]]: Dictionary of columns information, keyed by column name """ if self._common_columns_cache is None: self._common_columns_cache = self._load_common_columns() # Return a deep copy to prevent external modifications return {k: v.copy() for k, v in self._common_columns_cache.items()} def _get_table_spec_columns(self) -> dict[str, dict[str, Any]]: """ Get information of all table specific columns Returns: Dict[str, Dict[str, Any]]: Dictionary of table specific columns information, keyed by "full_table_name:column_name" """ if self._table_spec_columns_cache is None: self._table_spec_columns_cache = self._load_table_spec_columns() # Return a deep copy to prevent external modifications return {k: v.copy() for k, v in self._table_spec_columns_cache.items()} def get_column_list(self, table: str | None = None, database: str | None = None) -> list[dict[str, Any]]: _common_columns = self._get_common_columns() if table is None: return list(_common_columns.values()) # Get the full table name full_table_name, db_name, table_name = split_db_table_name(table, database) # Filter table columns tables_dict = self._get_all_table_schema() if full_table_name not in tables_dict: return [] table_columns = tables_dict[full_table_name] # If no columns found, return empty list if not table_columns: return [] # Filter and return column details result = [] _table_spec_columns = self._get_table_spec_columns() for column in table_columns: # check if the column is table specific key = f"{full_table_name}:{column}" if key in _table_spec_columns: column_info = _table_spec_columns[key] column_info["is_common"] = False result.append(column_info) else: column_info = _common_columns.get(column) if column_info: column_info["is_common"] = True result.append(column_info) return result def get_table_information(self, table: str, database: str | None = None) -> dict[str, Any]: full_table_name, db_name, table_name = split_db_table_name(table, database) if self._table_info_cache is None: self._table_info_cache = self._load_yaml_file(self.table_info_file) if db_name in self._table_info_cache and table_name in self._table_info_cache[db_name]: # Return a copy to prevent external modifications return self._table_info_cache[db_name][table_name].copy() return {} def get_sql_examples( self, table: str | None = None, database: str | None = None ) -> list[tuple[str, str, list[str]]]: if self._sql_example_cache is None: self._sql_example_cache = self._load_yaml_file(self.sql_example_file) if table is None: # If no table specified, return all examples examples = [] for db_name, tables in self._sql_example_cache.items(): for table_name, example_text in tables.items(): qa_pairs = self._parse_example_text(example_text) examples.extend([(q, a, [f"{db_name}.{table_name}"]) for (q, a) in qa_pairs]) return examples full_table_name, db_name, table_name = split_db_table_name(table, database) # Find examples that include this table examples = [] # Check the fact section if db_name in self._sql_example_cache: if table_name in self._sql_example_cache[db_name]: # Parse example text, format is Q: ... A: ... qa_pairs = self._parse_example_text(self._sql_example_cache[db_name][table_name]) examples.extend([(q, a, [full_table_name]) for (q, a) in qa_pairs]) return examples @staticmethod def _load_table_selection_examples_from_csv(file_path: str) -> list[tuple[str, list[str]]]: examples = [] try: with open(file_path, encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: question = row.get("question", "").strip() selected_tables = row.get("selected_tables", "").strip() if question and selected_tables: table_list = [p.strip() for p in re.split(r"[ ,\n]", selected_tables) if p.strip()] examples.append((question, table_list)) except (FileNotFoundError, PermissionError, UnicodeDecodeError) as e: logger.warning(f"Failed to load table selection examples from {file_path}: {e}") return examples def get_table_selection_examples(self) -> list[tuple[str, list[str]]]: if self._table_selection_example_cache is None: self._table_selection_example_cache = self._load_table_selection_examples_from_csv( self.table_selection_example_file ) return self._table_selection_example_cache def save_table_information( self, table: str, information: dict[str, Any], columns: list[dict[str, Any]], database: str | None = None, update_existing: bool = False, ) -> bool: # Validate input data (let validation errors propagate) self._validate_table_name(table) self._validate_table_information(information) self._validate_column_data(columns) try: full_table_name, db_name, table_name = split_db_table_name(table, database) table_info = self._load_yaml_file(self.table_info_file) # Save columns first if not self._save_columns(table_name, columns, db_name, update_existing): logger.error(f"Failed to save columns for table {full_table_name}") return False # Save table information (ensure proper structure) if db_name not in table_info: table_info[db_name] = {} if update_existing or table_name not in table_info[db_name]: table_info[db_name][table_name] = information success = self._save_yaml_file(self.table_info_file, table_info) if success: logger.info(f"Successfully saved table information for {full_table_name}") # Clear cache to ensure consistency self._clear_cache() return success except Exception as e: logger.error(f"Unexpected error when saving table information: {e}") logger.error(traceback.format_stack()) return False def _save_columns( self, table_name: str, columns: list[dict[str, Any]], db_name: str = "", update_existing: bool = False ) -> bool: """ Save columns information to common_columns.csv and columns of tables to table_columns.csv Args: table_name (str): Table name columns (List[Dict[str, Any]]): List of column information db_name (str): Database name update_existing (bool): Update existing column information Returns: bool: Whether the save was successful """ full_table_name, db_name, table_name = split_db_table_name(table_name, db_name) # Load existing data tables_data = self._load_csv_file(self.table_columns_file) common_columns_dict = self._load_common_columns() table_spec_columns_dict = self._load_table_spec_columns() # Create a set of existing table-column combinations existing_table_columns = set() for row in tables_data: if "db_name" in row and "table_name" in row and "column_name" in row: key = f"{row['db_name']}.{row['table_name']}:{row['column_name']}" existing_table_columns.add(key) # Update table_columns.csv and track new columns to add for column in columns: if "column_name" not in column: continue column_name = column["column_name"] is_common_column = column.get("is_common", False) key = f"{full_table_name}:{column_name}" column_info = {k: str(v) for k, v in column.items() if k != "is_common"} if not is_common_column: column_info["db_name"] = db_name column_info["table_name"] = table_name # New column of the table -> add to table_columns.csv if key not in existing_table_columns: tables_data.append({"db_name": db_name, "table_name": table_name, "column_name": column_name}) existing_table_columns.add(key) if is_common_column: # Handle common_columns.csv - avoid duplicates if column_name not in common_columns_dict: # Add new columns to columns_data logger.info(f"Add new column column {column_name}") common_columns_dict[column_name] = column_info else: table_spec_columns_dict[key] = column_info # Apply updates to existing columns in columns_data elif update_existing: if is_common_column: common_columns_dict[column_name] = column_info else: table_spec_columns_dict[key] = column_info # Save updated data tables_success = self._save_csv_file( self.table_columns_file, tables_data, ["db_name", "table_name", "column_name"] ) common_columns_success = self._save_csv_file( self.common_columns_file, list(common_columns_dict.values()), ["column_name", "display_name", "alias", "type", "category", "tag", "description"], ) table_spec_columns_success = self._save_csv_file( self.table_spec_columns_file, list(table_spec_columns_dict.values()), ["db_name", "table_name", "column_name", "display_name", "alias", "type", "category", "tag", "description"], ) success = tables_success and common_columns_success and table_spec_columns_success if success: # Clear cache to ensure consistency self._clear_cache() logger.debug(f"Successfully saved columns for table {table_name}") return success def save_table_sql_examples(self, table: str, examples: list[dict[str, str]], database: str | None = None) -> bool: # Validate input data (let validation errors propagate) self._validate_table_name(table) self._validate_sql_examples(examples) try: full_table_name, db_name, table_name = split_db_table_name(table, database) sql_examples = self._load_yaml_file(self.sql_example_file) # Ensure database exists in structure if db_name not in sql_examples: sql_examples[db_name] = {} # example text example_text = "" for example in examples: example_text += f"Q: {example['question']}\nA: {example['answer']}\n\n" sql_examples[db_name][table_name] = example_text.strip() success = self._save_yaml_file(self.sql_example_file, sql_examples) if success: logger.info(f"Successfully saved {len(examples)} examples for table {full_table_name}") # Update cache self._sql_example_cache = sql_examples return success except Exception as e: logger.error(f"Unexpected error when saving table examples: {e}") logger.error(traceback.format_stack()) return False def save_table_selection_examples(self, examples: list[tuple[str, list[str]]]) -> bool: example_data = [] for example in examples: example_data.append({"question": example[0], "selected_tables": example[1]}) save_success = self._save_csv_file( self.table_selection_example_file, example_data, ["question", "selected_tables"] ) if save_success: logger.info(f"Successfully saved {len(examples)} table selection examples.") return save_success def check_exists(self) -> bool: try: # Check if essential catalog files exist and have content files_missing = ( not os.path.exists(self.table_columns_file) or not os.path.exists(self.common_columns_file) or os.path.getsize(self.table_columns_file) <= 1 # Empty or just header or os.path.getsize(self.common_columns_file) <= 1 ) return not files_missing except Exception as e: logger.warning(f"Error checking catalog existence: {e}") logger.error(traceback.format_stack()) return False ================================================ FILE: openchatbi/catalog/token_service.py ================================================ """Token service for authentication with external services.""" import json import requests class TokenService: """Service for managing authentication tokens. Handles token application, validation, and authentication with external services. """ base_url = None token = None user_name = None password = None def __init__(self, user_name: str, password: str): """Initialize token service.""" self.user_name = user_name self.password = password def apply_token(self): """Apply for authentication token using credentials.""" response = requests.post( self.base_url + "/apply_token", data=json.dumps({"user_name": self.user_name, "password": self.password}) ) resp_json = response.json() self.token = resp_json.get("token") def apply_token_for_user(token_url: str, user_name: str, password: str): """Apply for token and return token with username. Args: token_url (str): Base URL for token service. user_name (str): The user name. password (str): The password. Returns: token """ token_service = TokenService(user_name, password) token_service.base_url = token_url token_service.apply_token() return token_service.token ================================================ FILE: openchatbi/code/docker_executor.py ================================================ import os import shutil import subprocess import tempfile from pathlib import Path import docker from docker.errors import ContainerError from openchatbi.code.executor_base import ExecutorBase def check_docker_status() -> tuple[bool, str]: """ Check Docker installation and status without initializing DockerExecutor. Returns: Tuple[bool, str]: (is_available, status_message) """ try: # Check if Docker CLI is installed if not shutil.which("docker"): return False, "Docker is not installed. Please install Docker." # Check if Docker daemon is running result = subprocess.run(["docker", "info"], capture_output=True, text=True, timeout=10) if result.returncode == 0: return True, "Docker is installed and running" else: if "Cannot connect to the Docker daemon" in result.stderr: return False, "Docker is installed but not running. Please start the Docker daemon." else: return False, f"Docker is not available: {result.stderr.strip()}" except subprocess.TimeoutExpired: return False, "Docker command timed out. Docker may not be running properly." except FileNotFoundError: return False, "Docker command not found. Please install Docker." except Exception as e: return False, f"Error checking Docker status: {str(e)}" class DockerExecutor(ExecutorBase): """Docker-based Python code executor for isolated execution.""" def __init__(self, variable: dict = None): super().__init__(variable) self.image_name = "python-executor" self.dockerfile_path = Path(__file__).parent.parent.parent / "Dockerfile.python-executor" # Check Docker installation and status self._check_docker_availability() try: self.client = docker.from_env() # Build Docker image if it doesn't exist self._ensure_image_exists() except Exception as e: self._handle_docker_error(e) @staticmethod def _check_docker_availability(): """Check if Docker is installed and available.""" # Check if Docker CLI is installed if not shutil.which("docker"): raise RuntimeError("Docker is not installed. Please install Docker and ensure it's in your system PATH.") # Check if Docker daemon is running try: result = subprocess.run(["docker", "info"], capture_output=True, text=True, timeout=10) if result.returncode != 0: if "Cannot connect to the Docker daemon" in result.stderr: raise RuntimeError( "Docker is installed but not running. Please start the Docker daemon and try again." ) else: raise RuntimeError( f"Docker is not available. Please check Docker installation and status. " f"Error: {result.stderr.strip()}" ) except subprocess.TimeoutExpired: raise RuntimeError("Docker command timed out. Please check if Docker is running properly.") except FileNotFoundError: raise RuntimeError("Docker command not found. Please install Docker and ensure it's in your system PATH.") @staticmethod def _handle_docker_error(error: Exception): """Handle Docker-related errors with specific error messages.""" error_str = str(error).lower() if "connection aborted" in error_str and "no such file or directory" in error_str: raise RuntimeError("Docker is not running. Please start the Docker daemon and try again.") elif "permission denied" in error_str: raise RuntimeError( "Permission denied accessing Docker. Please ensure your user has Docker permissions " "or try running with appropriate privileges." ) elif "docker daemon" in error_str or "connection refused" in error_str: raise RuntimeError("Cannot connect to Docker daemon. Please start the Docker daemon and try again.") else: raise RuntimeError( f"Failed to initialize Docker client. Please ensure Docker is installed and running. " f"Error: {str(error)}" ) def _ensure_image_exists(self): """Build Docker image if it doesn't exist.""" try: self.client.images.get(self.image_name) except docker.errors.ImageNotFound: print(f"Building Docker image '{self.image_name}'...") self.client.images.build( path=str(self.dockerfile_path.parent), dockerfile=self.dockerfile_path.name, tag=self.image_name, rm=True, ) print(f"Docker image '{self.image_name}' built successfully.") def run_code(self, code: str) -> tuple[bool, str]: """Execute Python code in a Docker container.""" try: # Create a temporary file with the code with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: # Add variable definitions to the code variable_code = "" for key, value in self._variable.items(): if isinstance(value, str): variable_code += f'{key} = "{value}"\n' else: variable_code += f"{key} = {repr(value)}\n" full_code = variable_code + "\n" + code f.write(full_code) temp_file_path = f.name try: # Run the code in a Docker container container = self.client.containers.run( self.image_name, command=["python3", f"/app/{os.path.basename(temp_file_path)}"], volumes={temp_file_path: {"bind": f"/app/{os.path.basename(temp_file_path)}", "mode": "ro"}}, remove=True, detach=False, stdout=True, stderr=True, network_mode="none", # Disable network access for security ) # Get the output output = container.decode("utf-8") return True, output except ContainerError as e: # Container exited with non-zero code error_output = e.stderr if e.stderr else str(e) return False, f"Container execution failed: {error_output}" except Exception as e: return False, f"Docker execution error: {str(e)}" finally: # Clean up temporary file if "temp_file_path" in locals() and os.path.exists(temp_file_path): try: os.unlink(temp_file_path) except (OSError, PermissionError) as e: # Log but don't fail the operation for cleanup issues print(f"Warning: Failed to clean up temporary file {temp_file_path}: {e}") def __del__(self): """Clean up Docker client on deletion.""" try: if hasattr(self, "client") and self.client is not None: self.client.close() except Exception: # Ignore cleanup errors during object destruction pass ================================================ FILE: openchatbi/code/executor_base.py ================================================ from typing import Any class ExecutorBase: """Base class for executing python code.""" _variable: dict def __init__(self, variable: dict = None): if variable is None: self._variable = {} else: self._variable = variable def run_code(self, code: str) -> (bool, str): """Execute python code.""" raise NotImplementedError() def set_variable(self, key: str, value: Any) -> None: """Set variable.""" self._variable[key] = value ================================================ FILE: openchatbi/code/local_executor.py ================================================ import sys from io import StringIO from openchatbi.code.executor_base import ExecutorBase class LocalExecutor(ExecutorBase): def run_code(self, code: str) -> str: safe_globals = {"__builtins__": __builtins__} original_stdout = sys.stdout output_buffer = StringIO() sys.stdout = output_buffer try: exec(code, safe_globals, safe_globals) output = output_buffer.getvalue() return True, output except Exception as e: return False, str(e) finally: sys.stdout = original_stdout ================================================ FILE: openchatbi/code/restricted_local_executor.py ================================================ import sys from io import StringIO from RestrictedPython import compile_restricted, safe_globals, utility_builtins from RestrictedPython.Guards import safe_builtins, safer_getattr from openchatbi.code.executor_base import ExecutorBase class RestrictedLocalExecutor(ExecutorBase): def run_code(self, code: str) -> (bool, str): try: # compile restricted code byte_code = compile_restricted(code, "", "exec") if byte_code is None: return False, "Failed to compile restricted code" restricted_locals = {} restricted_globals = safe_globals.copy() # Set up restricted environment with necessary functions restricted_globals.update(safe_builtins) restricted_globals["_getattr_"] = safer_getattr restricted_globals["__builtins__"] = utility_builtins # Add variable definitions to the restricted locals for key, value in self._variable.items(): restricted_locals[key] = value # Capture print output original_stdout = sys.stdout output_buffer = StringIO() sys.stdout = output_buffer # Use the standard print function for RestrictedPython restricted_globals["_print_"] = lambda *args, **kwargs: print(*args, **kwargs) exec(byte_code, restricted_globals, restricted_locals) output = output_buffer.getvalue() return True, output except Exception as e: return False, str(e) finally: if "original_stdout" in locals(): sys.stdout = original_stdout ================================================ FILE: openchatbi/config.yaml.template ================================================ organization: The Company dialect: presto bi_config_file: example/bi.yaml # Python Code Execution Configuration # Options: "local", "restricted_local", "docker" # - local: Run code in the current Python process (fastest, least secure) # - restricted_local: Run code with RestrictedPython (moderate security, some limitations) # - docker: Run code in isolated Docker containers (slowest, most secure, requires Docker to be installed) python_executor: local # Visualization configuration # Options: "rule" (rule-based), "llm" (LLM-based), or null (skip visualization) # visualization_mode: llm # Context management configuration # Controls how conversation context is managed and compressed when it becomes too long context_config: # Enable/disable context management entirely enabled: true # Token limit that triggers context management (when conversation exceeds this, compression starts) summary_trigger_tokens: 12000 # Number of recent messages to always preserve in full (never compress these) keep_recent_messages: 20 # Historical tool output compression limits max_tool_output_length: 2000 # Max length for historical tool outputs max_sql_result_rows: 50 # Max rows to keep in CSV results max_code_output_lines: 50 # Max lines for code execution output # Conversation summarization settings enable_summarization: true # Enable conversation summarization enable_conversation_summary: true # Enable detailed conversation summary summary_max_messages: 50 # Max messages to include in summary context # Content preservation settings preserve_tool_errors: true # Always preserve error messages in full preserve_recent_sql: true # Preserve SQL content (less aggressive compression) # Time Series Forecasting Service Configuration # URL for the time series forecasting service endpoint, adjust based on your deployment scenario: # - Local development (OpenChatBI on host, Forecasting service in Docker): "http://localhost:8765" # - Remote service: "http://your-service-host:8765" timeseries_forecasting_service_url: "http://localhost:8765" # Catalog store configuration catalog_store: store_type: file_system data_path: ./example # Data warehouse configuration data_warehouse_config: uri: "presto://{user_name}@domain:8080/db/default" include_tables: - null # null means include all tables, or specify yaml list database_name: "db.default" # database name to use in catalog token_service: "https://tokens-domain:8080/v1" user_name: TOKEN_SERVICE_USER_NAME password: TOKEN_SERVICE_PASSWORD # Vector database (chroma) path # vector_db_path: ./.chroma_db # LLM configurations (multiple providers) # # 1) Define providers under `llm_providers` # 2) Select which one to use by setting `default_llm: ` default_llm: openai llm_providers: openai: default_llm: class: langchain_openai.ChatOpenAI params: api_key: YOUR_API_KEY_HERE model: gpt-4.1 temperature: 0.01 max_tokens: 8192 embedding_model: class: langchain_openai.OpenAIEmbeddings params: api_key: YOUR_API_KEY_HERE model: text-embedding-3-large chunk_size: 1024 # Optional text2sql_llm: class: langchain_openai.ChatOpenAI params: api_key: YOUR_API_KEY_HERE model: gpt-4.1 temperature: 0.0 max_tokens: 8192 # anthropic: # default_llm: # class: langchain_anthropic.ChatAnthropic # params: # api_key: YOUR_API_KEY_HERE # model: claude-3-5-sonnet-latest # MCP (Model Context Protocol) server configurations mcp_servers: # File system MCP server (stdio transport) - name: filesystem transport: stdio command: ["npx", "-y", "@modelcontextprotocol/server-filesystem"] args: ["--path", "/tmp"] enabled: false timeout: 30 # Example HTTP-based MCP server (streamable_http transport) - name: weather transport: streamable_http url: "http://localhost:8000/mcp/" headers: Authorization: "Bearer YOUR_TOKEN" enabled: false timeout: 30 ================================================ FILE: openchatbi/config_loader.py ================================================ import importlib import os from importlib.util import find_spec from typing import Any from unittest.mock import MagicMock from langchain_core.language_models import BaseChatModel from pydantic import BaseModel from openchatbi.catalog.factory import create_catalog_store from openchatbi.utils import log class LLMProviderConfig(BaseModel): """Resolved LLM objects for a single provider.""" model_config = {"arbitrary_types_allowed": True} default_llm: BaseChatModel | MagicMock embedding_model: BaseModel | MagicMock | None = None text2sql_llm: BaseChatModel | MagicMock | None = None class Config(BaseModel): """Configuration model for the OpenChatBI application. Attributes: organization (str): Organization name. Defaults to "The Company". dialect (str): SQL dialect to use. Defaults to "presto". default_llm (BaseChatModel): Default language model for general tasks. embedding_model (BaseModel): Language model for embedding generation. text2sql_llm (Optional[BaseChatModel]): Language model specifically for text-to-SQL tasks. bi_config (Dict[str, Any]): BI configuration loaded from YAML file. Defaults to empty dict. data_warehouse_config (Dict[str, Any]): Data warehouse configuration. Defaults to empty dict. """ model_config = {"arbitrary_types_allowed": True} # General Configurations organization: str = "The Company" dialect: str = "presto" # LLM Configurations default_llm: BaseChatModel | MagicMock embedding_model: BaseModel | MagicMock | None = None text2sql_llm: BaseChatModel | MagicMock | None = None # Multiple LLM providers (optional) llm_provider: str | None = None llm_providers: dict[str, LLMProviderConfig] = {} # BI Configuration bi_config: dict[str, Any] = {} # Data Warehouse Configuration data_warehouse_config: dict[str, Any] = {} # Catalog Store catalog_store: Any = None # Path to the vector database file vector_db_path: str = None # MCP Servers Configuration mcp_servers: list[dict[str, Any]] = [] # Report Configuration report_directory: str = "./data" # Code Execution Configuration python_executor: str = "local" # Options: "local", "restricted_local", "docker" # Visualization Configuration visualization_mode: str | None = "rule" # Options: "rule", "llm", None (skip visualization) # Context Management Configuration context_config: dict[str, Any] = {} # Time Series Service Configuration timeseries_forecasting_service_url: str = "http://localhost:8765" @classmethod def from_dict(cls, config: dict[str, Any]) -> "Config": """Creates a Config instance from a dictionary. Args: config (Dict[str, Any]): Dictionary containing configuration values. Returns: Config: A new Config instance with the provided values. """ return cls(**config) class ConfigLoader: """Singleton class to load and manage configuration settings for OpenChatBI. This class provides methods to load, get, and set configuration parameters for the application, including LLM models, SQL dialect, and other settings. """ _instance = None _config: Config = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance llm_configs = ["default_llm", "embedding_model", "text2sql_llm"] def get(self) -> Config: """Get the current configuration. Returns: Config: The current configuration instance. Raises: ValueError: If the configuration has not been loaded. """ if self._config is None: raise ValueError("Configuration has not been loaded. Please call load() or set() first.") return self._config def load(self, config_file: str = None) -> None: """Load configuration from a YAML file. Args: config_file (str, optional): Path to configuration file. Uses CONFIG_FILE environment variable or 'openchatbi/config.yaml' if not provided. Raises: ImportError: If pyyaml is not installed. FileNotFoundError: If the configuration file cannot be found. """ if config_file is None: config_file = os.getenv("CONFIG_FILE", "openchatbi/config.yaml") if not find_spec("yaml"): raise ImportError("Please install pyyaml to use this feature.") import yaml try: with open(config_file, encoding="utf-8") as file: config_data = yaml.safe_load(file) if config_data is None: config_data = {} except FileNotFoundError: log(f"Configuration file not found: {config_file}, leave config un-loaded.") return except yaml.YAMLError as e: raise ValueError(f"Invalid YAML in configuration file {config_file}: {e}") except Exception as e: raise RuntimeError(f"Failed to read configuration file {config_file}: {e}") self._process_config_dict(config_data) self._config = Config.from_dict(config_data) def _process_config_dict(self, config_data: dict[str, Any]) -> None: """ Processes a configuration dictionary. """ self._process_llm_providers(config_data) providers = config_data.get("llm_providers", {}) selected_provider = None default_llm_value = config_data.get("default_llm") if isinstance(default_llm_value, str): # Simplified multi-provider config: default_llm: if not providers: raise ValueError("default_llm is a provider name but llm_providers is missing.") selected_provider = default_llm_value elif providers: # Backwards-compat: allow selecting provider via llm_provider legacy_provider = config_data.get("llm_provider") if isinstance(legacy_provider, str): selected_provider = legacy_provider elif "default_llm" not in config_data: # Pick the first provider in config order for backwards-compatible YAML behavior selected_provider = next(iter(providers.keys()), None) elif isinstance(default_llm_value, dict): raise ValueError( "When using llm_providers, set default_llm to a provider name (e.g. default_llm: openai), " "not a class config." ) if providers: if not selected_provider or selected_provider not in providers: raise ValueError(f"Unknown LLM provider '{selected_provider}'. Available: {sorted(providers.keys())}") # Store selected provider for runtime lookups (UI/API can still override per-request) config_data["llm_provider"] = selected_provider # Populate top-level LLM objects for legacy call sites config_data["default_llm"] = providers[selected_provider].default_llm config_data.setdefault("embedding_model", providers[selected_provider].embedding_model) config_data.setdefault("text2sql_llm", providers[selected_provider].text2sql_llm) elif "default_llm" not in config_data: raise ValueError("Missing LLM config key: default_llm") if not config_data.get("embedding_model"): log("WARN: Missing LLM config key: embedding_model, will use BM25 based retrival only") if "data_warehouse_config" not in config_data: raise ValueError("Missing Data Warehouse config key: data_warehouse_config") # Load BI configuration if "bi_config_file" in config_data: bi_config = self.load_bi_config(config_data["bi_config_file"]) bi_config.update(config_data.get("bi_config", {})) config_data["bi_config"] = bi_config if "catalog_store" in config_data: if "store_type" not in config_data["catalog_store"]: raise ValueError("catalog_store must have a store_type field.") catalog_store = create_catalog_store( **config_data["catalog_store"], auto_load=config_data["catalog_store"].get("auto_load", True), data_warehouse_config=config_data.get("data_warehouse_config"), ) else: log("Catalog store config key `catalog_store` not found. Using default file system store.") catalog_store = create_catalog_store( store_type="file_system", auto_load=True, data_warehouse_config=config_data.get("data_warehouse_config"), ) config_data["catalog_store"] = catalog_store for config_key in self.llm_configs: config_item = config_data.get(config_key) if not isinstance(config_item, dict) or "class" not in config_item: continue config_data[config_key] = self._instantiate_from_config_dict(config_item, config_key=config_key) def _instantiate_from_config_dict(self, config_item: dict[str, Any], *, config_key: str) -> Any: try: class_path = config_item["class"] if "." not in class_path: raise ValueError(f"Invalid class path format: {class_path}") module_name, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_name) llm_cls = getattr(module, class_name) params = config_item.get("params", {}) return llm_cls(**params) except (ImportError, AttributeError, ValueError, TypeError) as e: raise RuntimeError(f"Failed to load {config_key} class '{config_item.get('class', '')}': {e}") from e def _process_llm_providers(self, config_data: dict[str, Any]) -> None: """Resolve llm_providers into instantiated provider configs (if present).""" raw_providers = config_data.get("llm_providers") if not raw_providers: return if not isinstance(raw_providers, dict): raise ValueError("llm_providers must be a mapping of provider_name -> config") providers: dict[str, LLMProviderConfig] = {} for provider_name, provider_cfg in raw_providers.items(): if isinstance(provider_cfg, LLMProviderConfig): providers[str(provider_name)] = provider_cfg continue if not isinstance(provider_cfg, dict): raise ValueError(f"llm_providers.{provider_name} must be a mapping") resolved_cfg: dict[str, Any] = dict(provider_cfg) for config_key in self.llm_configs: config_item = resolved_cfg.get(config_key) if not isinstance(config_item, dict) or "class" not in config_item: continue resolved_cfg[config_key] = self._instantiate_from_config_dict( config_item, config_key=f"llm_providers.{provider_name}.{config_key}" ) if "default_llm" not in resolved_cfg or resolved_cfg["default_llm"] is None: raise ValueError(f"llm_providers.{provider_name} missing default_llm") providers[str(provider_name)] = LLMProviderConfig(**resolved_cfg) config_data["llm_providers"] = providers def load_bi_config(self, bi_config_file: str) -> dict[str, Any]: """Load BI configuration from a YAML file. Args: bi_config_file (str): Path to the BI configuration file. Defaults to 'example/bi.yaml'. Returns: Dict[str, Any]: The loaded BI configuration as a dictionary. Raises: ImportError: If pyyaml is not installed. FileNotFoundError: If the BI configuration file cannot be found. """ if not find_spec("yaml"): raise ImportError("Please install pyyaml to use this feature.") import yaml bi_config_data = {} try: with open(bi_config_file, encoding="utf-8") as file: bi_config_data = yaml.safe_load(file) or {} except FileNotFoundError: log(f"Warning: BI config file '{bi_config_file}' not found. Ignore load BI config from yaml file.") except yaml.YAMLError as e: log(f"Warning: Invalid YAML in BI config file '{bi_config_file}': {e}. Using empty config.") except Exception as e: log(f"Warning: Failed to read BI config file '{bi_config_file}': {e}. Using empty config.") return bi_config_data def set(self, config: dict[str, Any]) -> None: """Set the configuration from a dictionary. Args: config (Dict[str, Any]): Dictionary containing configuration values. """ self._process_config_dict(config) self._config = Config.from_dict(config) ================================================ FILE: openchatbi/constants.py ================================================ """Constants used throughout the OpenChatBI application.""" # Date/time format strings datetime_format = "%Y-%m-%d %H:%M:%S" date_format = "%Y-%m-%d" datetime_format_ms = "%Y-%m-%d %H:%M:%S.%f" datetime_format_ms_T = "%Y-%m-%dT%H:%M:%S.%fZ" # SQL execution status codes SQL_NA = "SQL_NA" SQL_SUCCESS = "SQL_SUCCESS" SQL_EXECUTE_TIMEOUT = "SQL_CHECK_TIMEOUT" SQL_SYNTAX_ERROR = "SQL_SYNTAX_ERROR" SQL_UNKNOWN_ERROR = "SQL_UNKNOWN_ERROR" MCP_TOOL_DEFAULT_TIMEOUT_SECONDS = 60 ================================================ FILE: openchatbi/context_config.py ================================================ """Configuration for context management settings.""" from dataclasses import dataclass from openchatbi import config @dataclass class ContextConfig: """Configuration class for context management settings.""" # Enable/disable context management enabled: bool = True # Token limits for triggering context management summary_trigger_tokens: int = 12000 # Message retention (how many recent messages to always preserve) keep_recent_messages: int = 20 # Historical tool output compression limits max_tool_output_length: int = 2000 # Max length for historical tool outputs max_sql_result_rows: int = 50 # Max rows to keep in CSV results max_code_output_lines: int = 50 # Max lines for code execution output # Conversation summarization enable_summarization: bool = True enable_conversation_summary: bool = True summary_max_messages: int = 50 # Max messages to include in summary context # Content preservation settings preserve_tool_errors: bool = True # Always preserve error messages in full preserve_recent_sql: bool = True # Preserve SQL content (less aggressive compression) def get_context_config() -> ContextConfig: """Get the current context configuration. This function loads context configuration from the main config system. Falls back to default configuration if not available. Returns: ContextConfig: The current context configuration """ try: main_config = config.get() # Check if context_config exists in the main config if hasattr(main_config, "context_config") and main_config.context_config: context_config_dict = main_config.context_config # Create ContextConfig from the loaded configuration context_config = ContextConfig() for key, value in context_config_dict.items(): if hasattr(context_config, key): setattr(context_config, key, value) return context_config except (ImportError, ValueError, AttributeError): # Fall back to default if config system is not available or configured pass return ContextConfig() def update_context_config(**kwargs) -> ContextConfig: """Update context configuration with new values. Args: **kwargs: Configuration parameters to update Returns: ContextConfig: Updated configuration """ config = get_context_config() for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) return config ================================================ FILE: openchatbi/context_manager.py ================================================ """Context management utilities for handling long conversations.""" import json import re import uuid from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage from openchatbi.context_config import ContextConfig, get_context_config from openchatbi.llm.llm import call_llm_chat_model_with_retry from openchatbi.prompts.system_prompt import get_summary_prompt_template from openchatbi.utils import log class ContextManager: """Manages conversation context to prevent token limit issues.""" def __init__(self, llm: BaseChatModel, config: ContextConfig = None): """Initialize context manager. Args: llm: Language model for summarization config: Context configuration. If None, uses default config. """ self.llm = llm self.config = config or get_context_config() # ============================================================================ # PUBLIC API METHODS # ============================================================================ def manage_context_messages(self, messages: list) -> None: """Main context management function that directly modifies messages list. Args: messages: The list of messages to manage (modified in place) """ if not self.config.enabled: return if not messages: return # Check if we need to manage context estimated_tokens = self.estimate_message_tokens(messages) if estimated_tokens <= self.config.summary_trigger_tokens: return # No action needed log(f"Context management triggered: {estimated_tokens} tokens > {self.config.summary_trigger_tokens}") # Apply historical tool message compression directly self._compress_historical_tool_messages(messages) # Check if we still need summarization after compression remaining_tokens = self.estimate_message_tokens(messages) if remaining_tokens > self.config.summary_trigger_tokens and self.config.enable_summarization: self._apply_conversation_summarization(messages) log("Context management completed") # ============================================================================ # TOKEN ESTIMATION METHODS # ============================================================================ @staticmethod def estimate_tokens(text: str) -> int: """Rough token estimation (1 token ≈ 4 characters for most languages).""" return len(text) // 4 def estimate_message_tokens(self, messages: list[BaseMessage]) -> int: """Estimate total tokens in a list of messages.""" total = 0 for msg in messages: total += self.estimate_tokens(str(msg.content)) # Add tokens for metadata and structure total += 50 return total # ============================================================================ # TOOL OUTPUT TRIMMING METHODS # ============================================================================ def trim_tool_output(self, content: str, tool_name: str = "") -> str: """Trim tool output to manageable size while preserving key information.""" if len(content) <= self.config.max_tool_output_length: return content # Preserve full error messages if configured if self.config.preserve_tool_errors and ("Error:" in content or "Traceback" in content): return content # For SQL results, preserve structure if "```sql" in content or "```csv" in content: return self._trim_structured_output(content) # For code execution results if "```python" in content or "Traceback" in content: return self._trim_code_output(content) # Generic trimming max_len = self.config.max_tool_output_length trimmed = content[: max_len // 2] + "\n\n... [Output truncated] ...\n\n" + content[-max_len // 2 :] return trimmed def _trim_structured_output(self, content: str) -> str: """Trim SQL/CSV output while preserving structure.""" parts = [] # Extract SQL query (always keep) sql_match = re.search(r"```sql\n(.*?)\n```", content, re.DOTALL) if sql_match: parts.append(f"```sql\n{sql_match.group(1)}\n```") # Extract and trim CSV data csv_match = re.search(r"```csv\n(.*?)\n```", content, re.DOTALL) if csv_match: csv_data = csv_match.group(1) lines = csv_data.split("\n") max_rows = self.config.max_sql_result_rows if len(lines) > max_rows: # Keep header + first half + last quarter keep_start = max_rows // 2 keep_end = max_rows // 4 trimmed_csv = "\n".join( lines[: keep_start + 1] + [f"... [{len(lines) - keep_start - keep_end - 1} rows omitted] ..."] + lines[-keep_end:] ) parts.append(f"```csv\n{trimmed_csv}\n```") else: parts.append(f"```csv\n{csv_data}\n```") # Keep visualization info viz_match = re.search(r"Visualization Created:.*", content) if viz_match: parts.append(viz_match.group(0)) return "\n\n".join(parts) def _trim_code_output(self, content: str) -> str: """Trim Python code execution output.""" # Keep error messages (full) if configured if self.config.preserve_tool_errors and ("Traceback" in content or "Error:" in content): return content lines = content.split("\n") max_lines = self.config.max_code_output_lines if len(lines) <= max_lines: return content # Keep first half and last quarter keep_start = max_lines // 2 keep_end = max_lines // 4 return "\n".join(lines[:keep_start] + ["... [Output truncated] ..."] + lines[-keep_end:]) # ============================================================================ # CONVERSATION SUMMARIZATION METHODS # ============================================================================ def summarize_conversation(self, messages: list[BaseMessage]) -> str: """Create a summary of conversation history.""" if not self.config.enable_conversation_summary: return "" # Filter out system messages for summarization # Note: The messages passed in are already historical messages (split point already calculated) messages_to_summarize = [] for msg in messages: if not isinstance(msg, SystemMessage): messages_to_summarize.append(msg) if not messages_to_summarize: return "" # Create summarization prompt conversation_text = self._format_messages_for_summary(messages_to_summarize) # Get the summary prompt template from the file and replace placeholder summary_prompt = get_summary_prompt_template().replace("[conversation_text]", conversation_text) try: response = call_llm_chat_model_with_retry( self.llm, [HumanMessage(content=summary_prompt)], parallel_tool_call=False ) if isinstance(response, AIMessage): return f"[Conversation Summary]: {response.content}" return "[Summary generation failed]" except Exception as e: log(f"Failed to generate conversation summary: {e}") return "[Summary generation failed]" def _truncate_text(self, text: str, truncate_len: int = 500) -> str: # do not truncate Conversation Summary if text.startswith("[Conversation Summary]"): return text if len(text) > truncate_len: return text[:truncate_len] + "... [truncated]" return text def _truncate_text_or_list(self, content): results = [] if isinstance(content, str): results.append(self._truncate_text(content)) elif isinstance(content, list): for item in content: if isinstance(item, str): results.append(self._truncate_text(item)) elif isinstance(item, dict): if item["type"] == "text": results.append(self._truncate_text(item["text"])) elif item["type"] == "tool_use": results.append(json.dumps(item)) return results def _format_messages_for_summary(self, messages: list[BaseMessage]) -> str: """Format messages for summary generation.""" formatted = [] max_messages = self.config.summary_max_messages # Limit messages for summary context for msg in messages[-max_messages:]: if isinstance(msg, HumanMessage): formatted.append(f" {msg.content} ") elif isinstance(msg, AIMessage): content = msg.content or "" formatted.append("") formatted.extend(self._truncate_text_or_list(content)) formatted.append("") elif isinstance(msg, ToolMessage): formatted.append( f" tool_call_id: {msg.tool_call_id}, " f"tool: {msg.name}, " f"status: {msg.status}, " f"result: {self._truncate_text_or_list(msg.content)} " ) return "\n".join(formatted) # ============================================================================ # CONTEXT MANAGEMENT IMPLEMENTATION METHODS # ============================================================================ def _compress_historical_tool_messages(self, messages: list[BaseMessage]) -> None: """Compress historical (not recent) tool messages in place.""" # Find a safe split point recent_start_index = self._find_safe_split_point(messages) # Find tool messages in historical part (before recent_start_index) that need compression for i in range(recent_start_index): msg = messages[i] if isinstance(msg, ToolMessage): original_content = str(msg.content) # Apply intelligent filtering for tool message compression if self._should_compress_historical_tool_message(msg, original_content): trimmed_content = self.trim_tool_output(original_content) if len(trimmed_content) < len(original_content): # Update message content directly messages[i] = ToolMessage( content=trimmed_content, tool_call_id=msg.tool_call_id, id=msg.id, # Keep original ID to preserve position ) log( f"Compressed historical tool message: {len(original_content)} -> {len(trimmed_content)} chars" ) def _apply_conversation_summarization(self, messages: list[BaseMessage]) -> None: """Apply conversation summarization by modifying messages list in place.""" if not self.config.enable_conversation_summary: return # Find a safe split point that doesn't separate AI messages with tool calls from their ToolMessages recent_start_index = self._find_safe_split_point(messages) if recent_start_index == 0: return # No historical messages to summarize historical_messages = messages[:recent_start_index] recent_messages = messages[recent_start_index:] if len(historical_messages) == 1: msg = historical_messages[0] if isinstance(msg, AIMessage) and msg.content.startswith("[Conversation Summary]"): return # Generate summary summary_text = self.summarize_conversation(historical_messages) if summary_text: # Rebuild messages list in place: summary + recent new_messages = [AIMessage(content=summary_text, id=str(uuid.uuid4()))] + recent_messages # Clear and repopulate the list in place messages.clear() messages.extend(new_messages) log(f"Applied conversation summary, removed {len(historical_messages)} historical messages") def _find_safe_split_point(self, messages: list[BaseMessage]) -> int: """Find a safe split point that start at HumanMessage Returns the index where recent messages should start (everything before this index is historical). """ if len(messages) <= self.config.keep_recent_messages: return 0 # Keep all messages as recent # If keep_recent_messages is 0, return all messages as historical if self.config.keep_recent_messages <= 0: return len(messages) # Start from the naive split point naive_split = len(messages) - self.config.keep_recent_messages # Find the nearest HumanMessage for i in range(naive_split, -1, -1): msg = messages[i] if isinstance(msg, HumanMessage) or isinstance(msg, dict) and msg["role"] == "user": return i # Split before this HumanMessage return naive_split # ============================================================================ # CONTENT ANALYSIS HELPER METHODS # ============================================================================ def _should_compress_historical_tool_message(self, tool_msg: ToolMessage, content: str) -> bool: """Determine if a historical tool message should be compressed. Args: tool_msg: The tool message to evaluate content: The content of the tool message Returns: bool: True if the message should be compressed """ # Don't compress if content is already short if len(content) <= self.config.max_tool_output_length: return False # Always preserve error messages if configured if self.config.preserve_tool_errors and self._is_error_content(content): return False # Don't compress recent SQL results if configured if self.config.preserve_recent_sql and self._is_sql_content(content): return False # Compress large outputs from specific tools more aggressively if self._is_data_query_result(content): return True # Compress Python execution results but preserve errors if self._is_python_execution_result(content): return not self._is_error_content(content) # Default: compress if content is long return True def _is_error_content(self, content: str) -> bool: """Check if content contains error information.""" error_indicators = [ "error:", "Error:", "ERROR:", "exception:", "Exception:", "EXCEPTION:", "traceback", "Traceback", "TRACEBACK", "failed", "Failed", "FAILED", "KeyError", "ValueError", "TypeError", "AttributeError", "FileNotFoundError", "ConnectionError", ] return any(indicator in content for indicator in error_indicators) def _is_sql_content(self, content: str) -> bool: """Check if content contains SQL query results.""" sql_indicators = [ "```sql", "query results", "sql query:", "select ", "insert ", "update ", "delete ", "create table", "alter table", ] content_lower = content.lower() return any(indicator in content_lower for indicator in sql_indicators) def _is_data_query_result(self, content: str) -> bool: """Check if content is a data query result that can be safely compressed.""" indicators = [ "```csv", "query results", "rows returned", "records found", "records in the database", "found records", "csv format", ] content_lower = content.lower() return any(indicator in content_lower for indicator in indicators) def _is_python_execution_result(self, content: str) -> bool: """Check if content is Python code execution result.""" indicators = [ "```python", "execution completed", "output:", "result:", "print(", ] return any(indicator.lower() in content.lower() for indicator in indicators) ================================================ FILE: openchatbi/graph_state.py ================================================ """State classes for OpenChatBI graph execution.""" from typing import Annotated, Any from langchain_core.messages import AIMessage, HumanMessage from langgraph.graph import MessagesState from langgraph.types import Send def add_history_messages(left: list, right: list): if left: total_messages = left + right else: total_messages = right return total_messages class AgentState(MessagesState): """State for the main agent graph execution. Extends MessagesState with additional fields for routing and responses. """ history_messages: Annotated[list[HumanMessage | AIMessage], add_history_messages] agent_next_node: str sends: list[Send] sql: str final_answer: str class SQLGraphState(MessagesState): """State for SQL generation subgraph. Contains rewritten question, table selection, extracted entities, and generated SQL. """ rewrite_question: str tables: list[dict[str, Any]] info_entities: dict[str, Any] sql: str sql_retry_count: int sql_execution_result: str schema_info: dict[str, Any] # Data schema analysis results data: str # CSV data for display previous_sql_errors: list[dict[str, Any]] visualization_dsl: dict[str, Any] class InputState(MessagesState): """Input state schema for the main graph.""" pass class OutputState(MessagesState): """Output state schema for the main graph.""" pass class SQLOutputState(MessagesState): """Output state schema for the SQL generation subgraph.""" rewrite_question: str tables: list[dict[str, Any]] sql: str schema_info: dict[str, Any] # Data schema analysis results data: str # CSV data for display visualization_dsl: dict[str, Any] ================================================ FILE: openchatbi/llm/llm.py ================================================ import time import traceback from langchain_core.language_models import BaseChatModel from langchain_core.runnables.base import RunnableBinding from langchain_core.tools import StructuredTool from openchatbi import config from openchatbi.tool.ask_human import AskHuman from openchatbi.utils import log def list_llm_providers() -> list[str]: """List configured LLM provider names (if any).""" try: providers = getattr(config.get(), "llm_providers", None) or {} except ValueError: return [] return sorted(providers.keys()) def _get_provider_config(provider: str | None): cfg = config.get() providers = getattr(cfg, "llm_providers", None) or {} if not provider: provider = getattr(cfg, "llm_provider", None) if not provider: return None if provider not in providers: raise ValueError(f"Unknown llm_provider '{provider}'. Available: {sorted(providers.keys())}") return providers[provider] def get_embedding_model(provider: str | None = None): """Get embedding model from config (optionally scoped to a provider).""" provider_cfg = _get_provider_config(provider) if provider_cfg and getattr(provider_cfg, "embedding_model", None) is not None: return provider_cfg.embedding_model return config.get().embedding_model def get_default_llm(provider: str | None = None): """Get default LLM from config (optionally scoped to a provider).""" provider_cfg = _get_provider_config(provider) if provider_cfg: return provider_cfg.default_llm return config.get().default_llm def get_llm(provider: str | None = None): """Get the chat model to use (alias for `get_default_llm`).""" return get_default_llm(provider) def get_text2sql_llm(provider: str | None = None): """Get text2sql LLM from config (optionally scoped to a provider).""" provider_cfg = _get_provider_config(provider) if provider_cfg: return provider_cfg.text2sql_llm or provider_cfg.default_llm return config.get().text2sql_llm or get_default_llm() def _invalid_tool_names(valid_tools, tool_calls) -> str: invalid_tools = [] for tool in tool_calls: if tool["name"] not in valid_tools: invalid_tools.append(tool["name"]) return ",".join(invalid_tools) def call_llm_chat_model_with_retry( chat_model: BaseChatModel, messages, streaming_tokens=False, bound_tools=None, parallel_tool_call=False ): """Calls a language model chat endpoint with retry logic. Retries up to 3 times if there are errors or invalid tool calls. Args: chat_model: The chat model to invoke. messages (list): List of messages to send to the model. streaming_tokens (bool, optional): flag to indicate whether or not to show streaming tokens in UI. bound_tools (list, optional): List of valid tool names that can be called. parallel_tool_call (bool, optional): whether or not to call multiple tools in parallel. Returns: AIMessage or None: The model response or None if all retries failed. """ new_messages = list(messages) valid_tools = [] if bound_tools: for tool in bound_tools: if isinstance(tool, str): valid_tools.append(tool) elif isinstance(tool, StructuredTool): valid_tools.append(tool.name) elif tool == AskHuman: valid_tools.append("AskHuman") elif isinstance(chat_model, RunnableBinding) and "tools" in chat_model.kwargs: valid_tools += [tool["name"] for tool in chat_model.kwargs["tools"] if "name" in tool] extra_prompt = ( " Please select the `AskHuman` tool if you need to confirm with user." if "AskHuman" in valid_tools else "" ) response = None retry = 0 # retry 3 times while retry < 3: start_time = time.time() try: log(f"Call LLM chat model with retry {retry} times.") response = chat_model.invoke(new_messages, config={"metadata": {"streaming_tokens": streaming_tokens}}) run_time = int(time.time() - start_time) log(f"LLM response after {run_time} seconds.") except Exception: run_time = int(time.time() - start_time) retry += 1 log(f"LLM response error after {run_time} seconds, retry {retry} times.") log("===== Messages:") log(str(messages)) traceback.print_exc() continue if response.tool_calls: if len(response.tool_calls) > 1 and not parallel_tool_call: retry += 1 log(f"More than one tool {response.tool_calls}, retry {retry} times.") new_messages += [{"role": "user", "content": "You should only response with one tool call."}] response = None continue invalid_tools = _invalid_tool_names(valid_tools, response.tool_calls) if invalid_tools: retry += 1 log(f"Invalid tool {invalid_tools}, retry {retry} times.") new_messages += [ { "role": "user", "content": f"You should not use tool that does not exist:`{invalid_tools}`." f"Available tools are: {valid_tools}. Please choose a valid tool and try again." f"{extra_prompt}", } ] response = None continue break return response ================================================ FILE: openchatbi/prompts/agent_prompt.md ================================================ You are a helpful BI assistant that can answer user's question. Use the instructions below and the tools available to you to assist the user. # Capabilities: 1. Answer general question. 2. Answer question based on knowledge base. 3. Answer question regarding data query by call the SQL graph to write SQL to answer the question. 4. Answer question that need to analyze the data by write and execute python code # Guidelines: - You should be concise, direct, and to the point. - No fabricate information, if you don't know, just say you don't know. - Summarize the information you found to answer the question. - When data analysis results include "Visualization Created" message, acknowledge that an interactive chart has been automatically generated and focus on interpreting the data insights rather than creating additional charts. # Tool usage policy - If you cannot answer the question, call tools that are available. - For `run_python_code` tool, you can use these libs when writing python code: pandas numpy matplotlib seaborn requests json5 - IMPORTANT: DO NOT create charts/visualizations with Python code if the text2sql tool response already indicates "Visualization Created". The interactive chart is automatically generated and displayed in the UI. Simply summarize the results without duplicating the visualization. - If user provide personalized information that need to remember or want to forget or correct something mentioned before, use `manage_memory` tool to save, delete or update the long term memory - If the question is related to user information, characteristic or preference, proactively use `search_memory` tool to get the long term memory - If the question is not clear, or some information is missing, ask the user to clarify by calling AskHuman tool. - When generating reports, analysis results, or data summaries that users might want to save or share, use the `save_report` tool to save the content to a file and provide a download link. - **When text2sql tool returns empty SQL**: This indicates the current data capabilities cannot support the requested query. Explain to the user that the requested data or analysis is not available in the current system, and suggest alternative queries that might be supported based on available data sources. ## Knowledge Search Optimization - **AVOID excessive knowledge searches** for data queries that contain standard business terms already covered in your basic knowledge - **ONLY search knowledge** when: - User asks about unfamiliar business terms, metrics, or dimensions not in basic knowledge - Question contains ambiguous terminology that needs clarification - Need to understand complex business relationships or derived metrics - User explicitly asks "what is [term]" or requests definitions - **SKIP knowledge search** for straightforward data queries since `text2sql` tool will handle it - **Prioritize direct SQL execution** over knowledge lookup for routine data analysis requests [extra_tool_use_rule] # Basic Business Knowledge: [basic_knowledge_glossary] # Realtime Environment Current time is [time_field_placeholder] (format 'yyyy-MM-dd HH:mm:ss') Review current state and decide what to do next. If the information is sufficient to answer the question, generate the well summarized final answer. ================================================ FILE: openchatbi/prompts/extraction_prompt.md ================================================ You are a specialized language expert responsible for analyzing user questions and extracting structured information for business intelligence queries. Your task is to process natural language questions and convert them into structured data that can be used for SQL generation and data analysis. # Context You will be provided with: - Business knowledge glossary of [organization] - User question - Chat history (if available) [basic_knowledge_glossary] # Core Processing Steps ## Step 1: Information Extraction Extract and categorize the following information from the user's question and context: ### 1.1 Keywords (Required Array) Extract all relevant business terms, including: - Dimension names and aliases - Metric names and aliases - Entity types (exclude specific IDs/values) **Example**: "Show revenue for order 10001" → Extract: ["revenue", "order"] (exclude "10001") ### 1.2 Dimensions (Required Array) Identify categorical data fields that can be used for grouping or filtering: - Database column names (e.g., "order_id", "country", "site_id") - Distinguish between ID fields (numeric identifiers) and name fields (text labels) ### 1.3 Metrics (Optional Array) Identify measurable quantities that can be aggregated: - Numeric values that can be summed, averaged, counted, etc. - For derived metrics (defined in glossary), extract all component parts - Example: For "click-through rate", extract ["click-through rate", "clicks", "impressions"] ### 1.4 Time Range (Optional) **start_time** and **end_time**: Convert relative time expressions to absolute timestamps if the question is related to date/time like trends, aggregated metric, etc. - Format: `'%Y-%m-%d %H:%M:%S'` - Handle expressions like "yesterday", "last 7 days", "from X to Y" - Default to "last 7 days" if no time range and granularity specified - Specific default if user mentioned granularity: - Weekly -> "last 12 weeks" - Monthly -> "last 12 months" - Yearly -> "Full data" **Example**: ``` Question: "show top 10 ads by CTR yesterday" (today = 2025-05-11) start_time: "2025-05-10 00:00:00" end_time: "2025-05-10 23:59:59" ``` ### 1.5 Timezone (Optional) Extract timezone information using this priority: 1. Explicit mention in current question (e.g., "in CET", "EST time") 2. Previously mentioned timezone in conversation history 3. Reset timezone requests → "UTC" **Common formats**: "America/New_York", "CET", "UTC", "Europe/London" ## Step 2: Filter Conditions Generate SQL-compatible filter expressions: **Rules**: - **Text matching**: Use `LIKE '%text%'` for partial name matches - **Exact IDs**: Use `=` for numeric identifiers - **Missing context**: Generate `AskHuman` tool call for clarification **Examples**: - "profile 1234" → `["profile_id=1234"]` - "exam sites" → `["site_name LIKE '%exam%'"]` - "the site" (no context) → Ask for clarification ## Step 3: Question Rewriting Transform the original question into a clear, comprehensive query specification. **Process**: 1. **Analysis**: Break down each component of the user's request 2. **Verification**: Confirm all elements are understood and unambiguous 3. **Rewrite**: Create detailed, explicit version with no ambiguity **Enhancement Rules**: - Add metric definitions in brackets: "CTR" → "click-through rate (clicks/impressions)" - Include default time range if none specified - Include visualization preference if provided by user - Preserve user intent while adding necessary context - Use conversation history to fill gaps # Knowledge Search Decision Before extracting information, determine if knowledge search is needed: ## When to Search Knowledge (use `search_knowledge` tool): - **Unfamiliar terms**: Business-specific jargon, custom metrics, or domain acronyms not in basic knowledge - **Ambiguous terminology**: Terms that could have multiple meanings in business context - **Complex derived metrics**: Multi-component calculations requiring formula understanding - **Explicit requests**: User asks "what is [term]" or requests definitions ## When to Skip Knowledge Search (proceed with JSON extraction): - **Standard business terms**: Common metrics (revenue, orders, users, clicks, CTR, conversion rate) - **Basic dimensions**: Standard fields (date, time, location, category, status, id) - **Clear data requests**: Simple queries with well-understood terminology - **Routine analytics**: Top N, totals, averages, trends with common business terms **Decision rule**: Only search knowledge if you encounter terms that are NOT covered in your basic business knowledge or if terminology is genuinely ambiguous in the business context. # Output Format Return a JSON object with the following structure: ```json { "reasoning": "Step-by-step analysis of user input and decision-making process", "keywords": ["array", "of", "extracted", "keywords"], "dimensions": ["array", "of", "dimension", "names"], "metrics": ["array", "of", "metric", "names"], "filter": ["array", "of", "sql", "expressions"], "start_time": "YYYY-MM-DD HH:MM:SS", "end_time": "YYYY-MM-DD HH:MM:SS", "timezone": "timezone_identifier", "rewrite_question": "Complete and detailed question rewrite" } ``` # Quality Guidelines ## Data Consistency - If a dimension appears in filters, include it in the dimensions array - Extract all aliases for derived metrics as defined in the glossary ## Accuracy Rules - **No fabrication**: Only use information present in context or glossary - **Prioritization**: Current question takes precedence over chat history - **Completeness**: Use chat history to fill gaps when current question lacks detail ## Output Formatting - **Standard response**: JSON wrapped in ```json code blocks - **Clarification needed**: Generate `AskHuman` tool call instead of JSON - **Required fields**: Always include `reasoning`, `keywords`, `dimensions`, `filter`, `rewrite_question` # Comprehensive Example **Input Question**: "Show me site 1001's CTR trend from 2024-04-01 to 2024-04-10" **Expected Output**: ```json { "reasoning": "User wants to analyze click-through rate trends for a specific site. Breaking down the request: 1) Site identifier: 1001 (numeric ID), 2) Metric: CTR (click-through rate, calculated as clicks/impressions), 3) Analysis type: trend (time-based progression), 4) Time range: 2024-04-01 to 2024-04-10 (9-day period). Since it's a short time range, hourly granularity is most appropriate for trend analysis. All components are clear and complete.", "keywords": ["site", "click-through rate", "CTR", "clicks", "impressions", "trend"], "dimensions": ["site_id"], "metrics": ["click-through rate", "clicks", "impressions"], "filter": ["site_id=1001"], "start_time": "2024-04-01 00:00:00", "end_time": "2024-04-11 00:00:00", "rewrite_question": "Show me the hourly click-through rate (calculated as clicks/impressions) trend for site_id = 1001 from 2024-04-01 to 2024-04-10" } ``` # Special Cases ## Case 1: Insufficient Information **Input**: "Show me revenue trends for the site" **Action**: Generate `AskHuman` tool call requesting site identification ## Case 2: Conversation Context Usage **Previous**: "Let's analyze site ABC performance" **Current**: "Show me CTR for last week" **Result**: Inherit site "ABC" context ## Case 3: Timezone Handling **Input**: "Yesterday's metrics in EST" **Result**: Extract timezone="America/New_York", calculate yesterday in EST # Environment Variables - Current date: `[time_field_placeholder]`\ ================================================ FILE: openchatbi/prompts/schema_linking_prompt.md ================================================ You are a language expert and professional SQL engineer tasked with analyzing questions from [organization] users and selecting the appropriate table to write SQL. - You need to analyze the user's question, find the possible dimensions and metrics, and then select the tables and all required columns related to the query. - I will give you the business knowledge introduction and the glossaries of [organization] for reference. - I will give you the data warehouse introduction about how these tables are generated and organized. - I will give you the candidate tables and their schema, read the table description and rule carefully to understand the purpose and capability of the table, and select the appropriate tables and columns. [basic_knowledge_glossary] [data_warehouse_introduction] # Candidate Tables I found the following tables and their relevant columns and descriptions that might contain the data the user is looking for. [tables] # Examples Here are some examples of questions and selected tables related to the user's question [examples] # General Rules - Must follow the table description and rule to select the table first - If it is not clear which table to select, you can check the columns in the table to find the columns most related to the question - The "Candidate Tables" contain all the tables and columns you can use, NEVER make up columns or tables. - VERY IMPORTANT: the columns you outputted **MUST** be contained in the table you selected, as described in the "# Candidate Tables" section. - If the question is asking about the metadata of an entity only, you should find a suitable dimension table - If the question needs to join the fact table with the dimension table, you should also output the dimension table - If there are very similar questions in examples, you can refer to the selected tables in examples. - If there are multiple tables that both need requirements, you should select the most relevant one. - Select and output multiple tables when single table do not contain all fields and need join from multiple tables. # Output Format You should output a JSON object, it should include: - tables: JSON array of selected tables and columns - table: The selected table - columns: The columns in the table that are related to the question - reasoning: The reasoning behind the table selection Strictly only output the format of JSON below, and do not output any extra description content. ## Example ```json { "reasoning": "the reason you select the two tables and columns", "tables": [ { "table": "table_name1", "columns": ["column1", "column2", "column3"] }, { "table": "table_name2", "columns": ["column4", "column5"] }] } ``` ================================================ FILE: openchatbi/prompts/sql_dialect/presto.md ================================================ # Rules for Presto SQL - Use 'LIKE' instead of 'ILIKE' in the Presto SQL. - If there is a 'GROUP BY' clause in the user query, you can use serial number(1,2,3..) instead of the column names in the 'GROUP BY' clause. - When filter Array type dimension, use ARRAYS_OVERLAP, e.g. ARRAYS_OVERLAP(states, ARRAY['CA']) - If you have to write two SQL statements, ensure to separate them with a semicolon `;`. - If there is no 'limit' or 'top' count mentioned in user question, default use "LIMIT 10000" in the SQL query. - If the SQL you provide includes fuzzy matching filters (e.g., 'name LIKE ...'), you should apply a GROUP BY clause for this dimension to handle cases where multiple rows have similar names. - If a table name is referenced multiple times in the SQL query (e.g., table_name.column), assign an alias to the table to simplify the query, such as (a.column). - If you need to write an SQL statement that involves division between two columns, use CAST to convert the numerator to a double type and ensure the denominator is not zero. For example: CASE WHEN SUM(column2)=0 THEN 0 ELSE CAST(SUM(column1) AS DOUBLE) / SUM(column2) END ## Datetime filter related rules - Please use "INTERVAL '7' DAY" instead of "INTERVAL '1' WEEK" in the presto sql you given. - Do not use "DATE_SUB" or "DATE_ADD", You can use only datetime calculation like "NOW() - INTERVAL '1' DAY". - Use `NOW()` instead of `CURRENT_DATE` when you need to get the current date. - If user ask date range measured in days or months, you need to make sure the end date is included, example: "from 2025-03-12 to 2025-03-16", the condition should be `WHERE event_date >= timestamp '2025-03-12 00:00:00' and event_date < timestamp '2025-03-17 00:00:00'` - If user ask for time range measured in hours, the end time is not included, example: "from 2025-03-12 00:00:00 to 2025-03-16 00:00:00", the condition should be `WHERE event_date >= timestamp '2025-03-12 00:00:00' and event_date < timestamp '2025-03-16 00:00:00'` - If user ask for daily breakdown, make sure the whole day are correctly filtered, e.g. "last 7 days per day" should be `WHERE event_date >= DATE_TRUNC('day', (NOW() - INTERVAL '7' DAY)) AND event_date < DATE_TRUNC('day', NOW())` ## Rules for Timezone ### 1. Default Timezone All event_date in the table are stored in **UTC**. If the user specifies a timezone (e.g., CET, PST), convert between timezones accordingly. ### 2. Timezone Conversion Syntax - Use `AT TIME ZONE` to convert event_date to other timezone. Example, to convert to CET: `event_date_expr AT TIME ZONE 'CET'` - User `with_timezone` function to define a constant timestamp with timezone. Example, 2025-05-06 00:00 at CET: `with_timezone(timestamp '2025-05-06 00:00:00', 'CET')` ### 3. WHERE Clause Conversion - If the user query includes absolute(constant) filters with a specific timezone, convert the timestamp with user timezone to UTC. Keep relative time filters unchanged. - Example: `WHERE event_date >= timestamp '2025-01-01 00:00:00' and event_date < NOW() - INTERVAL '1' DAY` -> `WHERE event_date >= with_timezone(timestamp '2025-01-01 00:00:00', CET') AT TIME ZONE 'UTC' AND event_date < NOW() - INTERVAL '1' DAY` - Instruction when user ask for daily breakdown with timezone - If ask for relative date, the filter condition should use the date as "trunc date at that timezone first, then convert to UTC" - Example: "last 7 days per day in NY time" should be `AND event_date >= DATE_TRUNC('day', (NOW() - INTERVAL '7' DAY) AT TIME ZONE 'America/New_York') AT TIME ZONE 'UTC' AND event_date < DATE_TRUNC('day', NOW() AT TIME ZONE 'America/New_York') AT TIME ZONE 'UTC'` - If ask for absolute(constant) date, the filter condition should convert the 00:00 timestamp with user timezone to UTC - Example: "during 2025-05-04 and 2025-05-14 per day in NY time" should be `AND event_date >= with_timezone(timestamp '2025-05-04 00:00:00', 'America/New_York') AT TIME ZONE 'UTC' AND event_date < with_timezone(timestamp '2025-05-15 00:00:00', 'America/New_York') AT TIME ZONE 'UTC'` ### 4. SELECT Clause Conversion - If applying a function f to a datetime column (e.g.date_trunc, format_datetime), convert the event_date from UTC to the user’s timezone before applying f, then cast the result as TIMESTAMP. - Example: `SELECT f(event_date) AS event_date`-> `SELECT CAST(f(event_date AT TIME ZONE 'CET') AS TIMESTAMP) AS event_date` ### 5. Full Example - User Question: "Show me hourly pv using table fact_table from 2025-01-01 to yesterday in CET" - Generated SQL: ``` SELECT CAST(date_trunc('hour', event_date AT TIME ZONE 'CET') AS TIMESTAMP) AS event_date, SUM(pv) AS "PV" FROM fact_table WHERE event_date >= with_timezone(timestamp '2025-01-01 00:00:00', 'CET') AT TIME ZONE 'UTC' AND event_date < (NOW() - INTERVAL '1' DAY) ``` ## Rules for Array Dimension - Filtering: Use ARRAYS_OVERLAP - When filter value in Array type dimension , use ARRAYS_OVERLAP, e.g. ARRAYS_OVERLAP(states, ARRAY['CA']) - Flattening Arrays: Use CROSS JOIN UNNEST - When filter Array type dimension items from a row to multi rows, use `CROSS JOIN UNNEST(COALESCE(NULLIF(id_array, ARRAY[]), ARRAY[-1])) AS t(id)`; MAKE SURE the untested alias has format like `t(id)` - Additionally, when the subquery uses CROSS JOIN UNNEST, do not sum the metrics for the total array items without group by the unnested id. - Avoid UNNEST if the user didn’t request it - If the user refers to an array-type dimension without specifying any particular item, or does not ask to expand it into individual elements, you should not use CROSS JOIN UNNEST. ================================================ FILE: openchatbi/prompts/summary_prompt.md ================================================ Create a concise summary of this conversation for continuing the data analysis work. Focus on: 1. **User's Main Questions and Objectives**: What data insights or analysis the user is seeking, business questions they want answered 2. **Key Data Analysis Results**: Important findings from SQL queries, relevant tables/columns discovered, key metrics or patterns identified 3. **Tools and Data Sources Overview**: - Key databases/tables accessed - Main analysis tools used (SQL, Python, etc.) - Data export formats generated 4. **Business Context**: Important business concepts, domain-specific terms, data definitions that were clarified 5. **Conversation Flow**: List ALL user messages with corresponding response summaries. These are critical for understanding user feedback and changing intent: - User message -> Response summary (what was done/analyzed) - User feedback -> How the analysis was adjusted - Follow-up questions -> Additional insights provided 6. **Current Progress**: What analysis was completed, any ongoing tasks, user feedback on results or requested modifications Here's an example of how your output should be structured: 1. **User's Main Questions and Objectives**: [User's data analysis goals and business questions] 2. **Key Data Analysis Results**: - [Important SQL query results] - [Key tables and relationships discovered] - [Metrics and insights found] 3. **Tools and Data Sources Overview**: - [Databases: customer_db, sales_warehouse] - [Analysis tools: SQL, Python pandas] - [Exports: CSV reports, dashboard charts] 4. **Business Context**: - [Business concepts and terminology] - [Data field definitions] - [Domain knowledge gained] 5. **Conversation Flow**: - [User message 1: original request] -> [Response: SQL query executed, found X insights] - [User message 2: clarification about metric Y] -> [Response: adjusted analysis, discovered Z pattern] - [User message 3: follow-up question] -> [Response: additional exploration, provided interpretation] 6. **Current Progress**: [What was completed, ongoing work, and user feedback] Conversation to summarize: [conversation_text] Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response. ================================================ FILE: openchatbi/prompts/system_prompt.py ================================================ """System prompt templates and business configuration.""" import importlib.resources from openchatbi import config # Global cache variables for lazy loading (only for file I/O operations) _dialect_rules_cache = None _agent_prompt_template_cache = None _extraction_prompt_template_cache = None _table_selection_prompt_template_cache = None _text2sql_prompt_template_cache = None _visualization_prompt_template_cache = None _summary_prompt_template_cache = None def get_basic_knowledge(): """Get basic knowledge from config.""" try: return config.get().bi_config.get("basic_knowledge_glossary", "") except ValueError: return "" def get_data_warehouse_introduction(): """Get data warehouse introduction from config.""" try: return config.get().bi_config.get("data_warehouse_introduction", "") except ValueError: return "" def get_agent_extra_tool_use_rule(): """Get agent extra tool use rule from config.""" try: return config.get().bi_config.get("extra_tool_use_rule", "") except ValueError: return "" def get_organization(): """Get organization from config.""" try: return config.get().organization except ValueError: return "The Company" def get_dialect_rules(): """Get SQL dialect rules with lazy loading and caching.""" global _dialect_rules_cache if _dialect_rules_cache is None: dialect_dir = importlib.resources.files("openchatbi.prompts.sql_dialect") _dialect_rules_cache = {} for item in dialect_dir.iterdir(): if item.is_file() and item.name.endswith(".md"): dialect_name = item.name[:-3] with item.open() as f: prompt = f.read() _dialect_rules_cache[dialect_name] = prompt return _dialect_rules_cache def get_agent_prompt_template() -> str: """Get agent prompt template with caching.""" global _agent_prompt_template_cache if _agent_prompt_template_cache is None: with importlib.resources.files("openchatbi.prompts").joinpath("agent_prompt.md").open("r") as f: prompt = f.read() _agent_prompt_template_cache = ( prompt.replace("[organization]", get_organization()) .replace("[basic_knowledge_glossary]", get_basic_knowledge()) .replace("[extra_tool_use_rule]", get_agent_extra_tool_use_rule()) ) return _agent_prompt_template_cache def get_extraction_prompt_template() -> str: """Get extraction prompt template with caching.""" global _extraction_prompt_template_cache if _extraction_prompt_template_cache is None: with importlib.resources.files("openchatbi.prompts").joinpath("extraction_prompt.md").open("r") as f: prompt = f.read() _extraction_prompt_template_cache = prompt.replace("[organization]", get_organization()).replace( "[basic_knowledge_glossary]", get_basic_knowledge() ) return _extraction_prompt_template_cache def get_table_selection_prompt_template() -> str: """Get table selection prompt template with caching.""" global _table_selection_prompt_template_cache if _table_selection_prompt_template_cache is None: with importlib.resources.files("openchatbi.prompts").joinpath("schema_linking_prompt.md").open("r") as f: prompt = f.read() _table_selection_prompt_template_cache = prompt.replace("[organization]", get_organization()).replace( "[basic_knowledge_glossary]", get_basic_knowledge() ) return _table_selection_prompt_template_cache def get_text2sql_prompt_template() -> str: """Get text2sql prompt template with caching.""" global _text2sql_prompt_template_cache if _text2sql_prompt_template_cache is None: with importlib.resources.files("openchatbi.prompts").joinpath("text2sql_prompt.md").open("r") as f: prompt = f.read() _text2sql_prompt_template_cache = ( prompt.replace("[organization]", get_organization()) .replace("[basic_knowledge_glossary]", get_basic_knowledge()) .replace("[data_warehouse_introduction]", get_data_warehouse_introduction()) ) return _text2sql_prompt_template_cache def get_visualization_prompt_template() -> str: """Get visualization prompt template with caching.""" global _visualization_prompt_template_cache if _visualization_prompt_template_cache is None: with importlib.resources.files("openchatbi.prompts").joinpath("visualization_prompt.md").open("r") as f: _visualization_prompt_template_cache = f.read() return _visualization_prompt_template_cache def get_summary_prompt_template() -> str: """Get summary prompt template with caching.""" global _summary_prompt_template_cache if _summary_prompt_template_cache is None: with importlib.resources.files("openchatbi.prompts").joinpath("summary_prompt.md").open("r") as f: _summary_prompt_template_cache = f.read() return _summary_prompt_template_cache def get_text2sql_dialect_prompt_template(dialect: str) -> str: """Get text2sql prompt template for specific SQL dialect.""" prompt = get_text2sql_prompt_template() if not prompt: prompt = "Generate SQL query for the given question in [dialect] dialect." dialect_rules = get_dialect_rules() prompt = prompt.replace("[dialect]", dialect).replace("[sql_dialect_rules]", dialect_rules.get(dialect, "")) return prompt def reset_cache(): """Reset all cached values. Useful for testing.""" global _dialect_rules_cache, _agent_prompt_template_cache global _extraction_prompt_template_cache, _table_selection_prompt_template_cache global _text2sql_prompt_template_cache, _visualization_prompt_template_cache global _summary_prompt_template_cache _dialect_rules_cache = None _agent_prompt_template_cache = None _extraction_prompt_template_cache = None _table_selection_prompt_template_cache = None _text2sql_prompt_template_cache = None _visualization_prompt_template_cache = None _summary_prompt_template_cache = None ================================================ FILE: openchatbi/prompts/text2sql_prompt.md ================================================ You are a professional SQL engineer, your task is to transform user query into [dialect] SQL. - I will give you the business knowledge introduction and the glossaries of [organization] for reference. - I will give you the selected tables, you need to analyze the user query, read the table description, schema, constrains and examples carefully to write [dialect] SQL to answer user's question. - You are a read-only analytics assistant. NEVER generate DELETE, DROP, UPDATE, or INSERT statements. [basic_knowledge_glossary] # Tables [table_schema] # Examples [examples] # Rules for [dialect] SQL [sql_dialect_rules] # Rules for Task - I will provide you with data schema definition and the explanation and usage scenario of each field. - You can only use the tables listed in "# Tables". - You can only use the metrics, dimension, columns from the schema I provided. - You should only use the display name as alias in query if provided in schema. - Never create or assume additional tables or columns, even if they were mentioned in history message. - Do not use any id or date in example SQL. - Do not output any explanations or comment. - If the query asks for a metric or field not explicitly defined in the table schema, do not generate a SQL query with an invented field, instead, you should output "NULL". - You can only answer when you are very confident, otherwise, please output "NULL" # Output format(case sensitive) ```sql ``` # Realtime Environment Current time is [time_field_placeholder] (format 'yyyy-MM-dd HH:mm:ss') Based on the Tables, Columns, take your time to think user query carefully, transform it into [dialect] SQL and reply following Output format. ================================================ FILE: openchatbi/prompts/visualization_prompt.md ================================================ You are a data visualization expert. Analyze the user's question and data to recommend the most appropriate chart type. ## User Question [question] ## Data Schema - Columns: [columns] - Numeric columns: [numeric_columns] - Categorical columns: [categorical_columns] - DateTime columns: [datetime_columns] - Row count: [row_count] ## Data Sample [data_sample] ## Available Chart Types 1. **line** - For trends over time, time series data 2. **bar** - For comparing categories, discrete comparisons 3. **pie** - For showing proportions, parts of a whole (best for <= 6 categories) 4. **scatter** - For showing relationships between two numeric variables 5. **histogram** - For showing distribution of a single numeric variable 6. **box** - For showing statistical distribution, outliers, quartiles 7. **heatmap** - For showing correlation or intensity across two dimensions 8. **table** - For detailed data examination, small datasets, or when charts aren't suitable ## Analysis Guidelines Consider: - The user's intent and question keywords - Data types and structure - Number of data points and categories - What insights the user is likely seeking ## Response Format Respond with ONLY the chart type name (line, bar, pie, scatter, histogram, box, heatmap, or table). No explanation needed. ================================================ FILE: openchatbi/text2sql/__init__.py ================================================ """Text-to-SQL conversion module for OpenChatBI.""" ================================================ FILE: openchatbi/text2sql/data.py ================================================ import os from openchatbi import config from openchatbi.text2sql.text2sql_utils import init_sql_example_retriever, init_table_selection_example_dict # Skip init during documentation build if not os.environ.get("SPHINX_BUILD"): try: _catalog_store = config.get().catalog_store except ValueError: _catalog_store = None else: _catalog_store = None if _catalog_store: sql_example_retriever, sql_example_dicts = init_sql_example_retriever(_catalog_store, config.get().vector_db_path) table_selection_retriever, table_selection_example_dict = init_table_selection_example_dict( _catalog_store, config.get().vector_db_path ) else: sql_example_retriever, sql_example_dicts = None, {} table_selection_retriever, table_selection_example_dict = None, {} ================================================ FILE: openchatbi/text2sql/extraction.py ================================================ """Information extraction module for text2sql processing.""" import traceback from collections.abc import Callable from datetime import date from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from openchatbi.graph_state import SQLGraphState from openchatbi.llm.llm import call_llm_chat_model_with_retry from openchatbi.prompts.system_prompt import get_basic_knowledge, get_extraction_prompt_template from openchatbi.utils import extract_json_from_answer, get_text_from_content, log def generate_extraction_prompt() -> str: """Generate extraction prompt. Returns: str: Generated prompt with placeholders replaced. """ prompt = get_extraction_prompt_template() date_str = date.today().strftime("%Y-%m-%d") prompt = prompt.replace("[time_field_placeholder]", date_str) prompt = prompt.replace("[basic_knowledge_glossary]", get_basic_knowledge()) return prompt def parse_extracted_info_json(llm_answer_content: Any) -> dict[str, Any]: """Extract and parse JSON from LLM response. Args: llm_answer_content: LLM response containing JSON. Returns: dict: Parsed JSON or empty dict if parsing fails. """ try: text = get_text_from_content(llm_answer_content) result = extract_json_from_answer(text) except Exception: log(traceback.format_exc()) result = {} return result def information_extraction(llm: BaseChatModel) -> Callable: """Create function to extract information from questions. Args: llm (BaseChatModel): Language model for information extraction. Returns: function: Node function that extracts information from questions. """ def _extract(state: SQLGraphState): """Extract information from question in state. Args: state (SQLGraphState): Current SQL graph state with question. Returns: dict: Updated state with extracted information. """ messages = state["messages"] last_message = messages[-1] user_input = last_message.content log(f"information_extraction: {user_input}") system_prompt = generate_extraction_prompt() prompt = "Please extract the information according to the context." response = call_llm_chat_model_with_retry( llm, ([SystemMessage(system_prompt)] + messages + [HumanMessage(prompt)]), ["search_knowledge", "AskHuman"] ) if response: log(response) if response.tool_calls: return {"messages": [response]} else: llm_answer_content = response.content parsed_result = parse_extracted_info_json(llm_answer_content) return { "messages": [response], "rewrite_question": parsed_result.get("rewrite_question"), "info_entities": parsed_result, } else: return {"messages": [AIMessage(role="system", content="{}")]} return _extract def information_extraction_conditional_edges(state: SQLGraphState): """Determine next node after information extraction. Args: state (SQLGraphState): Current SQL graph state. Returns: str: Next node ('ask_human', 'search_knowledge', 'next', or 'end'). """ messages = state["messages"] last_message = messages[-1] tool_calls = None if isinstance(last_message, AIMessage): tool_calls = last_message.tool_calls log(f"tool_calls: {tool_calls}") if tool_calls: if tool_calls[0]["name"] == "AskHuman": return "ask_human" elif tool_calls[0]["name"] == "search_knowledge": return "search_knowledge" else: print(f"Unknown tool call: {tool_calls[0]['name']}") return "end" else: if "rewrite_question" in state: return "next" else: return "end" ================================================ FILE: openchatbi/text2sql/generate_sql.py ================================================ import datetime from collections.abc import Callable from typing import Any import pandas as pd from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from sqlalchemy import text from sqlalchemy.exc import DatabaseError, OperationalError, ProgrammingError, TimeoutError from openchatbi.catalog import CatalogStore from openchatbi.constants import ( SQL_EXECUTE_TIMEOUT, SQL_NA, SQL_SUCCESS, SQL_SYNTAX_ERROR, SQL_UNKNOWN_ERROR, datetime_format, ) from openchatbi.graph_state import SQLGraphState from openchatbi.prompts.system_prompt import get_text2sql_dialect_prompt_template from openchatbi.text2sql.data import sql_example_dicts, sql_example_retriever from openchatbi.text2sql.visualization import VisualizationService from openchatbi.utils import get_text_from_content, log COLUMN_PROMPT_TEMPLATE = """### Columns Column(Name, Type, Display Name, Description): [ {} ] """ def create_sql_nodes( llm: BaseChatModel, catalog: CatalogStore, dialect: str, visualization_mode: str | None = "rule" ) -> tuple[Callable, Callable, Callable, Callable]: """Creates the four SQL processing nodes for LangGraph. Args: llm (BaseChatModel): The language model to use for SQL generation. catalog (CatalogStore): The catalog store containing schema information. dialect (str): The SQL dialect to use (e.g., 'presto', 'mysql'). visualization_mode (str | None): Visualization analysis mode ("rule", "llm", or None to skip). Returns: tuple: Four node functions (generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node) """ # Initialize visualization service based on configuration visualization_service = VisualizationService(llm if visualization_mode == "llm" else None) def _get_column_prompt(column: dict[str, Any]) -> str: alias_prompt = f"alias({column['alias']})" if "alias" in column and column["alias"] else "" return ( f""" Column("{column['column_name']}", {column['type']}, {column['display_name']},""" f""" "{alias_prompt}{column['description']}"),""" ) def _get_table_schema_prompt(tables_columns: list[dict[str, Any]]) -> str: """Generates a prompt string for table schemas, including table description, columns, derived metrics and rules when writting SQL Args: tables_columns (List[Dict[str, Any]]): List of tables with selected columns. Returns: str: Formatted table schema prompt string. """ schema_prompt = [] for table_dict in tables_columns: table_name = table_dict["table"] # TODO maybe use columns in prompt columns = table_dict["columns"] table_info = catalog.get_table_information(table_name) single_table_schema_prompt = f"## Table {table_name}\n{table_info['description']}\n" columns = catalog.get_column_list(table_name) single_table_schema_prompt += COLUMN_PROMPT_TEMPLATE.format( "\n".join([_get_column_prompt(column) for column in columns]) ) single_table_schema_prompt += table_info.get("derived_metric", "") single_table_schema_prompt += table_info["sql_rule"] schema_prompt.append(single_table_schema_prompt) return "\n".join(schema_prompt) def _get_relevant_sql_examples_prompt(question, tables_columns: list[dict[str, Any]]) -> str: """Retrieves relevant SQL examples based on the question and selected tables. Args: question (str): The natural language question. tables_columns (List[str]): List of selected tables with selected columns. Returns: str: Formatted string of relevant SQL examples. """ tables = [d["table"] for d in tables_columns] relevant_questions = sql_example_retriever.invoke(question) # log(f"Retrieved examples for question: {question} \n Relevant questions: {relevant_questions}") # filter examples that only use the selected tables examples = [] for relevant_document in relevant_questions: question = relevant_document.page_content example_sql, used_tables = sql_example_dicts[question] if all(table in tables for table in used_tables): examples.append(f"\nQ: {question}\nA: {example_sql}\n\n") log(f"Examples using selected tables: {examples}") return "\n".join(examples) def _analyze_dataframe_schema(df: pd.DataFrame) -> dict[str, Any]: """Analyze DataFrame to understand column types and characteristics.""" try: schema_info = { "columns": list(df.columns), "column_types": {}, "row_count": len(df), "numeric_columns": [], "categorical_columns": [], "datetime_columns": [], } for col in df.columns: dtype = str(df[col].dtype) schema_info["column_types"][col] = dtype # Classify column types if df[col].dtype in ["int64", "float64", "int32", "float32"]: schema_info["numeric_columns"].append(col) elif df[col].dtype == "object": # Check if it could be datetime try: pd.to_datetime(df[col].head(10)) schema_info["datetime_columns"].append(col) except: schema_info["categorical_columns"].append(col) # Calculate unique value counts for categorical columns schema_info["unique_counts"] = {} for col in schema_info["categorical_columns"]: schema_info["unique_counts"][col] = df[col].nunique() return schema_info except Exception as e: return {"error": f"Failed to analyze data schema: {str(e)}"} def _execute_sql(sql: str) -> tuple[dict, str]: """Executes the generated SQL query and returns the result with schema analysis. Args: sql (str): The SQL query to execute. Returns: Tuple[dict, str]: A tuple containing (schema_info, CSV string). """ with catalog.get_sql_engine().connect() as connection: result = connection.execute(text(sql)) # Fetch all rows from the result rows = result.fetchall() # Get column names columns = list(result.keys()) # Create DataFrame for analysis df = pd.DataFrame(rows, columns=columns) # Analyze data schema schema_info = _analyze_dataframe_schema(df) # Format as CSV csv_data = df.to_csv(index=False) connection.commit() return schema_info, csv_data def generate_sql_node(state: SQLGraphState) -> dict: """First node: Generates initial SQL query based on the state. Args: state (SQLGraphState): The current SQL graph state containing the question and tables. Returns: dict: Updated state with generated SQL query. """ if "rewrite_question" not in state: log("Missing rewrite question, skipping SQL generation.") return {} if "tables" not in state or len(state["tables"]) == 0: log("Missing tables, skipping SQL generation.") return {} question = state["rewrite_question"] tables_columns = state["tables"] system_prompt = ( get_text2sql_dialect_prompt_template(dialect) .replace("[table_schema]", _get_table_schema_prompt(tables_columns)) .replace("[examples]", _get_relevant_sql_examples_prompt(question, tables_columns)) .replace("[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format)) ) user_prompt = f"""Generate a SQL query for the question: {question}""" messages = [SystemMessage(system_prompt)] + list(state["messages"]) + [HumanMessage(user_prompt)] response = llm.invoke(messages) response_content = get_text_from_content(response.content) sql_query = response_content.replace("```sql", "").replace("```", "").strip() if not sql_query or sql_query.lower() == "null": log(f"Generated SQL query is empty. LLM output: {response.content}") return { "messages": [AIMessage(response_content)], "sql": sql_query, "sql_retry_count": 0, "sql_execution_result": "", "previous_sql_errors": [], } return {"sql": sql_query, "sql_retry_count": 0, "sql_execution_result": "", "previous_sql_errors": []} def execute_sql_node(state: SQLGraphState) -> dict: """Second node: Executes the SQL query and returns result or error. Args: state (SQLGraphState): The current SQL graph state containing the SQL query. Returns: dict: Updated state with execution result or error information. """ sql_query = state.get("sql", "").strip() if not sql_query: return {"sql_execution_result": SQL_NA, "messages": [AIMessage("No SQL query to execute")]} try: schema_info, csv_result = _execute_sql(sql_query) result = f"```sql\n{sql_query}\n```\nSQL Result:\n```csv\n{csv_result}\n```" return { "sql_execution_result": SQL_SUCCESS, "schema_info": schema_info, "data": csv_result, "messages": [AIMessage(result)], } except (OperationalError, TimeoutError) as e: log(f"Database connection/timeout error: {str(e)}") error_result = ( f"```sql\n{sql_query}\n```\nDatabase Connection Timeout: {str(e)}\nPlease check database connectivity." ) return {"sql_execution_result": SQL_EXECUTE_TIMEOUT, "messages": [AIMessage(error_result)]} except Exception as e: error_type = "Unexpected error" if isinstance(e, ProgrammingError): error_type = "SQL syntax error" elif isinstance(e, DatabaseError): error_type = "Database error" log(f"{error_type}: {str(e)}") # Add error to previous errors list previous_errors = list(state.get("previous_sql_errors", [])) previous_errors.append({"sql": sql_query, "error": f"{error_type}: {str(e)}", "error_type": error_type}) return { "sql_execution_result": SQL_UNKNOWN_ERROR if error_type == "Unexpected error" else SQL_SYNTAX_ERROR, "previous_sql_errors": previous_errors, } def regenerate_sql_node(state: SQLGraphState) -> dict: """Third node: Regenerates SQL based on previous errors. Args: state (SQLGraphState): The current SQL graph state containing error information. Returns: dict: Updated state with regenerated SQL query. """ question = state["rewrite_question"] tables = state["tables"] previous_errors = state.get("previous_sql_errors", []) retry_count = state.get("sql_retry_count", 0) + 1 system_prompt = ( get_text2sql_dialect_prompt_template(dialect) .replace("[table_schema]", _get_table_schema_prompt(tables)) .replace("[examples]", _get_relevant_sql_examples_prompt(question, tables)) .replace("[time_field_placeholder]", datetime.datetime.now().strftime(datetime_format)) ) user_prompt = f"""Generate a SQL query for the question: {question}""" if previous_errors: user_prompt += "\n\nPrevious attempts failed with errors:" for i, error_info in enumerate(previous_errors, 1): user_prompt += f"\n\nAttempt {i}:\nSQL: {error_info['sql']}\nError: {error_info['error']}" user_prompt += "\n\nPlease analyze the errors above and generate a corrected SQL query." messages = [SystemMessage(system_prompt)] + list(state["messages"]) + [HumanMessage(user_prompt)] response = llm.invoke(messages) response_content = get_text_from_content(response.content) sql_query = response_content.replace("```sql", "").replace("```", "").strip() if not sql_query: log(f"Generated SQL query is empty. LLM output: {response.content}") error_result = f"Failed to regenerate valid SQL after {retry_count} attempts." return { "messages": [AIMessage(error_result)], "sql": "", "sql_retry_count": retry_count, "sql_execution_result": SQL_NA, } return {"sql": sql_query, "sql_retry_count": retry_count, "sql_execution_result": ""} def generate_visualization_node(state: SQLGraphState) -> dict: """Fourth node: Generates visualization DSL based on successful SQL execution result. Args: state (SQLGraphState): The current SQL graph state containing query data and results. Returns: dict: Updated state with visualization DSL. """ execution_result = state.get("sql_execution_result", "") if execution_result != SQL_SUCCESS: # No visualization for failed queries return {"visualization_dsl": {}} question = state.get("rewrite_question", "") schema_info = state.get("schema_info", {}) data = state.get("data", "") if not question or not schema_info or not data or not visualization_mode: return {"visualization_dsl": {}} try: # Generate visualization DSL using configured service viz_dsl = visualization_service.generate_visualization(question, schema_info, data) # Handle case where visualization is skipped if viz_dsl is None: return {"visualization_dsl": {}} # Update the AI message to include visualization information messages = list(state.get("messages", [])) if messages and hasattr(messages[-1], "content"): current_content = messages[-1].content viz_info = f"\n\n**Visualization Generated**: {viz_dsl.chart_type.title()} chart with {len(viz_dsl.data_columns)} column(s)" messages[-1] = AIMessage(current_content + viz_info) return {"visualization_dsl": viz_dsl.to_dict(), "messages": messages} except Exception as e: log(f"Visualization generation error: {str(e)}") return {"visualization_dsl": {"error": f"Failed to generate visualization: {str(e)}"}} return generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node def should_retry_sql(state: SQLGraphState) -> str: """Conditional edge function to determine if SQL should be retried. Args: state (SQLGraphState): Current state Returns: str: Next node name - "regenerate_sql" if retry needed, "end" if done """ execution_result = state.get("sql_execution_result", "") retry_count = state.get("sql_retry_count", 0) max_retries = 3 if execution_result in (SQL_SUCCESS, SQL_EXECUTE_TIMEOUT): return "end" elif retry_count < max_retries: return "regenerate_sql" else: # Max retries reached or other terminal state if retry_count >= max_retries: previous_errors = state.get("previous_sql_errors", []) if previous_errors: last_error = previous_errors[-1] error_result = f"```sql\n{last_error['sql']}\n```\n{last_error['error']}\nFailed to generate valid SQL after {max_retries} attempts." else: error_result = f"Failed to generate valid SQL after {max_retries} attempts." # Update state with final error message state["messages"] = [AIMessage(error_result)] state["sql_execution_result"] = SQL_NA return "end" def should_execute_sql(state: SQLGraphState) -> str: """Conditional edge function to determine if SQL should be executed. Args: state (SQLGraphState): Current state Returns: str: Next node name - "execute_sql" if SQL is generated, "end" if done """ sql = state.get("sql", "") if not sql: return "end" else: return "execute_sql" ================================================ FILE: openchatbi/text2sql/schema_linking.py ================================================ """Schema linking module for table and column selection in text2sql.""" from datetime import datetime from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage from openchatbi.catalog import CatalogStore from openchatbi.catalog.schema_retrival import col_dict, column_tables_mapping, get_relevant_columns from openchatbi.constants import datetime_format from openchatbi.graph_state import SQLGraphState from openchatbi.prompts.system_prompt import get_table_selection_prompt_template from openchatbi.text2sql.data import table_selection_example_dict, table_selection_retriever from openchatbi.utils import extract_json_from_answer, log def schema_linking(llm: BaseChatModel, catalog: CatalogStore): """Create function for schema linking: select appropriate tables and columns for a question. Args: llm (BaseChatModel): Language model for table selection. catalog (CatalogStore): Catalog store with schema information. Returns: function: Node function for schema linking based on question. """ def _get_related_tables_and_columns(keywords_list, dimensions, metrics, start_time=None, invalid_table=None): """Retrieves tables and columns related to the given keywords, dimensions, and metrics. Args: keywords_list (list): List of keywords extracted from the question. dimensions (list): List of dimensions mentioned in the question. metrics (list): List of metrics mentioned in the question. start_time (str, optional): Start time for filtering tables. invalid_table (list, optional): List of tables to exclude. Returns: dict: Dictionary mapping table names to their information and related columns. """ # 1. Get the top similar columns relevant_columns = get_relevant_columns(keywords_list, dimensions, metrics) # 2. Get all the related tables candidate_tables = set() for column in relevant_columns: table_list = column_tables_mapping.get(column, []) candidate_tables.update(table_list) if start_time: try: start_time = datetime.strptime(start_time, datetime_format) except ValueError: start_time = None # 3. Get all the table's related column related_table_column_dict = {} for table_name in candidate_tables: if table_name in invalid_table: continue table_info = catalog.get_table_information(table_name) if not table_info: continue if start_time and "start_time" in table_info: if datetime.strptime(table_info.get("start_time"), datetime_format) > start_time: continue columns = [] for column_name in relevant_columns: column_dict = col_dict[column_name].copy() if table_name not in column_tables_mapping.get(column_name, []): continue columns.append(column_dict) related_table_column_dict[table_name] = (table_info, columns) return related_table_column_dict def _example_retrieval(query, candidate_tables): """Retrieves example questions and their selected tables that match the candidate tables. Args: query (str): The natural language question. candidate_tables (list): List of candidate table names. Returns: dict: Dictionary mapping example questions to their selected tables. """ similar_questions = table_selection_retriever.invoke(query) valid_examples = {} for question_doc in similar_questions: question = question_doc.page_content if not question: continue expected_tables = table_selection_example_dict[question] expected_tables = [table for table in expected_tables if table in candidate_tables] if expected_tables: valid_examples[question] = expected_tables return valid_examples def _build_table_selection_prompt(related_table_column_dict, similar_examples): """Builds a prompt for table selection based on related tables and examples. Args: related_table_column_dict (dict): Dictionary of tables with their information and columns. similar_examples (dict): Dictionary of example questions and their selected tables. Returns: str: Formatted prompt for table selection. """ similar_examples = [ f"- Question: {example} Selected Tables: [{','.join(selected_tables)}]" for example, selected_tables in similar_examples.items() ] table_column_descs = [] for table_name, (table_info, columns) in related_table_column_dict.items(): columns_desc = "\n".join( [ f"- {column['category']}({column['column_name']}, {column['display_name']}, \"{column['description']}\")" for column in columns ] ) desc_part = f"\n### Table Description: \n{table_info['description']}" rule_part = f"\n### Rule: \n{table_info.get('selection_rule')}" if table_info.get("selection_rule") else "" table_desc = ( f"\n## Table: {table_name} {desc_part} {rule_part}" "\n### Columns: \nCategory(Name, Display Name, Description): " f"\n{columns_desc}" "" ) table_column_descs.append(table_desc) # Build the LLM prompt prompt = ( get_table_selection_prompt_template() .replace("[tables]", "\n\n".join(table_column_descs)) .replace("[examples]", "\n".join(similar_examples)) ) return prompt def _verify_table(selected_tables, candidate_tables): """Verifies that selected tables are valid candidates. Args: selected_tables (list): List of tables selected by the model. candidate_tables (list): List of candidate tables. Returns: bool: True if all selected tables are valid candidates. """ if not selected_tables: return False for table in selected_tables: if table.get("table") not in candidate_tables: return False return True def _call_llm_select(llm: BaseChatModel, system_prompt, messages, question, candidate_tables): """Calls the language model to select appropriate tables for the question. Retries up to 3 times if the LLM's answer is invalid. Args: llm (BaseChatModel): The language model to use. system_prompt (str): The system prompt for table selection. messages (list): List of previous messages. question (str): The natural language question. candidate_tables (list): List of candidate tables. Returns: dict: Dictionary containing selected tables. """ log("Selecting appropriate tables...") # print(f"candidate_tables: {candidate_tables}") prompt = f"""Please select the appropriate tables for the question: {question}""" messages.append(HumanMessage(prompt)) retry_flag = True retry_cnt = 1 while retry_flag: try: log("Ask LLM to select the table...") # print("_call_llm_select") # print(messages) response = llm.invoke([SystemMessage(system_prompt)] + messages) result = extract_json_from_answer(response.content) selected_tables = result.get("tables") log(result) if _verify_table(selected_tables, candidate_tables): return {"tables": selected_tables} else: messages.append( HumanMessage( f'The selected table {",".join([table.get("table") for table in result.get("tables")])} is not valid. ' f"Do not select this table, please try again." ) ) retry_cnt += 1 if retry_cnt > 3: retry_flag = False if retry_flag: log( f"The selected table {','.join([table.get('table') for table in result.get('tables')])} is not in the candidate tables." ) log("Retry Table Selection...") except Exception as e: log(str(e)) retry_cnt += 1 if retry_cnt > 3: retry_flag = False return {} def _select(state: SQLGraphState) -> dict: if not state.get("rewrite_question"): log("Missing rewrite question, skipping schema linking.") return {} messages = state["messages"] question = state["rewrite_question"] info_entities = state["info_entities"] keywords_list = info_entities.get("keywords", []) dimensions = info_entities.get("dimensions", []) metrics = info_entities.get("metrics", []) start_time = info_entities.get("start_time") invalid_table = [] log("Retrieving related table schema...") # 1. Get related tables and columns related_table_column_dict = _get_related_tables_and_columns( keywords_list, dimensions, metrics, start_time, invalid_table ) candidate_tables = related_table_column_dict.keys() # 2. Get the similar examples similar_examples = _example_retrieval(" ".join(keywords_list), related_table_column_dict.keys()) # 3. Build tables prompt system_prompt = _build_table_selection_prompt(related_table_column_dict, similar_examples) # 4. Call LLM to select the table return _call_llm_select(llm, system_prompt, messages, question, candidate_tables) return _select ================================================ FILE: openchatbi/text2sql/sql_graph.py ================================================ """SQL generation graph construction and execution.""" from langchain_openai.chat_models.base import BaseChatOpenAI from langgraph.constants import END, START from langgraph.graph import StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode from langgraph.store.base import BaseStore from langgraph.types import Checkpointer, interrupt from openchatbi import config from openchatbi.catalog import CatalogStore from openchatbi.constants import SQL_SUCCESS from openchatbi.graph_state import InputState, SQLGraphState, SQLOutputState from openchatbi.llm.llm import get_llm, get_text2sql_llm from openchatbi.text2sql.extraction import information_extraction, information_extraction_conditional_edges from openchatbi.text2sql.generate_sql import create_sql_nodes, should_execute_sql from openchatbi.text2sql.schema_linking import schema_linking from openchatbi.tool.ask_human import AskHuman from openchatbi.tool.search_knowledge import search_knowledge def ask_human(state): """Node function to ask human for additional information or clarification. Args: state (SQLGraphState): The current SQL graph state containing messages and context. Returns: dict: Updated state with human feedback as a tool message and user input. """ tool_call = state["messages"][-1].tool_calls[0] tool_call_id = tool_call["id"] args = tool_call["args"] user_feedback = interrupt({"text": args["question"], "buttons": args.get("options", None)}) tool_message = [{"tool_call_id": tool_call_id, "type": "tool", "content": user_feedback}] return {"messages": tool_message, "user_input": user_feedback} def should_generate_visualization_or_retry(state: SQLGraphState) -> str: """Conditional edge function to determine next action after execute_sql. Args: state (SQLGraphState): Current state Returns: str: Next node name - "generate_visualization" if SQL succeeded, "regenerate_sql" if retry needed, "end" if done """ execution_result = state.get("sql_execution_result", "") retry_count = state.get("sql_retry_count", 0) max_retries = 3 if execution_result == SQL_SUCCESS: return "generate_visualization" elif retry_count < max_retries and execution_result not in ("SQL_EXECUTE_TIMEOUT",): return "regenerate_sql" else: return "end" def build_sql_graph( catalog: CatalogStore, checkpointer: Checkpointer, memory_store: BaseStore, llm_provider: str | None = None ) -> CompiledStateGraph: """Build SQL generation graph with all nodes and edges. Args: catalog: Catalog store containing schema information. checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory. memory_store: The BaseStore to use for long-term memory. If None, no long-term memory. Returns: CompiledStateGraph: Compiled SQL graph ready for execution. """ tools = [search_knowledge, AskHuman] search_tool_node = ToolNode([search_knowledge]) default_llm = get_llm(llm_provider) if isinstance(default_llm, BaseChatOpenAI): llm_with_tools = default_llm.bind_tools(tools, strict=True).bind(response_format={"type": "json_object"}) else: llm_with_tools = default_llm.bind_tools(tools) # Create SQL processing nodes with visualization configuration generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node = create_sql_nodes( get_text2sql_llm(llm_provider), catalog, dialect=config.get().dialect, visualization_mode=config.get().visualization_mode, ) # Define the SQL generation graph graph = StateGraph(SQLGraphState, input_schema=InputState, output_schema=SQLOutputState) # Add nodes to the graph graph.add_node("search_knowledge", search_tool_node) graph.add_node("ask_human", ask_human) graph.add_node("information_extraction", information_extraction(llm_with_tools)) graph.add_node("table_selection", schema_linking(default_llm, catalog)) graph.add_node("generate_sql", generate_sql_node) graph.add_node("execute_sql", execute_sql_node) graph.add_node("regenerate_sql", regenerate_sql_node) graph.add_node("generate_visualization", generate_visualization_node) # Add basic edges graph.add_edge(START, "information_extraction") graph.add_edge("ask_human", "information_extraction") graph.add_edge("search_knowledge", "information_extraction") graph.add_edge("table_selection", "generate_sql") # Add conditional routing from information extraction graph.add_conditional_edges( "information_extraction", information_extraction_conditional_edges, # mapping of paths to node names { "ask_human": "ask_human", "search_knowledge": "search_knowledge", "next": "table_selection", "end": END, }, ) # Add conditional edges for generate_sql graph.add_conditional_edges( "generate_sql", should_execute_sql, { "execute_sql": "execute_sql", "end": END, }, ) # Add conditional edges for regenerate_sql graph.add_conditional_edges( "regenerate_sql", should_execute_sql, { "execute_sql": "execute_sql", "end": END, }, ) # Add conditional edges for execute_sql - either retry, generate visualization, or end graph.add_conditional_edges( "execute_sql", should_generate_visualization_or_retry, { "generate_visualization": "generate_visualization", "regenerate_sql": "regenerate_sql", "end": END, }, ) # Add edge from visualization to end graph.add_edge("generate_visualization", END) graph = graph.compile(name="text2sql_graph", checkpointer=checkpointer, store=memory_store) return graph ================================================ FILE: openchatbi/text2sql/text2sql_utils.py ================================================ """Utility functions for text2sql retrieval systems.""" from openchatbi.llm.llm import get_embedding_model from openchatbi.utils import create_vector_db def init_sql_example_retriever(catalog, vector_db_path: str = None): """Initialize SQL example retriever from catalog. Args: catalog: Catalog store containing SQL examples. vector_db_path: Path to the vector database file. Returns: tuple: (retriever, sql_example_dict) """ sql_examples = catalog.get_sql_examples() sql_example_dict = {q: (sql, table) for q, sql, table in sql_examples} texts = list(sql_example_dict.keys()) vector_db = create_vector_db( texts, get_embedding_model(), collection_name="text2sql", collection_metadata={"hnsw:space": "cosine"}, chroma_db_path=vector_db_path, ) retriever = vector_db.as_retriever( search_type="mmr", search_kwargs={"distance_metric": "cosine", "fetch_k": 30, "k": 10} ) return retriever, sql_example_dict def init_table_selection_example_dict(catalog, vector_db_path: str = None): """Initialize table selection example retriever from catalog. Args: catalog: Catalog store containing table selection examples. vector_db_path: Path to the vector database file. Returns: tuple: (retriever, table_selection_example_dict) """ sql_examples = catalog.get_table_selection_examples() table_selection_example_dict = dict((q, tables) for q, tables in sql_examples) texts = list(table_selection_example_dict.keys()) if not texts: texts = [""] # Empty text as fallback vector_db = create_vector_db( texts, get_embedding_model(), collection_name="table_selection_example", collection_metadata={"hnsw:space": "cosine"}, chroma_db_path=vector_db_path, ) retriever = vector_db.as_retriever( search_type="mmr", search_kwargs={"distance_metric": "cosine", "fetch_k": 30, "k": 10} ) return retriever, table_selection_example_dict ================================================ FILE: openchatbi/text2sql/visualization.py ================================================ """Visualization generation for SQL query results using Plotly.""" from dataclasses import dataclass from enum import Enum from io import StringIO from typing import Any import pandas as pd from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage from openchatbi.prompts.system_prompt import get_visualization_prompt_template class ChartType(Enum): """Supported chart types for data visualization.""" LINE = "line" BAR = "bar" PIE = "pie" SCATTER = "scatter" HISTOGRAM = "histogram" BOX = "box" HEATMAP = "heatmap" TABLE = "table" @dataclass class VisualizationConfig: """Configuration for generating visualization DSL.""" chart_type: ChartType x_column: str | None = None y_column: str | None = None color_column: str | None = None size_column: str | None = None title: str | None = None x_title: str | None = None y_title: str | None = None show_legend: bool = True width: int | None = None height: int | None = None @dataclass class VisualizationDSL: """Plotly-friendly DSL for data visualization.""" chart_type: str data_columns: list[str] config: dict[str, Any] layout: dict[str, Any] def to_dict(self) -> dict[str, Any]: """Convert to dictionary for JSON serialization.""" return { "chart_type": self.chart_type, "data_columns": self.data_columns, "config": self.config, "layout": self.layout, } class VisualizationService: """Service class to handle visualization generation with configurable analysis method.""" # Chart type mapping for LLM responses CHART_TYPE_MAPPING = { "line": ChartType.LINE, "bar": ChartType.BAR, "pie": ChartType.PIE, "scatter": ChartType.SCATTER, "histogram": ChartType.HISTOGRAM, "box": ChartType.BOX, "heatmap": ChartType.HEATMAP, "table": ChartType.TABLE, } def __init__(self, llm: BaseChatModel | None = None): """Initialize visualization service. Args: llm: BaseChatModel LLM instance, will skip using LLM if None """ self.llm = llm def _get_chart_type_by_rule(self, question: str, schema_info: dict[str, Any]) -> ChartType: """Recommend chart type based on user question and data schema using rules.""" question_lower = question.lower() # Get data characteristics numeric_cols = schema_info.get("numeric_columns", []) categorical_cols = schema_info.get("categorical_columns", []) datetime_cols = schema_info.get("datetime_columns", []) row_count = schema_info.get("row_count", 0) # Question-based heuristics if any(keyword in question_lower for keyword in ["trend", "over time", "timeline", "time series"]): return ChartType.LINE elif any(keyword in question_lower for keyword in ["distribution", "frequency", "histogram"]): return ChartType.HISTOGRAM elif any(keyword in question_lower for keyword in ["correlation", "relationship", "scatter"]): return ChartType.SCATTER elif any(keyword in question_lower for keyword in ["proportion", "percentage", "share", "pie"]): return ChartType.PIE elif any(keyword in question_lower for keyword in ["compare", "comparison", "vs", "versus", "bar"]): return ChartType.BAR elif any(keyword in question_lower for keyword in ["summary", "range", "quartile", "box"]): return ChartType.BOX # Data-based heuristics if len(datetime_cols) > 0 and len(numeric_cols) > 0: return ChartType.LINE elif len(categorical_cols) == 1 and len(numeric_cols) == 1: unique_count = schema_info.get("unique_counts", {}).get(categorical_cols[0], 0) if unique_count <= 10: return ChartType.PIE if unique_count <= 6 else ChartType.BAR else: return ChartType.BAR elif len(numeric_cols) == 2: return ChartType.SCATTER elif len(numeric_cols) == 1 and len(categorical_cols) == 0: return ChartType.HISTOGRAM elif row_count <= 20: # Changed from 50 to 20 return ChartType.TABLE else: return ChartType.BAR def generate_visualization_dsl( self, question: str, schema_info: dict[str, Any], chart_type: ChartType | None = None ) -> VisualizationDSL: """Generate visualization DSL based on question and schema info.""" if "error" in schema_info: # Return table view for error cases return VisualizationDSL( chart_type="table", data_columns=["error"], config={"error": schema_info["error"]}, layout={"title": "Data Analysis Error"}, ) # Determine chart type if chart_type is None: chart_type = self._get_chart_type_by_rule(question, schema_info) columns = schema_info["columns"] numeric_cols = schema_info["numeric_columns"] categorical_cols = schema_info["categorical_columns"] datetime_cols = schema_info["datetime_columns"] # Generate DSL based on chart type if chart_type == ChartType.LINE: x_col = datetime_cols[0] if datetime_cols else (categorical_cols[0] if categorical_cols else columns[0]) # For line charts, include all numeric columns for multiple metrics y_cols = numeric_cols if numeric_cols else [columns[-1]] data_columns = [x_col] + y_cols # Support multiple y-axis columns config = {"x": x_col, "mode": "lines+markers"} if len(y_cols) == 1: config["y"] = y_cols[0] title = f"Line Chart: {y_cols[0]} over {x_col}" else: config["y"] = y_cols # Multiple metrics title = f"Line Chart: {', '.join(y_cols)} over {x_col}" return VisualizationDSL( chart_type="line", data_columns=data_columns, config=config, layout={"title": title, "xaxis_title": x_col, "yaxis_title": "Value"}, ) elif chart_type == ChartType.BAR: x_col = categorical_cols[0] if categorical_cols else columns[0] # For bar charts, include all numeric columns for multiple metrics y_cols = numeric_cols if numeric_cols else [columns[-1]] data_columns = [x_col] + y_cols config = {"x": x_col} if len(y_cols) == 1: config["y"] = y_cols[0] title = f"Bar Chart: {y_cols[0]} by {x_col}" else: config["y"] = y_cols # Multiple metrics title = f"Bar Chart: {', '.join(y_cols)} by {x_col}" return VisualizationDSL( chart_type="bar", data_columns=data_columns, config=config, layout={"title": title, "xaxis_title": x_col, "yaxis_title": "Value"}, ) elif chart_type == ChartType.PIE: label_col = categorical_cols[0] if categorical_cols else columns[0] value_col = numeric_cols[0] if numeric_cols else columns[-1] return VisualizationDSL( chart_type="pie", data_columns=[label_col, value_col], config={"labels": label_col, "values": value_col}, layout={"title": f"Pie Chart: {value_col} by {label_col}"}, ) elif chart_type == ChartType.SCATTER: x_col = numeric_cols[0] if len(numeric_cols) > 0 else columns[0] y_col = numeric_cols[1] if len(numeric_cols) > 1 else columns[-1] return VisualizationDSL( chart_type="scatter", data_columns=[x_col, y_col], config={"x": x_col, "y": y_col, "mode": "markers"}, layout={"title": f"Scatter Plot: {y_col} vs {x_col}", "xaxis_title": x_col, "yaxis_title": y_col}, ) elif chart_type == ChartType.HISTOGRAM: col = numeric_cols[0] if numeric_cols else columns[0] return VisualizationDSL( chart_type="histogram", data_columns=[col], config={"x": col, "nbins": 20}, layout={"title": f"Histogram: Distribution of {col}", "xaxis_title": col, "yaxis_title": "Frequency"}, ) elif chart_type == ChartType.BOX: y_col = numeric_cols[0] if numeric_cols else columns[0] x_col = categorical_cols[0] if categorical_cols else None config = {"y": y_col} if x_col: config["x"] = x_col return VisualizationDSL( chart_type="box", data_columns=[col for col in [x_col, y_col] if col], config=config, layout={ "title": f"Box Plot: {y_col}" + (f" by {x_col}" if x_col else ""), "xaxis_title": x_col if x_col else "", "yaxis_title": y_col, }, ) else: # TABLE or fallback return VisualizationDSL( chart_type="table", data_columns=columns, config={"columns": columns}, layout={"title": "Data Table"} ) def _llm_recommend_chart_type(self, question: str, schema_info: dict[str, Any], data_sample: str) -> ChartType: """Use LLM to recommend chart type based on question and data analysis. Args: question: User's question or intent schema_info: Data schema information data_sample: Sample of the data Returns: ChartType: Recommended chart type """ try: prompt = ( get_visualization_prompt_template() .replace("[question]", question) .replace("[columns]", str(schema_info.get("columns", []))) .replace("[numeric_columns]", str(schema_info.get("numeric_columns", []))) .replace("[categorical_columns]", str(schema_info.get("categorical_columns", []))) .replace("[datetime_columns]", str(schema_info.get("datetime_columns", []))) .replace("[row_count]", str(schema_info.get("row_count", 0))) .replace("[data_sample]", data_sample) ) # Call LLM with the formatted prompt response = self.llm.invoke([HumanMessage(content=prompt)]) chart_type_str = response.content.strip().lower() return self.CHART_TYPE_MAPPING.get(chart_type_str, ChartType.TABLE) except Exception: # Fallback to rule-based recommendation on other LLM errors return self._get_chart_type_by_rule(question, schema_info) def generate_visualization( self, question: str, schema_info: dict[str, Any], csv_data: str, chart_type: ChartType | None = None ) -> VisualizationDSL | None: """Generate visualization using the configured analysis method. Args: question: User's question or intent schema_info: Pre-analyzed schema information csv_data: CSV data string for LLM analysis if needed chart_type: Optional specific chart type to use Returns: VisualizationDSL or None: Generated visualization configuration, or None if skipped """ # Use existing DSL generation if chart type is already specified if chart_type is not None: return self.generate_visualization_dsl(question, schema_info, chart_type) # Determine chart type based on configured method if self.llm: if "error" in schema_info: return VisualizationDSL( chart_type="table", data_columns=["error"], config={"error": schema_info["error"]}, layout={"title": "Data Analysis Error"}, ) # Prepare data sample for LLM analysis try: df = pd.read_csv(StringIO(csv_data)) data_sample = df.head(3).to_string() if len(df) > 0 else "No data available" except Exception: data_sample = "Unable to parse data" chart_type = self._llm_recommend_chart_type(question, schema_info, data_sample) # Generate DSL using determined or recommended chart type return self.generate_visualization_dsl(question, schema_info, chart_type) ================================================ FILE: openchatbi/text_segmenter.py ================================================ """Text segmentation utility with jieba support.""" import re import string import sys # Try to import jieba, fallback to None if not available # Note: jieba is not compatible with Python 3.12+ _jieba_available = False if sys.version_info < (3, 12): try: import jieba _jieba_available = True except ImportError: _jieba_available = False class TextSegmenter: """A text segmenter that uses jieba for Chinese text and simple splitting for others. This segmenter tries to use jieba for better Chinese word segmentation. If jieba is not available or Python version is 3.12+, it falls back to simple punctuation/whitespace splitting. Note: jieba is not compatible with Python 3.12+, so simple segmentation will be used on Python 3.12 and higher versions. """ def __init__(self, use_jieba: bool = True): """Initialize the text segmenter. Args: use_jieba: Whether to use jieba for Chinese text segmentation. Defaults to True. Will automatically fall back to simple segmentation if jieba is not available. """ self.use_jieba = use_jieba and _jieba_available # Include both English and Chinese punctuation chinese_punctuation = ",。!?;:" "''()【】《》〈〉「」『』〔〕" all_separators = string.punctuation + chinese_punctuation + " \t\n\r" # Create regex pattern to split on any separator self.split_pattern = "[" + re.escape(all_separators) + "]+" @staticmethod def _contains_chinese(text: str) -> bool: """Check if text contains Chinese characters. Args: text: Input text to check Returns: True if text contains Chinese characters, False otherwise """ return any("\u4e00" <= char <= "\u9fff" for char in text) def _simple_cut(self, text: str) -> list[str]: """Simple segmentation by splitting on punctuation and whitespace. Args: text: Input text to be segmented Returns: List of tokens """ if not text: return [] # Split by separators and filter empty strings tokens = re.split(self.split_pattern, text) return [token for token in tokens if token.strip()] def cut(self, text: str) -> list[str]: """Segment text into tokens. For Chinese text with jieba available, uses jieba for word segmentation. Otherwise, splits by punctuation and whitespace. Args: text: Input text to be segmented Returns: List of tokens """ if not text: return [] # Use jieba for Chinese text if available if self.use_jieba and self._contains_chinese(text): return list(jieba.cut(text)) # Fall back to simple segmentation return self._simple_cut(text) class SimpleSegmenter: """A simple text segmenter that splits text by punctuation and whitespace. This is a lightweight text segmentation tool that provides basic functionality without external dependencies. Note: This class is kept for backward compatibility. Consider using TextSegmenter instead for better Chinese text support. """ def __init__(self): # Include both English and Chinese punctuation chinese_punctuation = ",。!?;:" "''()【】《》〈〉「」『』〔〕" all_separators = string.punctuation + chinese_punctuation + " \t\n\r" # Create regex pattern to split on any separator self.split_pattern = "[" + re.escape(all_separators) + "]+" def cut(self, text: str) -> list[str]: """Segment text into tokens by splitting on punctuation and whitespace. Args: text: Input text to be segmented Returns: List of tokens """ if not text: return [] # Split by separators and filter empty strings tokens = re.split(self.split_pattern, text) return [token for token in tokens if token.strip()] # Global instance - use TextSegmenter with jieba support _segmenter = TextSegmenter() ================================================ FILE: openchatbi/tool/ask_human.py ================================================ """Tool for asking human clarification when information is ambiguous.""" from pydantic import BaseModel, Field class AskHuman(BaseModel): """Ask user for clarification when data is missing or ambiguous. Use this tool ONLY when you are STRONGLY certain that information is ambiguous or missing. First try to solve the question with available user input before calling this tool. """ question: str = Field(description="Question to ask the user for clarification") options: list[str] = Field(description="Options for user to choose (max 3). Empty if not a choice question.") ================================================ FILE: openchatbi/tool/mcp_tools.py ================================================ """MCP (Model Context Protocol) tools integration for OpenChatBI. This module provides integration with MCP servers using langchain-mcp-adapters, allowing the agent to use external tools through the Model Context Protocol. """ import asyncio import logging from concurrent.futures import ThreadPoolExecutor from typing import Any from langchain_core.tools import StructuredTool from langchain_mcp_adapters.client import MultiServerMCPClient from pydantic import BaseModel, Field from openchatbi.constants import MCP_TOOL_DEFAULT_TIMEOUT_SECONDS logger = logging.getLogger(__name__) def make_tool_sync_compatible(tool: StructuredTool, timeout: int) -> StructuredTool: """Make an async-only StructuredTool compatible with sync invocation. This wraps the async coroutine with a sync function that runs it in an event loop. Args: tool: The StructuredTool to make sync-compatible timeout: Timeout in seconds for tool execution Returns: StructuredTool with sync compatibility """ if tool.func is not None: # Tool already has sync support return tool if tool.coroutine is None: # Tool has no async function either, can't help return tool def sync_wrapper(*args: Any, **kwargs: Any) -> Any: """Synchronous wrapper for async tool function.""" try: loop = asyncio.get_event_loop() if loop.is_running(): # We're in an async context, can't use run_until_complete # Create a new thread with its own event loop with ThreadPoolExecutor(max_workers=1) as executor: def run_in_new_loop() -> Any: new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) try: return new_loop.run_until_complete(tool.coroutine(*args, **kwargs)) # type: ignore finally: new_loop.close() future = executor.submit(run_in_new_loop) return future.result(timeout=timeout) else: # No running loop, we can use run_until_complete return loop.run_until_complete(tool.coroutine(*args, **kwargs)) # type: ignore except RuntimeError: # No event loop exists, create one loop = asyncio.new_event_loop() try: return loop.run_until_complete(tool.coroutine(*args, **kwargs)) # type: ignore finally: loop.close() # Create a new StructuredTool with both sync and async functions return StructuredTool( name=tool.name, description=tool.description, args_schema=tool.args_schema, func=sync_wrapper, coroutine=tool.coroutine, ) class MCPServerConfig(BaseModel): """Configuration for MCP server connection.""" name: str = Field(description="Name of the MCP server") transport: str = Field(default="stdio", description="Transport type: stdio, sse, or streamable_http") # For stdio transport command: list[str] = Field(default_factory=list, description="Command to start the MCP server") args: list[str] = Field(default_factory=list, description="Arguments for the MCP server") env: dict[str, str] = Field(default_factory=dict, description="Environment variables") # For HTTP transports (sse, streamable_http) url: str = Field(default="", description="URL for HTTP-based transports") headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers") # Common settings enabled: bool = Field(default=True, description="Whether this MCP server is enabled") timeout: int = Field(default=MCP_TOOL_DEFAULT_TIMEOUT_SECONDS, description="Connection timeout in seconds") async def create_mcp_tools_async(server_configs: list[dict[str, Any]]) -> list[StructuredTool]: """Create MCP tools asynchronously from server configurations. This function processes MCP server configurations, establishes connections to enabled servers, retrieves available tools, and makes them sync-compatible with proper timeout configuration. Args: server_configs: List of MCP server configuration dictionaries containing server connection details, transport settings, and timeouts Returns: List of LangChain StructuredTool instances with mcp_ prefixes and sync compatibility """ if not server_configs: return [] # Filter enabled servers and convert to MCPServerConfig enabled_servers = {} max_timeout = MCP_TOOL_DEFAULT_TIMEOUT_SECONDS # Default from constants for config_dict in server_configs: try: config = MCPServerConfig(**config_dict) if not config.enabled: continue server_name = config.name # Track the maximum timeout across all servers max_timeout = max(max_timeout, config.timeout) # Build server configuration for MultiServerMCPClient if config.transport == "stdio": if not config.command: logger.warning(f"MCP server {server_name}: command required for stdio transport") continue enabled_servers[server_name] = { "transport": "stdio", "command": config.command[0] if config.command else "", "args": config.command[1:] + config.args if len(config.command) > 1 else config.args, "env": config.env, } elif config.transport in ["sse", "streamable_http"]: if not config.url: logger.warning(f"MCP server {server_name}: url required for {config.transport} transport") continue server_config: dict[str, Any] = { "transport": config.transport, "url": config.url, } if config.headers: server_config["headers"] = config.headers enabled_servers[server_name] = server_config else: logger.warning(f"MCP server {server_name}: unsupported transport {config.transport}") continue except Exception as e: logger.error(f"Invalid MCP server configuration: {e}") continue if not enabled_servers: logger.info("No enabled MCP servers found") return [] try: # Create MultiServerMCPClient and get tools with timeout client = MultiServerMCPClient(enabled_servers) tools = await asyncio.wait_for(client.get_tools(), timeout=max_timeout) logger.info(f"Successfully loaded {len(tools)} MCP tools from {len(enabled_servers)} servers") # Add server prefix to tool names and make sync-compatible prefixed_tools = [] for tool in tools: # Get server name from tool metadata or guess from tool name original_name = tool.name if not original_name.startswith("mcp_"): tool.name = f"mcp_{original_name}" # Make tool sync-compatible with configured timeout sync_compatible_tool = make_tool_sync_compatible(tool, timeout=max_timeout) prefixed_tools.append(sync_compatible_tool) return prefixed_tools except Exception as e: logger.error(f"Failed to initialize MCP client: {e}") return [] def create_mcp_tools_sync(server_configs: list[dict[str, Any]]) -> list[StructuredTool]: """Create MCP tools from server configurations synchronously. This function initializes MCP tools in a separate thread with its own event loop to avoid conflicts with existing async contexts. Args: server_configs: List of MCP server configuration dictionaries Returns: List of LangChain StructuredTool instances with sync compatibility """ if not server_configs: return [] # For sync mode, run async initialization in a thread def sync_initialize() -> list[StructuredTool]: # Create new event loop for this thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(create_mcp_tools_async(server_configs)) except Exception as e: logger.error(f"Failed to create MCP tools in sync mode: {e}") return [] finally: loop.close() try: with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(sync_initialize) return future.result(timeout=MCP_TOOL_DEFAULT_TIMEOUT_SECONDS) except Exception as e: logger.error(f"MCP tools sync initialization failed: {e}") return [] # Global variable to store async-initialized tools _async_mcp_tools = None async def get_mcp_tools_async(server_configs: list[dict[str, Any]]) -> list[StructuredTool]: """Get MCP tools asynchronously, using cached version if available. Args: server_configs: List of MCP server configuration dictionaries Returns: List of cached or newly created LangChain StructuredTool instances """ global _async_mcp_tools if _async_mcp_tools is None: _async_mcp_tools = await create_mcp_tools_async(server_configs) return _async_mcp_tools def reset_mcp_tools_cache() -> None: """Reset the async MCP tools cache.""" global _async_mcp_tools _async_mcp_tools = None ================================================ FILE: openchatbi/tool/memory.py ================================================ import functools import sys from typing import Any try: import pysqlite3 as sqlite3 except ImportError: # pragma: no cover import sqlite3 # Make sure langgraph sqlite connector uses the same sqlite module. sys.modules["sqlite3"] = sqlite3 from langchain.tools import StructuredTool from langchain_core.language_models import BaseChatModel from langchain_openai.chat_models.base import BaseChatOpenAI from langgraph.store.sqlite import SqliteStore from langgraph.store.sqlite.aio import AsyncSqliteStore from langmem import ( create_manage_memory_tool, create_memory_store_manager, create_search_memory_tool, ) from openchatbi import config try: from pydantic import BaseModel, ConfigDict except ImportError: ConfigDict = None # Use AsyncSqliteStore for async operations async_memory_store = None async_store_context_manager = None sync_memory_store = None memory_manager = None # Define profile structure class UserProfile(BaseModel): """Represents the full representation of a user.""" name: str | None = None language: str | None = None timezone: str | None = None jargon: str | None = None def get_sync_memory_store() -> SqliteStore | None: global sync_memory_store embedding_model = config.get().embedding_model if not embedding_model: return None if sync_memory_store is None: # For backwards compatibility and sync operations conn = sqlite3.connect("memory.db", check_same_thread=False) conn.isolation_level = None sync_memory_store = SqliteStore( conn, index={ "dims": 1536, "embed": embedding_model, "fields": ["text"], # specify which fields to embed }, ) try: sync_memory_store.setup() except Exception: pass return sync_memory_store async def get_async_memory_store() -> AsyncSqliteStore | None: """Get or create the async memory store.""" global async_memory_store, async_store_context_manager embedding_model = config.get().embedding_model if not embedding_model: return None if async_memory_store is None: # AsyncSqliteStore.from_conn_string returns an async context manager async_store_context_manager = AsyncSqliteStore.from_conn_string( "memory.db", index={ "dims": 1536, "embed": embedding_model, "fields": ["text"], # specify which fields to embed }, ) async_memory_store = await async_store_context_manager.__aenter__() return async_memory_store async def cleanup_async_memory_store() -> None: """Cleanup async memory store resources.""" global async_memory_store, async_store_context_manager if async_memory_store is not None and async_store_context_manager is not None: try: await async_store_context_manager.__aexit__(None, None, None) except Exception as e: print(f"Error cleaning up async memory store: {e}") finally: async_memory_store = None async_store_context_manager = None async def setup_async_memory_store() -> Any: """Setup async memory store for langmem.""" await get_async_memory_store() def fix_schema_for_openai(schema: dict) -> None: props = schema.get("properties", {}) schema["required"] = list(props.keys()) # Since Pydantic 2.11, it will always add `additionalProperties: True` for arbitrary dictionary schemas # If it is already set to True, we need override it to False # Can remove this fix when the patch release: https://github.com/langchain-ai/langchain/pull/32879 def fix(obj): if isinstance(obj, dict): if obj.get("type") == "object" and "additionalProperties" in obj and obj["additionalProperties"]: obj["additionalProperties"] = False for v in obj.values(): fix(v) elif isinstance(obj, list): for item in obj: fix(item) fix(schema) def get_memory_manager() -> Any: global memory_manager if memory_manager is None: memory_manager = create_memory_store_manager( config.get().default_llm, schemas=[UserProfile], instructions="Extract user profile information", enable_inserts=False, ) return memory_manager class StructuredToolWithRequired(StructuredTool): def __init__(self, orig_tool: StructuredTool): name = getattr(orig_tool, "name", None) super().__init__( name=name, description=orig_tool.description, args_schema=orig_tool.args_schema, func=orig_tool.func, coroutine=orig_tool.coroutine, ) @functools.cached_property def tool_call_schema(self) -> "ArgsSchema": tcs = super().tool_call_schema try: if tcs.model_config: tcs.model_config["json_schema_extra"] = fix_schema_for_openai elif ConfigDict is not None: tcs.model_config = ConfigDict(json_schema_extra=fix_schema_for_openai) except Exception: pass return tcs def get_memory_tools( llm: BaseChatModel, sync_mode: bool = False, store: Any | None = None ) -> list[StructuredTool] | None: # Get the appropriate store based on mode if not store: if sync_mode: store = get_sync_memory_store() else: store = None if not store: return None # create langmem manage memory tool with {user_id} template manage_memory_tool = create_manage_memory_tool(namespace=("memories", "{user_id}"), store=store) search_memory_tool = create_search_memory_tool(namespace=("memories", "{user_id}"), store=store) if isinstance(llm, BaseChatOpenAI): manage_memory_tool = StructuredToolWithRequired(manage_memory_tool) search_memory_tool = StructuredToolWithRequired(search_memory_tool) return [manage_memory_tool, search_memory_tool] async def get_async_memory_tools(llm: BaseChatModel) -> list[StructuredTool]: """Get memory tools configured with async store.""" async_store = await get_async_memory_store() return get_memory_tools(llm, sync_mode=False, store=async_store) ================================================ FILE: openchatbi/tool/run_python_code.py ================================================ """Tool for running python code.""" from langchain.tools import tool from pydantic import BaseModel, Field from openchatbi.code.docker_executor import DockerExecutor, check_docker_status from openchatbi.code.local_executor import LocalExecutor from openchatbi.code.restricted_local_executor import RestrictedLocalExecutor from openchatbi.config_loader import ConfigLoader from openchatbi.utils import log class PythonCodeInput(BaseModel): reasoning: str = Field(description="Reason for using this run python code tool") code: str = Field(description="The python code to execute") def _create_executor(): """Create appropriate executor based on configuration.""" config_loader = ConfigLoader() try: config = config_loader.get() executor_type = config.python_executor.lower() except ValueError: # Configuration not loaded, use default local executor log("Configuration not loaded, using default LocalExecutor") return LocalExecutor() log(f"Creating executor of type: {executor_type}") if executor_type == "docker": # Check if Docker is available before creating DockerExecutor is_available, status_message = check_docker_status() if not is_available: log(f"Docker is not available ({status_message}), falling back to LocalExecutor") return LocalExecutor() log("Docker is available, creating DockerExecutor") return DockerExecutor() elif executor_type == "restricted_local": log("Creating RestrictedLocalExecutor") return RestrictedLocalExecutor() elif executor_type == "local": log("Creating LocalExecutor") return LocalExecutor() else: log(f"Unknown executor type '{executor_type}', using LocalExecutor as fallback") return LocalExecutor() @tool("run_python_code", args_schema=PythonCodeInput, return_direct=False, infer_schema=True) def run_python_code(reasoning: str, code: str) -> str: """Run python code string. Note: Only print outputs are visible, function return values will be ignored. Use print statements to see results. Returns: str: The print outputs of the python code """ log(f"Run Python Code, Reasoning: {reasoning}") try: executor = _create_executor() log(f"Using {executor.__class__.__name__} for code execution") success, output = executor.run_code(code) if success: return output else: return f"Error: {output}" except Exception as e: log(f"Failed to create executor: {e}") # Fallback to LocalExecutor if configuration fails log("Falling back to LocalExecutor") executor = LocalExecutor() success, output = executor.run_code(code) if success: return output else: return f"Error: {output}" ================================================ FILE: openchatbi/tool/save_report.py ================================================ """Tool for saving reports to files.""" import datetime from pathlib import Path from langchain.tools import tool from pydantic import BaseModel, Field from openchatbi import config from openchatbi.utils import log class SaveReportInput(BaseModel): content: str = Field(description="The content of the report to save") title: str = Field(description="The title of the report (will be used in filename)") file_format: str = Field( description="The file format/extension, only support 'md', 'csv', 'txt', 'json', 'html', 'xml'" ) @tool("save_report", args_schema=SaveReportInput, return_direct=False, infer_schema=True) def save_report(content: str, title: str, file_format: str = "md") -> str: """Save a report to a file with timestamp and title in filename. Args: content: The content of the report to save title: The title of the report (will be used in filename) file_format: The file format/extension, only support 'md', 'csv', 'txt', 'json', 'html', 'xml' Returns: str: Success message with download link or error message """ allowed_formats = {"md", "csv", "txt", "json", "html", "xml"} if file_format not in allowed_formats: raise ValueError(f"Unsupported file format: {file_format}") try: # Get report directory from config report_dir = config.get().report_directory # Create directory if it doesn't exist Path(report_dir).mkdir(parents=True, exist_ok=True) # Generate timestamp for filename timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # Clean title for filename (remove invalid characters) clean_title = "".join(c for c in title if c.isalnum() or c in (" ", "-")).rstrip() clean_title = clean_title.replace(" ", "_") # Create filename filename = f"{timestamp}_{clean_title}.{file_format}" file_path = Path(report_dir) / filename # Write content to file with open(file_path, "w", encoding="utf-8") as f: f.write(content) log(f"Report saved: {file_path}") # Return success message with download link download_url = f"/api/download/report/{filename}" return f"Report saved successfully! Download link: {download_url}" except Exception as e: error_msg = f"Failed to save report: {str(e)}" log(error_msg) return error_msg ================================================ FILE: openchatbi/tool/search_knowledge.py ================================================ """Tools for searching knowledge bases and schema information.""" from langchain.tools import tool from pydantic import BaseModel, Field from openchatbi import config from openchatbi.catalog.schema_retrival import col_dict, column_tables_mapping, get_relevant_columns from openchatbi.utils import log class SearchInput(BaseModel): """Input schema for knowledge search tool.""" reasoning: str = Field(description="Reason for using this search tool") query_list: list[str] = Field(description="Query terms to search (max 5, avoid duplicates)") knowledge_bases: list[str] = Field( description="""Knowledge bases to search, options are: - `"columns"`: The description, alias of columns, including dimensions and metrics. - `"business"`: The business knowledge.""" ) with_table_list: bool = Field( description="Include table list for columns (only set to True when user asks about table-column relationships)" ) @tool("search_knowledge", args_schema=SearchInput, return_direct=False, infer_schema=True) def search_knowledge( reasoning: str, query_list: list[str], knowledge_bases: list[str], with_table_list: bool = False ) -> dict[str, str]: """Search relevant knowledge from knowledge bases. Returns: Dict[str, str]: Search results for each knowledge base. """ log(f"Search knowledge, query_list={query_list}, knowledge_bases={knowledge_bases}, reasoning={reasoning}") final_results = {} if "columns" in knowledge_bases: column_results = search_column_from_catalog(query_list, with_table_list) final_results["columns"] = f"# Relevant Columns and Description:\n{column_results}" return final_results class ShowSchemaInput(BaseModel): """Input schema for show schema tool.""" reasoning: str = Field(description="Reason for showing schema") tables: list[str] = Field(description="Full table names to show (max 5)") @tool("show_schema", args_schema=ShowSchemaInput, return_direct=False, infer_schema=True) def show_schema(reasoning: str, tables: list[str]) -> list[str]: """Show table schemas including description, columns, and derived metrics. Returns: list[str]: Schema information for each table. """ log(f"Show schema, tables={tables}, reasoning={reasoning}") result = list_table_from_catalog(tables) return result def search_column_from_catalog(query_list: list[str], with_table_list: bool) -> str: """Search columns from catalog based on query list.""" relevant_column_set = set() for keywords in query_list: relevant_columns = get_relevant_columns(keywords.split(" "), keywords.split(" "), keywords.split(" ")) relevant_column_set.update(relevant_columns) column_results = render_column_result(relevant_column_set, with_table_list) return "\n".join(column_results) def list_table_from_catalog(tables: list[str]) -> list[str]: """Get table information from catalog.""" result = [] catalog_store = config.get().catalog_store for table_name in tables: table_info = catalog_store.get_table_information(table_name) if not table_info: continue table_desc = f"Table: `{table_name}` \n# Description: {table_info['description']}\n" columns = catalog_store.get_column_list(table_name) column_names = [info["column_name"] for info in columns] column_results = render_column_result(column_names) table_desc += "# Columns:\n" table_desc += "\n".join(column_results) if table_info.get("derived_metric"): table_desc += "## Derived metrics:\n" table_desc += table_info["derived_metric"] result.append(table_desc) return result def render_column_result(column_list: list[str], with_table_list: bool = False) -> list[str]: """Render column information as formatted strings.""" column_results = [] for column_name in column_list: if column_name not in col_dict: continue table_list = column_tables_mapping.get(column_name, []) column = col_dict[column_name] column_desc = ( f"## {column['column_name']}" f"\n- Column Category: {column['category']}" f"\n- Display Name: {column['display_name']} " f"\n- Description \"{column['description']}\"" ) if with_table_list: column_desc += f"\n- Related Tables: {table_list}" column_results.append(column_desc) return column_results ================================================ FILE: openchatbi/tool/timeseries_forecast.py ================================================ """Tool for time series forecasting.""" import logging from typing import Any import requests from langchain.tools import tool from pydantic import BaseModel, Field from openchatbi import config from openchatbi.utils import log logger = logging.getLogger(__name__) class TimeseriesForecastInput(BaseModel): """Input schema for time series forecasting tool.""" reasoning: str = Field(description="Reason for using time series forecasting and what insights you expect to gain") input_data: list[float | int | dict[str, Any]] = Field( description="Time series data as list of numbers or structured data with timestamps and values" ) forecast_window: int = Field( default=24, description="Number of future time points to predict (1-200)", ge=1, le=200 ) frequency: str = Field(default="hourly", description="Time series frequency: hourly, daily, weekly, monthly, etc.") input_length: int | None = Field( default=None, description="Optional limit on input data length to use for prediction" ) target_column: str = Field( default="value", description="Column name to forecast for structured data (default: 'value')" ) def _check_service_health(service_url: str) -> bool: """Check if time series forecasting service is available.""" try: response = requests.get(f"{service_url}/health", timeout=5) if response.status_code == 200: health_data = response.json() return health_data.get("model_initialized", False) return False except requests.exceptions.RequestException: return False def check_forecast_service_health() -> bool: try: service_url = config.get().timeseries_forecasting_service_url return _check_service_health(service_url) except ValueError: # Configuration not loaded yet (e.g., in tests) return False def _call_timeseries_service( service_url: str, input_data: list[float | int | dict[str, Any]], forecast_window: int, frequency: str, input_length: int | None = None, target_column: str = "value", ) -> dict[str, Any]: """Call time series forecasting service.""" try: # Prepare request payload payload = {"input": input_data, "forecast_window": forecast_window, "frequency": frequency} if input_length is not None: payload["input_len"] = input_length if target_column != "value": payload["target_column"] = target_column # Make request to time series forecasting service response = requests.post(f"{service_url}/predict", json=payload, timeout=30) if response.status_code == 200: return response.json() else: return { "error": f"Service returned status {response.status_code}: {response.text}", "status": "http_error", "status_code": response.status_code, } except requests.exceptions.Timeout: return {"error": "Request timeout - forecasting service took too long to respond", "status": "error"} except requests.exceptions.RequestException as e: return {"error": f"Failed to connect to forecasting service: {str(e)}", "status": "error"} except Exception as e: return {"error": f"Unexpected error: {str(e)}", "status": "error"} def _format_forecast_result(result: dict[str, Any], reasoning: str, input_data_length: int) -> str: """Format the forecasting result for the agent.""" if result.get("status") == "error": return f"""Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')} Please check: 1. Time series forecasting service is running (docker run -p 8765:8765 timeseries-forecasting) 2. Model load successfully 3. Try again if timeout""" elif result.get("status") == "http_error": if result.get("status_code") == 400: return f"""Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')} Please check: 1. Input data format is correct 2. input_len is set to larger when the input data length is not enough 3. Forecast window is reasonable (1-200)""" else: return f"""Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')}""" predictions = result.get("predictions", []) forecast_window = result.get("forecast_window", len(predictions)) frequency = result.get("frequency", "unknown") if not predictions: return "No predictions were generated. Please check your input data." # Calculate basic statistics sum_predictions = sum(predictions) avg_prediction = sum_predictions / len(predictions) if predictions else 0 min_prediction = min(predictions) if predictions else 0 max_prediction = max(predictions) if predictions else 0 # Create formatted response response_parts = [ "✅ Time Series Forecasting Completed", "", "Forecast Summary:", f" • Input data points: {input_data_length}", f" • Forecast window: {forecast_window} {frequency.lower()} periods", "", "Predictions:", f" • Average forecast: {avg_prediction:.2f}", f" • Sum: {sum_predictions:.2f}", f" • Range: {min_prediction:.2f} to {max_prediction:.2f}", f" • Total periods forecasted: {len(predictions)}", "", "Detailed Forecast Values:", ] for i, pred in enumerate(predictions): period_label = f"Period {i + 1}" response_parts.append(f" • {period_label}: {pred:.2f}") return "\n".join(response_parts) @tool("timeseries_forecast", args_schema=TimeseriesForecastInput, return_direct=False, infer_schema=True) def timeseries_forecast( reasoning: str, input_data: list[float | int | dict[str, Any]], forecast_window: int = 24, frequency: str = "hourly", input_length: int | None = None, target_column: str = "value", ) -> str: """Forecast future values for time series data using advanced deep learning models. This tool uses state-of-the-art deep learning models (currently transformer based) to predict future values based on historical time series data. Perfect for sales forecasting, demand planning, trend analysis, and business intelligence. Args: reasoning: Explanation of why forecasting is needed and what insights are expected input_data: Historical time series data as list of numbers or structured data with timestamps forecast_window: Number of future time points to predict (1-200, default: 24) frequency: Time series frequency - hourly, daily, weekly, monthly, etc. input_length: Optional limit on how much historical data to use for prediction target_column: Column name to forecast for structured data (default: 'value') Returns: str: Formatted forecast results with predictions, statistics, and interpretation guidance Examples: - Sales forecasting: Predict next month's daily sales based on historical data - Demand planning: Forecast product demand for inventory management - Financial planning: Predict revenue, costs, or other financial metrics - Operational planning: Forecast website traffic, resource usage, etc. """ # Get service URL from config service_url = config.get().timeseries_forecasting_service_url log(f"Time Series Forecast: {reasoning}") log(f"Input data points: {len(input_data)}, Forecast window: {forecast_window}, Frequency: {frequency}") # Validate input data if not input_data: return "Error: Input data cannot be empty. Please provide historical time series data." if len(input_data) < 3: return "Error: Need at least 3 data points for reliable forecasting. Please provide more historical data." # Check service availability if not _check_service_health(service_url): return """Time Series Forecasting Service Unavailable. The time series forecasting service is not running or not in service. """ # Call the forecasting service result = _call_timeseries_service( service_url=service_url, input_data=input_data, forecast_window=forecast_window, frequency=frequency, input_length=input_length, target_column=target_column, ) # Format and return the result return _format_forecast_result(result, reasoning, len(input_data)) ================================================ FILE: openchatbi/utils.py ================================================ """Utility functions for OpenChatBI.""" import json import sys import uuid from pathlib import Path from typing import Any from fastapi import HTTPException from fastapi.responses import FileResponse from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.messages import AIMessage, AIMessageChunk, RemoveMessage, ToolMessage from langchain_core.vectorstores import VectorStore from rank_bm25 import BM25Okapi from regex import regex from openchatbi.graph_state import AgentState from openchatbi.text_segmenter import _segmenter def log(args) -> None: """Log messages to stderr for debugging.""" print(args, file=sys.stderr, flush=True) def get_text_from_content(content: str | list[str | dict]) -> str: """Extract text from various content formats. Args: content: String, list of strings, or list of dicts with 'text' key. Returns: str: Extracted text content. """ if isinstance(content, str): return content elif isinstance(content, list): if isinstance(content[0], str): return "".join(content) elif isinstance(content[0], dict): return "".join([item.get("text", "") for item in content]) return "" def get_text_from_message_chunk(chunk: AIMessageChunk) -> str: """Extract content from an AIMessageChunk. Args: chunk (AIMessageChunk): The message chunk to extract text from. Returns: str: Extracted text content or empty string. """ if not isinstance(chunk, AIMessageChunk) or not hasattr(chunk, "content") or not chunk.content: return "" return get_text_from_content(chunk.content) def extract_json_from_answer(answer: str) -> dict: """Extract the first JSON object from a string answer. Args: answer (str): String that may contain JSON objects. Returns: dict: Parsed JSON object or empty dict if none found. """ pattern = regex.compile(r"\{(?:[^{}]+|(?R))*\}") matches = pattern.findall(answer) json_result = matches[0] if matches else "{}" return json.loads(json_result) def get_report_download_response(filename: str) -> FileResponse: """Get FileResponse for downloading a report file. Args: filename: The filename of the report to download Returns: FileResponse: Response object for file download Raises: HTTPException: Various HTTP errors for invalid requests """ try: # Import config here to avoid circular imports from openchatbi import config # Get report directory from config report_dir = config.get().report_directory file_path = Path(report_dir) / filename # Check if file exists and is within the report directory if not file_path.exists(): raise HTTPException(status_code=404, detail="Report file not found") if not file_path.is_file(): raise HTTPException(status_code=400, detail="Invalid file path") # Ensure the file is within the report directory (security check) try: file_path.resolve().relative_to(Path(report_dir).resolve()) except ValueError: raise HTTPException(status_code=403, detail="Access denied") from None # Determine media type based on file extension media_type_map = { ".md": "text/markdown", ".csv": "text/csv", ".txt": "text/plain", ".json": "application/json", ".html": "text/html", ".xml": "application/xml", } file_extension = file_path.suffix.lower() media_type = media_type_map.get(file_extension, "application/octet-stream") return FileResponse(path=str(file_path), media_type=media_type, filename=filename) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to download report: {str(e)}") from e def _create_chroma_from_texts( texts: list[str], embedding, collection_name: str, metadatas, collection_metadata: dict, chroma_dir: str, ): """Helper function to create Chroma client from texts.""" return Chroma.from_texts( texts, embedding, metadatas=metadatas, collection_name=collection_name, collection_metadata=collection_metadata, persist_directory=chroma_dir, ) def create_vector_db( texts: list[str], embedding=None, collection_name: str = "langchain", metadatas=None, collection_metadata: dict = None, chroma_db_path: str = None, ) -> VectorStore: """Create or reuse a Chroma vector database. Args: texts (List[str]): Text documents to index. embedding: Embedding function to use. collection_name (str): Name of the collection. metadatas: Metadata for each document. collection_metadata (dict): Collection-level metadata. chroma_db_path (str): Path to chroma database file. Returns: Chroma: Vector database instance. """ # fallback to Simple vector store using BM25 if no embedding model configured if not embedding: return SimpleStore(texts, metadatas) chroma_dir = chroma_db_path or "./.chroma_db" client = Chroma( collection_name, persist_directory=chroma_dir, embedding_function=embedding, collection_metadata=collection_metadata, ) use_cache = False existing_docs = None try: # Try to get documents to check if collection exists and has content existing_docs = client.get() if not existing_docs["documents"]: print(f"Init new client from text for {collection_name}...") else: # Check if cached texts match the input texts cached_texts = existing_docs["documents"] # Compare texts: check count first, then content if len(cached_texts) != len(texts): print( f"Texts count mismatch for {collection_name} " f"(cached: {len(cached_texts)}, input: {len(texts)}). Recreating collection..." ) else: # Compare content by sorting both lists to handle order differences sorted_cached = sorted(cached_texts) sorted_input = sorted(texts) if sorted_cached != sorted_input: print(f"Cache content mismatch for {collection_name}. Recreating collection...") else: print(f"Re-use collection for {collection_name}") use_cache = True except Exception: # If collection doesn't exist or any error, create new one print(f"Init new client from text for {collection_name}...") if not use_cache: # Clear existing collection before recreating to avoid data duplication if existing_docs and existing_docs["documents"]: try: client.reset_collection() print(f"Cleared existing collection {collection_name} before recreating...") except Exception as e: # If reset fails, log and continue with recreation print(f"Warning: Failed to clear collection {collection_name}: {e}") client = _create_chroma_from_texts( texts, embedding, collection_name, metadatas, collection_metadata, chroma_dir ) return client def recover_incomplete_tool_calls(state: AgentState) -> list: """Recover from incomplete tool calls by creating message operations to insert ToolMessages correctly. When the graph execution is interrupted (e.g., by kill or app restart) during tool execution, the state can end up with AIMessage containing tool_calls but no corresponding ToolMessage responses. This function detects such cases and creates the necessary message operations to insert failure ToolMessages in the correct position (right after the AIMessage). Args: state (AgentState): The current graph state containing messages. Returns: list: Message operations to insert recovery ToolMessages, or empty list if no recovery needed. """ messages = state.get("messages", []) if not messages: return [] # Find the last AIMessage with tool_calls last_ai_message = None last_ai_index = -1 for i in range(len(messages) - 1, -1, -1): if isinstance(messages[i], AIMessage) and messages[i].tool_calls: last_ai_message = messages[i] last_ai_index = i break if not last_ai_message: return [] # Check if there are any ToolMessages after this AIMessage tool_call_ids = {call["id"] for call in last_ai_message.tool_calls} handled_tool_call_ids = set() # Look for ToolMessages that respond to these tool calls for msg in messages[last_ai_index + 1 :]: if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: handled_tool_call_ids.add(msg.tool_call_id) # Find unhandled tool calls unhandled_tool_call_ids = tool_call_ids - handled_tool_call_ids if not unhandled_tool_call_ids: return [] # All tool calls have responses # Create failure ToolMessages for unhandled tool calls recovery_messages = [] for tool_call in last_ai_message.tool_calls: if tool_call["id"] in unhandled_tool_call_ids: failure_msg = ToolMessage( content=f"Tool `{tool_call['name']}` execution was interrupted due to system restart or process termination. Please retry the operation.", tool_call_id=tool_call["id"], ) recovery_messages.append(failure_msg) # Build operations to insert recovery messages in correct position operations = [] messages_after_ai = messages[last_ai_index + 1 :] # Collect IDs that will be removed removed_ids = set() # If there are messages after the AIMessage, we need to remove them first if messages_after_ai: for msg in messages_after_ai: operations.append(RemoveMessage(id=msg.id)) removed_ids.add(msg.id) # Add recovery messages (they will be inserted right after the AIMessage) operations.extend(recovery_messages) # Re-add the messages that were after the AIMessage (if any) # CRITICAL: Must regenerate Message ids if matches a RemoveMessage to prevent RemoveMessage from being cancelled if messages_after_ai: for msg in messages_after_ai: # Only regenerate ID if this message's ID was removed if msg.id in removed_ids: # Create a copy with new ID to prevent the RemoveMessage from being discarded new_msg = msg.model_copy(update={"id": str(uuid.uuid4())}) operations.append(new_msg) else: # Keep original message as-is if ID wasn't removed operations.append(msg) log(f"Recovered {len(recovery_messages)} incomplete tool calls") return operations class SimpleStore(VectorStore): """Simple vector store using BM25 for text retrieval without embeddings.""" def __init__( self, texts: list[str], metadatas: list[dict] | None = None, ids: list[str] | None = None, ): """Initialize SimpleStore with texts. Args: texts: List of text documents to store. metadatas: Optional list of metadata dicts for each document. ids: Optional list of IDs for each document. """ self.texts = texts self.metadatas = metadatas or [{} for _ in texts] self.ids = ids or [str(uuid.uuid4()) for _ in texts] # Create Document objects self.documents = [ Document(id=doc_id, page_content=text, metadata=meta) for doc_id, text, meta in zip(self.ids, self.texts, self.metadatas) ] # Tokenize texts and create BM25 index self.tokenized_corpus = [self._tokenize(text) for text in texts] # BM25Okapi doesn't support empty corpus, so set to None if empty self.bm25 = BM25Okapi(self.tokenized_corpus) if texts else None def _tokenize(self, text: str) -> list[str]: """Tokenize text for BM25 indexing using TextSegmenter. Args: text: Text to tokenize. Returns: List of tokens. """ return _segmenter.cut(text) def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> list[Document]: """Search for documents similar to the query using BM25. Args: query: Query text. k: Number of documents to return. **kwargs: Additional arguments (unused). Returns: List of most similar Document objects. """ if not self.texts: return [] # Tokenize query tokenized_query = self._tokenize(query) # Get BM25 scores scores = self.bm25.get_scores(tokenized_query) # Get top-k indices top_k = min(k, len(scores)) top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] # Return corresponding documents return [self.documents[i] for i in top_k_indices] def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> list[tuple[Document, float]]: """Search for documents similar to the query with BM25 scores. Args: query: Query text. k: Number of documents to return. **kwargs: Additional arguments (unused). Returns: List of (Document, score) tuples. """ if not self.texts: return [] # Tokenize query tokenized_query = self._tokenize(query) # Get BM25 scores scores = self.bm25.get_scores(tokenized_query) # Get top-k items top_k = min(k, len(scores)) top_k_items = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k] # Return (Document, score) tuples return [(self.documents[i], score) for i, score in top_k_items] def _select_relevance_score_fn(self): """Return relevance score function for BM25. BM25 scores are already relevance scores, so return identity function. """ return lambda score: score def add_texts( self, texts: list[str], metadatas: list[dict] | None = None, *, ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: """Add texts to the store. Args: texts: Texts to add. metadatas: Optional metadata for each text. ids: Optional IDs for each text. **kwargs: Additional arguments (unused). Returns: List of IDs of added texts. """ if metadatas is None: metadatas = [{} for _ in texts] if ids is None: ids = [str(uuid.uuid4()) for _ in texts] # Add to existing data self.texts.extend(texts) self.metadatas.extend(metadatas) self.ids.extend(ids) # Create new Document objects new_documents = [ Document(id=doc_id, page_content=text, metadata=meta) for doc_id, text, meta in zip(ids, texts, metadatas) ] self.documents.extend(new_documents) # Update BM25 index new_tokenized = [self._tokenize(text) for text in texts] self.tokenized_corpus.extend(new_tokenized) self.bm25 = BM25Okapi(self.tokenized_corpus) return ids def delete(self, ids: list[str] | None = None, **kwargs: Any) -> bool | None: """Delete documents by IDs. Args: ids: List of document IDs to delete. **kwargs: Additional arguments (unused). Returns: True if deletion successful, False otherwise. """ if ids is None: return False # Find indices to delete indices_to_delete = [i for i, doc_id in enumerate(self.ids) if doc_id in ids] if not indices_to_delete: return False # Remove items in reverse order to maintain indices for idx in sorted(indices_to_delete, reverse=True): del self.texts[idx] del self.metadatas[idx] del self.ids[idx] del self.documents[idx] del self.tokenized_corpus[idx] # Rebuild BM25 index if self.tokenized_corpus: self.bm25 = BM25Okapi(self.tokenized_corpus) else: self.bm25 = None return True def get_by_ids(self, ids: list[str], /) -> list[Document]: """Get documents by their IDs. Args: ids: List of document IDs to retrieve. Returns: List of Document objects. """ id_to_doc = {doc.id: doc for doc in self.documents} return [id_to_doc[doc_id] for doc_id in ids if doc_id in id_to_doc] @classmethod def from_texts( cls, texts: list[str], embedding: Any = None, # Unused but required by interface metadatas: list[dict] | None = None, *, ids: list[str] | None = None, **kwargs: Any, ) -> "SimpleStore": """Create SimpleStore from texts. Args: texts: List of texts. embedding: Unused (SimpleStore doesn't use embeddings). metadatas: Optional metadata for each text. ids: Optional IDs for each text. **kwargs: Additional arguments (unused). Returns: SimpleStore instance. """ return cls(texts, metadatas, ids) def max_marginal_relevance_search( self, query: str, k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: query: Text to look up documents similar to. k: Number of `Document` objects to return. fetch_k: Number of `Document` objects to fetch to pass to MMR algorithm. lambda_mult: Number between `0` and `1` that determines the degree of diversity among the results with `0` corresponding to maximum diversity and `1` to minimum diversity. **kwargs: Arguments to pass to the search method. Returns: List of `Document` objects selected by maximal marginal relevance. """ if not self.texts: return [] # Get initial candidates using BM25 similarity search candidates = self.similarity_search_with_score(query, k=fetch_k, **kwargs) if not candidates: return [] if len(candidates) <= k: return [doc for doc, _ in candidates] # Normalize BM25 scores to [0, 1] for proper MMR calculation scores = [score for _, score in candidates] min_score = min(scores) if scores else 0 max_score = max(scores) if scores else 1 score_range = max_score - min_score if max_score > min_score else 1 normalized_candidates = [(doc, (score - min_score) / score_range) for doc, score in candidates] # MMR implementation following standard algorithm selected = [] remaining = list(range(len(normalized_candidates))) # Select documents iteratively using MMR formula while len(selected) < k and remaining: best_mmr_score = float("-inf") best_idx = -1 best_remaining_idx = -1 for i, doc_idx in enumerate(remaining): candidate_doc, relevance_score = normalized_candidates[doc_idx] # Calculate maximum similarity to already selected documents max_similarity = 0.0 if selected: max_similarity = max( self._calculate_similarity(candidate_doc, normalized_candidates[sel_idx][0]) for sel_idx in selected ) # Standard MMR formula: λ * Sim(q, d) - (1-λ) * max(Sim(d, s)) for s in selected mmr_score = lambda_mult * relevance_score - (1 - lambda_mult) * max_similarity if mmr_score > best_mmr_score: best_mmr_score = mmr_score best_idx = doc_idx best_remaining_idx = i if best_idx != -1: selected.append(best_idx) remaining.pop(best_remaining_idx) return [normalized_candidates[idx][0] for idx in selected] def _calculate_similarity(self, doc1: Document, doc2: Document) -> float: """Calculate similarity between two documents using Jaccard similarity. Args: doc1: First document. doc2: Second document. Returns: Similarity score between 0 and 1 (higher means more similar). """ tokens1 = set(self._tokenize(doc1.page_content)) tokens2 = set(self._tokenize(doc2.page_content)) # Calculate Jaccard similarity intersection = len(tokens1 & tokens2) union = len(tokens1 | tokens2) return intersection / union if union > 0 else 0.0 ================================================ FILE: pyproject.toml ================================================ [project] name = "openchatbi" version = "0.2.2" description = "OpenChatBI - Natural language business intelligence powered by LLMs for intuitive data analysis and SQL generation" authors = [ { name = "Yu Zhong", email = "zhongyu8@gmail.com" }, ] license = { text = "MIT" } readme = "README.md" keywords = [ "business intelligence", "bi", "analytics", "llm", "gpt", "ai", "machine learning", "nlp", "text2sql", "agent", "query data", "talk to data", "analyze data", "data agent", "database", "langchain", "langgraph", "natural language", "conversational ai", "timeseries", "forecasting", "prediction" ] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Intended Audience :: End Users/Desktop", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Database", "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Office/Business", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.11", "Operating System :: OS Independent", "License :: OSI Approved :: MIT License", ] requires-python = ">=3.11,<4.0" dependencies = [ "requests>=2.31.0,<3.0.0", "langgraph>=0.4.7,<1.0.0", "langchain-openai>=0.3.18,<1.0.0", "langchain-anthropic>=0.3.13,<1.0.0", "langchain-community>=0.3.27,<1.0.0", "langgraph-checkpoint-sqlite>=2.0.11", "langchain-chroma>=0.2.5", "langchain-mcp-adapters>=0.1.9,<0.2.0", "langmem>=0.0.29", "sqlalchemy>=2.0.41,<3.0.0", "sqlalchemy-trino>=0.5.0", "aiosqlite>=0.21.0", "pyhive[presto]>=0.7.0", "rank-bm25>=0.2.2,<1.0.0", "python-levenshtein>=0.27.1", "gradio>=5.43.1,<6.0.0", "streamlit>=1.49.1,<2.0.0", "RestrictedPython>=8.0,<9.0", "docker>=7.0.0,<8.0.0", "pandas>=2.2.0,<3.0.0", "numpy>=2.3.0,<3.0.0", "matplotlib>=3.10.6,<4.0.0", "seaborn>=0.13.0,<1.0.0", "plotly>=5.17.0,<6.0.0", "json5>=0.10.0,<1.0.0", "jieba>=0.42.1", # Note: jieba is not compatible with Python 3.12+ ] [project.urls] Homepage = "https://github.com/zhongyu09/openchatbi" Repository = "https://github.com/zhongyu09/openchatbi" Documentation = "https://github.com/zhongyu09/openchatbi/tree/main" "Bug Tracker" = "https://github.com/zhongyu09/openchatbi/issues" [project.optional-dependencies] docs = [ "sphinx>=8.2.3,<9.0.0", "sphinx-rtd-theme>=3.0.0,<4.0.0", "sphinx-autodoc-typehints>=2.5.0,<3.0.0", "myst_parser", "autodoc-pydantic", ] test = [ "pytest>=7.4.0,<9.0.0", "pytest-mock>=3.14.0,<4.0.0", "pytest-asyncio>=0.23.8,<1.0.0", "pytest-sugar>=1.0.0,<2.0.0", "pytest-cov>=6.0.0,<7.0.0", "aioresponses>=0.7.7,<1.0.0", "responses>=0.25.3,<1.0.0", "langsmith[pytest]>=0.4.8,<1.0.0", "openevals>=0.1.0,<1.0.0", ] dev = [ "openchatbi[test,docs]", "black>=24.10.0,<25.0.0", "mypy>=1.13.0,<2.0.0", "ruff>=0.8.0,<1.0.0", "pre-commit>=4.0.1,<5.0.0", "bandit>=1.8.6,<2.0.0", "types-setuptools>=75.6.0.20241126", "twine>=6.0.0,<7.0.0", ] [tool.uv] managed = true dev-dependencies = [ "openchatbi[dev]", ] [build-system] requires = ["hatchling>=1.26.0"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["openchatbi"] [tool.hatch.build.targets.sdist] include = [ "/openchatbi", "/tests", "/README.md", "/LICENSE", ] [tool.hatch.metadata] allow-direct-references = true [tool.black] line-length = 120 target-version = ["py311"] include = '\.pyi?$' exclude = ''' /( \.git | \.mypy_cache | \.tox | \.venv | _build | buck-out | build | dist )/ ''' skip-string-normalization = false skip-magic-trailing-comma = false preview = false [tool.ruff] line-length = 120 target-version = "py311" exclude = [ ".git", ".mypy_cache", ".tox", ".venv", "_build", "buck-out", "build", "dist", ] [tool.ruff.lint] select = [ "E", # pycodestyle errors "W", # pycodestyle warnings "F", # pyflakes "I", # isort "C", # flake8-comprehensions "B", # flake8-bugbear "UP", # pyupgrade ] ignore = [ "E501", # line too long, handled by black "B008", # do not perform function calls in argument defaults "C901", # too complex ] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] "tests/**/*" = ["B011"] [tool.mypy] python_version = "3.11" strict = true warn_return_any = true warn_unused_configs = true disallow_untyped_defs = true disallow_incomplete_defs = true check_untyped_defs = true disallow_untyped_decorators = true no_implicit_optional = true warn_redundant_casts = true warn_unused_ignores = true warn_no_return = true warn_unreachable = true ignore_missing_imports = true show_error_codes = true [tool.pytest.ini_options] minversion = "7.0" addopts = [ "--strict-markers", "--strict-config", "--cov=openchatbi", "--cov-report=term-missing", "--cov-report=html", "--cov-report=xml", ] testpaths = ["tests"] markers = [ "unit: Unit tests", "integration: Integration tests", "slow: Slow tests that may take several seconds", "requires_db: Tests that require database connection", "requires_llm: Tests that require LLM service", "asyncio: Asynchronous tests" ] filterwarnings = [ "error", "ignore::UserWarning", "ignore::DeprecationWarning", ] [tool.coverage.run] source = ["openchatbi"] omit = [ "*/tests/*", "*/test_*.py", "setup.py", ] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "def __repr__", "if self.debug:", "if settings.DEBUG", "raise AssertionError", "raise NotImplementedError", "if 0:", "if __name__ == .__main__.:", "class .*\\bProtocol\\):", "@(abc\\.)?abstractmethod", ] ================================================ FILE: run_streamlit_ui.py ================================================ #!/usr/bin/env python3 """ Launch script for the Streamlit-based OpenChatBI interface. Usage: python run_streamlit_ui.py This will start the Streamlit server on http://localhost:8501 """ import os import subprocess import sys def main(): """Launch the Streamlit UI""" # Change to the project directory project_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(project_dir) print("🚀 Starting OpenChatBI Streamlit UI...") print("📍 URL: http://localhost:8501") print("⏹️ Press Ctrl+C to stop the server") print("-" * 50) try: # Run streamlit with the new UI file subprocess.run( [ sys.executable, "-m", "streamlit", "run", "sample_ui/streamlit_ui.py", "--server.port=8501", "--server.address=localhost", ], check=True, ) except KeyboardInterrupt: print("\n👋 Stopping Streamlit server...") except subprocess.CalledProcessError as e: print(f"❌ Error starting Streamlit: {e}") print("\n💡 Make sure Streamlit is installed:") print(" pip install streamlit") except FileNotFoundError: print("❌ Python or Streamlit not found") print("\n💡 Make sure Python and Streamlit are installed:") print(" pip install streamlit") if __name__ == "__main__": main() ================================================ FILE: run_tests.py ================================================ #!/usr/bin/env python3 """Test runner script for OpenChatBI.""" import argparse import subprocess import sys def run_command(cmd, description): """Run a command and return the result.""" print(f"\\n{'=' * 60}") print(f"Running: {description}") print(f"Command: {' '.join(cmd)}") print(f"{'=' * 60}") result = subprocess.run(cmd, capture_output=True, text=True) if result.stdout: print("STDOUT:") print(result.stdout) if result.stderr: print("STDERR:") print(result.stderr) if result.returncode != 0: print(f"❌ {description} failed with return code {result.returncode}") return False else: print(f"✅ {description} passed") return True def main(): """Main test runner function.""" parser = argparse.ArgumentParser(description="Run OpenChatBI tests") parser.add_argument("--unit", action="store_true", help="Run only unit tests") parser.add_argument("--integration", action="store_true", help="Run only integration tests") parser.add_argument("--coverage", action="store_true", help="Run with coverage report") parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") parser.add_argument("--fast", action="store_true", help="Skip slow tests") parser.add_argument("--lint", action="store_true", help="Run linting checks") parser.add_argument("--type-check", action="store_true", help="Run type checking") parser.add_argument("--all", action="store_true", help="Run all checks (tests, lint, type-check)") parser.add_argument("--file", help="Run specific test file") args = parser.parse_args() # Determine test command base_cmd = ["uv", "run", "pytest"] if args.verbose: base_cmd.append("-v") if args.coverage: base_cmd.extend(["--cov=openchatbi", "--cov-report=html", "--cov-report=term-missing"]) if args.unit: base_cmd.extend(["-m", "unit"]) elif args.integration: base_cmd.extend(["-m", "integration"]) elif args.fast: base_cmd.extend(["-m", "not slow"]) if args.file: base_cmd.append(f"tests/{args.file}") success = True # Run tests if not args.lint and not args.type_check: success &= run_command(base_cmd, "Unit Tests") # Run linting if requested if args.lint or args.all: lint_commands = [ (["uv", "run", "black", "--check", "."], "Black formatting check"), (["uv", "run", "isort", "--check-only", "."], "Import sorting check"), (["uv", "run", "ruff", "check", "."], "Ruff linting"), (["uv", "run", "bandit", "-r", "openchatbi/"], "Security scanning"), ] for cmd, desc in lint_commands: success &= run_command(cmd, desc) # Run type checking if requested if args.type_check or args.all: success &= run_command(["uv", "run", "mypy", "openchatbi/"], "Type checking") # Run all tests if --all is specified if args.all: test_commands = [ (["uv", "run", "pytest", "-m", "unit", "-v"], "Unit Tests"), (["uv", "run", "pytest", "-m", "integration", "-v"], "Integration Tests"), (["uv", "run", "pytest", "--cov=openchatbi", "--cov-report=html"], "Coverage Report"), ] for cmd, desc in test_commands: success &= run_command(cmd, desc) # Print summary print(f"\\n{'=' * 60}") if success: print("🎉 All checks passed!") sys.exit(0) else: print("❌ Some checks failed!") sys.exit(1) if __name__ == "__main__": main() ================================================ FILE: sample_api/async_api.py ================================================ """Async API for streaming chat responses from OpenChatBI.""" import asyncio from typing import Any from collections import defaultdict from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from langchain_core.messages import AIMessageChunk from pydantic import BaseModel from openchatbi import config from openchatbi.agent_graph import build_agent_graph_async from openchatbi.utils import get_report_download_response # Session state storage: session_id -> state sessions = defaultdict(dict) # Graphs keyed by provider name graphs: dict[str, Any] = {} graphs_lock = asyncio.Lock() async def get_or_build_graph(provider: str | None): """Get (or lazily build) a graph for the requested provider.""" key = provider or "__default__" if key in graphs: return graphs[key] async with graphs_lock: if key in graphs: return graphs[key] graphs[key] = await build_agent_graph_async(config.get().catalog_store, llm_provider=provider) return graphs[key] @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifespan events.""" # Startup: Initialize the async graph graphs["__default__"] = await build_agent_graph_async(config.get().catalog_store) yield # Shutdown: cleanup if needed graphs.clear() app = FastAPI(lifespan=lifespan) class UserRequest(BaseModel): """Request model for streaming chat.""" input: str user_id: str | None = "default" session_id: str | None = "default" provider: str | None = None @app.post("/chat/stream") async def chat_stream(req: UserRequest): """Stream chat responses from the agent graph.""" user_id = req.user_id or "default" session_id = req.session_id or "default" provider = req.provider # Create user-session ID just like in UI user_session_id = f"{user_id}-{session_id}" stream_input = {"messages": [("user", req.input)]} config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}} try: graph = await get_or_build_graph(provider) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e async def event_generator(): """Generate streaming events from the graph.""" async for _namespace, event_type, event_value in graph.astream( stream_input, config=config, stream_mode=["updates", "messages"], subgraphs=True ): text = "" if event_type == "messages": message_chunk = event_value[0] if isinstance(message_chunk, AIMessageChunk): text = message_chunk.content elif event_value.get("llm_node") and event_value["llm_node"].get("final_answer"): text = event_value["llm_node"]["final_answer"] if text: yield text return StreamingResponse(event_generator(), media_type="text/plain") @app.get("/user/{user_id}/memories") async def get_user_memories(user_id: str): """Get all memories for a specific user.""" try: # Import required modules for memory access import json from openchatbi.tool.memory import get_async_memory_store # Get the async memory store memory_store = await get_async_memory_store() memories = [] namespace = ("memories", user_id) try: # Search for all memories for this user search_results = memory_store.search(namespace) for item in search_results: # Parse the memory data try: content = json.loads(item.value.decode("utf-8")) if isinstance(item.value, bytes) else item.value except (json.JSONDecodeError, AttributeError): content = str(item.value) memory_data = { "key": item.key, "content": content, "namespace": str(namespace), "created_at": getattr(item, "created_at", "Unknown"), "updated_at": getattr(item, "updated_at", "Unknown"), } memories.append(memory_data) return {"user_id": user_id, "total_memories": len(memories), "memories": memories} except Exception as e: raise HTTPException(status_code=500, detail=f"Error retrieving memories: {str(e)}") from e except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to access memory store: {str(e)}") from e @app.get("/api/download/report/{filename}") async def download_report(filename: str): """Download a saved report file.""" return get_report_download_response(filename) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) ================================================ FILE: sample_ui/async_graph_manager.py ================================================ """Common AsyncGraphManager for UIs.""" from typing import Any from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from openchatbi import config from openchatbi.agent_graph import build_agent_graph_async from openchatbi.tool.memory import cleanup_async_memory_store, get_async_memory_store, setup_async_memory_store from openchatbi.utils import log class AsyncGraphManager: """Manages the async graph and checkpointer lifecycle""" def __init__(self): self.checkpointer = None self.graph = None # Default graph (backwards compatible) self.graphs: dict[str, Any] = {} self._context_manager = None self._memory_store = None self._initialized = False async def initialize(self): """Initialize the graph and checkpointer""" if self._initialized: return try: # Setup async memory store await setup_async_memory_store() # Initialize checkpointer self._context_manager = AsyncSqliteSaver.from_conn_string("checkpoints.db") self.checkpointer = await self._context_manager.__aenter__() # Cache store for graph builds self._memory_store = await get_async_memory_store() self._initialized = True # Build default graph for backwards compatibility self.graph = await self.get_graph() log("Graph initialized successfully") except Exception as e: self._initialized = False log(f"Failed to initialize graph: {e}") raise async def get_graph(self, llm_provider: str | None = None): """Get or build a graph for the requested LLM provider.""" if not self._initialized: await self.initialize() key = llm_provider or "__default__" if key in self.graphs: return self.graphs[key] graph = await build_agent_graph_async( config.get().catalog_store, checkpointer=self.checkpointer, memory_store=self._memory_store, memory_tools=None, # Let graph builder create provider-appropriate tools llm_provider=llm_provider, ) self.graphs[key] = graph return graph async def cleanup(self): """Cleanup resources""" if self.checkpointer is not None and self._context_manager is not None: try: await self._context_manager.__aexit__(None, None, None) await cleanup_async_memory_store() log("Graph cleaned up successfully") except Exception as e: log(f"Error during cleanup: {e}") finally: self.checkpointer = None self.graph = None self.graphs = {} self._context_manager = None self._memory_store = None self._initialized = False ================================================ FILE: sample_ui/memory_ui.py ================================================ """Memory listing UI for OpenChatBI using FastAPI and Gradio.""" import json from typing import Any import gradio as gr import uvicorn from fastapi import FastAPI from sample_ui.style import custom_css def get_thread_memory_store() -> Any: """Create a thread-safe memory store connection.""" try: import pysqlite3 as sqlite3 except ImportError: import sqlite3 from langgraph.store.sqlite import SqliteStore from openchatbi import config conn = sqlite3.connect("memory.db", check_same_thread=False) conn.isolation_level = None # Use autocommit mode to avoid transaction conflicts store = SqliteStore(conn, index={"dims": 1536, "embed": config.get().embedding_model, "fields": ["text"]}) try: store.setup() except Exception: pass # Store might already be set up return store, conn def list_all_memories() -> list[dict[str, Any]]: """ Retrieve all memories from the memory store. Returns: List of memory items with their metadata """ try: memory_store, conn = get_thread_memory_store() memories = [] try: # Use search with partial namespace to find all memory items items = memory_store.search(("memories",), limit=1000) for item in items: memory_data = { "namespace": item.namespace, "key": item.key, "value": item.value, "created_at": getattr(item, "created_at", "Unknown"), "updated_at": getattr(item, "updated_at", "Unknown"), } memories.append(memory_data) except Exception as e: return [{"error": f"Failed to retrieve memories: {str(e)}"}] finally: conn.close() return memories except Exception as e: return [{"error": f"Failed to access memory store: {str(e)}"}] def format_memories_for_display(memories: list[dict[str, Any]]) -> str: """ Format memories for display in the Gradio interface. Args: memories: List of memory items Returns: Formatted string for display """ if not memories: return "No memories found." if len(memories) == 1 and "error" in memories[0]: return f"Error: {memories[0]['error']}" formatted = [] for i, memory in enumerate(memories, 1): if "error" in memory: formatted.append(f"**Error:** {memory['error']}") continue formatted.append(f"## Memory {i}") formatted.append(f"**Namespace:** {memory['namespace']}") formatted.append(f"**Key:** {memory['key']}") # Format the value nicely value = memory["value"] if isinstance(value, dict): try: value_str = json.dumps(value, indent=2) formatted.append(f"**Content:**\n```json\n{value_str}\n```") except: formatted.append(f"**Content:** {str(value)}") else: formatted.append(f"**Content:** {str(value)}") formatted.append(f"**Created:** {memory['created_at']}") formatted.append(f"**Updated:** {memory['updated_at']}") formatted.append("---") return "\n".join(formatted) def refresh_memories() -> list[list[str]]: """Refresh and return formatted memories.""" memories = list_all_memories() return format_memories_for_display(memories) def delete_memory_by_key(namespace_str: str, key: str) -> str: """ Delete a memory by namespace and key. Args: namespace_str: String representation of namespace (e.g., "('memories', 'user1')") key: Memory key to delete Returns: Status message """ try: import ast memory_store, conn = get_thread_memory_store() try: # Parse namespace string back to tuple namespace = ast.literal_eval(namespace_str) # Delete the item memory_store.delete(namespace, key) return f"Successfully deleted memory: {key} from namespace {namespace}" finally: conn.close() except Exception as e: return f"Failed to delete memory: {str(e)}" # ---------- FastAPI ---------- app = FastAPI() # ---------- Gradio UI ---------- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown("## 🧠 Memory Store Viewer") gr.Markdown("View and manage long-term memories stored in the OpenChatBI system.") with gr.Row(): with gr.Column(scale=3): memories_display = gr.Markdown(value=refresh_memories(), elem_id="memories-display") with gr.Column(scale=1): gr.Markdown("### Actions") refresh_btn = gr.Button("🔄 Refresh Memories", variant="primary") gr.Markdown("### Delete Memory") namespace_input = gr.Textbox( label="Namespace", placeholder="('memories', 'user_id')", info="Copy the exact namespace from the memory list", ) key_input = gr.Textbox( label="Key", placeholder="memory_key", info="Copy the exact key from the memory list" ) delete_btn = gr.Button("🗑️ Delete Memory", variant="stop") delete_status = gr.Textbox(label="Status", interactive=False) # Event handlers refresh_btn.click(fn=refresh_memories, outputs=[memories_display]) delete_btn.click(fn=delete_memory_by_key, inputs=[namespace_input, key_input], outputs=[delete_status]).then( fn=refresh_memories, outputs=[memories_display] ) # ---------- Application Startup ---------- # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/memory") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8001) ================================================ FILE: sample_ui/plotly_utils.py ================================================ """Plotly utilities for generating charts from visualization DSL.""" from io import StringIO from typing import Any import pandas as pd import plotly.express as px import plotly.graph_objects as go def create_plotly_chart(data_csv: str, visualization_dsl: dict[str, Any]) -> go.Figure: """Create a plotly chart from CSV data and visualization DSL. Args: data_csv: CSV string containing the data visualization_dsl: Dictionary containing chart configuration Returns: Plotly Figure object """ if not data_csv or not visualization_dsl: return create_empty_chart("No data available") if "error" in visualization_dsl: return create_empty_chart(f"Visualization error: {visualization_dsl['error']}") try: # Parse CSV data df = pd.read_csv(StringIO(data_csv)) if df.empty: return create_empty_chart("No data to visualize") chart_type = visualization_dsl.get("chart_type", "table") config = visualization_dsl.get("config", {}) layout = visualization_dsl.get("layout", {}) # Create chart based on type if chart_type == "line": return create_line_chart(df, config, layout) elif chart_type == "bar": return create_bar_chart(df, config, layout) elif chart_type == "pie": return create_pie_chart(df, config, layout) elif chart_type == "scatter": return create_scatter_chart(df, config, layout) elif chart_type == "histogram": return create_histogram_chart(df, config, layout) elif chart_type == "box": return create_box_chart(df, config, layout) elif chart_type == "table": return create_table_chart(df, config, layout) else: return create_empty_chart(f"Unsupported chart type: {chart_type}") except Exception as e: return create_empty_chart(f"Chart generation error: {str(e)}") def create_line_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure: """Create a line chart.""" x_col = config.get("x") y_col = config.get("y") color_col = config.get("color") if not x_col or x_col not in df.columns: return create_empty_chart("Missing required x column for line chart") # Handle multiple y columns case if isinstance(y_col, list): # Multiple metrics - need to melt the data if not all(col in df.columns for col in y_col): return create_empty_chart("Some y columns missing from data") # Melt the dataframe to long format for multiple series melted_df = df.melt(id_vars=[x_col], value_vars=y_col, var_name="metric", value_name="value") fig = px.line(melted_df, x=x_col, y="value", color="metric") else: # Single y column if not y_col or y_col not in df.columns: return create_empty_chart("Missing required y column for line chart") # Check if color column exists and is valid if color_col and color_col in df.columns: fig = px.line(df, x=x_col, y=y_col, color=color_col) else: fig = px.line(df, x=x_col, y=y_col) fig.update_layout(**layout) return fig def create_bar_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure: """Create a bar chart.""" x_col = config.get("x") y_col = config.get("y") if not x_col or x_col not in df.columns: return create_empty_chart("Missing required x column for bar chart") # Handle multiple y columns case if isinstance(y_col, list): # Multiple metrics - need to melt the data if not all(col in df.columns for col in y_col): return create_empty_chart("Some y columns missing from data") # Melt the dataframe to long format for multiple series melted_df = df.melt(id_vars=[x_col], value_vars=y_col, var_name="metric", value_name="value") fig = px.bar(melted_df, x=x_col, y="value", color="metric") else: # Single y column if not y_col or y_col not in df.columns: return create_empty_chart("Missing required y column for bar chart") fig = px.bar(df, x=x_col, y=y_col) fig.update_layout(**layout) return fig def create_pie_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure: """Create a pie chart.""" labels_col = config.get("labels") values_col = config.get("values") if not labels_col or not values_col or labels_col not in df.columns or values_col not in df.columns: return create_empty_chart("Missing required columns for pie chart") fig = px.pie(df, names=labels_col, values=values_col) fig.update_layout(**layout) return fig def create_scatter_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure: """Create a scatter plot.""" x_col = config.get("x") y_col = config.get("y") if not x_col or not y_col or x_col not in df.columns or y_col not in df.columns: return create_empty_chart("Missing required columns for scatter plot") fig = px.scatter(df, x=x_col, y=y_col) fig.update_layout(**layout) return fig def create_histogram_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure: """Create a histogram.""" x_col = config.get("x") nbins = config.get("nbins", 20) if not x_col or x_col not in df.columns: return create_empty_chart("Missing required column for histogram") fig = px.histogram(df, x=x_col, nbins=nbins) fig.update_layout(**layout) return fig def create_box_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure: """Create a box plot.""" y_col = config.get("y") x_col = config.get("x") if not y_col or y_col not in df.columns: return create_empty_chart("Missing required column for box plot") if x_col and x_col in df.columns: fig = px.box(df, x=x_col, y=y_col) else: fig = px.box(df, y=y_col) fig.update_layout(**layout) return fig def create_table_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure: """Create a table display.""" columns = config.get("columns", list(df.columns)) # Limit to first 100 rows for display display_df = df.head(100) fig = go.Figure( data=[ go.Table( header=dict(values=columns, fill_color="lightblue", align="left"), cells=dict( values=[display_df[col] for col in columns if col in display_df.columns], fill_color="white", align="left", ), ) ] ) fig.update_layout(**layout) return fig def create_empty_chart(message: str) -> go.Figure: """Create an empty chart with a message.""" fig = go.Figure() fig.add_annotation( text=message, xref="paper", yref="paper", x=0.5, y=0.5, xanchor="center", yanchor="middle", showarrow=False, font=dict(size=16), ) fig.update_layout( title="Chart Generation Issue", xaxis=dict(showgrid=False, showticklabels=False, zeroline=False), yaxis=dict(showgrid=False, showticklabels=False, zeroline=False), ) return fig def visualization_dsl_to_gradio_plot(data_csv: str, visualization_dsl: dict[str, Any]) -> tuple[go.Figure, str]: """Convert visualization DSL to Gradio-compatible plotly figure. Args: data_csv: CSV string containing the data visualization_dsl: Dictionary containing chart configuration Returns: Tuple of (plotly figure, description string) """ fig = create_plotly_chart(data_csv, visualization_dsl) if visualization_dsl: chart_type = visualization_dsl.get("chart_type", "unknown") layout_title = visualization_dsl.get("layout", {}).get("title", f"{chart_type.title()} Chart") description = f"Generated {chart_type} visualization: {layout_title}" else: description = "Data table view" return fig, description def create_inline_chart_markdown(data_csv: str, visualization_dsl: dict[str, Any]) -> str: """Create a simplified markdown representation of the chart for inline display. This creates a text-based summary with a clickable link to show the interactive chart. """ if not data_csv or not visualization_dsl: return "📊 *No visualization data available*" if "error" in visualization_dsl: return f"⚠️ *Visualization error: {visualization_dsl['error']}*" try: from io import StringIO import pandas as pd df = pd.read_csv(StringIO(data_csv)) chart_type = visualization_dsl.get("chart_type", "table") layout = visualization_dsl.get("layout", {}) title = layout.get("title", f"{chart_type.title()} Chart") # Create a text summary with key data points and view instruction summary_lines = [ f"📊 **{title}**", "", f"*Chart Type: {chart_type.title()}* | *Data Points: {len(df)} rows, {len(df.columns)} columns*", "", "✨ **Interactive chart will appear automatically in the chart panel →**", "", ] # Add a small data sample if len(df) > 0: summary_lines.append("**Sample Data:**") summary_lines.append("```") # Show first few rows in a clean format sample_df = df.head(3) summary_lines.append(sample_df.to_string(index=False)) if len(df) > 3: summary_lines.append(f"... and {len(df) - 3} more rows") summary_lines.append("```\n") return "\n".join(summary_lines) except Exception as e: return f"⚠️ *Chart generation error: {str(e)}*" ================================================ FILE: sample_ui/simple_ui.py ================================================ """Simple web UI for OpenChatBI using FastAPI and Gradio.""" from collections import defaultdict import gradio as gr import uvicorn from fastapi import FastAPI from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.types import Command from openchatbi import config from openchatbi.agent_graph import build_agent_graph_sync from openchatbi.tool.memory import get_sync_memory_store from openchatbi.utils import get_report_download_response, log from sample_ui.style import custom_css # Session state storage: session_id -> state session_interrupt = defaultdict(bool) # Use SqliteSaver for persistence sqlite_checkpointer_cm = SqliteSaver.from_conn_string("checkpoints.db") sqlite_checkpointer = sqlite_checkpointer_cm.__enter__() graph = build_agent_graph_sync( config.get().catalog_store, checkpointer=sqlite_checkpointer, memory_store=get_sync_memory_store() ) # ---------- FastAPI ---------- app = FastAPI() # ---------- Gradio UI ---------- def chat_fn(message: str, history: list[tuple[str, str]], user_id: str = "default", session_id: str = "default") -> str: """Chat function for Gradio interface.""" user_session_id = f"{user_id}-{session_id}" config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}} if session_interrupt[user_session_id]: inputs = Command(resume=message) else: inputs = {"messages": [{"role": "user", "content": message}]} # Use synchronous call result = graph.invoke(inputs, config=config) state = graph.get_state(config) if state.interrupts: log(f"state.interrupts: {state.interrupts}") output_content = state.interrupts[0].value.get("text") session_interrupt[user_session_id] = True else: session_interrupt[user_session_id] = False output_content = result["messages"][-1].content return output_content # Create Gradio interface with custom CSS and theme with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown("## 💬 OpenChatBI Agent Chatbot") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot(elem_id="chatbot", label="", bubble_full_width=False, height=600) msg = gr.Textbox(placeholder="Type a message and press Enter", label="Input", elem_id="msg") with gr.Column(scale=1): user_box = gr.Textbox(value="default", label="User ID", interactive=True) session_box = gr.Textbox(value="default", label="Session ID", interactive=True) gr.Markdown( """ **Instructions** - Type a message and press Enter to send - User ID is used for memory isolation between users - Session ID can be used to differentiate between conversations """, elem_id="description", ) def respond( message: str, chat_history: list[tuple[str, str]], user_id: str, session_id: str ) -> tuple[str, list[tuple[str, str]]]: """Handle response in Gradio chat interface.""" response = chat_fn(message, chat_history, user_id, session_id) chat_history.append((message, response)) return "", chat_history msg.submit(respond, [msg, chatbot, user_box, session_box], [msg, chatbot]) # ---------- API Endpoints ---------- @app.get("/api/download/report/{filename}") def download_report(filename: str): """Download a saved report file.""" return get_report_download_response(filename) # ---------- Application Startup ---------- # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/ui") if __name__ == "__main__": try: uvicorn.run(app, host="0.0.0.0", port=8000) finally: # Cleanup checkpointer sqlite_checkpointer_cm.__exit__(None, None, None) ================================================ FILE: sample_ui/streaming_ui.py ================================================ """Gradio-based Streaming UI for OpenChatBI with real-time chat interface.""" import asyncio import sys from collections import defaultdict from contextlib import asynccontextmanager import gradio as gr try: import pysqlite3 as sqlite3 except ImportError: # pragma: no cover import sqlite3 from fastapi import FastAPI from langchain_core.messages import AIMessage sys.modules["sqlite3"] = sqlite3 from langgraph.types import Command from openchatbi.utils import get_report_download_response, get_text_from_message_chunk, log from sample_ui.async_graph_manager import AsyncGraphManager from sample_ui.plotly_utils import create_inline_chart_markdown, visualization_dsl_to_gradio_plot from sample_ui.style import custom_css # Session state storage: user_session_id -> state session_interrupt = defaultdict(bool) # Global event loop for async operations (similar to Streamlit approach) global_event_loop = None # Global graph manager (similar to Streamlit approach) graph_manager = AsyncGraphManager() @asynccontextmanager async def lifespan(app: FastAPI): """Async context manager for FastAPI lifespan""" # Startup await graph_manager.initialize() yield # Shutdown await graph_manager.cleanup() # ---------- FastAPI ---------- app = FastAPI(lifespan=lifespan) # ---------- Gradio UI functions ---------- def get_or_create_event_loop(): """Get or create an independent event loop""" global global_event_loop if global_event_loop is None or global_event_loop.is_closed(): global_event_loop = asyncio.new_event_loop() asyncio.set_event_loop(global_event_loop) return global_event_loop async def _async_respond_helper(message, chat_history, user_id, session_id): """ Helper async function that contains the actual async logic. This will be run in an independent event loop. Collects all responses and returns them as a list. """ responses = [] # Collect all yield values user_session_id = f"{user_id}-{session_id}" full_response = "" plot_figure = None chart_panel_update = gr.update() if session_interrupt[user_session_id]: stream_input = Command(resume=message) else: stream_input = {"messages": [{"role": "user", "content": message}]} config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}} # Ensure graph is available if not graph_manager._initialized: try: await graph_manager.initialize() except Exception as e: log(f"Failed to initialize graph: {e}") chat_history[-1] = (chat_history[-1][0], f"Error: Failed to initialize system - {str(e)}") responses.append(("", chat_history, plot_figure, chart_panel_update)) return responses data_csv = None # Asynchronously iterate through LangGraph stream async for _namespace, event_type, event_value in graph_manager.graph.astream( stream_input, config=config, stream_mode=["updates", "messages"], subgraphs=True, debug=True ): token = "" if event_type == "messages": chunk = event_value[0] metadata = event_value[1] # Keep llm node messages only to avoid duplicates if metadata["langgraph_node"] != "llm_node" or not metadata.get("streaming_tokens", False): continue token = get_text_from_message_chunk(chunk) else: # Process intermediate graph node updates if event_value.get("llm_node"): message_obj = event_value["llm_node"].get("messages")[0] if message_obj and isinstance(message_obj, AIMessage) and message_obj.tool_calls: token = f"\nUse tool: {', '.join(tool['name'] for tool in message_obj.tool_calls)}\n" else: token = "\n" elif event_value.get("information_extraction"): message_obj = event_value["information_extraction"].get("messages")[0] if message_obj.tool_calls: token = f"Use tool: {message_obj.tool_calls[0]['name']}\n" else: token = f"Rewrite question: {event_value['information_extraction'].get('rewrite_question')}\n" elif event_value.get("table_selection"): token = f"Selected tables: {event_value['table_selection'].get('tables')}\n" elif event_value.get("generate_sql"): token = f"SQL: \n ```sql \n{event_value['generate_sql'].get('sql')}\n```\n" elif event_value.get("execute_sql"): token = "Running SQL...\n" data_csv = event_value["execute_sql"].get("data") elif event_value.get("regenerate_sql"): token = f"SQL: \n ```sql \n{event_value['regenerate_sql'].get('sql')}\n```\n" elif event_value.get("generate_visualization"): visualization_dsl = event_value["generate_visualization"].get("visualization_dsl") # Check for visualization data in the final state and embed in response if visualization_dsl and "error" not in visualization_dsl and data_csv: try: plot_figure, plot_description = visualization_dsl_to_gradio_plot(data_csv, visualization_dsl) # Add markdown representation to the chat chart_markdown = create_inline_chart_markdown(data_csv, visualization_dsl) full_response += f"\n\n{chart_markdown}" chat_history[-1] = (chat_history[-1][0], full_response) # Auto-show chart panel when plot is generated chart_panel_update = gr.update(visible=True) responses.append(("", chat_history, plot_figure, chart_panel_update)) except Exception as e: log(f"Visualization generation error: {str(e)}") full_response += f"\n\n⚠️ Visualization error: {str(e)}" chat_history[-1] = (chat_history[-1][0], full_response) responses.append(("", chat_history, plot_figure, chart_panel_update)) # Update chat history with new tokens and collect response if token: full_response += token chat_history[-1] = (chat_history[-1][0], full_response) responses.append(("", chat_history, plot_figure, chart_panel_update)) # Get final state and check for visualization data state = await graph_manager.graph.aget_state(config) final_state_values = state.values if state.interrupts: log(f"state.interrupts: {state.interrupts}") output_content = state.interrupts[0].value.get("text") if "buttons" in state.interrupts[0].value: output_content += str(state.interrupts[0].value.get("buttons")) full_response += output_content chat_history[-1] = (chat_history[-1][0], full_response) session_interrupt[user_session_id] = True responses.append(("", chat_history, plot_figure, chart_panel_update)) else: session_interrupt[user_session_id] = False return responses def respond(message, chat_history, user_id, session_id="default"): """ Synchronous callback for Gradio Chatbot with streaming updates. This function processes user input and streams responses from the LangGraph agent. Returns: message_input, chat_history, plot_figure, chart_panel_visibility """ # Add a placeholder in chat history chat_history.append((message, "")) plot_figure = None chart_panel_update = gr.update() yield "", chat_history, plot_figure, chart_panel_update # Stream updates to UI # Get or create independent event loop loop = get_or_create_event_loop() # Run the async helper in the independent event loop try: responses = loop.run_until_complete(_async_respond_helper(message, chat_history, user_id, session_id)) # Yield all collected responses for response in responses: yield response except Exception as e: log(f"Error in respond: {e}") import traceback traceback.print_exc() chat_history[-1] = (chat_history[-1][0], f"Error: {str(e)}") yield "", chat_history, plot_figure, chart_panel_update # ---------- Memory Management Functions ---------- def list_user_memories(user_id: str) -> str: """List all memories for a specific user.""" try: import json try: import pysqlite3 as sqlite3 except ImportError: import sqlite3 from langgraph.store.sqlite import SqliteStore from openchatbi import config # Create a new connection in this thread to avoid SQLite threading issues conn = sqlite3.connect("memory.db", check_same_thread=False) conn.isolation_level = None # Use autocommit mode to avoid transaction conflicts thread_memory_store = SqliteStore( conn, index={"dims": 1536, "embed": config.get().embedding_model, "fields": ["text"]} ) try: thread_memory_store.setup() except Exception: pass # Store might already be set up memories = [] namespace = ("memories", user_id) try: # Use search with namespace to find all items for this user items = thread_memory_store.search(namespace, limit=1000) for item in items: memory_data = { "key": item.key, "value": item.value, "created_at": getattr(item, "created_at", "Unknown"), "updated_at": getattr(item, "updated_at", "Unknown"), } memories.append(memory_data) except Exception as e: return f"No memories found for user {user_id} or error: {str(e)}" finally: conn.close() if not memories: return f"No memories found for user {user_id}" formatted = [f"## Memories for User: {user_id}\n"] for i, memory in enumerate(memories, 1): formatted.append(f"### Memory {i}") formatted.append(f"**Key:** {memory['key']}") value = memory["value"] if isinstance(value, dict): try: value_str = json.dumps(value, indent=2) formatted.append(f"**Content:**\n```json\n{value_str}\n```") except ValueError: formatted.append(f"**Content:** {str(value)}") else: formatted.append(f"**Content:** {str(value)}") formatted.append(f"**Created:** {memory['created_at']}") formatted.append(f"**Updated:** {memory['updated_at']}") formatted.append("---") return "\n".join(formatted) except Exception as e: return f"Error accessing memories: {str(e)}" # ---------- Gradio UI Blocks ---------- # Create Gradio interface with custom CSS and theme with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown("## 💬 OpenChatBI Agent Chatbot with Streaming & On-Demand Visualization") with gr.Tabs(): with gr.TabItem("💬 Chat"): with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( elem_id="chatbot", label="Chat", bubble_full_width=False, height=500, show_label=False, sanitize_html=False, render_markdown=True, ) msg = gr.Textbox(placeholder="Type a message and press Enter", label="Input", elem_id="msg") with gr.Column(scale=2, visible=False) as chart_panel: with gr.Row(): with gr.Column(scale=3): gr.Markdown("### 📊 Interactive Chart") with gr.Column(scale=1): hide_chart_btn = gr.Button("✖️ Hide", elem_id="hide-chart-btn", size="sm") plot = gr.Plot(label="", visible=True, show_label=False) with gr.Column(scale=1): user_box = gr.Textbox(value="default", label="User ID", interactive=True) session_box = gr.Textbox(value="default", label="Session ID", interactive=True) show_chart_btn = gr.Button("📊 Show Chart Panel", variant="secondary") gr.Markdown( """ **Instructions** - Type a data question and press Enter - Supports streaming output (real-time display) - Click chart links in chat to view interactive charts - Use 'Show Chart Panel' to make panel visible - Session ID can be used to differentiate between conversations """, elem_id="description", ) def show_chart_panel(): """Show the chart panel.""" return gr.update(visible=True) def hide_chart_panel(): """Hide the chart panel.""" return gr.update(visible=False) # Register async submit handler for message input with plot output msg.submit(respond, [msg, chatbot, user_box, session_box], [msg, chatbot, plot, chart_panel]) show_chart_btn.click(show_chart_panel, outputs=[chart_panel]) hide_chart_btn.click(hide_chart_panel, outputs=[chart_panel]) with gr.TabItem("🧠 Memory Store"): gr.Markdown("### Long-term Memory Viewer") gr.Markdown("View memories stored for each user in the system.") with gr.Row(): with gr.Column(scale=3): memory_display = gr.Markdown( value="Enter a User ID and click 'Load Memories' to view stored memories.", elem_id="memory-display", ) with gr.Column(scale=1): memory_user_input = gr.Textbox(label="User ID", placeholder="default", value="default") load_memories_btn = gr.Button("🔍 Load Memories", variant="primary") # Event handler for loading memories load_memories_btn.click(fn=list_user_memories, inputs=[memory_user_input], outputs=[memory_display]) # ---------- API Endpoints ---------- @app.get("/api/download/report/{filename}") async def download_report(filename: str): """Download a saved report file.""" return get_report_download_response(filename) # ---------- Application Startup ---------- # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/ui") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) ================================================ FILE: sample_ui/streamlit_ui.py ================================================ """Streamlit-based Streaming UI for OpenChatBI with collapsible thinking sections.""" import asyncio import sys import traceback import uuid from pathlib import Path import plotly.graph_objects as go try: import pysqlite3 as sqlite3 except ImportError: # pragma: no cover import sqlite3 import streamlit as st from langchain_core.messages import AIMessage sys.modules["sqlite3"] = sqlite3 from langgraph.types import Command from openchatbi import config as openchatbi_config from openchatbi.llm.llm import list_llm_providers from openchatbi.utils import get_text_from_message_chunk, log from sample_ui.plotly_utils import visualization_dsl_to_gradio_plot from sample_ui.async_graph_manager import AsyncGraphManager # Configuration st.set_page_config(page_title="OpenChatBI - Streamlit Interface", page_icon="💬", layout="wide") # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "graph_manager" not in st.session_state: st.session_state.graph_manager = AsyncGraphManager() if "session_interrupts" not in st.session_state: st.session_state.session_interrupts = {} if "event_loop" not in st.session_state: st.session_state.event_loop = None async def process_user_message_stream( message: str, user_id: str, session_id: str, llm_provider: str | None, thinking_container, response_container ): """ Process user message through the OpenChatBI graph with real-time updates Updates the thinking_container and response_container as processing happens """ thinking_steps = [] final_response = "" plot_figure = None # Initialize graph if needed if not st.session_state.graph_manager._initialized: await st.session_state.graph_manager.initialize() graph = await st.session_state.graph_manager.get_graph(llm_provider) user_session_id = f"{user_id}-{session_id}" # Check for interrupts if st.session_state.session_interrupts.get(user_session_id, False): stream_input = Command(resume=message) else: stream_input = {"messages": [{"role": "user", "content": message}]} config = {"configurable": {"thread_id": user_session_id, "user_id": user_id}} data_csv = None # Use empty container for real-time updates thinking_placeholder = thinking_container.empty() # Build content chronologically - all events in time order base_content = "🔄 **Processing...**\n\n" chronological_content = "" # All content in time order def update_display(): full_content = base_content + chronological_content thinking_placeholder.markdown(full_content) # Initial display update_display() # Stream through the graph async for _namespace, event_type, event_value in graph.astream( stream_input, config=config, stream_mode=["updates", "messages"], subgraphs=True, debug=True ): if event_type == "messages": chunk = event_value[0] metadata = event_value[1] # Keep llm node messages only to avoid duplicates if metadata["langgraph_node"] != "llm_node" or not metadata.get("streaming_tokens", False): continue token = get_text_from_message_chunk(chunk) if token: final_response += token # Add to thinking content during processing if len(final_response) == len(token): chronological_content += "\n**🤖 AI Response:** " chronological_content += token update_display() else: # Process tool calls and intermediate steps step_description = "" if event_value.get("llm_node"): message_obj = event_value["llm_node"].get("messages")[0] if message_obj and isinstance(message_obj, AIMessage) and message_obj.tool_calls: step_description = f"🛠️ Using tools: {', '.join(tool['name'] for tool in message_obj.tool_calls)}" elif event_value.get("information_extraction"): message_obj = event_value["information_extraction"].get("messages")[0] if message_obj and message_obj.tool_calls: step_description = f"🛠️ Using tool: {message_obj.tool_calls[0]['name']}" else: rewrite_q = event_value["information_extraction"].get("rewrite_question") if rewrite_q: step_description = f"📝 Rewriting question: {rewrite_q}" elif event_value.get("table_selection"): tables = event_value["table_selection"].get("tables") if tables: step_description = f"🗂️ Selected tables: {tables}" elif event_value.get("generate_sql"): sql = event_value["generate_sql"].get("sql") if sql: step_description = f"💾 Generated SQL:\n```sql\n{sql}\n```" elif event_value.get("execute_sql"): step_description = "⚡ Executing SQL query..." data_csv = event_value["execute_sql"].get("data") elif event_value.get("regenerate_sql"): sql = event_value["regenerate_sql"].get("sql") if sql: step_description = f"🔄 Regenerated SQL:\n```sql\n{sql}\n```" elif event_value.get("generate_visualization"): visualization_dsl = event_value["generate_visualization"].get("visualization_dsl") if visualization_dsl and "error" not in visualization_dsl and data_csv: try: plot_figure, plot_description = visualization_dsl_to_gradio_plot(data_csv, visualization_dsl) step_description = f"📊 Generated visualization: {plot_description}" except Exception as e: step_description = f"⚠️ Visualization error: {str(e)}" if step_description: thinking_steps.append(step_description) # Append new step to chronological content in time order step_number = len(thinking_steps) # Ensure proper spacing before new step if chronological_content and not chronological_content.endswith("\n\n"): chronological_content += "\n\n" chronological_content += f"**Step {step_number}:** {step_description}\n\n" update_display() # Check for interrupts in final state state = await graph.aget_state(config) if state.interrupts: log(f"State interrupts: {state.interrupts}") output_content = state.interrupts[0].value.get("text", "") if "buttons" in state.interrupts[0].value: output_content += str(state.interrupts[0].value.get("buttons")) final_response += output_content # Append interrupt content to chronological content chronological_content += output_content update_display() st.session_state.session_interrupts[user_session_id] = True else: st.session_state.session_interrupts[user_session_id] = False # Final update - add completion message to chronological content # Add some spacing if the last content didn't end with newlines if not chronological_content.endswith("\n\n"): chronological_content += "\n\n" chronological_content += "✅ **Analysis complete!**" update_display() # Extract final answer (last part without tool calls) and display outside thinking if final_response: # Find the last occurrence of tool usage to separate final answer lines = final_response.split("\n") final_answer_lines = [] collecting_final = False for line in reversed(lines): if "Use tool:" in line or "Using tools:" in line or "Using tool:" in line: break final_answer_lines.append(line) collecting_final = True if collecting_final and final_answer_lines: # Reverse back to correct order final_answer_lines.reverse() final_answer_text = "\n".join(final_answer_lines).strip() if final_answer_text: with response_container: processed_final_answer_text = process_download_links(final_answer_text) render_content_with_downloads(processed_final_answer_text) # Final update to response container - only show plot if available (text response is in thinking container) with response_container: if plot_figure: st.plotly_chart(plot_figure, use_container_width=True, key=str(uuid.uuid4())) # Extract final answer for separate storage final_answer_text = "" if final_response: lines = final_response.split("\n") final_answer_lines = [] collecting_final = False for line in reversed(lines): if "Use tool:" in line or "Using tools:" in line or "Using tool:" in line: break final_answer_lines.append(line) collecting_final = True if collecting_final and final_answer_lines: final_answer_lines.reverse() final_answer_text = "\n".join(final_answer_lines).strip() return final_response, plot_figure, thinking_steps, chronological_content, final_answer_text def get_available_reports() -> list[str]: """Get list of available report files for download.""" try: # Import config here to avoid circular imports from openchatbi import config report_dir = Path(config.get().report_directory) if not report_dir.exists(): return [] # Get all files in the report directory report_files = [] for file_path in report_dir.iterdir(): if file_path.is_file(): report_files.append(file_path.name) return sorted(report_files) except Exception as e: st.error(f"Error accessing reports: {str(e)}") return [] def get_report_file_content(filename: str) -> tuple[bytes | None, str | None]: """Get report file content for download. Returns: tuple: (file_content_bytes, mime_type) or (None, None) if error """ try: # Import config here to avoid circular imports from openchatbi import config report_dir = Path(config.get().report_directory) file_path = report_dir / filename # Security check - ensure file is within report directory if not file_path.exists() or not file_path.is_file(): st.error(f"Report file not found: {filename}") return None, None try: file_path.resolve().relative_to(report_dir.resolve()) except ValueError: st.error("Access denied to file") return None, None # Determine MIME type mime_type_map = { ".md": "text/markdown", ".csv": "text/csv", ".txt": "text/plain", ".json": "application/json", ".html": "text/html", ".xml": "application/xml", } file_extension = file_path.suffix.lower() mime_type = mime_type_map.get(file_extension, "application/octet-stream") # Read file content with open(file_path, "rb") as f: content = f.read() return content, mime_type except Exception as e: st.error(f"Error reading report file: {str(e)}") return None, None def process_download_links(content: str) -> str: """Process download links in content and replace them with Streamlit-compatible ones. Args: content: Message content that may contain download links Returns: str: Content with download links replaced """ import re if not content: return content # Pattern to match both full URLs and path-only download links # Matches: http://localhost:8501/api/download/report/filename.ext or /api/download/report/filename.ext download_pattern = r"(?:https?://[^/\s]+)?/api/download/report/([^)\s\]<>]+)" def replace_link(match): filename = match.group(1) # Return a placeholder that we'll replace with actual download button return f"[DOWNLOAD_LINK:{filename}]" processed_content = re.sub(download_pattern, replace_link, content) # Debug log to see if processing worked if processed_content != content: st.write(f"🔍 Debug: Processed download links - found {content.count('/api/download/report/')} links") return processed_content def render_content_with_downloads(content: str) -> None: """Render content and replace download placeholders with actual download buttons.""" import re # Split content by download placeholders download_pattern = r"\[DOWNLOAD_LINK:([^)]+)\]" parts = re.split(download_pattern, content) for i, part in enumerate(parts): if i % 2 == 0: # Regular content if part.strip(): st.markdown(part) else: # Download link filename filename = part file_content, mime_type = get_report_file_content(filename) if file_content is not None: st.download_button( label=f"📥 Download {filename}", data=file_content, file_name=filename, mime=mime_type, key=f"inline_download_{filename}_{hash(content)}", ) else: st.error(f"❌ Could not load report: {filename}") def display_message_with_thinking( role: str, content: str, thinking_steps: list[str] = None, plot_figure: go.Figure = None ): """Display a message with collapsible thinking section""" with st.chat_message(role): if thinking_steps and role == "assistant": # Create thinking section with all content inside with st.expander("💭 AI Thinking Process", expanded=False): for i, step in enumerate(thinking_steps, 1): st.markdown(f"**Step {i}:** {step}") if content: st.markdown("**🤖 AI Response:**") render_content_with_downloads(content) st.success("✅ Analysis complete") # For non-assistant messages, display content normally elif content and role != "assistant": render_content_with_downloads(content) # Display plot if available (outside thinking container) if plot_figure: st.plotly_chart(plot_figure, use_container_width=True, key=str(uuid.uuid4())) # Main UI st.title("💬 OpenChatBI - Streamlit UI") st.markdown("*AI-powered Business Intelligence Chat with Thinking*") # Sidebar for configuration with st.sidebar: st.header("⚙️ Configuration") user_id = st.text_input("User ID", value="default", help="Unique identifier for the user session") session_id = st.text_input("Session ID", value="default", help="Session identifier for conversation continuity") # Optional multi-provider support llm_provider = None provider_options = list_llm_providers() if provider_options: try: default_provider = getattr(openchatbi_config.get(), "llm_provider", None) except Exception: default_provider = None default_index = provider_options.index(default_provider) if default_provider in provider_options else 0 llm_provider = st.selectbox( "LLM Provider", options=provider_options, index=default_index, help="Select which configured LLM provider to use for this session", ) st.markdown("---") st.markdown( """ **💡 How to use:** - Type your business questions - Watch the AI thinking process in collapsible sections - View generated charts and analyses - Use different session IDs for separate conversations """ ) if st.button("🗑️ Clear Chat History"): st.session_state.messages = [] st.rerun() st.markdown("---") st.markdown("### 📁 Report Downloads") # Get available reports available_reports = get_available_reports() if available_reports: selected_report = st.selectbox( "Select a report to download:", options=[""] + available_reports, help="Choose a report file to download" ) if selected_report and st.button("📥 Download Report"): file_content, mime_type = get_report_file_content(selected_report) if file_content is not None: st.download_button( label=f"💾 Save {selected_report}", data=file_content, file_name=selected_report, mime=mime_type, key=f"download_{selected_report}", ) st.success(f"✅ {selected_report} is ready for download!") else: st.info("No reports available for download.") # Display chat history for msg in st.session_state.messages: if msg["type"] == "chronological_message": # Display chronological content in expander - all collapsed after completion with st.chat_message(msg["role"]): with st.expander("💭 AI Thinking Process", expanded=False): st.markdown(msg["chronological_content"]) # Extract and display final answer text outside thinking if msg.get("final_answer"): render_content_with_downloads(msg["final_answer"]) # Display plot if available (outside thinking container) if msg.get("plot_figure"): st.plotly_chart(msg["plot_figure"], use_container_width=True, key=str(uuid.uuid4())) elif msg["type"] == "thinking_message": display_message_with_thinking( msg["role"], msg["content"], msg.get("thinking_steps", []), msg.get("plot_figure") ) else: with st.chat_message(msg["role"]): if msg["type"] == "text": render_content_with_downloads(msg["content"]) elif msg["type"] == "plot" and msg.get("plot_figure"): st.plotly_chart(msg["plot_figure"], use_container_width=True, key=str(uuid.uuid4())) # Chat input if prompt := st.chat_input("Ask me anything about your data..."): # Add user message st.session_state.messages.append({"role": "user", "type": "text", "content": prompt}) # Display user message immediately with st.chat_message("user"): st.markdown(prompt) # Process assistant response with real-time streaming with st.chat_message("assistant"): # Create thinking and response containers thinking_expander = st.expander("💭 AI Thinking Process...", expanded=True) thinking_container = thinking_expander.container() response_container = st.container() # Process the message asynchronously with real-time updates try: # Reuse the same event loop to avoid binding issues if st.session_state.event_loop is None or st.session_state.event_loop.is_closed(): st.session_state.event_loop = asyncio.new_event_loop() asyncio.set_event_loop(st.session_state.event_loop) loop = st.session_state.event_loop final_response, plot_figure, thinking_steps, full_chronological_content, final_answer = ( loop.run_until_complete( process_user_message_stream( prompt, user_id, session_id, llm_provider, thinking_container, response_container ) ) ) # No need to create another expander - content is already shown in real-time # Process download links in the content before storing processed_chronological_content = process_download_links(full_chronological_content) processed_final_answer = process_download_links(final_answer) if final_answer else final_answer # Store the complete message with the processed content st.session_state.messages.append( { "role": "assistant", "type": "chronological_message", "chronological_content": processed_chronological_content, "final_answer": processed_final_answer, "plot_figure": plot_figure, } ) # Trigger rerun to collapse the thinking section st.rerun() except Exception as e: traceback.print_exc() st.error(f"❌ Error processing request: {str(e)}") error_content = f"❌ Error: {str(e)}" processed_error_content = process_download_links(error_content) st.session_state.messages.append({"role": "assistant", "type": "text", "content": processed_error_content}) # Cleanup on session end def cleanup_session(): """Cleanup resources when session ends""" if "graph_manager" in st.session_state: try: # Use the same event loop for cleanup if st.session_state.event_loop and not st.session_state.event_loop.is_closed(): loop = st.session_state.event_loop loop.run_until_complete(st.session_state.graph_manager.cleanup()) loop.close() st.session_state.event_loop = None except Exception as e: log(f"Error during session cleanup: {e}") # Register cleanup (this is a simplified approach - in production you might want more robust cleanup) import atexit atexit.register(cleanup_session) ================================================ FILE: sample_ui/style.py ================================================ # Custom CSS for styling the chat interface custom_css = """ #chatbot { height: 600px !important; font-family: "Inter", "Helvetica Neue", sans-serif; } #chatbot .wrap.svelte-1cl84sx { background: #f5f7fa; border-radius: 12px; padding: 8px; } .message.user { background-color: #d1e9ff !important; border-radius: 12px 12px 0 12px; margin: 4px 0; padding: 10px 14px; font-size: 15px; } .message.bot { background-color: #ffffff !important; border-radius: 12px 12px 12px 0; margin: 4px 0; padding: 10px 14px; font-size: 15px; box-shadow: 0px 1px 3px rgba(0,0,0,0.08); } #msg { font-family: "Inter", "Helvetica Neue", sans-serif; font-size: 15px; } #description { font-family: "Inter", "Helvetica Neue", sans-serif; font-size: 14px; color: #374151; line-height: 1.6; } """ ================================================ FILE: tests/README.md ================================================ # OpenChatBI Test Suite This directory contains comprehensive unit tests for the OpenChatBI project. The test suite is built using pytest and follows modern Python testing best practices. ## Test Structure ``` tests/ ├── __init__.py # Test package initialization ├── conftest.py # Shared fixtures and configuration ├── README.md # This file │ ├── Core Module Tests ├── test_config_loader.py # Configuration loading tests ├── test_graph_state.py # State management tests ├── test_utils.py # Utility function tests │ ├── Catalog System Tests ├── test_catalog_store.py # Catalog store interface tests ├── test_catalog_loader.py # Database catalog loading tests │ ├── Text2SQL Pipeline Tests ├── test_text2sql_extraction.py # Information extraction tests ├── test_text2sql_generate_sql.py # SQL generation tests ├── test_text2sql_schema_linking.py # Schema linking tests ├── test_text2sql_visualization.py # Data visualization tests │ ├── Tool Tests ├── test_tools_ask_human.py # Human interaction tool tests ├── test_tools_run_python_code.py # Python code execution tests ├── test_tools_search_knowledge.py # Knowledge search tests │ ├── Additional Module Tests ├── test_memory.py # Memory management tests ├── test_plotly_utils.py # Plotly utilities tests ├── test_incomplete_tool_calls.py # Incomplete tool call handling tests │ └── Context Management Tests └── context_management/ # Context management test suite (see context_management/README.md) ``` ## Running Tests ### Prerequisites Ensure you have the development dependencies installed: ```bash # Using uv (recommended) uv sync --group dev # Or using pip pip install -e ".[dev]" ``` ### Basic Test Execution ```bash # Run all tests uv run pytest # Run tests with verbose output uv run pytest -v # Run specific test file uv run pytest tests/test_config_loader.py # Run specific test class uv run pytest tests/test_config_loader.py::TestConfigLoader # Run specific test method uv run pytest tests/test_config_loader.py::TestConfigLoader::test_load_config_from_file ``` ### Test Coverage ```bash # Run tests with coverage report uv run pytest --cov=openchatbi --cov-report=html --cov-report=term-missing # View HTML coverage report open htmlcov/index.html ``` ### Test Categories ```bash # Run only fast unit tests (exclude slow integration tests) uv run pytest -m "not slow" # Run tests for specific components uv run pytest tests/test_catalog* -k "catalog" uv run pytest tests/test_text2sql* -k "text2sql" uv run pytest tests/test_tools* -k "tools" # Run context management tests uv run pytest tests/context_management/ # Run memory and utility tests uv run pytest tests/test_memory.py tests/test_plotly_utils.py # Run incomplete tool call tests uv run pytest tests/test_incomplete_tool_calls.py ``` ## Test Configuration ### Environment Variables The test suite uses several environment variables that can be set to customize test behavior: - `OPENCHATBI_TEST_MODE=true` - Enables test mode - `OPENCHATBI_CONFIG_PATH` - Path to test configuration file - `PYTEST_TIMEOUT=300` - Test timeout in seconds ### Fixtures The `conftest.py` file provides shared fixtures used across tests: #### Core Fixtures - `test_config` - Test configuration dictionary - `temp_dir` - Temporary directory for test files - `mock_llm` - Mocked language model for testing - `sample_agent_state` - Sample AgentState for testing #### Catalog Fixtures - `mock_catalog_store` - Mocked catalog store with sample data - `mock_database_engine` - Mocked database engine - `sample_table_info` - Sample table metadata #### Database Fixtures - `mock_presto_connection` - Mocked Presto database connection - `mock_token_service` - Mocked authentication token service ## Writing Tests ### Test Naming Conventions Follow these naming conventions for consistency: ```python # Test files test_.py # Test classes class TestModuleName: # Test methods def test_specific_functionality(self): def test_error_condition_handling(self): def test_edge_case_scenario(self): ``` ### Test Categories Use pytest marks to categorize tests: ```python import pytest @pytest.mark.unit def test_basic_functionality(): """Unit test for basic functionality.""" pass @pytest.mark.integration def test_database_integration(): """Integration test with database.""" pass @pytest.mark.slow def test_performance_benchmark(): """Slow performance test.""" pass @pytest.mark.parametrize("input,expected", [ ("test1", "result1"), ("test2", "result2"), ]) def test_multiple_scenarios(input, expected): """Test multiple input/output scenarios.""" pass ``` ### Mocking Best Practices Use proper mocking for external dependencies: ```python from unittest.mock import Mock, patch, MagicMock # Mock external services @patch('openchatbi.module.external_service') def test_with_external_service(mock_service): mock_service.return_value = "expected_result" # Test implementation # Mock LLM responses def test_llm_integration(mock_llm): mock_llm.invoke.return_value = AIMessage(content="Mock response") # Test implementation ``` ### Async Test Support For testing async functionality: ```python import pytest import asyncio @pytest.mark.asyncio async def test_async_functionality(): """Test asynchronous operations.""" result = await some_async_function() assert result is not None ``` ## Test Data ### Sample Data Files Test data is managed through fixtures and temporary files: ```python def test_with_sample_data(temp_dir): """Test using temporary sample data.""" # Create test data file data_file = temp_dir / "test_data.csv" data_file.write_text("col1,col2\\nval1,val2") # Test with the data assert data_file.exists() ``` ### Mock Responses Common mock responses are defined in fixtures: ```python # SQL generation mock response mock_llm.invoke.return_value = AIMessage( content="SELECT COUNT(*) FROM test_table;" ) # Catalog search mock response mock_catalog.search_tables.return_value = [ {"table_name": "users", "description": "User data"} ] ``` ## Continuous Integration ### GitHub Actions Tests run automatically on: - Pull requests - Pushes to main branch - Scheduled runs (daily) ### Test Matrix Tests run against multiple configurations: - Python versions: 3.11+ - Dependencies: Minimum and latest versions ## Debugging Tests ### Common Issues 1. **Import Errors** ```bash # Ensure package is installed in development mode pip install -e . ``` 2. **Missing Dependencies** ```bash # Install test dependencies pip install -e ".[test]" ``` 3. **Configuration Issues** ```bash # Set test environment variables export OPENCHATBI_TEST_MODE=true ``` ### Debug Output Enable debug output for failing tests: ```bash # Run with debug output uv run pytest -v -s --tb=long # Run with pdb on failures uv run pytest --pdb # Run with coverage debug uv run pytest --cov-report=term-missing -v ``` ## Performance Testing ### Benchmarks Performance tests are marked with `@pytest.mark.slow`: ```bash # Run performance tests uv run pytest -m slow # Skip performance tests uv run pytest -m "not slow" ``` ### Memory Profiling For memory usage testing: ```bash # Install memory profiler pip install memory-profiler # Run with memory profiling uv run pytest --profile-mem ``` ## Contributing ### Adding New Tests 1. Create test file following naming conventions 2. Import required fixtures from `conftest.py` 3. Write comprehensive test cases covering: - Happy path scenarios - Error conditions - Edge cases - Performance considerations 4. Use appropriate mocking for external dependencies 5. Add docstrings explaining test purpose 6. Run tests locally before submitting PR ### Test Review Guidelines When reviewing test PRs: - Ensure adequate test coverage - Verify mock usage is appropriate - Check for test independence - Validate error case handling - Confirm performance test categorization ## Resources - [Pytest Documentation](https://docs.pytest.org/) - [Python unittest.mock](https://docs.python.org/3/library/unittest.mock.html) - [Coverage.py Documentation](https://coverage.readthedocs.io/) - [pytest-asyncio](https://pytest-asyncio.readthedocs.io/) ================================================ FILE: tests/__init__.py ================================================ """Test package for OpenChatBI.""" ================================================ FILE: tests/conftest.py ================================================ """Pytest configuration and shared fixtures.""" import tempfile from collections.abc import Generator from pathlib import Path from typing import Any from unittest.mock import Mock import pytest from langchain_core.language_models import FakeListChatModel from langchain_core.messages import AIMessage, HumanMessage from sqlalchemy import create_engine from openchatbi.catalog.store.file_system import FileSystemCatalogStore from openchatbi.config_loader import ConfigLoader from openchatbi.graph_state import AgentState @pytest.fixture(scope="session") def test_config() -> dict[str, Any]: """Test configuration fixture.""" return { "organization": "TestOrg", "dialect": "presto", "bi_config_file": "test_bi.yaml", "catalog_store": {"store_type": "file_system", "data_path": "./test_data"}, "default_llm": { "class": "langchain_core.language_models.FakeListChatModel", "params": {"responses": ["Test response"]}, }, } @pytest.fixture def temp_dir() -> Generator[Path, None, None]: """Temporary directory fixture.""" with tempfile.TemporaryDirectory() as temp_dir: yield Path(temp_dir) @pytest.fixture def mock_llm() -> FakeListChatModel: """Mock LLM fixture for testing.""" return FakeListChatModel( responses=["SELECT COUNT(*) FROM test_table;", "This is a test SQL query.", "Test analysis result."] ) @pytest.fixture def sample_agent_state() -> AgentState: """Sample agent state for testing.""" return AgentState( messages=[HumanMessage(content="Test query")], sql="SELECT * FROM test_table;", agent_next_node="sql_generation", final_answer="Test data results", ) @pytest.fixture def mock_catalog_store(temp_dir: Path) -> FileSystemCatalogStore: """Mock catalog store fixture.""" # Create test data files test_data_dir = temp_dir / "test_data" test_data_dir.mkdir(exist_ok=True) # Create sample table_columns.csv tables_info_file = test_data_dir / "table_info.yaml" tables_info_file.write_text( """test: test_table: type: fact description: A test table for unit tests user_data: type: fact description: User information table""" ) # Create sample table_columns.csv tables_file = test_data_dir / "table_columns.csv" tables_file.write_text( """db_name,table_name,column_name test,test_table,id test,test_table,name test,user_data,user_id""" ) # Create sample table_spec_columns.csv columns_file = test_data_dir / "table_spec_columns.csv" columns_file.write_text( """db_name,table_name,column_name,type,display_name,description test,test_table,id,bigint,Id,Primary key test,test_table,name,varchar,Name,User name test,user_data,user_id,bigint,User Id,User identifier""" ) # Create sample common_columns.csv common_columns_file = test_data_dir / "common_columns.csv" common_columns_file.write_text( """column_name,type,display_name,description status,varchar,Status,Record status created_at,timestamp,Created At,Creation timestamp updated_at,timestamp,Updated At,Last update timestamp""" ) # Mock data warehouse config data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"} return FileSystemCatalogStore(data_path=str(test_data_dir), data_warehouse_config=data_warehouse_config) @pytest.fixture def mock_database_engine(): """Mock database engine fixture.""" engine = create_engine("sqlite:///:memory:") # Create test tables with engine.connect() as conn: conn.execute("CREATE TABLE test_table (id INTEGER, name TEXT)") conn.execute("INSERT INTO test_table VALUES (1, 'Test User')") conn.commit() return engine @pytest.fixture def sample_table_info() -> dict[str, Any]: """Sample table information fixture.""" return { "test_table": { "columns": [ {"name": "id", "type": "bigint", "description": "Primary key"}, {"name": "name", "type": "varchar", "description": "User name"}, ], "description": "A test table for unit tests", "sql_rule": "Always filter by active status", } } @pytest.fixture def sample_messages() -> list: """Sample message history fixture.""" return [ HumanMessage(content="What's the user count?"), AIMessage(content="I'll help you get the user count from the database."), HumanMessage(content="Show me the SQL query"), ] @pytest.fixture(autouse=True) def reset_config_loader(): """Reset ConfigLoader singleton state before each test.""" # Reset the singleton instance to ensure clean state ConfigLoader._instance = None ConfigLoader._config = None yield # Clean up after test ConfigLoader._instance = None ConfigLoader._config = None @pytest.fixture def mock_config(): """Provide a mock configuration for tests that need it.""" from unittest.mock import MagicMock config_dict = { "organization": "Test Company", "dialect": "presto", "default_llm": MagicMock(), "embedding_model": MagicMock(), "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, "report_directory": "./data", "python_executor": "local", "visualization_mode": "rule", "context_config": {}, } loader = ConfigLoader() loader.set(config_dict) return loader.get() @pytest.fixture(autouse=True) def setup_test_env(monkeypatch, temp_dir): """Setup test environment variables.""" monkeypatch.setenv("OPENCHATBI_CONFIG_PATH", str(temp_dir / "config.yaml")) monkeypatch.setenv("OPENCHATBI_TEST_MODE", "true") class MockTokenService: """Mock token service for testing.""" def __init__(self): self.token = "mock_token_12345" def get_token(self) -> str: return self.token @pytest.fixture def mock_token_service() -> MockTokenService: """Mock token service fixture.""" return MockTokenService() @pytest.fixture def sample_sql_examples() -> list: """Sample SQL examples fixture.""" return [ {"question": "How many users are there?", "sql": "SELECT COUNT(*) FROM users;", "tables": ["users"]}, { "question": "What's the average age?", "sql": "SELECT AVG(age) FROM users WHERE age IS NOT NULL;", "tables": ["users"], }, ] @pytest.fixture def mock_presto_connection(): """Mock Presto connection fixture.""" mock_conn = Mock() mock_cursor = Mock() # Setup cursor behavior mock_cursor.fetchall.return_value = [("table1", "Test table 1"), ("table2", "Test table 2")] mock_cursor.description = [("table_name",), ("description",)] mock_conn.cursor.return_value = mock_cursor mock_conn.execute.return_value = mock_cursor return mock_conn ================================================ FILE: tests/context_management/README.md ================================================ # Context Management Test Suite This directory contains comprehensive tests for the context management functionality in OpenChatBI. ## Test Structure ### 📁 Test Files - **`test_context_manager.py`** - Unit tests for the `ContextManager` class - **`test_context_config.py`** - Tests for context configuration management - **`test_agent_graph_integration.py`** - Integration tests for agent graph with context management - **`test_edge_cases.py`** - Edge case handling - **`test_state_operations.py`** - Tests for state operations and message processing - **`conftest.py`** - Shared pytest fixtures and configuration - **`test_runner.py`** - Custom test runner script ## 🧪 Test Categories ### Unit Tests (`test_context_manager.py`) Tests core functionality of the `ContextManager` class: - ✅ Token estimation and message token calculation - ✅ Tool output trimming (generic, SQL, Python code) - ✅ Conversation summarization - ✅ Context management with sliding window - ✅ Tool wrapper functionality - ✅ Configuration-based behavior **Key test cases:** - `test_trim_sql_output()` - Tests intelligent SQL result trimming - `test_conversation_summary_success()` - Tests LLM-based summarization - `test_manage_context_with_summarization()` - Tests full context management flow ### Configuration Tests (`test_context_config.py`) Tests configuration management and validation: - ✅ Default configuration values - ✅ Custom configuration creation - ✅ Configuration updates - ✅ Edge cases (zero/negative values) - ✅ Different configuration presets **Key test cases:** - `test_update_context_config_multiple_values()` - Tests configuration updates - `test_production_optimized_config()` - Tests realistic production settings ### Integration Tests (`test_agent_graph_integration.py`) Tests integration with the agent graph system: - ✅ Agent router with context management - ✅ Graph building with/without context management - ✅ Tool wrapping in graph context - ✅ Full conversation flow testing - ✅ System message preservation **Key test cases:** - `test_agent_router_with_context_manager()` - Tests router integration - `test_full_context_management_flow()` - Tests end-to-end functionality ### State Operations Tests (`test_state_operations.py`) Tests state manipulation and message processing operations: - ✅ Message trimming and truncation logic - ✅ State updates and modifications - ✅ Message type handling and conversion - ✅ Context state preservation during operations - ✅ Error handling in state operations **Key test cases:** - `test_trim_messages_by_token_count()` - Tests message trimming logic - `test_state_message_processing()` - Tests state message operations - `test_context_state_updates()` - Tests context state modifications ### Edge Cases (`test_edge_cases.py`) Tests system behavior under stress and edge conditions: - ✅ Unicode and encoding edge cases - ✅ Malformed input handling **Key test cases:** - `test_sql_output_edge_cases()` - SQL edge cases - `test_extremely_nested_or_complex_structures()` - Complex data structures ## 🚀 Running Tests ### Using the Test Runner ```bash # Run all tests python tests/context_management/test_runner.py # Run only unit tests python tests/context_management/test_runner.py --type unit # Run with coverage reporting python tests/context_management/test_runner.py --coverage ``` ### Using Pytest Directly ```bash # Run all context management tests pytest tests/context_management/ # Run specific test file pytest tests/context_management/test_context_manager.py # Run with verbose output pytest tests/context_management/ -v # Run with coverage pytest tests/context_management/ --cov=openchatbi.context_manager --cov-report=html ``` ## 📊 Test Markers Tests are organized using pytest markers: - `@pytest.mark.integration` - Integration tests - `@pytest.mark.slow` - Slow-running tests (can be excluded) ## 🎯 Test Coverage Areas ### Core Functionality - [x] Token estimation and management - [x] Message processing and trimming - [x] Conversation summarization - [x] Context compression strategies ### Tool Output Management - [x] SQL output trimming with structure preservation - [x] Python code output handling - [x] Error message preservation - [x] Generic output trimming ### Configuration Management - [x] Default and custom configurations - [x] Configuration validation - [x] Runtime configuration updates - [x] Edge case configurations ### Integration Points - [x] Agent router integration - [x] Graph building integration - [x] Tool wrapper integration - [x] LLM service integration ### Edge Cases - [x] Unicode and encoding issues - [x] Malformed input handling ## 🧩 Fixtures ### Common Fixtures (in `conftest.py`) - `mock_llm` - Mock language model for testing - `standard_config` - Standard test configuration - `minimal_config` - Minimal configuration for edge testing - `sample_conversation` - Sample conversation data - `large_sql_output` - Large SQL output for trimming tests - `error_output` - Sample error output for preservation tests ## 🔧 Extending Tests ### Adding New Test Cases 1. Choose the appropriate test file based on the functionality 2. Use existing fixtures from `conftest.py` 3. Follow the naming convention: `test_feature_description()` 4. Add appropriate markers for categorization ### Adding New Fixtures Add shared fixtures to `conftest.py` if they'll be used across multiple test files. ## 🐛 Debugging Tests ### Common Issues 1. **Mock LLM failures**: Ensure proper mocking of LLM responses 2. **Configuration conflicts**: Use isolated config instances 3. **Memory leaks in large tests**: Force garbage collection with `gc.collect()` ### Debugging Tools ```bash # Run with debugging output pytest tests/context_management/ -v -s # Run single test with full traceback pytest tests/context_management/test_name.py::test_function -v --tb=long # Profile test performance pytest tests/context_management/ --profile ``` ## 📋 Test Results Expected test results: - **Total tests**: ~100+ test cases across 6 test files - **Coverage target**: >95% for context management modules - **State operations tests**: All message processing should work correctly - **Edge cases**: All should handle gracefully without crashes ================================================ FILE: tests/context_management/__init__.py ================================================ """Context management test package.""" # Test package initialization ================================================ FILE: tests/context_management/conftest.py ================================================ """Pytest configuration and fixtures for context management tests.""" from unittest.mock import Mock import pytest from langchain_core.messages import AIMessage from openchatbi.context_config import ContextConfig @pytest.fixture def mock_llm(): """Mock LLM for testing across all test modules.""" llm = Mock() llm.bind_tools = Mock(return_value=llm) return llm @pytest.fixture def mock_llm_with_summary_response(): """Mock LLM that returns a summary response.""" llm = Mock() llm.bind_tools = Mock(return_value=llm) return llm @pytest.fixture def standard_config(): """Standard test configuration.""" return ContextConfig( enabled=True, summary_trigger_tokens=12000, keep_recent_messages=10, max_tool_output_length=2000, max_sql_result_rows=20, max_code_output_lines=50, enable_summarization=True, enable_conversation_summary=True, preserve_tool_errors=True, ) @pytest.fixture def minimal_config(): """Minimal test configuration.""" return ContextConfig( enabled=True, summary_trigger_tokens=800, keep_recent_messages=3, max_tool_output_length=200, max_sql_result_rows=5, max_code_output_lines=10, ) @pytest.fixture def disabled_config(): """Configuration with context management disabled.""" return ContextConfig( enabled=False, summary_trigger_tokens=12000, keep_recent_messages=10, max_tool_output_length=2000 ) @pytest.fixture def sample_conversation(): """Sample conversation for testing.""" from langchain_core.messages import HumanMessage, ToolMessage return [ HumanMessage(content="Can you analyze our sales data?"), AIMessage(content="I'll help you analyze the sales data. Let me query the database."), ToolMessage(content="Query executed successfully. Found 1000 records.", tool_call_id="query_1"), HumanMessage(content="What are the top trends?"), AIMessage(content="Based on the data, I can see several key trends..."), HumanMessage(content="Can you create a visualization?"), AIMessage( content="I'll create a chart for you.", tool_calls=[{"name": "create_chart", "args": {"type": "bar"}, "id": "chart_1"}], ), ToolMessage(content="Chart created successfully.", tool_call_id="chart_1"), ] @pytest.fixture def large_sql_output(): """Large SQL output for testing trimming.""" csv_data = "id,name,value,date\n" csv_data += "\n".join([f"{i},Customer_{i},{i*100},2023-01-{i%30+1:02d}" for i in range(100)]) return f"""SQL Query: ```sql SELECT id, name, value, date FROM customers ORDER BY value DESC LIMIT 100; ``` Query Results (CSV format): ```csv {csv_data} ``` Visualization Created: bar chart has been automatically generated and will be displayed in the UI.""" @pytest.fixture def large_python_output(): """Large Python code execution output.""" output_lines = [] output_lines.append("Processing data...") for i in range(100): output_lines.append(f"Step {i}: Processing record {i} - Status: OK") output_lines.append("Processing complete!") return "\n".join(output_lines) @pytest.fixture def error_output(): """Sample error output.""" return """Traceback (most recent call last): File "/app/analysis.py", line 42, in analyze_sales df = pd.read_csv('nonexistent_file.csv') File "/usr/local/lib/python3.9/site-packages/pandas/io/parsers/readers.py", line 912, in read_csv return _read(filepath_or_buffer, kwds) FileNotFoundError: [Errno 2] No such file or directory: 'nonexistent_file.csv' Error: Could not load the sales data file. Please check that the file exists and is accessible.""" # Pytest configuration def pytest_configure(config): """Configure pytest with custom markers.""" config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')") config.addinivalue_line("markers", "integration: marks tests as integration tests") def pytest_collection_modifyitems(config, items): """Modify test items to add markers.""" for item in items: # Mark integration tests if "integration" in item.nodeid.lower(): item.add_marker(pytest.mark.integration) # Mark slow tests based on certain patterns if any(pattern in item.nodeid.lower() for pattern in ["large", "stress", "concurrent"]): item.add_marker(pytest.mark.slow) ================================================ FILE: tests/context_management/test_agent_graph_integration.py ================================================ """Integration tests for agent graph with context management.""" from unittest.mock import Mock, patch import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.tools import StructuredTool from openchatbi.agent_graph import _build_graph_core, agent_llm_call, build_agent_graph_async, build_agent_graph_sync from openchatbi.context_config import ContextConfig from openchatbi.context_manager import ContextManager from openchatbi.graph_state import AgentState class TestAgentGraphIntegration: """Integration tests for agent graph with context management.""" @pytest.fixture def mock_catalog(self): """Mock catalog store for testing.""" catalog = Mock() catalog.get_schema = Mock(return_value={"tables": []}) return catalog @pytest.fixture def mock_llm(self): """Mock LLM for testing.""" llm = Mock() llm.bind_tools = Mock(return_value=llm) return llm @pytest.fixture def mock_tools(self): """Mock tools for testing.""" def mock_tool_func(query: str) -> str: return "Mock tool result" tool = StructuredTool.from_function(func=mock_tool_func, name="mock_tool", description="Mock tool for testing") return [tool] @pytest.fixture def test_config(self): """Test configuration for context management.""" return ContextConfig( enabled=True, summary_trigger_tokens=800, keep_recent_messages=3, max_tool_output_length=100, ) def test_agent_llm_node_with_context_manager(self, mock_llm, mock_tools, test_config): """Test agent llm_node with context manager integration.""" context_manager = ContextManager(llm=mock_llm, config=test_config) # Mock LLM response mock_response = AIMessage(content="Test response", tool_calls=[]) with patch("openchatbi.agent_graph.call_llm_chat_model_with_retry", return_value=mock_response): llm_node_func = agent_llm_call(mock_llm, mock_tools, context_manager) # Create test state with long messages to trigger context management long_messages = [ HumanMessage(content="A" * 500), # Long message AIMessage(content="B" * 500), # Long message ToolMessage(content="C" * 200, tool_call_id="123"), # Long tool output HumanMessage(content="Recent question"), ] state = AgentState(messages=long_messages) result = llm_node_func(state) # Should have processed the state assert "messages" in result assert isinstance(result["messages"][0], AIMessage) def test_agent_llm_node_without_context_manager(self, mock_llm, mock_tools): """Test agent llm_node without context manager.""" mock_response = AIMessage(content="Test response", tool_calls=[]) with patch("openchatbi.agent_graph.call_llm_chat_model_with_retry", return_value=mock_response): llm_node_func = agent_llm_call(mock_llm, mock_tools, context_manager=None) state = AgentState(messages=[HumanMessage(content="Test")]) result = llm_node_func(state) assert "messages" in result assert isinstance(result["messages"][0], AIMessage) def test_build_graph_core_with_context_management(self, mock_catalog, mock_llm): """Test core graph building with context management enabled.""" def create_mock_tool(name): def mock_func(input_str: str) -> str: return f"Mock {name} result" return StructuredTool.from_function(func=mock_func, name=name, description=f"Mock {name} tool") # Mock all the tool imports directly with ( patch("openchatbi.agent_graph.search_knowledge", create_mock_tool("search_knowledge")), patch("openchatbi.agent_graph.show_schema", create_mock_tool("show_schema")), patch("openchatbi.agent_graph.run_python_code", create_mock_tool("run_python_code")), patch("openchatbi.agent_graph.save_report", create_mock_tool("save_report")), patch("openchatbi.agent_graph.timeseries_forecast", create_mock_tool("timeseries_forecast")), patch("openchatbi.agent_graph.get_sql_tools") as mock_get_sql_tools, patch("openchatbi.agent_graph.build_sql_graph") as mock_sql_graph, patch("openchatbi.agent_graph.get_memory_tools") as mock_memory_tools, patch("openchatbi.agent_graph.create_mcp_tools_sync") as mock_mcp_tools, patch("openchatbi.agent_graph.get_llm", return_value=mock_llm), ): # Setup function-based mocks mock_get_sql_tools.return_value = create_mock_tool("call_sql_graph_tool") mock_sql_graph.return_value = Mock() mock_memory_tools.return_value = ( create_mock_tool("manage_memory_tool"), create_mock_tool("search_memory_tool"), ) mock_mcp_tools.return_value = [] graph = _build_graph_core( catalog=mock_catalog, sync_mode=True, checkpointer=None, memory_store=None, memory_tools=None, mcp_tools=[], enable_context_management=True, ) # Should create a compiled graph assert graph is not None # Verify that SQL graph was initialized mock_sql_graph.assert_called_once() def test_build_graph_core_without_context_management(self, mock_catalog, mock_llm): """Test core graph building with context management disabled.""" def create_mock_tool(name): def mock_func(input_str: str) -> str: return f"Mock {name} result" return StructuredTool.from_function(func=mock_func, name=name, description=f"Mock {name} tool") # Mock all the tool imports directly - same pattern as with context management with ( patch("openchatbi.agent_graph.search_knowledge", create_mock_tool("search_knowledge")), patch("openchatbi.agent_graph.show_schema", create_mock_tool("show_schema")), patch("openchatbi.agent_graph.run_python_code", create_mock_tool("run_python_code")), patch("openchatbi.agent_graph.save_report", create_mock_tool("save_report")), patch("openchatbi.agent_graph.timeseries_forecast", create_mock_tool("timeseries_forecast")), patch("openchatbi.agent_graph.get_sql_tools") as mock_get_sql_tools, patch("openchatbi.agent_graph.build_sql_graph") as mock_sql_graph, patch("openchatbi.agent_graph.get_memory_tools") as mock_memory_tools, patch("openchatbi.agent_graph.create_mcp_tools_sync") as mock_mcp_tools, patch("openchatbi.agent_graph.get_llm", return_value=mock_llm), ): # Setup function-based mocks mock_get_sql_tools.return_value = create_mock_tool("call_sql_graph_tool") mock_sql_graph.return_value = Mock() mock_memory_tools.return_value = ( create_mock_tool("manage_memory_tool"), create_mock_tool("search_memory_tool"), ) mock_mcp_tools.return_value = [] graph = _build_graph_core( catalog=mock_catalog, sync_mode=True, checkpointer=None, memory_store=None, memory_tools=None, mcp_tools=[], enable_context_management=False, ) # Should still create a compiled graph assert graph is not None def test_build_agent_graph_sync_with_context_management(self, mock_catalog): """Test sync graph building with context management.""" with ( patch("openchatbi.agent_graph.create_mcp_tools_sync") as mock_mcp_tools, patch("openchatbi.agent_graph._build_graph_core") as mock_build_core, ): mock_build_core.return_value = Mock() mock_mcp_tools.return_value = [] graph = build_agent_graph_sync(catalog=mock_catalog, enable_context_management=True) # Verify _build_graph_core was called with correct parameters mock_build_core.assert_called_once() call_args = mock_build_core.call_args assert call_args[1]["enable_context_management"] is True # Should return the graph assert graph is not None @pytest.mark.asyncio async def test_build_agent_graph_async_with_context_management(self, mock_catalog): """Test async graph building with context management.""" with ( patch("openchatbi.agent_graph.get_mcp_tools_async") as mock_mcp_tools, patch("openchatbi.agent_graph._build_graph_core") as mock_build_core, ): mock_build_core.return_value = Mock() # Mock async function mock_mcp_tools.return_value = [] graph = await build_agent_graph_async(catalog=mock_catalog, enable_context_management=True) # Verify _build_graph_core was called with correct parameters mock_build_core.assert_called_once() call_args = mock_build_core.call_args assert call_args[1]["enable_context_management"] is True # Should return the graph assert graph is not None @patch("openchatbi.agent_graph.call_llm_chat_model_with_retry") def test_full_context_management_flow(self, mock_llm_call, mock_catalog): """Test full context management flow in agent graph.""" # Mock LLM responses mock_llm_call.side_effect = [ AIMessage(content="Response 1"), AIMessage(content="Summary of conversation"), # For summarization AIMessage(content="Final response"), ] context_manager = ContextManager( llm=Mock(), config=ContextConfig( enabled=True, summary_trigger_tokens=80, keep_recent_messages=2, ), ) # Create many messages to trigger context management messages = [] for i in range(10): messages.extend( [ HumanMessage(content=f"Question {i}" * 10), # Make messages longer AIMessage(content=f"Response {i}" * 10), ToolMessage(content=f"Tool result {i}" * 20, tool_call_id=f"tool_{i}"), ] ) # Test context management original_count = len(messages) context_manager.manage_context_messages(messages) managed_messages = messages # Should have fewer messages than input assert len(managed_messages) < original_count # Should preserve recent messages assert any("Question 9" in str(msg.content) for msg in managed_messages if hasattr(msg, "content")) class TestContextManagementEdgeCases: """Test edge cases for context management in agent graph.""" def test_empty_message_handling(self): """Test handling of empty messages.""" config = ContextConfig(enabled=True) context_manager = ContextManager(llm=Mock(), config=config) messages = [] context_manager.manage_context_messages(messages) result = messages assert result == [] def test_state_message_type_validation(self): """Test that only valid state message types are maintained during context management.""" config = ContextConfig(enabled=True) context_manager = ContextManager(llm=Mock(), config=config) # State should only contain valid message types (no SystemMessage) messages = [ HumanMessage(content="A" * 100), # Long message AIMessage(content="B" * 100), # Long message HumanMessage(content="Recent question"), ] with patch( "openchatbi.context_manager.call_llm_chat_model_with_retry", return_value=AIMessage(content="Summary") ): context_manager.manage_context_messages(messages) result = messages # Should only contain valid state message types valid_types = {HumanMessage, AIMessage, ToolMessage} assert all(type(msg) in valid_types for msg in result), "Should only contain valid state message types" def test_context_management_with_tool_calls(self): """Test context management when AI messages have tool calls.""" config = ContextConfig(enabled=True) context_manager = ContextManager(llm=Mock(), config=config) ai_message_with_tools = AIMessage( content="I'll help you with that.", tool_calls=[{"name": "search_tool", "args": {"query": "test"}, "id": "call_123"}], ) messages = [ai_message_with_tools, HumanMessage(content="Follow up")] context_manager.manage_context_messages(messages) result = messages # AI message with tool calls should be preserved ai_msgs = [msg for msg in result if isinstance(msg, AIMessage)] assert len(ai_msgs) > 0 assert any(hasattr(msg, "tool_calls") and msg.tool_calls for msg in ai_msgs) @patch("openchatbi.context_manager.call_llm_chat_model_with_retry") def test_summarization_failure_fallback(self, mock_llm_call): """Test fallback behavior when summarization fails.""" # Mock LLM failure mock_llm_call.side_effect = Exception("LLM unavailable") config = ContextConfig(enabled=True) context_manager = ContextManager(llm=Mock(), config=config) # Create messages that would trigger summarization (no SystemMessage in state) messages = [ HumanMessage(content="A" * 100), # Long messages to trigger AIMessage(content="B" * 100), HumanMessage(content="C" * 100), AIMessage(content="D" * 100), HumanMessage(content="Recent"), ] context_manager.manage_context_messages(messages) result = messages # Should fallback to sliding window assert len(result) <= len(messages) # Should preserve recent messages and only contain valid state message types assert any("Recent" in str(msg.content) for msg in result if hasattr(msg, "content")) valid_types = {HumanMessage, AIMessage, ToolMessage} assert all(type(msg) in valid_types for msg in result), "Should only contain valid state message types" ================================================ FILE: tests/context_management/test_context_config.py ================================================ """Unit tests for context configuration.""" from openchatbi.context_config import ContextConfig, get_context_config, update_context_config class TestContextConfig: """Test cases for ContextConfig class.""" def test_default_config_values(self): """Test that default configuration has expected values.""" config = ContextConfig() # Test default values assert config.enabled is True assert config.summary_trigger_tokens == 12000 assert config.keep_recent_messages == 20 assert config.max_tool_output_length == 2000 assert config.max_sql_result_rows == 50 assert config.max_code_output_lines == 50 # Test boolean flags assert config.enable_summarization is True assert config.enable_conversation_summary is True assert config.preserve_tool_errors is True assert config.preserve_recent_sql is True def test_custom_config_values(self): """Test creating config with custom values.""" config = ContextConfig( enabled=False, summary_trigger_tokens=8000, keep_recent_messages=5, max_tool_output_length=1000, enable_summarization=False, ) assert config.enabled is False assert config.summary_trigger_tokens == 8000 assert config.keep_recent_messages == 5 assert config.max_tool_output_length == 1000 assert config.enable_summarization is False # Other values should use defaults assert config.max_sql_result_rows == 50 assert config.preserve_tool_errors is True def test_config_validation_logic(self): """Test logical relationships in configuration.""" config = ContextConfig() # Keep recent messages should be reasonable assert config.keep_recent_messages > 0 assert config.keep_recent_messages < 100 # Sanity check # Output limits should be positive assert config.max_tool_output_length > 0 assert config.max_sql_result_rows > 0 assert config.max_code_output_lines > 0 # Token limits should be reasonable assert config.summary_trigger_tokens > 0 def test_get_context_config(self): """Test getting context configuration.""" config = get_context_config() assert isinstance(config, ContextConfig) def test_update_context_config_single_value(self): """Test updating a single configuration value.""" original_trigger_tokens = get_context_config().summary_trigger_tokens updated_config = update_context_config(summary_trigger_tokens=15000) assert updated_config.summary_trigger_tokens == 15000 # Other values should remain unchanged assert updated_config.keep_recent_messages == get_context_config().keep_recent_messages def test_update_context_config_multiple_values(self): """Test updating multiple configuration values.""" updated_config = update_context_config( summary_trigger_tokens=20000, keep_recent_messages=15, enable_summarization=False, max_tool_output_length=3000, ) assert updated_config.summary_trigger_tokens == 20000 assert updated_config.keep_recent_messages == 15 assert updated_config.enable_summarization is False assert updated_config.max_tool_output_length == 3000 def test_update_context_config_invalid_attribute(self): """Test updating config with invalid attribute name.""" # Should not raise error, just ignore invalid attributes config = update_context_config(invalid_attribute=123) assert not hasattr(config, "invalid_attribute") def test_update_context_config_returns_copy(self): """Test that update_context_config returns a modified copy.""" original_config = get_context_config() updated_config = update_context_config(summary_trigger_tokens=30000) # Original should be unchanged (if it's designed that way) # Updated should have new values assert updated_config.summary_trigger_tokens == 30000 class TestContextConfigPresets: """Test different configuration presets for common scenarios.""" def test_minimal_context_config(self): """Test configuration for minimal context management.""" config = ContextConfig( enabled=True, enable_summarization=False, enable_conversation_summary=False, max_tool_output_length=500, ) assert config.enabled is True assert config.enable_summarization is False assert config.enable_conversation_summary is False def test_aggressive_compression_config(self): """Test configuration for aggressive context compression.""" config = ContextConfig( enabled=True, summary_trigger_tokens=6000, keep_recent_messages=5, max_tool_output_length=1000, max_sql_result_rows=10, max_code_output_lines=20, enable_summarization=True, ) assert config.summary_trigger_tokens == 6000 assert config.keep_recent_messages == 5 assert config.max_tool_output_length == 1000 assert config.max_sql_result_rows == 10 assert config.max_code_output_lines == 20 def test_development_debug_config(self): """Test configuration suitable for development/debugging.""" config = ContextConfig( enabled=True, summary_trigger_tokens=40000, keep_recent_messages=20, max_tool_output_length=10000, # Don't trim much preserve_tool_errors=True, # Always preserve errors ) assert config.preserve_tool_errors is True def test_production_optimized_config(self): """Test configuration optimized for production use.""" config = ContextConfig( enabled=True, summary_trigger_tokens=10000, keep_recent_messages=8, max_tool_output_length=1500, max_sql_result_rows=15, enable_summarization=True, preserve_tool_errors=True, ) assert config.summary_trigger_tokens == 10000 assert config.enable_summarization is True assert config.preserve_tool_errors is True class TestContextConfigEdgeCases: """Test edge cases and boundary conditions for context configuration.""" def test_zero_values(self): """Test configuration with zero values.""" config = ContextConfig( summary_trigger_tokens=0, keep_recent_messages=0, max_tool_output_length=0, ) # Should accept zero values (might cause issues in practice) assert config.keep_recent_messages == 0 assert config.summary_trigger_tokens == 0 assert config.max_tool_output_length == 0 def test_very_large_values(self): """Test configuration with very large values.""" config = ContextConfig( summary_trigger_tokens=900000, keep_recent_messages=1000, max_tool_output_length=100000, ) assert config.keep_recent_messages == 1000 def test_inconsistent_token_limits(self): """Test configuration where summary trigger > max tokens.""" # Should accept but might cause logical issues def test_all_features_disabled(self): """Test configuration with all features disabled.""" config = ContextConfig( enabled=False, enable_summarization=False, enable_conversation_summary=False, ) assert config.enabled is False assert config.enable_summarization is False assert config.enable_conversation_summary is False def test_config_serialization(self): """Test that config can be converted to/from dict (if needed).""" config = ContextConfig(summary_trigger_tokens=15000, enable_summarization=True) # Test converting to dict-like representation config_dict = { "summary_trigger_tokens": config.summary_trigger_tokens, "enable_summarization": config.enable_summarization, "enabled": config.enabled, } assert config_dict["summary_trigger_tokens"] == 15000 assert config_dict["enable_summarization"] is True assert config_dict["enabled"] is True def test_config_immutability_simulation(self): """Test that config behaves consistently across operations.""" config1 = ContextConfig(summary_trigger_tokens=10000) config2 = ContextConfig(summary_trigger_tokens=10000) # Same values should be equal assert config1.summary_trigger_tokens == config2.summary_trigger_tokens assert config1.enabled == config2.enabled def test_realistic_configuration_scenarios(self): """Test realistic configuration scenarios.""" # Small dataset scenario small_config = ContextConfig() # Large dataset scenario large_config = ContextConfig() # Interactive analysis scenario interactive_config = ContextConfig( keep_recent_messages=50, # Keep more context for back-and-forth preserve_tool_errors=True, max_tool_output_length=5000, # Don't trim too aggressively ) assert interactive_config.keep_recent_messages > small_config.keep_recent_messages assert interactive_config.preserve_tool_errors is True ================================================ FILE: tests/context_management/test_context_manager.py ================================================ """Unit tests for ContextManager class.""" from unittest.mock import Mock, patch import pytest from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from openchatbi.context_config import ContextConfig from openchatbi.context_manager import ContextManager class TestContextManager: """Test cases for ContextManager class.""" @pytest.fixture def mock_llm(self): """Mock LLM for testing.""" llm = Mock() # Mock response for summarization llm_response = AIMessage(content="This is a test summary of the conversation.") with patch("openchatbi.context_manager.call_llm_chat_model_with_retry", return_value=llm_response): yield llm @pytest.fixture def default_config(self): """Default context configuration for testing.""" return ContextConfig( enabled=True, summary_trigger_tokens=900, keep_recent_messages=3, max_tool_output_length=200, max_sql_result_rows=5, max_code_output_lines=10, enable_conversation_summary=True, enable_summarization=True, ) @pytest.fixture def context_manager(self, mock_llm, default_config): """Context manager instance for testing.""" return ContextManager(llm=mock_llm, config=default_config) def test_token_estimation(self, context_manager): """Test token estimation functionality.""" # Test basic token estimation short_text = "Hello world" assert context_manager.estimate_tokens(short_text) == len(short_text) // 4 # Test longer text long_text = "This is a longer text that should have more tokens estimated." assert context_manager.estimate_tokens(long_text) > context_manager.estimate_tokens(short_text) def test_message_token_estimation(self, context_manager): """Test token estimation for messages.""" messages = [ HumanMessage(content="Hello"), AIMessage(content="Hi there!"), ToolMessage(content="Tool result", tool_call_id="123"), ] total_tokens = context_manager.estimate_message_tokens(messages) assert total_tokens > 0 # Should include content tokens plus metadata overhead assert total_tokens > sum(len(str(msg.content)) // 4 for msg in messages) def test_trim_short_tool_output(self, context_manager): """Test trimming tool output that's already short enough.""" short_output = "This is a short output." result = context_manager.trim_tool_output(short_output) assert result == short_output def test_trim_long_generic_output(self, context_manager): """Test trimming long generic tool output.""" long_output = "A" * 500 # Much longer than max_tool_output_length (200) result = context_manager.trim_tool_output(long_output) assert len(result) < len(long_output) assert "... [Output truncated] ..." in result assert result.startswith("A") assert result.endswith("A") def test_trim_sql_output(self, context_manager): """Test trimming SQL output with structured data.""" sql_output = """SQL Query: ```sql SELECT * FROM users WHERE age > 18; ``` Query Results (CSV format): ```csv id,name,age,email 1,John,25,john@example.com 2,Jane,30,jane@example.com 3,Bob,22,bob@example.com 4,Alice,28,alice@example.com 5,Charlie,35,charlie@example.com 6,Diana,27,diana@example.com 7,Eve,31,eve@example.com ``` Visualization Created: bar chart has been automatically generated.""" result = context_manager.trim_tool_output(sql_output) # Should preserve SQL query assert "SELECT * FROM users WHERE age > 18;" in result # Should preserve visualization info assert "Visualization Created:" in result # Should trim CSV data but keep structure assert "```csv" in result assert "rows omitted" in result def test_trim_code_output(self, context_manager): """Test trimming Python code execution output.""" # Test long output without errors long_code_output = "\n".join([f"Line {i}: Some output here" for i in range(50)]) result = context_manager.trim_tool_output(long_code_output) assert len(result.split("\n")) < 50 assert "... [Output truncated] ..." in result def test_preserve_error_output(self, context_manager): """Test that error outputs are preserved when configured.""" error_output = """Traceback (most recent call last): File "test.py", line 1, in print(undefined_variable) NameError: name 'undefined_variable' is not defined""" # With preserve_tool_errors=True (default in test config) result = context_manager.trim_tool_output(error_output) assert result == error_output # Should be preserved in full # Test with preserve_tool_errors=False context_manager.config.preserve_tool_errors = False result = context_manager.trim_tool_output(error_output) # Should still preserve because it's an error, but could be trimmed based on length # Tool output trimming disable test removed - trimming is always enabled now def test_conversation_summary_disabled(self, context_manager): """Test conversation summary when disabled.""" context_manager.config.enable_conversation_summary = False messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")] summary = context_manager.summarize_conversation(messages) assert summary == "" @patch("openchatbi.context_manager.call_llm_chat_model_with_retry") def test_conversation_summary_success(self, mock_llm_call, context_manager): """Test successful conversation summarization.""" # Mock successful LLM response mock_response = AIMessage(content="Summary: User asked about data analysis.") mock_llm_call.return_value = mock_response messages = [ HumanMessage(content="Can you analyze our sales data?"), AIMessage(content="I'll help you analyze the sales data."), ToolMessage(content="Query results: 100 records", tool_call_id="123"), HumanMessage(content="What are the trends?"), AIMessage(content="The trends show increasing sales."), HumanMessage(content="Recent question"), # This should be excluded from summary ] summary = context_manager.summarize_conversation(messages) assert summary.startswith("[Conversation Summary]:") assert "Summary: User asked about data analysis." in summary mock_llm_call.assert_called_once() @patch("openchatbi.context_manager.call_llm_chat_model_with_retry") def test_conversation_summary_failure(self, mock_llm_call, context_manager): """Test conversation summary when LLM call fails.""" # Mock LLM failure mock_llm_call.side_effect = Exception("LLM service unavailable") # Need more messages than keep_recent_messages (3) to trigger summarization messages = [ HumanMessage(content="First message"), AIMessage(content="First response"), HumanMessage(content="Second message"), AIMessage(content="Second response"), HumanMessage(content="Third message"), AIMessage(content="Third response"), ] summary = context_manager.summarize_conversation(messages) assert summary == "[Summary generation failed]" def test_manage_context_disabled(self, context_manager): """Test context management when disabled.""" context_manager.config.enabled = False messages = [HumanMessage(content="Test")] context_manager.manage_context_messages(messages) result = messages assert result == messages # Should return unchanged def test_manage_context_empty_messages(self, context_manager): """Test context management with empty message list.""" messages = [] context_manager.manage_context_messages(messages) result = messages assert result == [] def test_manage_context_tool_message_trimming(self, context_manager): """Test that tool messages are trimmed during context management.""" long_content = "A" * 500 # Add enough messages to trigger context management, with ToolMessage in historical part # keep_recent_messages=3, so we need more than 3 messages after the ToolMessage messages = [ HumanMessage(content="This is a long question that helps reach the token threshold " * 10), ToolMessage(content=long_content, tool_call_id="123"), # This should be in historical part AIMessage(content="This is a long response that helps reach the token threshold " * 10), HumanMessage(content="Another long question to increase token count " * 10), AIMessage(content="Response " * 20), HumanMessage(content="Final question"), AIMessage(content="Final response"), ] context_manager.manage_context_messages(messages) result = messages # Find the tool message in results tool_msg = next(msg for msg in result if isinstance(msg, ToolMessage)) assert len(str(tool_msg.content)) < len(long_content) assert "... [Output truncated] ..." in str(tool_msg.content) @patch("openchatbi.context_manager.call_llm_chat_model_with_retry") def test_manage_context_with_summarization(self, mock_llm_call, context_manager): """Test context management triggering summarization.""" # Mock successful summarization mock_response = AIMessage(content="Conversation summary here.") mock_llm_call.return_value = mock_response # Create many messages to trigger summarization messages = [] for i in range(10): messages.extend( [ HumanMessage(content=f"Question {i}"), AIMessage(content=f"Response {i}" * 100), # Long responses to increase token count ] ) original_length = len(messages) context_manager.manage_context_messages(messages) result = messages # Should have fewer messages than input assert len(result) < original_length # Should contain summary message assert any("[Conversation Summary]:" in str(msg.content) for msg in result if hasattr(msg, "content")) # Verify LLM was called for summarization mock_llm_call.assert_called_once() # Tool wrapper tests removed - we now handle context at state level def test_format_messages_for_summary(self, context_manager): """Test message formatting for summary generation.""" messages = [ HumanMessage(content="User question"), AIMessage(content="AI response"), ToolMessage(content="Tool result with some data", tool_call_id="123"), SystemMessage(content="System message"), # Should be excluded ] formatted = context_manager._format_messages_for_summary(messages) assert " User question " in formatted assert "" in formatted and "AI response" in formatted assert "tool_result" in formatted assert "System message" not in formatted # System messages excluded def test_format_long_ai_message_for_summary(self, context_manager): """Test that long AI messages are truncated in summary formatting.""" long_content = "A" * 1000 messages = [AIMessage(content=long_content)] formatted = context_manager._format_messages_for_summary(messages) assert len(formatted) < len(f"Assistant: {long_content}") assert "... [truncated]" in formatted # Tool wrapping tests removed - we now handle context at state level instead of wrapping tools # Pytest fixtures and test data @pytest.fixture def sample_sql_output(): """Sample SQL output for testing.""" return """SQL Query: ```sql SELECT customer_id, SUM(amount) as total FROM orders WHERE order_date >= '2023-01-01' GROUP BY customer_id ORDER BY total DESC; ``` Query Results (CSV format): ```csv customer_id,total 1001,15420.50 1002,12350.75 1003,11200.00 1004,9875.25 1005,8650.00 1006,7500.50 1007,6200.75 1008,5800.00 1009,4950.25 1010,4200.00 ``` Visualization Created: bar chart has been automatically generated and will be displayed in the UI.""" @pytest.fixture def sample_error_output(): """Sample error output for testing.""" return """Traceback (most recent call last): File "/app/code.py", line 15, in analyze_data result = df.groupby('nonexistent_column').sum() File "/usr/local/lib/python3.9/site-packages/pandas/core/groupby/groupby.py", line 1647, in sum return self._cython_transform("sum", numeric_only=numeric_only, **kwargs) KeyError: 'nonexistent_column' Error: Column 'nonexistent_column' not found in DataFrame. Available columns: ['customer_id', 'order_date', 'amount', 'product_id']""" ================================================ FILE: tests/context_management/test_edge_cases.py ================================================ """Edge cases for context management.""" from unittest.mock import Mock, patch import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from openchatbi.context_config import ContextConfig from openchatbi.context_manager import ContextManager class TestContextManagementEdgeCases: """Edge cases and boundary conditions for context management.""" @pytest.fixture def edge_case_config(self): """Configuration for edge case testing.""" return ContextConfig( enabled=True, summary_trigger_tokens=800, keep_recent_messages=2, max_tool_output_length=50, ) @pytest.fixture def context_manager(self, edge_case_config): """Context manager for edge case testing.""" return ContextManager(llm=Mock(), config=edge_case_config) def test_empty_and_none_inputs(self, context_manager): """Test handling of empty and None inputs.""" # Empty list messages = [] context_manager.manage_context_messages(messages) assert messages == [] # List with None elements (should be filtered out gracefully) messages = [HumanMessage(content="Test"), None, AIMessage(content="Response")] # Filter out None values before passing to context manager filtered_messages = [msg for msg in messages if msg is not None] context_manager.manage_context_messages(filtered_messages) result = filtered_messages assert len(result) == 2 def test_malformed_messages(self, context_manager): """Test handling of malformed messages.""" # Message with None content try: malformed_msg = HumanMessage(content=None) messages = [malformed_msg] context_manager.manage_context_messages(messages) result = messages # Should handle gracefully assert isinstance(result, list) except Exception as e: # If it raises an exception, it should be a reasonable one assert "content" in str(e).lower() def test_extremely_long_single_message(self, context_manager): """Test handling of extremely long single messages.""" # Create a message longer than the entire context limit very_long_content = "A" * 100000 # Much longer than context limit long_message = HumanMessage(content=very_long_content) messages = [long_message] context_manager.manage_context_messages(messages) result = messages # Should still return the message (context management doesn't trim individual message content) assert len(result) == 1 assert isinstance(result[0], HumanMessage) def test_tool_message_without_tool_call_id(self, context_manager): """Test handling of tool messages without proper tool_call_id.""" try: # This might raise an error depending on LangChain's validation tool_msg = ToolMessage(content="Result", tool_call_id="") messages = [tool_msg] context_manager.manage_context_messages(messages) result = messages assert isinstance(result, list) except Exception: # If LangChain validates and raises, that's acceptable pass def test_circular_references_in_content(self, context_manager): """Test handling of complex content that might cause issues.""" # Content with special characters and formatting special_content = ( """ Content with: - Unicode: 🚀 中文 العربية - Code blocks: ```python\nprint("hello")\n``` - JSON: {"key": "value", "nested": {"array": [1,2,3]}} - HTML:
content
- URLs: https://example.com/path?param=value - Very long line: """ + "X" * 1000 ) message = HumanMessage(content=special_content) messages = [message] context_manager.manage_context_messages(messages) result = messages assert len(result) == 1 assert isinstance(result[0], HumanMessage) def test_zero_configuration_values(self): """Test behavior with zero configuration values.""" zero_config = ContextConfig( enabled=True, summary_trigger_tokens=0, keep_recent_messages=0, max_tool_output_length=0, ) context_manager = ContextManager(llm=Mock(), config=zero_config) messages = [HumanMessage(content="Test")] # Should handle zero values gracefully context_manager.manage_context_messages(messages) result = messages assert isinstance(result, list) def test_negative_configuration_values(self): """Test behavior with negative configuration values.""" negative_config = ContextConfig( enabled=True, summary_trigger_tokens=-50, keep_recent_messages=-5, max_tool_output_length=-10, ) context_manager = ContextManager(llm=Mock(), config=negative_config) messages = [HumanMessage(content="Test")] # Should handle negative values gracefully (might treat as disabled) context_manager.manage_context_messages(messages) result = messages assert isinstance(result, list) def test_unicode_and_encoding_edge_cases(self, context_manager): """Test handling of various Unicode and encoding scenarios.""" unicode_messages = [ HumanMessage(content="English text"), HumanMessage(content="中文内容测试"), HumanMessage(content="العربية"), HumanMessage(content="Русский текст"), HumanMessage(content="🚀🎉💡🔥"), # Emojis HumanMessage(content="Mixed: Hello 世界 🌍"), ToolMessage(content="Unicode tool result: café naïve résumé", tool_call_id="unicode_1"), ] context_manager.manage_context_messages(unicode_messages) result = unicode_messages # Should handle all Unicode content assert len(result) > 0 assert all(isinstance(msg, (HumanMessage, AIMessage, ToolMessage)) for msg in result) def test_extremely_nested_or_complex_structures(self, context_manager): """Test handling of complex nested data structures in tool outputs.""" # Simulate deeply nested JSON output nested_data = {"level1": {"level2": {"level3": {"data": ["item"] * 1000}}}} complex_output = str(nested_data) * 100 # Make it very large # Create messages so the tool message is in historical part (not recent) # keep_recent_messages=2, so add more than 2 messages after the tool message messages = [ ToolMessage(content=complex_output, tool_call_id="complex_1"), # Historical part HumanMessage(content="Question 1"), AIMessage(content="Response 1"), HumanMessage(content="Recent question"), # Recent part starts here ] context_manager.manage_context_messages(messages) result = messages # Should trim the complex output since it's in historical part tool_msg = next(msg for msg in result if isinstance(msg, ToolMessage)) assert len(str(tool_msg.content)) < len(complex_output) def test_sql_output_edge_cases(self, context_manager): """Test SQL output trimming with edge cases.""" # SQL with no results empty_sql_output = """SQL Query: ```sql SELECT * FROM users WHERE id = -1; ``` Query Results (CSV format): ```csv id,name ```""" # SQL with single row single_row_sql = """SQL Query: ```sql SELECT COUNT(*) as total FROM users; ``` Query Results (CSV format): ```csv total 42 ```""" # Malformed SQL output malformed_sql = """Something that looks like SQL but isn't: ```sql INVALID QUERY HERE ``` Random text after""" test_cases = [empty_sql_output, single_row_sql, malformed_sql] for sql_output in test_cases: tool_msg = ToolMessage(content=sql_output, tool_call_id="sql_test") messages = [tool_msg] context_manager.manage_context_messages(messages) result = messages # Should handle all cases gracefully assert len(result) == 1 assert isinstance(result[0], ToolMessage) def test_conversation_state_consistency(self, context_manager): """Test that conversation state remains consistent through management.""" # Create a conversation with specific patterns (no SystemMessage in state) messages = [ HumanMessage(content="Question 1"), AIMessage(content="Response 1"), ToolMessage(content="Tool result 1", tool_call_id="tool_1"), HumanMessage(content="Question 2"), AIMessage( content="Response 2 with tool calls", tool_calls=[{"name": "test_tool", "args": {"param": "value"}, "id": "call_1"}], ), ToolMessage(content="Tool result 2", tool_call_id="call_1"), HumanMessage(content="Final question"), ] with patch( "openchatbi.context_manager.call_llm_chat_model_with_retry", return_value=AIMessage(content="Summary") ): context_manager.manage_context_messages(messages) result = messages # Should maintain message type consistency (only valid state message types) message_types = [type(msg) for msg in result] valid_types = {HumanMessage, AIMessage, ToolMessage} assert all( msg_type in valid_types for msg_type in message_types ), "Should only contain valid state message types" # Should not have orphaned tool messages without corresponding AI messages for i, msg in enumerate(result): if isinstance(msg, ToolMessage): # There should be an AI message with tool calls before this previous_ai_msgs = [m for m in result[:i] if isinstance(m, AIMessage)] assert len(previous_ai_msgs) > 0, "Tool message should have corresponding AI message" ================================================ FILE: tests/context_management/test_runner.py ================================================ """Test runner script for context management tests.""" import argparse import subprocess import sys from pathlib import Path def run_tests(test_type="all", verbose=False, coverage=False): """Run context management tests. Args: test_type: Type of tests to run ('all', 'unit', 'integration', 'edge_cases') verbose: Enable verbose output coverage: Enable coverage reporting """ # Base pytest command cmd = ["python", "-m", "pytest"] # Test directory test_dir = Path(__file__).parent # Add specific test files based on type if test_type == "all": cmd.append(str(test_dir)) elif test_type == "unit": cmd.extend([str(test_dir / "test_context_manager.py"), str(test_dir / "test_context_config.py")]) elif test_type == "integration": cmd.append(str(test_dir / "test_agent_graph_integration.py")) elif test_type == "edge_cases": cmd.extend([str(test_dir / "test_edge_cases.py"), str(test_dir / "test_state_operations.py")]) else: print(f"Unknown test type: {test_type}") return False # Add verbose flag if verbose: cmd.append("-v") # Add coverage if coverage: cmd.extend( [ "--cov=openchatbi.context_manager", "--cov=openchatbi.context_config", "--cov-report=html", "--cov-report=term-missing", ] ) # Add other useful flags cmd.extend( [ "--tb=short", # Shorter traceback format "-x", # Stop on first failure "--strict-markers", # Strict marker checking ] ) print(f"Running command: {' '.join(cmd)}") print("-" * 50) # Run the tests try: result = subprocess.run(cmd, check=False) return result.returncode == 0 except KeyboardInterrupt: print("\nTests interrupted by user") return False def main(): """Main function for test runner.""" parser = argparse.ArgumentParser(description="Run context management tests") parser.add_argument( "--type", "-t", choices=["all", "unit", "integration", "edge_cases"], default="all", help="Type of tests to run (default: all)", ) parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output") parser.add_argument("--coverage", "-c", action="store_true", help="Enable coverage reporting") args = parser.parse_args() success = run_tests(test_type=args.type, verbose=args.verbose, coverage=args.coverage) if success: print("\n✅ All tests passed!") sys.exit(0) else: print("\n❌ Some tests failed!") sys.exit(1) if __name__ == "__main__": main() ================================================ FILE: tests/context_management/test_state_operations.py ================================================ """Tests for message-based context management operations.""" from unittest.mock import Mock, patch import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from openchatbi.context_config import ContextConfig from openchatbi.context_manager import ContextManager class TestMessageBasedContextManagement: """Test message-based context management with direct modification.""" @pytest.fixture def test_config(self): """Configuration for testing message operations.""" return ContextConfig( enabled=True, summary_trigger_tokens=300, # Lower threshold to trigger management keep_recent_messages=3, max_tool_output_length=200, preserve_tool_errors=True, preserve_recent_sql=True, ) @pytest.fixture def context_manager(self, test_config): """Context manager for testing.""" mock_llm = Mock() return ContextManager(llm=mock_llm, config=test_config) def test_no_operations_when_disabled(self, context_manager): """Test that no operations are performed when context management is disabled.""" context_manager.config.enabled = False messages = [HumanMessage(content="Test", id="test_1")] original_messages = messages.copy() context_manager.manage_context_messages(messages) assert messages == original_messages # Should be unchanged def test_no_operations_when_under_limit(self, context_manager): """Test that no operations are performed when context is under token limit.""" # Short messages that won't trigger context management messages = [HumanMessage(content="Hi", id="human_1"), AIMessage(content="Hello", id="ai_1")] original_messages = messages.copy() context_manager.manage_context_messages(messages) assert messages == original_messages # Should be unchanged def test_historical_tool_compression(self, context_manager): """Test compression of historical tool messages.""" # Disable conversation summarization to test only tool compression context_manager.config.enable_conversation_summary = False context_manager.config.enable_summarization = False # Create messages with large historical tool outputs messages = [ HumanMessage(content="Query data", id="human_1"), AIMessage(content="Running query", id="ai_1"), # Large historical tool message (should be compressed) ToolMessage(content="A" * 1000, tool_call_id="query_1", id="tool_1_historical"), # Large content HumanMessage(content="More analysis", id="human_2"), AIMessage(content="Analyzing", id="ai_2"), # Another large historical tool message ToolMessage(content="B" * 800, tool_call_id="query_2", id="tool_2_historical"), # Large content # Recent messages (should be preserved) HumanMessage(content="Recent question", id="human_recent"), AIMessage(content="Recent response", id="ai_recent"), ToolMessage(content="Recent result", tool_call_id="recent_1", id="tool_recent"), ] original_count = len(messages) context_manager.manage_context_messages(messages) # Should have same number of messages but some content should be compressed assert len(messages) == original_count # Check that historical tool messages are compressed historical_tool_msgs = [ msg for msg in messages if isinstance(msg, ToolMessage) and msg.id in ["tool_1_historical", "tool_2_historical"] ] for msg in historical_tool_msgs: assert len(str(msg.content)) < 1000, "Historical tool messages should be compressed" def test_error_message_preservation(self, context_manager): """Test that error messages are preserved even if they're historical.""" error_content = """Traceback (most recent call last): File "test.py", line 1, in raise ValueError("Test error") ValueError: Test error""" messages = [ HumanMessage(content="Run code", id="human_1"), AIMessage(content="Executing", id="ai_1"), # Historical error message (should be preserved) ToolMessage(content=error_content, tool_call_id="code_1", id="error_tool_historical"), # Recent messages HumanMessage(content="What happened?", id="human_recent"), AIMessage(content="There was an error", id="ai_recent"), ] original_error_content = messages[2].content context_manager.manage_context_messages(messages) # Error message should be preserved error_msg = next(msg for msg in messages if msg.id == "error_tool_historical") assert error_msg.content == original_error_content, "Error messages should be preserved" def test_sql_content_preservation(self, context_manager): """Test that SQL content is preserved when configured.""" sql_content = """SQL Query: ```sql SELECT * FROM users WHERE active = 1; ``` Query Results (CSV format): ```csv id,name,email 1,John,john@example.com 2,Jane,jane@example.com ```""" messages = [ HumanMessage(content="Get user data", id="human_1"), AIMessage(content="Querying users", id="ai_1"), # Historical SQL result (should be preserved if preserve_recent_sql=True) ToolMessage(content=sql_content, tool_call_id="sql_1", id="sql_tool_historical"), # Recent messages HumanMessage(content="Analyze results", id="human_recent"), AIMessage(content="Analyzing", id="ai_recent"), ] # Test with SQL preservation enabled context_manager.config.preserve_recent_sql = True original_sql_content = messages[2].content context_manager.manage_context_messages(messages) # SQL should be preserved when preserve_recent_sql=True sql_msg = next(msg for msg in messages if msg.id == "sql_tool_historical") assert sql_msg.content == original_sql_content, "SQL content should be preserved when configured" @patch("openchatbi.context_manager.call_llm_chat_model_with_retry") def test_conversation_summarization(self, mock_llm_call, context_manager): """Test conversation summarization with message modification.""" # Mock LLM response for summarization mock_llm_call.return_value = AIMessage(content="Summary of the conversation") # Create a long conversation that will trigger summarization messages = [] # Add many historical messages for i in range(20): messages.extend( [ HumanMessage(content=f"Question {i}" * 10, id=f"human_{i}"), AIMessage(content=f"Response {i}" * 10, id=f"ai_{i}"), ] ) # Add recent messages messages.extend( [ HumanMessage(content="Recent question", id="human_recent"), AIMessage(content="Recent response", id="ai_recent"), ToolMessage(content="Recent result", tool_call_id="recent_1", id="tool_recent"), ] ) original_count = len(messages) context_manager.manage_context_messages(messages) # Should have fewer messages due to summarization assert len(messages) < original_count # Should have a summary message summary_msgs = [msg for msg in messages if isinstance(msg, AIMessage) and "Summary" in str(msg.content)] assert len(summary_msgs) > 0, "Should create a summary message" def test_content_type_detection(self, context_manager): """Test content type detection methods.""" # Test error content detection error_contents = [ "Error: Something went wrong", "Traceback (most recent call last):\n File test.py", "ValueError: Invalid input", "Connection failed with status 500", ] for content in error_contents: assert context_manager._is_error_content(content), f"Should detect error in: {content[:50]}" # Test SQL content detection sql_contents = [ "```sql\nSELECT * FROM users;\n```", "Query results: 100 rows returned", "SQL Query:\nSELECT id FROM table", ] for content in sql_contents: assert context_manager._is_sql_content(content), f"Should detect SQL in: {content[:50]}" # Test data query result detection data_contents = [ "```csv\nid,name\n1,test\n```", "Query Results (CSV format):", "Found 500 records in the database", ] for content in data_contents: assert context_manager._is_data_query_result(content), f"Should detect data result in: {content[:50]}" def test_should_compress_logic(self, context_manager): """Test the logic for determining whether to compress historical tool messages.""" # Short content should not be compressed short_msg = ToolMessage(content="Short", tool_call_id="test", id="short") assert not context_manager._should_compress_historical_tool_message(short_msg, "Short") # Long non-error content should be compressed long_content = "A" * 1000 long_msg = ToolMessage(content=long_content, tool_call_id="test", id="long") assert context_manager._should_compress_historical_tool_message(long_msg, long_content) # Long error content should not be compressed (if preserve_tool_errors=True) error_content = "Error: " + "A" * 1000 error_msg = ToolMessage(content=error_content, tool_call_id="test", id="error") context_manager.config.preserve_tool_errors = True assert not context_manager._should_compress_historical_tool_message(error_msg, error_content) # But should be compressed if preserve_tool_errors=False context_manager.config.preserve_tool_errors = False assert context_manager._should_compress_historical_tool_message(error_msg, error_content) def test_recent_messages_always_preserved(self, context_manager): """Test that recent messages are always preserved regardless of content.""" # Create messages where recent ones are large but should still be preserved messages = [] # Historical messages for i in range(10): messages.extend( [ HumanMessage(content=f"Historical {i}", id=f"hist_human_{i}"), ToolMessage(content="A" * 500, tool_call_id=f"hist_{i}", id=f"hist_tool_{i}"), ] ) # Recent messages (including large tool output) messages.extend( [ HumanMessage(content="Recent question", id="recent_human"), AIMessage(content="Recent response", id="recent_ai"), ToolMessage(content="B" * 1000, tool_call_id="recent", id="recent_tool"), # Large but recent ] ) original_count = len(messages) context_manager.manage_context_messages(messages) # Recent messages should be preserved (even if content gets compressed due to summarization) recent_ids = ["recent_human", "recent_ai", "recent_tool"] remaining_recent = [msg for msg in messages if hasattr(msg, "id") and msg.id in recent_ids] # All recent message IDs should still be present (even if summarization occurred) assert len(remaining_recent) >= 2, "Most recent messages should be preserved" def test_message_order_preservation(self, context_manager): """Test that message ordering is preserved during context management.""" # Disable conversation summarization to test only tool compression context_manager.config.enable_conversation_summary = False context_manager.config.enable_summarization = False # Create messages with specific order messages = [ HumanMessage(content="Question 1", id="human_1"), AIMessage(content="Response 1", id="ai_1"), ToolMessage(content="A" * 1000, tool_call_id="tool_1", id="tool_1"), # Will be compressed HumanMessage(content="Question 2", id="human_2"), AIMessage(content="Response 2", id="ai_2"), ToolMessage(content="B" * 1000, tool_call_id="tool_2", id="tool_2"), # Will be compressed HumanMessage(content="Recent question", id="human_recent"), # Recent, should not be compressed AIMessage(content="Recent response", id="ai_recent"), # Recent ToolMessage( content="C" * 1000, tool_call_id="tool_recent", id="tool_recent" ), # Recent, should not be compressed ] original_order = [msg.id for msg in messages if hasattr(msg, "id")] context_manager.manage_context_messages(messages) # Extract the IDs in the new order result_order = [msg.id for msg in messages if hasattr(msg, "id")] # The order should be preserved assert result_order == original_order, "Message order should be preserved" # Verify that historical tool messages were actually compressed historical_tools = [msg for msg in messages if isinstance(msg, ToolMessage) and msg.id in ["tool_1", "tool_2"]] for msg in historical_tools: assert len(str(msg.content)) < 1000, "Historical tool messages should be compressed" ================================================ FILE: tests/test_catalog_loader.py ================================================ """Tests for catalog loader functionality.""" from unittest.mock import Mock, patch import pytest from openchatbi.catalog.catalog_loader import DataCatalogLoader, load_catalog_from_data_warehouse class TestDataCatalogLoader: """Test DataCatalogLoader functionality.""" @pytest.fixture def mock_engine(self): """Mock SQLAlchemy engine.""" engine = Mock() return engine def test_catalog_loader_initialization(self, mock_engine): """Test DataCatalogLoader initialization.""" with patch("openchatbi.catalog.catalog_loader.inspect") as mock_inspect: mock_inspect.return_value = Mock() loader = DataCatalogLoader(engine=mock_engine, include_tables=["table1", "table2"]) assert loader.engine == mock_engine assert loader.include_tables == ["table1", "table2"] def test_catalog_loader_without_include_tables(self, mock_engine): """Test DataCatalogLoader without include tables.""" with patch("openchatbi.catalog.catalog_loader.inspect") as mock_inspect: mock_inspect.return_value = Mock() loader = DataCatalogLoader(engine=mock_engine, include_tables=None) assert loader.include_tables is None def test_get_tables_and_columns(self, mock_engine): """Test getting tables and columns metadata.""" # Mock inspector mock_inspector = Mock() mock_inspector.get_table_names.return_value = ["table1", "table2"] mock_inspector.get_columns.return_value = [ {"name": "col1", "type": "VARCHAR(50)", "comment": "Test column", "default": None, "primary_key": False} ] with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector): loader = DataCatalogLoader(engine=mock_engine, include_tables=["table1"]) result = loader.get_tables_and_columns() assert "table1" in result assert len(result["table1"]) == 1 assert result["table1"][0]["column_name"] == "col1" def test_get_table_indexes(self, mock_engine): """Test getting table indexes.""" mock_inspector = Mock() mock_inspector.get_indexes.return_value = [{"name": "idx_test", "column_names": ["col1"]}] with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector): loader = DataCatalogLoader(engine=mock_engine) result = loader.get_table_indexes("table1") assert len(result) == 1 assert result[0]["name"] == "idx_test" def test_get_foreign_keys(self, mock_engine): """Test getting foreign keys.""" mock_inspector = Mock() mock_inspector.get_foreign_keys.return_value = [ {"name": "fk_test", "constrained_columns": ["col1"], "referred_table": "ref_table"} ] with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector): loader = DataCatalogLoader(engine=mock_engine) result = loader.get_foreign_keys("table1") assert len(result) == 1 assert result[0]["name"] == "fk_test" def test_save_to_catalog_store_success(self, mock_engine): """Test saving to catalog store successfully.""" mock_catalog_store = Mock() mock_catalog_store.save_table_information.return_value = True mock_catalog_store.save_table_sql_examples.return_value = True mock_catalog_store.save_table_selection_examples.return_value = True mock_inspector = Mock() mock_inspector.get_table_names.return_value = ["table1"] mock_inspector.get_columns.return_value = [ {"name": "col1", "type": "VARCHAR(50)", "comment": "Test column", "default": None, "primary_key": False} ] mock_inspector.get_table_comment.return_value = {"text": "Test table"} with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector): loader = DataCatalogLoader(engine=mock_engine, include_tables=["table1"]) result = loader.save_to_catalog_store(mock_catalog_store, "test_db") assert result == True mock_catalog_store.save_table_information.assert_called() mock_catalog_store.save_table_sql_examples.assert_called() mock_catalog_store.save_table_selection_examples.assert_called() def test_save_to_catalog_store_failure(self, mock_engine): """Test handling catalog store save failures.""" mock_catalog_store = Mock() mock_catalog_store.save_table_information.return_value = False mock_inspector = Mock() mock_inspector.get_table_names.return_value = ["table1"] mock_inspector.get_columns.return_value = [] with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector): loader = DataCatalogLoader(engine=mock_engine) result = loader.save_to_catalog_store(mock_catalog_store) assert result == False def test_load_catalog_from_data_warehouse(self): """Test main entry point for catalog loading.""" mock_catalog_store = Mock() mock_catalog_store.get_data_warehouse_config.return_value = { "uri": "test://user@host/db", "include_tables": ["table1"], "database_name": "test_db", } mock_catalog_store.get_sql_engine.return_value = Mock() with patch("openchatbi.catalog.catalog_loader.DataCatalogLoader") as mock_loader_class: mock_loader = Mock() mock_loader.save_to_catalog_store.return_value = True mock_loader_class.return_value = mock_loader result = load_catalog_from_data_warehouse(mock_catalog_store) assert result == True mock_loader.save_to_catalog_store.assert_called_once() def test_error_handling_in_get_tables_and_columns(self, mock_engine): """Test error handling in get_tables_and_columns method.""" mock_inspector = Mock() mock_inspector.get_table_names.side_effect = Exception("Database error") with patch("openchatbi.catalog.catalog_loader.inspect", return_value=mock_inspector): loader = DataCatalogLoader(engine=mock_engine) result = loader.get_tables_and_columns() assert result == {} ================================================ FILE: tests/test_catalog_store.py ================================================ """Tests for catalog store functionality.""" import pytest from openchatbi.catalog.catalog_store import CatalogStore from openchatbi.catalog.store.file_system import FileSystemCatalogStore class TestCatalogStore: """Test base CatalogStore functionality.""" def test_catalog_store_is_abstract(self): """Test that CatalogStore cannot be instantiated directly.""" with pytest.raises(TypeError): CatalogStore() def test_catalog_store_interface_methods(self): """Test that CatalogStore defines required interface methods.""" # Check that abstract methods exist assert hasattr(CatalogStore, "get_table_list") assert hasattr(CatalogStore, "get_column_list") assert hasattr(CatalogStore, "get_table_information") assert hasattr(CatalogStore, "get_data_warehouse_config") assert hasattr(CatalogStore, "get_sql_engine") assert hasattr(CatalogStore, "save_table_information") class TestFileSystemCatalogStore: """Test FileSystemCatalogStore functionality.""" def test_filesystem_store_initialization(self, temp_dir): """Test FileSystemCatalogStore initialization.""" data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"} data_path = str(temp_dir) store = FileSystemCatalogStore(data_path=data_path, data_warehouse_config=data_warehouse_config) assert store.data_path == data_path assert isinstance(store, CatalogStore) def test_get_tables_from_csv(self, mock_catalog_store): """Test getting tables from CSV file.""" tables = mock_catalog_store.get_table_list() assert isinstance(tables, list) assert len(tables) >= 1 def test_get_columns_from_csv(self, mock_catalog_store): """Test getting columns from CSV file.""" columns = mock_catalog_store.get_column_list("test_table", "test") assert isinstance(columns, list) if columns: column = columns[0] assert "column_name" in column or "name" in column assert "data_type" in column or "type" in column def test_get_table_info(self, mock_catalog_store): """Test getting table information.""" table_info = mock_catalog_store.get_table_information("test.test_table") assert isinstance(table_info, dict) def test_get_tables_file_not_found(self, temp_dir): """Test handling when tables file doesn't exist.""" empty_dir = temp_dir / "empty" empty_dir.mkdir() data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"} store = FileSystemCatalogStore(data_path=str(empty_dir), data_warehouse_config=data_warehouse_config) # Should handle missing file gracefully tables = store.get_table_list() assert isinstance(tables, list) def test_get_columns_file_not_found(self, temp_dir): """Test handling when columns file doesn't exist.""" empty_dir = temp_dir / "empty" empty_dir.mkdir() data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"} store = FileSystemCatalogStore(data_path=str(empty_dir), data_warehouse_config=data_warehouse_config) # Should handle missing file gracefully columns = store.get_column_list("nonexistent_table") assert isinstance(columns, list) def test_get_tables_malformed_csv(self, temp_dir): """Test handling malformed CSV files.""" # Create malformed CSV malformed_csv = temp_dir / "table_columns.csv" malformed_csv.write_text("invalid,csv,format\\nno,proper\\nheaders") data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"} store = FileSystemCatalogStore(data_path=str(temp_dir), data_warehouse_config=data_warehouse_config) # Should handle malformed CSV gracefully tables = store.get_table_list() assert isinstance(tables, list) def test_get_tables_pandas_error(self, temp_dir): """Test handling pandas errors.""" data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"} store = FileSystemCatalogStore(data_path=str(temp_dir), data_warehouse_config=data_warehouse_config) # Should handle pandas errors gracefully tables = store.get_table_list() assert isinstance(tables, list) def test_get_table_schema(self, mock_catalog_store): """Test getting complete table schema.""" # Use get_table_information instead of get_table_schema schema = mock_catalog_store.get_table_information("test.test_table") assert isinstance(schema, dict) def test_search_tables(self, mock_catalog_store): """Test searching for tables by keyword.""" # This method might not exist in current implementation # but it's a common catalog feature if hasattr(mock_catalog_store, "search_tables"): results = mock_catalog_store.search_tables("test") assert isinstance(results, list) def test_get_all_table_names(self, mock_catalog_store): """Test getting all table names.""" tables = mock_catalog_store.get_table_list() # get_table_list() returns list of strings (table names), not dictionaries assert isinstance(tables, list) # Verify all items are strings for table_name in tables: assert isinstance(table_name, str) def test_case_insensitive_table_lookup(self, mock_catalog_store): """Test case-insensitive table lookups.""" # Test with different cases test_cases = ["test_table", "TEST_TABLE", "Test_Table"] for table_name in test_cases: columns = mock_catalog_store.get_column_list(table_name) assert isinstance(columns, list) def test_data_path_validation(self): """Test data path validation.""" data_warehouse_config = {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"} # Test with None path with pytest.raises((ValueError, TypeError)): FileSystemCatalogStore(data_path=None, data_warehouse_config=data_warehouse_config) # Test with empty string with pytest.raises((ValueError, FileNotFoundError)): FileSystemCatalogStore(data_path="", data_warehouse_config=data_warehouse_config) def test_concurrent_access(self, mock_catalog_store): """Test concurrent access to catalog store.""" import threading import time results = [] errors = [] def worker(): try: tables = mock_catalog_store.get_table_list() results.append(len(tables)) time.sleep(0.01) columns = mock_catalog_store.get_column_list("test_table", "test") results.append(len(columns)) except Exception as e: errors.append(e) # Create multiple threads threads = [] for _ in range(5): thread = threading.Thread(target=worker) threads.append(thread) # Start all threads for thread in threads: thread.start() # Wait for completion for thread in threads: thread.join() # Should not have errors from concurrent access assert len(errors) == 0 assert len(results) > 0 ================================================ FILE: tests/test_config_loader.py ================================================ """Tests for configuration loading functionality.""" from unittest.mock import MagicMock, patch import pytest import yaml from openchatbi.config_loader import Config, ConfigLoader class TestConfigLoader: """Test configuration loading functionality.""" def test_config_initialization(self): """Test Config model initialization.""" from unittest.mock import MagicMock mock_llm = MagicMock() mock_embedding = MagicMock() config = Config(organization="TestOrg", dialect="presto", default_llm=mock_llm, embedding_model=mock_embedding) assert config.organization == "TestOrg" assert config.dialect == "presto" assert config.default_llm == mock_llm assert config.embedding_model == mock_embedding def test_config_from_dict(self): """Test creating Config from dictionary.""" from unittest.mock import MagicMock mock_llm = MagicMock() mock_embedding = MagicMock() config_dict = { "organization": "TestOrg", "dialect": "mysql", "default_llm": mock_llm, "embedding_model": mock_embedding, } config = Config.from_dict(config_dict) assert config.organization == "TestOrg" assert config.dialect == "mysql" assert config.default_llm == mock_llm assert config.embedding_model == mock_embedding def test_config_loader_initialization(self): """Test ConfigLoader initialization.""" loader = ConfigLoader() # Initially, config should be None until loaded # Don't assert _config state since it depends on previous tests def test_load_config_from_file(self, temp_dir): """Test loading configuration from YAML file.""" config_data = { "organization": "TestOrg", "dialect": "presto", "default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}}, "embedding_model": { "class": "langchain_openai.OpenAIEmbeddings", "params": {"model": "text-embedding-ada-002"}, }, "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, } config_file = temp_dir / "test_config.yaml" with open(config_file, "w") as f: yaml.dump(config_data, f) loader = ConfigLoader() with ( patch("langchain_openai.ChatOpenAI") as mock_openai, patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings, ): # Create a proper mock that satisfies BaseChatModel interface from langchain_core.language_models import BaseChatModel mock_llm_instance = MagicMock(spec=BaseChatModel) mock_embedding_instance = MagicMock() mock_openai.return_value = mock_llm_instance mock_embeddings.return_value = mock_embedding_instance loader.load(str(config_file)) config = loader.get() assert config.organization == "TestOrg" assert config.dialect == "presto" assert config.default_llm == mock_llm_instance assert config.embedding_model == mock_embedding_instance def test_load_config_missing_file(self): """Test handling of missing configuration file.""" loader = ConfigLoader() # Reset the config to ensure clean state loader._config = None # The loader now logs and returns instead of raising FileNotFoundError loader.load("/nonexistent/path.yaml") # Verify that the config was not loaded (remains None) with pytest.raises(ValueError, match="Configuration has not been loaded"): loader.get() def test_load_config_invalid_yaml(self, temp_dir): """Test handling of invalid YAML syntax.""" config_file = temp_dir / "invalid_config.yaml" config_file.write_text("invalid: yaml: content: [") loader = ConfigLoader() with pytest.raises(ValueError, match="Invalid YAML in configuration file"): loader.load(str(config_file)) def test_load_config_with_bi_config_file(self, temp_dir): """Test loading configuration with BI config file.""" bi_config_data = {"metrics": ["revenue", "users"], "dimensions": ["date", "region"]} bi_config_file = temp_dir / "bi_config.yaml" with open(bi_config_file, "w") as f: yaml.dump(bi_config_data, f) config_data = { "organization": "TestOrg", "bi_config_file": str(bi_config_file), "default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}}, "embedding_model": { "class": "langchain_openai.OpenAIEmbeddings", "params": {"model": "text-embedding-ada-002"}, }, "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, } config_file = temp_dir / "test_config.yaml" with open(config_file, "w") as f: yaml.dump(config_data, f) loader = ConfigLoader() with ( patch("langchain_openai.ChatOpenAI") as mock_openai, patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings, ): mock_llm_instance = MagicMock() mock_embedding_instance = MagicMock() mock_openai.return_value = mock_llm_instance mock_embeddings.return_value = mock_embedding_instance loader.load(str(config_file)) config = loader.get() assert config.bi_config["metrics"] == ["revenue", "users"] assert config.bi_config["dimensions"] == ["date", "region"] def test_load_config_with_catalog_store(self, temp_dir): """Test loading configuration with catalog store.""" config_data = { "organization": "TestOrg", "catalog_store": {"store_type": "file_system", "data_path": str(temp_dir / "catalog_data")}, "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, "default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}}, "embedding_model": { "class": "langchain_openai.OpenAIEmbeddings", "params": {"model": "text-embedding-ada-002"}, }, } config_file = temp_dir / "test_config.yaml" with open(config_file, "w") as f: yaml.dump(config_data, f) loader = ConfigLoader() with ( patch("langchain_openai.ChatOpenAI") as mock_openai, patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings, ): mock_llm_instance = MagicMock() mock_embedding_instance = MagicMock() mock_openai.return_value = mock_llm_instance mock_embeddings.return_value = mock_embedding_instance loader.load(str(config_file)) config = loader.get() # Just verify that a catalog store was created assert config.catalog_store is not None assert hasattr(config.catalog_store, "get_table_list") def test_load_config_with_llm_configs(self, temp_dir): """Test loading configuration with LLM configs.""" config_data = { "organization": "TestOrg", "default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4", "temperature": 0.1}}, "embedding_model": { "class": "langchain_openai.OpenAIEmbeddings", "params": {"model": "text-embedding-ada-002"}, }, "text2sql_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-3.5-turbo"}}, "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, } config_file = temp_dir / "test_config.yaml" with open(config_file, "w") as f: yaml.dump(config_data, f) loader = ConfigLoader() with ( patch("langchain_openai.ChatOpenAI") as mock_openai, patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings, ): # Create proper mocks that satisfy BaseChatModel interface from langchain_core.language_models import BaseChatModel mock_instance1 = MagicMock(spec=BaseChatModel) mock_instance2 = MagicMock(spec=BaseChatModel) mock_embedding_instance = MagicMock() mock_openai.side_effect = [mock_instance1, mock_instance2] mock_embeddings.return_value = mock_embedding_instance loader.load(str(config_file)) config = loader.get() assert config.default_llm == mock_instance1 assert config.embedding_model == mock_embedding_instance assert config.text2sql_llm == mock_instance2 def test_load_config_with_llm_providers_selected_by_default_llm(self, temp_dir): """Test loading configuration using llm_providers with default_llm provider selector.""" config_data = { "organization": "TestOrg", "dialect": "presto", "default_llm": "openai", "llm_providers": { "openai": { "default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}}, "embedding_model": { "class": "langchain_openai.OpenAIEmbeddings", "params": {"model": "text-embedding-ada-002"}, }, }, "anthropic": { "default_llm": {"class": "langchain_anthropic.ChatAnthropic", "params": {"model": "claude"}}, }, }, "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, } config_file = temp_dir / "test_config.yaml" with open(config_file, "w") as f: yaml.dump(config_data, f) loader = ConfigLoader() with ( patch("langchain_openai.ChatOpenAI") as mock_openai, patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings, patch("langchain_anthropic.ChatAnthropic") as mock_anthropic, ): from langchain_core.language_models import BaseChatModel mock_openai_instance = MagicMock(spec=BaseChatModel) mock_anthropic_instance = MagicMock(spec=BaseChatModel) mock_embedding_instance = MagicMock() mock_openai.return_value = mock_openai_instance mock_anthropic.return_value = mock_anthropic_instance mock_embeddings.return_value = mock_embedding_instance loader.load(str(config_file)) config = loader.get() assert config.llm_provider == "openai" assert config.default_llm == mock_openai_instance assert config.embedding_model == mock_embedding_instance assert set(config.llm_providers.keys()) == {"openai", "anthropic"} def test_set_config(self): """Test setting configuration from dictionary.""" config_dict = { "organization": "SetOrg", "dialect": "postgresql", "default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}}, "embedding_model": { "class": "langchain_openai.OpenAIEmbeddings", "params": {"model": "text-embedding-ada-002"}, }, "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, } loader = ConfigLoader() with ( patch("langchain_openai.ChatOpenAI") as mock_openai, patch("langchain_openai.OpenAIEmbeddings") as mock_embeddings, ): mock_llm_instance = MagicMock() mock_embedding_instance = MagicMock() mock_openai.return_value = mock_llm_instance mock_embeddings.return_value = mock_embedding_instance loader.set(config_dict) config = loader.get() assert config.organization == "SetOrg" assert config.dialect == "postgresql" def test_get_config_not_loaded(self): """Test getting configuration when not loaded.""" loader = ConfigLoader() loader._config = None with pytest.raises(ValueError, match="Configuration has not been loaded"): loader.get() def test_load_bi_config_missing_file(self, temp_dir): """Test loading missing BI config file.""" nonexistent_file = temp_dir / "nonexistent_bi.yaml" loader = ConfigLoader() # Should not raise exception, just return empty dict result = loader.load_bi_config(str(nonexistent_file)) assert result == {} def test_catalog_store_missing_store_type(self, temp_dir): """Test catalog store configuration without store_type.""" config_data = { "organization": "TestOrg", "catalog_store": { "data_path": "/test/path" # Missing store_type }, "default_llm": {"class": "langchain_openai.ChatOpenAI", "params": {"model": "gpt-4"}}, "embedding_model": { "class": "langchain_openai.OpenAIEmbeddings", "params": {"model": "text-embedding-ada-002"}, }, "data_warehouse_config": {"uri": "sqlite:///:memory:", "include_tables": None, "database_name": "test_db"}, } config_file = temp_dir / "test_config.yaml" with open(config_file, "w") as f: yaml.dump(config_data, f) loader = ConfigLoader() with pytest.raises(ValueError, match="catalog_store must have a store_type field"): loader.load(str(config_file)) ================================================ FILE: tests/test_graph_state.py ================================================ """Tests for graph state management.""" from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from openchatbi.graph_state import AgentState, InputState, OutputState class TestAgentState: """Test AgentState functionality.""" def test_agent_state_with_data(self): """Test creating AgentState with initial data.""" messages = [HumanMessage(content="Test message")] sql = "SELECT * FROM test_table;" agent_next_node = "sql_generation" final_answer = "Here is your data" state = AgentState(messages=messages, sql=sql, agent_next_node=agent_next_node, final_answer=final_answer) assert state["messages"] == messages assert state["sql"] == sql assert state["agent_next_node"] == agent_next_node assert state["final_answer"] == final_answer def test_agent_state_message_types(self): """Test AgentState with different message types.""" messages = [ HumanMessage(content="User question"), AIMessage(content="AI response"), ToolMessage(content="Tool result", tool_call_id="test_id"), ] state = AgentState(messages=messages) assert len(state["messages"]) == 3 assert isinstance(state["messages"][0], HumanMessage) assert isinstance(state["messages"][1], AIMessage) assert isinstance(state["messages"][2], ToolMessage) def test_agent_state_immutability(self): """Test that AgentState behaves correctly with updates.""" original_state = AgentState( messages=[HumanMessage(content="Original")], sql="SELECT 1;", agent_next_node="original_node", final_answer="Original answer", ) # Create updated state new_messages = original_state["messages"] + [AIMessage(content="Response")] updated_state = AgentState( messages=new_messages, sql="SELECT 2;", agent_next_node="updated_node", final_answer="Updated answer" ) # Original state should remain unchanged assert len(original_state["messages"]) == 1 assert original_state["sql"] == "SELECT 1;" assert original_state["agent_next_node"] == "original_node" assert original_state["final_answer"] == "Original answer" # Updated state should have new values assert len(updated_state["messages"]) == 2 assert updated_state["sql"] == "SELECT 2;" assert updated_state["agent_next_node"] == "updated_node" assert updated_state["final_answer"] == "Updated answer" class TestInputState: """Test InputState functionality.""" def test_input_state_creation(self): """Test creating InputState.""" messages = [HumanMessage(content="Input message")] state = InputState(messages=messages) assert state["messages"] == messages def test_input_state_empty_messages(self): """Test InputState with empty messages.""" state = InputState(messages=[]) assert state["messages"] == [] class TestOutputState: """Test OutputState functionality.""" def test_output_state_creation(self): """Test creating OutputState.""" messages = [AIMessage(content="Output message")] state = OutputState(messages=messages) assert state["messages"] == messages def test_output_state_with_multiple_messages(self): """Test OutputState with conversation history.""" messages = [ HumanMessage(content="Question"), AIMessage(content="Answer"), HumanMessage(content="Follow-up"), AIMessage(content="Final response"), ] state = OutputState(messages=messages) assert len(state["messages"]) == 4 assert state["messages"] == messages class TestStateIntegration: """Test integration between different state types.""" def test_input_to_agent_state_conversion(self): """Test converting InputState to AgentState.""" input_messages = [HumanMessage(content="User input")] input_state = InputState(messages=input_messages) # Simulate conversion to AgentState agent_state = AgentState(messages=input_state["messages"], sql="", agent_next_node="", final_answer="") assert agent_state["messages"] == input_messages assert agent_state["sql"] == "" def test_agent_to_output_state_conversion(self): """Test converting AgentState to OutputState.""" agent_messages = [HumanMessage(content="Question"), AIMessage(content="Generated response")] agent_state = AgentState( messages=agent_messages, sql="SELECT * FROM test_table;", agent_next_node="output", final_answer="Generated response", ) # Simulate conversion to OutputState output_state = OutputState(messages=agent_state["messages"]) assert output_state["messages"] == agent_messages def test_state_serialization_compatibility(self): """Test that states can be serialized and deserialized.""" original_state = AgentState( messages=[HumanMessage(content="Test"), AIMessage(content="Response")], sql="SELECT COUNT(*) FROM table1;", agent_next_node="final", final_answer="Count results", ) # Convert to dict (simulating serialization) state_dict = { "messages": original_state["messages"], "sql": original_state["sql"], "agent_next_node": original_state["agent_next_node"], "final_answer": original_state["final_answer"], } # Recreate from dict (simulating deserialization) recreated_state = AgentState(**state_dict) assert recreated_state["messages"] == original_state["messages"] assert recreated_state["sql"] == original_state["sql"] assert recreated_state["agent_next_node"] == original_state["agent_next_node"] assert recreated_state["final_answer"] == original_state["final_answer"] ================================================ FILE: tests/test_incomplete_tool_calls.py ================================================ """Tests for incomplete tool call recovery functionality.""" from unittest.mock import Mock from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from openchatbi.agent_graph import agent_llm_call from openchatbi.graph_state import AgentState from openchatbi.utils import recover_incomplete_tool_calls class TestIncompleteToolCallRecovery: """Test cases for recover_incomplete_tool_calls function.""" def test_no_messages(self): """Test recovery with empty message list.""" state = AgentState(messages=[]) result = recover_incomplete_tool_calls(state) assert result == [] def test_no_tool_calls(self): """Test recovery with messages but no tool calls.""" messages = [HumanMessage(content="Hello"), AIMessage(content="Hi there!")] state = AgentState(messages=messages) result = recover_incomplete_tool_calls(state) assert result == [] def test_complete_tool_calls(self): """Test recovery when all tool calls have responses.""" messages = [ HumanMessage(content="Search for data"), AIMessage( content="I'll search for that data.", tool_calls=[{"name": "search", "args": {"query": "data"}, "id": "call_1"}], ), ToolMessage(content="Search completed", tool_call_id="call_1"), ] state = AgentState(messages=messages) result = recover_incomplete_tool_calls(state) assert result == [] def test_incomplete_single_tool_call(self): """Test recovery when there's one incomplete tool call.""" messages = [ HumanMessage(content="Search for data"), AIMessage( content="I'll search for that data.", tool_calls=[{"name": "search", "args": {"query": "data"}, "id": "call_1"}], ), ] state = AgentState(messages=messages) result = recover_incomplete_tool_calls(state) assert isinstance(result, list) assert len(result) == 1 # Just the recovery message failure_msg = result[0] assert failure_msg.tool_call_id == "call_1" assert "interrupted" in failure_msg.content.lower() def test_incomplete_multiple_tool_calls(self): """Test recovery when there are multiple incomplete tool calls.""" messages = [ HumanMessage(content="Search and analyze"), AIMessage( content="I'll search and analyze.", tool_calls=[ {"name": "search", "args": {"query": "data"}, "id": "call_1"}, {"name": "analyze", "args": {"data": "result"}, "id": "call_2"}, ], ), ] state = AgentState(messages=messages) result = recover_incomplete_tool_calls(state) assert isinstance(result, list) assert len(result) == 2 # Just the recovery messages # Check that both tool calls get failure messages recovery_messages = result tool_call_ids = {msg.tool_call_id for msg in recovery_messages} assert tool_call_ids == {"call_1", "call_2"} for msg in recovery_messages: assert isinstance(msg, ToolMessage) assert "interrupted" in msg.content.lower() def test_partial_incomplete_tool_calls(self): """Test recovery when some tool calls are complete, others are not.""" messages = [ HumanMessage(content="Search and analyze"), AIMessage( content="I'll search and analyze.", tool_calls=[ {"name": "search", "args": {"query": "data"}, "id": "call_1"}, {"name": "analyze", "args": {"data": "result"}, "id": "call_2"}, ], ), ToolMessage(content="Search completed", tool_call_id="call_1"), # Missing ToolMessage for call_2 ] state = AgentState(messages=messages) result = recover_incomplete_tool_calls(state) assert isinstance(result, list) assert len(result) == 3 # RemoveMessage + recovery message + re-added message # Should have: RemoveMessage, ToolMessage(recovery for call_2), ToolMessage(original for call_1) operations = result assert "RemoveMessage" in str(type(operations[0])) # Remove the existing ToolMessage assert isinstance(operations[1], ToolMessage) # Recovery message for call_2 assert isinstance(operations[2], ToolMessage) # Re-added original message for call_1 # The recovery message should be for call_2 recovery_msg = operations[1] assert recovery_msg.tool_call_id == "call_2" assert "interrupted" in recovery_msg.content.lower() # The re-added message should be the original for call_1 original_msg = operations[2] assert original_msg.tool_call_id == "call_1" assert original_msg.content == "Search completed" def test_multiple_ai_messages_with_tool_calls(self): """Test recovery considers only the last AIMessage with tool calls.""" messages = [ HumanMessage(content="First task"), AIMessage(content="Doing first task.", tool_calls=[{"name": "task1", "args": {}, "id": "old_call"}]), ToolMessage(content="Task 1 done", tool_call_id="old_call"), HumanMessage(content="Second task"), AIMessage(content="Doing second task.", tool_calls=[{"name": "task2", "args": {}, "id": "new_call"}]), # Missing ToolMessage for new_call ] state = AgentState(messages=messages) result = recover_incomplete_tool_calls(state) assert isinstance(result, list) assert len(result) == 1 # Just the recovery message # The recovery message should be for new_call only recovery_msg = result[0] assert recovery_msg.tool_call_id == "new_call" assert "interrupted" in recovery_msg.content.lower() def test_llm_node_integration_with_recovery(self): """Test that the llm_node handles recovery correctly and continues processing.""" # Create a mock llm_node function for testing mock_llm = Mock() mock_tools = [] llm_node_func = agent_llm_call(mock_llm, mock_tools) # State with incomplete tool calls messages = [ HumanMessage(content="Search for data"), AIMessage( content="I'll search for that data.", tool_calls=[{"name": "search", "args": {"query": "data"}, "id": "call_1"}], ), ] state = AgentState(messages=messages) # Call the llm node - it should detect incomplete tool calls and return recovery result = llm_node_func(state) # Should return message operations and continue to llm_node assert "messages" in result assert "agent_next_node" in result assert result["agent_next_node"] == "llm_node" # Should have recovery ToolMessage operation for the incomplete call operations = result["messages"] assert len(operations) == 1 # Only recovery message needed assert isinstance(operations[0], ToolMessage) assert operations[0].tool_call_id == "call_1" ================================================ FILE: tests/test_memory.py ================================================ """Tests for memory tool functionality.""" from pathlib import Path from unittest.mock import AsyncMock, Mock, patch import pytest from langchain_core.language_models import FakeListChatModel from langchain_openai import ChatOpenAI # Check if pysqlite3 is available, if not skip these tests pysqlite3 = pytest.importorskip("pysqlite3", reason="pysqlite3 not available") from openchatbi.tool.memory import ( StructuredToolWithRequired, UserProfile, cleanup_async_memory_store, fix_schema_for_openai, get_async_memory_store, get_async_memory_tools, get_memory_manager, get_memory_tools, get_sync_memory_store, setup_async_memory_store, ) class TestUserProfile: """Test UserProfile model functionality.""" def test_user_profile_basic_initialization(self): """Test basic UserProfile model creation.""" profile = UserProfile(name="John Doe", language="English", timezone="UTC", jargon="Technical") assert profile.name == "John Doe" assert profile.language == "English" assert profile.timezone == "UTC" assert profile.jargon == "Technical" def test_user_profile_optional_fields(self): """Test UserProfile with optional fields.""" profile = UserProfile() assert profile.name is None assert profile.language is None assert profile.timezone is None assert profile.jargon is None def test_user_profile_partial_initialization(self): """Test UserProfile with partial field initialization.""" profile = UserProfile(name="Jane Smith", language="Spanish") assert profile.name == "Jane Smith" assert profile.language == "Spanish" assert profile.timezone is None assert profile.jargon is None def test_user_profile_serialization(self): """Test UserProfile model serialization.""" profile = UserProfile(name="Test User", timezone="EST") data = profile.model_dump() assert data["name"] == "Test User" assert data["timezone"] == "EST" assert data["language"] is None assert data["jargon"] is None class TestMemoryStoreManagement: """Test memory store management functions.""" @pytest.fixture(autouse=True) def setup_test_env(self, tmp_path: Path): """Setup test environment with temporary database.""" self.temp_db_path = tmp_path / "test_memory.db" # Clean up any global state import openchatbi.tool.memory as memory_module memory_module.sync_memory_store = None memory_module.async_memory_store = None memory_module.async_store_context_manager = None @patch("openchatbi.tool.memory.sqlite3.connect") @patch("openchatbi.tool.memory.config.get") def test_get_sync_memory_store(self, mock_config, mock_connect): """Test sync memory store creation.""" mock_config.return_value.embedding_model = Mock() mock_conn = Mock() mock_connect.return_value = mock_conn # Mock SqliteStore with patch("openchatbi.tool.memory.SqliteStore") as mock_store_class: mock_store = Mock() mock_store_class.return_value = mock_store store = get_sync_memory_store() assert store == mock_store mock_store_class.assert_called_once() mock_store.setup.assert_called_once() @pytest.mark.asyncio @patch("openchatbi.tool.memory.AsyncSqliteStore.from_conn_string") @patch("openchatbi.tool.memory.config.get") async def test_get_async_memory_store(self, mock_config, mock_from_conn_string): """Test async memory store creation.""" mock_config.return_value.embedding_model = Mock() # Mock the async context manager mock_context_manager = AsyncMock() mock_store = Mock() mock_context_manager.__aenter__.return_value = mock_store mock_from_conn_string.return_value = mock_context_manager store = await get_async_memory_store() assert store == mock_store mock_from_conn_string.assert_called_once() mock_context_manager.__aenter__.assert_called_once() @pytest.mark.asyncio @patch("openchatbi.tool.memory.async_memory_store", new=Mock()) @patch("openchatbi.tool.memory.async_store_context_manager") async def test_cleanup_async_memory_store(self, mock_context_manager): """Test async memory store cleanup.""" mock_context_manager.__aexit__ = AsyncMock() await cleanup_async_memory_store() mock_context_manager.__aexit__.assert_called_once_with(None, None, None) @pytest.mark.asyncio @patch("openchatbi.tool.memory.get_async_memory_store") async def test_setup_async_memory_store(self, mock_get_store): """Test async memory store setup.""" mock_store = Mock() mock_get_store.return_value = mock_store result = await setup_async_memory_store() mock_get_store.assert_called_once() assert result is None class TestMemoryTools: """Test memory tools creation and management.""" @patch("openchatbi.tool.memory.create_manage_memory_tool") @patch("openchatbi.tool.memory.create_search_memory_tool") @patch("openchatbi.tool.memory.get_sync_memory_store") def test_get_memory_tools_sync_mode(self, mock_get_store, mock_search_tool, mock_manage_tool): """Test getting memory tools in sync mode.""" mock_llm = FakeListChatModel(responses=["test"]) mock_store = Mock() mock_get_store.return_value = mock_store mock_manage = Mock() mock_search = Mock() mock_manage_tool.return_value = mock_manage mock_search_tool.return_value = mock_search memory_tools = get_memory_tools(mock_llm, sync_mode=True) manage_tool, search_tool = memory_tools[0], memory_tools[1] assert manage_tool == mock_manage assert search_tool == mock_search mock_manage_tool.assert_called_once_with(namespace=("memories", "{user_id}"), store=mock_store) mock_search_tool.assert_called_once_with(namespace=("memories", "{user_id}"), store=mock_store) @patch("openchatbi.tool.memory.create_manage_memory_tool") @patch("openchatbi.tool.memory.create_search_memory_tool") @patch("openchatbi.tool.memory.config.get") def test_get_memory_tools_with_openai_llm(self, mock_config, mock_search_tool, mock_manage_tool): """Test getting memory tools with OpenAI LLM (requires structured tool wrapper).""" mock_llm = Mock(spec=ChatOpenAI) mock_config.return_value.embedding_model = Mock() mock_manage = Mock() mock_search = Mock() mock_manage_tool.return_value = mock_manage mock_search_tool.return_value = mock_search with patch("openchatbi.tool.memory.StructuredToolWithRequired") as mock_wrapper: mock_wrapped_manage = Mock() mock_wrapped_search = Mock() mock_wrapper.side_effect = [mock_wrapped_manage, mock_wrapped_search] memory_tools = get_memory_tools(mock_llm, sync_mode=True) manage_tool, search_tool = memory_tools[0], memory_tools[1] assert manage_tool == mock_wrapped_manage assert search_tool == mock_wrapped_search assert mock_wrapper.call_count == 2 @pytest.mark.asyncio @patch("openchatbi.tool.memory.get_async_memory_store") @patch("openchatbi.tool.memory.get_memory_tools") @patch("openchatbi.tool.memory.config.get") async def test_get_async_memory_tools(self, mock_config, mock_get_tools, mock_get_store): """Test getting async memory tools.""" mock_llm = FakeListChatModel(responses=["test"]) mock_store = Mock() mock_get_store.return_value = mock_store mock_config.return_value.embedding_model = Mock() mock_manage = Mock() mock_search = Mock() mock_get_tools.return_value = (mock_manage, mock_search) manage_tool, search_tool = await get_async_memory_tools(mock_llm) assert manage_tool == mock_manage assert search_tool == mock_search mock_get_store.assert_called_once() mock_get_tools.assert_called_once_with(mock_llm, sync_mode=False, store=mock_store) class TestMemoryManager: """Test memory manager functionality.""" @patch("openchatbi.tool.memory.create_memory_store_manager") @patch("openchatbi.tool.memory.config.get") def test_get_memory_manager(self, mock_config, mock_create_manager): """Test memory manager creation.""" mock_llm = Mock() mock_config.return_value.default_llm = mock_llm mock_manager = Mock() mock_create_manager.return_value = mock_manager manager = get_memory_manager() assert manager == mock_manager mock_create_manager.assert_called_once_with( mock_llm, schemas=[UserProfile], instructions="Extract user profile information", enable_inserts=False, ) @patch("openchatbi.tool.memory.memory_manager", new=Mock()) @patch("openchatbi.tool.memory.create_memory_store_manager") @patch("openchatbi.tool.memory.config.get") def test_get_memory_manager_singleton(self, mock_config, mock_create_manager): """Test memory manager singleton behavior.""" # Reset the global variable for this test import openchatbi.tool.memory as memory_module existing_manager = Mock() memory_module.memory_manager = existing_manager manager = get_memory_manager() # Should return existing manager without creating new one assert manager == existing_manager mock_create_manager.assert_not_called() class TestSchemaFixer: """Test schema fixing functionality for OpenAI compatibility.""" def test_fix_schema_for_openai_basic(self): """Test basic schema fixing.""" schema = {"properties": {"field1": {"type": "string"}, "field2": {"type": "number"}}} fix_schema_for_openai(schema) assert schema["required"] == ["field1", "field2"] def test_fix_schema_for_openai_nested_object(self): """Test schema fixing with nested objects.""" schema = { "properties": { "nested": {"type": "object", "additionalProperties": True, "properties": {"inner": {"type": "string"}}} } } fix_schema_for_openai(schema) assert schema["required"] == ["nested"] assert schema["properties"]["nested"]["additionalProperties"] is False def test_fix_schema_for_openai_with_arrays(self): """Test schema fixing with array properties.""" schema = {"properties": {"items": {"type": "array", "items": {"type": "object", "additionalProperties": True}}}} fix_schema_for_openai(schema) assert schema["required"] == ["items"] assert schema["properties"]["items"]["items"]["additionalProperties"] is False class TestStructuredToolWithRequired: """Test StructuredToolWithRequired wrapper functionality.""" def test_structured_tool_with_required_initialization(self): """Test StructuredToolWithRequired initialization.""" mock_original_tool = Mock() mock_original_tool.name = "test_tool" mock_original_tool.description = "Test description" mock_original_tool.args_schema = Mock() mock_original_tool.func = Mock() mock_original_tool.coroutine = None with patch("openchatbi.tool.memory.StructuredTool.__init__", return_value=None) as mock_init: wrapper = StructuredToolWithRequired(mock_original_tool) # Verify the __init__ was called with correct parameters mock_init.assert_called_once() call_args = mock_init.call_args assert call_args.kwargs["name"] == "test_tool" assert call_args.kwargs["description"] == "Test description" def test_tool_call_schema_property(self): """Test tool_call_schema cached property.""" mock_original_tool = Mock() mock_original_tool.name = "test_tool" mock_original_tool.description = "Test description" mock_original_tool.args_schema = Mock() mock_original_tool.func = Mock() mock_original_tool.coroutine = None with patch("openchatbi.tool.memory.StructuredTool.__init__", return_value=None): wrapper = StructuredToolWithRequired(mock_original_tool) # Mock the parent's tool_call_schema mock_tcs = Mock() mock_tcs.model_config = {} with patch("openchatbi.tool.memory.StructuredTool.tool_call_schema", new_callable=lambda: mock_tcs): result = wrapper.tool_call_schema assert result == mock_tcs assert "json_schema_extra" in mock_tcs.model_config ================================================ FILE: tests/test_plotly_utils.py ================================================ """Tests for plotly utilities in the UI.""" import plotly.graph_objects as go import pytest from sample_ui.plotly_utils import ( create_empty_chart, create_plotly_chart, visualization_dsl_to_gradio_plot, ) @pytest.fixture def sample_csv_data(): """Sample CSV data for testing.""" return """product,sales,region,month Widget A,10000,North,Jan Widget B,15000,South,Jan Widget C,8000,East,Jan Widget A,12000,North,Feb Widget B,18000,South,Feb Widget C,9000,East,Feb""" @pytest.fixture def sample_line_dsl(): """Sample DSL for line chart.""" return { "chart_type": "line", "data_columns": ["month", "sales"], "config": {"x": "month", "y": "sales", "mode": "lines+markers"}, "layout": {"title": "Sales Over Time", "xaxis_title": "Month", "yaxis_title": "Sales"}, } @pytest.fixture def sample_bar_dsl(): """Sample DSL for bar chart.""" return { "chart_type": "bar", "data_columns": ["region", "sales"], "config": {"x": "region", "y": "sales"}, "layout": {"title": "Sales by Region", "xaxis_title": "Region", "yaxis_title": "Sales"}, } @pytest.fixture def sample_pie_dsl(): """Sample DSL for pie chart.""" return { "chart_type": "pie", "data_columns": ["product", "sales"], "config": {"labels": "product", "values": "sales"}, "layout": {"title": "Sales Distribution by Product"}, } class TestPlotlyChartCreation: """Tests for individual chart creation functions.""" def test_create_line_chart_success(self, sample_csv_data, sample_line_dsl): """Test successful line chart creation.""" fig = create_plotly_chart(sample_csv_data, sample_line_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 assert fig.layout.title.text == "Sales Over Time" def test_create_line_chart_with_color(self, sample_csv_data): """Test line chart creation with color parameter for multiple series.""" multi_series_dsl = { "chart_type": "line", "data_columns": ["month", "sales", "product"], "config": {"x": "month", "y": "sales", "color": "product"}, "layout": {"title": "Sales Over Time by Product", "xaxis_title": "Month", "yaxis_title": "Sales"}, } fig = create_plotly_chart(sample_csv_data, multi_series_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 # Should have multiple traces for different products assert fig.layout.title.text == "Sales Over Time by Product" def test_create_line_chart_with_multiple_y_columns(self): """Test line chart creation with multiple y columns.""" multi_metric_data = """date,revenue,profit,users 2023-01-01,50000,15000,1000 2023-02-01,55000,18000,1100 2023-03-01,60000,20000,1200""" multi_y_dsl = { "chart_type": "line", "data_columns": ["date", "revenue", "profit"], "config": {"x": "date", "y": ["revenue", "profit"]}, "layout": {"title": "Multiple Metrics Over Time", "xaxis_title": "Date", "yaxis_title": "Value"}, } fig = create_plotly_chart(multi_metric_data, multi_y_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 # Should have multiple traces for different metrics assert fig.layout.title.text == "Multiple Metrics Over Time" def test_create_bar_chart_success(self, sample_csv_data, sample_bar_dsl): """Test successful bar chart creation.""" fig = create_plotly_chart(sample_csv_data, sample_bar_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 assert fig.layout.title.text == "Sales by Region" def test_create_pie_chart_success(self, sample_csv_data, sample_pie_dsl): """Test successful pie chart creation.""" fig = create_plotly_chart(sample_csv_data, sample_pie_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 assert fig.layout.title.text == "Sales Distribution by Product" def test_create_scatter_chart(self, sample_csv_data): """Test scatter chart creation.""" scatter_dsl = { "chart_type": "scatter", "data_columns": ["sales", "region"], "config": {"x": "sales", "y": "region", "mode": "markers"}, "layout": {"title": "Sales Scatter Plot"}, } fig = create_plotly_chart(sample_csv_data, scatter_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 def test_create_histogram_chart(self, sample_csv_data): """Test histogram chart creation.""" histogram_dsl = { "chart_type": "histogram", "data_columns": ["sales"], "config": {"x": "sales", "nbins": 10}, "layout": {"title": "Sales Distribution"}, } fig = create_plotly_chart(sample_csv_data, histogram_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 def test_create_box_chart(self, sample_csv_data): """Test box chart creation.""" box_dsl = { "chart_type": "box", "data_columns": ["sales", "region"], "config": {"y": "sales", "x": "region"}, "layout": {"title": "Sales Distribution by Region"}, } fig = create_plotly_chart(sample_csv_data, box_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 def test_create_table_chart(self, sample_csv_data): """Test table chart creation.""" table_dsl = { "chart_type": "table", "data_columns": ["product", "sales", "region", "month"], "config": {"columns": ["product", "sales", "region", "month"]}, "layout": {"title": "Data Table"}, } fig = create_plotly_chart(sample_csv_data, table_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 assert fig.data[0].type == "table" class TestErrorHandling: """Tests for error handling in chart creation.""" def test_empty_data(self): """Test handling of empty data.""" fig = create_plotly_chart("", {}) assert isinstance(fig, go.Figure) # Should create an empty chart with error message def test_invalid_csv_data(self, sample_bar_dsl): """Test handling of invalid CSV data.""" invalid_csv = "invalid,csv\ndata" fig = create_plotly_chart(invalid_csv, sample_bar_dsl) assert isinstance(fig, go.Figure) # Should create an empty chart with error message def test_missing_columns(self, sample_csv_data): """Test handling of missing columns in DSL.""" invalid_dsl = { "chart_type": "line", "data_columns": ["nonexistent_col"], "config": {"x": "nonexistent_col", "y": "another_missing_col"}, "layout": {"title": "Invalid Chart"}, } fig = create_plotly_chart(sample_csv_data, invalid_dsl) assert isinstance(fig, go.Figure) # Should create an empty chart with error message def test_unsupported_chart_type(self, sample_csv_data): """Test handling of unsupported chart types.""" invalid_dsl = {"chart_type": "unsupported_type", "data_columns": ["sales"], "config": {}, "layout": {}} fig = create_plotly_chart(sample_csv_data, invalid_dsl) assert isinstance(fig, go.Figure) # Should create an empty chart with error message def test_visualization_dsl_error(self): """Test handling of DSL with error field.""" error_dsl = {"error": "Failed to generate visualization"} fig = create_plotly_chart("some,data\n1,2", error_dsl) assert isinstance(fig, go.Figure) # Should create an empty chart with error message class TestVisualizationDslToGradioPlot: """Tests for the main interface function.""" def test_successful_conversion(self, sample_csv_data, sample_line_dsl): """Test successful DSL to Gradio plot conversion.""" fig, description = visualization_dsl_to_gradio_plot(sample_csv_data, sample_line_dsl) assert isinstance(fig, go.Figure) assert isinstance(description, str) assert "line" in description.lower() assert "Sales Over Time" in description def test_empty_dsl(self, sample_csv_data): """Test conversion with empty DSL.""" fig, description = visualization_dsl_to_gradio_plot(sample_csv_data, {}) assert isinstance(fig, go.Figure) assert isinstance(description, str) assert "table" in description.lower() def test_no_data(self, sample_line_dsl): """Test conversion with no data.""" fig, description = visualization_dsl_to_gradio_plot("", sample_line_dsl) assert isinstance(fig, go.Figure) assert isinstance(description, str) class TestCreateEmptyChart: """Tests for empty chart creation.""" def test_create_empty_chart(self): """Test empty chart creation with message.""" message = "Test error message" fig = create_empty_chart(message) assert isinstance(fig, go.Figure) assert fig.layout.title.text == "Chart Generation Issue" # Check if annotation contains the message assert len(fig.layout.annotations) > 0 assert fig.layout.annotations[0].text == message @pytest.fixture def sample_time_series_data(): """Sample time series data for testing.""" return """date,revenue,users 2023-01-01,50000,1000 2023-02-01,55000,1100 2023-03-01,60000,1200 2023-04-01,52000,1050 2023-05-01,58000,1150""" class TestIntegrationScenarios: """Integration tests for complete visualization scenarios.""" def test_sales_dashboard_scenario(self, sample_csv_data): """Test a complete sales dashboard scenario.""" # Test multiple chart types with the same data chart_configs = [ {"chart_type": "bar", "config": {"x": "product", "y": "sales"}, "layout": {"title": "Sales by Product"}}, { "chart_type": "pie", "config": {"labels": "region", "values": "sales"}, "layout": {"title": "Sales by Region"}, }, ] for config in chart_configs: config["data_columns"] = list(config["config"].values()) fig = create_plotly_chart(sample_csv_data, config) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 def test_time_series_scenario(self, sample_time_series_data): """Test time series visualization scenario.""" line_dsl = { "chart_type": "line", "data_columns": ["date", "revenue"], "config": {"x": "date", "y": "revenue", "mode": "lines+markers"}, "layout": {"title": "Revenue Trend Over Time", "xaxis_title": "Date", "yaxis_title": "Revenue"}, } fig, description = visualization_dsl_to_gradio_plot(sample_time_series_data, line_dsl) assert isinstance(fig, go.Figure) assert "line" in description.lower() assert "Revenue Trend Over Time" in description def test_multiple_metrics_scenario(self, sample_time_series_data): """Test scenario with multiple metrics.""" scatter_dsl = { "chart_type": "scatter", "data_columns": ["revenue", "users"], "config": {"x": "revenue", "y": "users", "mode": "markers"}, "layout": {"title": "Revenue vs Users Correlation", "xaxis_title": "Revenue", "yaxis_title": "Users"}, } fig = create_plotly_chart(sample_time_series_data, scatter_dsl) assert isinstance(fig, go.Figure) assert len(fig.data) > 0 assert fig.layout.title.text == "Revenue vs Users Correlation" ================================================ FILE: tests/test_simple_store.py ================================================ """Unit tests for SimpleStore.""" import pytest from openchatbi.utils import SimpleStore class TestSimpleStore: """Test suite for SimpleStore class.""" @pytest.fixture def sample_texts(self): """Sample texts for testing.""" return [ "Python is a programming language", "Machine learning is a subset of AI", "Deep learning uses neural networks", "Natural language processing works with text", ] @pytest.fixture def sample_metadatas(self): """Sample metadata for testing.""" return [ {"category": "programming"}, {"category": "ai"}, {"category": "ai"}, {"category": "nlp"}, ] @pytest.fixture def simple_store(self, sample_texts): """Create a SimpleStore instance for testing.""" return SimpleStore(sample_texts) def test_initialization_basic(self, sample_texts): """Test basic initialization.""" store = SimpleStore(sample_texts) assert len(store.texts) == len(sample_texts) assert store.texts == sample_texts assert len(store.documents) == len(sample_texts) assert store.bm25 is not None def test_initialization_with_metadata_and_ids(self, sample_texts, sample_metadatas): """Test initialization with metadata and custom IDs.""" ids = ["id1", "id2", "id3", "id4"] store = SimpleStore(sample_texts, sample_metadatas, ids) assert store.texts == sample_texts assert store.metadatas == sample_metadatas assert store.ids == ids # Check documents are created correctly for doc, text, meta, doc_id in zip(store.documents, sample_texts, sample_metadatas, ids): assert doc.page_content == text assert doc.metadata == meta assert doc.id == doc_id def test_similarity_search(self, simple_store): """Test similarity search functionality.""" query = "programming" results = simple_store.similarity_search(query, k=2) assert len(results) == 2 assert "Python" in results[0].page_content # Test k parameter results = simple_store.similarity_search(query, k=10) assert len(results) == 4 # Should return all documents def test_similarity_search_with_score(self, simple_store): """Test similarity search with scores.""" query = "programming" results = simple_store.similarity_search_with_score(query, k=2) assert len(results) == 2 for doc, score in results: assert hasattr(doc, "page_content") assert isinstance(score, (int, float)) assert score >= 0 # Scores should be in descending order scores = [score for _, score in results] assert scores == sorted(scores, reverse=True) def test_empty_store(self): """Test empty store operations.""" store = SimpleStore([]) assert store.bm25 is None assert store.similarity_search("test", k=5) == [] assert store.similarity_search_with_score("test", k=5) == [] def test_add_texts(self, simple_store): """Test adding texts with and without metadata.""" initial_count = len(simple_store.texts) new_texts = ["Data science is important", "Statistics is fundamental"] new_metadatas = [{"type": "test"}, {"type": "example"}] # Add with metadata and custom IDs custom_ids = ["custom_1", "custom_2"] returned_ids = simple_store.add_texts(new_texts, metadatas=new_metadatas, ids=custom_ids) assert returned_ids == custom_ids assert len(simple_store.texts) == initial_count + len(new_texts) assert all(text in simple_store.texts for text in new_texts) # Check metadata was added correctly added_docs = [doc for doc in simple_store.documents if doc.id in custom_ids] assert len(added_docs) == 2 assert added_docs[0].metadata == {"type": "test"} # Verify BM25 index is updated results = simple_store.similarity_search("data science", k=1) assert "data" in results[0].page_content.lower() or "science" in results[0].page_content.lower() def test_delete(self): """Test deleting documents.""" texts = ["Text A", "Text B", "Text C", "Text D"] ids = ["id1", "id2", "id3", "id4"] store = SimpleStore(texts, ids=ids) # Delete specific IDs result = store.delete(["id2", "id3"]) assert result is True assert len(store.texts) == 2 assert store.texts == ["Text A", "Text D"] assert store.ids == ["id1", "id4"] # Delete non-existent IDs result = store.delete(["nonexistent"]) assert result is False # Delete with None result = store.delete(None) assert result is False # Delete all remaining documents result = store.delete(["id1", "id4"]) assert result is True assert len(store.texts) == 0 assert store.bm25 is None def test_get_by_ids(self, sample_texts): """Test retrieving documents by IDs.""" ids = ["id1", "id2", "id3", "id4"] store = SimpleStore(sample_texts, ids=ids) # Get existing IDs docs = store.get_by_ids(["id1", "id3"]) assert len(docs) == 2 assert docs[0].id == "id1" assert docs[0].page_content == sample_texts[0] # Get non-existent IDs docs = store.get_by_ids(["nonexistent"]) assert len(docs) == 0 # Mixed existent and non-existent docs = store.get_by_ids(["id1", "nonexistent", "id3"]) assert len(docs) == 2 def test_from_texts(self, sample_texts, sample_metadatas): """Test creating store using from_texts class method.""" ids = ["id1", "id2", "id3", "id4"] store = SimpleStore.from_texts(sample_texts, embedding=None, metadatas=sample_metadatas, ids=ids) assert isinstance(store, SimpleStore) assert store.texts == sample_texts assert store.metadatas == sample_metadatas assert store.ids == ids def test_as_retriever(self, simple_store): """Test creating a retriever from the store.""" retriever = simple_store.as_retriever(search_kwargs={"k": 2}) results = retriever.invoke("programming") assert len(results) <= 2 assert all(hasattr(doc, "page_content") for doc in results) def test_chinese_and_mixed_language(self): """Test search with Chinese and mixed language texts.""" from openchatbi.text_segmenter import _jieba_available mixed_texts = [ "Python programming language", "机器学习很重要", "Deep learning neural networks", "数据科学分析", ] store = SimpleStore(mixed_texts) # Search in English en_results = store.similarity_search("programming", k=1) assert "Python" in en_results[0].page_content # Search in Chinese - result depends on jieba availability cn_results = store.similarity_search("机器学习", k=2) assert len(cn_results) > 0 # If jieba is available, expect better Chinese matching if _jieba_available: assert "机器学习" in cn_results[0].page_content else: # Without jieba, just verify results are returned # (Chinese text may not be perfectly tokenized) assert any("机器学习" in doc.page_content for doc in cn_results) or any( "数据科学" in doc.page_content for doc in cn_results ) def test_max_marginal_relevance_search(self, simple_store): """Test max_marginal_relevance_search method.""" query = "programming language" # Test basic MMR search results = simple_store.max_marginal_relevance_search(query, k=2, fetch_k=4, lambda_mult=0.5) assert len(results) == 2 assert all(hasattr(doc, "page_content") for doc in results) # Test relevance-focused search (lambda_mult = 1.0) results_relevant = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=1.0) assert len(results_relevant) == 3 # Test diversity-focused search (lambda_mult = 0.0) results_diverse = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=0.0) assert len(results_diverse) == 3 # Verify different lambda values produce different results # (unless there are ties in scoring) assert len(results_relevant) == len(results_diverse) # Test with k >= fetch_k results = simple_store.max_marginal_relevance_search(query, k=5, fetch_k=3, lambda_mult=0.5) assert len(results) == 3 # Should return fetch_k documents # Test empty query results = simple_store.max_marginal_relevance_search("", k=2) assert len(results) <= 2 # Test empty store empty_store = SimpleStore([]) results = empty_store.max_marginal_relevance_search(query, k=2) assert results == [] def test_calculate_similarity(self, simple_store): """Test _calculate_similarity method.""" # Get two documents doc1 = simple_store.documents[0] # "Python is a programming language" doc2 = simple_store.documents[1] # "Machine learning is a subset of AI" doc3 = simple_store.documents[0] # Same as doc1 # Test similarity between different documents similarity_diff = simple_store._calculate_similarity(doc1, doc2) assert 0.0 <= similarity_diff <= 1.0 # Test similarity between identical documents similarity_same = simple_store._calculate_similarity(doc1, doc3) assert similarity_same == 1.0 # Test with empty documents from langchain_core.documents import Document empty_doc1 = Document(page_content="", metadata={}) empty_doc2 = Document(page_content="", metadata={}) similarity_empty = simple_store._calculate_similarity(empty_doc1, empty_doc2) assert similarity_empty == 0.0 # Empty sets have 0 Jaccard similarity ================================================ FILE: tests/test_text2sql_extraction.py ================================================ """Tests for text2sql information extraction functionality.""" import json from datetime import date from unittest.mock import Mock, patch from langchain_core.messages import AIMessage, HumanMessage from openchatbi.graph_state import SQLGraphState from openchatbi.text2sql.extraction import ( generate_extraction_prompt, information_extraction, information_extraction_conditional_edges, parse_extracted_info_json, ) class TestText2SQLExtraction: """Test text2sql information extraction functionality.""" def test_generate_extraction_prompt(self): """Test extraction prompt generation.""" prompt = generate_extraction_prompt() # Should replace time placeholder with today's date today_str = date.today().strftime("%Y-%m-%d") assert today_str in prompt # Should contain basic knowledge assert "[basic_knowledge_glossary]" not in prompt assert "[time_field_placeholder]" not in prompt def test_parse_extracted_info_json_valid(self): """Test parsing valid JSON from LLM response.""" json_response = { "keywords": ["revenue", "sales"], "dimensions": ["date", "region"], "metrics": ["total_revenue"], "filters": [], } # Mock LLM response with JSON llm_content = f"```json\n{json.dumps(json_response)}\n```" with patch("openchatbi.text2sql.extraction.get_text_from_content", return_value=llm_content): with patch("openchatbi.text2sql.extraction.extract_json_from_answer", return_value=json_response): result = parse_extracted_info_json(llm_content) assert result == json_response assert "keywords" in result assert "dimensions" in result def test_parse_extracted_info_json_invalid(self): """Test parsing invalid JSON returns empty dict.""" invalid_content = "Not valid JSON content" with patch("openchatbi.text2sql.extraction.get_text_from_content", return_value=invalid_content): with patch("openchatbi.text2sql.extraction.extract_json_from_answer", side_effect=Exception("Parse error")): result = parse_extracted_info_json(invalid_content) assert result == {} def test_information_extraction_function_creation(self): """Test creating information extraction function.""" mock_llm = Mock() extraction_func = information_extraction(mock_llm) # Should return a callable function assert callable(extraction_func) def test_information_extraction_successful(self): """Test successful information extraction.""" mock_llm = Mock() # Mock LLM response extracted_info = { "rewrite_question": "What is the total revenue by region?", "keywords": ["revenue", "total"], "dimensions": ["region"], "metrics": ["revenue"], "filters": [], } mock_response = AIMessage(content=json.dumps(extracted_info)) with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response): with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value=extracted_info): extraction_func = information_extraction(mock_llm) state = SQLGraphState( messages=[HumanMessage(content="Show me revenue by region")], question="Show me revenue by region" ) result = extraction_func(state) assert "info_entities" in result assert result["rewrite_question"] == "What is the total revenue by region?" def test_information_extraction_empty_response(self): """Test handling empty extraction response.""" mock_llm = Mock() mock_response = AIMessage(content="") with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response): with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value={}): extraction_func = information_extraction(mock_llm) state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question") result = extraction_func(state) # Should handle empty response gracefully assert "info_entities" in result assert result["info_entities"] == {} def test_information_extraction_conditional_edges_success(self): """Test conditional edges with successful extraction.""" state = SQLGraphState( messages=[HumanMessage(content="Test question")], question="Test question", rewrite_question="What is the total revenue by region?", info_entities={"keywords": ["revenue"], "dimensions": ["date"]}, ) result = information_extraction_conditional_edges(state) # Should proceed to next when rewrite_question exists assert result == "next" def test_information_extraction_conditional_edges_failure(self): """Test conditional edges with failed extraction.""" state = SQLGraphState( messages=[HumanMessage(content="Test question")], question="Test question", info_entities={} ) result = information_extraction_conditional_edges(state) # Should end when no info extracted assert result == "end" def test_information_extraction_conditional_edges_missing(self): """Test conditional edges with missing info_entities.""" state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question") result = information_extraction_conditional_edges(state) # Should end when info_entities not present assert result == "end" def test_information_extraction_with_retry_on_failure(self): """Test information extraction with retry mechanism.""" mock_llm = Mock() # First call fails, second succeeds extracted_info = { "rewrite_question": "Test question", "keywords": ["test"], "dimensions": [], "metrics": [], "filters": [], } mock_response = AIMessage(content=json.dumps(extracted_info)) with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response): with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value=extracted_info): extraction_func = information_extraction(mock_llm) state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question") result = extraction_func(state) assert "info_entities" in result assert result["info_entities"]["keywords"] == ["test"] def test_information_extraction_time_period_detection(self): """Test time period detection in queries.""" mock_llm = Mock() extracted_info = { "rewrite_question": "Show data for the last 7 days", "keywords": ["data"], "dimensions": ["date"], "metrics": [], "filters": [], "start_time": "2024-01-01", } mock_response = AIMessage(content=json.dumps(extracted_info)) with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", return_value=mock_response): with patch("openchatbi.text2sql.extraction.parse_extracted_info_json", return_value=extracted_info): extraction_func = information_extraction(mock_llm) state = SQLGraphState( messages=[HumanMessage(content="Test question")], question="Show data for last 7 days" ) result = extraction_func(state) assert "info_entities" in result assert "start_time" in result["info_entities"] def test_information_extraction_error_handling(self): """Test error handling in information extraction.""" mock_llm = Mock() # Mock call to raise exception with patch("openchatbi.text2sql.extraction.call_llm_chat_model_with_retry", side_effect=Exception("LLM error")): extraction_func = information_extraction(mock_llm) state = SQLGraphState(messages=[HumanMessage(content="Test question")], question="Test question") # Should raise exception as the function doesn't have try-catch try: result = extraction_func(state) # Should not reach here assert False, "Expected exception to be raised" except Exception as e: assert "LLM error" in str(e) ================================================ FILE: tests/test_text2sql_generate_sql.py ================================================ """Tests for text2sql SQL generation functionality.""" from unittest.mock import Mock, patch import pytest from langchain_core.messages import AIMessage from openchatbi.graph_state import SQLGraphState from openchatbi.text2sql.generate_sql import create_sql_nodes, should_execute_sql, should_retry_sql class TestText2SQLGenerateSQL: """Test text2sql SQL generation functionality.""" @pytest.fixture def mock_llm(self): """Mock LLM for testing.""" llm = Mock() llm.invoke.return_value = AIMessage(content="SELECT * FROM users") return llm @pytest.fixture def mock_catalog(self): """Mock catalog store for testing.""" catalog = Mock() catalog.get_table_information.return_value = { "description": "User data table", "sql_rule": "", "derived_metric": "", } catalog.get_column_list.return_value = [ { "column_name": "user_id", "type": "bigint", "display_name": "User ID", "description": "Unique user identifier", "alias": "", } ] # Mock SQL engine with proper context manager mock_engine = Mock() mock_connection = Mock() mock_result = Mock() mock_result.fetchall.return_value = [("1", "John"), ("2", "Jane")] mock_result.keys.return_value = ["id", "name"] mock_connection.execute.return_value = mock_result # Create a proper context manager mock using MagicMock from unittest.mock import MagicMock mock_context_manager = MagicMock() mock_context_manager.__enter__.return_value = mock_connection mock_context_manager.__exit__.return_value = None mock_engine.connect.return_value = mock_context_manager catalog.get_sql_engine.return_value = mock_engine return catalog def test_create_sql_nodes(self, mock_llm, mock_catalog): """Test creating SQL processing nodes.""" generate_node, execute_node, regenerate_node, visualization_node = create_sql_nodes( mock_llm, mock_catalog, "presto" ) assert callable(generate_node) assert callable(execute_node) assert callable(regenerate_node) assert callable(visualization_node) def test_generate_sql_node_success(self, mock_llm, mock_catalog): """Test successful SQL generation.""" generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState( messages=[], question="Show all users", rewrite_question="Show all users", tables=[{"table": "users", "columns": []}], ) with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] result = generate_node(state) assert "sql" in result assert result["sql"] == "SELECT * FROM users" def test_generate_sql_node_missing_rewrite_question(self, mock_llm, mock_catalog): """Test SQL generation with missing rewrite question.""" generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState( messages=[], question="Show all users", # Missing rewrite_question ) result = generate_node(state) assert result == {} def test_generate_sql_node_missing_tables(self, mock_llm, mock_catalog): """Test SQL generation with missing tables.""" generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState( messages=[], question="Show all users", rewrite_question="Show all users", tables=[] # Empty tables ) result = generate_node(state) assert result == {} def test_execute_sql_node_success(self, mock_llm, mock_catalog): """Test successful SQL execution.""" _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState(messages=[], sql="SELECT * FROM users") result = execute_node(state) assert "sql_execution_result" in result from openchatbi.constants import SQL_SUCCESS assert result["sql_execution_result"] == SQL_SUCCESS assert "data" in result def test_execute_sql_node_empty_sql(self, mock_llm, mock_catalog): """Test SQL execution with empty SQL.""" _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState(messages=[], sql="") # Empty SQL result = execute_node(state) assert "sql_execution_result" in result from openchatbi.constants import SQL_NA assert result["sql_execution_result"] == SQL_NA def test_execute_sql_node_syntax_error(self, mock_llm, mock_catalog): """Test SQL execution with syntax error.""" _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") # Mock SQL execution to raise syntax error mock_engine = mock_catalog.get_sql_engine.return_value mock_connection = mock_engine.connect.return_value.__enter__.return_value from sqlalchemy.exc import ProgrammingError mock_connection.execute.side_effect = ProgrammingError("", "", "Syntax error") state = SQLGraphState(messages=[], sql="SELECT * FRON users") # Intentional syntax error result = execute_node(state) assert "sql_execution_result" in result from openchatbi.constants import SQL_SYNTAX_ERROR assert result["sql_execution_result"] == SQL_SYNTAX_ERROR assert "previous_sql_errors" in result def test_regenerate_sql_node_success(self, mock_llm, mock_catalog): """Test successful SQL regeneration.""" _, _, regenerate_node, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState( messages=[], question="Show all users", rewrite_question="Show all users", tables=[{"table": "users", "columns": []}], previous_sql_errors=[ {"sql": "SELECT * FRON users", "error": "Syntax error: FRON", "error_type": "SQL syntax error"} ], sql_retry_count=1, ) with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] result = regenerate_node(state) assert "sql" in result assert "sql_retry_count" in result assert result["sql_retry_count"] == 2 def test_should_retry_sql_success(self): """Test retry decision with successful execution.""" # Import the constant from the module from openchatbi.constants import SQL_SUCCESS state = SQLGraphState(sql_execution_result=SQL_SUCCESS, sql_retry_count=1) result = should_retry_sql(state) assert result == "end" def test_should_retry_sql_timeout(self): """Test retry decision with timeout.""" # Import the constant from the module from openchatbi.constants import SQL_EXECUTE_TIMEOUT state = SQLGraphState(sql_execution_result=SQL_EXECUTE_TIMEOUT, sql_retry_count=1) result = should_retry_sql(state) assert result == "end" def test_should_retry_sql_retry_needed(self): """Test retry decision when retry is needed.""" state = SQLGraphState(sql_execution_result="SYNTAX_ERROR", sql_retry_count=1) result = should_retry_sql(state) assert result == "regenerate_sql" def test_should_retry_sql_max_retries_reached(self): """Test retry decision when max retries reached.""" state = SQLGraphState(sql_execution_result="SYNTAX_ERROR", sql_retry_count=3) result = should_retry_sql(state) assert result == "end" def test_should_execute_sql_with_sql(self): """Test execute decision with SQL present.""" state = SQLGraphState(sql="SELECT * FROM users") result = should_execute_sql(state) assert result == "execute_sql" def test_should_execute_sql_without_sql(self): """Test execute decision without SQL.""" state = SQLGraphState(sql="") result = should_execute_sql(state) assert result == "end" def test_sql_generation_with_examples(self, mock_llm, mock_catalog): """Test SQL generation with relevant examples.""" generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState( messages=[], question="Show user count", rewrite_question="Show user count", tables=[{"table": "users", "columns": []}], ) # Mock example retrieval mock_document = Mock() mock_document.page_content = "How many users are there?" with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever: mock_retriever.invoke.return_value = [mock_document] with patch( "openchatbi.text2sql.generate_sql.sql_example_dicts", {"How many users are there?": ("SELECT COUNT(*) FROM users", ["users"])}, ): result = generate_node(state) assert "sql" in result def test_sql_error_handling_database_error(self, mock_llm, mock_catalog): """Test handling of database connection errors.""" _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") # Mock database connection error mock_engine = mock_catalog.get_sql_engine.return_value mock_connection = mock_engine.connect.return_value.__enter__.return_value from sqlalchemy.exc import OperationalError mock_connection.execute.side_effect = OperationalError("", "", "Connection failed") state = SQLGraphState(messages=[], sql="SELECT * FROM users") result = execute_node(state) assert "sql_execution_result" in result from openchatbi.constants import SQL_EXECUTE_TIMEOUT assert result["sql_execution_result"] == SQL_EXECUTE_TIMEOUT def test_regenerate_sql_empty_response(self, mock_llm, mock_catalog): """Test regeneration with empty LLM response.""" mock_llm.invoke.return_value = AIMessage(content="") _, _, regenerate_node, _ = create_sql_nodes(mock_llm, mock_catalog, "presto") state = SQLGraphState( messages=[], question="Show all users", rewrite_question="Show all users", tables=[{"table": "users", "columns": []}], previous_sql_errors=[], sql_retry_count=1, ) with patch("openchatbi.text2sql.generate_sql.sql_example_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] result = regenerate_node(state) assert "sql" in result assert result["sql"] == "" from openchatbi.constants import SQL_NA assert result["sql_execution_result"] == SQL_NA ================================================ FILE: tests/test_text2sql_schema_linking.py ================================================ """Tests for text2sql schema linking functionality.""" from unittest.mock import Mock, patch import pytest from langchain_core.messages import AIMessage from openchatbi.graph_state import SQLGraphState from openchatbi.text2sql.schema_linking import schema_linking class TestText2SQLSchemaLinking: """Test text2sql schema linking functionality.""" @pytest.fixture def mock_llm(self): """Mock LLM for testing.""" llm = Mock() llm.invoke.return_value = AIMessage(content='{"tables": [{"table": "users", "reason": "Contains user data"}]}') return llm @pytest.fixture def mock_catalog(self): """Mock catalog store for testing.""" catalog = Mock() catalog.get_table_information.return_value = { "description": "User data table", "selection_rule": "Use for user-related queries", } return catalog def test_select_table_function_creation(self, mock_llm, mock_catalog): """Test creating table selection function.""" select_func = schema_linking(mock_llm, mock_catalog) assert callable(select_func) def test_select_table_success(self, mock_llm, mock_catalog): """Test successful table selection.""" with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns: mock_get_columns.return_value = ["user_id", "name", "email"] with patch( "openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users", "profiles"], "name": ["users"], "email": ["users", "contacts"]}, ): with patch( "openchatbi.text2sql.schema_linking.col_dict", { "user_id": { "column_name": "user_id", "category": "dimension", "display_name": "User ID", "description": "Unique user identifier", }, "name": { "column_name": "name", "category": "dimension", "display_name": "Name", "description": "User full name", }, "email": { "column_name": "email", "category": "dimension", "display_name": "Email", "description": "User email address", }, }, ): with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}): with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract: mock_extract.return_value = { "tables": [{"table": "users", "reason": "Contains user data"}] } select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show user information", rewrite_question="Show user information", info_entities={ "keywords": ["user", "information"], "dimensions": ["name", "email"], "metrics": [], }, ) result = select_func(state) assert "tables" in result assert len(result["tables"]) == 1 assert result["tables"][0]["table"] == "users" def test_select_table_missing_rewrite_question(self, mock_llm, mock_catalog): """Test table selection with missing rewrite question.""" select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show user information", # Missing rewrite_question ) result = select_func(state) assert result == {} def test_select_table_with_examples(self, mock_llm, mock_catalog): """Test table selection with similar examples.""" with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns: mock_get_columns.return_value = ["user_id", "revenue"] with patch( "openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"], "revenue": ["sales"]} ): with patch( "openchatbi.text2sql.schema_linking.col_dict", { "user_id": { "column_name": "user_id", "category": "dimension", "display_name": "User ID", "description": "Unique user identifier", }, "revenue": { "column_name": "revenue", "category": "metric", "display_name": "Revenue", "description": "Total revenue amount", }, }, ): # Mock similar examples mock_document = Mock() mock_document.page_content = "What is user revenue?" with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever: mock_retriever.invoke.return_value = [mock_document] with patch( "openchatbi.text2sql.schema_linking.table_selection_example_dict", {"What is user revenue?": ["users", "sales"]}, ): with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract: mock_extract.return_value = {"tables": [{"table": "users"}, {"table": "sales"}]} select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show user revenue", rewrite_question="Show user revenue", info_entities={ "keywords": ["user", "revenue"], "dimensions": ["user_id"], "metrics": ["revenue"], }, ) result = select_func(state) assert "tables" in result assert len(result["tables"]) == 2 def test_select_table_invalid_table_selection(self, mock_llm, mock_catalog): """Test handling of invalid table selection.""" with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns: mock_get_columns.return_value = ["user_id"] with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}): with patch( "openchatbi.text2sql.schema_linking.col_dict", { "user_id": { "column_name": "user_id", "category": "dimension", "display_name": "User ID", "description": "Unique user identifier", } }, ): with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}): with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract: # Return invalid table not in candidate list mock_extract.return_value = {"tables": [{"table": "invalid_table"}]} select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show user info", rewrite_question="Show user info", info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []}, ) result = select_func(state) # Should return empty dict when invalid table selected assert result == {} def test_select_table_retry_mechanism(self, mock_llm, mock_catalog): """Test retry mechanism for table selection.""" with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns: mock_get_columns.return_value = ["user_id"] with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}): with patch( "openchatbi.text2sql.schema_linking.col_dict", { "user_id": { "column_name": "user_id", "category": "dimension", "display_name": "User ID", "description": "Unique user identifier", } }, ): with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}): with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract: # First returns invalid, then valid mock_extract.side_effect = [ {"tables": [{"table": "invalid_table"}]}, {"tables": [{"table": "users"}]}, ] select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show user info", rewrite_question="Show user info", info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []}, ) result = select_func(state) assert "tables" in result assert result["tables"][0]["table"] == "users" def test_select_table_with_time_filter(self, mock_llm, mock_catalog): """Test table selection with time filtering.""" # Mock table with start_time mock_catalog.get_table_information.return_value = { "description": "User data table", "selection_rule": "Use for user queries", "start_time": "2024-01-01", } with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns: mock_get_columns.return_value = ["user_id"] with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}): with patch( "openchatbi.text2sql.schema_linking.col_dict", { "user_id": { "column_name": "user_id", "category": "dimension", "display_name": "User ID", "description": "Unique user identifier", } }, ): with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}): with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract: mock_extract.return_value = {"tables": [{"table": "users"}]} select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show recent user info", rewrite_question="Show recent user info", info_entities={ "keywords": ["user"], "dimensions": ["user_id"], "metrics": [], "start_time": "2024-06-01", # Later than table start_time }, ) result = select_func(state) assert "tables" in result assert result["tables"][0]["table"] == "users" def test_select_table_llm_error_handling(self, mock_llm, mock_catalog): """Test handling of LLM errors during table selection.""" mock_llm.invoke.side_effect = Exception("LLM service error") with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns: mock_get_columns.return_value = ["user_id"] with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}): with patch( "openchatbi.text2sql.schema_linking.col_dict", { "user_id": { "column_name": "user_id", "category": "dimension", "display_name": "User ID", "description": "Unique user identifier", } }, ): with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show user info", rewrite_question="Show user info", info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []}, ) result = select_func(state) # Should handle error gracefully and return empty dict assert result == {} def test_select_table_max_retries_exceeded(self, mock_llm, mock_catalog): """Test behavior when max retries are exceeded.""" with patch("openchatbi.text2sql.schema_linking.get_relevant_columns") as mock_get_columns: mock_get_columns.return_value = ["user_id"] with patch("openchatbi.text2sql.schema_linking.column_tables_mapping", {"user_id": ["users"]}): with patch( "openchatbi.text2sql.schema_linking.col_dict", { "user_id": { "column_name": "user_id", "category": "dimension", "display_name": "User ID", "description": "Unique user identifier", } }, ): with patch("openchatbi.text2sql.schema_linking.table_selection_retriever") as mock_retriever: mock_retriever.invoke.return_value = [] with patch("openchatbi.text2sql.schema_linking.table_selection_example_dict", {}): with patch("openchatbi.text2sql.schema_linking.extract_json_from_answer") as mock_extract: # Always return invalid table mock_extract.return_value = {"tables": [{"table": "invalid_table"}]} select_func = schema_linking(mock_llm, mock_catalog) state = SQLGraphState( messages=[], question="Show user info", rewrite_question="Show user info", info_entities={"keywords": ["user"], "dimensions": ["user_id"], "metrics": []}, ) result = select_func(state) # Should return empty dict after max retries assert result == {} ================================================ FILE: tests/test_text2sql_visualization.py ================================================ """Tests for text2sql visualization functionality.""" import pytest from openchatbi.text2sql.visualization import ChartType, VisualizationConfig, VisualizationDSL, VisualizationService class TestVisualizationService: """Tests for the VisualizationService class.""" def test_generate_visualization_dsl_basic(self): """Test basic DSL generation with schema info.""" schema_info = { "columns": ["name", "age", "salary", "department"], "row_count": 4, "numeric_columns": ["age", "salary"], "categorical_columns": ["name", "department"], "datetime_columns": [], "unique_counts": {"name": 4, "department": 2}, } service = VisualizationService() question = "Compare salary by department" dsl = service.generate_visualization_dsl(question, schema_info) assert dsl.chart_type == "bar" # Should use first categorical column which is "name" assert "name" in dsl.data_columns assert "age" in dsl.data_columns and "salary" in dsl.data_columns # Both numeric columns should be included def test_get_chart_type_by_rule_with_datetime(self): """Test chart type recommendation with datetime columns.""" schema_info = { "columns": ["date", "sales", "region"], "numeric_columns": ["sales"], "categorical_columns": ["region"], "datetime_columns": ["date"], "row_count": 3, } service = VisualizationService() question = "Show sales trend over time" chart_type = service._get_chart_type_by_rule(question, schema_info) assert chart_type == ChartType.LINE def test_generate_visualization_dsl_error_handling(self): """Test DSL generation with error in schema info.""" schema_info = {"error": "Failed to analyze data schema"} service = VisualizationService() question = "Show data" dsl = service.generate_visualization_dsl(question, schema_info) assert dsl.chart_type == "table" assert "error" in dsl.config def test_get_chart_type_by_rule_line_chart(self): """Test recommendation for line chart based on question keywords.""" question = "Show me the sales trend over time" schema = { "numeric_columns": ["sales"], "categorical_columns": ["region"], "datetime_columns": ["date"], "row_count": 10, } service = VisualizationService() chart_type = service._get_chart_type_by_rule(question, schema) assert chart_type == ChartType.LINE def test_get_chart_type_by_rule_pie_chart(self): """Test recommendation for pie chart based on question keywords.""" question = "What is the percentage breakdown by department?" schema = { "numeric_columns": ["count"], "categorical_columns": ["department"], "datetime_columns": [], "row_count": 5, "unique_counts": {"department": 4}, } service = VisualizationService() chart_type = service._get_chart_type_by_rule(question, schema) assert chart_type == ChartType.PIE def test_get_chart_type_by_rule_bar_chart(self): """Test recommendation for bar chart based on question keywords.""" question = "Compare sales by region" schema = { "numeric_columns": ["sales"], "categorical_columns": ["region"], "datetime_columns": [], "row_count": 10, "unique_counts": {"region": 4}, } service = VisualizationService() chart_type = service._get_chart_type_by_rule(question, schema) assert chart_type == ChartType.BAR def test_get_chart_type_by_rule_scatter_plot(self): """Test recommendation for scatter plot based on data characteristics.""" question = "Show relationship between age and salary" schema = { "numeric_columns": ["age", "salary"], "categorical_columns": ["name"], "datetime_columns": [], "row_count": 10, } service = VisualizationService() chart_type = service._get_chart_type_by_rule(question, schema) assert chart_type == ChartType.SCATTER def test_get_chart_type_by_rule_histogram(self): """Test recommendation for histogram based on keywords.""" question = "What is the distribution of ages?" schema = {"numeric_columns": ["age"], "categorical_columns": [], "datetime_columns": [], "row_count": 100} service = VisualizationService() chart_type = service._get_chart_type_by_rule(question, schema) assert chart_type == ChartType.HISTOGRAM def test_get_chart_type_by_rule_data_based_priority(self): """Test that data characteristics take priority over row count.""" question = "Show all records" schema = { "numeric_columns": ["value"], "categorical_columns": ["category"], "datetime_columns": [], "row_count": 15, "unique_counts": {"category": 5}, # Small number of categories } service = VisualizationService() chart_type = service._get_chart_type_by_rule(question, schema) # Should choose PIE because of categorical + numeric columns, not TABLE due to row count assert chart_type == ChartType.PIE def test_generate_visualization_dsl_line_chart(self): """Test DSL generation for line chart.""" question = "Show sales trend over time" schema_info = { "columns": ["date", "sales"], "numeric_columns": ["sales"], "categorical_columns": [], "datetime_columns": ["date"], "row_count": 3, } service = VisualizationService() dsl = service.generate_visualization_dsl(question, schema_info) assert dsl.chart_type == "line" assert "date" in dsl.data_columns assert "sales" in dsl.data_columns assert dsl.config["x"] == "date" assert dsl.config["y"] == "sales" def test_generate_visualization_dsl_bar_chart(self): """Test DSL generation for bar chart.""" question = "Compare sales by region" schema_info = { "columns": ["region", "sales"], "numeric_columns": ["sales"], "categorical_columns": ["region"], "datetime_columns": [], "row_count": 4, "unique_counts": {"region": 4}, } service = VisualizationService() dsl = service.generate_visualization_dsl(question, schema_info) assert dsl.chart_type == "bar" assert "region" in dsl.data_columns assert "sales" in dsl.data_columns assert dsl.config["x"] == "region" assert dsl.config["y"] == "sales" def test_generate_visualization_dsl_pie_chart(self): """Test DSL generation for pie chart.""" question = "Show percentage breakdown by department" schema_info = { "columns": ["department", "count"], "numeric_columns": ["count"], "categorical_columns": ["department"], "datetime_columns": [], "row_count": 4, "unique_counts": {"department": 4}, } service = VisualizationService() dsl = service.generate_visualization_dsl(question, schema_info, ChartType.PIE) assert dsl.chart_type == "pie" assert dsl.config["labels"] == "department" assert dsl.config["values"] == "count" def test_generate_visualization_dsl_empty_data(self): """Test DSL generation with empty data.""" question = "Show data" schema_info = { "columns": [], "numeric_columns": [], "categorical_columns": [], "datetime_columns": [], "row_count": 0, } service = VisualizationService() dsl = service.generate_visualization_dsl(question, schema_info) assert dsl.chart_type == "table" assert "columns" in dsl.config def test_visualization_config_dataclass(self): """Test VisualizationConfig dataclass.""" config = VisualizationConfig( chart_type=ChartType.BAR, x_column="category", y_column="value", title="Test Chart" ) assert config.chart_type == ChartType.BAR assert config.x_column == "category" assert config.y_column == "value" assert config.title == "Test Chart" assert config.show_legend is True # default value def test_visualization_dsl_to_dict(self): """Test VisualizationDSL to_dict method.""" dsl = VisualizationDSL( chart_type="bar", data_columns=["x", "y"], config={"x": "category", "y": "value"}, layout={"title": "Test Chart"}, ) result = dsl.to_dict() assert result["chart_type"] == "bar" assert result["data_columns"] == ["x", "y"] assert result["config"]["x"] == "category" assert result["layout"]["title"] == "Test Chart" class TestChartType: """Tests for ChartType enum.""" def test_chart_type_values(self): """Test ChartType enum values.""" assert ChartType.LINE.value == "line" assert ChartType.BAR.value == "bar" assert ChartType.PIE.value == "pie" assert ChartType.SCATTER.value == "scatter" assert ChartType.HISTOGRAM.value == "histogram" assert ChartType.BOX.value == "box" assert ChartType.HEATMAP.value == "heatmap" assert ChartType.TABLE.value == "table" @pytest.fixture def sample_csv_data(): """Fixture providing sample CSV data for testing.""" return """product,sales,region,quarter Widget A,10000,North,Q1 Widget B,15000,South,Q1 Widget C,8000,East,Q1 Widget A,12000,North,Q2 Widget B,18000,South,Q2 Widget C,9000,East,Q2""" @pytest.fixture def sample_time_series_data(): """Fixture providing sample time series data for testing.""" return """date,revenue,users 2023-01-01,50000,1000 2023-02-01,55000,1100 2023-03-01,60000,1200 2023-04-01,52000,1050 2023-05-01,58000,1150""" class TestVisualizationIntegration: """Integration tests for visualization functionality.""" def test_complete_workflow_line_chart(self, sample_time_series_data): """Test complete workflow for generating line chart.""" question = "Show revenue trend over time" # Mock schema info for time series data schema_info = { "columns": ["date", "revenue", "users"], "numeric_columns": ["revenue", "users"], "categorical_columns": [], "datetime_columns": ["date"], "row_count": 5, } service = VisualizationService() # Recommend chart type chart_type = service._get_chart_type_by_rule(question, schema_info) # Generate DSL dsl = service.generate_visualization_dsl(question, schema_info, chart_type) assert chart_type == ChartType.LINE assert dsl.chart_type == "line" assert "date" in dsl.data_columns assert "revenue" in dsl.data_columns def test_complete_workflow_bar_chart(self, sample_csv_data): """Test complete workflow for generating bar chart.""" question = "Compare sales by product" # Mock schema info for sample CSV data schema_info = { "columns": ["product", "sales", "region", "quarter"], "numeric_columns": ["sales"], "categorical_columns": ["product", "region", "quarter"], "datetime_columns": [], "row_count": 6, "unique_counts": {"product": 3, "region": 3, "quarter": 2}, } service = VisualizationService() # Generate DSL directly (will analyze schema internally) dsl = service.generate_visualization_dsl(question, schema_info) assert dsl.chart_type in ["bar", "line"] # Could be either based on heuristics assert len(dsl.data_columns) >= 2 assert dsl.layout.get("title") is not None ================================================ FILE: tests/test_tools_ask_human.py ================================================ """Tests for ask_human tool functionality.""" import pytest from pydantic import ValidationError from openchatbi.tool.ask_human import AskHuman class TestAskHuman: """Test AskHuman model functionality.""" def test_ask_human_basic_initialization(self): """Test basic AskHuman model creation.""" question = "What time period should I analyze?" options = ["Last 7 days", "Last 30 days", "Last year"] ask_human = AskHuman(question=question, options=options) assert ask_human.question == question assert ask_human.options == options def test_ask_human_empty_options(self): """Test AskHuman with empty options list.""" ask_human = AskHuman(question="Simple question?", options=[]) assert ask_human.question == "Simple question?" assert ask_human.options == [] def test_ask_human_validation_error(self): """Test AskHuman model validation.""" with pytest.raises(ValidationError): AskHuman() # Missing required fields with pytest.raises(ValidationError): AskHuman(question="Test") # Missing options field def test_ask_human_serialization(self): """Test AskHuman model serialization.""" ask_human = AskHuman(question="Which analysis method?", options=["Statistical", "Machine Learning"]) data = ask_human.model_dump() assert data["question"] == "Which analysis method?" assert data["options"] == ["Statistical", "Machine Learning"] ================================================ FILE: tests/test_tools_run_python_code.py ================================================ """Tests for run_python_code tool functionality.""" from unittest.mock import patch from openchatbi.tool.run_python_code import run_python_code class TestRunPythonCode: """Test run_python_code tool functionality.""" def test_run_python_code_basic(self): """Test basic Python code execution.""" reasoning = "Testing basic print functionality" code = "print('Hello, World!')" with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Hello, World!\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Hello, World!" in result def test_run_python_code_with_variables(self): """Test Python code execution with variables.""" reasoning = "Testing variable operations" code = """ x = 10 y = 20 result = x + y print(f"Result: {result}") """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Result: 30\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Result: 30" in result def test_run_python_code_data_analysis(self): """Test Python code for data analysis operations.""" reasoning = "Performing data analysis calculations" code = """ import math data = [1, 2, 3, 4, 5] mean = sum(data) / len(data) print(f"Mean: {mean}") """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Mean: 3.0\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Mean: 3.0" in result def test_run_python_code_matplotlib_plot(self): """Test Python code for creating plots.""" reasoning = "Creating a matplotlib visualization" code = """ import matplotlib.pyplot as plt x = [1, 2, 3, 4, 5] y = [2, 4, 6, 8, 10] plt.plot(x, y) plt.title('Sample Plot') print('Plot created successfully') """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Plot created successfully\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Plot created successfully" in result def test_run_python_code_syntax_error(self): """Test Python code execution with syntax errors.""" reasoning = "Testing error handling for syntax errors" code = "print('Hello World'" # Missing closing parenthesis with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (False, "SyntaxError: unexpected EOF while parsing") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Error:" in result assert "SyntaxError" in result def test_run_python_code_runtime_error(self): """Test Python code execution with runtime errors.""" reasoning = "Testing error handling for runtime errors" code = """ x = 10 y = 0 result = x / y print(result) """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (False, "ZeroDivisionError: division by zero") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Error:" in result assert "ZeroDivisionError" in result def test_run_python_code_import_error(self): """Test Python code execution with import errors.""" reasoning = "Testing error handling for import errors" code = """ import nonexistent_module print('This should not print') """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (False, "ModuleNotFoundError: No module named 'nonexistent_module'") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Error:" in result assert "ModuleNotFoundError" in result def test_run_python_code_multiline_output(self): """Test Python code with multiple print statements.""" reasoning = "Testing multiple output lines" code = """ for i in range(3): print(f"Line {i + 1}") """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Line 1\nLine 2\nLine 3\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Line 1" in result assert "Line 2" in result assert "Line 3" in result def test_run_python_code_with_sql_data(self): """Test Python code working with SQL-like data.""" reasoning = "Processing SQL query results" code = """ data = [ {'name': 'Alice', 'age': 30, 'salary': 50000}, {'name': 'Bob', 'age': 25, 'salary': 45000}, {'name': 'Charlie', 'age': 35, 'salary': 55000} ] total_salary = sum(row['salary'] for row in data) print(f"Total salary: ${total_salary}") """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Total salary: $150000\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Total salary: $150000" in result def test_run_python_code_empty_code(self): """Test Python code execution with empty code.""" reasoning = "Testing empty code handling" code = "" with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert result == "" def test_run_python_code_whitespace_only(self): """Test Python code execution with whitespace only.""" reasoning = "Testing whitespace-only code" code = " \n \t \n " with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert result == "" def test_run_python_code_with_comments(self): """Test Python code execution with comments.""" reasoning = "Testing code with comments" code = """ # This is a comment x = 5 # Another comment print(f"Value: {x}") # Final comment """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Value: 5\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Value: 5" in result def test_run_python_code_security_restrictions(self): """Test Python code with potentially restricted operations.""" reasoning = "Testing security restrictions" code = """ # Attempting file operations try: with open('/etc/passwd', 'r') as f: content = f.read() print("File read successfully") except Exception as e: print(f"Security restriction: {e}") """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = ( False, "PermissionError: [Errno 13] Permission denied: '/etc/passwd'", ) result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Error:" in result def test_run_python_code_timeout_handling(self): """Test Python code execution timeout scenarios.""" reasoning = "Testing timeout handling" code = """ import time time.sleep(10) # Long running operation print("This might timeout") """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (False, "TimeoutError: Code execution timed out") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Error:" in result assert "TimeoutError" in result def test_run_python_code_memory_limit(self): """Test Python code execution with memory limitations.""" reasoning = "Testing memory limit handling" code = """ # Creating a large list that might exceed memory limits large_list = [0] * (10**8) print(f"Created list with {len(large_list)} elements") """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (False, "MemoryError: Unable to allocate memory") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Error:" in result assert "MemoryError" in result def test_run_python_code_return_values(self): """Test that return values are not captured (only prints).""" reasoning = "Testing return value handling" code = """ def calculate(): return 42 result = calculate() print(f"Function returned: {result}") # The return value itself should not be captured calculate() """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (True, "Function returned: 42\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Function returned: 42" in result # Should not contain the raw return value assert result.strip() == "Function returned: 42" def test_run_python_code_exception_details(self): """Test detailed exception information.""" reasoning = "Testing detailed exception handling" code = """ def faulty_function(): raise ValueError("This is a custom error message") faulty_function() """ with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor: mock_instance = mock_executor.return_value mock_instance.run_code.return_value = (False, "ValueError: This is a custom error message") result = run_python_code.run({"reasoning": reasoning, "code": code}) assert isinstance(result, str) assert "Error:" in result assert "ValueError" in result assert "custom error message" in result def test_run_python_code_executor_selection(self): """Test that LocalExecutor is properly instantiated and used.""" reasoning = "Testing executor instantiation" code = "print('Executor test')" with patch("openchatbi.tool.run_python_code.LocalExecutor") as mock_executor_class: mock_instance = mock_executor_class.return_value mock_instance.run_code.return_value = (True, "Executor test\n") result = run_python_code.run({"reasoning": reasoning, "code": code}) # Verify LocalExecutor was instantiated mock_executor_class.assert_called_once() # Verify run_code was called with the correct code mock_instance.run_code.assert_called_once_with(code) assert isinstance(result, str) assert "Executor test" in result ================================================ FILE: tests/test_tools_search_knowledge.py ================================================ """Tests for search_knowledge tool functionality.""" from unittest.mock import patch import pytest from openchatbi.tool.search_knowledge import search_knowledge, show_schema class TestSearchKnowledge: """Test search_knowledge tool functionality.""" def test_search_knowledge_basic(self): """Test basic knowledge search functionality.""" reasoning = "Looking for user information" query_list = ["user", "information"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_id: User identifier\nuser_name: User name" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "columns" in result assert "User identifier" in result["columns"] mock_search.assert_called_once_with(query_list, False) def test_search_knowledge_table_matching(self): """Test knowledge search with table matching.""" reasoning = "Finding table relationships" query_list = ["user", "metrics"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_id: Unique identifier\nmetrics_value: Metric value" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": True, } ) assert isinstance(result, dict) assert "columns" in result mock_search.assert_called_once_with(query_list, True) def test_search_knowledge_empty_query(self): """Test knowledge search with empty query.""" reasoning = "Testing empty search" query_list = [] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "columns" in result mock_search.assert_called_once_with(query_list, False) def test_search_knowledge_no_matches(self): """Test knowledge search with no matches.""" reasoning = "Testing no matches" query_list = ["nonexistent"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "columns" in result assert result["columns"] == "# Relevant Columns and Description:\n" def test_search_knowledge_multiple_matches(self): """Test knowledge search with multiple matches.""" reasoning = "Finding multiple matches" query_list = ["user", "data", "profile"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_id: User ID\nuser_name: Name\nprofile_data: Profile" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "columns" in result assert "user_id" in result["columns"] assert "profile_data" in result["columns"] def test_search_knowledge_with_synonyms(self): """Test knowledge search with synonym matching.""" reasoning = "Testing synonym search" query_list = ["customer", "client", "user"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "customer_id: Customer identifier\nclient_name: Client name" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "columns" in result def test_search_knowledge_case_insensitive(self): """Test case insensitive knowledge search.""" reasoning = "Testing case sensitivity" query_list = ["USER", "Data", "PROFILE"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_data: User information" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "columns" in result def test_search_knowledge_partial_matches(self): """Test knowledge search with partial matches.""" reasoning = "Testing partial matching" query_list = ["usr", "prof"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_profile: User profile data" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) def test_search_knowledge_error_handling(self): """Test knowledge search error handling.""" reasoning = "Testing error handling" query_list = ["test"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.side_effect = Exception("Search error") # Should handle exceptions gracefully with pytest.raises(Exception): search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) def test_show_schema_basic(self): """Test basic schema display functionality.""" reasoning = "Showing basic schema" tables = ["user_data"] with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list: mock_list.return_value = ["Table: user_data\n# Description: User information\n# Columns:\nuser_id: User ID"] result = show_schema.run({"reasoning": reasoning, "tables": tables}) assert isinstance(result, list) assert len(result) == 1 assert "user_data" in result[0] mock_list.assert_called_once_with(tables) def test_show_schema_detailed_info(self): """Test detailed schema information.""" reasoning = "Showing detailed schema" tables = ["user_data", "metrics"] with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list: mock_list.return_value = [ "Table: user_data\n# Columns: user_id, name, email", "Table: metrics\n# Columns: metric_id, value, timestamp", ] result = show_schema.run({"reasoning": reasoning, "tables": tables}) assert isinstance(result, list) assert len(result) == 2 assert any("user_data" in schema for schema in result) assert any("metrics" in schema for schema in result) def test_show_schema_nonexistent_table(self): """Test schema display for nonexistent table.""" reasoning = "Testing nonexistent table" tables = ["nonexistent_table"] with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list: mock_list.return_value = [] result = show_schema.run({"reasoning": reasoning, "tables": tables}) assert isinstance(result, list) assert len(result) == 0 def test_show_schema_table_error(self): """Test schema display error handling.""" reasoning = "Testing schema errors" tables = ["error_table"] with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list: mock_list.side_effect = Exception("Table access error") with pytest.raises(Exception): show_schema.run({"reasoning": reasoning, "tables": tables}) def test_show_schema_complex_table(self): """Test schema display for complex table structure.""" reasoning = "Showing complex schema" tables = ["complex_table"] with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list: mock_list.return_value = [ "Table: complex_table\n# Description: Complex data structure\n# Columns:\nid: Primary key\ndata: JSON data\ncreated_at: Timestamp" ] result = show_schema.run({"reasoning": reasoning, "tables": tables}) assert isinstance(result, list) assert "complex_table" in result[0] assert "Primary key" in result[0] def test_search_knowledge_with_metrics(self): """Test knowledge search focusing on metrics.""" reasoning = "Finding metrics columns" query_list = ["revenue", "clicks", "impressions"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "revenue: Revenue amount\nclicks: Click count\nimpressions: Impression count" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "revenue" in result["columns"] assert "clicks" in result["columns"] def test_search_knowledge_contextual_search(self): """Test contextual knowledge search.""" reasoning = "Contextual search for user behavior" query_list = ["user", "behavior", "tracking"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_behavior: User activity tracking\ntracking_id: Tracking identifier" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "behavior" in result["columns"] def test_search_knowledge_with_aggregations(self): """Test knowledge search for aggregation columns.""" reasoning = "Finding aggregation metrics" query_list = ["sum", "count", "average"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "total_count: Count aggregation\naverage_value: Average calculation" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) def test_show_schema_with_examples(self): """Test schema display with usage examples.""" reasoning = "Showing schema with examples" tables = ["example_table"] with patch("openchatbi.tool.search_knowledge.list_table_from_catalog") as mock_list: mock_list.return_value = [ "Table: example_table\n# Description: Example usage\n## Derived metrics:\nSELECT COUNT(*) FROM example_table" ] result = show_schema.run({"reasoning": reasoning, "tables": tables}) assert isinstance(result, list) assert "example_table" in result[0] assert "Derived metrics" in result[0] def test_search_knowledge_performance(self): """Test knowledge search performance characteristics.""" reasoning = "Testing search performance" query_list = ["performance", "speed", "optimization"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "performance_metric: Performance measurement" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) # Just ensure it completes without performance issues def test_search_knowledge_special_characters(self): """Test knowledge search with special characters.""" reasoning = "Testing special character handling" query_list = ["user@domain", "data-point", "metric_value"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_email: User email address\ndata_point: Data measurement" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) def test_search_knowledge_unicode_support(self): """Test knowledge search with unicode characters.""" reasoning = "Testing unicode support" query_list = ["utilización", "données", "用户"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "user_data: International user data" result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) def test_knowledge_integration_with_state(self): """Test knowledge search integration with agent state.""" reasoning = "Testing state integration" query_list = ["state", "integration"] knowledge_bases = ["columns"] with patch("openchatbi.tool.search_knowledge.search_column_from_catalog") as mock_search: mock_search.return_value = "state_data: Application state information" # Test that the tool can be called in the context of agent state result = search_knowledge.run( { "reasoning": reasoning, "query_list": query_list, "knowledge_bases": knowledge_bases, "with_table_list": False, } ) assert isinstance(result, dict) assert "columns" in result ================================================ FILE: tests/test_utils.py ================================================ """Tests for utility functions.""" import io from unittest.mock import patch import pytest from openchatbi.utils import log class TestUtilityFunctions: """Test utility functions.""" def test_log_function_basic(self): """Test basic logging functionality.""" # Capture stdout captured_output = io.StringIO() with patch("sys.stderr", captured_output): log("Test message") output = captured_output.getvalue() assert "Test message" in output def test_log_function_multiple_messages(self): """Test logging with multiple messages.""" captured_output = io.StringIO() with patch("sys.stderr", captured_output): log("First message") log("Second message") output = captured_output.getvalue() assert "First message" in output assert "Second message" in output def test_log_function_empty_message(self): """Test logging with empty message.""" captured_output = io.StringIO() with patch("sys.stderr", captured_output): log("") output = captured_output.getvalue() # Should handle empty messages gracefully assert output is not None def test_log_function_none_message(self): """Test logging with None message.""" captured_output = io.StringIO() with patch("sys.stderr", captured_output): log(None) output = captured_output.getvalue() # Should handle None messages gracefully assert "None" in output or output == "" def test_log_function_complex_objects(self): """Test logging with complex objects.""" test_dict = {"key": "value", "number": 42} test_list = [1, 2, 3, "string"] captured_output = io.StringIO() with patch("sys.stderr", captured_output): log(test_dict) log(test_list) output = captured_output.getvalue() assert "key" in output or str(test_dict) in output assert "string" in output or str(test_list) in output def test_log_function_with_exception(self): """Test logging exception information.""" try: raise ValueError("Test exception") except ValueError as e: captured_output = io.StringIO() with patch("sys.stderr", captured_output): log(f"Exception occurred: {e}") output = captured_output.getvalue() assert "Exception occurred" in output assert "Test exception" in output @patch("sys.stderr") def test_log_function_stderr_error(self, mock_stderr): """Test logging when stderr has issues.""" mock_stderr.write.side_effect = OSError("stderr error") # Current implementation raises exception when stderr fails - this is expected with pytest.raises(OSError, match="stderr error"): log("Test message") def test_log_function_unicode_handling(self): """Test logging with unicode characters.""" unicode_message = "Test with émojis: 🚀 and spéciál characters: ñáéíóú" captured_output = io.StringIO() with patch("sys.stderr", captured_output): log(unicode_message) output = captured_output.getvalue() # Should handle unicode characters properly assert len(output) > 0 def test_log_function_large_message(self): """Test logging with very large messages.""" large_message = "A" * 10000 # 10KB message captured_output = io.StringIO() with patch("sys.stderr", captured_output): log(large_message) output = captured_output.getvalue() assert len(output) > 0 assert "A" in output def test_log_function_newline_handling(self): """Test logging with messages containing newlines.""" multiline_message = "Line 1\\nLine 2\\nLine 3" captured_output = io.StringIO() with patch("sys.stderr", captured_output): log(multiline_message) output = captured_output.getvalue() assert "Line 1" in output assert "Line 2" in output assert "Line 3" in output def test_log_function_timestamp_format(self): """Test that log includes timestamp information.""" captured_output = io.StringIO() with patch("sys.stderr", captured_output): log("Timestamp test") output = captured_output.getvalue() # Check if output contains timestamp-like format (basic check) # The actual implementation might vary assert len(output) > len("Timestamp test") def test_log_function_concurrent_calls(self): """Test logging with concurrent-like calls.""" import threading import time captured_output = io.StringIO() def log_worker(message): log(f"Worker: {message}") time.sleep(0.01) # Small delay # Patch stderr for all threads with patch("sys.stderr", captured_output): # Create multiple threads (simulating concurrency) threads = [] for i in range(5): thread = threading.Thread(target=log_worker, args=(f"message_{i}",)) threads.append(thread) # Start all threads for thread in threads: thread.start() # Wait for all threads for thread in threads: thread.join() output = captured_output.getvalue() # Should handle concurrent access gracefully assert len(output) > 0 ================================================ FILE: timeseries_forecasting/Dockerfile ================================================ FROM python:3.10-slim # Install only essential build tools RUN apt-get update && apt-get install -y --no-install-recommends \ curl \ wget \ build-essential \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean # Install Python dependencies for time series forecasting RUN pip3 --no-cache-dir install \ fastapi==0.120.4 \ uvicorn==0.38.0 \ transformers==4.40.1 \ torch==2.9.0 \ numpy==2.2.6 \ pandas==2.3.3 \ pydantic==2.12.3 # Set working directory WORKDIR /home/model-server # Copy the model COPY ../hf_model /home/model-server/hf_model # Copy application files COPY app.py model_handler.py /home/model-server/ # Set environment variables ENV PYTHONPATH=/home/model-server ENV PYTHONUNBUFFERED=1 # Expose port EXPOSE 8765 # Define entrypoint and default command CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8765"] ================================================ FILE: timeseries_forecasting/README.md ================================================ # Transformer Time Series Forecasting Service A Docker-based time series forecasting service using Transformer based models for accurate time series prediction. This service provides a FastAPI-based REST API for easy integration with various applications. ## Features - **Transformer Model Integration**: Uses state-of-the-art Transformer models for time series forecasting - **FastAPI Backend**: Modern, fast web framework with automatic API documentation - **Docker Support**: Fully containerized service for easy deployment - **Flexible Input**: Supports both simple numeric arrays and structured data with timestamps - **Multiple Forecast Horizons**: Configure prediction length from 1 to 200 time steps - **GPU Support**: Automatic GPU detection and utilization when available ## Prerequisites - Docker installed and running - Transformer model files (compatible with Hugging Face transformers library) ## Quick Start ### 1. Download Transformer Model Download a pre-trained model from Hugging Face and place it in the `hf_model` directory. For example, use the recommended `timer-base-84m` model from https://huggingface.co/thuml/timer-base-84m: > **Note**: The `timer-base-84m` model requires at least 96 time points in the input data. When integrating with OpenChatBI, add this restriction to your `extra_tool_use_rule` in bi.yaml: > ``` > - timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros. > ``` ```bash ### 2. Build and Run ```bash cd timeseries_forecasting chmod +x build_and_run.sh ./build_and_run.sh ``` The service will be available at: - **Predictions**: `http://localhost:8765/predict` - **Health Check**: `http://localhost:8765/health` - **API Documentation**: `http://localhost:8765/docs` - **Model Info**: `http://localhost:8765/model/info` ### 2. Make a Prediction ```bash curl -X POST http://localhost:8765/predict \ -H "Content-Type: application/json" \ -d '{ "input": [100, 102, 98, 105, 107, 103, 99, 101, 104, 106], "input_len": 100, "forecast_window": 5, "frequency": "H" }' ``` ### 4. Test the Service Run the comprehensive test suite: ```bash python test_forecasting.py --url http://localhost:8765 ``` ## API Reference ### Prediction Endpoint **POST** `/predict` #### Request Format ```json { "input": [...], // Time series data (required) "forecast_window": 24, // Number of future points to predict (default: 24, max: 200) "frequency": "H", // Frequency: "H" (hourly), "D" (daily), etc. (default: "H") "input_len": null, // Limit input length, if provided, will use it to truncate input or pad zero (optional) "target_column": "value" // Column name for structured data (default: "value") } ``` #### Input Data Formats **Simple Numeric Array:** ```json { "input": [100, 102, 98, 105, 107, 103, 99, 101], "input_len": 100, "forecast_window": 12 } ``` **Structured Data with Timestamps:** ```json { "input": [ {"timestamp": "2024-01-01T00:00:00", "value": 100}, {"timestamp": "2024-01-01T01:00:00", "value": 102}, {"timestamp": "2024-01-01T02:00:00", "value": 98} ], "input_len": 100, "forecast_window": 24, "target_column": "value" } ``` #### Response Format ```json { "predictions": [101.5, 103.2, 99.8, ...], "forecast_window": 24, "frequency": "H", "status": "success" } ``` ## Configuration ### Environment Variables - `PYTHONPATH`: Python path for modules (default: /home/model-server) - `PYTHONUNBUFFERED`: Disable Python output buffering (default: 1) ### Docker Run Options ```bash # Basic run docker run -p 8765:8765 timeseries-forecasting # With volume mount for models docker run -p 8765:8765 \ -v /path/to/model:/app/hf_model \ timeseries-forecasting # With custom environment variables docker run -p 8765:8765 \ -e PYTHONPATH=/home/model-server \ timeseries-forecasting ``` ## Testing ### Service Tests Run the test script to validate the service: ```bash # Make test script executable chmod +x test_forecasting.py # Install test dependencies pip install requests numpy # Run tests python test_forecasting.py --url http://localhost:8765 ``` ## Model Information - **Recommended Models**: https://huggingface.co/thuml/timer-base-84m - **Model Type**: Transformer-based Causal Language Model for Time Series - **Framework**: Hugging Face Transformers - **Architecture**: AutoModelForCausalLM - **Device Support**: Automatic GPU/CPU detection - **Capabilities**: Univariate time series forecasting with automatic normalization ## Troubleshooting ### Common Issues 1. **Service Not Starting** - Check if port 8765 is available: `lsof -i :8765` - Verify Docker has sufficient memory allocated (minimum 4GB recommended) - Check logs: `docker logs time-series-forecasting-service` 2. **Model Loading Errors** - Ensure model files are properly copied during build - Check available disk space (models can be several GB) - Verify Hugging Face transformers library compatibility 3. **Prediction Errors** - Validate input data format - Check forecast horizon is reasonable - Ensure input data has sufficient length ### Debug Mode Enable debug logging: ```bash docker run -p 8765:8765 \ -e PYTHONPATH=/home/model-server \ -e LOGGING_LEVEL=DEBUG \ timeseries-forecasting ``` ## Performance - **Cold Start**: ~10 seconds (model loading) - **Inference Time**: ~100-300ms per request (varies by input size and model) - **Memory Usage**: ~2-4GB (depending on input size and model) - **Concurrent Requests**: Supported (configure workers) ## Limitations - Maximum forecast window: 200 time points - Univariate forecasting (single time series) - Requires minimum input data for reliable predictions, timer-base-84m needs at least 96 time points - Model-specific context length limitations may apply ================================================ FILE: timeseries_forecasting/app.py ================================================ """app.py: FastAPI application for Transformer time series forecasting.""" import logging import time from typing import Any import uvicorn from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from starlette.requests import Request from model_handler import TransformerModelHandler, get_model_handler # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Create FastAPI app app = FastAPI( title="Transformer Time Series Forecasting API", description="A REST API for time series forecasting using Transformer model", version="1.0.0", ) # Request models class ForecastRequest(BaseModel): """Request model for forecasting.""" input: list[float | int | dict[str, Any]] = Field( ..., description="Time series data as list of numbers or structured data", example=[100, 102, 98, 105, 107, 103, 99, 101], ) forecast_window: int = Field(default=24, ge=1, le=200, description="Number of future points to predict") input_len: int | None = Field(default=None, description="Optional input length limit") frequency: str = Field(default="hourly", description="Frequency of the time series (hourly, daily, etc.)") target_column: str = Field(default="value", description="Column name for structured data") class ForecastResponse(BaseModel): """Response model for forecasting.""" predictions: list[float] = Field(description="Forecasted values") forecast_window: int = Field(description="Number of predictions") frequency: str = Field(description="Time series frequency") status: str = Field(description="Response status") class ErrorResponse(BaseModel): """Error response model.""" error: str = Field(description="Error message") status: str = Field(description="Response status") # Global variables model_handler: TransformerModelHandler | None = None startup_time: float | None = None @app.on_event("startup") async def startup_event(): """Initialize model on startup.""" global model_handler, startup_time startup_time = time.time() logger.info("Starting Transformer Forecasting API...") try: # Initialize model handler model_handler = get_model_handler() model_success = model_handler.initialize() if model_success: logger.info("Model initialized successfully") else: logger.error("Failed to initialize model") except Exception as e: logger.error(f"Startup failed: {str(e)}") @app.get("/health") async def health_check(): """Health check endpoint.""" uptime = time.time() - startup_time if startup_time else 0 return { "status": "healthy", "model_initialized": model_handler.initialized if model_handler else False, "uptime_seconds": round(uptime, 2), } @app.get("/ping") async def ping(): """Simple ping endpoint.""" return {"status": "ok"} @app.post( "/predict", response_model=ForecastResponse | ErrorResponse, responses={ 400: {"model": ErrorResponse, "description": "Bad Request"}, 422: {"model": ErrorResponse, "description": "Validation Error"}, 500: {"model": ErrorResponse, "description": "Internal Error"}, }, ) async def predict(request: ForecastRequest): """ Main forecasting endpoint. Args: request: Forecast request containing time series data and parameters Returns: Forecast response with predictions or error """ try: logger.info(f"Received prediction request: {len(request.input)} data points, horizon={request.forecast_window}") # Check if model is initialized if not model_handler or not model_handler.initialized: raise HTTPException(status_code=500, detail="Model not initialized") # Validate input if len(request.input) == 0: raise HTTPException(status_code=400, detail="Input data cannot be empty") # Make prediction result = model_handler.predict( time_series_data=request.input, forecast_window=request.forecast_window, input_len=request.input_len, frequency=request.frequency, target_column=request.target_column, ) # Check if prediction was successful if result.get("status") == "error": raise HTTPException(status_code=result.get("code", 500), detail=result.get("error", "Prediction failed")) logger.info(f"Prediction successful: {len(result['predictions'])} predictions generated") return ForecastResponse(**result) except HTTPException as e: return JSONResponse(status_code=e.status_code, content=ErrorResponse(error=str(e), status="error").model_dump()) except Exception as e: logger.error(f"Prediction error: {str(e)}") return JSONResponse(status_code=500, content=ErrorResponse(error=str(e), status="error").model_dump()) @app.get("/model/info") async def model_info(): """Get model information.""" if not model_handler or not model_handler.initialized: return {"error": "Model not initialized", "status": "error"} return { "model_path": model_handler.model_path, "device": str(model_handler.device), "initialized": model_handler.initialized, "config": str(model_handler.config) if model_handler.config else None, } @app.get("/") async def root(): """Root endpoint with API information.""" return { "name": "Transformer Time Series Forecasting API", "version": "1.0.0", "description": "REST API for time series forecasting using Transformer model", "endpoints": { "predict": "/predict", "health": "/health", "ping": "/ping", "model_info": "/model/info", "docs": "/docs", }, } # Error handlers @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """Handle HTTP exceptions.""" return JSONResponse( status_code=exc.status_code, content={"status": "error", "message": exc.detail}, ) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): """Handle general exceptions.""" logger.error(f"Unhandled exception: {str(exc)}") return JSONResponse( status_code=500, content={"status": "error", "message": "Internal server error"}, ) if __name__ == "__main__": # For development uvicorn.run("app:app", host="0.0.0.0", port=8765, reload=True, log_level="info") ================================================ FILE: timeseries_forecasting/build_and_run.sh ================================================ #!/bin/bash # Build and run script for time series forecasting service set -e echo "=== Building Timeseries Forecasting Docker Container ===" # Check if the hf_model model directory exists MODEL_DIR="../hf_model" if [ ! -d "$MODEL_DIR" ]; then echo "Error: Hugging face model directory not found at $MODEL_DIR" echo "Please ensure the model is downloaded and available at this location" exit 1 fi echo "✓ Found Hugging face model at: $MODEL_DIR" rm -rf hf_model cp -r $MODEL_DIR . # Build the Docker image echo "Building Docker image..." docker build -t timeseries-forecasting . if [ $? -eq 0 ]; then echo "✓ Docker image built successfully" else echo "✗ Failed to build Docker image" exit 1 fi # Check if container is already running CONTAINER_NAME="time-series-forecasting-service" if [ "$(docker ps -q -f name=$CONTAINER_NAME)" ]; then echo "Stopping existing container..." docker stop $CONTAINER_NAME docker rm $CONTAINER_NAME fi echo "=== Starting Timeseries Forecasting Service ===" # Run the container docker run -d \ --name $CONTAINER_NAME \ -p 8765:8765 \ timeseries-forecasting if [ $? -eq 0 ]; then echo "✓ Container started successfully" echo "" echo "Service endpoints:" echo " - Predictions: http://localhost:8765/predict" echo " - Health Check: http://localhost:8765/health" echo " - API Docs: http://localhost:8765/docs" echo "" echo "Container logs:" echo " docker logs -f $CONTAINER_NAME" echo "" echo "To test the service:" echo " python test_forecasting.py" echo "" echo "To stop the service:" echo " docker stop $CONTAINER_NAME" else echo "✗ Failed to start container" exit 1 fi # Wait a moment and check if container is still running sleep 5 if [ "$(docker ps -q -f name=$CONTAINER_NAME)" ]; then echo "✓ Service is running" # Show few logs echo "" echo "=== Initial Service Logs ===" docker logs "$CONTAINER_NAME" | head -n 50 else echo "✗ Service failed to start" echo "Checking logs..." docker logs $CONTAINER_NAME exit 1 fi ================================================ FILE: timeseries_forecasting/model_handler.py ================================================ """model_handler.py: Transformer based model handler for time series forecasting.""" import logging from typing import Any import numpy as np import pandas as pd import torch from transformers import AutoConfig, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TransformerModelHandler: """ Transformer based Model handler for time series forecasting. """ def __init__(self, model_path: str = "hf_model"): """Initialize the model handler.""" logger.info("Initializing Transformer Model Handler") self.model_path = model_path self.model = None self.config = None self.device = None self.initialized = False def initialize(self) -> bool: """ Initialize model. Returns: bool: True if initialization successful """ try: logger.info("Starting model initialization") # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") logger.info(f"Loading model from: {self.model_path}") # Load model configuration self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True) # Load the pretrained model self.model = AutoModelForCausalLM.from_pretrained( self.model_path, config=self.config, trust_remote_code=True ) # Move model to device self.model.to(self.device) self.model.eval() self.initialized = True logger.info("Transformer model loaded successfully") logger.info(f"Model config: {self.config}") return True except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") self.initialized = False return False def preprocess( self, time_series_data: list, forecast_window: int = 24, input_len: int | None = None, frequency: str = "hourly", target_column: str = "value", ) -> tuple[torch.Tensor, dict[str, Any]]: """ Transform raw input into model input data. Args: time_series_data: Input time series data forecast_window: Number of future points to predict input_len: Optional input length limit frequency: Frequency of the time series target_column: Column name for structured data Returns: Tuple of (processed_tensor, metadata) """ try: logger.info(f"Input data length: {len(time_series_data) if isinstance(time_series_data, list) else 'N/A'}") logger.info(f"Forecast window: {forecast_window}") # Convert input to numpy array if isinstance(time_series_data, list): if len(time_series_data) > 0 and isinstance(time_series_data[0], dict): # Handle structured data (with timestamps) df = pd.DataFrame(time_series_data) if target_column in df.columns: values = df[target_column].values else: # Use the first numeric column numeric_cols = df.select_dtypes(include=[np.number]).columns if len(numeric_cols) > 0: values = df[numeric_cols[0]].values else: values = np.array([float(x) for x in time_series_data]) else: # Handle simple numeric list values = np.array([float(x) for x in time_series_data]) else: values = np.array(time_series_data) # Handle input length constraint if input_len is not None: if input_len > len(values): # Pad with zeros if input is shorter than required values = np.pad(values, (input_len - len(values), 0), mode="constant", constant_values=0) elif input_len < len(values): # Take the last input_len values values = values[-input_len:] # Normalize the data (simple z-score normalization) mean_val = np.mean(values) std_val = np.std(values) if std_val > 0: normalized_values = (values - mean_val) / std_val else: normalized_values = values - mean_val # Convert to tensor tensor = torch.tensor(normalized_values, dtype=torch.float32).unsqueeze(0) tensor = tensor.to(self.device) # Store metadata for post-processing metadata = { "mean": mean_val, "std": std_val, "forecast_window": forecast_window, "frequency": frequency, "original_length": len(values), } logger.info(f"Preprocessed tensor shape: {tensor.shape}") return tensor, metadata except Exception as e: logger.error(f"Preprocessing failed: {str(e)}") raise e def inference(self, input_tensor: torch.Tensor, metadata: dict[str, Any]) -> torch.Tensor: """ Run inference on the model. Args: input_tensor: Preprocessed input tensor metadata: Preprocessing metadata Returns: Model output tensor """ try: if not self.initialized: raise RuntimeError("Model not initialized") with torch.no_grad(): forecast_window = metadata.get("forecast_window", 24) # Use generate method output = self.model.generate(input_tensor, max_new_tokens=forecast_window) logger.info(f"Model output shape: {output.shape}") return output except ValueError as e: logger.error(f"Inference failed due to ValueError: {str(e)}") raise e except Exception as e: logger.error(f"Inference failed: {str(e)}") raise e def postprocess(self, output_tensor: torch.Tensor, metadata: dict[str, Any]) -> list[float]: """ Transform model output to final prediction format. Args: output_tensor: Raw model output metadata: Preprocessing metadata Returns: Final predictions as list """ try: # Extract predictions from tensor if output_tensor.dim() > 1: predictions = output_tensor[0].cpu().numpy() else: predictions = output_tensor.cpu().numpy() # Denormalize the predictions mean_val = metadata.get("mean", 0) std_val = metadata.get("std", 1) if std_val > 0: denormalized_predictions = predictions * std_val + mean_val else: denormalized_predictions = predictions + mean_val # Convert to list and ensure it's the right length forecast_window = metadata.get("forecast_window", 24) result = denormalized_predictions[:forecast_window].tolist() logger.info(f"Final predictions length: {len(result)}") return result except Exception as e: logger.error(f"Postprocessing failed: {str(e)}") raise e def predict( self, time_series_data: list, forecast_window: int = 24, input_len: int | None = None, frequency: str = "hourly", target_column: str = "value", ) -> dict[str, Any]: """ Main prediction method. Args: time_series_data: Input time series data forecast_window: Number of future points to predict input_len: Optional input length limit frequency: Frequency of the time series target_column: Column name for structured data Returns: Dictionary containing predictions and metadata """ try: # Ensure model is initialized if not self.initialized: if not self.initialize(): raise RuntimeError("Failed to initialize model") # Preprocess input input_tensor, metadata = self.preprocess( time_series_data, forecast_window, input_len, frequency, target_column ) # Run inference output_tensor = self.inference(input_tensor, metadata) # Postprocess output predictions = self.postprocess(output_tensor, metadata) # Format result result = { "predictions": predictions, "forecast_window": metadata.get("forecast_window", 24), "frequency": metadata.get("frequency", "hourly"), "status": "success", } return result except ValueError as e: logger.error(f"Prediction failed due to ValueError: {str(e)}") return {"error": str(e), "code": 400, "status": "error"} except Exception as e: logger.error(f"Prediction failed: {str(e)}") return {"error": str(e), "status": "error"} # Global model handler instance _model_handler = None def get_model_handler() -> TransformerModelHandler: """Get or create global model handler instance.""" global _model_handler if _model_handler is None: _model_handler = TransformerModelHandler() return _model_handler ================================================ FILE: timeseries_forecasting/test_forecasting.py ================================================ #!/usr/bin/env python3 """test_forecasting.py: Test script for Timer forecasting service.""" import time from datetime import datetime, timedelta import numpy as np import requests from requests.exceptions import RequestException class TimeseriesForecastingTester: """Test class for Timer forecasting service.""" def __init__(self, base_url="http://localhost:8765"): """Initialize the tester.""" self.base_url = base_url self.predictions_endpoint = f"{base_url}/predict" self.health_endpoint = f"{base_url}/health" def generate_sample_data(self, length=100, frequency="H"): """Generate sample time series data for testing.""" # Generate synthetic time series with trend and seasonality t = np.arange(length) # Add trend trend = 0.1 * t # Add seasonality (daily pattern for hourly data) if frequency == "H": seasonality = 5 * np.sin(2 * np.pi * t / 24) else: seasonality = 3 * np.sin(2 * np.pi * t / 7) # Weekly pattern for daily data # Add noise noise = np.random.normal(0, 1, length) # Combine components values = 100 + trend + seasonality + noise return values.tolist() def test_basic_forecasting(self): """Test basic time series forecasting.""" print("\n=== Testing Basic Forecasting ===") # Generate sample data sample_data = self.generate_sample_data(length=168, frequency="H") # 1 week of hourly data # Prepare request payload payload = { "input": sample_data, "forecast_window": 24, # Forecast next 24 hours "frequency": "H", "input_len": 168, # Use last week of data } # Send request try: response = requests.post( self.predictions_endpoint, json=payload, headers={"Content-Type": "application/json"}, timeout=30 ) if response.status_code == 200: result = response.json() print("✓ Basic forecasting successful") print(f" - Input length: {len(sample_data)}") print(f" - Forecast Window: {payload['forecast_window']}") print(f" - Predictions length: {len(result.get('predictions', []))}") print(f" - Sample predictions: {result.get('predictions', [])[:5]}") return True else: print(f"✗ Request failed with status: {response.status_code}") print(f" Response: {response.text}") return False except requests.exceptions.RequestException as e: print(f"✗ Request failed: {str(e)}") return False def test_structured_data(self): """Test forecasting with structured data (timestamps + values).""" print("\n=== Testing Structured Data Forecasting ===") # Generate structured data with timestamps start_time = datetime.now() - timedelta(days=7) structured_data = [] for i in range(168): # 1 week of hourly data timestamp = start_time + timedelta(hours=i) value = 100 + 0.1 * i + 5 * np.sin(2 * np.pi * i / 24) + np.random.normal(0, 1) structured_data.append({"timestamp": timestamp.isoformat(), "value": value}) # Prepare request payload payload = { "input": structured_data, "forecast_window": 48, # Forecast next 48 hours "frequency": "H", "target_column": "value", } # Send request try: response = requests.post( self.predictions_endpoint, json=payload, headers={"Content-Type": "application/json"}, timeout=30 ) if response.status_code == 200: result = response.json() print("✓ Structured data forecasting successful") print(f" - Input records: {len(structured_data)}") print(f" - Forecast Window: {payload['forecast_window']}") print(f" - Predictions length: {len(result.get('predictions', []))}") return True else: print(f"✗ Request failed with status: {response.status_code}") print(f" Response: {response.text}") return False except requests.exceptions.RequestException as e: print(f"✗ Request failed: {str(e)}") return False def test_different_windows(self): """Test forecasting with different forecast windows.""" print("\n=== Testing Different Forecast Horizons ===") sample_data = self.generate_sample_data(length=100) windows = [1, 12, 24, 48, 72] for window in windows: payload = {"input": sample_data, "forecast_window": window, "frequency": "H"} try: response = requests.post( self.predictions_endpoint, json=payload, headers={"Content-Type": "application/json"}, timeout=30 ) if response.status_code == 200: result = response.json() predictions_len = len(result.get("predictions", [])) print(f"✓ Window {window}: {predictions_len} predictions") else: print(f"✗ Window {window}: Failed with status {response.status_code}") return False except requests.exceptions.RequestException as e: print(f"✗ Window {window}: Request failed - {str(e)}") return False return True def test_error_handling(self): """Test error handling with invalid inputs.""" print("\n=== Testing Error Handling ===") # Test empty input try: response = requests.post( self.predictions_endpoint, json={"input": []}, headers={"Content-Type": "application/json"}, timeout=10 ) print(f"Empty input: Status {response.status_code}") if response.status_code != 400: print("✗ Empty input: Expected 400 status code") return False except RequestException: print("Empty input: exception occurred not expected") return False # Test invalid JSON try: response = requests.post( self.predictions_endpoint, data="invalid json", headers={"Content-Type": "application/json"}, timeout=10 ) print(f"Invalid JSON: Status {response.status_code}") if response.status_code != 422: print("✗ Empty input: Expected 422 status code") return False except RequestException: print("Empty input: exception occurred not expected") return False return True def test_health_check(self): """Test service health check.""" print("\n=== Testing Service Health ===") try: # Test health endpoint response = requests.get(self.health_endpoint, timeout=5) if response.status_code == 200: result = response.json() print("✓ Service health check passed") print(f" - Model initialized: {result.get('model_initialized', 'Unknown')}") print(f" - Uptime: {result.get('uptime_seconds', 'Unknown')} seconds") return True else: print(f"✗ Health check failed: {response.status_code}") return False except RequestException as e: print(f"Health check failed: {str(e)}") return False def run_all_tests(self): """Run all tests.""" print("=" * 50) print("TIMER FORECASTING SERVICE TESTS") print("=" * 50) # Wait for service to be ready print("Waiting for service to be ready...") for _i in range(30): # Wait up to 30 seconds try: response = requests.get(self.health_endpoint, timeout=2) if response.status_code == 200: result = response.json() if result.get("model_initialized", False): print("✓ Service is ready") break except RequestException: pass time.sleep(1) else: print("✗ Service not ready after 30 seconds") return False # Run tests tests = [ self.test_health_check, self.test_basic_forecasting, self.test_structured_data, self.test_different_windows, self.test_error_handling, ] passed = 0 total = len(tests) for test in tests: try: if test(): passed += 1 except Exception as e: print(f"✗ Test {test.__name__} failed with exception: {str(e)}") print("\n" + "=" * 50) print(f"TESTS COMPLETED: {passed}/{total} passed") print("=" * 50) return passed == total def main(): """Main function.""" import argparse parser = argparse.ArgumentParser(description="Test Timer forecasting service") parser.add_argument( "--url", default="http://localhost:8765", help="Base URL of the service (default: http://localhost:8080)" ) args = parser.parse_args() tester = TimeseriesForecastingTester(args.url) success = tester.run_all_tests() exit(0 if success else 1) if __name__ == "__main__": main()