[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/workflows/docs.yml",
    "content": "name: Build and Deploy Documentation\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\npermissions:\n  contents: read\n  pages: write\n  id-token: write\n\nconcurrency:\n  group: \"pages\"\n  cancel-in-progress: false\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n\n    - name: Set up Python\n      uses: actions/setup-python@v4\n      with:\n        python-version: '3.11'\n\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip setuptools wheel\n        pip install -e \".[docs]\"\n\n    - name: Build documentation\n      run: |\n        cd docs\n        make html\n\n    - name: Setup Pages\n      uses: actions/configure-pages@v4\n\n    - name: Upload artifact\n      uses: actions/upload-pages-artifact@v3\n      with:\n        path: './docs/build/html'\n\n  deploy:\n    environment:\n      name: github-pages\n      url: ${{ steps.deployment.outputs.page_url }}\n    runs-on: ubuntu-latest\n    needs: build\n    if: github.ref == 'refs/heads/main'\n    steps:\n      - name: Deploy to GitHub Pages\n        id: deployment\n        uses: actions/deploy-pages@v4"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "name: Publish to PyPI\n\non:\n  release:\n    types: [published]  # Trigger when a release is published\n  workflow_dispatch:  # Allow manual triggering\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: ['3.11', '3.12']\n\n    steps:\n    - uses: actions/checkout@v4\n\n    - name: Install uv\n      uses: astral-sh/setup-uv@v4\n      with:\n        version: \"latest\"\n\n    - name: Set up Python ${{ matrix.python-version }}\n      run: uv python install ${{ matrix.python-version }}\n\n    - name: Install dependencies\n      run: |\n        uv sync --all-extras\n\n    - name: Run linting\n      run: |\n        uv run black --check .\n\n    - name: Run tests\n      run: |\n        uv run pytest -v --cov=openchatbi --cov-report=xml\n\n    - name: Upload coverage to Codecov\n      uses: codecov/codecov-action@v3\n      with:\n        file: ./coverage.xml\n        flags: unittests\n        name: codecov-umbrella\n\n  build:\n    needs: test\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n\n    - name: Install uv\n      uses: astral-sh/setup-uv@v4\n      with:\n        version: \"latest\"\n\n    - name: Set up Python\n      run: uv python install 3.11\n\n    - name: Build package\n      run: |\n        uv build\n\n    - name: Check build artifacts\n      run: |\n        ls -la dist/\n        uv run twine check dist/*\n\n    - name: Upload build artifacts\n      uses: actions/upload-artifact@v4\n      with:\n        name: dist\n        path: dist/\n\n  publish:\n    needs: build\n    runs-on: ubuntu-latest\n    if: github.event_name == 'release'\n    environment:\n      name: pypi\n      url: https://pypi.org/p/openchatbi\n    permissions:\n      id-token: write  # Required for PyPI trusted publishing\n      contents: read\n\n    steps:\n    - name: Download build artifacts\n      uses: actions/download-artifact@v4\n      with:\n        name: dist\n        path: dist/\n\n    - name: Publish to PyPI\n      uses: pypa/gh-action-pypi-publish@release/v1\n      with:\n        # Uses PyPI trusted publishing, no API token needed\n        verbose: true\n        print-hash: true\n\n  publish-test:\n    needs: build\n    runs-on: ubuntu-latest\n    if: github.event_name == 'workflow_dispatch'  # Only publish to TestPyPI when manually triggered\n    environment:\n      name: testpypi\n      url: https://test.pypi.org/p/openchatbi\n    permissions:\n      id-token: write\n      contents: read\n\n    steps:\n    - name: Download build artifacts\n      uses: actions/download-artifact@v4\n      with:\n        name: dist\n        path: dist/\n\n    - name: Publish to TestPyPI\n      uses: pypa/gh-action-pypi-publish@release/v1\n      with:\n        repository-url: https://test.pypi.org/legacy/\n        verbose: true\n        print-hash: true"
  },
  {
    "path": ".github/workflows/runledger.yml",
    "content": "name: runledger\non:\n  workflow_dispatch:\n  pull_request:\n    paths:\n      - \"openchatbi/**\"\n\njobs:\n  runledger:\n    if: github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'runledger')\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          python -m pip install runledger\n          python -m pip install .\n      - name: Run deterministic evals (replay)\n        run: |\n          runledger run evals/runledger --mode replay --baseline baselines/runledger-openchatbi.json\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: runledger-artifacts\n          path: runledger_out/**\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[codz]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py.cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n#poetry.toml\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#   pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.\n#   https://pdm-project.org/en/latest/usage/project/#working-with-version-control\n#pdm.lock\n#pdm.toml\n.pdm-python\n.pdm-build/\n\n# pixi\n#   Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.\n#pixi.lock\n#   Pixi creates a virtual environment in the .pixi directory, just like venv module creates one\n#   in the .venv directory. It is recommended not to include this directory in version control.\n.pixi\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.envrc\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# Abstra\n# Abstra is an AI-powered process automation framework.\n# Ignore directories containing user credentials, local state, and settings.\n# Learn more at https://abstra.io/docs\n.abstra/\n\n# Visual Studio Code\n#  Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore \n#  that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore\n#  and can be added to the global gitignore or merged into this file. However, if you prefer, \n#  you could uncomment the following to ignore the entire vscode folder\n# .vscode/\n\n# Ruff stuff:\n.ruff_cache/\n\n# PyPI configuration file\n.pypirc\n\n# Cursor\n#  Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to\n#  exclude from AI features like autocomplete and code analysis. Recommended for sensitive data\n#  refer to https://docs.cursor.com/context/ignore-files\n.cursorignore\n.cursorindexingignore\n\n# Marimo\nmarimo/_static/\nmarimo/_lsp/\n__marimo__/\n\n# project spec\nopenchatbi/config.yaml\nmemory.db\ncheckpoints.db\ndata\nhf_model\ntimeseries_forecasting/hf_model\nrunledger_out/\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to OpenChatBI\nHi there! Thank you for your interest in contributing to OpenChatBI.\n\nOpenChatBI started as a personal project, with the hope of making it easier for businesses to build their own ChatBI applications with less effort. To achieve this goal, I made it open source, and I greatly appreciate contributions of all kinds.\n\nWhether you’d like to propose a new feature, refactor the code, enhance documentation, or fix bugs, your contributions are always welcome.\n"
  },
  {
    "path": "Dockerfile.python-executor",
    "content": "FROM python:3.11-slim\n\n# Set working directory\nWORKDIR /app\n\n# Install basic packages that might be needed for data analysis\nRUN pip install --no-cache-dir \\\n    pandas \\\n    numpy \\\n    matplotlib \\\n    seaborn \\\n    requests \\\n    json5\n\n# Create a directory for code execution\nRUN mkdir -p /app/code\n\n# Set up a non-root user for security\nRUN useradd -m -u 1000 executor\nUSER executor\n\n# Set the default command\nCMD [\"python3\"]"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 Yu Zhong\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# OpenChatBI\n\nOpenChatBI is an open source, chat-based intelligent BI tool powered by large language models, designed to help users \nquery, analyze, and visualize data through natural language conversations. Built on LangGraph and LangChain ecosystem, \nit provides chat agents and workflows that support natural language to SQL conversion and streamlined data analysis.\n\nJoin the Slack channel to discuss: https://join.slack.com/t/openchatbicommunity/shared_invite/zt-3jpzpx9mv-Sk88RxpO4Up0L~YTZYf4GQ\n\n<img src=\"https://github.com/zhongyu09/openchatbi/raw/main/example/demo.gif\" alt=\"Demo\" width=\"800\">\n\n## Core Features\n\n1. **Natural Language Interaction**: Get data analysis results by asking questions in natural language\n2. **Automatic SQL Generation**: Convert natural language queries into SQL statements using advanced text2sql workflows\n   with schema linking and well organized prompt engineering\n3. **Data Visualization**: Generate intuitive data visualizations (via plotly)\n4. **Data Catalog Management**: Automatically discovers and indexes database table structures, supports flexible catalog\n   storage backends with vector-based or BM25-based retrieval, and easily maintains business explanations for tables\n   and columns as well as optimizes Prompts.\n5. **Time Series Forecasting**: Forecasting models deployed in-house that can be called as tools\n6. **Code Execution**: Execute Python code for data analysis and visualization\n7. **Interactive Problem-Solving**: Proactively ask users for more context when information is incomplete\n8. **Persistent Memory**: Conversation management and user characteristic memory based on LangGraph checkpointing\n9. **MCP Support**: Integration with MCP tools by configuration\n10. **Knowledge Base Integration**: Answer complex questions by combining catalog based knowledge retrival and external\n   knowledge base retrival (via MCP tools)\n11. **Web UI Interface**: Provide 2 sample UI: simple and streaming web interfaces using Gradio and Streamlit, easy to\n   integrate with other web applications\n\n## Roadmap\n\n1. **Anomaly Detection Algorithm**: Time series anomaly detection\n2. **Root Cause Analysis Algorithm**: Multi-dimensional drill-down capabilities for anomaly investigation\n\n# Getting started\n\n## Installation & Setup\n\n### Prerequisites\n\n- Python 3.11 or higher\n- Access to a supported LLM provider (OpenAI, Anthropic, etc.)\n- Data Warehouse (Database) credentials (like Presto, PostgreSQL, MySQL, etc.)\n- (Optional) Embedding model for vector-based retrieval - if not available, BM25-based retrieval will be used\n- (Optional) Docker - required only for `docker` executor mode\n\n**Note on Chinese Text Segmentation**: For better Chinese text retrieval, `jieba` is used for word segmentation. However, `jieba` is not compatible with Python 3.12+. On Python 3.12 and higher, the system automatically falls back to simple punctuation-based segmentation for Chinese text.\n\n### Installation\n\n1. **Using uv (recommended):**\n\n```bash\ngit clone git@github.com:zhongyu09/openchatbi\nuv sync\n```\n\n2. **Using pip:**\n\n```bash\npip install openchatbi\n```\n\n3. **For development:**\n\n```bash\ngit clone git@github.com:zhongyu09/openchatbi\nuv sync --group dev\n```\n\nOptional: If you want to use `pysqlite3` (newer SQLite builds), you can install it manually. If build fails, install SQLite first:\n\nOn macOS, try to install sqlite using Homebrew:\n```bash\nbrew install sqlite\nbrew info sqlite\nexport LDFLAGS=\"-L/opt/homebrew/opt/sqlite/lib\"\nexport CPPFLAGS=\"-I/opt/homebrew/opt/sqlite/include\"\n```\nOn Amazon Linux / RHEL / CentOS:\n```bash\nsudo yum install sqlite-devel\n```\nOn Ubuntu / Debian:\n```bash\nsudo apt-get update\nsudo apt-get install libsqlite3-dev\n```\n\n### Run Demo\n\nRun demo using **example dataset** from spider dataset. You need to provide \"YOUR OPENAI API KEY\" or change config to use other LLM providers.\n\n**Note**: The demo example includes embedding model configuration. If you want to run without an embedding model, you can remove the `embedding_model` section in the config - BM25 retrieval will be used automatically.\n\n```bash\ncp example/config.yaml openchatbi/config.yaml\nsed -i 's/YOUR_API_KEY_HERE/[YOUR OPENAI API KEY]/g' openchatbi/config.yaml\npython run_streamlit_ui.py\n```\n\n### Configuration\n\n1. **Create configuration file**\n\nCopy the configuration template:\n```bash\ncp openchatbi/config.yaml.template openchatbi/config.yaml\n```\nOr create an empty YAML file.\n\n2. **Configure your LLMs:**\n\n```yaml\n# Select which provider to use\ndefault_llm: openai\n\n# Define one or more providers\nllm_providers:\n  openai:\n    default_llm:\n      class: langchain_openai.ChatOpenAI\n      params:\n        api_key: YOUR_API_KEY_HERE\n        model: gpt-4.1\n        temperature: 0.02\n        max_tokens: 8192\n\n    # Optional: Embedding model for vector-based retrieval and memory tools\n    # If not configured, BM25-based retrieval will be used, and the memory tools will not work\n    embedding_model:\n      class: langchain_openai.OpenAIEmbeddings\n      params:\n        api_key: YOUR_API_KEY_HERE\n        model: text-embedding-3-large\n        chunk_size: 1024\n```\n\n3. **Configure your data warehouse:**\n\n```yaml\norganization: Your Company\ndialect: presto\ndata_warehouse_config:\n  uri: \"presto://user@host:8080/catalog/schema\"\n  include_tables:\n    - your_table_name\n  database_name: \"catalog.schema\"\n```\n\n### Running the Application\n\n1. **Invoking LangGraph:**\n\n```bash\nexport CONFIG_FILE=YOUR_CONFIG_FILE_PATH\n```\n\n```python\nfrom openchatbi import get_default_graph\n\ngraph = get_default_graph()\ngraph.invoke({\"messages\": [{\"role\": \"user\", \"content\": \"Show me ctr trends for the past 7 days\"}]},\n    config={\"configurable\": {\"thread_id\": \"1\"}})\n```\n\n```\n# System-generated SQL\nSELECT date, SUM(clicks)/SUM(impression) AS ctr\nFROM ad_performance\nWHERE date >= CURRENT_DATE - 7 DAYS\nGROUP BY date\nORDER BY date;\n```\n\n2. **Sample Web UI:**\n\nStreamlit based UI:\n```bash\nstreamlit run sample_ui streamlit_ui.py\n```\n\nRun Gradio based UI:\n```bash\npython sample_ui/streaming_ui.py\n```\n\n## Configuration Instructions\n\nThe configuration template is provided at `config.yaml.template`. Key configuration sections include:\n\n### Basic Settings\n\n- `organization`: Organization name (e.g., \"Your Company\")\n- `dialect`: Database dialect (e.g., \"presto\")\n- `bi_config_file`: Path to BI configuration file (e.g., \"example/bi.yaml\")\n\n### Catalog Store Configuration\n\n- `catalog_store`: Configuration for data catalog storage\n    - `store_type`: Storage type (e.g., \"file_system\")\n    - `data_path`: Path to catalog data stored by file system (e.g., \"./example\")\n\n### Data Warehouse Configuration\n\n- `data_warehouse_config`: Database connection settings\n    - `uri`: Connection string for your database\n    - `include_tables`: List of tables to include in catalog, leave empty to include all tables\n    - `database_name`: Database name for catalog\n    - `token_service`: Token service URL (for data warehouse that need token authentication like Presto)\n    - `user_name` / `password`: Token service credentials\n\n### LLM Configuration\n\nVarious LLMs are supported based on LangChain, see LangChain API\nDocument(https://python.langchain.com/api_reference/reference.html#integrations) for full list that support\n`chat_models`. You can configure different LLMs for different tasks:\n\n- `default_llm`: Primary language model for general tasks\n- `embedding_model`: (Optional) Model for embedding generation. If not configured, BM25-based text retrieval will be used as fallback, and the memory tools will not work\n- `text2sql_llm`: (Optional) Specialized model for SQL generation. If not configured, uses `default_llm`\n\nMultiple providers (optional):\n\n- Configure multiple providers under `llm_providers` and select with `default_llm: <provider_name>`.\n- In `sample_ui/streamlit_ui.py`, a provider dropdown appears when `llm_providers` is configured.\n- In `sample_api/async_api.py`, pass `provider` in the `/chat/stream` request body.\n\nCommonly used LLM providers and their corresponding classes and installation commands:\n\n- **Anthropic**: `langchain_anthropic.ChatAnthropic`, `pip install langchain-anthropic`\n- **OpenAI**: `langchain_openai.ChatOpenAI`, `pip install langchain-openai`\n- **Azure OpenAI**: `langchain_openai.AzureChatOpenAI`, `pip install langchain-openai`\n- **Google Vertex AI**: `langchain_google_vertexai.ChatVertexAI`, `pip install langchain-google-vertexai`\n- **Bedrock**: `langchain_aws.ChatBedrock`, `pip install langchain-aws`\n- **Huggingface**: `langchain_huggingface.ChatHuggingFace`, `pip install langchain-huggingface`\n- **Deepseek**: `langchain_deepseek.ChatDeepSeek`, `pip install langchain-deepseek`\n- **Ollama**: `langchain_ollama.ChatOllama`, `pip install langchain-ollama`\n\n### Advanced Configuration\n\nOpenChatBI supports sophisticated customization through prompt engineering and catalog management features:\n\n- **Prompt Engineering Configuration**: Customize system prompts, business glossaries, and data warehouse introductions\n- **Data Catalog Management**: Configure table metadata, column descriptions, and SQL generation rules\n- **Business Rules**: Define table selection criteria and domain-specific SQL constraints\n- **Forecasting Service**: Configure the forecasting service url and prompt based on your own deployment \n\nFor detailed configuration options and examples, see the [Advanced Features](#advanced-features) section.\n\n## Architecture Overview\n\nOpenChatBI is built using a modular architecture with clear separation of concerns:\n\n1. **LangGraph Workflows**: Core orchestration using state machines for complex multi-step processes\n2. **Catalog Management**: Flexible data catalog system with intelligent retrieval (vector-based or BM25 fallback)\n3. **Text2SQL Pipeline**: Advanced natural language to SQL conversion with schema linking\n4. **Code Execution**: Sandboxed Python execution environment for data analysis\n5. **Tool Integration**: Extensible tool system for human interaction and knowledge search\n6. **Persistent Memory**: SQLite-based conversation state management\n\n## Technology Stack\n\n- **Frameworks**: LangGraph, LangChain, FastAPI, Gradio/Streamlit\n- **Large Language Models**: Azure OpenAI (GPT-4), Anthropic Claude, OpenAI GPT models\n- **Text Retrieval**: Vector-based (with embedding models) or BM25-based (fallback without embeddings)\n- **Databases**: Presto, Trino, MySQL with SQLAlchemy support\n- **Code Execution**: Local Python, RestrictedPython, Docker containerization\n- **Development**: Python 3.11+, with modern tooling (Black, Ruff, MyPy, Pytest)\n- **Storage**: SQLite for conversation checkpointing, file system catalog storage\n\n### Agent Graph\n<img src=\"https://github.com/zhongyu09/openchatbi/raw/main/assets/agent_graph.png\" alt=\"Agent Graph\" width=\"800\">\n\n### Text2SQL Graph\n<img src=\"https://github.com/zhongyu09/openchatbi/raw/main/assets/text2sql_graph.png\" alt=\"Text2SQL Graph\" width=\"800\">\n\n## Project Structure\n\n```\nopenchatbi/\n├── README.md                    # Project documentation\n├── pyproject.toml               # Modern Python project configuration\n├── Dockerfile.python-executor  # Docker image for isolated code execution\n├── run_tests.py                # Test runner script\n├── run_streamlit_ui.py         # Streamlit UI launcher\n├── openchatbi/                 # Core application code\n│   ├── __init__.py             # Package initialization\n│   ├── config.yaml.template    # Configuration template\n│   ├── config_loader.py        # Configuration management\n│   ├── constants.py            # Application constants\n│   ├── agent_graph.py          # Main LangGraph workflow\n│   ├── graph_state.py          # State definition for workflows\n│   ├── context_config.py       # Context management configuration\n│   ├── context_manager.py      # Context window and token management\n│   ├── text_segmenter.py       # Text segmentation with jieba support\n│   ├── utils.py                # Utility functions and SimpleStore (BM25-based retrieval)\n│   ├── catalog/                # Data catalog management\n│   │   ├── __init__.py         # Package initialization\n│   │   ├── catalog_loader.py   # Catalog loading logic\n│   │   ├── catalog_store.py    # Catalog storage interface\n│   │   ├── factory.py          # Catalog factory patterns\n│   │   ├── helper.py           # Catalog helper functions\n│   │   ├── retrival_helper.py  # Retrieval helper utilities\n│   │   ├── schema_retrival.py  # Schema retrieval logic\n│   │   ├── token_service.py    # Token service integration\n│   │   └── store/              # Catalog storage implementations\n│   │       └── file_system.py  # File system-based catalog storage\n│   ├── code/                   # Code execution framework\n│   │   ├── __init__.py         # Package initialization\n│   │   ├── executor_base.py    # Base executor interface\n│   │   ├── local_executor.py   # Local Python execution\n│   │   ├── restricted_local_executor.py # RestrictedPython execution\n│   │   └── docker_executor.py  # Docker-based isolated execution\n│   ├── llm/                    # LLM integration layer\n│   │   ├── __init__.py         # Package initialization\n│   │   └── llm.py              # LLM management and retry logic\n│   ├── prompts/                # Prompt templates and engineering\n│   │   ├── __init__.py         # Package initialization\n│   │   ├── agent_prompt.md     # Main agent prompts\n│   │   ├── extraction_prompt.md # Information extraction prompts\n│   │   ├── system_prompt.py    # System prompt management\n│   │   ├── summary_prompt.md   # Summary conversation prompts\n│   │   ├── table_selection_prompt.md # Table selection prompts\n│   │   ├── text2sql_prompt.md  # Text-to-SQL prompts\n│   │   └── sql_dialect/        # SQL dialect-specific prompts\n│   ├── text2sql/               # Text-to-SQL conversion pipeline\n│   │   ├── __init__.py         # Package initialization\n│   │   ├── data.py             # Data and retriever for Text-to-SQL\n│   │   ├── extraction.py       # Information extraction\n│   │   ├── generate_sql.py     # SQL generation and execution logic\n│   │   ├── schema_linking.py   # Schema linking process\n│   │   ├── sql_graph.py        # SQL generation LangGraph workflow\n│   │   ├── text2sql_utils.py   # Text2SQL utilities\n│   │   └── visualization.py    # Data visualization functions\n│   └── tool/                   # LangGraph tools and functions\n│       ├── ask_human.py        # Human-in-the-loop interactions\n│       ├── memory.py           # Memory management tool\n│       ├── mcp_tools.py        # MCP (Model Context Protocol) integration\n│       ├── run_python_code.py  # Configurable Python code execution\n│       ├── save_report.py      # Report saving functionality\n│       ├── search_knowledge.py # Knowledge base search\n│       └── timeseries_forecast.py # Time series forecasting tool\n├── sample_api/                 # API implementations\n│   └── async_api.py            # Asynchronous FastAPI example\n├── sample_ui/                  # Web interface implementations\n│   ├── memory_ui.py            # Memory-enhanced UI interface\n│   ├── plotly_utils.py         # Plotly utilities and helpers\n│   ├── simple_ui.py            # Simple non-streaming Gradio UI\n│   ├── streaming_ui.py         # Streaming Gradio UI with real-time updates\n│   ├── streamlit_ui.py         # Streaming Streamlit UI with enhanced features\n│   └── style.py                # UI styling and CSS\n├── example/                    # Example configurations and data\n│   ├── bi.yaml                 # BI configuration example\n│   ├── config.yaml             # Application config example\n│   ├── table_info.yaml         # Table information\n│   ├── table_columns.csv       # Table column registry\n│   ├── common_columns.csv      # Common column definitions\n│   ├── sql_example.yaml        # SQL examples for retrieval\n│   ├── table_selection_example.csv # Table selection examples\n│   └── tracking_orders.sqlite  # Sample SQLite database\n├── timeseries_forecasting/     # Time series forecasting service\n│   ├── README.md               # Forecasting service documentation\n│   └── ...                     # Forecasting service implementation\n├── tests/                      # Test suite\n│   ├── __init__.py             # Package initialization\n│   ├── conftest.py             # Test configuration\n│   ├── test_*.py               # Test modules for various components\n│   └── README.md               # Testing documentation\n├── docs/                       # Documentation\n│   ├── source/                 # Sphinx documentation source\n│   ├── build/                  # Built documentation\n│   ├── Makefile                # Documentation build scripts\n│   └── make.bat                # Windows build script\n└── .github/                    # GitHub workflows and templates\n    └── workflows/              # CI/CD workflows\n```\n\n## Advanced Features\n\n### Visualization configuration\nYou can choose rule-based or llm-based visualization or disable visualization.\n```yaml\n# Options: \"rule\" (rule-based), \"llm\" (LLM-based), or null (skip visualization)\nvisualization_mode: llm\n```\n\n### Prompt Engineering\n#### Basic Knowledge & Glossary\n\nYou can define basic knowledge and glossary in `example/bi.yaml`, for example:\n\n```yaml\nbasic_knowledge_glossary: |\n  # Basic Knowledge Introduction\n    The basic knowledge about your company and its business, including key concepts, metrics, and processes.\n  # Glossary\n    Common terms and their definitions used in your business context.\n```\n\n#### Data Warehouse Introduction\n\nYou can provide a brief introduction of your data warehouse in `example/bi.yaml`, for example:\n\n```yaml\ndata_warehouse_introduction: |\n  # Data Warehouse Introduction\n    This data warehouse is built on Presto and contains various tables related to XXXXX.\n    The main fact tables include XXXX metrics, while dimension tables include XXXXX.\n    The data is updated hourly and is used for reporting and analysis purposes.\n```\n\n#### Table Selection Rules\n\nYou can configure table selection rules in `example/bi.yaml`, for example:\n\n```yaml\ntable_selection_extra_rule: |\n  - All tables with is_valid can support both valid and invalid traffics\n```\n\n#### Custom SQL Rules\n\nYou can define your additional SQL Generation rules for tables in `example/table_info.yaml`, for example:\n\n```yaml\nsql_rule: |\n  ### SQL Rules\n  - All event_date in the table are stored in **UTC**. If the user specifies a timezone (e.g., CET, PST), convert between timezones accordingly.\n\n```\n\n\n### Catalog Management\n\n#### Introduction\n\nHigh-quality catalog data is essential for accurate Text2SQL generation and data analysis. OpenChatBI automatically \ndiscovers and indexes data warehouse table structures while providing flexible management for business metadata, column \ndescriptions, and query optimization rules.\n\n#### Catalog Structure\n\nThe catalog system organizes metadata in a hierarchical structure:\n\n**Database Level**\n- Top-level container for all tables and schemas\n\n**Table Level**\n- `description`: Business functionality and purpose of the table\n- `selection_rule`: Guidelines for when and how to use this table in queries\n- `sql_rule`: Specific SQL generation rules and constraints for this table\n\n**Column Level**\n- **Required Fields**: Essential metadata for each column to enable effective Text2SQL generation\n  - `column_name`: Technical database column name\n  - `display_name`: Human-readable name for business users\n  - `alias`: Alternative names or abbreviations\n  - `type`: Data type (string, integer, date, etc.)\n  - `category`: Business category, dimension or metric\n  - `tag`: Additional labels for filtering and organization\n  - `description`: Detailed explanation of column purpose and usage\n- **Two Types** of Columns\n  - **Common Columns**: Columns with standardized business meanings shared across tables\n  - **Table-Specific Columns**: Columns with context-dependent meanings that vary between tables\n- **Derived Metrics**: Virtual metrics calculated from existing columns using SQL formulas\n  - Computed dynamically during query execution rather than stored as physical columns\n  - Examples: CTR (clicks/impressions), conversion rates, profit margins\n  - Enable complex business calculations without pre-computing values\n  \n#### Loading Catalog from Database\n\nOpenChatBI can automatically discover and load table structures from your data warehouse:\n\n1. **Automatic Discovery**: Connects to your configured data warehouse and scans table schemas\n2. **Metadata Extraction**: Extracts column names, data types, and basic structural information\n3. **Incremental Updates**: Supports updating catalog data as your database schema evolves\n\nConfigure automatic catalog loading in your `config.yaml`:\n\n```yaml\ncatalog_store:\n  store_type: file_system\n  data_path: ./catalog_data\ndata_warehouse_config:\n  include_tables:\n    - your_table_pattern\n  # Leave empty to include all accessible tables\n```\n\n#### File System Catalog Store\n\nThe file system catalog store organizes metadata across multiple files for maintainability and version control:\n\n**Core Table Information**\n- `table_info.yaml`: Comprehensive table metadata organized hierarchically (database → table → information)\n  - `type`: Table classification (e.g., \"fact\" for Fact Tables, \"dimension\" for Dimension Tables)\n  - `description`: Business functionality and purpose\n  - `selection_rule`: Usage guidelines in markdown list format (each line starts with `-`)\n  - `sql_rule`: SQL generation rules in markdown header format (each rule starts with `####`)\n  - `derived_metric`: Virtual metrics with calculation formulas, organized by groups:\n    ```md\n    #### Derived Ratio Metrics\n    Click-through Rate (alias CTR): SUM(clicks) / SUM(impression)\n    Conversion Rate (alias CVR): SUM(conversions) / SUM(clicks)\n    ```\n\n**Column Management**\n- `table_columns.csv`: Basic column registry with schema `db_name,table_name,column_name`\n- `table_spec_columns.csv`: Table-specific column metadata with full schema:\n  `db_name,table_name,column_name,display_name,alias,type,category,tag,description`\n- `common_columns.csv`: Shared column definitions across tables with schema:\n  `column_name,display_name,alias,type,category,tag,description`\n\n**Query Examples and Training Data**\n- `table_selection_example.csv`: Table selection training examples with schema `question,selected_tables`\n- `sql_example.yaml`: Query examples organized by database and table structure:\n  ```yaml\n  your_database:\n    ad_performance: |\n      Q: Show me CTR trends for the past 7 days\n      A: SELECT date, SUM(clicks)/SUM(impressions) AS ctr\n         FROM ad_performance\n         WHERE date >= CURRENT_DATE - INTERVAL 7 DAY\n         GROUP BY date\n         ORDER BY date;\n  ```\n\n### Time Series Forecasting Service Setup\n\nOpenChatBI can integrate with a time series forecasting service for advanced predictive analytics. Follow these steps to set up the service:\n\n#### 1. Build and Run the Forecasting Service\n\nSee detailed instructions in [timeseries_forecasting/README.md](timeseries_forecasting/README.md)\n\nQuick start:\n```bash\ncd timeseries_forecasting\n./build_and_run.sh\n```\n\n#### 2. Configure Tool Usage Rules\n\nIn your `bi.yaml`, add constraints for the timeseries_forecast tool, e.g. if you are using `timer-base-84m` model:\n```yaml\nextra_tool_use_rule: |\n  - timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros.\n```\n\n#### 3. Configure Service URL\n\nIn your `config.yaml`:\n```yaml\n# Time Series Forecasting Service Configuration\ntimeseries_forecasting_service_url: \"http://localhost:8765\"\n```\n\n**Important**: Adjust the URL based on your deployment scenario:\n- **Local development** (OpenChatBI on host, Forecasting service in Docker): `http://localhost:8765`\n- **Remote service**: `http://your-service-host:8765`\n\n\n#### 4. Verify Service Health\n\nTest the service is accessible:\n```bash\ncurl http://localhost:8765/health\n```\n\nExpected response:\n```json\n{\n  \"status\": \"healthy\",\n  \"model_initialized\": true,\n  \"uptime_seconds\": 123.45\n}\n``` \n\n### Python Code Execution Configuration\n\nOpenChatBI supports multiple execution environments for running Python code with different security and performance characteristics:\n\n```yaml\n# Python Code Execution Configuration\npython_executor: local  # Options: \"local\", \"restricted_local\", \"docker\"\n```\n\n#### Executor Types\n\n- **`local`** (Default)\n  - **Performance**: Fastest execution\n  - **Security**: Least secure (code runs in current Python process)\n  - **Capabilities**: Full Python capabilities and library access\n  - **Use Case**: Development environments, trusted code execution\n\n- **`restricted_local`**\n  - **Performance**: Moderate execution speed\n  - **Security**: Moderate security with RestrictedPython sandboxing\n  - **Capabilities**: Limited Python features (no imports, file access, etc.)\n  - **Use Case**: Semi-trusted environments with controlled execution\n\n- **`docker`**\n  - **Performance**: Slower due to container overhead\n  - **Security**: Highest security with complete process isolation\n  - **Capabilities**: Full Python capabilities within isolated container\n  - **Use Case**: Production environments, untrusted code execution\n  - **Requirements**: Docker must be installed and running\n\n#### Docker Executor Setup\n\nFor production deployments or when running untrusted code, the Docker executor provides complete isolation:\n\n1. **Install Docker**: Download and install Docker Desktop or Docker Engine\n2. **Configure executor**: Set `python_executor: docker` in your config\n3. **Automatic setup**: OpenChatBI will automatically build the required Docker image\n4. **Fallback behavior**: If Docker is unavailable, automatically falls back to local executor\n\n**Docker Executor Features**:\n- Pre-installed data science libraries (pandas, numpy, matplotlib, seaborn)\n- Network isolation for security\n- Automatic container cleanup\n- Resource isolation from host system\n\n## Development & Testing\n\n### Code Quality Tools\n\nThe project uses modern Python tooling for code quality:\n\n```bash\n# Format code\nuv run black .\n\n# Lint code  \nuv run ruff check .\n\n# Type checking\nuv run mypy openchatbi/\n\n# Security scanning\nuv run bandit -r openchatbi/\n```\n\n### Testing\n\nRun the test suite:\n\n```bash\n# Run all tests\nuv run pytest\n\n# Run with coverage\nuv run pytest --cov=openchatbi --cov-report=html\n\n# Run specific test files\nuv run pytest test/test_generate_sql.py\nuv run pytest test/test_agent_graph.py\n```\n\n### Pre-commit Hooks\n\nInstall pre-commit hooks for automatic code quality checks:\n\n```bash\nuv run pre-commit install\n```\n\n## Contribution Guidelines\n\n1. Fork the repository\n2. Create a feature branch (`git checkout -b feature/fooBar`)\n3. Commit your changes (`git commit -am 'Add some fooBar'`)\n4. Push to the branch (`git push origin feature/fooBar`)\n5. Create a new Pull Request\n\n## License\n\nThis project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details\n\n## Contact & Support\n\n- **Author**: Yu Zhong ([zhongyu8@gmail.com](mailto:zhongyu8@gmail.com))\n- **Repository**: [github.com/zhongyu09/openchatbi](https://github.com/zhongyu09/openchatbi)\n- **Issues**: [Report bugs and feature requests](https://github.com/zhongyu09/openchatbi/issues)\n"
  },
  {
    "path": "baselines/runledger-openchatbi.json",
    "content": "{\n  \"aggregates\": {\n    \"cases_error\": 0,\n    \"cases_fail\": 0,\n    \"cases_pass\": 1,\n    \"cases_total\": 1,\n    \"metrics\": {\n      \"cost_usd\": {\n        \"max\": null,\n        \"mean\": null,\n        \"min\": null,\n        \"p50\": null,\n        \"p95\": null\n      },\n      \"steps\": {\n        \"max\": null,\n        \"mean\": null,\n        \"min\": null,\n        \"p50\": null,\n        \"p95\": null\n      },\n      \"tokens_in\": {\n        \"max\": null,\n        \"mean\": null,\n        \"min\": null,\n        \"p50\": null,\n        \"p95\": null\n      },\n      \"tokens_out\": {\n        \"max\": null,\n        \"mean\": null,\n        \"min\": null,\n        \"p50\": null,\n        \"p95\": null\n      },\n      \"tool_calls\": {\n        \"max\": 1.0,\n        \"mean\": 1.0,\n        \"min\": 1.0,\n        \"p50\": 1.0,\n        \"p95\": 1.0\n      },\n      \"tool_errors\": {\n        \"max\": 0.0,\n        \"mean\": 0.0,\n        \"min\": 0.0,\n        \"p50\": 0.0,\n        \"p95\": 0.0\n      },\n      \"wall_ms\": {\n        \"max\": 1.0,\n        \"mean\": 1.0,\n        \"min\": 1.0,\n        \"p50\": 1.0,\n        \"p95\": 1.0\n      }\n    },\n    \"pass_rate\": 1.0\n  },\n  \"cases\": [\n    {\n      \"assertions\": {\n        \"failed\": 0,\n        \"total\": 1\n      },\n      \"cost_usd\": null,\n      \"failed_assertions\": null,\n      \"failure_reason\": null,\n      \"id\": \"t1\",\n      \"replay\": {\n        \"cassette_path\": \"evals/runledger/cassettes/t1.jsonl\",\n        \"cassette_sha256\": \"7e9830609490d140bf09178106dfa647bba4c9ec15859072b5aa2c3ae1659289\"\n      },\n      \"status\": \"pass\",\n      \"steps\": null,\n      \"tokens_in\": null,\n      \"tokens_out\": null,\n      \"tool_calls\": 1,\n      \"tool_calls_by_name\": {\n        \"search_knowledge\": 1\n      },\n      \"tool_errors\": 0,\n      \"tool_errors_by_name\": {},\n      \"wall_ms\": 1\n    }\n  ],\n  \"generated_at\": \"2026-01-03T19:10:00Z\",\n  \"run\": {\n    \"ci\": null,\n    \"exit_status\": \"success\",\n    \"git_sha\": null,\n    \"mode\": \"replay\",\n    \"run_id\": \"baseline\"\n  },\n  \"runledger_version\": \"0.1.1\",\n  \"schema_version\": 1,\n  \"suite\": {\n    \"agent_command\": [\n      \"python\",\n      \"evals/runledger/agent/agent.py\"\n    ],\n    \"cases_total\": 1,\n    \"name\": \"runledger-openchatbi\",\n    \"suite_config_hash\": null,\n    \"suite_path\": \"evals/runledger/suite.yaml\",\n    \"tool_mode\": \"replay\"\n  }\n}\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=source\r\nset BUILDDIR=build\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.https://www.sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/source/_templates/layout.html",
    "content": "{% extends \"!layout.html\" %}\n\n{% block extrahead %}\n  {{ super() }}\n  <meta name=\"google-site-verification\" content=\"geDcsz839O_UHavbn1pIpMOk6sJgneL4NlULcVJ4-KA\" />\n  <script async src=\"https://www.googletagmanager.com/gtag/js?id=AW-17595718197\"></script>\n  <script>\n    window.dataLayer = window.dataLayer || [];\n    function gtag(){dataLayer.push(arguments);}\n    gtag('js', new Date());\n    gtag('config', 'AW-17595718197');\n  </script>\n  <script>\n    gtag('event', 'conversion', {\n        'send_to': 'AW-17595718197/JxBiCPzC06AbELW0pcZB',\n        'value': 1.0,\n        'currency': 'SGD'\n    });\n  </script>\n{% endblock %}"
  },
  {
    "path": "docs/source/catalog.rst",
    "content": "Catalog System\n==============\n\nOverview\n--------\n\nThe catalog system manages metadata for database tables, columns, and business rules.\n\nCatalog Store\n-------------\n\n.. automodule:: openchatbi.catalog.catalog_store\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nFilesystem Implementation\n^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. automodule:: openchatbi.catalog.store.file_system\n    :members:\n    :show-inheritance:\n\nCatalog Loader\n--------------\n\n.. automodule:: openchatbi.catalog.catalog_loader\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nSchema Retrieval\n----------------\n\n.. automodule:: openchatbi.catalog.schema_retrival\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/source/code.rst",
    "content": "Code Execution\n==============\n\nCode Module\n-----------\n\n.. automodule:: openchatbi.code\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nExecutor Base\n-------------\n\n.. automodule:: openchatbi.code.executor_base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nLocal Executor\n--------------\n\n.. automodule:: openchatbi.code.local_executor\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# For the full list of built-in configuration values, see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Project information -----------------------------------------------------\n# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information\nimport os\nimport sys\n\nsys.path.insert(0, os.path.abspath(\"../..\"))\n\nproject = \"OpenChatBI\"\ncopyright = \"2025, Yu Zhong\"\nauthor = \"Yu Zhong\"\nrelease = \"0.2.2\"\n\n# -- General configuration ---------------------------------------------------\n# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration\n\n# Mock dependencies for documentation build\nautodoc_mock_imports = []\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx.ext.githubpages\",\n    \"myst_parser\",\n]\n\n# Set an environment variable to indicate we're building docs\nimport os\n\nos.environ[\"SPHINX_BUILD\"] = \"1\"\n\n# MyST parser configuration\nmyst_enable_extensions = [\n    \"colon_fence\",\n    \"deflist\",\n    \"html_admonition\",\n    \"html_image\",\n]\nmyst_heading_anchors = 3\n\ntemplates_path = [\"_templates\"]\nexclude_patterns = []\n\n# Autodoc configuration\nautodoc_default_options = {\n    \"members\": True,\n    \"member-order\": \"bysource\",\n    \"special-members\": \"__init__\",\n    \"undoc-members\": True,\n    \"exclude-members\": \"__weakref__\",\n}\n\n# Napoleon configuration for Google/NumPy style docstrings\nnapoleon_google_docstring = True\nnapoleon_numpy_docstring = True\nnapoleon_include_init_with_doc = False\nnapoleon_include_private_with_doc = False\nnapoleon_include_special_with_doc = True\nnapoleon_use_admonition_for_examples = False\nnapoleon_use_admonition_for_notes = False\nnapoleon_use_admonition_for_references = False\nnapoleon_use_ivar = False\nnapoleon_use_param = True\nnapoleon_use_rtype = True\nnapoleon_preprocess_types = False\nnapoleon_type_aliases = None\nnapoleon_attr_annotations = True\n\n# -- Options for HTML output -------------------------------------------------\n# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output\n\nhtml_theme = \"sphinx_rtd_theme\"\nhtml_static_path = [\"_static\"]\n\n# GitHub Pages configuration\nhtml_baseurl = \"https://zhongyu09.github.io/openchatbi/\"\n\n# Theme options for RTD theme\nhtml_theme_options = {\n    \"navigation_depth\": 4,\n    \"collapse_navigation\": False,\n    \"sticky_navigation\": True,\n    \"includehidden\": True,\n    \"titles_only\": False,\n}\n"
  },
  {
    "path": "docs/source/config.rst",
    "content": "Configuration\n=============\n\nThe configuration system consists of two main classes:\n\n- **Config**: Defines the configuration model.\n- **ConfigLoader**: Manages loading and accessing configuration.\n\nConfig\n------\n\n.. autoclass:: openchatbi.config_loader.Config\n    :exclude-members: organization, dialect, default_llm, embedding_model, text2sql_llm, bi_config, data_warehouse_config, catalog_store, mcp_servers, report_directory, python_executor\n\nConfigLoader\n------------\n\n.. autoclass:: openchatbi.config_loader.ConfigLoader\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/source/core.rst",
    "content": "Core Module\n===========\n\nMain Module\n-----------\n\n.. automodule:: openchatbi\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nAgent Graph\n-----------\n\n.. automodule:: openchatbi.agent_graph\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nState Management\n----------------\n\n.. automodule:: openchatbi.graph_state\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nUtilities\n---------\n\n.. automodule:: openchatbi.utils\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/source/index.rst",
    "content": "OpenChatBI Documentation\n========================\n\n`GitHub Repository <https://github.com/zhongyu09/openchatbi>`_\n\n.. include:: ../../README.md\n   :parser: myst_parser.sphinx_\n\n.. toctree::\n   :maxdepth: 4\n   :caption: Documentation:\n   :titlesonly:\n\n   self\n\n.. toctree::\n   :maxdepth: 2\n   :caption: API Reference:\n\n   Core Module <core>\n   Configuration <config>\n   Catalog System <catalog>\n   Text2SQL System <text2sql>\n   Code Execution <code>\n   LLM Integration <llm>\n   Tools and Utilities <tools>\n   Time Series Forecasting Service <timeseries>\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/source/llm.rst",
    "content": "LLM Integration\n===============\n\nLLM Module\n----------\n\n.. automodule:: openchatbi.llm\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nLLM Implementation\n------------------\n\n.. automodule:: openchatbi.llm.llm\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/source/text2sql.rst",
    "content": "Text2SQL System\n===============\n\nOverview\n--------\n\nNatural language to SQL conversion pipeline with schema linking and prompt engineering.\n\n\nSQL Graph\n---------\n\n.. automodule:: openchatbi.text2sql.sql_graph\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nSQL Generation\n--------------\n\n.. automodule:: openchatbi.text2sql.generate_sql\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nSchema Linking\n--------------\n\n.. automodule:: openchatbi.text2sql.schema_linking\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nInformation Extraction\n----------------------\n\n.. automodule:: openchatbi.text2sql.extraction\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nText2SQL Utilities\n-------------------\n\n.. automodule:: openchatbi.text2sql.text2sql_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/source/timeseries.rst",
    "content": "Time Series Forecasting Service\n========================\n\n`GitHub Repository <https://github.com/zhongyu09/openchatbi/timeseries_forecasting>`_\n\n.. include:: ../../timeseries_forecasting/README.md\n   :parser: myst_parser.sphinx_\n"
  },
  {
    "path": "docs/source/tools.rst",
    "content": "Tools and Utilities\n===================\n\nOverview\n--------\n\nLangGraph tools for human interaction, code execution, and knowledge search.\n\nPython Code Execution\n----------------------\n\n.. automodule:: openchatbi.tool.run_python_code\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nHuman Interaction\n-----------------\n\n.. automodule:: openchatbi.tool.ask_human\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nMemory Management\n-----------------\n\n.. automodule:: openchatbi.tool.memory\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nKnowledge Search\n----------------\n\n.. automodule:: openchatbi.tool.search_knowledge\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "evals/__init__.py",
    "content": "\"\"\"Evaluation suites for RunLedger.\"\"\"\n"
  },
  {
    "path": "evals/runledger/README.md",
    "content": "# RunLedger eval (OpenChatBI)\n\nThis suite is **replay-only** by default. It runs a deterministic CI check using a JSONL adapter that proxies tool calls through RunLedger and replays results from a cassette.\n\n## Run (replay)\n\n```bash\nrunledger run evals/runledger --mode replay --baseline baselines/runledger-openchatbi.json\n```\n\n## Record / update cassette (optional)\n\nIf you want to re-record the cassette with real tool outputs, run in record mode in a fully configured OpenChatBI environment (valid `openchatbi/config.yaml`, data warehouse/catalog, LLM keys).\n\n```bash\nrunledger run evals/runledger --mode record \\\n  --baseline baselines/runledger-openchatbi.json \\\n  --tool-module evals.runledger.tools\n```\n\nNotes:\n- Tool args are passed as JSON objects; see `evals/runledger/cassettes/t1.jsonl` for the exact shape.\n- After recording, promote the new baseline:\n\n```bash\nrunledger baseline promote \\\n  --from runledger_out/runledger-openchatbi/<run_id> \\\n  --to baselines/runledger-openchatbi.json\n```\n\n"
  },
  {
    "path": "evals/runledger/__init__.py",
    "content": "\"\"\"RunLedger eval suite for OpenChatBI.\"\"\"\n"
  },
  {
    "path": "evals/runledger/agent/agent.py",
    "content": "import json\nimport sys\nfrom itertools import count\nfrom typing import Any\nfrom unittest.mock import MagicMock\n\nimport builtins\nfrom langchain_core.messages import AIMessage, HumanMessage, ToolMessage\nfrom langchain_core.tools import StructuredTool\nfrom langgraph.checkpoint.memory import MemorySaver\nfrom pydantic import BaseModel, Field\n\nfrom openchatbi import config\nimport openchatbi.agent_graph as agent_graph\n\n\n_CALL_COUNTER = count(1)\n_ORIG_PRINT = builtins.print\n\n\ndef _safe_print(*args: Any, **kwargs: Any) -> None:\n    \"\"\"Suppress stdout prints so JSONL stays clean; allow stderr.\"\"\"\n    target = kwargs.get(\"file\")\n    if target is None or target is sys.stdout:\n        return\n    _ORIG_PRINT(*args, **kwargs)\n\n\nbuiltins.print = _safe_print\n\n\nclass JsonlChannel:\n    def __init__(self, stream: Any) -> None:\n        self._stream = stream\n\n    def read(self) -> dict[str, Any] | None:\n        while True:\n            line = self._stream.readline()\n            if not line:\n                return None\n            line = line.strip()\n            if not line:\n                continue\n            try:\n                return json.loads(line)\n            except json.JSONDecodeError:\n                continue\n\n    @staticmethod\n    def send(payload: dict[str, Any]) -> None:\n        sys.stdout.write(json.dumps(payload) + \"\\n\")\n        sys.stdout.flush()\n\n\ndef _last_user_text(messages: list[Any]) -> str:\n    for message in reversed(messages):\n        if isinstance(message, HumanMessage):\n            return str(message.content).strip()\n    return \"OpenChatBI\"\n\n\ndef _runledger_tool_call(channel: JsonlChannel, name: str, args: dict[str, Any]) -> Any:\n    call_id = f\"c{next(_CALL_COUNTER)}\"\n    channel.send({\"type\": \"tool_call\", \"name\": name, \"call_id\": call_id, \"args\": args})\n    while True:\n        message = channel.read()\n        if message is None:\n            raise RuntimeError(\"Tool result missing\")\n        if message.get(\"type\") != \"tool_result\":\n            continue\n        if message.get(\"call_id\") != call_id:\n            continue\n        if message.get(\"ok\"):\n            return message.get(\"result\")\n        raise RuntimeError(message.get(\"error\") or \"Tool error\")\n\n\nclass SearchKnowledgeInput(BaseModel):\n    reasoning: str = Field(description=\"Reason for searching knowledge\")\n    query_list: list[str] = Field(description=\"Query terms\")\n    knowledge_bases: list[str] = Field(description=\"Knowledge bases to search\")\n    with_table_list: bool = Field(default=False, description=\"Include table list\")\n\n\nclass ShowSchemaInput(BaseModel):\n    reasoning: str = Field(description=\"Reason for showing schema\")\n    tables: list[str] = Field(description=\"Table names\")\n\n\nclass Text2SQLInput(BaseModel):\n    reasoning: str = Field(description=\"Reason for calling text2sql\")\n    context: str = Field(description=\"Full context for the SQL graph\")\n\n\nclass RunPythonInput(BaseModel):\n    reasoning: str = Field(description=\"Reason for running python code\")\n    code: str = Field(description=\"Python code to execute\")\n\n\nclass SaveReportInput(BaseModel):\n    content: str = Field(description=\"Report content\")\n    title: str = Field(description=\"Report title\")\n    file_format: str = Field(description=\"File extension\")\n\n\ndef _build_tool_proxies(channel: JsonlChannel) -> dict[str, StructuredTool]:\n    def search_knowledge(\n        reasoning: str,\n        query_list: list[str],\n        knowledge_bases: list[str],\n        with_table_list: bool = False,\n    ) -> Any:\n        return _runledger_tool_call(\n            channel,\n            \"search_knowledge\",\n            {\n                \"reasoning\": reasoning,\n                \"query_list\": query_list,\n                \"knowledge_bases\": knowledge_bases,\n                \"with_table_list\": with_table_list,\n            },\n        )\n\n    def show_schema(reasoning: str, tables: list[str]) -> Any:\n        return _runledger_tool_call(\n            channel,\n            \"show_schema\",\n            {\"reasoning\": reasoning, \"tables\": tables},\n        )\n\n    def text2sql(reasoning: str, context: str) -> Any:\n        return _runledger_tool_call(\n            channel,\n            \"text2sql\",\n            {\"reasoning\": reasoning, \"context\": context},\n        )\n\n    def run_python_code(reasoning: str, code: str) -> Any:\n        return _runledger_tool_call(\n            channel,\n            \"run_python_code\",\n            {\"reasoning\": reasoning, \"code\": code},\n        )\n\n    def save_report(content: str, title: str, file_format: str = \"md\") -> Any:\n        return _runledger_tool_call(\n            channel,\n            \"save_report\",\n            {\"content\": content, \"title\": title, \"file_format\": file_format},\n        )\n\n    return {\n        \"search_knowledge\": StructuredTool.from_function(\n            func=search_knowledge,\n            name=\"search_knowledge\",\n            description=\"RunLedger proxy for search_knowledge\",\n            args_schema=SearchKnowledgeInput,\n        ),\n        \"show_schema\": StructuredTool.from_function(\n            func=show_schema,\n            name=\"show_schema\",\n            description=\"RunLedger proxy for show_schema\",\n            args_schema=ShowSchemaInput,\n        ),\n        \"text2sql\": StructuredTool.from_function(\n            func=text2sql,\n            name=\"text2sql\",\n            description=\"RunLedger proxy for text2sql\",\n            args_schema=Text2SQLInput,\n        ),\n        \"run_python_code\": StructuredTool.from_function(\n            func=run_python_code,\n            name=\"run_python_code\",\n            description=\"RunLedger proxy for run_python_code\",\n            args_schema=RunPythonInput,\n        ),\n        \"save_report\": StructuredTool.from_function(\n            func=save_report,\n            name=\"save_report\",\n            description=\"RunLedger proxy for save_report\",\n            args_schema=SaveReportInput,\n        ),\n    }\n\n\ndef _stub_llm_call(chat_model: Any, messages: list[Any], **_kwargs: Any) -> AIMessage:\n    tool_seen = any(isinstance(msg, ToolMessage) or getattr(msg, \"type\", None) == \"tool\" for msg in messages)\n    if tool_seen:\n        return AIMessage(content=\"Here is a deterministic summary based on the tool result.\", tool_calls=[])\n\n    user_text = _last_user_text(messages)\n    tool_args = {\n        \"reasoning\": \"Look up relevant knowledge\",\n        \"query_list\": [user_text],\n        \"knowledge_bases\": [\"columns\"],\n        \"with_table_list\": False,\n    }\n    return AIMessage(\n        content=\"Searching knowledge base.\",\n        tool_calls=[{\"name\": \"search_knowledge\", \"args\": tool_args, \"id\": \"call_1\"}],\n    )\n\n\ndef _configure_agent_graph(channel: JsonlChannel) -> None:\n    tool_proxies = _build_tool_proxies(channel)\n\n    agent_graph.search_knowledge = tool_proxies[\"search_knowledge\"]\n    agent_graph.show_schema = tool_proxies[\"show_schema\"]\n    agent_graph.run_python_code = tool_proxies[\"run_python_code\"]\n    agent_graph.save_report = tool_proxies[\"save_report\"]\n    agent_graph.get_sql_tools = lambda *_args, **_kwargs: tool_proxies[\"text2sql\"]\n    agent_graph.build_sql_graph = lambda *_args, **_kwargs: object()\n    agent_graph.get_memory_tools = lambda *_args, **_kwargs: []\n    agent_graph.create_mcp_tools_sync = lambda *_args, **_kwargs: []\n    agent_graph.check_forecast_service_health = lambda: False\n    agent_graph.call_llm_chat_model_with_retry = _stub_llm_call\n\n\ndef _bootstrap_config() -> None:\n    config.set(\n        {\n            \"default_llm\": MagicMock(),\n            \"data_warehouse_config\": {},\n            \"catalog_store\": {\"store_type\": \"file_system\", \"auto_load\": False},\n        }\n    )\n\n\ndef main() -> int:\n    channel = JsonlChannel(sys.stdin)\n    message = channel.read()\n    if not message or message.get(\"type\") != \"task_start\":\n        return 1\n\n    _bootstrap_config()\n    _configure_agent_graph(channel)\n\n    prompt = \"\"\n    payload = message.get(\"input\", {})\n    if isinstance(payload, dict):\n        prompt = payload.get(\"prompt\") or payload.get(\"question\") or payload.get(\"query\") or \"\"\n    if not prompt:\n        prompt = \"OpenChatBI\"\n\n    graph = agent_graph.build_agent_graph_sync(\n        catalog=config.get().catalog_store,\n        checkpointer=MemorySaver(),\n        memory_store=None,\n        enable_context_management=False,\n    )\n\n    result = graph.invoke({\"messages\": [{\"role\": \"user\", \"content\": prompt}]})\n    output = \"\"\n    if isinstance(result, dict) and result.get(\"messages\"):\n        output = str(result[\"messages\"][-1].content)\n\n    channel.send(\n        {\n            \"type\": \"final_output\",\n            \"output\": {\"category\": \"bi\", \"reply\": output or \"Completed request.\"},\n        }\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "evals/runledger/cases/t1.yaml",
    "content": "id: t1\ndescription: \"basic BI flow with a single search_knowledge tool call\"\ninput:\n  prompt: \"OpenChatBI\"\ncassette: cassettes/t1.jsonl\n"
  },
  {
    "path": "evals/runledger/cassettes/t1.jsonl",
    "content": "{\"tool\":\"search_knowledge\",\"args\":{\"knowledge_bases\":[\"columns\"],\"query_list\":[\"OpenChatBI\"],\"reasoning\":\"Look up relevant knowledge\",\"with_table_list\":false},\"ok\":true,\"result\":{\"columns\":\"# Relevant Columns and Description:\\n## openchatbi\\n- Column Category: metric\\n- Display Name: OpenChatBI\\n- Description \\\"Project overview\\\"\"}}\n"
  },
  {
    "path": "evals/runledger/schema.json",
    "content": "{\n  \"type\": \"object\",\n  \"properties\": {\n    \"category\": {\n      \"type\": \"string\"\n    },\n    \"reply\": {\n      \"type\": \"string\"\n    }\n  },\n  \"required\": [\n    \"category\",\n    \"reply\"\n  ]\n}\n"
  },
  {
    "path": "evals/runledger/suite.yaml",
    "content": "suite_name: runledger-openchatbi\nagent_command: [\"python\", \"evals/runledger/agent/agent.py\"]\nmode: replay\ncases_path: cases\ntool_registry:\n  - search_knowledge\ntool_module: evals.runledger.tools\n\nassertions:\n  - type: json_schema\n    schema_path: schema.json\n\nbudgets:\n  max_wall_ms: 20000\n  max_tool_calls: 5\n  max_tool_errors: 0\n\nbaseline_path: ../../baselines/runledger-openchatbi.json\n"
  },
  {
    "path": "evals/runledger/tools.py",
    "content": "from __future__ import annotations\n\nfrom typing import Any\n\nfrom openchatbi.tool.search_knowledge import search_knowledge\n\n\ndef _invoke_tool(tool, args: dict[str, Any]) -> Any:\n    return tool.invoke(args)\n\n\ndef _search_knowledge(args: dict[str, Any]) -> Any:\n    return _invoke_tool(search_knowledge, args)\n\n\nTOOLS = {\n    \"search_knowledge\": _search_knowledge,\n}\n"
  },
  {
    "path": "example/bi.yaml",
    "content": "extra_tool_use_rule: |\n  - Try your best to give appropriate parameters when calling tools.\n  - timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros.\n\ntable_selection_extra_rule: |\n  - When users ask about orders, consider if they need customer information (join with Customers table)\n  - For product-related queries, check if order information is needed (join with Order_Items)  \n  - Shipment queries often require order and product details (join with multiple tables)\n  - Invoice questions may need shipment information for complete tracking\n\ntext2sql_extra_rule: |\n  - Use proper JOIN syntax when connecting related tables\n  - Use LIKE operator for partial string matches in product names or customer names\n  - Handle NULL values properly in optional fields like details columns\n\nbasic_knowledge_glossary: |\n  # Sales Business System Glossary\n  \n  ## Overview\n  You're answering questions related to a sales order tracking business system that manages the complete customer order lifecycle from placement to delivery.\n\n  ## Key Business Concepts\n  \n  **Customer Management:**\n  - Customer: Individual or entity who places orders\n  - Customer Details: Additional information like contact info, preferences, or notes\n  \n  **Order Processing:**\n  - Order: A request from a customer to purchase products\n  - Order Status: Current state - Valid values: \"Shipped\", \"Packing\", \"On Road\"\n  - Order Item: Individual product within an order (orders can contain multiple items)\n  - Order Item Status: Status of specific items - Valid values: \"Finish\", \"Payed\", \"Cancel\"\n  \n  **Product Catalog:**\n  - Product: Items available for purchase\n  - Product Details: Specifications, descriptions, or additional product information\n  \n  **Fulfillment & Shipping:**\n  - Shipment: Physical delivery package sent to customer\n  - Shipment Items: Specific order items included in a shipment\n  - Tracking Number: Unique identifier for package tracking\n  - Shipment Date: When package was dispatched\n  \n  **Financial Processing:**\n  - Invoice: Bill generated for completed orders\n  - Invoice Number: Unique identifier for billing purposes\n  - Invoice Date: When billing document was created\n  \n  ## Business Rules\n  - One order can have multiple items (products)\n  - One order can be fulfilled through multiple shipments\n  - Each shipment links to one invoice for billing\n  - Order items can have different statuses within the same order\n  - Customers can have multiple orders over time"
  },
  {
    "path": "example/common_columns.csv",
    "content": "column_name,display_name,alias,type,category,tag,description,dimension_table,default\r\ncustomer_id,Customer ID,cust_id,INTEGER,identifier,customer,Unique identifier for customers,Customers,\r\ncustomer_name,Customer Name,cust_name,VARCHAR(80),attribute,customer,Name of the customer,Customers,\r\ncustomer_details,Customer Details,cust_details,VARCHAR(255),attribute,customer,Additional customer information,Customers,\r\ninvoice_number,Invoice Number,inv_num,INTEGER,identifier,financial,Unique invoice identifier,Invoices,\r\ninvoice_date,Invoice Date,inv_date,DATETIME,temporal,financial,Date the invoice was created,Invoices,\r\ninvoice_details,Invoice Details,inv_details,VARCHAR(255),attribute,financial,Additional invoice information,Invoices,\r\norder_item_id,Order Item ID,oi_id,INTEGER,identifier,order,Unique identifier for order items,Order_Items,\r\nproduct_id,Product ID,prod_id,INTEGER,identifier,product,Unique identifier for products,Products,\r\norder_id,Order ID,ord_id,INTEGER,identifier,order,Unique identifier for orders,Orders,\r\norder_item_status,Order Item Status,oi_status,VARCHAR(10),status,order,Current status of the order item (Finish|Payed|Cancel),Order_Items,\r\norder_item_details,Order Item Details,oi_details,VARCHAR(255),attribute,order,Additional order item information,Order_Items,\r\norder_status,Order Status,ord_status,VARCHAR(10),status,order,Current status of the order (Shipped|Packing|On Road),Orders,\r\ndate_order_placed,Order Placed Date,ord_date,DATETIME,temporal,order,Date when the order was placed,Orders,\r\norder_details,Order Details,ord_details,VARCHAR(255),attribute,order,Additional order information,Orders,\r\nproduct_name,Product Name,prod_name,VARCHAR(80),attribute,product,Name of the product,Products,\r\nproduct_details,Product Details,prod_details,VARCHAR(255),attribute,product,Additional product information,Products,\r\nshipment_id,Shipment ID,ship_id,INTEGER,identifier,shipment,Unique identifier for shipments,Shipments,\r\nshipment_tracking_number,Tracking Number,track_num,VARCHAR(80),identifier,shipment,Tracking number for shipment,Shipments,\r\nshipment_date,Shipment Date,ship_date,DATETIME,temporal,shipment,Date when the shipment was sent,Shipments,\r\nother_shipment_details,Shipment Details,ship_details,VARCHAR(255),attribute,shipment,Additional shipment information,Shipments,\r\n"
  },
  {
    "path": "example/config.yaml",
    "content": "organization: MyCompany\ndialect: sqlite\nbi_config_file: example/bi.yaml\n\npython_executor: docker\n\n# Visualization configuration\nvisualization_mode: llm\n\n# Catalog store configuration\ncatalog_store:\n  store_type: file_system\n  data_path: ./example\n\n# Data warehouse configuration\ndata_warehouse_config:\n  # sqlite from spider->tracking_orders dataset\n  uri: \"sqlite:///example/tracking_orders.sqlite\"\n  database_name: \"\"\n\n# LLM configurations\n# Use OpenAI LLM, replace YOUR_API_KEY_HERE with your actual API key\ndefault_llm:\n  class: langchain_openai.ChatOpenAI\n  params:\n    api_key: YOUR_API_KEY_HERE\n    model: gpt-4.1\n    temperature: 0.01\n    max_tokens: 8192\n\nembedding_model:\n  class: langchain_openai.OpenAIEmbeddings\n  params:\n    api_key: YOUR_API_KEY_HERE\n    model: text-embedding-3-large\n    chunk_size: 1024\n\n# If you cannot access to OpenAI or other cloud LLM provider,\n# uncomment the following lines instead to use Ollama local LLM\n#default_llm:\n#  class: langchain_ollama.ChatOllama\n#  params:\n#    model: gpt-oss:20b\n#    temperature: 0.01\n#    num_predict: 8192\n"
  },
  {
    "path": "example/sql_example.yaml",
    "content": "'':\n  Customers: |\n    Q: Show me all customers with their names and details\n    A: SELECT customer_id, customer_name, customer_details \n    FROM Customers \n    ORDER BY customer_name\n  Invoices: |\n    Q: List all invoices from the last 30 days\n    A: SELECT invoice_number, invoice_date, invoice_details \n    FROM Invoices \n    WHERE invoice_date >= DATE(''now'', ''-30 days'') \n    ORDER BY invoice_date DESC\n    \n  Order_Items: |\n    Q: Show me all items in order 123\n    A: SELECT oi.order_item_id, p.product_name, oi.order_item_status, oi.order_item_details \n    FROM Order_Items oi \n    JOIN Products p ON oi.product_id = p.product_id \n    WHERE oi.order_id = 123\n  Orders: |\n    Q: Find all pending orders with customer information\n    A: SELECT o.order_id, c.customer_name, o.order_status, o.date_order_placed \n    FROM Orders o \n    JOIN Customers c ON o.customer_id = c.customer_id \n    WHERE o.order_status = ''pending'' \n    ORDER BY o.date_order_placed\n  Products: |\n    Q: Search for products containing ''laptop'' in the name\n    A: SELECT product_id, product_name, product_details \n    FROM Products \n    WHERE product_name LIKE ''%laptop%'' \n    ORDER BY product_name'\n  Shipment_Items: |\n    Q: Show which order items are in shipment 456\n    A: SELECT si.shipment_id, si.order_item_id, p.product_name \n    FROM Shipment_Items si \n    JOIN Order_Items oi ON si.order_item_id = oi.order_item_id \n    JOIN Products p ON oi.product_id = p.product_id \n    WHERE si.shipment_id = 456\n  Shipments: |\n    Q: Track all shipments for order 789\n    A: SELECT shipment_id, shipment_tracking_number, shipment_date, other_shipment_details \n    FROM Shipments \n    WHERE order_id = 789 \n    ORDER BY shipment_date\n"
  },
  {
    "path": "example/table_columns.csv",
    "content": "db_name,table_name,column_name\r\n,Customers,customer_id\r\n,Customers,customer_name\r\n,Customers,customer_details\r\n,Invoices,invoice_number\r\n,Invoices,invoice_date\r\n,Invoices,invoice_details\r\n,Order_Items,order_item_id\r\n,Order_Items,product_id\r\n,Order_Items,order_id\r\n,Order_Items,order_item_status\r\n,Order_Items,order_item_details\r\n,Orders,order_id\r\n,Orders,customer_id\r\n,Orders,order_status\r\n,Orders,date_order_placed\r\n,Orders,order_details\r\n,Products,product_id\r\n,Products,product_name\r\n,Products,product_details\r\n,Shipment_Items,shipment_id\r\n,Shipment_Items,order_item_id\r\n,Shipments,shipment_id\r\n,Shipments,order_id\r\n,Shipments,invoice_number\r\n,Shipments,shipment_tracking_number\r\n,Shipments,shipment_date\r\n,Shipments,other_shipment_details\r\n"
  },
  {
    "path": "example/table_info.yaml",
    "content": "? ''\n: Customers:\n    description: 'Contains customer information including unique ID, name, and additional details'\n    selection_rule: 'Select when queries involve customer information, customer names, or need to join orders with customer data'\n    sql_rule: 'Use customer_id as primary key for joins. Always include customer_name when displaying customer information'\n  Invoices:\n    description: 'Stores invoice information with unique invoice numbers, dates, and details'\n    selection_rule: 'Select when queries involve billing, invoice tracking, or financial reporting'\n    sql_rule: 'Use invoice_number as primary key. Filter by invoice_date for temporal queries'\n  Order_Items:\n    description: 'Links products to orders with individual item status and details'\n    selection_rule: 'Select when queries need product details within orders or item-level status tracking'\n    sql_rule: 'Always join with Products table via product_id and Orders table via order_id for complete information'\n  Orders:\n    description: 'Main order table containing order status, placement date, and customer relationships'\n    selection_rule: 'Select when queries involve order status, order history, or customer order relationships'\n    sql_rule: 'Use order_id as primary key. Join with Customers via customer_id for customer information'\n  Products:\n    description: 'Product catalog containing product names, IDs, and detailed product information'\n    selection_rule: 'Select when queries involve product information, product searches, or inventory-related questions'\n    sql_rule: 'Use product_id as primary key. Use LIKE operator for product_name searches'\n  Shipment_Items:\n    description: 'Junction table linking shipments to specific order items'\n    selection_rule: 'Select when queries need to track which specific items are in which shipments'\n    sql_rule: 'Always join with both Shipments and Order_Items tables. No primary key - composite key of shipment_id and order_item_id'\n  Shipments:\n    description: 'Shipment tracking information including tracking numbers, dates, and shipment details'\n    selection_rule: 'Select when queries involve shipping, delivery tracking, or fulfillment information'\n    sql_rule: 'Use shipment_id as primary key. Join with Orders via order_id and Invoices via invoice_number for complete shipping context'\n"
  },
  {
    "path": "example/table_selection_example.csv",
    "content": "question,selected_tables\r\n\"Show me all customers\",\"[\"\"Customers\"\"]\"\r\n\"What orders were placed today?\",\"[\"\"Orders\"\"]\"\r\n\"List all products and their details\",\"[\"\"Products\"\"]\"\r\n\"Show me customer orders with their details\",\"[\"\"Customers\"\", \"\"Orders\"\"]\"\r\n\"What products are in each order?\",\"[\"\"Orders\"\", \"\"Order_Items\"\", \"\"Products\"\"]\"\r\n\"Show shipment tracking information\",\"[\"\"Shipments\"\"]\"\r\n\"Which items are in each shipment?\",\"[\"\"Shipments\"\", \"\"Shipment_Items\"\", \"\"Order_Items\"\"]\"\r\n\"Show order status and customer information\",\"[\"\"Orders\"\", \"\"Customers\"\"]\"\r\n\"What invoices were created this month?\",\"[\"\"Invoices\"\"]\"\r\n\"Show complete order fulfillment chain\",\"[\"\"Orders\"\", \"\"Order_Items\"\", \"\"Products\"\", \"\"Shipments\"\", \"\"Invoices\"\"]\"\r\n"
  },
  {
    "path": "openchatbi/__init__.py",
    "content": "\"\"\"OpenChatBI core module initialization.\"\"\"\n\nimport os\n\nfrom langgraph.graph.state import CompiledStateGraph\n\nfrom openchatbi.config_loader import ConfigLoader\n\n# Global configuration instance\nconfig = ConfigLoader()\n# Skip config loading during documentation build\nif not os.environ.get(\"SPHINX_BUILD\"):\n    config.load()\nelse:\n    config.set({})\n\n\ndef get_default_graph():\n    \"\"\"\n    Build the synchronous mode of the agent graph using default catalog in config.\n\n    Returns:\n        CompiledStateGraph: Compiled agent graph ready for execution.\n    \"\"\"\n    if os.environ.get(\"SPHINX_BUILD\"):\n        return None\n\n    from langgraph.checkpoint.memory import MemorySaver\n\n    from openchatbi.agent_graph import build_agent_graph_sync\n    from openchatbi.tool.memory import get_sync_memory_store\n\n    checkpointer = MemorySaver()\n    return build_agent_graph_sync(\n        config.get().catalog_store, checkpointer=checkpointer, memory_store=get_sync_memory_store()\n    )\n"
  },
  {
    "path": "openchatbi/agent_graph.py",
    "content": "\"\"\"Main agent graph construction and execution logic.\"\"\"\n\nimport datetime\nimport logging\nimport traceback\nfrom collections.abc import Callable\nfrom typing import Any\n\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.messages import AIMessage, HumanMessage, SystemMessage\nfrom langchain_core.tools import StructuredTool\nfrom langchain_openai.chat_models.base import BaseChatOpenAI\nfrom langgraph.constants import START\nfrom langgraph.errors import GraphInterrupt\nfrom langgraph.graph import END, StateGraph\nfrom langgraph.graph.state import CompiledStateGraph\nfrom langgraph.prebuilt import ToolNode\nfrom langgraph.store.base import BaseStore\nfrom langgraph.types import Checkpointer, Send, interrupt\nfrom pydantic import BaseModel, Field\n\nfrom openchatbi import config\nfrom openchatbi.catalog import CatalogStore\nfrom openchatbi.constants import datetime_format\nfrom openchatbi.context_config import get_context_config\nfrom openchatbi.context_manager import ContextManager\nfrom openchatbi.graph_state import AgentState, InputState, OutputState\nfrom openchatbi.llm.llm import call_llm_chat_model_with_retry, get_llm\nfrom openchatbi.prompts.system_prompt import get_agent_prompt_template\nfrom openchatbi.text2sql.sql_graph import build_sql_graph\nfrom openchatbi.tool.ask_human import AskHuman\nfrom openchatbi.tool.mcp_tools import create_mcp_tools_sync, get_mcp_tools_async\nfrom openchatbi.tool.memory import get_memory_tools\nfrom openchatbi.tool.run_python_code import run_python_code\nfrom openchatbi.tool.save_report import save_report\nfrom openchatbi.tool.search_knowledge import search_knowledge, show_schema\nfrom openchatbi.tool.timeseries_forecast import check_forecast_service_health, timeseries_forecast\nfrom openchatbi.utils import log, recover_incomplete_tool_calls\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_mcp_servers():\n    \"\"\"Get MCP servers from config with fallback for tests.\"\"\"\n    try:\n        return config.get().mcp_servers\n    except ValueError:\n        return []\n\n\ndef ask_human(state: AgentState) -> dict[str, Any]:\n    \"\"\"Node function to ask human for additional information or clarification.\n\n    Args:\n        state (AgentState): The current graph state containing messages and context.\n\n    Returns:\n        dict: Updated state with human feedback as a tool message and user input.\n    \"\"\"\n    tool_call = state[\"messages\"][-1].tool_calls[0]\n    tool_call_id = tool_call[\"id\"]\n    args = tool_call[\"args\"]\n    user_feedback = interrupt({\"text\": args[\"question\"], \"buttons\": args.get(\"options\", None)})\n    tool_message = [{\"tool_call_id\": tool_call_id, \"type\": \"tool\", \"content\": user_feedback}]\n    return {\n        \"messages\": tool_message,\n        \"history_messages\": [AIMessage(args[\"question\"]), HumanMessage(user_feedback)],\n        \"user_input\": user_feedback,\n    }\n\n\nclass CallSQLGraphInput(BaseModel):\n    reasoning: str = Field(\n        description=\"Explanation of why Text2SQL tool is needed\",\n    )\n    context: str = Field(\n        description=\"\"\"The full context pass to Text2SQL tool, make sure do not miss any potential information that related to user's question.\n        Following the format: History Conversation: (user and assistant history dialog)\n        Information: (the knowledge you retrival that is relevant, like metrics and dimensions)\n        User's latest question:\"\"\",\n    )\n\n\n# Description for SQL tools\nTEXT2SQL_TOOL_DESCRIPTION = \"\"\"Text2SQL tool to generate and execute SQL query and build visualization DSL for UI\nbased on user's question and context.\n\nReturns:\n    str: A formatted response containing SQL, data, and visualization status.\n\nImportant notes:\n- If user want to change the visualization chart type or style, add the requirement in the question\n- Make sure to provide question in English\n\"\"\"\n\n\ndef _format_sql_response(sql_graph_response: dict) -> str:\n    \"\"\"Format SQL graph response into a standardized string format.\n\n    Args:\n        sql_graph_response: The response dictionary from the SQL graph\n\n    Returns:\n        str: Formatted response string\n    \"\"\"\n    sql = sql_graph_response.get(\"sql\", \"\")\n    data = sql_graph_response.get(\"data\", \"\")\n    visualization_dsl = sql_graph_response.get(\"visualization_dsl\", {})\n\n    response_parts = []\n    if sql:\n        response_parts.append(f\"SQL Query:\\n```sql\\n{sql}\\n```\")\n    if data:\n        response_parts.append(f\"\\nQuery Results (CSV format):\\n```csv\\n{data}\\n```\")\n\n    # Include visualization status\n    if visualization_dsl and \"error\" not in visualization_dsl:\n        chart_type = visualization_dsl.get(\"chart_type\", \"unknown\")\n        response_parts.append(\n            f\"\\nVisualization Created: {chart_type} chart has been automatically generated and will be displayed in the UI.\"\n        )\n    elif visualization_dsl and \"error\" in visualization_dsl:\n        response_parts.append(f\"\\nVisualization Error: {visualization_dsl['error']}\")\n\n    return \"\\n\\n\".join(response_parts) if response_parts else \"No results returned.\"\n\n\ndef get_sql_tools(sql_graph: CompiledStateGraph, sync_mode: bool = False) -> Callable:\n    \"\"\"Create SQL generation tool from compiled SQL graph.\n\n    Args:\n        sql_graph (CompiledStateGraph): The compiled SQL generation subgraph.\n        sync_mode (bool): Whether to create synchronous or asynchronous tools\n\n    Returns:\n        function: Tool function for SQL generation.\n    \"\"\"\n\n    def call_sql_graph_sync(reasoning: str, context: str) -> str:\n        \"\"\"Sync node function for Text2SQL tool\"\"\"\n        log(f\"Call SQL graph (sync) with reasoning: {reasoning}, context: {context}\")\n        try:\n            sql_graph_response = sql_graph.invoke({\"messages\": context})\n            return _format_sql_response(sql_graph_response)\n        except GraphInterrupt as e:\n            log(f\"Sql graph interrupted:\\n{repr(e)}\")\n            raise e\n        except Exception as e:\n            log(f\"Run sql graph error:\\n{repr(e)}\")\n            traceback.print_exc()\n        return \"Error occurred when calling Text2SQL tool.\"\n\n    async def call_sql_graph_async(reasoning: str, context: str) -> str:\n        \"\"\"Async node function for Text2SQL tool\"\"\"\n        log(f\"Call SQL graph (async) with reasoning: {reasoning}, context: {context}\")\n        try:\n            sql_graph_response = await sql_graph.ainvoke({\"messages\": context})\n            return _format_sql_response(sql_graph_response)\n        except GraphInterrupt as e:\n            log(f\"Sql graph interrupted:\\n{repr(e)}\")\n            raise e\n        except Exception as e:\n            log(f\"Run sql graph error:\\n{repr(e)}\")\n            traceback.print_exc()\n        return \"Error occurred when calling Text2SQL tool.\"\n\n    if sync_mode:\n        return StructuredTool.from_function(\n            func=call_sql_graph_sync,\n            name=\"text2sql\",\n            description=TEXT2SQL_TOOL_DESCRIPTION,\n            args_schema=CallSQLGraphInput,\n            return_direct=False,\n        )\n    else:\n        return StructuredTool.from_function(\n            coroutine=call_sql_graph_async,\n            name=\"text2sql\",\n            description=TEXT2SQL_TOOL_DESCRIPTION,\n            args_schema=CallSQLGraphInput,\n            return_direct=False,\n        )\n\n\ndef agent_llm_call(llm: BaseChatModel, tools: list, context_manager: ContextManager = None) -> Callable:\n    \"\"\"Create llm call function to generate reasoning and determine next node based on tool calls in LLM response.\n\n    Args:\n        llm (BaseChatModel): The LLM for agent decision-making.\n        tools: List of tools.\n        context_manager: Optional context manager for handling long conversations.\n\n    Returns:\n        function: function that processes state and determines next node.\n    \"\"\"\n\n    # OpenAI models support strict tool calling\n    if isinstance(llm, BaseChatOpenAI):\n        llm_with_tools = llm.bind_tools(tools, strict=True)\n    else:\n        llm_with_tools = llm.bind_tools(tools)\n\n    def _call_model(state: AgentState):\n        # First, check and recover any incomplete tool calls\n        recovery_ops = recover_incomplete_tool_calls(state)\n        if recovery_ops:\n            return {\"messages\": recovery_ops, \"agent_next_node\": \"llm_node\"}\n\n        messages = state[\"messages\"]\n        final_messages = []\n        if isinstance(messages[-1], HumanMessage):\n            final_messages.append(messages[-1])\n\n        # Apply context management if available (before processing)\n        if context_manager:\n            original_count = len(messages)\n            context_manager.manage_context_messages(messages)\n            if len(messages) != original_count:\n                logger.info(f\"Context management: modified messages from {original_count} to {len(messages)}\")\n\n        system_prompt = get_agent_prompt_template().replace(\n            \"[time_field_placeholder]\", datetime.datetime.now().strftime(datetime_format)\n        )\n\n        response = call_llm_chat_model_with_retry(\n            llm_with_tools,\n            ([SystemMessage(system_prompt)] + messages),\n            streaming_tokens=True,\n            bound_tools=tools,\n            parallel_tool_call=True,\n        )\n        if isinstance(response, AIMessage):\n            tool_calls = response.tool_calls\n            print(\"Tool Call:\", \", \".join(tool[\"name\"] for tool in tool_calls))\n            if tool_calls:\n                # Group tool calls by type for parallel routing\n                ask_human_calls = [call for call in tool_calls if call[\"name\"] == \"AskHuman\"]\n                normal_tool_calls = [call for call in tool_calls if call[\"name\"] != \"AskHuman\"]\n\n                # Create Send objects for parallel routing\n                sends = []\n                if ask_human_calls:\n                    # Create message with only AskHuman calls\n                    ask_human_msg = AIMessage(content=response.content, tool_calls=ask_human_calls)\n                    sends.append(Send(\"ask_human\", {\"messages\": [ask_human_msg]}))\n\n                if normal_tool_calls:\n                    # Create message with only normal tool calls\n                    tool_msg = AIMessage(content=response.content, tool_calls=normal_tool_calls)\n                    sends.append(Send(\"use_tool\", {\"messages\": [tool_msg]}))\n\n                return {\"messages\": [response], \"history_messages\": final_messages, \"sends\": sends}\n            else:\n                final_messages.append(AIMessage(response.content))\n                return {\n                    \"messages\": [response],\n                    \"final_answer\": response.content,\n                    \"history_messages\": final_messages,\n                    \"agent_next_node\": END,\n                }\n        elif response is None:\n            return {\n                \"messages\": [AIMessage(\"Sorry, the LLM service is currently unavailable.\")],\n                \"history_messages\": final_messages,\n                \"agent_next_node\": END,\n            }\n        else:\n            return {\"messages\": [response], \"history_messages\": final_messages, \"agent_next_node\": END}\n\n    return _call_model\n\n\ndef _build_graph_core(\n    catalog: CatalogStore,\n    sync_mode: bool,\n    checkpointer: Checkpointer,\n    memory_store: BaseStore,\n    memory_tools: list[Callable] | None,\n    mcp_tools: list,\n    enable_context_management: bool = True,\n    llm_provider: str | None = None,\n) -> CompiledStateGraph:\n    \"\"\"Core graph building logic shared by both sync and async versions.\n\n    Args:\n        catalog: Catalog store containing schema information\n        sync_mode: Whether to use synchronous mode for tools and operations\n        checkpointer: The Checkpointer for state persistence\n        memory_store: The BaseStore to use for long-term memory\n        memory_tools: List of memory tools (manage_memory_tool, search_memory_tool)\n        mcp_tools: Pre-initialized MCP tools\n        enable_context_management: Whether to enable context management\n\n    Returns:\n        CompiledStateGraph: Compiled agent graph ready for execution\n    \"\"\"\n    sql_graph = build_sql_graph(catalog, checkpointer, memory_store, llm_provider=llm_provider)\n    call_sql_graph_tool = get_sql_tools(sql_graph=sql_graph, sync_mode=sync_mode)\n\n    # Use provided memory tools or create them\n    if not memory_tools:\n        memory_tools = get_memory_tools(get_llm(llm_provider), sync_mode=sync_mode, store=memory_store)\n\n    log(str(mcp_tools))\n    normal_tools = [\n        search_knowledge,\n        show_schema,\n        call_sql_graph_tool,\n        run_python_code,\n        save_report,\n    ]\n    if memory_tools:\n        normal_tools.extend(memory_tools)\n    if check_forecast_service_health():\n        normal_tools.append(timeseries_forecast)\n    else:\n        logger.warning(\"Time series forecasting service is not healthy. Skipping timeseries_forecast tool.\")\n    normal_tools.extend(mcp_tools)\n\n    # Initialize context manager if enabled\n    context_manager = None\n    if enable_context_management:\n        context_manager = ContextManager(llm=get_llm(llm_provider), config=get_context_config())\n\n    tool_node = ToolNode(normal_tools)\n\n    # Define the agent graph\n    graph = StateGraph(AgentState, input_schema=InputState, output_schema=OutputState)\n\n    # Add nodes to the graph\n    graph.add_node(\"llm_node\", agent_llm_call(get_llm(llm_provider), normal_tools + [AskHuman], context_manager))\n    graph.add_node(\"ask_human\", ask_human)\n    graph.add_node(\"use_tool\", tool_node)\n\n    # Add edges between nodes\n    graph.add_edge(START, \"llm_node\")\n    graph.add_edge(\"ask_human\", \"llm_node\")\n    graph.add_edge(\"use_tool\", \"llm_node\")\n\n    # Add conditional routing from llm node\n    def route_tools(state: AgentState):\n        # Only use sends if the last message came from the llm node (has tool_calls)\n        last_message = state[\"messages\"][-1] if state[\"messages\"] else None\n        if (\n            last_message\n            and isinstance(last_message, AIMessage)\n            and last_message.tool_calls\n            and \"sends\" in state\n            and state[\"sends\"]\n        ):\n            return state[\"sends\"]  # Return Send objects for parallel execution\n        elif \"agent_next_node\" in state:\n            return state[\"agent_next_node\"]  # Return single node name\n        else:\n            return END\n\n    graph.add_conditional_edges(\n        \"llm_node\",\n        route_tools,\n        # mapping of paths to node names (for single routing)\n        {\n            \"llm_node\": \"llm_node\",\n            \"ask_human\": \"ask_human\",\n            \"use_tool\": \"use_tool\",\n            END: END,\n        },\n    )\n\n    graph = graph.compile(name=\"agent_graph\", checkpointer=checkpointer, store=memory_store)\n    return graph\n\n\ndef build_agent_graph_sync(\n    catalog: CatalogStore,\n    checkpointer: Checkpointer = None,\n    memory_store: BaseStore = None,\n    enable_context_management: bool = True,\n    llm_provider: str | None = None,\n) -> CompiledStateGraph:\n    \"\"\"Build the main agent graph with all nodes and edges (sync version).\n\n    Args:\n        catalog: Catalog store containing schema information.\n        checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.\n        memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.\n        enable_context_management: Whether to enable context management for long conversations.\n\n    Returns:\n        CompiledStateGraph: Compiled agent graph ready for execution.\n    \"\"\"\n    # Get MCP tools for sync context\n    mcp_tools = create_mcp_tools_sync(get_mcp_servers())\n\n    return _build_graph_core(\n        catalog=catalog,\n        sync_mode=True,\n        checkpointer=checkpointer,\n        memory_store=memory_store,\n        memory_tools=None,  # Always None for sync version - creates its own\n        mcp_tools=mcp_tools,\n        enable_context_management=enable_context_management,\n        llm_provider=llm_provider,\n    )\n\n\nasync def build_agent_graph_async(\n    catalog: CatalogStore,\n    checkpointer: Checkpointer = None,\n    memory_store: BaseStore = None,\n    memory_tools: list[Callable] = None,\n    enable_context_management: bool = True,\n    llm_provider: str | None = None,\n) -> CompiledStateGraph:\n    \"\"\"Build the main agent graph with all nodes and edges (async version).\n\n    This function is identical to build_agent_graph_sync but properly handles\n    async MCP tool initialization when called from async contexts.\n\n    Args:\n        catalog: Catalog store containing schema information.\n        checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.\n        memory_store: The BaseStore to use for long-term memory. If None, will auto assign according to sync_mode.\n        memory_tools: List of memory tools (manage_memory_tool, search_memory_tool). If None, creates async tools.\n        enable_context_management: Whether to enable context management for long conversations.\n\n    Returns:\n        CompiledStateGraph: Compiled agent graph ready for execution.\n    \"\"\"\n    # Get MCP tools for async context\n    mcp_tools = await get_mcp_tools_async(get_mcp_servers())\n\n    return _build_graph_core(\n        catalog=catalog,\n        sync_mode=False,\n        checkpointer=checkpointer,\n        memory_store=memory_store,\n        memory_tools=memory_tools,\n        mcp_tools=mcp_tools,\n        enable_context_management=enable_context_management,\n        llm_provider=llm_provider,\n    )\n"
  },
  {
    "path": "openchatbi/catalog/__init__.py",
    "content": "\"\"\"Data catalog management module for OpenChatBI.\"\"\"\n\nfrom openchatbi.catalog.catalog_loader import (\n    DataCatalogLoader,\n    load_catalog_from_data_warehouse,\n)\nfrom openchatbi.catalog.catalog_store import CatalogStore\nfrom openchatbi.catalog.factory import create_catalog_store\n\n__all__ = [\n    \"CatalogStore\",\n    \"DataCatalogLoader\",\n    \"load_catalog_from_data_warehouse\",\n]\n"
  },
  {
    "path": "openchatbi/catalog/catalog_loader.py",
    "content": "import logging\nfrom typing import Any\n\nfrom sqlalchemy import MetaData, inspect\nfrom sqlalchemy.engine import Engine\n\nfrom .catalog_store import CatalogStore\n\nlogger = logging.getLogger(__name__)\n\n\nclass DataCatalogLoader:\n    \"\"\"\n    The loader to load data catalog from data warehouse metadata and save to catalog store.\n    \"\"\"\n\n    def __init__(self, engine: Engine, include_tables: list[str] | None = None):\n        \"\"\"\n        Initialize catalog loader.\n\n        Args:\n            engine (Engine): SQLAlchemy engine instance\n            include_tables (Optional[List[str]]): List of table names to include, None for all\n        \"\"\"\n        self.engine = engine\n        self.include_tables = include_tables\n        self.metadata = MetaData()\n        self.inspector = inspect(engine)\n\n    def get_tables_and_columns(self) -> dict[str, list[dict[str, Any]]]:\n        \"\"\"\n        Extract table and column metadata including comments using SQLAlchemy inspector.\n\n        Returns:\n            Dict[str, List[Dict[str, Any]]]: Dictionary mapping table names to list of column information\n        \"\"\"\n        try:\n            tables_columns = {}\n\n            # Get all table names\n            table_names = self.inspector.get_table_names()\n\n            # Filter to specific tables if configured\n            if self.include_tables:\n                table_names = [name for name in table_names if name in self.include_tables]\n\n            logger.info(f\"Found {len(table_names)} tables to process\")\n\n            for table_name in table_names:\n                try:\n                    # Get column information for the table\n                    columns = self.inspector.get_columns(table_name)\n                    column_list = []\n                    for column in columns:\n                        is_common_column = column not in (\"id\", \"name\", \"type\", \"status\")\n                        column_info = {\n                            \"column_name\": column[\"name\"],\n                            \"display_name\": \"\",\n                            \"alias\": \"\",\n                            \"type\": str(column[\"type\"]),\n                            \"category\": \"\",\n                            \"tag\": \"\",\n                            \"description\": column.get(\"comment\", \"\") or \"\",\n                            \"dimension_table\": \"\",\n                            \"default\": str(column.get(\"default\", \"\")) if column.get(\"default\") is not None else \"\",\n                            \"is_common\": is_common_column,\n                        }\n                        column_list.append(column_info)\n\n                    tables_columns[table_name] = column_list\n                    logger.debug(f\"Processed table {table_name} with {len(column_list)} columns\")\n\n                except Exception as e:\n                    logger.error(f\"Failed to process table {table_name}: {e}\")\n                    continue\n\n            logger.info(f\"Successfully processed {len(tables_columns)} tables\")\n            return tables_columns\n\n        except Exception as e:\n            logger.error(f\"Failed to get tables and columns from data warehouse: {e}\")\n            return {}\n\n    def get_table_indexes(self, table_name: str) -> list[dict[str, Any]]:\n        \"\"\"\n        Get index information for a specific table.\n\n        Args:\n            table_name (str): Name of the table\n\n        Returns:\n            List[Dict[str, Any]]: List of index information\n        \"\"\"\n        try:\n            indexes = self.inspector.get_indexes(table_name)\n            return indexes\n        except Exception as e:\n            logger.warning(f\"Failed to get indexes for table {table_name}: {e}\")\n            return []\n\n    def get_foreign_keys(self, table_name: str) -> list[dict[str, Any]]:\n        \"\"\"\n        Get foreign key information for a specific table.\n\n        Args:\n            table_name (str): Name of the table\n\n        Returns:\n            List[Dict[str, Any]]: List of foreign key information\n        \"\"\"\n        try:\n            foreign_keys = self.inspector.get_foreign_keys(table_name)\n            return foreign_keys\n        except Exception as e:\n            logger.warning(f\"Failed to get foreign keys for table {table_name}: {e}\")\n            return []\n\n    def save_to_catalog_store(\n        self, catalog_store: CatalogStore, database_name: str | None = None, update: bool = False\n    ) -> bool:\n        \"\"\"\n        Extract warehouse metadata and save to catalog store.\n\n        Args:\n            catalog_store (CatalogStore): Target catalog store to load data to\n            database_name (Optional[str]): Database name in catalog, defaults to 'default'\n            update (bool): Update existing catalog store to sync with data warehouse\n\n        Returns:\n            bool: True if load was successful, False otherwise\n        \"\"\"\n        try:\n            if database_name is None:\n                database_name = \"default\"\n\n            # Get tables and columns from data warehouse\n            tables_columns = self.get_tables_and_columns()\n\n            if not tables_columns:\n                logger.warning(\"No tables found in data warehouse\")\n                return True\n\n            # Import each table\n            success_count = 0\n            total_count = len(tables_columns)\n\n            for table_name, columns in tables_columns.items():\n                try:\n                    # Get table comment if available\n                    table_comment = \"\"\n                    try:\n                        table_info = self.inspector.get_table_comment(table_name)\n                        table_comment = table_info.get(\"text\", \"\") if table_info else \"\"\n                    except Exception:\n                        # Some databases don't support table comments\n                        pass\n\n                    table_info = {\"description\": table_comment, \"selection_rule\": \"\", \"sql_rule\": \"\"}\n                    if catalog_store.save_table_information(table_name, table_info, columns, database_name):\n                        success_count += 1\n                        logger.info(f\"Successfully loaded table: {database_name}.{table_name}\")\n                    else:\n                        logger.error(f\"Failed to load table: {database_name}.{table_name}\")\n\n                    # init null SQL examples\n                    catalog_store.save_table_sql_examples(\n                        table_name, [{\"question\": \"null\", \"answer\": \"null\"}], database_name\n                    )\n\n                except Exception as e:\n                    logger.error(f\"Error loading table {table_name}: {e}\")\n\n            # init empty table selection examples\n            catalog_store.save_table_selection_examples([(\"\", [])])\n\n            logger.info(f\"Load completed: {success_count}/{total_count} tables loaded successfully\")\n            return success_count == total_count\n\n        except Exception as e:\n            logger.error(f\"Failed to load data warehouse to catalog store: {e}\")\n            return False\n\n\ndef load_catalog_from_data_warehouse(catalog_store: CatalogStore) -> bool:\n    \"\"\"\n    Load catalog data from data warehouse using SQLAlchemy based on data warehouse config (URI)\n\n    Main entry point for catalog loading.\n\n    Args:\n        catalog_store (CatalogStore): Target catalog store\n\n    Returns:\n        bool: True if load was successful, False otherwise\n    \"\"\"\n    try:\n        data_warehouse_config = catalog_store.get_data_warehouse_config()\n        database_uri = data_warehouse_config.get(\"uri\")\n        include_tables = data_warehouse_config.get(\"include_tables\")\n        database_name = data_warehouse_config.get(\"database_name\", \"default\")\n        engine = catalog_store.get_sql_engine()\n\n        loader = DataCatalogLoader(engine, include_tables)\n        return loader.save_to_catalog_store(catalog_store, database_name)\n\n    except Exception as e:\n        logger.error(f\"Failed to import catalog from data warehouse URI {database_uri}: {e}\")\n        return False\n"
  },
  {
    "path": "openchatbi/catalog/catalog_store.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any\n\nfrom sqlalchemy import Engine\n\n\nclass CatalogStore(ABC):\n    \"\"\"\n    Abstract base class defining the storage interface for data catalog (database, table, column definitions, descriptions, and additional prompts).\n\n    Common columns which have same meanings across tables will be store centralized to avoid data duplication.\n\n    Column attribute:\n\n        - column_name: the name of the column\n        - display_name: the display name of the column\n        - type: the data type of the column\n        - category: dimension or metric\n        - description: the description of the column\n        - is_common: is common column or not\n    \"\"\"\n\n    @abstractmethod\n    def get_data_warehouse_config(self) -> dict:\n        \"\"\"\n        Get the data warehouse configuration\n\n        Returns:\n            dict: Data warehouse configuration\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_sql_engine(self) -> Engine:\n        \"\"\"\n        Get the SQLAlchemy engine for the catalog\n\n        Returns:\n            Engine: SQLAlchemy engine\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_database_list(self) -> list[str]:\n        \"\"\"\n        Get a list of all databases\n\n        Returns:\n            List[str]: List of database names\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_table_list(self, database: str | None = None) -> list[str]:\n        \"\"\"\n        Get a list of all tables in the specified database, if database is None, return all tables\n\n        Args:\n            database (Optional[str]): Database name\n\n        Returns:\n            List[str]: List of table names\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_column_list(self, table: str | None = None, database: str | None = None) -> list[dict[str, Any]]:\n        \"\"\"\n        Get all column information for the specified table, if table is None, return all common columns in the catalog\n\n        Args:\n            table (Optional[str]): Table name\n            database (Optional[str]): Database name\n\n        Returns:\n            List[Dict[str, Any]]: List of column information, each column contains name, type, description, etc.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_table_information(self, table: str, database: str | None = None) -> dict[str, Any]:\n        \"\"\"\n        Get the information for the specified table\n\n        Args:\n            table (str): Table name\n            database (Optional[str]): Database name\n\n        Returns:\n            Dict[str, Any]: Table information, including description text, selection rules, etc.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_sql_examples(\n        self, table: str | None = None, database: str | None = None\n    ) -> list[tuple[str, str, list[str]]]:\n        \"\"\"\n        Get SQL examples\n\n        Args:\n            table (Optional[str]): Table name\n            database (Optional[str]): Database name\n\n        Returns:\n            List[Tuple[str, str, List[str]]]: List of SQL examples, each example is a Tuple-3 as (question, SQL, full_table_names)\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_table_selection_examples(self) -> list[tuple[str, list[str]]]:\n        \"\"\"\n        Get table selection examples\n\n        Returns:\n            List[Tuple[str, List[str]]]: List of table selection examples, each example is a Tuple-2 as (question, selected tables)\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def save_table_information(\n        self,\n        table: str,\n        information: dict[str, Any],\n        columns: list[dict[str, Any]],\n        database: str | None = None,\n        update_existing: bool = False,\n    ) -> bool:\n        \"\"\"\n        Save the information and columns for a table\n\n        Args:\n            table (str): Table name\n            information (Dict[str, Any]): Table information\n            columns (List[Dict[str, Any]]): List of column information, each column dict contains at lease\n                column_name, type, category, description\n            database (Optional[str]): Database name\n            update_existing (bool): Update existing table and column information\n\n        Returns:\n            bool: Whether the save was successful\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def save_table_sql_examples(self, table: str, examples: list[dict[str, str]], database: str | None = None) -> bool:\n        \"\"\"\n        Save SQL examples for a table\n\n        Args:\n            table (str): Table name\n            examples (List[Dict[str, str]]): List of SQL examples\n            database (Optional[str]): Database name\n\n        Returns:\n            bool: Whether the save was successful\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def save_table_selection_examples(self, examples: list[tuple[str, list[str]]]) -> bool:\n        \"\"\"\n        Save table selection examples\n\n        Args:\n            examples (List[Tuple[str, List[str]]]): List of table selection examples\n\n        Returns:\n            bool: Whether the save was successful\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def check_exists(self) -> bool:\n        \"\"\"\n        Check if the catalog store has existing data/content\n\n        Returns:\n            bool: True if catalog store has existing data, False if empty or missing essential files\n        \"\"\"\n        pass\n\n\ndef split_db_table_name(table: str, database: str | None = None) -> tuple[str, str, str]:\n    \"\"\"\n    Split full table name into db name and table name\n    Args:\n        table (str): if database is None, should be full table name like `db.table`, otherwise should be only table name\n        database (Optional[str]): Database name\n    Returns:\n        Tuple[str, str, str]: full_table_name, db_name, table_name\n\n    \"\"\"\n    full_table_name = table\n    if database is not None and \".\" not in table:\n        full_table_name = f\"{database}.{table}\"\n    if \".\" in full_table_name:\n        db_name, table_name = full_table_name.rsplit(\".\", 1)\n    else:\n        db_name = \"\"\n        table_name = full_table_name\n    return full_table_name, db_name, table_name\n"
  },
  {
    "path": "openchatbi/catalog/factory.py",
    "content": "import logging\nimport os\n\nfrom openchatbi.catalog.catalog_loader import load_catalog_from_data_warehouse\nfrom openchatbi.catalog.catalog_store import CatalogStore\nfrom openchatbi.catalog.store.file_system import FileSystemCatalogStore\n\nlogger = logging.getLogger(__name__)\n\n\n# Factory function for creating CatalogStore instances\ndef create_catalog_store(\n    store_type: str, auto_load: bool = True, data_warehouse_config: dict = None, **kwargs\n) -> CatalogStore:\n    \"\"\"\n    Create a CatalogStore instance\n\n    Args:\n        store_type (str): Storage type, supports 'file_system'\n        auto_load (bool): Whether to autoload from database if catalog files don't exist\n        data_warehouse_config (dict): Data warehouse configuration dictionary\n        **kwargs: Other parameters\n\n    Returns:\n        CatalogStore: CatalogStore instance\n\n    Raises:\n        ValueError: If the storage type is not supported\n    \"\"\"\n    if store_type == \"file_system\":\n        data_path = kwargs.get(\"data_path\", \"data\")\n        # convert relative path to absolute path\n        if not data_path.startswith(\"/\"):\n            data_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), data_path)\n        catalog_store = FileSystemCatalogStore(data_path, data_warehouse_config)\n\n        # Check if autoload is enabled and if catalog files are missing\n        if auto_load:\n            _auto_load_catalog_if_needed(catalog_store)\n\n        return catalog_store\n    else:\n        raise ValueError(f\"Unsupported storage type: {store_type}\")\n\n\ndef _auto_load_catalog_if_needed(catalog_store: CatalogStore) -> None:\n    \"\"\"\n    Autoload catalog from data warehouse if catalog files are missing or empty\n\n    Args:\n        catalog_store (CatalogStore): The catalog store instance\n    \"\"\"\n\n    # Check if catalog store has existing data using the store's own check_exists method\n    if not catalog_store.check_exists():\n        logger.info(\"Catalog files missing or empty, attempting to load from data warehouse...\")\n\n        try:\n            # Get data warehouse config from loaded configuration\n            data_warehouse_config = catalog_store.get_data_warehouse_config()\n            if not data_warehouse_config:\n                logger.warning(\"No data warehouse configuration found, skipping autoload\")\n                return\n\n            warehouse_uri = data_warehouse_config.get(\"uri\")\n            if not warehouse_uri:\n                logger.warning(\"No data warehouse URI found in configuration, skipping autoload\")\n                return\n\n            # load catalog from data warehouse\n            success = load_catalog_from_data_warehouse(catalog_store)\n\n            if success:\n                logger.info(\"Successfully loaded catalog from data warehouse\")\n            else:\n                logger.error(\"Failed to load catalog from data warehouse\")\n                raise Exception(\"Failed to load catalog from data warehouse\")\n\n        except Exception as e:\n            logger.warning(f\"Autoload from data warehouse failed: {e}\")\n            raise Exception(\"Failed to load catalog from data warehouse\") from e\n"
  },
  {
    "path": "openchatbi/catalog/helper.py",
    "content": "from typing import Any\n\nimport requests\nfrom sqlalchemy import Engine, create_engine\n\nfrom openchatbi.catalog.token_service import apply_token_for_user\nfrom openchatbi.utils import log\n\n\ndef get_requests_session(token: str, header_extra_params: dict) -> requests.Session:\n    \"\"\"Create HTTP session with bearer token authentication.\"\"\"\n    session = requests.Session()\n    session.headers.update({\"Authorization\": f\"Bearer {token}\"})\n    if header_extra_params:\n        session.headers.update(header_extra_params)\n    return session\n\n\ndef create_sqlalchemy_engine_instance(data_warehouse_config: dict[str, Any]) -> Engine:\n    \"\"\"\n    Create SQLAlchemy engine instance from data warehouse config\n\n    Args:\n        data_warehouse_config: Config dict with 'uri' and optional 'token_service'\n\n    Returns:\n        Configured SQLAlchemy engine\n    \"\"\"\n    database_uri = data_warehouse_config.get(\"uri\")\n\n    engine_args = {\"echo\": True}\n\n    # Handle Presto authentication\n    if \"presto\" in database_uri and \"token_service\" in data_warehouse_config:\n        token_service = data_warehouse_config.get(\"token_service\")\n        user_name = data_warehouse_config.get(\"user_name\")\n        password = data_warehouse_config.get(\"password\")\n        header_extra_params = data_warehouse_config.get(\"header_extra_params\", {})\n        token = apply_token_for_user(token_service, user_name, password)\n        log(f\"Applied presto token: {token} for user: {user_name}\")\n        engine_args[\"connect_args\"] = {\n            \"protocol\": \"https\",\n            \"requests_session\": get_requests_session(token, header_extra_params),\n        }\n        database_uri = database_uri.format(user_name=user_name)\n\n    engine = create_engine(database_uri, **engine_args)\n\n    return engine\n"
  },
  {
    "path": "openchatbi/catalog/retrival_helper.py",
    "content": "\"\"\"Helper functions for building column retrieval systems.\"\"\"\n\nfrom rank_bm25 import BM25Okapi\n\nfrom openchatbi.llm.llm import get_embedding_model\nfrom openchatbi.text_segmenter import _segmenter\nfrom openchatbi.utils import create_vector_db, log\n\n\ndef get_columns_metadata(catalog):\n    \"\"\"Extract column metadata for indexing.\n\n    Args:\n        catalog: Catalog store instance.\n\n    Returns:\n        tuple: (columns, col_dict, column_tokens, embedding_keys)\n    \"\"\"\n    columns = catalog.get_column_list()\n    col_dict = {}\n    column_tokens = []\n    embedding_keys = []\n    for column in columns:\n        col_dict[column[\"column_name\"]] = column\n        text_parts = [\n            column.get(\"column_name\", \"\"),\n            column.get(\"display_name\", \"\"),\n            column.get(\"alias\", \"\"),\n            column.get(\"tag\", \"\"),\n            column.get(\"description\", \"\"),\n        ]\n        text = \" \".join(text_parts)\n        tokens = [token for token in _segmenter.cut(text) if token not in (\"_\", \" \")]\n        column_tokens.append(tokens)\n        embedding_key = f\"{column['column_name']}: {column['display_name']}\"\n        embedding_keys.append(embedding_key)\n    return columns, col_dict, column_tokens, embedding_keys\n\n\ndef build_column_tables_mapping(catalog):\n    \"\"\"Build a mapping of column names to their corresponding table names.\"\"\"\n    column_tables_mapping = {}\n    for table_name in catalog.get_table_list():\n        for column in catalog.get_column_list(table_name):\n            column_name = column[\"column_name\"]\n            if column_name not in column_tables_mapping:\n                column_tables_mapping[column_name] = []\n            column_tables_mapping[column_name].append(table_name)\n    return column_tables_mapping\n\n\ndef build_columns_retriever(catalog, vector_db_path: str = None):\n    \"\"\"Build BM25 and vector retrievers for columns.\n\n    Args:\n        catalog: Catalog store instance.\n        vector_db_path: Path to the vector database file.\n\n    Returns:\n        tuple: (bm25, vector_db, columns, col_dict)\n    \"\"\"\n    columns, col_dict, column_tokens, embedding_keys = get_columns_metadata(catalog)\n\n    bm25 = BM25Okapi(column_tokens)\n\n    log(\"Building vector database for columns...\")\n    vector_db = create_vector_db(\n        embedding_keys,\n        get_embedding_model(),\n        metadatas=columns,\n        collection_name=\"columns\",\n        collection_metadata={\"hnsw:space\": \"cosine\"},\n        chroma_db_path=vector_db_path,\n    )\n\n    return bm25, vector_db, columns, col_dict\n"
  },
  {
    "path": "openchatbi/catalog/schema_retrival.py",
    "content": "\"\"\"Schema and column retrieval functionality for finding relevant database structures.\"\"\"\n\nimport os\nimport re\n\nimport Levenshtein\n\nfrom openchatbi import config\nfrom openchatbi.catalog.retrival_helper import build_column_tables_mapping, build_columns_retriever\nfrom openchatbi.text_segmenter import _segmenter\nfrom openchatbi.utils import log\n\n# Skip build during documentation build\nif not os.environ.get(\"SPHINX_BUILD\"):\n    try:\n        _catalog_store = config.get().catalog_store\n    except ValueError:\n        _catalog_store = None\nelse:\n    _catalog_store = None\n\nif _catalog_store:\n    bm25, vector_db, columns, col_dict = build_columns_retriever(_catalog_store, config.get().vector_db_path)\n    column_tables_mapping = build_column_tables_mapping(_catalog_store)\nelse:\n    bm25, vector_db, columns, col_dict = None, None, [], {}\n    column_tables_mapping = {}\n\n\ndef column_retrieval(query, db, k=10, threshold=0.5, filter=None):\n    \"\"\"Retrieves relevant columns based on a similarity search.\n\n    Args:\n        query (str): The query string to search for.\n        db: The vector database to search in.\n        k (int, optional): The number of top results to return. Defaults to 10.\n        threshold (float, optional): The similarity threshold for filtering results. Defaults to 0.5.\n        filter (dict, optional): A filter to apply to the search. Defaults to None.\n\n    Returns:\n        list: List of relevant column names.\n    \"\"\"\n    log(f\"Get the top relevant columns for query: {query}\")\n    similar_column_key_scores = db.similarity_search_with_score(query, k=k, filter=filter)\n    # log(f\"similar_column_key_scores: {similar_column_key_scores}\")\n    column_names = [key.metadata[\"column_name\"] for (key, score) in similar_column_key_scores if score < threshold]\n    log(f\"Filtered relevant columns: {column_names}\")\n    return column_names\n\n\ndef merge_list(list1, list2):\n    return list(set(list1 + list2))\n\n\ndef edit_distance_score(key1, key2):\n    \"\"\"Calculate normalized edit distance score between two strings.\n\n    Returns:\n        float: Score between 0 (identical) and 1 (completely different).\n    \"\"\"\n    dist = Levenshtein.distance(key1, key2)\n    max_len = max(len(key1), len(key2))\n    return dist / max_len if max_len > 0 else 1\n\n\ndef edit_distance_search(keywords_list, top_k=10, threshold=0.5):\n    \"\"\"Searches for columns using edit distance similarity.\n\n    Args:\n        keywords_list (list): List of keywords to search for.\n        top_k (int, optional): The number of top results to return per keyword. Defaults to 10.\n        threshold (float, optional): The maximum edit distance score to consider. Defaults to 0.5.\n\n    Returns:\n        list: List of relevant column names.\n    \"\"\"\n    keys = set([re.sub(r\"(_id|_name| id| name)$\", \"\", key.lower()) for key in keywords_list])\n    column_similarity_score = set()\n    for key in keys:\n        key_column_similarity_score = {}\n        for column_name, row in col_dict.items():\n            column_name_score = edit_distance_score(\n                key, re.sub(r\"(_id|_name| id| name)$\", \"\", row.get(\"column_name\", \"\"))\n            )\n            display_score = edit_distance_score(\n                key, re.sub(r\"(_id|_name| id| name)$\", \"\", row.get(\"display_name\", \"\").lower())\n            )\n            if column_name_score < threshold or display_score < threshold:\n                key_column_similarity_score[column_name] = min(column_name_score, display_score)\n        key_top_column = [\n            key for key, _ in sorted(key_column_similarity_score.items(), key=lambda x: x[1], reverse=True)[:top_k]\n        ]\n        column_similarity_score.update(key_top_column)\n    return list(column_similarity_score)\n\n\ndef bm25_search(query_list, top_k=5, score_threshold=0.5):\n    \"\"\"Performs a BM25 search on columns based on the query.\n\n    Args:\n        query_list (list): List of query terms.\n        top_k (int, optional): The number of top results to return. Defaults to 5.\n        score_threshold (float, optional): The minimum BM25 score to consider. Defaults to 0.5.\n\n    Returns:\n        list: List of relevant column names.\n    \"\"\"\n    query_tokens = [token for token in _segmenter.cut(\" \".join(query_list)) if token not in (\"_\", \" \")]\n    scores = bm25.get_scores(query_tokens)\n    ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)\n    results = []\n    for idx, score in ranked[:top_k]:\n        if score_threshold and score < score_threshold:\n            continue\n        results.append(columns[idx][\"column_name\"])\n    return results\n\n\ndef get_relevant_columns(keywords_list, dimensions, metrics):\n    \"\"\"Get the most relevant columns for given keywords, dimensions, and metrics.\n\n    Uses multiple retrieval methods (BM25, edit distance, vector similarity)\n    to find the best matching columns.\n\n    Args:\n        keywords_list (list): General keywords to search for.\n        dimensions (list): Dimension-specific keywords.\n        metrics (list): Metric-specific keywords.\n\n    Returns:\n        list: Relevant column names.\n    \"\"\"\n    # 1. BM25 search for general keywords\n    total_results = bm25_search(keywords_list, top_k=len(keywords_list) * 4)\n\n    # 2. Edit distance search for exact matches\n    keyword_len = len(keywords_list + dimensions + metrics)\n    ed_results = edit_distance_search(keywords_list + dimensions + metrics, top_k=keyword_len, threshold=0.3)\n    total_results = merge_list(total_results, ed_results)\n\n    # 3. Vector similarity search for dimensions\n    if dimensions:\n        d_results = column_retrieval(\" \".join(dimensions), vector_db, k=10, filter={\"category\": \"dimension\"})\n        total_results = merge_list(total_results, d_results)\n\n    # 4. Vector similarity search for metrics\n    if metrics:\n        m_results = column_retrieval(\" \".join(metrics), vector_db, k=10, threshold=0.55, filter={\"category\": \"metric\"})\n        total_results = merge_list(total_results, m_results)\n\n    log(f\"Relevant columns: {total_results}\")\n    return total_results\n"
  },
  {
    "path": "openchatbi/catalog/store/__init__.py",
    "content": "\"\"\"Catalog store implementations.\"\"\"\n\nfrom .file_system import FileSystemCatalogStore\n"
  },
  {
    "path": "openchatbi/catalog/store/file_system.py",
    "content": "\"\"\"File system-based catalog store implementation.\"\"\"\n\nimport csv\nimport logging\nimport os\nimport re\nimport traceback\nfrom typing import Any\n\nimport yaml\nfrom sqlalchemy import Engine\n\nfrom ..catalog_store import CatalogStore, split_db_table_name\nfrom ..helper import create_sqlalchemy_engine_instance\n\nlogger = logging.getLogger(__name__)\n\n\nclass FileSystemCatalogStore(CatalogStore):\n    \"\"\"File system-based data catalog storage implementation.\n\n    Stores catalog data in CSV and YAML files on the local filesystem.\n    \"\"\"\n\n    data_path: str\n    table_info_file: str\n    sql_example_file: str\n    table_selection_example_file: str\n    table_columns_file: str\n    common_columns_file: str\n    table_spec_columns_file: str\n\n    _table_info_cache: dict | None\n    _table_columns_cache: dict | None\n    _common_columns_cache: dict | None\n    _table_spec_columns_cache: dict | None\n    _sql_example_cache: dict | None\n    _table_selection_example_cache: dict | None\n\n    _data_warehouse_config: dict\n    _sql_engine: Engine\n\n    def __init__(self, data_path: str, data_warehouse_config: dict):\n        \"\"\"Initialize filesystem catalog store.\n\n        Args:\n            data_path (str): Directory absolute path for storing catalog files.\n            data_warehouse_config (dict): Data warehouse configuration dictionary with keys:\n                - uri (str): Database connection URI\n                - include_tables (Optional[List[str]]): List of tables to include, if None include all\n                - database_name (Optional[str]): Database name to use in catalog\n        \"\"\"\n        if not isinstance(data_path, str) or not data_path.strip():\n            raise ValueError(\"data_path must be a non-empty string\")\n\n        if data_warehouse_config is None:\n            data_warehouse_config = {}\n        elif not isinstance(data_warehouse_config, dict):\n            raise ValueError(\"data_warehouse_config must be a dictionary\")\n\n        self.data_path = data_path.strip()\n        self.table_info_file = os.path.join(data_path, \"table_info.yaml\")\n        self.sql_example_file = os.path.join(data_path, \"sql_example.yaml\")\n        self.table_selection_example_file = os.path.join(data_path, \"table_selection_example.csv\")\n        self.table_columns_file = os.path.join(data_path, \"table_columns.csv\")\n        self.common_columns_file = os.path.join(data_path, \"common_columns.csv\")\n        self.table_spec_columns_file = os.path.join(data_path, \"table_spec_columns.csv\")\n\n        # Ensure directory exists with proper error handling\n        try:\n            os.makedirs(self.data_path, exist_ok=True)\n        except (OSError, PermissionError) as e:\n            raise RuntimeError(f\"Failed to create data directory '{self.data_path}': {e}\") from e\n\n        # Initialize cache\n        self._table_info_cache = None\n        self._table_columns_cache = None\n        self._common_columns_cache = None\n        self._table_spec_columns_cache = None\n        self._sql_example_cache = None\n        self._table_selection_example_cache = None\n\n        self._data_warehouse_config = data_warehouse_config\n        try:\n            self._sql_engine = create_sqlalchemy_engine_instance(data_warehouse_config)\n        except Exception as e:\n            logger.warning(f\"Failed to create SQL engine: {e}. Some catalog operations may not work.\")\n            self._sql_engine = None\n\n    def _clear_cache(self) -> None:\n        \"\"\"\n        Clear all cached data to ensure consistency after data modifications\n        \"\"\"\n        self._table_info_cache = None\n        self._table_columns_cache = None\n        self._common_columns_cache = None\n        self._table_spec_columns_cache = None\n        self._sql_example_cache = None\n        self._table_selection_example_cache = None\n        logger.debug(\"Cleared all caches\")\n\n    def get_data_warehouse_config(self) -> dict:\n        return self._data_warehouse_config\n\n    def get_sql_engine(self) -> Engine:\n        if self._sql_engine is None:\n            raise RuntimeError(\"SQL engine is not available. Check data warehouse configuration.\")\n        return self._sql_engine\n\n    def _validate_table_name(self, table: str) -> bool:\n        \"\"\"\n        Validate table name\n\n        Args:\n            table (str): Table name\n\n        Returns:\n            bool: Whether the table name is valid\n\n        Raises:\n            ValueError: If table name is invalid\n        \"\"\"\n        if not table or not isinstance(table, str):\n            raise ValueError(\"Table name must be a non-empty string\")\n\n        # Check for invalid characters (allow dots for db.table format)\n        invalid_chars = [\"/\", \"\\\\\", \"*\", \"?\", \"<\", \">\", \"|\", '\"', \"'\"]\n        if any(char in table for char in invalid_chars):\n            raise ValueError(f\"Table name contains invalid characters: {table}\")\n\n        return True\n\n    def _validate_column_data(self, columns: list[dict[str, Any]]) -> bool:\n        \"\"\"\n        Validate column data format\n\n        Args:\n            columns (List[Dict[str, Any]]): List of column information\n\n        Returns:\n            bool: Whether the column data is valid\n\n        Raises:\n            ValueError: If column data is invalid\n        \"\"\"\n        if not isinstance(columns, list):\n            raise ValueError(\"Columns must be a list\")\n\n        required_fields = {\"column_name\", \"type\"}\n\n        for i, column in enumerate(columns):\n            if not isinstance(column, dict):\n                raise ValueError(f\"Column {i} must be a dictionary\")\n\n            # Check required fields\n            missing_fields = required_fields - set(column.keys())\n            if missing_fields:\n                raise ValueError(f\"Column {i} missing required fields: {missing_fields}\")\n\n            # Validate column_name\n            column_name = column.get(\"column_name\")\n            if not isinstance(column_name, str) or not column_name.strip():\n                raise ValueError(f\"Column {i}: column_name must be a non-empty string\")\n\n            # Validate type\n            column_type = column.get(\"type\")\n            if not isinstance(column_type, str) or not column_type.strip():\n                raise ValueError(f\"Column {i}: type must be a non-empty string\")\n\n        return True\n\n    def _validate_table_information(self, information: dict[str, Any]) -> bool:\n        \"\"\"\n        Validate table information format\n\n        Args:\n            information (Dict[str, Any]): Table information\n\n        Returns:\n            bool: Whether the table information is valid\n\n        Raises:\n            ValueError: If table information is invalid\n        \"\"\"\n        if not isinstance(information, dict):\n            raise ValueError(\"Table information must be a dictionary\")\n\n        # Validate optional string fields\n        string_fields = [\"description\", \"selection_rule\"]\n        for field in string_fields:\n            if field in information:\n                value = information[field]\n                if value is not None and not isinstance(value, str):\n                    raise ValueError(f\"Table information field '{field}' must be a string or None\")\n\n        return True\n\n    def _validate_sql_examples(self, examples: list[dict[str, str]]) -> bool:\n        \"\"\"\n        Validate SQL examples format\n\n        Args:\n            examples (List[Dict[str, str]]): List of SQL examples\n\n        Returns:\n            bool: Whether the SQL examples are valid\n\n        Raises:\n            ValueError: If SQL examples are invalid\n        \"\"\"\n        if not isinstance(examples, list):\n            raise ValueError(\"Examples must be a list\")\n\n        required_fields = {\"question\", \"answer\"}\n\n        for i, example in enumerate(examples):\n            if not isinstance(example, dict):\n                raise ValueError(f\"Example {i} must be a dictionary\")\n\n            # Check required fields\n            missing_fields = required_fields - set(example.keys())\n            if missing_fields:\n                raise ValueError(f\"Example {i} missing required fields: {missing_fields}\")\n\n            # Validate fields are non-empty strings\n            for field in required_fields:\n                value = example.get(field)\n                if not isinstance(value, str) or not value.strip():\n                    raise ValueError(f\"Example {i}: {field} must be a non-empty string\")\n\n        return True\n\n    @staticmethod\n    def _load_yaml_file(file_path: str) -> dict:\n        \"\"\"\n        Load YAML file\n\n        Args:\n            file_path (str): File path\n\n        Returns:\n            Dict: YAML content\n        \"\"\"\n        if not os.path.exists(file_path):\n            logger.debug(f\"YAML file does not exist: {file_path}\")\n            return {}\n\n        try:\n            with open(file_path, encoding=\"utf-8\") as f:\n                data = yaml.safe_load(f) or {}\n                logger.debug(f\"Successfully loaded YAML file: {file_path}\")\n                return data\n        except Exception as e:\n            logger.error(f\"Failed to load YAML file {file_path}: {e}\")\n            logger.error(traceback.format_stack())\n            return {}\n\n    @staticmethod\n    def _load_csv_file(file_path: str) -> list[dict[str, str]]:\n        \"\"\"\n        Load CSV file\n\n        Args:\n            file_path (str): File path\n\n        Returns:\n            List[Dict[str, str]]: List of rows as dictionaries\n        \"\"\"\n        if not os.path.exists(file_path):\n            logger.debug(f\"CSV file does not exist: {file_path}\")\n            return []\n\n        try:\n            result = []\n            with open(file_path, encoding=\"utf-8\") as f:\n                reader = csv.DictReader(f)\n                for row in reader:\n                    result.append(row)\n            logger.debug(f\"Successfully loaded CSV file: {file_path} with {len(result)} rows\")\n            return result\n        except Exception as e:\n            logger.error(f\"Failed to load CSV file {file_path}: {e}\")\n            logger.error(traceback.format_stack())\n            return []\n\n    @staticmethod\n    def _save_yaml_file(file_path: str, data: dict) -> bool:\n        \"\"\"\n        Save YAML file\n\n        Args:\n            file_path (str): File path\n            data (Dict): Data to save\n\n        Returns:\n            bool: Whether the save was successful\n        \"\"\"\n        try:\n            with open(file_path, \"w\", encoding=\"utf-8\") as f:\n                yaml.dump(data, f, default_flow_style=False, allow_unicode=True)\n            return True\n        except Exception as e:\n            logger.error(f\"Failed to save YAML file {file_path}: {e}\")\n            logger.error(traceback.format_stack())\n            return False\n\n    @staticmethod\n    def _save_csv_file(file_path: str, data: list[dict[str, str]], headers: list[str] = None) -> bool:\n        \"\"\"\n        Save CSV file\n\n        Args:\n            file_path (str): File path\n            data (List[Dict[str, str]]): List of rows as dictionaries\n            headers (List[str]): List of header names in sequence\n\n        Returns:\n            bool: Whether the save was successful\n        \"\"\"\n        try:\n            if not data:\n                return True\n\n            # Get all possible headers from all rows\n            all_headers = set()\n            for row in data:\n                all_headers.update(row.keys())\n\n            # If specify field_names, make sure all keys are in field_names\n            if headers is not None:\n                for key in all_headers:\n                    if key not in headers:\n                        headers.append(key)\n\n            with open(file_path, \"w\", encoding=\"utf-8\", newline=\"\") as f:\n                writer = csv.DictWriter(f, fieldnames=headers)\n                writer.writeheader()\n                for row in data:\n                    writer.writerow(row)\n\n            return True\n        except Exception as e:\n            logger.error(f\"Failed to save CSV file {file_path}: {e}\")\n            logger.error(traceback.format_stack())\n            return False\n\n    def _load_tables(self) -> dict[str, list[str]]:\n        # Load table_columns.csv\n        table_columns_csv = self._load_csv_file(self.table_columns_file)\n\n        # Get unique db_name.table_name combinations\n        table_dict = {}\n        for row in table_columns_csv:\n            if \"db_name\" in row and \"table_name\" in row and \"column_name\" in row:\n                db_name = row[\"db_name\"]\n                table_name = row[\"table_name\"]\n                column_name = row[\"column_name\"]\n                full_table_name = f\"{db_name}.{table_name}\"\n                if full_table_name not in table_dict:\n                    table_dict[full_table_name] = []\n                table_dict[full_table_name].append(column_name)\n        return table_dict\n\n    def _load_common_columns(self) -> dict[str, dict[str, Any]]:\n        # Load common_columns.csv to get column details\n        columns_csv = self._load_csv_file(self.common_columns_file)\n\n        # Filter and return column details\n        column_dict = {}\n        for row in columns_csv:\n            if row.get(\"column_name\") and row.get(\"type\"):\n                # Convert row to Dict[str, Any]\n                column_info = {}\n                for key, value in row.items():\n                    if key != \"\":\n                        column_info[key] = value\n                column_dict[row[\"column_name\"]] = column_info\n\n        return column_dict\n\n    def _load_table_spec_columns(self) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Load info of table spec columns\n        Returns:\n            Dict[str, Dict[str, Any]]: Dictionary of table specific columns information, keyed by \"full_table_name:column_name\"\n        \"\"\"\n        # Load table_spec_columns.csv to get table specific column details\n        columns_csv = self._load_csv_file(self.table_spec_columns_file)\n\n        # Filter and return column details\n        column_dict = {}\n        for row in columns_csv:\n            if \"db_name\" in row and \"table_name\" in row and \"column_name\" in row and row[\"column_name\"]:\n                # Convert row to Dict[(str, str), Any]\n                full_table_name = f\"{row['db_name']}.{row['table_name']}\"\n                column_info = {}\n                for key, value in row.items():\n                    if key != \"\":\n                        column_info[key] = value\n                column_dict[f\"{full_table_name}:{row['column_name']}\"] = column_info\n\n        return column_dict\n\n    def _parse_example_text(self, example_text: str) -> list[tuple[str, str]]:\n        \"\"\"\n        Parse example text, format is Q: ... A: ...\n\n        Args:\n            example_text (str): Example text\n\n        Returns:\n            List[Tuple[str, str]]: List of parsed question-answer pairs\n        \"\"\"\n        examples = []\n        lines = example_text.strip().split(\"\\n\")\n\n        question = \"\"\n        answer = \"\"\n        current_type = None\n\n        for line in lines:\n            if line.startswith(\"Q:\"):\n                # If there is already a complete question-answer pair, add it to the results\n                if question and answer:\n                    examples.append((question.strip(), answer.strip()))\n                    question = \"\"\n                    answer = \"\"\n\n                question = line[2:]\n                current_type = \"Q\"\n            elif line.startswith(\"A:\"):\n                answer = line[2:]\n                current_type = \"A\"\n            else:\n                # Continue adding to the current type\n                if current_type == \"Q\":\n                    question += \"\\n\" + line\n                elif current_type == \"A\":\n                    answer += \"\\n\" + line\n\n        # Add the last question-answer pair\n        if question and answer:\n            examples.append((question.strip(), answer.strip()))\n\n        return examples\n\n    def get_database_list(self) -> list[str]:\n        # Extract unique database names\n        databases = set()\n        for table in self._get_all_table_schema().keys():\n            full_table_name, db_name, table_name = split_db_table_name(table)\n            databases.add(db_name)\n\n        return list(databases)\n\n    def _get_all_table_schema(self) -> dict[str, list[str]]:\n        \"\"\"\n        Get all tables schema (columns of table)\n        Returns:\n            Dict[str, List[str]]: Tables schema (columns) dict, keyed by table name\n        \"\"\"\n        if self._table_columns_cache is None:\n            self._table_columns_cache = self._load_tables()\n        # Return a deep copy to prevent external modifications\n        return {k: v.copy() for k, v in self._table_columns_cache.items()}\n\n    def get_table_list(self, database: str | None = None) -> list[str]:\n        tables = self._get_all_table_schema()\n        if database is None:\n            return list(tables.keys())\n\n        # Filter by database\n        filtered_tables = []\n        for full_table_name in tables.keys():\n            _, db_name, table_name = split_db_table_name(full_table_name)\n            if db_name == database:\n                filtered_tables.append(full_table_name)\n\n        return filtered_tables\n\n    def _get_common_columns(self) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Get information of all common columns\n        Returns:\n            Dict[str, Dict[str, Any]]: Dictionary of columns information, keyed by column name\n        \"\"\"\n        if self._common_columns_cache is None:\n            self._common_columns_cache = self._load_common_columns()\n        # Return a deep copy to prevent external modifications\n        return {k: v.copy() for k, v in self._common_columns_cache.items()}\n\n    def _get_table_spec_columns(self) -> dict[str, dict[str, Any]]:\n        \"\"\"\n        Get information of all table specific columns\n        Returns:\n            Dict[str, Dict[str, Any]]: Dictionary of table specific columns information, keyed by \"full_table_name:column_name\"\n        \"\"\"\n        if self._table_spec_columns_cache is None:\n            self._table_spec_columns_cache = self._load_table_spec_columns()\n        # Return a deep copy to prevent external modifications\n        return {k: v.copy() for k, v in self._table_spec_columns_cache.items()}\n\n    def get_column_list(self, table: str | None = None, database: str | None = None) -> list[dict[str, Any]]:\n        _common_columns = self._get_common_columns()\n        if table is None:\n            return list(_common_columns.values())\n\n        # Get the full table name\n        full_table_name, db_name, table_name = split_db_table_name(table, database)\n\n        # Filter table columns\n        tables_dict = self._get_all_table_schema()\n        if full_table_name not in tables_dict:\n            return []\n\n        table_columns = tables_dict[full_table_name]\n\n        # If no columns found, return empty list\n        if not table_columns:\n            return []\n\n        # Filter and return column details\n        result = []\n        _table_spec_columns = self._get_table_spec_columns()\n        for column in table_columns:\n            # check if the column is table specific\n            key = f\"{full_table_name}:{column}\"\n            if key in _table_spec_columns:\n                column_info = _table_spec_columns[key]\n                column_info[\"is_common\"] = False\n                result.append(column_info)\n            else:\n                column_info = _common_columns.get(column)\n                if column_info:\n                    column_info[\"is_common\"] = True\n                    result.append(column_info)\n        return result\n\n    def get_table_information(self, table: str, database: str | None = None) -> dict[str, Any]:\n        full_table_name, db_name, table_name = split_db_table_name(table, database)\n\n        if self._table_info_cache is None:\n            self._table_info_cache = self._load_yaml_file(self.table_info_file)\n\n        if db_name in self._table_info_cache and table_name in self._table_info_cache[db_name]:\n            # Return a copy to prevent external modifications\n            return self._table_info_cache[db_name][table_name].copy()\n\n        return {}\n\n    def get_sql_examples(\n        self, table: str | None = None, database: str | None = None\n    ) -> list[tuple[str, str, list[str]]]:\n        if self._sql_example_cache is None:\n            self._sql_example_cache = self._load_yaml_file(self.sql_example_file)\n\n        if table is None:\n            # If no table specified, return all examples\n            examples = []\n            for db_name, tables in self._sql_example_cache.items():\n                for table_name, example_text in tables.items():\n                    qa_pairs = self._parse_example_text(example_text)\n                    examples.extend([(q, a, [f\"{db_name}.{table_name}\"]) for (q, a) in qa_pairs])\n            return examples\n\n        full_table_name, db_name, table_name = split_db_table_name(table, database)\n\n        # Find examples that include this table\n        examples = []\n\n        # Check the fact section\n        if db_name in self._sql_example_cache:\n            if table_name in self._sql_example_cache[db_name]:\n                # Parse example text, format is Q: ... A: ...\n                qa_pairs = self._parse_example_text(self._sql_example_cache[db_name][table_name])\n                examples.extend([(q, a, [full_table_name]) for (q, a) in qa_pairs])\n\n        return examples\n\n    @staticmethod\n    def _load_table_selection_examples_from_csv(file_path: str) -> list[tuple[str, list[str]]]:\n        examples = []\n        try:\n            with open(file_path, encoding=\"utf-8\") as f:\n                reader = csv.DictReader(f)\n                for row in reader:\n                    question = row.get(\"question\", \"\").strip()\n                    selected_tables = row.get(\"selected_tables\", \"\").strip()\n                    if question and selected_tables:\n                        table_list = [p.strip() for p in re.split(r\"[ ,\\n]\", selected_tables) if p.strip()]\n                        examples.append((question, table_list))\n        except (FileNotFoundError, PermissionError, UnicodeDecodeError) as e:\n            logger.warning(f\"Failed to load table selection examples from {file_path}: {e}\")\n        return examples\n\n    def get_table_selection_examples(self) -> list[tuple[str, list[str]]]:\n        if self._table_selection_example_cache is None:\n            self._table_selection_example_cache = self._load_table_selection_examples_from_csv(\n                self.table_selection_example_file\n            )\n        return self._table_selection_example_cache\n\n    def save_table_information(\n        self,\n        table: str,\n        information: dict[str, Any],\n        columns: list[dict[str, Any]],\n        database: str | None = None,\n        update_existing: bool = False,\n    ) -> bool:\n        # Validate input data (let validation errors propagate)\n        self._validate_table_name(table)\n        self._validate_table_information(information)\n        self._validate_column_data(columns)\n\n        try:\n            full_table_name, db_name, table_name = split_db_table_name(table, database)\n\n            table_info = self._load_yaml_file(self.table_info_file)\n\n            # Save columns first\n            if not self._save_columns(table_name, columns, db_name, update_existing):\n                logger.error(f\"Failed to save columns for table {full_table_name}\")\n                return False\n\n            # Save table information (ensure proper structure)\n            if db_name not in table_info:\n                table_info[db_name] = {}\n            if update_existing or table_name not in table_info[db_name]:\n                table_info[db_name][table_name] = information\n            success = self._save_yaml_file(self.table_info_file, table_info)\n\n            if success:\n                logger.info(f\"Successfully saved table information for {full_table_name}\")\n                # Clear cache to ensure consistency\n                self._clear_cache()\n\n            return success\n        except Exception as e:\n            logger.error(f\"Unexpected error when saving table information: {e}\")\n            logger.error(traceback.format_stack())\n            return False\n\n    def _save_columns(\n        self, table_name: str, columns: list[dict[str, Any]], db_name: str = \"\", update_existing: bool = False\n    ) -> bool:\n        \"\"\"\n        Save columns information to common_columns.csv and columns of tables to table_columns.csv\n\n        Args:\n            table_name (str): Table name\n            columns (List[Dict[str, Any]]): List of column information\n            db_name (str): Database name\n            update_existing (bool): Update existing column information\n\n        Returns:\n            bool: Whether the save was successful\n        \"\"\"\n        full_table_name, db_name, table_name = split_db_table_name(table_name, db_name)\n        # Load existing data\n        tables_data = self._load_csv_file(self.table_columns_file)\n        common_columns_dict = self._load_common_columns()\n        table_spec_columns_dict = self._load_table_spec_columns()\n\n        # Create a set of existing table-column combinations\n        existing_table_columns = set()\n        for row in tables_data:\n            if \"db_name\" in row and \"table_name\" in row and \"column_name\" in row:\n                key = f\"{row['db_name']}.{row['table_name']}:{row['column_name']}\"\n                existing_table_columns.add(key)\n\n        # Update table_columns.csv and track new columns to add\n\n        for column in columns:\n            if \"column_name\" not in column:\n                continue\n\n            column_name = column[\"column_name\"]\n            is_common_column = column.get(\"is_common\", False)\n\n            key = f\"{full_table_name}:{column_name}\"\n            column_info = {k: str(v) for k, v in column.items() if k != \"is_common\"}\n            if not is_common_column:\n                column_info[\"db_name\"] = db_name\n                column_info[\"table_name\"] = table_name\n\n            # New column of the table -> add to table_columns.csv\n            if key not in existing_table_columns:\n                tables_data.append({\"db_name\": db_name, \"table_name\": table_name, \"column_name\": column_name})\n                existing_table_columns.add(key)\n                if is_common_column:\n                    # Handle common_columns.csv - avoid duplicates\n                    if column_name not in common_columns_dict:\n                        # Add new columns to columns_data\n                        logger.info(f\"Add new column column {column_name}\")\n                        common_columns_dict[column_name] = column_info\n                else:\n                    table_spec_columns_dict[key] = column_info\n            # Apply updates to existing columns in columns_data\n            elif update_existing:\n                if is_common_column:\n                    common_columns_dict[column_name] = column_info\n                else:\n                    table_spec_columns_dict[key] = column_info\n\n        # Save updated data\n        tables_success = self._save_csv_file(\n            self.table_columns_file, tables_data, [\"db_name\", \"table_name\", \"column_name\"]\n        )\n        common_columns_success = self._save_csv_file(\n            self.common_columns_file,\n            list(common_columns_dict.values()),\n            [\"column_name\", \"display_name\", \"alias\", \"type\", \"category\", \"tag\", \"description\"],\n        )\n        table_spec_columns_success = self._save_csv_file(\n            self.table_spec_columns_file,\n            list(table_spec_columns_dict.values()),\n            [\"db_name\", \"table_name\", \"column_name\", \"display_name\", \"alias\", \"type\", \"category\", \"tag\", \"description\"],\n        )\n\n        success = tables_success and common_columns_success and table_spec_columns_success\n        if success:\n            # Clear cache to ensure consistency\n            self._clear_cache()\n            logger.debug(f\"Successfully saved columns for table {table_name}\")\n\n        return success\n\n    def save_table_sql_examples(self, table: str, examples: list[dict[str, str]], database: str | None = None) -> bool:\n        # Validate input data (let validation errors propagate)\n        self._validate_table_name(table)\n        self._validate_sql_examples(examples)\n\n        try:\n            full_table_name, db_name, table_name = split_db_table_name(table, database)\n\n            sql_examples = self._load_yaml_file(self.sql_example_file)\n\n            # Ensure database exists in structure\n            if db_name not in sql_examples:\n                sql_examples[db_name] = {}\n\n            # example text\n            example_text = \"\"\n            for example in examples:\n                example_text += f\"Q: {example['question']}\\nA: {example['answer']}\\n\\n\"\n\n            sql_examples[db_name][table_name] = example_text.strip()\n\n            success = self._save_yaml_file(self.sql_example_file, sql_examples)\n\n            if success:\n                logger.info(f\"Successfully saved {len(examples)} examples for table {full_table_name}\")\n                # Update cache\n                self._sql_example_cache = sql_examples\n\n            return success\n        except Exception as e:\n            logger.error(f\"Unexpected error when saving table examples: {e}\")\n            logger.error(traceback.format_stack())\n            return False\n\n    def save_table_selection_examples(self, examples: list[tuple[str, list[str]]]) -> bool:\n        example_data = []\n        for example in examples:\n            example_data.append({\"question\": example[0], \"selected_tables\": example[1]})\n        save_success = self._save_csv_file(\n            self.table_selection_example_file, example_data, [\"question\", \"selected_tables\"]\n        )\n        if save_success:\n            logger.info(f\"Successfully saved {len(examples)} table selection examples.\")\n        return save_success\n\n    def check_exists(self) -> bool:\n        try:\n            # Check if essential catalog files exist and have content\n            files_missing = (\n                not os.path.exists(self.table_columns_file)\n                or not os.path.exists(self.common_columns_file)\n                or os.path.getsize(self.table_columns_file) <= 1  # Empty or just header\n                or os.path.getsize(self.common_columns_file) <= 1\n            )\n\n            return not files_missing\n\n        except Exception as e:\n            logger.warning(f\"Error checking catalog existence: {e}\")\n            logger.error(traceback.format_stack())\n            return False\n"
  },
  {
    "path": "openchatbi/catalog/token_service.py",
    "content": "\"\"\"Token service for authentication with external services.\"\"\"\n\nimport json\n\nimport requests\n\n\nclass TokenService:\n    \"\"\"Service for managing authentication tokens.\n\n    Handles token application, validation, and authentication\n    with external services.\n    \"\"\"\n\n    base_url = None\n    token = None\n    user_name = None\n    password = None\n\n    def __init__(self, user_name: str, password: str):\n        \"\"\"Initialize token service.\"\"\"\n        self.user_name = user_name\n        self.password = password\n\n    def apply_token(self):\n        \"\"\"Apply for authentication token using credentials.\"\"\"\n        response = requests.post(\n            self.base_url + \"/apply_token\", data=json.dumps({\"user_name\": self.user_name, \"password\": self.password})\n        )\n        resp_json = response.json()\n        self.token = resp_json.get(\"token\")\n\n\ndef apply_token_for_user(token_url: str, user_name: str, password: str):\n    \"\"\"Apply for token and return token with username.\n\n    Args:\n        token_url (str): Base URL for token service.\n        user_name (str): The user name.\n        password (str): The password.\n\n    Returns:\n        token\n    \"\"\"\n    token_service = TokenService(user_name, password)\n    token_service.base_url = token_url\n    token_service.apply_token()\n    return token_service.token\n"
  },
  {
    "path": "openchatbi/code/docker_executor.py",
    "content": "import os\nimport shutil\nimport subprocess\nimport tempfile\nfrom pathlib import Path\n\nimport docker\nfrom docker.errors import ContainerError\n\nfrom openchatbi.code.executor_base import ExecutorBase\n\n\ndef check_docker_status() -> tuple[bool, str]:\n    \"\"\"\n    Check Docker installation and status without initializing DockerExecutor.\n\n    Returns:\n        Tuple[bool, str]: (is_available, status_message)\n    \"\"\"\n    try:\n        # Check if Docker CLI is installed\n        if not shutil.which(\"docker\"):\n            return False, \"Docker is not installed. Please install Docker.\"\n\n        # Check if Docker daemon is running\n        result = subprocess.run([\"docker\", \"info\"], capture_output=True, text=True, timeout=10)\n\n        if result.returncode == 0:\n            return True, \"Docker is installed and running\"\n        else:\n            if \"Cannot connect to the Docker daemon\" in result.stderr:\n                return False, \"Docker is installed but not running. Please start the Docker daemon.\"\n            else:\n                return False, f\"Docker is not available: {result.stderr.strip()}\"\n\n    except subprocess.TimeoutExpired:\n        return False, \"Docker command timed out. Docker may not be running properly.\"\n    except FileNotFoundError:\n        return False, \"Docker command not found. Please install Docker.\"\n    except Exception as e:\n        return False, f\"Error checking Docker status: {str(e)}\"\n\n\nclass DockerExecutor(ExecutorBase):\n    \"\"\"Docker-based Python code executor for isolated execution.\"\"\"\n\n    def __init__(self, variable: dict = None):\n        super().__init__(variable)\n        self.image_name = \"python-executor\"\n        self.dockerfile_path = Path(__file__).parent.parent.parent / \"Dockerfile.python-executor\"\n\n        # Check Docker installation and status\n        self._check_docker_availability()\n\n        try:\n            self.client = docker.from_env()\n            # Build Docker image if it doesn't exist\n            self._ensure_image_exists()\n        except Exception as e:\n            self._handle_docker_error(e)\n\n    @staticmethod\n    def _check_docker_availability():\n        \"\"\"Check if Docker is installed and available.\"\"\"\n        # Check if Docker CLI is installed\n        if not shutil.which(\"docker\"):\n            raise RuntimeError(\"Docker is not installed. Please install Docker and ensure it's in your system PATH.\")\n\n        # Check if Docker daemon is running\n        try:\n            result = subprocess.run([\"docker\", \"info\"], capture_output=True, text=True, timeout=10)\n            if result.returncode != 0:\n                if \"Cannot connect to the Docker daemon\" in result.stderr:\n                    raise RuntimeError(\n                        \"Docker is installed but not running. Please start the Docker daemon and try again.\"\n                    )\n                else:\n                    raise RuntimeError(\n                        f\"Docker is not available. Please check Docker installation and status. \"\n                        f\"Error: {result.stderr.strip()}\"\n                    )\n        except subprocess.TimeoutExpired:\n            raise RuntimeError(\"Docker command timed out. Please check if Docker is running properly.\")\n        except FileNotFoundError:\n            raise RuntimeError(\"Docker command not found. Please install Docker and ensure it's in your system PATH.\")\n\n    @staticmethod\n    def _handle_docker_error(error: Exception):\n        \"\"\"Handle Docker-related errors with specific error messages.\"\"\"\n        error_str = str(error).lower()\n\n        if \"connection aborted\" in error_str and \"no such file or directory\" in error_str:\n            raise RuntimeError(\"Docker is not running. Please start the Docker daemon and try again.\")\n        elif \"permission denied\" in error_str:\n            raise RuntimeError(\n                \"Permission denied accessing Docker. Please ensure your user has Docker permissions \"\n                \"or try running with appropriate privileges.\"\n            )\n        elif \"docker daemon\" in error_str or \"connection refused\" in error_str:\n            raise RuntimeError(\"Cannot connect to Docker daemon. Please start the Docker daemon and try again.\")\n        else:\n            raise RuntimeError(\n                f\"Failed to initialize Docker client. Please ensure Docker is installed and running. \"\n                f\"Error: {str(error)}\"\n            )\n\n    def _ensure_image_exists(self):\n        \"\"\"Build Docker image if it doesn't exist.\"\"\"\n        try:\n            self.client.images.get(self.image_name)\n        except docker.errors.ImageNotFound:\n            print(f\"Building Docker image '{self.image_name}'...\")\n            self.client.images.build(\n                path=str(self.dockerfile_path.parent),\n                dockerfile=self.dockerfile_path.name,\n                tag=self.image_name,\n                rm=True,\n            )\n            print(f\"Docker image '{self.image_name}' built successfully.\")\n\n    def run_code(self, code: str) -> tuple[bool, str]:\n        \"\"\"Execute Python code in a Docker container.\"\"\"\n        try:\n            # Create a temporary file with the code\n            with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".py\", delete=False) as f:\n                # Add variable definitions to the code\n                variable_code = \"\"\n                for key, value in self._variable.items():\n                    if isinstance(value, str):\n                        variable_code += f'{key} = \"{value}\"\\n'\n                    else:\n                        variable_code += f\"{key} = {repr(value)}\\n\"\n\n                full_code = variable_code + \"\\n\" + code\n                f.write(full_code)\n                temp_file_path = f.name\n\n            try:\n                # Run the code in a Docker container\n                container = self.client.containers.run(\n                    self.image_name,\n                    command=[\"python3\", f\"/app/{os.path.basename(temp_file_path)}\"],\n                    volumes={temp_file_path: {\"bind\": f\"/app/{os.path.basename(temp_file_path)}\", \"mode\": \"ro\"}},\n                    remove=True,\n                    detach=False,\n                    stdout=True,\n                    stderr=True,\n                    network_mode=\"none\",  # Disable network access for security\n                )\n\n                # Get the output\n                output = container.decode(\"utf-8\")\n                return True, output\n\n            except ContainerError as e:\n                # Container exited with non-zero code\n                error_output = e.stderr if e.stderr else str(e)\n                return False, f\"Container execution failed: {error_output}\"\n\n        except Exception as e:\n            return False, f\"Docker execution error: {str(e)}\"\n\n        finally:\n            # Clean up temporary file\n            if \"temp_file_path\" in locals() and os.path.exists(temp_file_path):\n                try:\n                    os.unlink(temp_file_path)\n                except (OSError, PermissionError) as e:\n                    # Log but don't fail the operation for cleanup issues\n                    print(f\"Warning: Failed to clean up temporary file {temp_file_path}: {e}\")\n\n    def __del__(self):\n        \"\"\"Clean up Docker client on deletion.\"\"\"\n        try:\n            if hasattr(self, \"client\") and self.client is not None:\n                self.client.close()\n        except Exception:\n            # Ignore cleanup errors during object destruction\n            pass\n"
  },
  {
    "path": "openchatbi/code/executor_base.py",
    "content": "from typing import Any\n\n\nclass ExecutorBase:\n    \"\"\"Base class for executing python code.\"\"\"\n\n    _variable: dict\n\n    def __init__(self, variable: dict = None):\n        if variable is None:\n            self._variable = {}\n        else:\n            self._variable = variable\n\n    def run_code(self, code: str) -> (bool, str):\n        \"\"\"Execute python code.\"\"\"\n        raise NotImplementedError()\n\n    def set_variable(self, key: str, value: Any) -> None:\n        \"\"\"Set variable.\"\"\"\n        self._variable[key] = value\n"
  },
  {
    "path": "openchatbi/code/local_executor.py",
    "content": "import sys\nfrom io import StringIO\n\nfrom openchatbi.code.executor_base import ExecutorBase\n\n\nclass LocalExecutor(ExecutorBase):\n\n    def run_code(self, code: str) -> str:\n        safe_globals = {\"__builtins__\": __builtins__}\n        original_stdout = sys.stdout\n        output_buffer = StringIO()\n        sys.stdout = output_buffer\n        try:\n            exec(code, safe_globals, safe_globals)\n            output = output_buffer.getvalue()\n            return True, output\n        except Exception as e:\n            return False, str(e)\n        finally:\n            sys.stdout = original_stdout\n"
  },
  {
    "path": "openchatbi/code/restricted_local_executor.py",
    "content": "import sys\nfrom io import StringIO\n\nfrom RestrictedPython import compile_restricted, safe_globals, utility_builtins\nfrom RestrictedPython.Guards import safe_builtins, safer_getattr\n\nfrom openchatbi.code.executor_base import ExecutorBase\n\n\nclass RestrictedLocalExecutor(ExecutorBase):\n\n    def run_code(self, code: str) -> (bool, str):\n        try:\n            # compile restricted code\n            byte_code = compile_restricted(code, \"<string>\", \"exec\")\n            if byte_code is None:\n                return False, \"Failed to compile restricted code\"\n\n            restricted_locals = {}\n            restricted_globals = safe_globals.copy()\n\n            # Set up restricted environment with necessary functions\n            restricted_globals.update(safe_builtins)\n            restricted_globals[\"_getattr_\"] = safer_getattr\n            restricted_globals[\"__builtins__\"] = utility_builtins\n\n            # Add variable definitions to the restricted locals\n            for key, value in self._variable.items():\n                restricted_locals[key] = value\n\n            # Capture print output\n            original_stdout = sys.stdout\n            output_buffer = StringIO()\n            sys.stdout = output_buffer\n\n            # Use the standard print function for RestrictedPython\n            restricted_globals[\"_print_\"] = lambda *args, **kwargs: print(*args, **kwargs)\n\n            exec(byte_code, restricted_globals, restricted_locals)\n            output = output_buffer.getvalue()\n\n            return True, output\n\n        except Exception as e:\n            return False, str(e)\n        finally:\n            if \"original_stdout\" in locals():\n                sys.stdout = original_stdout\n"
  },
  {
    "path": "openchatbi/config.yaml.template",
    "content": "organization: The Company\ndialect: presto\nbi_config_file: example/bi.yaml\n\n# Python Code Execution Configuration\n# Options: \"local\", \"restricted_local\", \"docker\"\n# - local: Run code in the current Python process (fastest, least secure)\n# - restricted_local: Run code with RestrictedPython (moderate security, some limitations)\n# - docker: Run code in isolated Docker containers (slowest, most secure, requires Docker to be installed)\npython_executor: local\n\n# Visualization configuration\n# Options: \"rule\" (rule-based), \"llm\" (LLM-based), or null (skip visualization)\n# visualization_mode: llm\n\n# Context management configuration\n# Controls how conversation context is managed and compressed when it becomes too long\ncontext_config:\n  # Enable/disable context management entirely\n  enabled: true\n\n  # Token limit that triggers context management (when conversation exceeds this, compression starts)\n  summary_trigger_tokens: 12000\n\n  # Number of recent messages to always preserve in full (never compress these)\n  keep_recent_messages: 20\n\n  # Historical tool output compression limits\n  max_tool_output_length: 2000  # Max length for historical tool outputs\n  max_sql_result_rows: 50       # Max rows to keep in CSV results\n  max_code_output_lines: 50     # Max lines for code execution output\n\n  # Conversation summarization settings\n  enable_summarization: true         # Enable conversation summarization\n  enable_conversation_summary: true  # Enable detailed conversation summary\n  summary_max_messages: 50           # Max messages to include in summary context\n\n  # Content preservation settings\n  preserve_tool_errors: true    # Always preserve error messages in full\n  preserve_recent_sql: true     # Preserve SQL content (less aggressive compression)\n\n# Time Series Forecasting Service Configuration\n# URL for the time series forecasting service endpoint, adjust based on your deployment scenario:\n# - Local development (OpenChatBI on host, Forecasting service in Docker): \"http://localhost:8765\"\n# - Remote service: \"http://your-service-host:8765\"\ntimeseries_forecasting_service_url: \"http://localhost:8765\"\n\n# Catalog store configuration\ncatalog_store:\n  store_type: file_system\n  data_path: ./example\n\n# Data warehouse configuration\ndata_warehouse_config:\n  uri: \"presto://{user_name}@domain:8080/db/default\"\n  include_tables:\n    - null  # null means include all tables, or specify yaml list\n  database_name: \"db.default\"  # database name to use in catalog\n  token_service: \"https://tokens-domain:8080/v1\"\n  user_name: TOKEN_SERVICE_USER_NAME\n  password: TOKEN_SERVICE_PASSWORD\n\n# Vector database (chroma) path\n# vector_db_path: ./.chroma_db\n\n# LLM configurations (multiple providers)\n#\n# 1) Define providers under `llm_providers`\n# 2) Select which one to use by setting `default_llm: <provider_name>`\ndefault_llm: openai\nllm_providers:\n  openai:\n    default_llm:\n      class: langchain_openai.ChatOpenAI\n      params:\n        api_key: YOUR_API_KEY_HERE\n        model: gpt-4.1\n        temperature: 0.01\n        max_tokens: 8192\n    embedding_model:\n      class: langchain_openai.OpenAIEmbeddings\n      params:\n        api_key: YOUR_API_KEY_HERE\n        model: text-embedding-3-large\n        chunk_size: 1024\n    # Optional\n    text2sql_llm:\n      class: langchain_openai.ChatOpenAI\n      params:\n        api_key: YOUR_API_KEY_HERE\n        model: gpt-4.1\n        temperature: 0.0\n        max_tokens: 8192\n  # anthropic:\n  #   default_llm:\n  #     class: langchain_anthropic.ChatAnthropic\n  #     params:\n  #       api_key: YOUR_API_KEY_HERE\n  #       model: claude-3-5-sonnet-latest\n\n# MCP (Model Context Protocol) server configurations\nmcp_servers:\n  # File system MCP server (stdio transport)\n  - name: filesystem\n    transport: stdio\n    command: [\"npx\", \"-y\", \"@modelcontextprotocol/server-filesystem\"]\n    args: [\"--path\", \"/tmp\"]\n    enabled: false\n    timeout: 30\n  \n\n  # Example HTTP-based MCP server (streamable_http transport)\n  - name: weather\n    transport: streamable_http\n    url: \"http://localhost:8000/mcp/\"\n    headers:\n      Authorization: \"Bearer YOUR_TOKEN\"\n    enabled: false\n    timeout: 30\n"
  },
  {
    "path": "openchatbi/config_loader.py",
    "content": "import importlib\nimport os\nfrom importlib.util import find_spec\nfrom typing import Any\nfrom unittest.mock import MagicMock\n\nfrom langchain_core.language_models import BaseChatModel\nfrom pydantic import BaseModel\n\nfrom openchatbi.catalog.factory import create_catalog_store\nfrom openchatbi.utils import log\n\n\nclass LLMProviderConfig(BaseModel):\n    \"\"\"Resolved LLM objects for a single provider.\"\"\"\n\n    model_config = {\"arbitrary_types_allowed\": True}\n\n    default_llm: BaseChatModel | MagicMock\n    embedding_model: BaseModel | MagicMock | None = None\n    text2sql_llm: BaseChatModel | MagicMock | None = None\n\n\nclass Config(BaseModel):\n    \"\"\"Configuration model for the OpenChatBI application.\n\n    Attributes:\n        organization (str): Organization name. Defaults to \"The Company\".\n        dialect (str): SQL dialect to use. Defaults to \"presto\".\n        default_llm (BaseChatModel): Default language model for general tasks.\n        embedding_model (BaseModel): Language model for embedding generation.\n        text2sql_llm (Optional[BaseChatModel]): Language model specifically for text-to-SQL tasks.\n        bi_config (Dict[str, Any]): BI configuration loaded from YAML file. Defaults to empty dict.\n        data_warehouse_config (Dict[str, Any]): Data warehouse configuration. Defaults to empty dict.\n    \"\"\"\n\n    model_config = {\"arbitrary_types_allowed\": True}\n\n    # General Configurations\n    organization: str = \"The Company\"\n    dialect: str = \"presto\"\n\n    # LLM Configurations\n    default_llm: BaseChatModel | MagicMock\n    embedding_model: BaseModel | MagicMock | None = None\n    text2sql_llm: BaseChatModel | MagicMock | None = None\n    # Multiple LLM providers (optional)\n    llm_provider: str | None = None\n    llm_providers: dict[str, LLMProviderConfig] = {}\n\n    # BI Configuration\n    bi_config: dict[str, Any] = {}\n\n    # Data Warehouse Configuration\n    data_warehouse_config: dict[str, Any] = {}\n\n    # Catalog Store\n    catalog_store: Any = None\n\n    # Path to the vector database file\n    vector_db_path: str = None\n\n    # MCP Servers Configuration\n    mcp_servers: list[dict[str, Any]] = []\n\n    # Report Configuration\n    report_directory: str = \"./data\"\n\n    # Code Execution Configuration\n    python_executor: str = \"local\"  # Options: \"local\", \"restricted_local\", \"docker\"\n\n    # Visualization Configuration\n    visualization_mode: str | None = \"rule\"  # Options: \"rule\", \"llm\", None (skip visualization)\n\n    # Context Management Configuration\n    context_config: dict[str, Any] = {}\n\n    # Time Series Service Configuration\n    timeseries_forecasting_service_url: str = \"http://localhost:8765\"\n\n    @classmethod\n    def from_dict(cls, config: dict[str, Any]) -> \"Config\":\n        \"\"\"Creates a Config instance from a dictionary.\n\n        Args:\n            config (Dict[str, Any]): Dictionary containing configuration values.\n\n        Returns:\n            Config: A new Config instance with the provided values.\n        \"\"\"\n        return cls(**config)\n\n\nclass ConfigLoader:\n    \"\"\"Singleton class to load and manage configuration settings for OpenChatBI.\n\n    This class provides methods to load, get, and set configuration parameters\n    for the application, including LLM models, SQL dialect, and other settings.\n    \"\"\"\n\n    _instance = None\n    _config: Config = None\n\n    def __new__(cls):\n        if cls._instance is None:\n            cls._instance = super().__new__(cls)\n        return cls._instance\n\n    llm_configs = [\"default_llm\", \"embedding_model\", \"text2sql_llm\"]\n\n    def get(self) -> Config:\n        \"\"\"Get the current configuration.\n\n        Returns:\n            Config: The current configuration instance.\n\n        Raises:\n            ValueError: If the configuration has not been loaded.\n        \"\"\"\n        if self._config is None:\n            raise ValueError(\"Configuration has not been loaded. Please call load() or set() first.\")\n        return self._config\n\n    def load(self, config_file: str = None) -> None:\n        \"\"\"Load configuration from a YAML file.\n\n        Args:\n            config_file (str, optional): Path to configuration file. Uses CONFIG_FILE\n                environment variable or 'openchatbi/config.yaml' if not provided.\n\n        Raises:\n            ImportError: If pyyaml is not installed.\n            FileNotFoundError: If the configuration file cannot be found.\n        \"\"\"\n        if config_file is None:\n            config_file = os.getenv(\"CONFIG_FILE\", \"openchatbi/config.yaml\")\n\n        if not find_spec(\"yaml\"):\n            raise ImportError(\"Please install pyyaml to use this feature.\")\n\n        import yaml\n\n        try:\n            with open(config_file, encoding=\"utf-8\") as file:\n                config_data = yaml.safe_load(file)\n                if config_data is None:\n                    config_data = {}\n        except FileNotFoundError:\n            log(f\"Configuration file not found: {config_file}, leave config un-loaded.\")\n            return\n        except yaml.YAMLError as e:\n            raise ValueError(f\"Invalid YAML in configuration file {config_file}: {e}\")\n        except Exception as e:\n            raise RuntimeError(f\"Failed to read configuration file {config_file}: {e}\")\n\n        self._process_config_dict(config_data)\n        self._config = Config.from_dict(config_data)\n\n    def _process_config_dict(self, config_data: dict[str, Any]) -> None:\n        \"\"\"\n        Processes a configuration dictionary.\n        \"\"\"\n        self._process_llm_providers(config_data)\n\n        providers = config_data.get(\"llm_providers\", {})\n        selected_provider = None\n\n        default_llm_value = config_data.get(\"default_llm\")\n        if isinstance(default_llm_value, str):\n            # Simplified multi-provider config: default_llm: <provider_name>\n            if not providers:\n                raise ValueError(\"default_llm is a provider name but llm_providers is missing.\")\n            selected_provider = default_llm_value\n        elif providers:\n            # Backwards-compat: allow selecting provider via llm_provider\n            legacy_provider = config_data.get(\"llm_provider\")\n            if isinstance(legacy_provider, str):\n                selected_provider = legacy_provider\n            elif \"default_llm\" not in config_data:\n                # Pick the first provider in config order for backwards-compatible YAML behavior\n                selected_provider = next(iter(providers.keys()), None)\n            elif isinstance(default_llm_value, dict):\n                raise ValueError(\n                    \"When using llm_providers, set default_llm to a provider name (e.g. default_llm: openai), \"\n                    \"not a class config.\"\n                )\n\n        if providers:\n            if not selected_provider or selected_provider not in providers:\n                raise ValueError(f\"Unknown LLM provider '{selected_provider}'. Available: {sorted(providers.keys())}\")\n            # Store selected provider for runtime lookups (UI/API can still override per-request)\n            config_data[\"llm_provider\"] = selected_provider\n            # Populate top-level LLM objects for legacy call sites\n            config_data[\"default_llm\"] = providers[selected_provider].default_llm\n            config_data.setdefault(\"embedding_model\", providers[selected_provider].embedding_model)\n            config_data.setdefault(\"text2sql_llm\", providers[selected_provider].text2sql_llm)\n        elif \"default_llm\" not in config_data:\n            raise ValueError(\"Missing LLM config key: default_llm\")\n\n        if not config_data.get(\"embedding_model\"):\n            log(\"WARN: Missing LLM config key: embedding_model, will use BM25 based retrival only\")\n        if \"data_warehouse_config\" not in config_data:\n            raise ValueError(\"Missing Data Warehouse config key: data_warehouse_config\")\n\n        # Load BI configuration\n        if \"bi_config_file\" in config_data:\n            bi_config = self.load_bi_config(config_data[\"bi_config_file\"])\n            bi_config.update(config_data.get(\"bi_config\", {}))\n            config_data[\"bi_config\"] = bi_config\n\n        if \"catalog_store\" in config_data:\n            if \"store_type\" not in config_data[\"catalog_store\"]:\n                raise ValueError(\"catalog_store must have a store_type field.\")\n            catalog_store = create_catalog_store(\n                **config_data[\"catalog_store\"],\n                auto_load=config_data[\"catalog_store\"].get(\"auto_load\", True),\n                data_warehouse_config=config_data.get(\"data_warehouse_config\"),\n            )\n        else:\n            log(\"Catalog store config key `catalog_store` not found. Using default file system store.\")\n            catalog_store = create_catalog_store(\n                store_type=\"file_system\",\n                auto_load=True,\n                data_warehouse_config=config_data.get(\"data_warehouse_config\"),\n            )\n        config_data[\"catalog_store\"] = catalog_store\n\n        for config_key in self.llm_configs:\n            config_item = config_data.get(config_key)\n            if not isinstance(config_item, dict) or \"class\" not in config_item:\n                continue\n            config_data[config_key] = self._instantiate_from_config_dict(config_item, config_key=config_key)\n\n    def _instantiate_from_config_dict(self, config_item: dict[str, Any], *, config_key: str) -> Any:\n        try:\n            class_path = config_item[\"class\"]\n            if \".\" not in class_path:\n                raise ValueError(f\"Invalid class path format: {class_path}\")\n            module_name, class_name = class_path.rsplit(\".\", 1)\n            module = importlib.import_module(module_name)\n            llm_cls = getattr(module, class_name)\n            params = config_item.get(\"params\", {})\n            return llm_cls(**params)\n        except (ImportError, AttributeError, ValueError, TypeError) as e:\n            raise RuntimeError(f\"Failed to load {config_key} class '{config_item.get('class', '')}': {e}\") from e\n\n    def _process_llm_providers(self, config_data: dict[str, Any]) -> None:\n        \"\"\"Resolve llm_providers into instantiated provider configs (if present).\"\"\"\n        raw_providers = config_data.get(\"llm_providers\")\n        if not raw_providers:\n            return\n        if not isinstance(raw_providers, dict):\n            raise ValueError(\"llm_providers must be a mapping of provider_name -> config\")\n\n        providers: dict[str, LLMProviderConfig] = {}\n        for provider_name, provider_cfg in raw_providers.items():\n            if isinstance(provider_cfg, LLMProviderConfig):\n                providers[str(provider_name)] = provider_cfg\n                continue\n            if not isinstance(provider_cfg, dict):\n                raise ValueError(f\"llm_providers.{provider_name} must be a mapping\")\n\n            resolved_cfg: dict[str, Any] = dict(provider_cfg)\n            for config_key in self.llm_configs:\n                config_item = resolved_cfg.get(config_key)\n                if not isinstance(config_item, dict) or \"class\" not in config_item:\n                    continue\n                resolved_cfg[config_key] = self._instantiate_from_config_dict(\n                    config_item, config_key=f\"llm_providers.{provider_name}.{config_key}\"\n                )\n\n            if \"default_llm\" not in resolved_cfg or resolved_cfg[\"default_llm\"] is None:\n                raise ValueError(f\"llm_providers.{provider_name} missing default_llm\")\n\n            providers[str(provider_name)] = LLMProviderConfig(**resolved_cfg)\n\n        config_data[\"llm_providers\"] = providers\n\n    def load_bi_config(self, bi_config_file: str) -> dict[str, Any]:\n        \"\"\"Load BI configuration from a YAML file.\n\n        Args:\n            bi_config_file (str): Path to the BI configuration file.\n                Defaults to 'example/bi.yaml'.\n\n        Returns:\n            Dict[str, Any]: The loaded BI configuration as a dictionary.\n\n        Raises:\n            ImportError: If pyyaml is not installed.\n            FileNotFoundError: If the BI configuration file cannot be found.\n        \"\"\"\n        if not find_spec(\"yaml\"):\n            raise ImportError(\"Please install pyyaml to use this feature.\")\n\n        import yaml\n\n        bi_config_data = {}\n\n        try:\n            with open(bi_config_file, encoding=\"utf-8\") as file:\n                bi_config_data = yaml.safe_load(file) or {}\n        except FileNotFoundError:\n            log(f\"Warning: BI config file '{bi_config_file}' not found. Ignore load BI config from yaml file.\")\n        except yaml.YAMLError as e:\n            log(f\"Warning: Invalid YAML in BI config file '{bi_config_file}': {e}. Using empty config.\")\n        except Exception as e:\n            log(f\"Warning: Failed to read BI config file '{bi_config_file}': {e}. Using empty config.\")\n\n        return bi_config_data\n\n    def set(self, config: dict[str, Any]) -> None:\n        \"\"\"Set the configuration from a dictionary.\n\n        Args:\n            config (Dict[str, Any]): Dictionary containing configuration values.\n        \"\"\"\n        self._process_config_dict(config)\n        self._config = Config.from_dict(config)\n"
  },
  {
    "path": "openchatbi/constants.py",
    "content": "\"\"\"Constants used throughout the OpenChatBI application.\"\"\"\n\n# Date/time format strings\ndatetime_format = \"%Y-%m-%d %H:%M:%S\"\ndate_format = \"%Y-%m-%d\"\ndatetime_format_ms = \"%Y-%m-%d %H:%M:%S.%f\"\ndatetime_format_ms_T = \"%Y-%m-%dT%H:%M:%S.%fZ\"\n\n# SQL execution status codes\nSQL_NA = \"SQL_NA\"\nSQL_SUCCESS = \"SQL_SUCCESS\"\nSQL_EXECUTE_TIMEOUT = \"SQL_CHECK_TIMEOUT\"\nSQL_SYNTAX_ERROR = \"SQL_SYNTAX_ERROR\"\nSQL_UNKNOWN_ERROR = \"SQL_UNKNOWN_ERROR\"\n\n\nMCP_TOOL_DEFAULT_TIMEOUT_SECONDS = 60\n"
  },
  {
    "path": "openchatbi/context_config.py",
    "content": "\"\"\"Configuration for context management settings.\"\"\"\n\nfrom dataclasses import dataclass\n\nfrom openchatbi import config\n\n\n@dataclass\nclass ContextConfig:\n    \"\"\"Configuration class for context management settings.\"\"\"\n\n    # Enable/disable context management\n    enabled: bool = True\n\n    # Token limits for triggering context management\n    summary_trigger_tokens: int = 12000\n\n    # Message retention (how many recent messages to always preserve)\n    keep_recent_messages: int = 20\n\n    # Historical tool output compression limits\n    max_tool_output_length: int = 2000  # Max length for historical tool outputs\n    max_sql_result_rows: int = 50  # Max rows to keep in CSV results\n    max_code_output_lines: int = 50  # Max lines for code execution output\n\n    # Conversation summarization\n    enable_summarization: bool = True\n    enable_conversation_summary: bool = True\n    summary_max_messages: int = 50  # Max messages to include in summary context\n\n    # Content preservation settings\n    preserve_tool_errors: bool = True  # Always preserve error messages in full\n    preserve_recent_sql: bool = True  # Preserve SQL content (less aggressive compression)\n\n\ndef get_context_config() -> ContextConfig:\n    \"\"\"Get the current context configuration.\n\n    This function loads context configuration from the main config system.\n    Falls back to default configuration if not available.\n\n    Returns:\n        ContextConfig: The current context configuration\n    \"\"\"\n    try:\n        main_config = config.get()\n\n        # Check if context_config exists in the main config\n        if hasattr(main_config, \"context_config\") and main_config.context_config:\n            context_config_dict = main_config.context_config\n            # Create ContextConfig from the loaded configuration\n            context_config = ContextConfig()\n            for key, value in context_config_dict.items():\n                if hasattr(context_config, key):\n                    setattr(context_config, key, value)\n            return context_config\n    except (ImportError, ValueError, AttributeError):\n        # Fall back to default if config system is not available or configured\n        pass\n\n    return ContextConfig()\n\n\ndef update_context_config(**kwargs) -> ContextConfig:\n    \"\"\"Update context configuration with new values.\n\n    Args:\n        **kwargs: Configuration parameters to update\n\n    Returns:\n        ContextConfig: Updated configuration\n    \"\"\"\n    config = get_context_config()\n    for key, value in kwargs.items():\n        if hasattr(config, key):\n            setattr(config, key, value)\n    return config\n"
  },
  {
    "path": "openchatbi/context_manager.py",
    "content": "\"\"\"Context management utilities for handling long conversations.\"\"\"\n\nimport json\nimport re\nimport uuid\n\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage\n\nfrom openchatbi.context_config import ContextConfig, get_context_config\nfrom openchatbi.llm.llm import call_llm_chat_model_with_retry\nfrom openchatbi.prompts.system_prompt import get_summary_prompt_template\nfrom openchatbi.utils import log\n\n\nclass ContextManager:\n    \"\"\"Manages conversation context to prevent token limit issues.\"\"\"\n\n    def __init__(self, llm: BaseChatModel, config: ContextConfig = None):\n        \"\"\"Initialize context manager.\n\n        Args:\n            llm: Language model for summarization\n            config: Context configuration. If None, uses default config.\n        \"\"\"\n        self.llm = llm\n        self.config = config or get_context_config()\n\n    # ============================================================================\n    # PUBLIC API METHODS\n    # ============================================================================\n\n    def manage_context_messages(self, messages: list) -> None:\n        \"\"\"Main context management function that directly modifies messages list.\n\n        Args:\n            messages: The list of messages to manage (modified in place)\n        \"\"\"\n        if not self.config.enabled:\n            return\n\n        if not messages:\n            return\n\n        # Check if we need to manage context\n        estimated_tokens = self.estimate_message_tokens(messages)\n        if estimated_tokens <= self.config.summary_trigger_tokens:\n            return  # No action needed\n\n        log(f\"Context management triggered: {estimated_tokens} tokens > {self.config.summary_trigger_tokens}\")\n\n        # Apply historical tool message compression directly\n        self._compress_historical_tool_messages(messages)\n\n        # Check if we still need summarization after compression\n        remaining_tokens = self.estimate_message_tokens(messages)\n        if remaining_tokens > self.config.summary_trigger_tokens and self.config.enable_summarization:\n            self._apply_conversation_summarization(messages)\n\n        log(\"Context management completed\")\n\n    # ============================================================================\n    # TOKEN ESTIMATION METHODS\n    # ============================================================================\n\n    @staticmethod\n    def estimate_tokens(text: str) -> int:\n        \"\"\"Rough token estimation (1 token ≈ 4 characters for most languages).\"\"\"\n        return len(text) // 4\n\n    def estimate_message_tokens(self, messages: list[BaseMessage]) -> int:\n        \"\"\"Estimate total tokens in a list of messages.\"\"\"\n        total = 0\n        for msg in messages:\n            total += self.estimate_tokens(str(msg.content))\n            # Add tokens for metadata and structure\n            total += 50\n        return total\n\n    # ============================================================================\n    # TOOL OUTPUT TRIMMING METHODS\n    # ============================================================================\n\n    def trim_tool_output(self, content: str, tool_name: str = \"\") -> str:\n        \"\"\"Trim tool output to manageable size while preserving key information.\"\"\"\n        if len(content) <= self.config.max_tool_output_length:\n            return content\n\n        # Preserve full error messages if configured\n        if self.config.preserve_tool_errors and (\"Error:\" in content or \"Traceback\" in content):\n            return content\n\n        # For SQL results, preserve structure\n        if \"```sql\" in content or \"```csv\" in content:\n            return self._trim_structured_output(content)\n\n        # For code execution results\n        if \"```python\" in content or \"Traceback\" in content:\n            return self._trim_code_output(content)\n\n        # Generic trimming\n        max_len = self.config.max_tool_output_length\n        trimmed = content[: max_len // 2] + \"\\n\\n... [Output truncated] ...\\n\\n\" + content[-max_len // 2 :]\n        return trimmed\n\n    def _trim_structured_output(self, content: str) -> str:\n        \"\"\"Trim SQL/CSV output while preserving structure.\"\"\"\n        parts = []\n\n        # Extract SQL query (always keep)\n        sql_match = re.search(r\"```sql\\n(.*?)\\n```\", content, re.DOTALL)\n        if sql_match:\n            parts.append(f\"```sql\\n{sql_match.group(1)}\\n```\")\n\n        # Extract and trim CSV data\n        csv_match = re.search(r\"```csv\\n(.*?)\\n```\", content, re.DOTALL)\n        if csv_match:\n            csv_data = csv_match.group(1)\n            lines = csv_data.split(\"\\n\")\n            max_rows = self.config.max_sql_result_rows\n\n            if len(lines) > max_rows:  # Keep header + first half + last quarter\n                keep_start = max_rows // 2\n                keep_end = max_rows // 4\n                trimmed_csv = \"\\n\".join(\n                    lines[: keep_start + 1]\n                    + [f\"... [{len(lines) - keep_start - keep_end - 1} rows omitted] ...\"]\n                    + lines[-keep_end:]\n                )\n                parts.append(f\"```csv\\n{trimmed_csv}\\n```\")\n            else:\n                parts.append(f\"```csv\\n{csv_data}\\n```\")\n\n        # Keep visualization info\n        viz_match = re.search(r\"Visualization Created:.*\", content)\n        if viz_match:\n            parts.append(viz_match.group(0))\n\n        return \"\\n\\n\".join(parts)\n\n    def _trim_code_output(self, content: str) -> str:\n        \"\"\"Trim Python code execution output.\"\"\"\n        # Keep error messages (full) if configured\n        if self.config.preserve_tool_errors and (\"Traceback\" in content or \"Error:\" in content):\n            return content\n\n        lines = content.split(\"\\n\")\n        max_lines = self.config.max_code_output_lines\n\n        if len(lines) <= max_lines:\n            return content\n\n        # Keep first half and last quarter\n        keep_start = max_lines // 2\n        keep_end = max_lines // 4\n        return \"\\n\".join(lines[:keep_start] + [\"... [Output truncated] ...\"] + lines[-keep_end:])\n\n    # ============================================================================\n    # CONVERSATION SUMMARIZATION METHODS\n    # ============================================================================\n\n    def summarize_conversation(self, messages: list[BaseMessage]) -> str:\n        \"\"\"Create a summary of conversation history.\"\"\"\n        if not self.config.enable_conversation_summary:\n            return \"\"\n\n        # Filter out system messages for summarization\n        # Note: The messages passed in are already historical messages (split point already calculated)\n        messages_to_summarize = []\n        for msg in messages:\n            if not isinstance(msg, SystemMessage):\n                messages_to_summarize.append(msg)\n\n        if not messages_to_summarize:\n            return \"\"\n\n        # Create summarization prompt\n        conversation_text = self._format_messages_for_summary(messages_to_summarize)\n\n        # Get the summary prompt template from the file and replace placeholder\n        summary_prompt = get_summary_prompt_template().replace(\"[conversation_text]\", conversation_text)\n\n        try:\n            response = call_llm_chat_model_with_retry(\n                self.llm, [HumanMessage(content=summary_prompt)], parallel_tool_call=False\n            )\n\n            if isinstance(response, AIMessage):\n                return f\"[Conversation Summary]: {response.content}\"\n            return \"[Summary generation failed]\"\n\n        except Exception as e:\n            log(f\"Failed to generate conversation summary: {e}\")\n            return \"[Summary generation failed]\"\n\n    def _truncate_text(self, text: str, truncate_len: int = 500) -> str:\n        # do not truncate Conversation Summary\n        if text.startswith(\"[Conversation Summary]\"):\n            return text\n        if len(text) > truncate_len:\n            return text[:truncate_len] + \"... [truncated]\"\n        return text\n\n    def _truncate_text_or_list(self, content):\n        results = []\n        if isinstance(content, str):\n            results.append(self._truncate_text(content))\n        elif isinstance(content, list):\n            for item in content:\n                if isinstance(item, str):\n                    results.append(self._truncate_text(item))\n                elif isinstance(item, dict):\n                    if item[\"type\"] == \"text\":\n                        results.append(self._truncate_text(item[\"text\"]))\n                    elif item[\"type\"] == \"tool_use\":\n                        results.append(json.dumps(item))\n        return results\n\n    def _format_messages_for_summary(self, messages: list[BaseMessage]) -> str:\n        \"\"\"Format messages for summary generation.\"\"\"\n        formatted = []\n        max_messages = self.config.summary_max_messages\n\n        # Limit messages for summary context\n        for msg in messages[-max_messages:]:\n            if isinstance(msg, HumanMessage):\n                formatted.append(f\"<user> {msg.content} </user>\")\n            elif isinstance(msg, AIMessage):\n                content = msg.content or \"\"\n                formatted.append(\"<assistant>\")\n                formatted.extend(self._truncate_text_or_list(content))\n                formatted.append(\"</assistant>\")\n            elif isinstance(msg, ToolMessage):\n                formatted.append(\n                    f\"<tool_result> tool_call_id: {msg.tool_call_id},  \"\n                    f\"tool: {msg.name}, \"\n                    f\"status: {msg.status}, \"\n                    f\"result: {self._truncate_text_or_list(msg.content)} </tool_result>\"\n                )\n\n        return \"\\n\".join(formatted)\n\n    # ============================================================================\n    # CONTEXT MANAGEMENT IMPLEMENTATION METHODS\n    # ============================================================================\n\n    def _compress_historical_tool_messages(self, messages: list[BaseMessage]) -> None:\n        \"\"\"Compress historical (not recent) tool messages in place.\"\"\"\n        # Find a safe split point\n        recent_start_index = self._find_safe_split_point(messages)\n\n        # Find tool messages in historical part (before recent_start_index) that need compression\n        for i in range(recent_start_index):\n            msg = messages[i]\n            if isinstance(msg, ToolMessage):\n                original_content = str(msg.content)\n\n                # Apply intelligent filtering for tool message compression\n                if self._should_compress_historical_tool_message(msg, original_content):\n                    trimmed_content = self.trim_tool_output(original_content)\n\n                    if len(trimmed_content) < len(original_content):\n                        # Update message content directly\n                        messages[i] = ToolMessage(\n                            content=trimmed_content,\n                            tool_call_id=msg.tool_call_id,\n                            id=msg.id,  # Keep original ID to preserve position\n                        )\n\n                        log(\n                            f\"Compressed historical tool message: {len(original_content)} -> {len(trimmed_content)} chars\"\n                        )\n\n    def _apply_conversation_summarization(self, messages: list[BaseMessage]) -> None:\n        \"\"\"Apply conversation summarization by modifying messages list in place.\"\"\"\n        if not self.config.enable_conversation_summary:\n            return\n\n        # Find a safe split point that doesn't separate AI messages with tool calls from their ToolMessages\n        recent_start_index = self._find_safe_split_point(messages)\n\n        if recent_start_index == 0:\n            return  # No historical messages to summarize\n\n        historical_messages = messages[:recent_start_index]\n        recent_messages = messages[recent_start_index:]\n\n        if len(historical_messages) == 1:\n            msg = historical_messages[0]\n            if isinstance(msg, AIMessage) and msg.content.startswith(\"[Conversation Summary]\"):\n                return\n\n        # Generate summary\n        summary_text = self.summarize_conversation(historical_messages)\n\n        if summary_text:\n            # Rebuild messages list in place: summary + recent\n            new_messages = [AIMessage(content=summary_text, id=str(uuid.uuid4()))] + recent_messages\n\n            # Clear and repopulate the list in place\n            messages.clear()\n            messages.extend(new_messages)\n\n            log(f\"Applied conversation summary, removed {len(historical_messages)} historical messages\")\n\n    def _find_safe_split_point(self, messages: list[BaseMessage]) -> int:\n        \"\"\"Find a safe split point that start at HumanMessage\n\n        Returns the index where recent messages should start (everything before this index is historical).\n        \"\"\"\n        if len(messages) <= self.config.keep_recent_messages:\n            return 0  # Keep all messages as recent\n\n        # If keep_recent_messages is 0, return all messages as historical\n        if self.config.keep_recent_messages <= 0:\n            return len(messages)\n\n        # Start from the naive split point\n        naive_split = len(messages) - self.config.keep_recent_messages\n\n        # Find the nearest HumanMessage\n        for i in range(naive_split, -1, -1):\n            msg = messages[i]\n            if isinstance(msg, HumanMessage) or isinstance(msg, dict) and msg[\"role\"] == \"user\":\n                return i  # Split before this HumanMessage\n\n        return naive_split\n\n    # ============================================================================\n    # CONTENT ANALYSIS HELPER METHODS\n    # ============================================================================\n\n    def _should_compress_historical_tool_message(self, tool_msg: ToolMessage, content: str) -> bool:\n        \"\"\"Determine if a historical tool message should be compressed.\n\n        Args:\n            tool_msg: The tool message to evaluate\n            content: The content of the tool message\n\n        Returns:\n            bool: True if the message should be compressed\n        \"\"\"\n        # Don't compress if content is already short\n        if len(content) <= self.config.max_tool_output_length:\n            return False\n\n        # Always preserve error messages if configured\n        if self.config.preserve_tool_errors and self._is_error_content(content):\n            return False\n\n        # Don't compress recent SQL results if configured\n        if self.config.preserve_recent_sql and self._is_sql_content(content):\n            return False\n\n        # Compress large outputs from specific tools more aggressively\n        if self._is_data_query_result(content):\n            return True\n\n        # Compress Python execution results but preserve errors\n        if self._is_python_execution_result(content):\n            return not self._is_error_content(content)\n\n        # Default: compress if content is long\n        return True\n\n    def _is_error_content(self, content: str) -> bool:\n        \"\"\"Check if content contains error information.\"\"\"\n        error_indicators = [\n            \"error:\",\n            \"Error:\",\n            \"ERROR:\",\n            \"exception:\",\n            \"Exception:\",\n            \"EXCEPTION:\",\n            \"traceback\",\n            \"Traceback\",\n            \"TRACEBACK\",\n            \"failed\",\n            \"Failed\",\n            \"FAILED\",\n            \"KeyError\",\n            \"ValueError\",\n            \"TypeError\",\n            \"AttributeError\",\n            \"FileNotFoundError\",\n            \"ConnectionError\",\n        ]\n        return any(indicator in content for indicator in error_indicators)\n\n    def _is_sql_content(self, content: str) -> bool:\n        \"\"\"Check if content contains SQL query results.\"\"\"\n        sql_indicators = [\n            \"```sql\",\n            \"query results\",\n            \"sql query:\",\n            \"select \",\n            \"insert \",\n            \"update \",\n            \"delete \",\n            \"create table\",\n            \"alter table\",\n        ]\n        content_lower = content.lower()\n        return any(indicator in content_lower for indicator in sql_indicators)\n\n    def _is_data_query_result(self, content: str) -> bool:\n        \"\"\"Check if content is a data query result that can be safely compressed.\"\"\"\n        indicators = [\n            \"```csv\",\n            \"query results\",\n            \"rows returned\",\n            \"records found\",\n            \"records in the database\",\n            \"found records\",\n            \"csv format\",\n        ]\n        content_lower = content.lower()\n        return any(indicator in content_lower for indicator in indicators)\n\n    def _is_python_execution_result(self, content: str) -> bool:\n        \"\"\"Check if content is Python code execution result.\"\"\"\n        indicators = [\n            \"```python\",\n            \"execution completed\",\n            \"output:\",\n            \"result:\",\n            \"print(\",\n        ]\n        return any(indicator.lower() in content.lower() for indicator in indicators)\n"
  },
  {
    "path": "openchatbi/graph_state.py",
    "content": "\"\"\"State classes for OpenChatBI graph execution.\"\"\"\n\nfrom typing import Annotated, Any\n\nfrom langchain_core.messages import AIMessage, HumanMessage\nfrom langgraph.graph import MessagesState\nfrom langgraph.types import Send\n\n\ndef add_history_messages(left: list, right: list):\n    if left:\n        total_messages = left + right\n    else:\n        total_messages = right\n    return total_messages\n\n\nclass AgentState(MessagesState):\n    \"\"\"State for the main agent graph execution.\n\n    Extends MessagesState with additional fields for routing and responses.\n    \"\"\"\n\n    history_messages: Annotated[list[HumanMessage | AIMessage], add_history_messages]\n    agent_next_node: str\n    sends: list[Send]\n    sql: str\n    final_answer: str\n\n\nclass SQLGraphState(MessagesState):\n    \"\"\"State for SQL generation subgraph.\n\n    Contains rewritten question, table selection, extracted entities, and generated SQL.\n    \"\"\"\n\n    rewrite_question: str\n    tables: list[dict[str, Any]]\n    info_entities: dict[str, Any]\n    sql: str\n    sql_retry_count: int\n    sql_execution_result: str\n    schema_info: dict[str, Any]  # Data schema analysis results\n    data: str  # CSV data for display\n    previous_sql_errors: list[dict[str, Any]]\n    visualization_dsl: dict[str, Any]\n\n\nclass InputState(MessagesState):\n    \"\"\"Input state schema for the main graph.\"\"\"\n\n    pass\n\n\nclass OutputState(MessagesState):\n    \"\"\"Output state schema for the main graph.\"\"\"\n\n    pass\n\n\nclass SQLOutputState(MessagesState):\n    \"\"\"Output state schema for the SQL generation subgraph.\"\"\"\n\n    rewrite_question: str\n    tables: list[dict[str, Any]]\n    sql: str\n    schema_info: dict[str, Any]  # Data schema analysis results\n    data: str  # CSV data for display\n    visualization_dsl: dict[str, Any]\n"
  },
  {
    "path": "openchatbi/llm/llm.py",
    "content": "import time\nimport traceback\n\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.runnables.base import RunnableBinding\nfrom langchain_core.tools import StructuredTool\n\nfrom openchatbi import config\nfrom openchatbi.tool.ask_human import AskHuman\nfrom openchatbi.utils import log\n\n\ndef list_llm_providers() -> list[str]:\n    \"\"\"List configured LLM provider names (if any).\"\"\"\n    try:\n        providers = getattr(config.get(), \"llm_providers\", None) or {}\n    except ValueError:\n        return []\n    return sorted(providers.keys())\n\n\ndef _get_provider_config(provider: str | None):\n    cfg = config.get()\n    providers = getattr(cfg, \"llm_providers\", None) or {}\n    if not provider:\n        provider = getattr(cfg, \"llm_provider\", None)\n    if not provider:\n        return None\n    if provider not in providers:\n        raise ValueError(f\"Unknown llm_provider '{provider}'. Available: {sorted(providers.keys())}\")\n    return providers[provider]\n\n\ndef get_embedding_model(provider: str | None = None):\n    \"\"\"Get embedding model from config (optionally scoped to a provider).\"\"\"\n    provider_cfg = _get_provider_config(provider)\n    if provider_cfg and getattr(provider_cfg, \"embedding_model\", None) is not None:\n        return provider_cfg.embedding_model\n    return config.get().embedding_model\n\n\ndef get_default_llm(provider: str | None = None):\n    \"\"\"Get default LLM from config (optionally scoped to a provider).\"\"\"\n    provider_cfg = _get_provider_config(provider)\n    if provider_cfg:\n        return provider_cfg.default_llm\n    return config.get().default_llm\n\n\ndef get_llm(provider: str | None = None):\n    \"\"\"Get the chat model to use (alias for `get_default_llm`).\"\"\"\n    return get_default_llm(provider)\n\n\ndef get_text2sql_llm(provider: str | None = None):\n    \"\"\"Get text2sql LLM from config (optionally scoped to a provider).\"\"\"\n    provider_cfg = _get_provider_config(provider)\n    if provider_cfg:\n        return provider_cfg.text2sql_llm or provider_cfg.default_llm\n    return config.get().text2sql_llm or get_default_llm()\n\n\ndef _invalid_tool_names(valid_tools, tool_calls) -> str:\n    invalid_tools = []\n    for tool in tool_calls:\n        if tool[\"name\"] not in valid_tools:\n            invalid_tools.append(tool[\"name\"])\n    return \",\".join(invalid_tools)\n\n\ndef call_llm_chat_model_with_retry(\n    chat_model: BaseChatModel, messages, streaming_tokens=False, bound_tools=None, parallel_tool_call=False\n):\n    \"\"\"Calls a language model chat endpoint with retry logic.\n\n    Retries up to 3 times if there are errors or invalid tool calls.\n\n    Args:\n        chat_model: The chat model to invoke.\n        messages (list): List of messages to send to the model.\n        streaming_tokens (bool, optional): flag to indicate whether or not to show streaming tokens in UI.\n        bound_tools (list, optional): List of valid tool names that can be called.\n        parallel_tool_call (bool, optional): whether or not to call multiple tools in parallel.\n\n    Returns:\n        AIMessage or None: The model response or None if all retries failed.\n    \"\"\"\n    new_messages = list(messages)\n    valid_tools = []\n    if bound_tools:\n        for tool in bound_tools:\n            if isinstance(tool, str):\n                valid_tools.append(tool)\n            elif isinstance(tool, StructuredTool):\n                valid_tools.append(tool.name)\n            elif tool == AskHuman:\n                valid_tools.append(\"AskHuman\")\n    elif isinstance(chat_model, RunnableBinding) and \"tools\" in chat_model.kwargs:\n        valid_tools += [tool[\"name\"] for tool in chat_model.kwargs[\"tools\"] if \"name\" in tool]\n    extra_prompt = (\n        \" Please select the `AskHuman` tool if you need to confirm with user.\" if \"AskHuman\" in valid_tools else \"\"\n    )\n    response = None\n    retry = 0\n    # retry 3 times\n    while retry < 3:\n        start_time = time.time()\n        try:\n            log(f\"Call LLM chat model with retry {retry} times.\")\n            response = chat_model.invoke(new_messages, config={\"metadata\": {\"streaming_tokens\": streaming_tokens}})\n            run_time = int(time.time() - start_time)\n            log(f\"LLM response after {run_time} seconds.\")\n        except Exception:\n            run_time = int(time.time() - start_time)\n            retry += 1\n            log(f\"LLM response error after {run_time} seconds, retry {retry} times.\")\n            log(\"===== Messages:\")\n            log(str(messages))\n            traceback.print_exc()\n            continue\n\n        if response.tool_calls:\n            if len(response.tool_calls) > 1 and not parallel_tool_call:\n                retry += 1\n                log(f\"More than one tool {response.tool_calls}, retry {retry} times.\")\n                new_messages += [{\"role\": \"user\", \"content\": \"You should only response with one tool call.\"}]\n                response = None\n                continue\n            invalid_tools = _invalid_tool_names(valid_tools, response.tool_calls)\n            if invalid_tools:\n                retry += 1\n                log(f\"Invalid tool {invalid_tools}, retry {retry} times.\")\n                new_messages += [\n                    {\n                        \"role\": \"user\",\n                        \"content\": f\"You should not use tool that does not exist:`{invalid_tools}`.\"\n                        f\"Available tools are: {valid_tools}. Please choose a valid tool and try again.\"\n                        f\"{extra_prompt}\",\n                    }\n                ]\n                response = None\n                continue\n        break\n    return response\n"
  },
  {
    "path": "openchatbi/prompts/agent_prompt.md",
    "content": "You are a helpful BI assistant that can answer user's question. \nUse the instructions below and the tools available to you to assist the user.\n\n# Capabilities:\n1. Answer general question.\n2. Answer question based on knowledge base.\n3. Answer question regarding data query by call the SQL graph to write SQL to answer the question.\n4. Answer question that need to analyze the data by write and execute python code\n\n# Guidelines:\n- You should be concise, direct, and to the point.\n- No fabricate information, if you don't know, just say you don't know.\n- Summarize the information you found to answer the question.\n- When data analysis results include \"Visualization Created\" message, acknowledge that an interactive chart has been automatically generated and focus on interpreting the data insights rather than creating additional charts.\n\n\n# Tool usage policy\n- If you cannot answer the question, call tools that are available.\n- For `run_python_code` tool, you can use these libs when writing python code: pandas numpy matplotlib seaborn requests json5\n- IMPORTANT: DO NOT create charts/visualizations with Python code if the text2sql tool response already indicates \"Visualization Created\". The interactive chart is automatically generated and displayed in the UI. Simply summarize the results without duplicating the visualization.\n- If user provide personalized information that need to remember or want to forget or correct something mentioned before, use `manage_memory` tool to save, delete or update the long term memory\n- If the question is related to user information, characteristic or preference, proactively use `search_memory` tool to get the long term memory\n- If the question is not clear, or some information is missing, ask the user to clarify by calling AskHuman tool.\n- When generating reports, analysis results, or data summaries that users might want to save or share, use the `save_report` tool to save the content to a file and provide a download link.\n- **When text2sql tool returns empty SQL**: This indicates the current data capabilities cannot support the requested query. Explain to the user that the requested data or analysis is not available in the current system, and suggest alternative queries that might be supported based on available data sources.\n\n## Knowledge Search Optimization\n- **AVOID excessive knowledge searches** for data queries that contain standard business terms already covered in your basic knowledge\n- **ONLY search knowledge** when:\n  - User asks about unfamiliar business terms, metrics, or dimensions not in basic knowledge\n  - Question contains ambiguous terminology that needs clarification\n  - Need to understand complex business relationships or derived metrics\n  - User explicitly asks \"what is [term]\" or requests definitions\n- **SKIP knowledge search** for straightforward data queries since `text2sql` tool will handle it\n- **Prioritize direct SQL execution** over knowledge lookup for routine data analysis requests\n\n[extra_tool_use_rule]\n\n# Basic Business Knowledge:\n[basic_knowledge_glossary]\n\n# Realtime Environment\n\nCurrent time is [time_field_placeholder] (format 'yyyy-MM-dd HH:mm:ss')\n\nReview current state and decide what to do next.\nIf the information is sufficient to answer the question, generate the well summarized final answer.\n\n"
  },
  {
    "path": "openchatbi/prompts/extraction_prompt.md",
    "content": "You are a specialized language expert responsible for analyzing user questions and extracting structured information for business intelligence queries. \nYour task is to process natural language questions and convert them into structured data that can be used for SQL generation and data analysis.\n\n# Context\nYou will be provided with:\n- Business knowledge glossary of [organization]\n- User question\n- Chat history (if available)\n\n[basic_knowledge_glossary]\n\n# Core Processing Steps\n\n## Step 1: Information Extraction\nExtract and categorize the following information from the user's question and context:\n\n### 1.1 Keywords (Required Array)\nExtract all relevant business terms, including:\n- Dimension names and aliases\n- Metric names and aliases\n- Entity types (exclude specific IDs/values)\n\n**Example**: \"Show revenue for order 10001\" → Extract: [\"revenue\", \"order\"] (exclude \"10001\")\n\n### 1.2 Dimensions (Required Array)\nIdentify categorical data fields that can be used for grouping or filtering:\n- Database column names (e.g., \"order_id\", \"country\", \"site_id\")\n- Distinguish between ID fields (numeric identifiers) and name fields (text labels)\n\n### 1.3 Metrics (Optional Array)\nIdentify measurable quantities that can be aggregated:\n- Numeric values that can be summed, averaged, counted, etc.\n- For derived metrics (defined in glossary), extract all component parts\n  - Example: For \"click-through rate\", extract [\"click-through rate\", \"clicks\", \"impressions\"]\n\n### 1.4 Time Range (Optional)\n**start_time** and **end_time**: Convert relative time expressions to absolute timestamps if the question is related to date/time like trends, aggregated metric, etc.\n- Format: `'%Y-%m-%d %H:%M:%S'`\n- Handle expressions like \"yesterday\", \"last 7 days\", \"from X to Y\"\n- Default to \"last 7 days\" if no time range and granularity specified\n- Specific default if user mentioned granularity:\n  - Weekly -> \"last 12 weeks\"\n  - Monthly -> \"last 12 months\"\n  - Yearly -> \"Full data\"\n\n**Example**:\n```\nQuestion: \"show top 10 ads by CTR yesterday\" (today = 2025-05-11)\nstart_time: \"2025-05-10 00:00:00\"\nend_time: \"2025-05-10 23:59:59\"\n```\n\n### 1.5 Timezone (Optional)\nExtract timezone information using this priority:\n1. Explicit mention in current question (e.g., \"in CET\", \"EST time\")\n2. Previously mentioned timezone in conversation history\n3. Reset timezone requests → \"UTC\"\n\n**Common formats**: \"America/New_York\", \"CET\", \"UTC\", \"Europe/London\"\n\n## Step 2: Filter Conditions\nGenerate SQL-compatible filter expressions:\n\n**Rules**:\n- **Text matching**: Use `LIKE '%text%'` for partial name matches\n- **Exact IDs**: Use `=` for numeric identifiers\n- **Missing context**: Generate `AskHuman` tool call for clarification\n\n**Examples**:\n- \"profile 1234\" → `[\"profile_id=1234\"]`\n- \"exam sites\" → `[\"site_name LIKE '%exam%'\"]`\n- \"the site\" (no context) → Ask for clarification\n\n## Step 3: Question Rewriting\nTransform the original question into a clear, comprehensive query specification.\n\n**Process**:\n1. **Analysis**: Break down each component of the user's request\n2. **Verification**: Confirm all elements are understood and unambiguous\n3. **Rewrite**: Create detailed, explicit version with no ambiguity\n\n**Enhancement Rules**:\n- Add metric definitions in brackets: \"CTR\" → \"click-through rate (clicks/impressions)\"\n- Include default time range if none specified\n- Include visualization preference if provided by user\n- Preserve user intent while adding necessary context\n- Use conversation history to fill gaps\n\n# Knowledge Search Decision\n\nBefore extracting information, determine if knowledge search is needed:\n\n## When to Search Knowledge (use `search_knowledge` tool):\n- **Unfamiliar terms**: Business-specific jargon, custom metrics, or domain acronyms not in basic knowledge\n- **Ambiguous terminology**: Terms that could have multiple meanings in business context\n- **Complex derived metrics**: Multi-component calculations requiring formula understanding\n- **Explicit requests**: User asks \"what is [term]\" or requests definitions\n\n## When to Skip Knowledge Search (proceed with JSON extraction):\n- **Standard business terms**: Common metrics (revenue, orders, users, clicks, CTR, conversion rate)\n- **Basic dimensions**: Standard fields (date, time, location, category, status, id)\n- **Clear data requests**: Simple queries with well-understood terminology\n- **Routine analytics**: Top N, totals, averages, trends with common business terms\n\n**Decision rule**: Only search knowledge if you encounter terms that are NOT covered in your basic business knowledge or if terminology is genuinely ambiguous in the business context.\n\n# Output Format\n\nReturn a JSON object with the following structure:\n\n```json\n{\n  \"reasoning\": \"Step-by-step analysis of user input and decision-making process\",\n  \"keywords\": [\"array\", \"of\", \"extracted\", \"keywords\"],\n  \"dimensions\": [\"array\", \"of\", \"dimension\", \"names\"],\n  \"metrics\": [\"array\", \"of\", \"metric\", \"names\"],\n  \"filter\": [\"array\", \"of\", \"sql\", \"expressions\"],\n  \"start_time\": \"YYYY-MM-DD HH:MM:SS\",\n  \"end_time\": \"YYYY-MM-DD HH:MM:SS\",\n  \"timezone\": \"timezone_identifier\",\n  \"rewrite_question\": \"Complete and detailed question rewrite\"\n}\n```\n\n# Quality Guidelines\n\n## Data Consistency\n- If a dimension appears in filters, include it in the dimensions array\n- Extract all aliases for derived metrics as defined in the glossary\n\n## Accuracy Rules\n- **No fabrication**: Only use information present in context or glossary\n- **Prioritization**: Current question takes precedence over chat history\n- **Completeness**: Use chat history to fill gaps when current question lacks detail\n\n## Output Formatting\n- **Standard response**: JSON wrapped in ```json code blocks\n- **Clarification needed**: Generate `AskHuman` tool call instead of JSON\n- **Required fields**: Always include `reasoning`, `keywords`, `dimensions`, `filter`, `rewrite_question`\n\n# Comprehensive Example\n\n**Input Question**: \"Show me site 1001's CTR trend from 2024-04-01 to 2024-04-10\"\n\n**Expected Output**:\n```json\n{\n  \"reasoning\": \"User wants to analyze click-through rate trends for a specific site. Breaking down the request: 1) Site identifier: 1001 (numeric ID), 2) Metric: CTR (click-through rate, calculated as clicks/impressions), 3) Analysis type: trend (time-based progression), 4) Time range: 2024-04-01 to 2024-04-10 (9-day period). Since it's a short time range, hourly granularity is most appropriate for trend analysis. All components are clear and complete.\",\n  \"keywords\": [\"site\", \"click-through rate\", \"CTR\", \"clicks\", \"impressions\", \"trend\"],\n  \"dimensions\": [\"site_id\"],\n  \"metrics\": [\"click-through rate\", \"clicks\", \"impressions\"],\n  \"filter\": [\"site_id=1001\"],\n  \"start_time\": \"2024-04-01 00:00:00\",\n  \"end_time\": \"2024-04-11 00:00:00\",\n  \"rewrite_question\": \"Show me the hourly click-through rate (calculated as clicks/impressions) trend for site_id = 1001 from 2024-04-01 to 2024-04-10\"\n}\n```\n\n# Special Cases\n\n## Case 1: Insufficient Information\n**Input**: \"Show me revenue trends for the site\"\n**Action**: Generate `AskHuman` tool call requesting site identification\n\n## Case 2: Conversation Context Usage\n**Previous**: \"Let's analyze site ABC performance\"\n**Current**: \"Show me CTR for last week\"\n**Result**: Inherit site \"ABC\" context\n\n## Case 3: Timezone Handling\n**Input**: \"Yesterday's metrics in EST\"\n**Result**: Extract timezone=\"America/New_York\", calculate yesterday in EST\n\n# Environment Variables\n- Current date: `[time_field_placeholder]`\\\n"
  },
  {
    "path": "openchatbi/prompts/schema_linking_prompt.md",
    "content": "You are a language expert and professional SQL engineer tasked with analyzing questions from [organization] users and selecting the appropriate table to write SQL. \n- You need to analyze the user's question, find the possible dimensions and metrics, and then select the tables and all required columns related to the query. \n- I will give you the business knowledge introduction and the glossaries of [organization] for reference.\n- I will give you the data warehouse introduction about how these tables are generated and organized.\n- I will give you the candidate tables and their schema, read the table description and rule carefully to understand the purpose and capability of the table, and select the appropriate tables and columns.\n\n[basic_knowledge_glossary]\n\n[data_warehouse_introduction]\n\n# Candidate Tables\nI found the following tables and their relevant columns and descriptions that might contain the data the user is looking for.\n[tables]\n\n\n# Examples\nHere are some examples of questions and selected tables related to the user's question\n[examples]\n\n\n# General Rules\n- Must follow the table description and rule to select the table first\n- If it is not clear which table to select, you can check the columns in the table to find the columns most related to the question\n- The \"Candidate Tables\" contain all the tables and columns you can use, NEVER make up columns or tables.\n- VERY IMPORTANT: the columns you outputted **MUST** be contained in the table you selected, as described in the \"# Candidate Tables\" section.\n- If the question is asking about the metadata of an entity only, you should find a suitable dimension table\n- If the question needs to join the fact table with the dimension table, you should also output the dimension table\n- If there are very similar questions in examples, you can refer to the selected tables in examples.\n- If there are multiple tables that both need requirements, you should select the most relevant one.\n- Select and output multiple tables when single table do not contain all fields and need join from multiple tables.\n\n\n# Output Format \nYou should output a JSON object, it should include:\n   - tables: JSON array of selected tables and columns\n     - table: The selected table\n     - columns: The columns in the table that are related to the question\n   - reasoning: The reasoning behind the table selection\nStrictly only output the format of JSON below, and do not output any extra description content.\n\n\n## Example\n```json\n{\n    \"reasoning\": \"the reason you select the two tables and columns\",\n    \"tables\": [\n      {\n        \"table\": \"table_name1\",\n        \"columns\": [\"column1\", \"column2\", \"column3\"]\n      },\n      {\n        \"table\": \"table_name2\",\n        \"columns\": [\"column4\", \"column5\"]\n      }]\n}\n```\n"
  },
  {
    "path": "openchatbi/prompts/sql_dialect/presto.md",
    "content": "# Rules for Presto SQL\n- Use 'LIKE' instead of 'ILIKE' in the Presto SQL.\n- If there is a 'GROUP BY' clause in the user query, you can use serial number(1,2,3..) instead of the column names in the 'GROUP BY' clause.\n- When filter Array type dimension, use ARRAYS_OVERLAP, e.g. ARRAYS_OVERLAP(states, ARRAY['CA'])\n- If you have to write two SQL statements, ensure to separate them with a semicolon `;`.\n- If there is no 'limit' or 'top' count mentioned in user question, default use \"LIMIT 10000\" in the SQL query. \n- If the SQL you provide includes fuzzy matching filters (e.g., 'name LIKE ...'), you should apply a GROUP BY clause for this dimension to handle cases where multiple rows have similar names.\n- If a table name is referenced multiple times in the SQL query (e.g., table_name.column), assign an alias to the table to simplify the query, such as (a.column).\n- If you need to write an SQL statement that involves division between two columns, use CAST to convert the numerator to a double type and ensure the denominator is not zero. For example: CASE WHEN SUM(column2)=0 THEN 0 ELSE CAST(SUM(column1) AS DOUBLE) / SUM(column2) END\n\n## Datetime filter related rules\n- Please use \"INTERVAL '7' DAY\" instead of \"INTERVAL '1' WEEK\" in the presto sql you given.\n- Do not use \"DATE_SUB\" or \"DATE_ADD\", You can use only datetime calculation like \"NOW() - INTERVAL '1' DAY\".\n- Use `NOW()` instead of `CURRENT_DATE` when you need to get the current date.\n- If user ask date range measured in days or months, you need to make sure the end date is included, example: \"from 2025-03-12 to 2025-03-16\", the condition should be `WHERE event_date >= timestamp '2025-03-12 00:00:00' and event_date < timestamp '2025-03-17 00:00:00'`\n- If user ask for time range measured in hours, the end time is not included, example: \"from 2025-03-12 00:00:00 to 2025-03-16 00:00:00\", the condition should be `WHERE event_date >= timestamp '2025-03-12 00:00:00' and event_date < timestamp '2025-03-16 00:00:00'`\n- If user ask for daily breakdown, make sure the whole day are correctly filtered, e.g. \"last 7 days per day\" should be `WHERE event_date >= DATE_TRUNC('day', (NOW() - INTERVAL '7' DAY)) AND event_date < DATE_TRUNC('day', NOW())`\n\n## Rules for Timezone\n### 1. Default Timezone\nAll event_date in the table are stored in **UTC**. If the user specifies a timezone (e.g., CET, PST), convert between timezones accordingly.\n### 2. Timezone Conversion Syntax \n- Use `AT TIME ZONE` to convert event_date to other timezone. Example, to convert to CET: `event_date_expr AT TIME ZONE 'CET'`\n- User `with_timezone` function to define a constant timestamp with timezone. Example, 2025-05-06 00:00 at CET: `with_timezone(timestamp '2025-05-06 00:00:00', 'CET')`\n### 3. WHERE Clause Conversion\n- If the user query includes absolute(constant) filters with a specific timezone, convert the timestamp with user timezone to UTC. Keep relative time filters unchanged.\n  - Example: `WHERE event_date >= timestamp '2025-01-01 00:00:00' and event_date < NOW() - INTERVAL '1' DAY` ->\n    `WHERE event_date >= with_timezone(timestamp '2025-01-01 00:00:00', CET') AT TIME ZONE 'UTC' AND event_date < NOW() - INTERVAL '1' DAY`\n- Instruction when user ask for daily breakdown with timezone\n  - If ask for relative date, the filter condition should use the date as \"trunc date at that timezone first, then convert to UTC\"\n    - Example: \"last 7 days per day in NY time\" should be `AND event_date >= DATE_TRUNC('day', (NOW() - INTERVAL '7' DAY) AT TIME ZONE 'America/New_York') AT TIME ZONE 'UTC' AND event_date < DATE_TRUNC('day', NOW() AT TIME ZONE 'America/New_York') AT TIME ZONE 'UTC'`\n  - If ask for absolute(constant) date, the filter condition should convert the 00:00 timestamp with user timezone to UTC\n    - Example: \"during 2025-05-04 and 2025-05-14 per day in NY time\" should be `AND event_date >= with_timezone(timestamp '2025-05-04 00:00:00', 'America/New_York') AT TIME ZONE 'UTC' AND event_date < with_timezone(timestamp '2025-05-15 00:00:00', 'America/New_York') AT TIME ZONE 'UTC'`\n### 4. SELECT Clause Conversion\n- If applying a function f to a datetime column (e.g.date_trunc, format_datetime), convert the event_date from UTC to the user’s timezone before applying f, then cast the result as TIMESTAMP.\n- Example: `SELECT f(event_date) AS event_date`-> `SELECT CAST(f(event_date AT TIME ZONE 'CET') AS TIMESTAMP) AS event_date`\n### 5. Full Example\n- User Question: \"Show me hourly pv using table fact_table from 2025-01-01 to yesterday in CET\"\n- Generated SQL:\n```\nSELECT \n  CAST(date_trunc('hour', event_date AT TIME ZONE 'CET') AS TIMESTAMP) AS event_date,\n  SUM(pv) AS \"PV\"\nFROM fact_table\nWHERE \n  event_date >= with_timezone(timestamp '2025-01-01 00:00:00', 'CET') AT TIME ZONE 'UTC'\n  AND event_date < (NOW() - INTERVAL '1' DAY)\n```\n\n## Rules for Array Dimension\n- Filtering: Use ARRAYS_OVERLAP \n  - When filter value in Array type dimension , use ARRAYS_OVERLAP, e.g. ARRAYS_OVERLAP(states, ARRAY['CA'])\n- Flattening Arrays: Use CROSS JOIN UNNEST \n  - When filter Array type dimension items from a row to multi rows, use `CROSS JOIN UNNEST(COALESCE(NULLIF(id_array, ARRAY[]), ARRAY[-1])) AS t(id)`; MAKE SURE the untested alias has format like `t(id)`\n  - Additionally, when the subquery uses CROSS JOIN UNNEST, do not sum the metrics for the total array items without group by the unnested id.\n- Avoid UNNEST if the user didn’t request it \n  - If the user refers to an array-type dimension without specifying any particular item, or does not ask to expand it into individual elements, you should not use CROSS JOIN UNNEST.\n"
  },
  {
    "path": "openchatbi/prompts/summary_prompt.md",
    "content": "Create a concise summary of this conversation for continuing the data analysis work. Focus on:\n\n1. **User's Main Questions and Objectives**: What data insights or analysis the user is seeking, business questions they want answered\n2. **Key Data Analysis Results**: Important findings from SQL queries, relevant tables/columns discovered, key metrics or patterns identified\n3. **Tools and Data Sources Overview**:\n   - Key databases/tables accessed\n   - Main analysis tools used (SQL, Python, etc.)\n   - Data export formats generated\n4. **Business Context**: Important business concepts, domain-specific terms, data definitions that were clarified\n5. **Conversation Flow**: List ALL user messages with corresponding response summaries. These are critical for understanding user feedback and changing intent:\n   - User message -> Response summary (what was done/analyzed)\n   - User feedback -> How the analysis was adjusted\n   - Follow-up questions -> Additional insights provided\n6. **Current Progress**: What analysis was completed, any ongoing tasks, user feedback on results or requested modifications\n\n\nHere's an example of how your output should be structured:\n\n<example>\n1. **User's Main Questions and Objectives**:\n   [User's data analysis goals and business questions]\n\n2. **Key Data Analysis Results**:\n   - [Important SQL query results]\n   - [Key tables and relationships discovered]\n   - [Metrics and insights found]\n\n3. **Tools and Data Sources Overview**:\n   - [Databases: customer_db, sales_warehouse]\n   - [Analysis tools: SQL, Python pandas]\n   - [Exports: CSV reports, dashboard charts]\n\n4. **Business Context**:\n   - [Business concepts and terminology]\n   - [Data field definitions]\n   - [Domain knowledge gained]\n\n5. **Conversation Flow**:\n   - [User message 1: original request] -> [Response: SQL query executed, found X insights]\n   - [User message 2: clarification about metric Y] -> [Response: adjusted analysis, discovered Z pattern]\n   - [User message 3: follow-up question] -> [Response: additional exploration, provided interpretation]\n\n6. **Current Progress**:\n   [What was completed, ongoing work, and user feedback]\n</example>\n\nConversation to summarize:\n[conversation_text]\n\nPlease provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response.\n"
  },
  {
    "path": "openchatbi/prompts/system_prompt.py",
    "content": "\"\"\"System prompt templates and business configuration.\"\"\"\n\nimport importlib.resources\n\nfrom openchatbi import config\n\n# Global cache variables for lazy loading (only for file I/O operations)\n_dialect_rules_cache = None\n_agent_prompt_template_cache = None\n_extraction_prompt_template_cache = None\n_table_selection_prompt_template_cache = None\n_text2sql_prompt_template_cache = None\n_visualization_prompt_template_cache = None\n_summary_prompt_template_cache = None\n\n\ndef get_basic_knowledge():\n    \"\"\"Get basic knowledge from config.\"\"\"\n    try:\n        return config.get().bi_config.get(\"basic_knowledge_glossary\", \"\")\n    except ValueError:\n        return \"\"\n\n\ndef get_data_warehouse_introduction():\n    \"\"\"Get data warehouse introduction from config.\"\"\"\n    try:\n        return config.get().bi_config.get(\"data_warehouse_introduction\", \"\")\n    except ValueError:\n        return \"\"\n\n\ndef get_agent_extra_tool_use_rule():\n    \"\"\"Get agent extra tool use rule from config.\"\"\"\n    try:\n        return config.get().bi_config.get(\"extra_tool_use_rule\", \"\")\n    except ValueError:\n        return \"\"\n\n\ndef get_organization():\n    \"\"\"Get organization from config.\"\"\"\n    try:\n        return config.get().organization\n    except ValueError:\n        return \"The Company\"\n\n\ndef get_dialect_rules():\n    \"\"\"Get SQL dialect rules with lazy loading and caching.\"\"\"\n    global _dialect_rules_cache\n    if _dialect_rules_cache is None:\n        dialect_dir = importlib.resources.files(\"openchatbi.prompts.sql_dialect\")\n        _dialect_rules_cache = {}\n\n        for item in dialect_dir.iterdir():\n            if item.is_file() and item.name.endswith(\".md\"):\n                dialect_name = item.name[:-3]\n                with item.open() as f:\n                    prompt = f.read()\n                    _dialect_rules_cache[dialect_name] = prompt\n    return _dialect_rules_cache\n\n\ndef get_agent_prompt_template() -> str:\n    \"\"\"Get agent prompt template with caching.\"\"\"\n    global _agent_prompt_template_cache\n    if _agent_prompt_template_cache is None:\n        with importlib.resources.files(\"openchatbi.prompts\").joinpath(\"agent_prompt.md\").open(\"r\") as f:\n            prompt = f.read()\n\n        _agent_prompt_template_cache = (\n            prompt.replace(\"[organization]\", get_organization())\n            .replace(\"[basic_knowledge_glossary]\", get_basic_knowledge())\n            .replace(\"[extra_tool_use_rule]\", get_agent_extra_tool_use_rule())\n        )\n    return _agent_prompt_template_cache\n\n\ndef get_extraction_prompt_template() -> str:\n    \"\"\"Get extraction prompt template with caching.\"\"\"\n    global _extraction_prompt_template_cache\n    if _extraction_prompt_template_cache is None:\n        with importlib.resources.files(\"openchatbi.prompts\").joinpath(\"extraction_prompt.md\").open(\"r\") as f:\n            prompt = f.read()\n\n        _extraction_prompt_template_cache = prompt.replace(\"[organization]\", get_organization()).replace(\n            \"[basic_knowledge_glossary]\", get_basic_knowledge()\n        )\n    return _extraction_prompt_template_cache\n\n\ndef get_table_selection_prompt_template() -> str:\n    \"\"\"Get table selection prompt template with caching.\"\"\"\n    global _table_selection_prompt_template_cache\n    if _table_selection_prompt_template_cache is None:\n        with importlib.resources.files(\"openchatbi.prompts\").joinpath(\"schema_linking_prompt.md\").open(\"r\") as f:\n            prompt = f.read()\n        _table_selection_prompt_template_cache = prompt.replace(\"[organization]\", get_organization()).replace(\n            \"[basic_knowledge_glossary]\", get_basic_knowledge()\n        )\n    return _table_selection_prompt_template_cache\n\n\ndef get_text2sql_prompt_template() -> str:\n    \"\"\"Get text2sql prompt template with caching.\"\"\"\n    global _text2sql_prompt_template_cache\n    if _text2sql_prompt_template_cache is None:\n        with importlib.resources.files(\"openchatbi.prompts\").joinpath(\"text2sql_prompt.md\").open(\"r\") as f:\n            prompt = f.read()\n        _text2sql_prompt_template_cache = (\n            prompt.replace(\"[organization]\", get_organization())\n            .replace(\"[basic_knowledge_glossary]\", get_basic_knowledge())\n            .replace(\"[data_warehouse_introduction]\", get_data_warehouse_introduction())\n        )\n    return _text2sql_prompt_template_cache\n\n\ndef get_visualization_prompt_template() -> str:\n    \"\"\"Get visualization prompt template with caching.\"\"\"\n    global _visualization_prompt_template_cache\n    if _visualization_prompt_template_cache is None:\n        with importlib.resources.files(\"openchatbi.prompts\").joinpath(\"visualization_prompt.md\").open(\"r\") as f:\n            _visualization_prompt_template_cache = f.read()\n    return _visualization_prompt_template_cache\n\n\ndef get_summary_prompt_template() -> str:\n    \"\"\"Get summary prompt template with caching.\"\"\"\n    global _summary_prompt_template_cache\n    if _summary_prompt_template_cache is None:\n        with importlib.resources.files(\"openchatbi.prompts\").joinpath(\"summary_prompt.md\").open(\"r\") as f:\n            _summary_prompt_template_cache = f.read()\n    return _summary_prompt_template_cache\n\n\ndef get_text2sql_dialect_prompt_template(dialect: str) -> str:\n    \"\"\"Get text2sql prompt template for specific SQL dialect.\"\"\"\n    prompt = get_text2sql_prompt_template()\n    if not prompt:\n        prompt = \"Generate SQL query for the given question in [dialect] dialect.\"\n\n    dialect_rules = get_dialect_rules()\n    prompt = prompt.replace(\"[dialect]\", dialect).replace(\"[sql_dialect_rules]\", dialect_rules.get(dialect, \"\"))\n    return prompt\n\n\ndef reset_cache():\n    \"\"\"Reset all cached values. Useful for testing.\"\"\"\n    global _dialect_rules_cache, _agent_prompt_template_cache\n    global _extraction_prompt_template_cache, _table_selection_prompt_template_cache\n    global _text2sql_prompt_template_cache, _visualization_prompt_template_cache\n    global _summary_prompt_template_cache\n\n    _dialect_rules_cache = None\n    _agent_prompt_template_cache = None\n    _extraction_prompt_template_cache = None\n    _table_selection_prompt_template_cache = None\n    _text2sql_prompt_template_cache = None\n    _visualization_prompt_template_cache = None\n    _summary_prompt_template_cache = None\n"
  },
  {
    "path": "openchatbi/prompts/text2sql_prompt.md",
    "content": "You are a professional SQL engineer, your task is to transform user query into [dialect] SQL. \n- I will give you the business knowledge introduction and the glossaries of [organization] for reference.\n- I will give you the selected tables, you need to analyze the user query, read the table description, schema, constrains and examples carefully to write [dialect] SQL to answer user's question.\n- You are a read-only analytics assistant. NEVER generate DELETE, DROP, UPDATE, or INSERT statements. \n\n[basic_knowledge_glossary]\n\n# Tables\n[table_schema]\n\n# Examples\n[examples]\n\n# Rules for [dialect] SQL\n[sql_dialect_rules]\n\n# Rules for Task\n- I will provide you with data schema definition and the explanation and usage scenario of each field.\n- You can only use the tables listed in \"# Tables\". \n- You can only use the metrics, dimension, columns from the schema I provided.\n- You should only use the display name as alias in query if provided in schema.\n- Never create or assume additional tables or columns, even if they were mentioned in history message.\n- Do not use any id or date in example SQL.\n- Do not output any explanations or comment.\n- If the query asks for a metric or field not explicitly defined in the table schema, do not generate a SQL query with an invented field, instead, you should output \"NULL\".\n- You can only answer when you are very confident, otherwise, please output \"NULL\"\n\n# Output format(case sensitive)\n```sql\n<SQL>\n```\n\n# Realtime Environment \nCurrent time is [time_field_placeholder] (format 'yyyy-MM-dd HH:mm:ss')\n\nBased on the Tables, Columns, take your time to think user query carefully, transform it into [dialect] SQL and reply following Output format.\n"
  },
  {
    "path": "openchatbi/prompts/visualization_prompt.md",
    "content": "You are a data visualization expert. Analyze the user's question and data to recommend the most appropriate chart type.\n\n## User Question\n[question]\n\n## Data Schema\n- Columns: [columns]\n- Numeric columns: [numeric_columns]\n- Categorical columns: [categorical_columns]\n- DateTime columns: [datetime_columns]\n- Row count: [row_count]\n\n## Data Sample\n[data_sample]\n\n## Available Chart Types\n1. **line** - For trends over time, time series data\n2. **bar** - For comparing categories, discrete comparisons\n3. **pie** - For showing proportions, parts of a whole (best for <= 6 categories)\n4. **scatter** - For showing relationships between two numeric variables\n5. **histogram** - For showing distribution of a single numeric variable\n6. **box** - For showing statistical distribution, outliers, quartiles\n7. **heatmap** - For showing correlation or intensity across two dimensions\n8. **table** - For detailed data examination, small datasets, or when charts aren't suitable\n\n## Analysis Guidelines\nConsider:\n- The user's intent and question keywords\n- Data types and structure\n- Number of data points and categories\n- What insights the user is likely seeking\n\n## Response Format\nRespond with ONLY the chart type name (line, bar, pie, scatter, histogram, box, heatmap, or table). No explanation needed."
  },
  {
    "path": "openchatbi/text2sql/__init__.py",
    "content": "\"\"\"Text-to-SQL conversion module for OpenChatBI.\"\"\"\n"
  },
  {
    "path": "openchatbi/text2sql/data.py",
    "content": "import os\n\nfrom openchatbi import config\nfrom openchatbi.text2sql.text2sql_utils import init_sql_example_retriever, init_table_selection_example_dict\n\n# Skip init during documentation build\nif not os.environ.get(\"SPHINX_BUILD\"):\n    try:\n        _catalog_store = config.get().catalog_store\n    except ValueError:\n        _catalog_store = None\nelse:\n    _catalog_store = None\n\nif _catalog_store:\n    sql_example_retriever, sql_example_dicts = init_sql_example_retriever(_catalog_store, config.get().vector_db_path)\n    table_selection_retriever, table_selection_example_dict = init_table_selection_example_dict(\n        _catalog_store, config.get().vector_db_path\n    )\nelse:\n    sql_example_retriever, sql_example_dicts = None, {}\n    table_selection_retriever, table_selection_example_dict = None, {}\n"
  },
  {
    "path": "openchatbi/text2sql/extraction.py",
    "content": "\"\"\"Information extraction module for text2sql processing.\"\"\"\n\nimport traceback\nfrom collections.abc import Callable\nfrom datetime import date\nfrom typing import Any\n\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.messages import AIMessage, HumanMessage, SystemMessage\n\nfrom openchatbi.graph_state import SQLGraphState\nfrom openchatbi.llm.llm import call_llm_chat_model_with_retry\nfrom openchatbi.prompts.system_prompt import get_basic_knowledge, get_extraction_prompt_template\nfrom openchatbi.utils import extract_json_from_answer, get_text_from_content, log\n\n\ndef generate_extraction_prompt() -> str:\n    \"\"\"Generate extraction prompt.\n\n    Returns:\n        str: Generated prompt with placeholders replaced.\n    \"\"\"\n    prompt = get_extraction_prompt_template()\n\n    date_str = date.today().strftime(\"%Y-%m-%d\")\n    prompt = prompt.replace(\"[time_field_placeholder]\", date_str)\n    prompt = prompt.replace(\"[basic_knowledge_glossary]\", get_basic_knowledge())\n    return prompt\n\n\ndef parse_extracted_info_json(llm_answer_content: Any) -> dict[str, Any]:\n    \"\"\"Extract and parse JSON from LLM response.\n\n    Args:\n        llm_answer_content: LLM response containing JSON.\n\n    Returns:\n        dict: Parsed JSON or empty dict if parsing fails.\n    \"\"\"\n    try:\n        text = get_text_from_content(llm_answer_content)\n        result = extract_json_from_answer(text)\n    except Exception:\n        log(traceback.format_exc())\n        result = {}\n    return result\n\n\ndef information_extraction(llm: BaseChatModel) -> Callable:\n    \"\"\"Create function to extract information from questions.\n\n    Args:\n        llm (BaseChatModel): Language model for information extraction.\n\n    Returns:\n        function: Node function that extracts information from questions.\n    \"\"\"\n\n    def _extract(state: SQLGraphState):\n        \"\"\"Extract information from question in state.\n\n        Args:\n            state (SQLGraphState): Current SQL graph state with question.\n\n        Returns:\n            dict: Updated state with extracted information.\n        \"\"\"\n        messages = state[\"messages\"]\n        last_message = messages[-1]\n        user_input = last_message.content\n        log(f\"information_extraction: {user_input}\")\n        system_prompt = generate_extraction_prompt()\n        prompt = \"Please extract the information according to the context.\"\n        response = call_llm_chat_model_with_retry(\n            llm, ([SystemMessage(system_prompt)] + messages + [HumanMessage(prompt)]), [\"search_knowledge\", \"AskHuman\"]\n        )\n        if response:\n            log(response)\n            if response.tool_calls:\n                return {\"messages\": [response]}\n            else:\n                llm_answer_content = response.content\n                parsed_result = parse_extracted_info_json(llm_answer_content)\n                return {\n                    \"messages\": [response],\n                    \"rewrite_question\": parsed_result.get(\"rewrite_question\"),\n                    \"info_entities\": parsed_result,\n                }\n        else:\n            return {\"messages\": [AIMessage(role=\"system\", content=\"{}\")]}\n\n    return _extract\n\n\ndef information_extraction_conditional_edges(state: SQLGraphState):\n    \"\"\"Determine next node after information extraction.\n\n    Args:\n        state (SQLGraphState): Current SQL graph state.\n\n    Returns:\n        str: Next node ('ask_human', 'search_knowledge', 'next', or 'end').\n    \"\"\"\n    messages = state[\"messages\"]\n    last_message = messages[-1]\n    tool_calls = None\n    if isinstance(last_message, AIMessage):\n        tool_calls = last_message.tool_calls\n        log(f\"tool_calls: {tool_calls}\")\n    if tool_calls:\n        if tool_calls[0][\"name\"] == \"AskHuman\":\n            return \"ask_human\"\n        elif tool_calls[0][\"name\"] == \"search_knowledge\":\n            return \"search_knowledge\"\n        else:\n            print(f\"Unknown tool call: {tool_calls[0]['name']}\")\n            return \"end\"\n    else:\n        if \"rewrite_question\" in state:\n            return \"next\"\n        else:\n            return \"end\"\n"
  },
  {
    "path": "openchatbi/text2sql/generate_sql.py",
    "content": "import datetime\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport pandas as pd\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.messages import AIMessage, HumanMessage, SystemMessage\nfrom sqlalchemy import text\nfrom sqlalchemy.exc import DatabaseError, OperationalError, ProgrammingError, TimeoutError\n\nfrom openchatbi.catalog import CatalogStore\nfrom openchatbi.constants import (\n    SQL_EXECUTE_TIMEOUT,\n    SQL_NA,\n    SQL_SUCCESS,\n    SQL_SYNTAX_ERROR,\n    SQL_UNKNOWN_ERROR,\n    datetime_format,\n)\nfrom openchatbi.graph_state import SQLGraphState\nfrom openchatbi.prompts.system_prompt import get_text2sql_dialect_prompt_template\nfrom openchatbi.text2sql.data import sql_example_dicts, sql_example_retriever\nfrom openchatbi.text2sql.visualization import VisualizationService\nfrom openchatbi.utils import get_text_from_content, log\n\nCOLUMN_PROMPT_TEMPLATE = \"\"\"### Columns\nColumn(Name, Type, Display Name, Description):\n[\n{}\n]\n\"\"\"\n\n\ndef create_sql_nodes(\n    llm: BaseChatModel, catalog: CatalogStore, dialect: str, visualization_mode: str | None = \"rule\"\n) -> tuple[Callable, Callable, Callable, Callable]:\n    \"\"\"Creates the four SQL processing nodes for LangGraph.\n\n    Args:\n        llm (BaseChatModel): The language model to use for SQL generation.\n        catalog (CatalogStore): The catalog store containing schema information.\n        dialect (str): The SQL dialect to use (e.g., 'presto', 'mysql').\n        visualization_mode (str | None): Visualization analysis mode (\"rule\", \"llm\", or None to skip).\n\n    Returns:\n        tuple: Four node functions (generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node)\n    \"\"\"\n\n    # Initialize visualization service based on configuration\n    visualization_service = VisualizationService(llm if visualization_mode == \"llm\" else None)\n\n    def _get_column_prompt(column: dict[str, Any]) -> str:\n        alias_prompt = f\"alias({column['alias']})\" if \"alias\" in column and column[\"alias\"] else \"\"\n        return (\n            f\"\"\"    Column(\"{column['column_name']}\", {column['type']}, {column['display_name']},\"\"\"\n            f\"\"\" \"{alias_prompt}{column['description']}\"),\"\"\"\n        )\n\n    def _get_table_schema_prompt(tables_columns: list[dict[str, Any]]) -> str:\n        \"\"\"Generates a prompt string for table schemas, including table description,\n        columns, derived metrics and rules when writting SQL\n\n        Args:\n            tables_columns (List[Dict[str, Any]]): List of tables with selected columns.\n\n        Returns:\n            str: Formatted table schema prompt string.\n        \"\"\"\n        schema_prompt = []\n        for table_dict in tables_columns:\n            table_name = table_dict[\"table\"]\n            # TODO maybe use columns in prompt\n            columns = table_dict[\"columns\"]\n            table_info = catalog.get_table_information(table_name)\n            single_table_schema_prompt = f\"## Table {table_name}\\n{table_info['description']}\\n\"\n            columns = catalog.get_column_list(table_name)\n            single_table_schema_prompt += COLUMN_PROMPT_TEMPLATE.format(\n                \"\\n\".join([_get_column_prompt(column) for column in columns])\n            )\n            single_table_schema_prompt += table_info.get(\"derived_metric\", \"\")\n            single_table_schema_prompt += table_info[\"sql_rule\"]\n            schema_prompt.append(single_table_schema_prompt)\n        return \"\\n\".join(schema_prompt)\n\n    def _get_relevant_sql_examples_prompt(question, tables_columns: list[dict[str, Any]]) -> str:\n        \"\"\"Retrieves relevant SQL examples based on the question and selected tables.\n\n        Args:\n            question (str): The natural language question.\n            tables_columns (List[str]): List of selected tables with selected columns.\n\n        Returns:\n            str: Formatted string of relevant SQL examples.\n        \"\"\"\n        tables = [d[\"table\"] for d in tables_columns]\n        relevant_questions = sql_example_retriever.invoke(question)\n        # log(f\"Retrieved examples for question: {question} \\n Relevant questions: {relevant_questions}\")\n        # filter examples that only use the selected tables\n        examples = []\n        for relevant_document in relevant_questions:\n            question = relevant_document.page_content\n            example_sql, used_tables = sql_example_dicts[question]\n            if all(table in tables for table in used_tables):\n                examples.append(f\"<example>\\nQ: {question}\\nA: {example_sql}\\n</example>\\n\")\n        log(f\"Examples using selected tables: {examples}\")\n        return \"\\n\".join(examples)\n\n    def _analyze_dataframe_schema(df: pd.DataFrame) -> dict[str, Any]:\n        \"\"\"Analyze DataFrame to understand column types and characteristics.\"\"\"\n        try:\n            schema_info = {\n                \"columns\": list(df.columns),\n                \"column_types\": {},\n                \"row_count\": len(df),\n                \"numeric_columns\": [],\n                \"categorical_columns\": [],\n                \"datetime_columns\": [],\n            }\n\n            for col in df.columns:\n                dtype = str(df[col].dtype)\n                schema_info[\"column_types\"][col] = dtype\n\n                # Classify column types\n                if df[col].dtype in [\"int64\", \"float64\", \"int32\", \"float32\"]:\n                    schema_info[\"numeric_columns\"].append(col)\n                elif df[col].dtype == \"object\":\n                    # Check if it could be datetime\n                    try:\n                        pd.to_datetime(df[col].head(10))\n                        schema_info[\"datetime_columns\"].append(col)\n                    except:\n                        schema_info[\"categorical_columns\"].append(col)\n\n            # Calculate unique value counts for categorical columns\n            schema_info[\"unique_counts\"] = {}\n            for col in schema_info[\"categorical_columns\"]:\n                schema_info[\"unique_counts\"][col] = df[col].nunique()\n\n            return schema_info\n        except Exception as e:\n            return {\"error\": f\"Failed to analyze data schema: {str(e)}\"}\n\n    def _execute_sql(sql: str) -> tuple[dict, str]:\n        \"\"\"Executes the generated SQL query and returns the result with schema analysis.\n\n        Args:\n            sql (str): The SQL query to execute.\n\n        Returns:\n            Tuple[dict, str]: A tuple containing (schema_info, CSV string).\n        \"\"\"\n        with catalog.get_sql_engine().connect() as connection:\n            result = connection.execute(text(sql))\n\n            # Fetch all rows from the result\n            rows = result.fetchall()\n\n            # Get column names\n            columns = list(result.keys())\n\n            # Create DataFrame for analysis\n            df = pd.DataFrame(rows, columns=columns)\n\n            # Analyze data schema\n            schema_info = _analyze_dataframe_schema(df)\n\n            # Format as CSV\n            csv_data = df.to_csv(index=False)\n\n            connection.commit()\n            return schema_info, csv_data\n\n    def generate_sql_node(state: SQLGraphState) -> dict:\n        \"\"\"First node: Generates initial SQL query based on the state.\n\n        Args:\n            state (SQLGraphState): The current SQL graph state containing the question and tables.\n\n        Returns:\n            dict: Updated state with generated SQL query.\n        \"\"\"\n        if \"rewrite_question\" not in state:\n            log(\"Missing rewrite question, skipping SQL generation.\")\n            return {}\n        if \"tables\" not in state or len(state[\"tables\"]) == 0:\n            log(\"Missing tables, skipping SQL generation.\")\n            return {}\n\n        question = state[\"rewrite_question\"]\n        tables_columns = state[\"tables\"]\n        system_prompt = (\n            get_text2sql_dialect_prompt_template(dialect)\n            .replace(\"[table_schema]\", _get_table_schema_prompt(tables_columns))\n            .replace(\"[examples]\", _get_relevant_sql_examples_prompt(question, tables_columns))\n            .replace(\"[time_field_placeholder]\", datetime.datetime.now().strftime(datetime_format))\n        )\n\n        user_prompt = f\"\"\"Generate a SQL query for the question: {question}\"\"\"\n        messages = [SystemMessage(system_prompt)] + list(state[\"messages\"]) + [HumanMessage(user_prompt)]\n\n        response = llm.invoke(messages)\n        response_content = get_text_from_content(response.content)\n        sql_query = response_content.replace(\"```sql\", \"\").replace(\"```\", \"\").strip()\n\n        if not sql_query or sql_query.lower() == \"null\":\n            log(f\"Generated SQL query is empty. LLM output: {response.content}\")\n            return {\n                \"messages\": [AIMessage(response_content)],\n                \"sql\": sql_query,\n                \"sql_retry_count\": 0,\n                \"sql_execution_result\": \"\",\n                \"previous_sql_errors\": [],\n            }\n\n        return {\"sql\": sql_query, \"sql_retry_count\": 0, \"sql_execution_result\": \"\", \"previous_sql_errors\": []}\n\n    def execute_sql_node(state: SQLGraphState) -> dict:\n        \"\"\"Second node: Executes the SQL query and returns result or error.\n\n        Args:\n            state (SQLGraphState): The current SQL graph state containing the SQL query.\n\n        Returns:\n            dict: Updated state with execution result or error information.\n        \"\"\"\n        sql_query = state.get(\"sql\", \"\").strip()\n        if not sql_query:\n            return {\"sql_execution_result\": SQL_NA, \"messages\": [AIMessage(\"No SQL query to execute\")]}\n\n        try:\n            schema_info, csv_result = _execute_sql(sql_query)\n            result = f\"```sql\\n{sql_query}\\n```\\nSQL Result:\\n```csv\\n{csv_result}\\n```\"\n            return {\n                \"sql_execution_result\": SQL_SUCCESS,\n                \"schema_info\": schema_info,\n                \"data\": csv_result,\n                \"messages\": [AIMessage(result)],\n            }\n        except (OperationalError, TimeoutError) as e:\n            log(f\"Database connection/timeout error: {str(e)}\")\n            error_result = (\n                f\"```sql\\n{sql_query}\\n```\\nDatabase Connection Timeout: {str(e)}\\nPlease check database connectivity.\"\n            )\n            return {\"sql_execution_result\": SQL_EXECUTE_TIMEOUT, \"messages\": [AIMessage(error_result)]}\n        except Exception as e:\n            error_type = \"Unexpected error\"\n            if isinstance(e, ProgrammingError):\n                error_type = \"SQL syntax error\"\n            elif isinstance(e, DatabaseError):\n                error_type = \"Database error\"\n\n            log(f\"{error_type}: {str(e)}\")\n\n            # Add error to previous errors list\n            previous_errors = list(state.get(\"previous_sql_errors\", []))\n            previous_errors.append({\"sql\": sql_query, \"error\": f\"{error_type}: {str(e)}\", \"error_type\": error_type})\n\n            return {\n                \"sql_execution_result\": SQL_UNKNOWN_ERROR if error_type == \"Unexpected error\" else SQL_SYNTAX_ERROR,\n                \"previous_sql_errors\": previous_errors,\n            }\n\n    def regenerate_sql_node(state: SQLGraphState) -> dict:\n        \"\"\"Third node: Regenerates SQL based on previous errors.\n\n        Args:\n            state (SQLGraphState): The current SQL graph state containing error information.\n\n        Returns:\n            dict: Updated state with regenerated SQL query.\n        \"\"\"\n        question = state[\"rewrite_question\"]\n        tables = state[\"tables\"]\n        previous_errors = state.get(\"previous_sql_errors\", [])\n        retry_count = state.get(\"sql_retry_count\", 0) + 1\n\n        system_prompt = (\n            get_text2sql_dialect_prompt_template(dialect)\n            .replace(\"[table_schema]\", _get_table_schema_prompt(tables))\n            .replace(\"[examples]\", _get_relevant_sql_examples_prompt(question, tables))\n            .replace(\"[time_field_placeholder]\", datetime.datetime.now().strftime(datetime_format))\n        )\n\n        user_prompt = f\"\"\"Generate a SQL query for the question: {question}\"\"\"\n        if previous_errors:\n            user_prompt += \"\\n\\nPrevious attempts failed with errors:\"\n            for i, error_info in enumerate(previous_errors, 1):\n                user_prompt += f\"\\n\\nAttempt {i}:\\nSQL: {error_info['sql']}\\nError: {error_info['error']}\"\n            user_prompt += \"\\n\\nPlease analyze the errors above and generate a corrected SQL query.\"\n\n        messages = [SystemMessage(system_prompt)] + list(state[\"messages\"]) + [HumanMessage(user_prompt)]\n\n        response = llm.invoke(messages)\n        response_content = get_text_from_content(response.content)\n        sql_query = response_content.replace(\"```sql\", \"\").replace(\"```\", \"\").strip()\n\n        if not sql_query:\n            log(f\"Generated SQL query is empty. LLM output: {response.content}\")\n            error_result = f\"Failed to regenerate valid SQL after {retry_count} attempts.\"\n            return {\n                \"messages\": [AIMessage(error_result)],\n                \"sql\": \"\",\n                \"sql_retry_count\": retry_count,\n                \"sql_execution_result\": SQL_NA,\n            }\n\n        return {\"sql\": sql_query, \"sql_retry_count\": retry_count, \"sql_execution_result\": \"\"}\n\n    def generate_visualization_node(state: SQLGraphState) -> dict:\n        \"\"\"Fourth node: Generates visualization DSL based on successful SQL execution result.\n\n        Args:\n            state (SQLGraphState): The current SQL graph state containing query data and results.\n\n        Returns:\n            dict: Updated state with visualization DSL.\n        \"\"\"\n        execution_result = state.get(\"sql_execution_result\", \"\")\n        if execution_result != SQL_SUCCESS:\n            # No visualization for failed queries\n            return {\"visualization_dsl\": {}}\n\n        question = state.get(\"rewrite_question\", \"\")\n        schema_info = state.get(\"schema_info\", {})\n        data = state.get(\"data\", \"\")\n\n        if not question or not schema_info or not data or not visualization_mode:\n            return {\"visualization_dsl\": {}}\n\n        try:\n            # Generate visualization DSL using configured service\n            viz_dsl = visualization_service.generate_visualization(question, schema_info, data)\n\n            # Handle case where visualization is skipped\n            if viz_dsl is None:\n                return {\"visualization_dsl\": {}}\n\n            # Update the AI message to include visualization information\n            messages = list(state.get(\"messages\", []))\n            if messages and hasattr(messages[-1], \"content\"):\n                current_content = messages[-1].content\n                viz_info = f\"\\n\\n**Visualization Generated**: {viz_dsl.chart_type.title()} chart with {len(viz_dsl.data_columns)} column(s)\"\n                messages[-1] = AIMessage(current_content + viz_info)\n\n            return {\"visualization_dsl\": viz_dsl.to_dict(), \"messages\": messages}\n        except Exception as e:\n            log(f\"Visualization generation error: {str(e)}\")\n            return {\"visualization_dsl\": {\"error\": f\"Failed to generate visualization: {str(e)}\"}}\n\n    return generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node\n\n\ndef should_retry_sql(state: SQLGraphState) -> str:\n    \"\"\"Conditional edge function to determine if SQL should be retried.\n\n    Args:\n        state (SQLGraphState): Current state\n\n    Returns:\n        str: Next node name - \"regenerate_sql\" if retry needed, \"end\" if done\n    \"\"\"\n    execution_result = state.get(\"sql_execution_result\", \"\")\n    retry_count = state.get(\"sql_retry_count\", 0)\n    max_retries = 3\n\n    if execution_result in (SQL_SUCCESS, SQL_EXECUTE_TIMEOUT):\n        return \"end\"\n    elif retry_count < max_retries:\n        return \"regenerate_sql\"\n    else:\n        # Max retries reached or other terminal state\n        if retry_count >= max_retries:\n            previous_errors = state.get(\"previous_sql_errors\", [])\n            if previous_errors:\n                last_error = previous_errors[-1]\n                error_result = f\"```sql\\n{last_error['sql']}\\n```\\n{last_error['error']}\\nFailed to generate valid SQL after {max_retries} attempts.\"\n            else:\n                error_result = f\"Failed to generate valid SQL after {max_retries} attempts.\"\n\n            # Update state with final error message\n            state[\"messages\"] = [AIMessage(error_result)]\n            state[\"sql_execution_result\"] = SQL_NA\n        return \"end\"\n\n\ndef should_execute_sql(state: SQLGraphState) -> str:\n    \"\"\"Conditional edge function to determine if SQL should be executed.\n\n    Args:\n        state (SQLGraphState): Current state\n\n    Returns:\n        str: Next node name - \"execute_sql\" if SQL is generated, \"end\" if done\n    \"\"\"\n    sql = state.get(\"sql\", \"\")\n    if not sql:\n        return \"end\"\n    else:\n        return \"execute_sql\"\n"
  },
  {
    "path": "openchatbi/text2sql/schema_linking.py",
    "content": "\"\"\"Schema linking module for table and column selection in text2sql.\"\"\"\n\nfrom datetime import datetime\n\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.messages import HumanMessage, SystemMessage\n\nfrom openchatbi.catalog import CatalogStore\nfrom openchatbi.catalog.schema_retrival import col_dict, column_tables_mapping, get_relevant_columns\nfrom openchatbi.constants import datetime_format\nfrom openchatbi.graph_state import SQLGraphState\nfrom openchatbi.prompts.system_prompt import get_table_selection_prompt_template\nfrom openchatbi.text2sql.data import table_selection_example_dict, table_selection_retriever\nfrom openchatbi.utils import extract_json_from_answer, log\n\n\ndef schema_linking(llm: BaseChatModel, catalog: CatalogStore):\n    \"\"\"Create function for schema linking: select appropriate tables and columns for a question.\n\n    Args:\n        llm (BaseChatModel): Language model for table selection.\n        catalog (CatalogStore): Catalog store with schema information.\n\n    Returns:\n        function: Node function for schema linking based on question.\n    \"\"\"\n\n    def _get_related_tables_and_columns(keywords_list, dimensions, metrics, start_time=None, invalid_table=None):\n        \"\"\"Retrieves tables and columns related to the given keywords, dimensions, and metrics.\n\n        Args:\n            keywords_list (list): List of keywords extracted from the question.\n            dimensions (list): List of dimensions mentioned in the question.\n            metrics (list): List of metrics mentioned in the question.\n            start_time (str, optional): Start time for filtering tables.\n            invalid_table (list, optional): List of tables to exclude.\n\n        Returns:\n            dict: Dictionary mapping table names to their information and related columns.\n        \"\"\"\n        # 1. Get the top similar columns\n        relevant_columns = get_relevant_columns(keywords_list, dimensions, metrics)\n\n        # 2. Get all the related tables\n        candidate_tables = set()\n        for column in relevant_columns:\n            table_list = column_tables_mapping.get(column, [])\n            candidate_tables.update(table_list)\n        if start_time:\n            try:\n                start_time = datetime.strptime(start_time, datetime_format)\n            except ValueError:\n                start_time = None\n\n        # 3. Get all the table's related column\n        related_table_column_dict = {}\n        for table_name in candidate_tables:\n            if table_name in invalid_table:\n                continue\n            table_info = catalog.get_table_information(table_name)\n            if not table_info:\n                continue\n            if start_time and \"start_time\" in table_info:\n                if datetime.strptime(table_info.get(\"start_time\"), datetime_format) > start_time:\n                    continue\n            columns = []\n            for column_name in relevant_columns:\n                column_dict = col_dict[column_name].copy()\n                if table_name not in column_tables_mapping.get(column_name, []):\n                    continue\n                columns.append(column_dict)\n            related_table_column_dict[table_name] = (table_info, columns)\n\n        return related_table_column_dict\n\n    def _example_retrieval(query, candidate_tables):\n        \"\"\"Retrieves example questions and their selected tables that match the candidate tables.\n\n        Args:\n            query (str): The natural language question.\n            candidate_tables (list): List of candidate table names.\n\n        Returns:\n            dict: Dictionary mapping example questions to their selected tables.\n        \"\"\"\n        similar_questions = table_selection_retriever.invoke(query)\n        valid_examples = {}\n        for question_doc in similar_questions:\n            question = question_doc.page_content\n            if not question:\n                continue\n            expected_tables = table_selection_example_dict[question]\n            expected_tables = [table for table in expected_tables if table in candidate_tables]\n            if expected_tables:\n                valid_examples[question] = expected_tables\n        return valid_examples\n\n    def _build_table_selection_prompt(related_table_column_dict, similar_examples):\n        \"\"\"Builds a prompt for table selection based on related tables and examples.\n\n        Args:\n            related_table_column_dict (dict): Dictionary of tables with their information and columns.\n            similar_examples (dict): Dictionary of example questions and their selected tables.\n\n        Returns:\n            str: Formatted prompt for table selection.\n        \"\"\"\n        similar_examples = [\n            f\"- Question: {example}   Selected Tables: [{','.join(selected_tables)}]\"\n            for example, selected_tables in similar_examples.items()\n        ]\n\n        table_column_descs = []\n        for table_name, (table_info, columns) in related_table_column_dict.items():\n            columns_desc = \"\\n\".join(\n                [\n                    f\"- {column['category']}({column['column_name']}, {column['display_name']}, \\\"{column['description']}\\\")\"\n                    for column in columns\n                ]\n            )\n            desc_part = f\"\\n### Table Description: \\n{table_info['description']}\"\n            rule_part = f\"\\n### Rule: \\n{table_info.get('selection_rule')}\" if table_info.get(\"selection_rule\") else \"\"\n            table_desc = (\n                f\"\\n## Table: {table_name} {desc_part} {rule_part}\"\n                \"\\n### Columns: \\nCategory(Name, Display Name, Description): \"\n                f\"\\n{columns_desc}\"\n                \"\"\n            )\n            table_column_descs.append(table_desc)\n\n        # Build the LLM prompt\n        prompt = (\n            get_table_selection_prompt_template()\n            .replace(\"[tables]\", \"\\n\\n\".join(table_column_descs))\n            .replace(\"[examples]\", \"\\n\".join(similar_examples))\n        )\n        return prompt\n\n    def _verify_table(selected_tables, candidate_tables):\n        \"\"\"Verifies that selected tables are valid candidates.\n\n        Args:\n            selected_tables (list): List of tables selected by the model.\n            candidate_tables (list): List of candidate tables.\n\n        Returns:\n            bool: True if all selected tables are valid candidates.\n        \"\"\"\n        if not selected_tables:\n            return False\n        for table in selected_tables:\n            if table.get(\"table\") not in candidate_tables:\n                return False\n        return True\n\n    def _call_llm_select(llm: BaseChatModel, system_prompt, messages, question, candidate_tables):\n        \"\"\"Calls the language model to select appropriate tables for the question.\n\n        Retries up to 3 times if the LLM's answer is invalid.\n\n        Args:\n            llm (BaseChatModel): The language model to use.\n            system_prompt (str): The system prompt for table selection.\n            messages (list): List of previous messages.\n            question (str): The natural language question.\n            candidate_tables (list): List of candidate tables.\n\n        Returns:\n            dict: Dictionary containing selected tables.\n        \"\"\"\n        log(\"Selecting appropriate tables...\")\n        # print(f\"candidate_tables: {candidate_tables}\")\n        prompt = f\"\"\"Please select the appropriate tables for the question: {question}\"\"\"\n        messages.append(HumanMessage(prompt))\n        retry_flag = True\n        retry_cnt = 1\n        while retry_flag:\n            try:\n                log(\"Ask LLM to select the table...\")\n                # print(\"_call_llm_select\")\n                # print(messages)\n                response = llm.invoke([SystemMessage(system_prompt)] + messages)\n                result = extract_json_from_answer(response.content)\n                selected_tables = result.get(\"tables\")\n                log(result)\n                if _verify_table(selected_tables, candidate_tables):\n                    return {\"tables\": selected_tables}\n                else:\n                    messages.append(\n                        HumanMessage(\n                            f'The selected table {\",\".join([table.get(\"table\") for table in result.get(\"tables\")])} is not valid. '\n                            f\"Do not select this table, please try again.\"\n                        )\n                    )\n                retry_cnt += 1\n                if retry_cnt > 3:\n                    retry_flag = False\n                if retry_flag:\n                    log(\n                        f\"The selected table {','.join([table.get('table') for table in result.get('tables')])} is not in the candidate tables.\"\n                    )\n                    log(\"Retry Table Selection...\")\n\n            except Exception as e:\n                log(str(e))\n                retry_cnt += 1\n                if retry_cnt > 3:\n                    retry_flag = False\n        return {}\n\n    def _select(state: SQLGraphState) -> dict:\n        if not state.get(\"rewrite_question\"):\n            log(\"Missing rewrite question, skipping schema linking.\")\n            return {}\n\n        messages = state[\"messages\"]\n        question = state[\"rewrite_question\"]\n        info_entities = state[\"info_entities\"]\n        keywords_list = info_entities.get(\"keywords\", [])\n        dimensions = info_entities.get(\"dimensions\", [])\n        metrics = info_entities.get(\"metrics\", [])\n        start_time = info_entities.get(\"start_time\")\n\n        invalid_table = []\n        log(\"Retrieving related table schema...\")\n        # 1. Get related tables and columns\n        related_table_column_dict = _get_related_tables_and_columns(\n            keywords_list, dimensions, metrics, start_time, invalid_table\n        )\n        candidate_tables = related_table_column_dict.keys()\n\n        # 2. Get the similar examples\n        similar_examples = _example_retrieval(\" \".join(keywords_list), related_table_column_dict.keys())\n\n        # 3. Build tables prompt\n        system_prompt = _build_table_selection_prompt(related_table_column_dict, similar_examples)\n\n        # 4. Call LLM to select the table\n        return _call_llm_select(llm, system_prompt, messages, question, candidate_tables)\n\n    return _select\n"
  },
  {
    "path": "openchatbi/text2sql/sql_graph.py",
    "content": "\"\"\"SQL generation graph construction and execution.\"\"\"\n\nfrom langchain_openai.chat_models.base import BaseChatOpenAI\nfrom langgraph.constants import END, START\nfrom langgraph.graph import StateGraph\nfrom langgraph.graph.state import CompiledStateGraph\nfrom langgraph.prebuilt import ToolNode\nfrom langgraph.store.base import BaseStore\nfrom langgraph.types import Checkpointer, interrupt\n\nfrom openchatbi import config\nfrom openchatbi.catalog import CatalogStore\nfrom openchatbi.constants import SQL_SUCCESS\nfrom openchatbi.graph_state import InputState, SQLGraphState, SQLOutputState\nfrom openchatbi.llm.llm import get_llm, get_text2sql_llm\nfrom openchatbi.text2sql.extraction import information_extraction, information_extraction_conditional_edges\nfrom openchatbi.text2sql.generate_sql import create_sql_nodes, should_execute_sql\nfrom openchatbi.text2sql.schema_linking import schema_linking\nfrom openchatbi.tool.ask_human import AskHuman\nfrom openchatbi.tool.search_knowledge import search_knowledge\n\n\ndef ask_human(state):\n    \"\"\"Node function to ask human for additional information or clarification.\n\n    Args:\n        state (SQLGraphState): The current SQL graph state containing messages and context.\n\n    Returns:\n        dict: Updated state with human feedback as a tool message and user input.\n    \"\"\"\n    tool_call = state[\"messages\"][-1].tool_calls[0]\n    tool_call_id = tool_call[\"id\"]\n    args = tool_call[\"args\"]\n    user_feedback = interrupt({\"text\": args[\"question\"], \"buttons\": args.get(\"options\", None)})\n    tool_message = [{\"tool_call_id\": tool_call_id, \"type\": \"tool\", \"content\": user_feedback}]\n    return {\"messages\": tool_message, \"user_input\": user_feedback}\n\n\ndef should_generate_visualization_or_retry(state: SQLGraphState) -> str:\n    \"\"\"Conditional edge function to determine next action after execute_sql.\n\n    Args:\n        state (SQLGraphState): Current state\n\n    Returns:\n        str: Next node name - \"generate_visualization\" if SQL succeeded, \"regenerate_sql\" if retry needed, \"end\" if done\n    \"\"\"\n    execution_result = state.get(\"sql_execution_result\", \"\")\n    retry_count = state.get(\"sql_retry_count\", 0)\n    max_retries = 3\n\n    if execution_result == SQL_SUCCESS:\n        return \"generate_visualization\"\n    elif retry_count < max_retries and execution_result not in (\"SQL_EXECUTE_TIMEOUT\",):\n        return \"regenerate_sql\"\n    else:\n        return \"end\"\n\n\ndef build_sql_graph(\n    catalog: CatalogStore, checkpointer: Checkpointer, memory_store: BaseStore, llm_provider: str | None = None\n) -> CompiledStateGraph:\n    \"\"\"Build SQL generation graph with all nodes and edges.\n\n    Args:\n        catalog: Catalog store containing schema information.\n        checkpointer: The Checkpointer for state persistence (short memory). If None, no short memory.\n        memory_store: The BaseStore to use for long-term memory. If None, no long-term memory.\n\n    Returns:\n        CompiledStateGraph: Compiled SQL graph ready for execution.\n    \"\"\"\n    tools = [search_knowledge, AskHuman]\n    search_tool_node = ToolNode([search_knowledge])\n    default_llm = get_llm(llm_provider)\n    if isinstance(default_llm, BaseChatOpenAI):\n        llm_with_tools = default_llm.bind_tools(tools, strict=True).bind(response_format={\"type\": \"json_object\"})\n    else:\n        llm_with_tools = default_llm.bind_tools(tools)\n    # Create SQL processing nodes with visualization configuration\n    generate_sql_node, execute_sql_node, regenerate_sql_node, generate_visualization_node = create_sql_nodes(\n        get_text2sql_llm(llm_provider),\n        catalog,\n        dialect=config.get().dialect,\n        visualization_mode=config.get().visualization_mode,\n    )\n\n    # Define the SQL generation graph\n    graph = StateGraph(SQLGraphState, input_schema=InputState, output_schema=SQLOutputState)\n\n    # Add nodes to the graph\n    graph.add_node(\"search_knowledge\", search_tool_node)\n    graph.add_node(\"ask_human\", ask_human)\n    graph.add_node(\"information_extraction\", information_extraction(llm_with_tools))\n    graph.add_node(\"table_selection\", schema_linking(default_llm, catalog))\n    graph.add_node(\"generate_sql\", generate_sql_node)\n    graph.add_node(\"execute_sql\", execute_sql_node)\n    graph.add_node(\"regenerate_sql\", regenerate_sql_node)\n    graph.add_node(\"generate_visualization\", generate_visualization_node)\n\n    # Add basic edges\n    graph.add_edge(START, \"information_extraction\")\n    graph.add_edge(\"ask_human\", \"information_extraction\")\n    graph.add_edge(\"search_knowledge\", \"information_extraction\")\n    graph.add_edge(\"table_selection\", \"generate_sql\")\n\n    # Add conditional routing from information extraction\n    graph.add_conditional_edges(\n        \"information_extraction\",\n        information_extraction_conditional_edges,\n        # mapping of paths to node names\n        {\n            \"ask_human\": \"ask_human\",\n            \"search_knowledge\": \"search_knowledge\",\n            \"next\": \"table_selection\",\n            \"end\": END,\n        },\n    )\n\n    # Add conditional edges for generate_sql\n    graph.add_conditional_edges(\n        \"generate_sql\",\n        should_execute_sql,\n        {\n            \"execute_sql\": \"execute_sql\",\n            \"end\": END,\n        },\n    )\n\n    # Add conditional edges for regenerate_sql\n    graph.add_conditional_edges(\n        \"regenerate_sql\",\n        should_execute_sql,\n        {\n            \"execute_sql\": \"execute_sql\",\n            \"end\": END,\n        },\n    )\n\n    # Add conditional edges for execute_sql - either retry, generate visualization, or end\n    graph.add_conditional_edges(\n        \"execute_sql\",\n        should_generate_visualization_or_retry,\n        {\n            \"generate_visualization\": \"generate_visualization\",\n            \"regenerate_sql\": \"regenerate_sql\",\n            \"end\": END,\n        },\n    )\n\n    # Add edge from visualization to end\n    graph.add_edge(\"generate_visualization\", END)\n\n    graph = graph.compile(name=\"text2sql_graph\", checkpointer=checkpointer, store=memory_store)\n    return graph\n"
  },
  {
    "path": "openchatbi/text2sql/text2sql_utils.py",
    "content": "\"\"\"Utility functions for text2sql retrieval systems.\"\"\"\n\nfrom openchatbi.llm.llm import get_embedding_model\nfrom openchatbi.utils import create_vector_db\n\n\ndef init_sql_example_retriever(catalog, vector_db_path: str = None):\n    \"\"\"Initialize SQL example retriever from catalog.\n\n    Args:\n        catalog: Catalog store containing SQL examples.\n        vector_db_path: Path to the vector database file.\n\n    Returns:\n        tuple: (retriever, sql_example_dict)\n    \"\"\"\n    sql_examples = catalog.get_sql_examples()\n    sql_example_dict = {q: (sql, table) for q, sql, table in sql_examples}\n\n    texts = list(sql_example_dict.keys())\n    vector_db = create_vector_db(\n        texts,\n        get_embedding_model(),\n        collection_name=\"text2sql\",\n        collection_metadata={\"hnsw:space\": \"cosine\"},\n        chroma_db_path=vector_db_path,\n    )\n    retriever = vector_db.as_retriever(\n        search_type=\"mmr\", search_kwargs={\"distance_metric\": \"cosine\", \"fetch_k\": 30, \"k\": 10}\n    )\n    return retriever, sql_example_dict\n\n\ndef init_table_selection_example_dict(catalog, vector_db_path: str = None):\n    \"\"\"Initialize table selection example retriever from catalog.\n\n    Args:\n        catalog: Catalog store containing table selection examples.\n        vector_db_path: Path to the vector database file.\n\n    Returns:\n        tuple: (retriever, table_selection_example_dict)\n    \"\"\"\n    sql_examples = catalog.get_table_selection_examples()\n    table_selection_example_dict = dict((q, tables) for q, tables in sql_examples)\n\n    texts = list(table_selection_example_dict.keys())\n    if not texts:\n        texts = [\"\"]  # Empty text as fallback\n\n    vector_db = create_vector_db(\n        texts,\n        get_embedding_model(),\n        collection_name=\"table_selection_example\",\n        collection_metadata={\"hnsw:space\": \"cosine\"},\n        chroma_db_path=vector_db_path,\n    )\n    retriever = vector_db.as_retriever(\n        search_type=\"mmr\", search_kwargs={\"distance_metric\": \"cosine\", \"fetch_k\": 30, \"k\": 10}\n    )\n    return retriever, table_selection_example_dict\n"
  },
  {
    "path": "openchatbi/text2sql/visualization.py",
    "content": "\"\"\"Visualization generation for SQL query results using Plotly.\"\"\"\n\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom io import StringIO\nfrom typing import Any\n\nimport pandas as pd\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_core.messages import HumanMessage\n\nfrom openchatbi.prompts.system_prompt import get_visualization_prompt_template\n\n\nclass ChartType(Enum):\n    \"\"\"Supported chart types for data visualization.\"\"\"\n\n    LINE = \"line\"\n    BAR = \"bar\"\n    PIE = \"pie\"\n    SCATTER = \"scatter\"\n    HISTOGRAM = \"histogram\"\n    BOX = \"box\"\n    HEATMAP = \"heatmap\"\n    TABLE = \"table\"\n\n\n@dataclass\nclass VisualizationConfig:\n    \"\"\"Configuration for generating visualization DSL.\"\"\"\n\n    chart_type: ChartType\n    x_column: str | None = None\n    y_column: str | None = None\n    color_column: str | None = None\n    size_column: str | None = None\n    title: str | None = None\n    x_title: str | None = None\n    y_title: str | None = None\n    show_legend: bool = True\n    width: int | None = None\n    height: int | None = None\n\n\n@dataclass\nclass VisualizationDSL:\n    \"\"\"Plotly-friendly DSL for data visualization.\"\"\"\n\n    chart_type: str\n    data_columns: list[str]\n    config: dict[str, Any]\n    layout: dict[str, Any]\n\n    def to_dict(self) -> dict[str, Any]:\n        \"\"\"Convert to dictionary for JSON serialization.\"\"\"\n        return {\n            \"chart_type\": self.chart_type,\n            \"data_columns\": self.data_columns,\n            \"config\": self.config,\n            \"layout\": self.layout,\n        }\n\n\nclass VisualizationService:\n    \"\"\"Service class to handle visualization generation with configurable analysis method.\"\"\"\n\n    # Chart type mapping for LLM responses\n    CHART_TYPE_MAPPING = {\n        \"line\": ChartType.LINE,\n        \"bar\": ChartType.BAR,\n        \"pie\": ChartType.PIE,\n        \"scatter\": ChartType.SCATTER,\n        \"histogram\": ChartType.HISTOGRAM,\n        \"box\": ChartType.BOX,\n        \"heatmap\": ChartType.HEATMAP,\n        \"table\": ChartType.TABLE,\n    }\n\n    def __init__(self, llm: BaseChatModel | None = None):\n        \"\"\"Initialize visualization service.\n\n        Args:\n            llm: BaseChatModel LLM instance, will skip using LLM if None\n        \"\"\"\n        self.llm = llm\n\n    def _get_chart_type_by_rule(self, question: str, schema_info: dict[str, Any]) -> ChartType:\n        \"\"\"Recommend chart type based on user question and data schema using rules.\"\"\"\n        question_lower = question.lower()\n\n        # Get data characteristics\n        numeric_cols = schema_info.get(\"numeric_columns\", [])\n        categorical_cols = schema_info.get(\"categorical_columns\", [])\n        datetime_cols = schema_info.get(\"datetime_columns\", [])\n        row_count = schema_info.get(\"row_count\", 0)\n\n        # Question-based heuristics\n        if any(keyword in question_lower for keyword in [\"trend\", \"over time\", \"timeline\", \"time series\"]):\n            return ChartType.LINE\n        elif any(keyword in question_lower for keyword in [\"distribution\", \"frequency\", \"histogram\"]):\n            return ChartType.HISTOGRAM\n        elif any(keyword in question_lower for keyword in [\"correlation\", \"relationship\", \"scatter\"]):\n            return ChartType.SCATTER\n        elif any(keyword in question_lower for keyword in [\"proportion\", \"percentage\", \"share\", \"pie\"]):\n            return ChartType.PIE\n        elif any(keyword in question_lower for keyword in [\"compare\", \"comparison\", \"vs\", \"versus\", \"bar\"]):\n            return ChartType.BAR\n        elif any(keyword in question_lower for keyword in [\"summary\", \"range\", \"quartile\", \"box\"]):\n            return ChartType.BOX\n\n        # Data-based heuristics\n        if len(datetime_cols) > 0 and len(numeric_cols) > 0:\n            return ChartType.LINE\n        elif len(categorical_cols) == 1 and len(numeric_cols) == 1:\n            unique_count = schema_info.get(\"unique_counts\", {}).get(categorical_cols[0], 0)\n            if unique_count <= 10:\n                return ChartType.PIE if unique_count <= 6 else ChartType.BAR\n            else:\n                return ChartType.BAR\n        elif len(numeric_cols) == 2:\n            return ChartType.SCATTER\n        elif len(numeric_cols) == 1 and len(categorical_cols) == 0:\n            return ChartType.HISTOGRAM\n        elif row_count <= 20:  # Changed from 50 to 20\n            return ChartType.TABLE\n        else:\n            return ChartType.BAR\n\n    def generate_visualization_dsl(\n        self, question: str, schema_info: dict[str, Any], chart_type: ChartType | None = None\n    ) -> VisualizationDSL:\n        \"\"\"Generate visualization DSL based on question and schema info.\"\"\"\n        if \"error\" in schema_info:\n            # Return table view for error cases\n            return VisualizationDSL(\n                chart_type=\"table\",\n                data_columns=[\"error\"],\n                config={\"error\": schema_info[\"error\"]},\n                layout={\"title\": \"Data Analysis Error\"},\n            )\n\n        # Determine chart type\n        if chart_type is None:\n            chart_type = self._get_chart_type_by_rule(question, schema_info)\n\n        columns = schema_info[\"columns\"]\n        numeric_cols = schema_info[\"numeric_columns\"]\n        categorical_cols = schema_info[\"categorical_columns\"]\n        datetime_cols = schema_info[\"datetime_columns\"]\n\n        # Generate DSL based on chart type\n        if chart_type == ChartType.LINE:\n            x_col = datetime_cols[0] if datetime_cols else (categorical_cols[0] if categorical_cols else columns[0])\n            # For line charts, include all numeric columns for multiple metrics\n            y_cols = numeric_cols if numeric_cols else [columns[-1]]\n            data_columns = [x_col] + y_cols\n\n            # Support multiple y-axis columns\n            config = {\"x\": x_col, \"mode\": \"lines+markers\"}\n            if len(y_cols) == 1:\n                config[\"y\"] = y_cols[0]\n                title = f\"Line Chart: {y_cols[0]} over {x_col}\"\n            else:\n                config[\"y\"] = y_cols  # Multiple metrics\n                title = f\"Line Chart: {', '.join(y_cols)} over {x_col}\"\n\n            return VisualizationDSL(\n                chart_type=\"line\",\n                data_columns=data_columns,\n                config=config,\n                layout={\"title\": title, \"xaxis_title\": x_col, \"yaxis_title\": \"Value\"},\n            )\n\n        elif chart_type == ChartType.BAR:\n            x_col = categorical_cols[0] if categorical_cols else columns[0]\n            # For bar charts, include all numeric columns for multiple metrics\n            y_cols = numeric_cols if numeric_cols else [columns[-1]]\n            data_columns = [x_col] + y_cols\n\n            config = {\"x\": x_col}\n            if len(y_cols) == 1:\n                config[\"y\"] = y_cols[0]\n                title = f\"Bar Chart: {y_cols[0]} by {x_col}\"\n            else:\n                config[\"y\"] = y_cols  # Multiple metrics\n                title = f\"Bar Chart: {', '.join(y_cols)} by {x_col}\"\n\n            return VisualizationDSL(\n                chart_type=\"bar\",\n                data_columns=data_columns,\n                config=config,\n                layout={\"title\": title, \"xaxis_title\": x_col, \"yaxis_title\": \"Value\"},\n            )\n\n        elif chart_type == ChartType.PIE:\n            label_col = categorical_cols[0] if categorical_cols else columns[0]\n            value_col = numeric_cols[0] if numeric_cols else columns[-1]\n            return VisualizationDSL(\n                chart_type=\"pie\",\n                data_columns=[label_col, value_col],\n                config={\"labels\": label_col, \"values\": value_col},\n                layout={\"title\": f\"Pie Chart: {value_col} by {label_col}\"},\n            )\n\n        elif chart_type == ChartType.SCATTER:\n            x_col = numeric_cols[0] if len(numeric_cols) > 0 else columns[0]\n            y_col = numeric_cols[1] if len(numeric_cols) > 1 else columns[-1]\n            return VisualizationDSL(\n                chart_type=\"scatter\",\n                data_columns=[x_col, y_col],\n                config={\"x\": x_col, \"y\": y_col, \"mode\": \"markers\"},\n                layout={\"title\": f\"Scatter Plot: {y_col} vs {x_col}\", \"xaxis_title\": x_col, \"yaxis_title\": y_col},\n            )\n\n        elif chart_type == ChartType.HISTOGRAM:\n            col = numeric_cols[0] if numeric_cols else columns[0]\n            return VisualizationDSL(\n                chart_type=\"histogram\",\n                data_columns=[col],\n                config={\"x\": col, \"nbins\": 20},\n                layout={\"title\": f\"Histogram: Distribution of {col}\", \"xaxis_title\": col, \"yaxis_title\": \"Frequency\"},\n            )\n\n        elif chart_type == ChartType.BOX:\n            y_col = numeric_cols[0] if numeric_cols else columns[0]\n            x_col = categorical_cols[0] if categorical_cols else None\n            config = {\"y\": y_col}\n            if x_col:\n                config[\"x\"] = x_col\n            return VisualizationDSL(\n                chart_type=\"box\",\n                data_columns=[col for col in [x_col, y_col] if col],\n                config=config,\n                layout={\n                    \"title\": f\"Box Plot: {y_col}\" + (f\" by {x_col}\" if x_col else \"\"),\n                    \"xaxis_title\": x_col if x_col else \"\",\n                    \"yaxis_title\": y_col,\n                },\n            )\n\n        else:  # TABLE or fallback\n            return VisualizationDSL(\n                chart_type=\"table\", data_columns=columns, config={\"columns\": columns}, layout={\"title\": \"Data Table\"}\n            )\n\n    def _llm_recommend_chart_type(self, question: str, schema_info: dict[str, Any], data_sample: str) -> ChartType:\n        \"\"\"Use LLM to recommend chart type based on question and data analysis.\n\n        Args:\n            question: User's question or intent\n            schema_info: Data schema information\n            data_sample: Sample of the data\n\n        Returns:\n            ChartType: Recommended chart type\n        \"\"\"\n        try:\n            prompt = (\n                get_visualization_prompt_template()\n                .replace(\"[question]\", question)\n                .replace(\"[columns]\", str(schema_info.get(\"columns\", [])))\n                .replace(\"[numeric_columns]\", str(schema_info.get(\"numeric_columns\", [])))\n                .replace(\"[categorical_columns]\", str(schema_info.get(\"categorical_columns\", [])))\n                .replace(\"[datetime_columns]\", str(schema_info.get(\"datetime_columns\", [])))\n                .replace(\"[row_count]\", str(schema_info.get(\"row_count\", 0)))\n                .replace(\"[data_sample]\", data_sample)\n            )\n\n            # Call LLM with the formatted prompt\n            response = self.llm.invoke([HumanMessage(content=prompt)])\n            chart_type_str = response.content.strip().lower()\n            return self.CHART_TYPE_MAPPING.get(chart_type_str, ChartType.TABLE)\n\n        except Exception:\n            # Fallback to rule-based recommendation on other LLM errors\n            return self._get_chart_type_by_rule(question, schema_info)\n\n    def generate_visualization(\n        self, question: str, schema_info: dict[str, Any], csv_data: str, chart_type: ChartType | None = None\n    ) -> VisualizationDSL | None:\n        \"\"\"Generate visualization using the configured analysis method.\n\n        Args:\n            question: User's question or intent\n            schema_info: Pre-analyzed schema information\n            csv_data: CSV data string for LLM analysis if needed\n            chart_type: Optional specific chart type to use\n\n        Returns:\n            VisualizationDSL or None: Generated visualization configuration, or None if skipped\n        \"\"\"\n        # Use existing DSL generation if chart type is already specified\n        if chart_type is not None:\n            return self.generate_visualization_dsl(question, schema_info, chart_type)\n\n        # Determine chart type based on configured method\n        if self.llm:\n            if \"error\" in schema_info:\n                return VisualizationDSL(\n                    chart_type=\"table\",\n                    data_columns=[\"error\"],\n                    config={\"error\": schema_info[\"error\"]},\n                    layout={\"title\": \"Data Analysis Error\"},\n                )\n\n            # Prepare data sample for LLM analysis\n            try:\n                df = pd.read_csv(StringIO(csv_data))\n                data_sample = df.head(3).to_string() if len(df) > 0 else \"No data available\"\n            except Exception:\n                data_sample = \"Unable to parse data\"\n\n            chart_type = self._llm_recommend_chart_type(question, schema_info, data_sample)\n\n        # Generate DSL using determined or recommended chart type\n        return self.generate_visualization_dsl(question, schema_info, chart_type)\n"
  },
  {
    "path": "openchatbi/text_segmenter.py",
    "content": "\"\"\"Text segmentation utility with jieba support.\"\"\"\n\nimport re\nimport string\nimport sys\n\n# Try to import jieba, fallback to None if not available\n# Note: jieba is not compatible with Python 3.12+\n_jieba_available = False\nif sys.version_info < (3, 12):\n    try:\n        import jieba\n\n        _jieba_available = True\n    except ImportError:\n        _jieba_available = False\n\n\nclass TextSegmenter:\n    \"\"\"A text segmenter that uses jieba for Chinese text and simple splitting for others.\n\n    This segmenter tries to use jieba for better Chinese word segmentation.\n    If jieba is not available or Python version is 3.12+, it falls back to simple\n    punctuation/whitespace splitting.\n\n    Note: jieba is not compatible with Python 3.12+, so simple segmentation will be\n    used on Python 3.12 and higher versions.\n    \"\"\"\n\n    def __init__(self, use_jieba: bool = True):\n        \"\"\"Initialize the text segmenter.\n\n        Args:\n            use_jieba: Whether to use jieba for Chinese text segmentation.\n                Defaults to True. Will automatically fall back to simple\n                segmentation if jieba is not available.\n        \"\"\"\n        self.use_jieba = use_jieba and _jieba_available\n\n        # Include both English and Chinese punctuation\n        chinese_punctuation = \"，。！？；：\" \"''（）【】《》〈〉「」『』〔〕\"\n        all_separators = string.punctuation + chinese_punctuation + \" \\t\\n\\r\"\n        # Create regex pattern to split on any separator\n        self.split_pattern = \"[\" + re.escape(all_separators) + \"]+\"\n\n    @staticmethod\n    def _contains_chinese(text: str) -> bool:\n        \"\"\"Check if text contains Chinese characters.\n\n        Args:\n            text: Input text to check\n\n        Returns:\n            True if text contains Chinese characters, False otherwise\n        \"\"\"\n        return any(\"\\u4e00\" <= char <= \"\\u9fff\" for char in text)\n\n    def _simple_cut(self, text: str) -> list[str]:\n        \"\"\"Simple segmentation by splitting on punctuation and whitespace.\n\n        Args:\n            text: Input text to be segmented\n\n        Returns:\n            List of tokens\n        \"\"\"\n        if not text:\n            return []\n\n        # Split by separators and filter empty strings\n        tokens = re.split(self.split_pattern, text)\n        return [token for token in tokens if token.strip()]\n\n    def cut(self, text: str) -> list[str]:\n        \"\"\"Segment text into tokens.\n\n        For Chinese text with jieba available, uses jieba for word segmentation.\n        Otherwise, splits by punctuation and whitespace.\n\n        Args:\n            text: Input text to be segmented\n\n        Returns:\n            List of tokens\n        \"\"\"\n        if not text:\n            return []\n\n        # Use jieba for Chinese text if available\n        if self.use_jieba and self._contains_chinese(text):\n            return list(jieba.cut(text))\n\n        # Fall back to simple segmentation\n        return self._simple_cut(text)\n\n\nclass SimpleSegmenter:\n    \"\"\"A simple text segmenter that splits text by punctuation and whitespace.\n\n    This is a lightweight text segmentation tool that provides basic\n    functionality without external dependencies.\n\n    Note: This class is kept for backward compatibility. Consider using\n    TextSegmenter instead for better Chinese text support.\n    \"\"\"\n\n    def __init__(self):\n        # Include both English and Chinese punctuation\n        chinese_punctuation = \"，。！？；：\" \"''（）【】《》〈〉「」『』〔〕\"\n        all_separators = string.punctuation + chinese_punctuation + \" \\t\\n\\r\"\n        # Create regex pattern to split on any separator\n        self.split_pattern = \"[\" + re.escape(all_separators) + \"]+\"\n\n    def cut(self, text: str) -> list[str]:\n        \"\"\"Segment text into tokens by splitting on punctuation and whitespace.\n\n        Args:\n            text: Input text to be segmented\n\n        Returns:\n            List of tokens\n        \"\"\"\n        if not text:\n            return []\n\n        # Split by separators and filter empty strings\n        tokens = re.split(self.split_pattern, text)\n        return [token for token in tokens if token.strip()]\n\n\n# Global instance - use TextSegmenter with jieba support\n_segmenter = TextSegmenter()\n"
  },
  {
    "path": "openchatbi/tool/ask_human.py",
    "content": "\"\"\"Tool for asking human clarification when information is ambiguous.\"\"\"\n\nfrom pydantic import BaseModel, Field\n\n\nclass AskHuman(BaseModel):\n    \"\"\"Ask user for clarification when data is missing or ambiguous.\n\n    Use this tool ONLY when you are STRONGLY certain that information is\n    ambiguous or missing. First try to solve the question with available\n    user input before calling this tool.\n    \"\"\"\n\n    question: str = Field(description=\"Question to ask the user for clarification\")\n    options: list[str] = Field(description=\"Options for user to choose (max 3). Empty if not a choice question.\")\n"
  },
  {
    "path": "openchatbi/tool/mcp_tools.py",
    "content": "\"\"\"MCP (Model Context Protocol) tools integration for OpenChatBI.\n\nThis module provides integration with MCP servers using langchain-mcp-adapters,\nallowing the agent to use external tools through the Model Context Protocol.\n\"\"\"\n\nimport asyncio\nimport logging\nfrom concurrent.futures import ThreadPoolExecutor\nfrom typing import Any\n\nfrom langchain_core.tools import StructuredTool\nfrom langchain_mcp_adapters.client import MultiServerMCPClient\nfrom pydantic import BaseModel, Field\n\nfrom openchatbi.constants import MCP_TOOL_DEFAULT_TIMEOUT_SECONDS\n\nlogger = logging.getLogger(__name__)\n\n\ndef make_tool_sync_compatible(tool: StructuredTool, timeout: int) -> StructuredTool:\n    \"\"\"Make an async-only StructuredTool compatible with sync invocation.\n\n    This wraps the async coroutine with a sync function that runs it in an event loop.\n\n    Args:\n        tool: The StructuredTool to make sync-compatible\n        timeout: Timeout in seconds for tool execution\n\n    Returns:\n        StructuredTool with sync compatibility\n    \"\"\"\n    if tool.func is not None:\n        # Tool already has sync support\n        return tool\n\n    if tool.coroutine is None:\n        # Tool has no async function either, can't help\n        return tool\n\n    def sync_wrapper(*args: Any, **kwargs: Any) -> Any:\n        \"\"\"Synchronous wrapper for async tool function.\"\"\"\n        try:\n            loop = asyncio.get_event_loop()\n            if loop.is_running():\n                # We're in an async context, can't use run_until_complete\n                # Create a new thread with its own event loop\n                with ThreadPoolExecutor(max_workers=1) as executor:\n\n                    def run_in_new_loop() -> Any:\n                        new_loop = asyncio.new_event_loop()\n                        asyncio.set_event_loop(new_loop)\n                        try:\n                            return new_loop.run_until_complete(tool.coroutine(*args, **kwargs))  # type: ignore\n                        finally:\n                            new_loop.close()\n\n                    future = executor.submit(run_in_new_loop)\n                    return future.result(timeout=timeout)\n            else:\n                # No running loop, we can use run_until_complete\n                return loop.run_until_complete(tool.coroutine(*args, **kwargs))  # type: ignore\n        except RuntimeError:\n            # No event loop exists, create one\n            loop = asyncio.new_event_loop()\n            try:\n                return loop.run_until_complete(tool.coroutine(*args, **kwargs))  # type: ignore\n            finally:\n                loop.close()\n\n    # Create a new StructuredTool with both sync and async functions\n    return StructuredTool(\n        name=tool.name,\n        description=tool.description,\n        args_schema=tool.args_schema,\n        func=sync_wrapper,\n        coroutine=tool.coroutine,\n    )\n\n\nclass MCPServerConfig(BaseModel):\n    \"\"\"Configuration for MCP server connection.\"\"\"\n\n    name: str = Field(description=\"Name of the MCP server\")\n    transport: str = Field(default=\"stdio\", description=\"Transport type: stdio, sse, or streamable_http\")\n\n    # For stdio transport\n    command: list[str] = Field(default_factory=list, description=\"Command to start the MCP server\")\n    args: list[str] = Field(default_factory=list, description=\"Arguments for the MCP server\")\n    env: dict[str, str] = Field(default_factory=dict, description=\"Environment variables\")\n\n    # For HTTP transports (sse, streamable_http)\n    url: str = Field(default=\"\", description=\"URL for HTTP-based transports\")\n    headers: dict[str, str] = Field(default_factory=dict, description=\"HTTP headers\")\n\n    # Common settings\n    enabled: bool = Field(default=True, description=\"Whether this MCP server is enabled\")\n    timeout: int = Field(default=MCP_TOOL_DEFAULT_TIMEOUT_SECONDS, description=\"Connection timeout in seconds\")\n\n\nasync def create_mcp_tools_async(server_configs: list[dict[str, Any]]) -> list[StructuredTool]:\n    \"\"\"Create MCP tools asynchronously from server configurations.\n\n    This function processes MCP server configurations, establishes connections to enabled\n    servers, retrieves available tools, and makes them sync-compatible with proper\n    timeout configuration.\n\n    Args:\n        server_configs: List of MCP server configuration dictionaries containing\n                       server connection details, transport settings, and timeouts\n\n    Returns:\n        List of LangChain StructuredTool instances with mcp_ prefixes and sync compatibility\n    \"\"\"\n    if not server_configs:\n        return []\n\n    # Filter enabled servers and convert to MCPServerConfig\n    enabled_servers = {}\n    max_timeout = MCP_TOOL_DEFAULT_TIMEOUT_SECONDS  # Default from constants\n\n    for config_dict in server_configs:\n        try:\n            config = MCPServerConfig(**config_dict)\n            if not config.enabled:\n                continue\n\n            server_name = config.name\n\n            # Track the maximum timeout across all servers\n            max_timeout = max(max_timeout, config.timeout)\n\n            # Build server configuration for MultiServerMCPClient\n            if config.transport == \"stdio\":\n                if not config.command:\n                    logger.warning(f\"MCP server {server_name}: command required for stdio transport\")\n                    continue\n\n                enabled_servers[server_name] = {\n                    \"transport\": \"stdio\",\n                    \"command\": config.command[0] if config.command else \"\",\n                    \"args\": config.command[1:] + config.args if len(config.command) > 1 else config.args,\n                    \"env\": config.env,\n                }\n            elif config.transport in [\"sse\", \"streamable_http\"]:\n                if not config.url:\n                    logger.warning(f\"MCP server {server_name}: url required for {config.transport} transport\")\n                    continue\n\n                server_config: dict[str, Any] = {\n                    \"transport\": config.transport,\n                    \"url\": config.url,\n                }\n                if config.headers:\n                    server_config[\"headers\"] = config.headers\n                enabled_servers[server_name] = server_config\n            else:\n                logger.warning(f\"MCP server {server_name}: unsupported transport {config.transport}\")\n                continue\n\n        except Exception as e:\n            logger.error(f\"Invalid MCP server configuration: {e}\")\n            continue\n\n    if not enabled_servers:\n        logger.info(\"No enabled MCP servers found\")\n        return []\n\n    try:\n        # Create MultiServerMCPClient and get tools with timeout\n        client = MultiServerMCPClient(enabled_servers)\n        tools = await asyncio.wait_for(client.get_tools(), timeout=max_timeout)\n\n        logger.info(f\"Successfully loaded {len(tools)} MCP tools from {len(enabled_servers)} servers\")\n\n        # Add server prefix to tool names and make sync-compatible\n        prefixed_tools = []\n        for tool in tools:\n            # Get server name from tool metadata or guess from tool name\n            original_name = tool.name\n            if not original_name.startswith(\"mcp_\"):\n                tool.name = f\"mcp_{original_name}\"\n\n            # Make tool sync-compatible with configured timeout\n            sync_compatible_tool = make_tool_sync_compatible(tool, timeout=max_timeout)\n            prefixed_tools.append(sync_compatible_tool)\n\n        return prefixed_tools\n\n    except Exception as e:\n        logger.error(f\"Failed to initialize MCP client: {e}\")\n        return []\n\n\ndef create_mcp_tools_sync(server_configs: list[dict[str, Any]]) -> list[StructuredTool]:\n    \"\"\"Create MCP tools from server configurations synchronously.\n\n    This function initializes MCP tools in a separate thread with its own event loop\n    to avoid conflicts with existing async contexts.\n\n    Args:\n        server_configs: List of MCP server configuration dictionaries\n\n    Returns:\n        List of LangChain StructuredTool instances with sync compatibility\n    \"\"\"\n\n    if not server_configs:\n        return []\n\n    # For sync mode, run async initialization in a thread\n    def sync_initialize() -> list[StructuredTool]:\n        # Create new event loop for this thread\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        try:\n            return loop.run_until_complete(create_mcp_tools_async(server_configs))\n        except Exception as e:\n            logger.error(f\"Failed to create MCP tools in sync mode: {e}\")\n            return []\n        finally:\n            loop.close()\n\n    try:\n        with ThreadPoolExecutor(max_workers=1) as executor:\n            future = executor.submit(sync_initialize)\n            return future.result(timeout=MCP_TOOL_DEFAULT_TIMEOUT_SECONDS)\n    except Exception as e:\n        logger.error(f\"MCP tools sync initialization failed: {e}\")\n        return []\n\n\n# Global variable to store async-initialized tools\n_async_mcp_tools = None\n\n\nasync def get_mcp_tools_async(server_configs: list[dict[str, Any]]) -> list[StructuredTool]:\n    \"\"\"Get MCP tools asynchronously, using cached version if available.\n\n    Args:\n        server_configs: List of MCP server configuration dictionaries\n\n    Returns:\n        List of cached or newly created LangChain StructuredTool instances\n    \"\"\"\n    global _async_mcp_tools\n\n    if _async_mcp_tools is None:\n        _async_mcp_tools = await create_mcp_tools_async(server_configs)\n\n    return _async_mcp_tools\n\n\ndef reset_mcp_tools_cache() -> None:\n    \"\"\"Reset the async MCP tools cache.\"\"\"\n    global _async_mcp_tools\n    _async_mcp_tools = None\n"
  },
  {
    "path": "openchatbi/tool/memory.py",
    "content": "import functools\nimport sys\nfrom typing import Any\n\ntry:\n    import pysqlite3 as sqlite3\nexcept ImportError:  # pragma: no cover\n    import sqlite3\n\n# Make sure langgraph sqlite connector uses the same sqlite module.\nsys.modules[\"sqlite3\"] = sqlite3\n\nfrom langchain.tools import StructuredTool\nfrom langchain_core.language_models import BaseChatModel\nfrom langchain_openai.chat_models.base import BaseChatOpenAI\nfrom langgraph.store.sqlite import SqliteStore\nfrom langgraph.store.sqlite.aio import AsyncSqliteStore\nfrom langmem import (\n    create_manage_memory_tool,\n    create_memory_store_manager,\n    create_search_memory_tool,\n)\n\nfrom openchatbi import config\n\ntry:\n    from pydantic import BaseModel, ConfigDict\nexcept ImportError:\n    ConfigDict = None\n\n# Use AsyncSqliteStore for async operations\nasync_memory_store = None\nasync_store_context_manager = None\nsync_memory_store = None\nmemory_manager = None\n\n\n# Define profile structure\nclass UserProfile(BaseModel):\n    \"\"\"Represents the full representation of a user.\"\"\"\n\n    name: str | None = None\n    language: str | None = None\n    timezone: str | None = None\n    jargon: str | None = None\n\n\ndef get_sync_memory_store() -> SqliteStore | None:\n    global sync_memory_store\n    embedding_model = config.get().embedding_model\n    if not embedding_model:\n        return None\n    if sync_memory_store is None:\n        # For backwards compatibility and sync operations\n        conn = sqlite3.connect(\"memory.db\", check_same_thread=False)\n        conn.isolation_level = None\n        sync_memory_store = SqliteStore(\n            conn,\n            index={\n                \"dims\": 1536,\n                \"embed\": embedding_model,\n                \"fields\": [\"text\"],  # specify which fields to embed\n            },\n        )\n        try:\n            sync_memory_store.setup()\n        except Exception:\n            pass\n    return sync_memory_store\n\n\nasync def get_async_memory_store() -> AsyncSqliteStore | None:\n    \"\"\"Get or create the async memory store.\"\"\"\n    global async_memory_store, async_store_context_manager\n    embedding_model = config.get().embedding_model\n    if not embedding_model:\n        return None\n    if async_memory_store is None:\n        # AsyncSqliteStore.from_conn_string returns an async context manager\n        async_store_context_manager = AsyncSqliteStore.from_conn_string(\n            \"memory.db\",\n            index={\n                \"dims\": 1536,\n                \"embed\": embedding_model,\n                \"fields\": [\"text\"],  # specify which fields to embed\n            },\n        )\n        async_memory_store = await async_store_context_manager.__aenter__()\n    return async_memory_store\n\n\nasync def cleanup_async_memory_store() -> None:\n    \"\"\"Cleanup async memory store resources.\"\"\"\n    global async_memory_store, async_store_context_manager\n    if async_memory_store is not None and async_store_context_manager is not None:\n        try:\n            await async_store_context_manager.__aexit__(None, None, None)\n        except Exception as e:\n            print(f\"Error cleaning up async memory store: {e}\")\n        finally:\n            async_memory_store = None\n            async_store_context_manager = None\n\n\nasync def setup_async_memory_store() -> Any:\n    \"\"\"Setup async memory store for langmem.\"\"\"\n    await get_async_memory_store()\n\n\ndef fix_schema_for_openai(schema: dict) -> None:\n    props = schema.get(\"properties\", {})\n    schema[\"required\"] = list(props.keys())\n\n    # Since Pydantic 2.11, it will always add `additionalProperties: True` for arbitrary dictionary schemas\n    # If it is already set to True, we need override it to False\n    # Can remove this fix when the patch release: https://github.com/langchain-ai/langchain/pull/32879\n    def fix(obj):\n        if isinstance(obj, dict):\n            if obj.get(\"type\") == \"object\" and \"additionalProperties\" in obj and obj[\"additionalProperties\"]:\n                obj[\"additionalProperties\"] = False\n            for v in obj.values():\n                fix(v)\n        elif isinstance(obj, list):\n            for item in obj:\n                fix(item)\n\n    fix(schema)\n\n\ndef get_memory_manager() -> Any:\n    global memory_manager\n    if memory_manager is None:\n        memory_manager = create_memory_store_manager(\n            config.get().default_llm,\n            schemas=[UserProfile],\n            instructions=\"Extract user profile information\",\n            enable_inserts=False,\n        )\n    return memory_manager\n\n\nclass StructuredToolWithRequired(StructuredTool):\n    def __init__(self, orig_tool: StructuredTool):\n        name = getattr(orig_tool, \"name\", None)\n        super().__init__(\n            name=name,\n            description=orig_tool.description,\n            args_schema=orig_tool.args_schema,\n            func=orig_tool.func,\n            coroutine=orig_tool.coroutine,\n        )\n\n    @functools.cached_property\n    def tool_call_schema(self) -> \"ArgsSchema\":\n        tcs = super().tool_call_schema\n        try:\n            if tcs.model_config:\n                tcs.model_config[\"json_schema_extra\"] = fix_schema_for_openai\n            elif ConfigDict is not None:\n                tcs.model_config = ConfigDict(json_schema_extra=fix_schema_for_openai)\n        except Exception:\n            pass\n        return tcs\n\n\ndef get_memory_tools(\n    llm: BaseChatModel, sync_mode: bool = False, store: Any | None = None\n) -> list[StructuredTool] | None:\n    # Get the appropriate store based on mode\n    if not store:\n        if sync_mode:\n            store = get_sync_memory_store()\n        else:\n            store = None\n    if not store:\n        return None\n\n    # create langmem manage memory tool with {user_id} template\n    manage_memory_tool = create_manage_memory_tool(namespace=(\"memories\", \"{user_id}\"), store=store)\n    search_memory_tool = create_search_memory_tool(namespace=(\"memories\", \"{user_id}\"), store=store)\n\n    if isinstance(llm, BaseChatOpenAI):\n        manage_memory_tool = StructuredToolWithRequired(manage_memory_tool)\n        search_memory_tool = StructuredToolWithRequired(search_memory_tool)\n    return [manage_memory_tool, search_memory_tool]\n\n\nasync def get_async_memory_tools(llm: BaseChatModel) -> list[StructuredTool]:\n    \"\"\"Get memory tools configured with async store.\"\"\"\n    async_store = await get_async_memory_store()\n    return get_memory_tools(llm, sync_mode=False, store=async_store)\n"
  },
  {
    "path": "openchatbi/tool/run_python_code.py",
    "content": "\"\"\"Tool for running python code.\"\"\"\n\nfrom langchain.tools import tool\nfrom pydantic import BaseModel, Field\n\nfrom openchatbi.code.docker_executor import DockerExecutor, check_docker_status\nfrom openchatbi.code.local_executor import LocalExecutor\nfrom openchatbi.code.restricted_local_executor import RestrictedLocalExecutor\nfrom openchatbi.config_loader import ConfigLoader\nfrom openchatbi.utils import log\n\n\nclass PythonCodeInput(BaseModel):\n    reasoning: str = Field(description=\"Reason for using this run python code tool\")\n    code: str = Field(description=\"The python code to execute\")\n\n\ndef _create_executor():\n    \"\"\"Create appropriate executor based on configuration.\"\"\"\n    config_loader = ConfigLoader()\n    try:\n        config = config_loader.get()\n        executor_type = config.python_executor.lower()\n    except ValueError:\n        # Configuration not loaded, use default local executor\n        log(\"Configuration not loaded, using default LocalExecutor\")\n        return LocalExecutor()\n\n    log(f\"Creating executor of type: {executor_type}\")\n\n    if executor_type == \"docker\":\n        # Check if Docker is available before creating DockerExecutor\n        is_available, status_message = check_docker_status()\n        if not is_available:\n            log(f\"Docker is not available ({status_message}), falling back to LocalExecutor\")\n            return LocalExecutor()\n        log(\"Docker is available, creating DockerExecutor\")\n        return DockerExecutor()\n    elif executor_type == \"restricted_local\":\n        log(\"Creating RestrictedLocalExecutor\")\n        return RestrictedLocalExecutor()\n    elif executor_type == \"local\":\n        log(\"Creating LocalExecutor\")\n        return LocalExecutor()\n    else:\n        log(f\"Unknown executor type '{executor_type}', using LocalExecutor as fallback\")\n        return LocalExecutor()\n\n\n@tool(\"run_python_code\", args_schema=PythonCodeInput, return_direct=False, infer_schema=True)\ndef run_python_code(reasoning: str, code: str) -> str:\n    \"\"\"Run python code string. Note: Only print outputs are visible, function return values will be ignored. Use print statements to see results.\n    Returns:\n        str: The print outputs of the python code\n    \"\"\"\n    log(f\"Run Python Code, Reasoning: {reasoning}\")\n\n    try:\n        executor = _create_executor()\n        log(f\"Using {executor.__class__.__name__} for code execution\")\n        success, output = executor.run_code(code)\n        if success:\n            return output\n        else:\n            return f\"Error: {output}\"\n    except Exception as e:\n        log(f\"Failed to create executor: {e}\")\n        # Fallback to LocalExecutor if configuration fails\n        log(\"Falling back to LocalExecutor\")\n        executor = LocalExecutor()\n        success, output = executor.run_code(code)\n        if success:\n            return output\n        else:\n            return f\"Error: {output}\"\n"
  },
  {
    "path": "openchatbi/tool/save_report.py",
    "content": "\"\"\"Tool for saving reports to files.\"\"\"\n\nimport datetime\nfrom pathlib import Path\n\nfrom langchain.tools import tool\nfrom pydantic import BaseModel, Field\n\nfrom openchatbi import config\nfrom openchatbi.utils import log\n\n\nclass SaveReportInput(BaseModel):\n    content: str = Field(description=\"The content of the report to save\")\n    title: str = Field(description=\"The title of the report (will be used in filename)\")\n    file_format: str = Field(\n        description=\"The file format/extension, only support 'md', 'csv', 'txt', 'json', 'html', 'xml'\"\n    )\n\n\n@tool(\"save_report\", args_schema=SaveReportInput, return_direct=False, infer_schema=True)\ndef save_report(content: str, title: str, file_format: str = \"md\") -> str:\n    \"\"\"Save a report to a file with timestamp and title in filename.\n\n    Args:\n        content: The content of the report to save\n        title: The title of the report (will be used in filename)\n        file_format: The file format/extension, only support 'md', 'csv', 'txt', 'json', 'html', 'xml'\n\n    Returns:\n        str: Success message with download link or error message\n    \"\"\"\n    allowed_formats = {\"md\", \"csv\", \"txt\", \"json\", \"html\", \"xml\"}\n    if file_format not in allowed_formats:\n        raise ValueError(f\"Unsupported file format: {file_format}\")\n\n    try:\n        # Get report directory from config\n        report_dir = config.get().report_directory\n\n        # Create directory if it doesn't exist\n        Path(report_dir).mkdir(parents=True, exist_ok=True)\n\n        # Generate timestamp for filename\n        timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n\n        # Clean title for filename (remove invalid characters)\n        clean_title = \"\".join(c for c in title if c.isalnum() or c in (\" \", \"-\")).rstrip()\n        clean_title = clean_title.replace(\" \", \"_\")\n\n        # Create filename\n        filename = f\"{timestamp}_{clean_title}.{file_format}\"\n        file_path = Path(report_dir) / filename\n\n        # Write content to file\n        with open(file_path, \"w\", encoding=\"utf-8\") as f:\n            f.write(content)\n\n        log(f\"Report saved: {file_path}\")\n\n        # Return success message with download link\n        download_url = f\"/api/download/report/{filename}\"\n        return f\"Report saved successfully! Download link: {download_url}\"\n\n    except Exception as e:\n        error_msg = f\"Failed to save report: {str(e)}\"\n        log(error_msg)\n        return error_msg\n"
  },
  {
    "path": "openchatbi/tool/search_knowledge.py",
    "content": "\"\"\"Tools for searching knowledge bases and schema information.\"\"\"\n\nfrom langchain.tools import tool\nfrom pydantic import BaseModel, Field\n\nfrom openchatbi import config\nfrom openchatbi.catalog.schema_retrival import col_dict, column_tables_mapping, get_relevant_columns\nfrom openchatbi.utils import log\n\n\nclass SearchInput(BaseModel):\n    \"\"\"Input schema for knowledge search tool.\"\"\"\n\n    reasoning: str = Field(description=\"Reason for using this search tool\")\n    query_list: list[str] = Field(description=\"Query terms to search (max 5, avoid duplicates)\")\n    knowledge_bases: list[str] = Field(\n        description=\"\"\"Knowledge bases to search, options are:\n            - `\"columns\"`: The description, alias of columns, including dimensions and metrics.\n            - `\"business\"`: The business knowledge.\"\"\"\n    )\n    with_table_list: bool = Field(\n        description=\"Include table list for columns (only set to True when user asks about table-column relationships)\"\n    )\n\n\n@tool(\"search_knowledge\", args_schema=SearchInput, return_direct=False, infer_schema=True)\ndef search_knowledge(\n    reasoning: str, query_list: list[str], knowledge_bases: list[str], with_table_list: bool = False\n) -> dict[str, str]:\n    \"\"\"Search relevant knowledge from knowledge bases.\n    Returns:\n        Dict[str, str]: Search results for each knowledge base.\n    \"\"\"\n    log(f\"Search knowledge, query_list={query_list}, knowledge_bases={knowledge_bases}, reasoning={reasoning}\")\n    final_results = {}\n    if \"columns\" in knowledge_bases:\n        column_results = search_column_from_catalog(query_list, with_table_list)\n        final_results[\"columns\"] = f\"# Relevant Columns and Description:\\n{column_results}\"\n    return final_results\n\n\nclass ShowSchemaInput(BaseModel):\n    \"\"\"Input schema for show schema tool.\"\"\"\n\n    reasoning: str = Field(description=\"Reason for showing schema\")\n    tables: list[str] = Field(description=\"Full table names to show (max 5)\")\n\n\n@tool(\"show_schema\", args_schema=ShowSchemaInput, return_direct=False, infer_schema=True)\ndef show_schema(reasoning: str, tables: list[str]) -> list[str]:\n    \"\"\"Show table schemas including description, columns, and derived metrics.\n    Returns:\n        list[str]: Schema information for each table.\n    \"\"\"\n    log(f\"Show schema, tables={tables}, reasoning={reasoning}\")\n    result = list_table_from_catalog(tables)\n    return result\n\n\ndef search_column_from_catalog(query_list: list[str], with_table_list: bool) -> str:\n    \"\"\"Search columns from catalog based on query list.\"\"\"\n    relevant_column_set = set()\n    for keywords in query_list:\n        relevant_columns = get_relevant_columns(keywords.split(\" \"), keywords.split(\" \"), keywords.split(\" \"))\n        relevant_column_set.update(relevant_columns)\n    column_results = render_column_result(relevant_column_set, with_table_list)\n    return \"\\n\".join(column_results)\n\n\ndef list_table_from_catalog(tables: list[str]) -> list[str]:\n    \"\"\"Get table information from catalog.\"\"\"\n    result = []\n    catalog_store = config.get().catalog_store\n\n    for table_name in tables:\n        table_info = catalog_store.get_table_information(table_name)\n        if not table_info:\n            continue\n        table_desc = f\"Table: `{table_name}` \\n# Description: {table_info['description']}\\n\"\n        columns = catalog_store.get_column_list(table_name)\n        column_names = [info[\"column_name\"] for info in columns]\n        column_results = render_column_result(column_names)\n        table_desc += \"# Columns:\\n\"\n        table_desc += \"\\n\".join(column_results)\n        if table_info.get(\"derived_metric\"):\n            table_desc += \"## Derived metrics:\\n\"\n            table_desc += table_info[\"derived_metric\"]\n        result.append(table_desc)\n    return result\n\n\ndef render_column_result(column_list: list[str], with_table_list: bool = False) -> list[str]:\n    \"\"\"Render column information as formatted strings.\"\"\"\n    column_results = []\n    for column_name in column_list:\n        if column_name not in col_dict:\n            continue\n        table_list = column_tables_mapping.get(column_name, [])\n        column = col_dict[column_name]\n        column_desc = (\n            f\"## {column['column_name']}\"\n            f\"\\n- Column Category: {column['category']}\"\n            f\"\\n- Display Name: {column['display_name']} \"\n            f\"\\n- Description \\\"{column['description']}\\\"\"\n        )\n        if with_table_list:\n            column_desc += f\"\\n- Related Tables: {table_list}\"\n        column_results.append(column_desc)\n    return column_results\n"
  },
  {
    "path": "openchatbi/tool/timeseries_forecast.py",
    "content": "\"\"\"Tool for time series forecasting.\"\"\"\n\nimport logging\nfrom typing import Any\n\nimport requests\nfrom langchain.tools import tool\nfrom pydantic import BaseModel, Field\n\nfrom openchatbi import config\nfrom openchatbi.utils import log\n\nlogger = logging.getLogger(__name__)\n\n\nclass TimeseriesForecastInput(BaseModel):\n    \"\"\"Input schema for time series forecasting tool.\"\"\"\n\n    reasoning: str = Field(description=\"Reason for using time series forecasting and what insights you expect to gain\")\n    input_data: list[float | int | dict[str, Any]] = Field(\n        description=\"Time series data as list of numbers or structured data with timestamps and values\"\n    )\n    forecast_window: int = Field(\n        default=24, description=\"Number of future time points to predict (1-200)\", ge=1, le=200\n    )\n    frequency: str = Field(default=\"hourly\", description=\"Time series frequency: hourly, daily, weekly, monthly, etc.\")\n    input_length: int | None = Field(\n        default=None, description=\"Optional limit on input data length to use for prediction\"\n    )\n    target_column: str = Field(\n        default=\"value\", description=\"Column name to forecast for structured data (default: 'value')\"\n    )\n\n\ndef _check_service_health(service_url: str) -> bool:\n    \"\"\"Check if time series forecasting service is available.\"\"\"\n    try:\n        response = requests.get(f\"{service_url}/health\", timeout=5)\n        if response.status_code == 200:\n            health_data = response.json()\n            return health_data.get(\"model_initialized\", False)\n        return False\n    except requests.exceptions.RequestException:\n        return False\n\n\ndef check_forecast_service_health() -> bool:\n    try:\n        service_url = config.get().timeseries_forecasting_service_url\n        return _check_service_health(service_url)\n    except ValueError:\n        # Configuration not loaded yet (e.g., in tests)\n        return False\n\n\ndef _call_timeseries_service(\n    service_url: str,\n    input_data: list[float | int | dict[str, Any]],\n    forecast_window: int,\n    frequency: str,\n    input_length: int | None = None,\n    target_column: str = \"value\",\n) -> dict[str, Any]:\n    \"\"\"Call time series forecasting service.\"\"\"\n    try:\n        # Prepare request payload\n        payload = {\"input\": input_data, \"forecast_window\": forecast_window, \"frequency\": frequency}\n\n        if input_length is not None:\n            payload[\"input_len\"] = input_length\n\n        if target_column != \"value\":\n            payload[\"target_column\"] = target_column\n\n        # Make request to time series forecasting service\n        response = requests.post(f\"{service_url}/predict\", json=payload, timeout=30)\n\n        if response.status_code == 200:\n            return response.json()\n        else:\n            return {\n                \"error\": f\"Service returned status {response.status_code}: {response.text}\",\n                \"status\": \"http_error\",\n                \"status_code\": response.status_code,\n            }\n\n    except requests.exceptions.Timeout:\n        return {\"error\": \"Request timeout - forecasting service took too long to respond\", \"status\": \"error\"}\n    except requests.exceptions.RequestException as e:\n        return {\"error\": f\"Failed to connect to forecasting service: {str(e)}\", \"status\": \"error\"}\n    except Exception as e:\n        return {\"error\": f\"Unexpected error: {str(e)}\", \"status\": \"error\"}\n\n\ndef _format_forecast_result(result: dict[str, Any], reasoning: str, input_data_length: int) -> str:\n    \"\"\"Format the forecasting result for the agent.\"\"\"\n    if result.get(\"status\") == \"error\":\n        return f\"\"\"Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')}\nPlease check:\n1. Time series forecasting service is running (docker run -p 8765:8765 timeseries-forecasting)\n2. Model load successfully\n3. Try again if timeout\"\"\"\n    elif result.get(\"status\") == \"http_error\":\n        if result.get(\"status_code\") == 400:\n            return f\"\"\"Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')}\nPlease check:\n1. Input data format is correct\n2. input_len is set to larger when the input data length is not enough\n3. Forecast window is reasonable (1-200)\"\"\"\n        else:\n            return f\"\"\"Time Series Forecasting Error: {result.get('error', 'Unknown error occurred')}\"\"\"\n\n    predictions = result.get(\"predictions\", [])\n    forecast_window = result.get(\"forecast_window\", len(predictions))\n    frequency = result.get(\"frequency\", \"unknown\")\n\n    if not predictions:\n        return \"No predictions were generated. Please check your input data.\"\n\n    # Calculate basic statistics\n    sum_predictions = sum(predictions)\n    avg_prediction = sum_predictions / len(predictions) if predictions else 0\n    min_prediction = min(predictions) if predictions else 0\n    max_prediction = max(predictions) if predictions else 0\n\n    # Create formatted response\n    response_parts = [\n        \"✅ Time Series Forecasting Completed\",\n        \"\",\n        \"Forecast Summary:\",\n        f\"  • Input data points: {input_data_length}\",\n        f\"  • Forecast window: {forecast_window} {frequency.lower()} periods\",\n        \"\",\n        \"Predictions:\",\n        f\"  • Average forecast: {avg_prediction:.2f}\",\n        f\"  • Sum: {sum_predictions:.2f}\",\n        f\"  • Range: {min_prediction:.2f} to {max_prediction:.2f}\",\n        f\"  • Total periods forecasted: {len(predictions)}\",\n        \"\",\n        \"Detailed Forecast Values:\",\n    ]\n\n    for i, pred in enumerate(predictions):\n        period_label = f\"Period {i + 1}\"\n        response_parts.append(f\"  • {period_label}: {pred:.2f}\")\n\n    return \"\\n\".join(response_parts)\n\n\n@tool(\"timeseries_forecast\", args_schema=TimeseriesForecastInput, return_direct=False, infer_schema=True)\ndef timeseries_forecast(\n    reasoning: str,\n    input_data: list[float | int | dict[str, Any]],\n    forecast_window: int = 24,\n    frequency: str = \"hourly\",\n    input_length: int | None = None,\n    target_column: str = \"value\",\n) -> str:\n    \"\"\"Forecast future values for time series data using advanced deep learning models.\n\n    This tool uses state-of-the-art deep learning models (currently transformer based) to predict future values based on historical time series data.\n    Perfect for sales forecasting, demand planning, trend analysis, and business intelligence.\n\n    Args:\n        reasoning: Explanation of why forecasting is needed and what insights are expected\n        input_data: Historical time series data as list of numbers or structured data with timestamps\n        forecast_window: Number of future time points to predict (1-200, default: 24)\n        frequency: Time series frequency - hourly, daily, weekly, monthly, etc.\n        input_length: Optional limit on how much historical data to use for prediction\n        target_column: Column name to forecast for structured data (default: 'value')\n\n    Returns:\n        str: Formatted forecast results with predictions, statistics, and interpretation guidance\n\n    Examples:\n        - Sales forecasting: Predict next month's daily sales based on historical data\n        - Demand planning: Forecast product demand for inventory management\n        - Financial planning: Predict revenue, costs, or other financial metrics\n        - Operational planning: Forecast website traffic, resource usage, etc.\n    \"\"\"\n\n    # Get service URL from config\n    service_url = config.get().timeseries_forecasting_service_url\n\n    log(f\"Time Series Forecast: {reasoning}\")\n    log(f\"Input data points: {len(input_data)}, Forecast window: {forecast_window}, Frequency: {frequency}\")\n\n    # Validate input data\n    if not input_data:\n        return \"Error: Input data cannot be empty. Please provide historical time series data.\"\n\n    if len(input_data) < 3:\n        return \"Error: Need at least 3 data points for reliable forecasting. Please provide more historical data.\"\n\n    # Check service availability\n    if not _check_service_health(service_url):\n        return \"\"\"Time Series Forecasting Service Unavailable. The time series forecasting service is not running or not in service. \"\"\"\n\n    # Call the forecasting service\n    result = _call_timeseries_service(\n        service_url=service_url,\n        input_data=input_data,\n        forecast_window=forecast_window,\n        frequency=frequency,\n        input_length=input_length,\n        target_column=target_column,\n    )\n\n    # Format and return the result\n    return _format_forecast_result(result, reasoning, len(input_data))\n"
  },
  {
    "path": "openchatbi/utils.py",
    "content": "\"\"\"Utility functions for OpenChatBI.\"\"\"\n\nimport json\nimport sys\nimport uuid\nfrom pathlib import Path\nfrom typing import Any\n\nfrom fastapi import HTTPException\nfrom fastapi.responses import FileResponse\nfrom langchain_chroma import Chroma\nfrom langchain_core.documents import Document\nfrom langchain_core.messages import AIMessage, AIMessageChunk, RemoveMessage, ToolMessage\nfrom langchain_core.vectorstores import VectorStore\nfrom rank_bm25 import BM25Okapi\nfrom regex import regex\n\nfrom openchatbi.graph_state import AgentState\nfrom openchatbi.text_segmenter import _segmenter\n\n\ndef log(args) -> None:\n    \"\"\"Log messages to stderr for debugging.\"\"\"\n    print(args, file=sys.stderr, flush=True)\n\n\ndef get_text_from_content(content: str | list[str | dict]) -> str:\n    \"\"\"Extract text from various content formats.\n\n    Args:\n        content: String, list of strings, or list of dicts with 'text' key.\n\n    Returns:\n        str: Extracted text content.\n    \"\"\"\n    if isinstance(content, str):\n        return content\n    elif isinstance(content, list):\n        if isinstance(content[0], str):\n            return \"\".join(content)\n        elif isinstance(content[0], dict):\n            return \"\".join([item.get(\"text\", \"\") for item in content])\n    return \"\"\n\n\ndef get_text_from_message_chunk(chunk: AIMessageChunk) -> str:\n    \"\"\"Extract content from an AIMessageChunk.\n\n    Args:\n        chunk (AIMessageChunk): The message chunk to extract text from.\n\n    Returns:\n        str: Extracted text content or empty string.\n    \"\"\"\n    if not isinstance(chunk, AIMessageChunk) or not hasattr(chunk, \"content\") or not chunk.content:\n        return \"\"\n    return get_text_from_content(chunk.content)\n\n\ndef extract_json_from_answer(answer: str) -> dict:\n    \"\"\"Extract the first JSON object from a string answer.\n\n    Args:\n        answer (str): String that may contain JSON objects.\n\n    Returns:\n        dict: Parsed JSON object or empty dict if none found.\n    \"\"\"\n    pattern = regex.compile(r\"\\{(?:[^{}]+|(?R))*\\}\")\n    matches = pattern.findall(answer)\n    json_result = matches[0] if matches else \"{}\"\n    return json.loads(json_result)\n\n\ndef get_report_download_response(filename: str) -> FileResponse:\n    \"\"\"Get FileResponse for downloading a report file.\n\n    Args:\n        filename: The filename of the report to download\n\n    Returns:\n        FileResponse: Response object for file download\n\n    Raises:\n        HTTPException: Various HTTP errors for invalid requests\n    \"\"\"\n    try:\n        # Import config here to avoid circular imports\n        from openchatbi import config\n\n        # Get report directory from config\n        report_dir = config.get().report_directory\n        file_path = Path(report_dir) / filename\n\n        # Check if file exists and is within the report directory\n        if not file_path.exists():\n            raise HTTPException(status_code=404, detail=\"Report file not found\")\n\n        if not file_path.is_file():\n            raise HTTPException(status_code=400, detail=\"Invalid file path\")\n\n        # Ensure the file is within the report directory (security check)\n        try:\n            file_path.resolve().relative_to(Path(report_dir).resolve())\n        except ValueError:\n            raise HTTPException(status_code=403, detail=\"Access denied\") from None\n\n        # Determine media type based on file extension\n        media_type_map = {\n            \".md\": \"text/markdown\",\n            \".csv\": \"text/csv\",\n            \".txt\": \"text/plain\",\n            \".json\": \"application/json\",\n            \".html\": \"text/html\",\n            \".xml\": \"application/xml\",\n        }\n\n        file_extension = file_path.suffix.lower()\n        media_type = media_type_map.get(file_extension, \"application/octet-stream\")\n\n        return FileResponse(path=str(file_path), media_type=media_type, filename=filename)\n\n    except HTTPException:\n        raise\n    except Exception as e:\n        raise HTTPException(status_code=500, detail=f\"Failed to download report: {str(e)}\") from e\n\n\ndef _create_chroma_from_texts(\n    texts: list[str],\n    embedding,\n    collection_name: str,\n    metadatas,\n    collection_metadata: dict,\n    chroma_dir: str,\n):\n    \"\"\"Helper function to create Chroma client from texts.\"\"\"\n    return Chroma.from_texts(\n        texts,\n        embedding,\n        metadatas=metadatas,\n        collection_name=collection_name,\n        collection_metadata=collection_metadata,\n        persist_directory=chroma_dir,\n    )\n\n\ndef create_vector_db(\n    texts: list[str],\n    embedding=None,\n    collection_name: str = \"langchain\",\n    metadatas=None,\n    collection_metadata: dict = None,\n    chroma_db_path: str = None,\n) -> VectorStore:\n    \"\"\"Create or reuse a Chroma vector database.\n\n    Args:\n        texts (List[str]): Text documents to index.\n        embedding: Embedding function to use.\n        collection_name (str): Name of the collection.\n        metadatas: Metadata for each document.\n        collection_metadata (dict): Collection-level metadata.\n        chroma_db_path (str): Path to chroma database file.\n\n    Returns:\n        Chroma: Vector database instance.\n    \"\"\"\n    # fallback to Simple vector store using BM25 if no embedding model configured\n    if not embedding:\n        return SimpleStore(texts, metadatas)\n\n    chroma_dir = chroma_db_path or \"./.chroma_db\"\n    client = Chroma(\n        collection_name,\n        persist_directory=chroma_dir,\n        embedding_function=embedding,\n        collection_metadata=collection_metadata,\n    )\n    use_cache = False\n    existing_docs = None\n    try:\n        # Try to get documents to check if collection exists and has content\n        existing_docs = client.get()\n        if not existing_docs[\"documents\"]:\n            print(f\"Init new client from text for {collection_name}...\")\n        else:\n            # Check if cached texts match the input texts\n            cached_texts = existing_docs[\"documents\"]\n            # Compare texts: check count first, then content\n            if len(cached_texts) != len(texts):\n                print(\n                    f\"Texts count mismatch for {collection_name} \"\n                    f\"(cached: {len(cached_texts)}, input: {len(texts)}). Recreating collection...\"\n                )\n            else:\n                # Compare content by sorting both lists to handle order differences\n                sorted_cached = sorted(cached_texts)\n                sorted_input = sorted(texts)\n                if sorted_cached != sorted_input:\n                    print(f\"Cache content mismatch for {collection_name}. Recreating collection...\")\n                else:\n                    print(f\"Re-use collection for {collection_name}\")\n                    use_cache = True\n    except Exception:\n        # If collection doesn't exist or any error, create new one\n        print(f\"Init new client from text for {collection_name}...\")\n    if not use_cache:\n        # Clear existing collection before recreating to avoid data duplication\n        if existing_docs and existing_docs[\"documents\"]:\n            try:\n                client.reset_collection()\n                print(f\"Cleared existing collection {collection_name} before recreating...\")\n            except Exception as e:\n                # If reset fails, log and continue with recreation\n                print(f\"Warning: Failed to clear collection {collection_name}: {e}\")\n        client = _create_chroma_from_texts(\n            texts, embedding, collection_name, metadatas, collection_metadata, chroma_dir\n        )\n    return client\n\n\ndef recover_incomplete_tool_calls(state: AgentState) -> list:\n    \"\"\"Recover from incomplete tool calls by creating message operations to insert ToolMessages correctly.\n\n    When the graph execution is interrupted (e.g., by kill or app restart)\n    during tool execution, the state can end up with AIMessage containing\n    tool_calls but no corresponding ToolMessage responses. This function\n    detects such cases and creates the necessary message operations to insert\n    failure ToolMessages in the correct position (right after the AIMessage).\n\n    Args:\n        state (AgentState): The current graph state containing messages.\n\n    Returns:\n        list: Message operations to insert recovery ToolMessages, or empty list if no recovery needed.\n    \"\"\"\n    messages = state.get(\"messages\", [])\n    if not messages:\n        return []\n\n    # Find the last AIMessage with tool_calls\n    last_ai_message = None\n    last_ai_index = -1\n\n    for i in range(len(messages) - 1, -1, -1):\n        if isinstance(messages[i], AIMessage) and messages[i].tool_calls:\n            last_ai_message = messages[i]\n            last_ai_index = i\n            break\n\n    if not last_ai_message:\n        return []\n\n    # Check if there are any ToolMessages after this AIMessage\n    tool_call_ids = {call[\"id\"] for call in last_ai_message.tool_calls}\n    handled_tool_call_ids = set()\n\n    # Look for ToolMessages that respond to these tool calls\n    for msg in messages[last_ai_index + 1 :]:\n        if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:\n            handled_tool_call_ids.add(msg.tool_call_id)\n\n    # Find unhandled tool calls\n    unhandled_tool_call_ids = tool_call_ids - handled_tool_call_ids\n\n    if not unhandled_tool_call_ids:\n        return []  # All tool calls have responses\n\n    # Create failure ToolMessages for unhandled tool calls\n    recovery_messages = []\n    for tool_call in last_ai_message.tool_calls:\n        if tool_call[\"id\"] in unhandled_tool_call_ids:\n            failure_msg = ToolMessage(\n                content=f\"Tool `{tool_call['name']}` execution was interrupted due to system restart or process termination. Please retry the operation.\",\n                tool_call_id=tool_call[\"id\"],\n            )\n            recovery_messages.append(failure_msg)\n\n    # Build operations to insert recovery messages in correct position\n    operations = []\n    messages_after_ai = messages[last_ai_index + 1 :]\n\n    # Collect IDs that will be removed\n    removed_ids = set()\n\n    # If there are messages after the AIMessage, we need to remove them first\n    if messages_after_ai:\n        for msg in messages_after_ai:\n            operations.append(RemoveMessage(id=msg.id))\n            removed_ids.add(msg.id)\n\n    # Add recovery messages (they will be inserted right after the AIMessage)\n    operations.extend(recovery_messages)\n\n    # Re-add the messages that were after the AIMessage (if any)\n    # CRITICAL: Must regenerate Message ids if matches a RemoveMessage to prevent RemoveMessage from being cancelled\n    if messages_after_ai:\n        for msg in messages_after_ai:\n            # Only regenerate ID if this message's ID was removed\n            if msg.id in removed_ids:\n                # Create a copy with new ID to prevent the RemoveMessage from being discarded\n                new_msg = msg.model_copy(update={\"id\": str(uuid.uuid4())})\n                operations.append(new_msg)\n            else:\n                # Keep original message as-is if ID wasn't removed\n                operations.append(msg)\n\n    log(f\"Recovered {len(recovery_messages)} incomplete tool calls\")\n    return operations\n\n\nclass SimpleStore(VectorStore):\n    \"\"\"Simple vector store using BM25 for text retrieval without embeddings.\"\"\"\n\n    def __init__(\n        self,\n        texts: list[str],\n        metadatas: list[dict] | None = None,\n        ids: list[str] | None = None,\n    ):\n        \"\"\"Initialize SimpleStore with texts.\n\n        Args:\n            texts: List of text documents to store.\n            metadatas: Optional list of metadata dicts for each document.\n            ids: Optional list of IDs for each document.\n        \"\"\"\n        self.texts = texts\n        self.metadatas = metadatas or [{} for _ in texts]\n        self.ids = ids or [str(uuid.uuid4()) for _ in texts]\n\n        # Create Document objects\n        self.documents = [\n            Document(id=doc_id, page_content=text, metadata=meta)\n            for doc_id, text, meta in zip(self.ids, self.texts, self.metadatas)\n        ]\n\n        # Tokenize texts and create BM25 index\n        self.tokenized_corpus = [self._tokenize(text) for text in texts]\n        # BM25Okapi doesn't support empty corpus, so set to None if empty\n        self.bm25 = BM25Okapi(self.tokenized_corpus) if texts else None\n\n    def _tokenize(self, text: str) -> list[str]:\n        \"\"\"Tokenize text for BM25 indexing using TextSegmenter.\n\n        Args:\n            text: Text to tokenize.\n\n        Returns:\n            List of tokens.\n        \"\"\"\n        return _segmenter.cut(text)\n\n    def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> list[Document]:\n        \"\"\"Search for documents similar to the query using BM25.\n\n        Args:\n            query: Query text.\n            k: Number of documents to return.\n            **kwargs: Additional arguments (unused).\n\n        Returns:\n            List of most similar Document objects.\n        \"\"\"\n        if not self.texts:\n            return []\n\n        # Tokenize query\n        tokenized_query = self._tokenize(query)\n\n        # Get BM25 scores\n        scores = self.bm25.get_scores(tokenized_query)\n\n        # Get top-k indices\n        top_k = min(k, len(scores))\n        top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]\n\n        # Return corresponding documents\n        return [self.documents[i] for i in top_k_indices]\n\n    def similarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> list[tuple[Document, float]]:\n        \"\"\"Search for documents similar to the query with BM25 scores.\n\n        Args:\n            query: Query text.\n            k: Number of documents to return.\n            **kwargs: Additional arguments (unused).\n\n        Returns:\n            List of (Document, score) tuples.\n        \"\"\"\n        if not self.texts:\n            return []\n\n        # Tokenize query\n        tokenized_query = self._tokenize(query)\n\n        # Get BM25 scores\n        scores = self.bm25.get_scores(tokenized_query)\n\n        # Get top-k items\n        top_k = min(k, len(scores))\n        top_k_items = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k]\n\n        # Return (Document, score) tuples\n        return [(self.documents[i], score) for i, score in top_k_items]\n\n    def _select_relevance_score_fn(self):\n        \"\"\"Return relevance score function for BM25.\n\n        BM25 scores are already relevance scores, so return identity function.\n        \"\"\"\n        return lambda score: score\n\n    def add_texts(\n        self,\n        texts: list[str],\n        metadatas: list[dict] | None = None,\n        *,\n        ids: list[str] | None = None,\n        **kwargs: Any,\n    ) -> list[str]:\n        \"\"\"Add texts to the store.\n\n        Args:\n            texts: Texts to add.\n            metadatas: Optional metadata for each text.\n            ids: Optional IDs for each text.\n            **kwargs: Additional arguments (unused).\n\n        Returns:\n            List of IDs of added texts.\n        \"\"\"\n\n        if metadatas is None:\n            metadatas = [{} for _ in texts]\n\n        if ids is None:\n            ids = [str(uuid.uuid4()) for _ in texts]\n\n        # Add to existing data\n        self.texts.extend(texts)\n        self.metadatas.extend(metadatas)\n        self.ids.extend(ids)\n\n        # Create new Document objects\n        new_documents = [\n            Document(id=doc_id, page_content=text, metadata=meta) for doc_id, text, meta in zip(ids, texts, metadatas)\n        ]\n        self.documents.extend(new_documents)\n\n        # Update BM25 index\n        new_tokenized = [self._tokenize(text) for text in texts]\n        self.tokenized_corpus.extend(new_tokenized)\n        self.bm25 = BM25Okapi(self.tokenized_corpus)\n\n        return ids\n\n    def delete(self, ids: list[str] | None = None, **kwargs: Any) -> bool | None:\n        \"\"\"Delete documents by IDs.\n\n        Args:\n            ids: List of document IDs to delete.\n            **kwargs: Additional arguments (unused).\n\n        Returns:\n            True if deletion successful, False otherwise.\n        \"\"\"\n        if ids is None:\n            return False\n\n        # Find indices to delete\n        indices_to_delete = [i for i, doc_id in enumerate(self.ids) if doc_id in ids]\n\n        if not indices_to_delete:\n            return False\n\n        # Remove items in reverse order to maintain indices\n        for idx in sorted(indices_to_delete, reverse=True):\n            del self.texts[idx]\n            del self.metadatas[idx]\n            del self.ids[idx]\n            del self.documents[idx]\n            del self.tokenized_corpus[idx]\n\n        # Rebuild BM25 index\n        if self.tokenized_corpus:\n            self.bm25 = BM25Okapi(self.tokenized_corpus)\n        else:\n            self.bm25 = None\n\n        return True\n\n    def get_by_ids(self, ids: list[str], /) -> list[Document]:\n        \"\"\"Get documents by their IDs.\n\n        Args:\n            ids: List of document IDs to retrieve.\n\n        Returns:\n            List of Document objects.\n        \"\"\"\n        id_to_doc = {doc.id: doc for doc in self.documents}\n        return [id_to_doc[doc_id] for doc_id in ids if doc_id in id_to_doc]\n\n    @classmethod\n    def from_texts(\n        cls,\n        texts: list[str],\n        embedding: Any = None,  # Unused but required by interface\n        metadatas: list[dict] | None = None,\n        *,\n        ids: list[str] | None = None,\n        **kwargs: Any,\n    ) -> \"SimpleStore\":\n        \"\"\"Create SimpleStore from texts.\n\n        Args:\n            texts: List of texts.\n            embedding: Unused (SimpleStore doesn't use embeddings).\n            metadatas: Optional metadata for each text.\n            ids: Optional IDs for each text.\n            **kwargs: Additional arguments (unused).\n\n        Returns:\n            SimpleStore instance.\n        \"\"\"\n        return cls(texts, metadatas, ids)\n\n    def max_marginal_relevance_search(\n        self,\n        query: str,\n        k: int = 4,\n        fetch_k: int = 20,\n        lambda_mult: float = 0.5,\n        **kwargs: Any,\n    ) -> list[Document]:\n        \"\"\"Return docs selected using the maximal marginal relevance.\n\n        Maximal marginal relevance optimizes for similarity to query AND diversity\n        among selected documents.\n\n        Args:\n            query: Text to look up documents similar to.\n            k: Number of `Document` objects to return.\n            fetch_k: Number of `Document` objects to fetch to pass to MMR algorithm.\n            lambda_mult: Number between `0` and `1` that determines the degree\n                of diversity among the results with `0` corresponding\n                to maximum diversity and `1` to minimum diversity.\n            **kwargs: Arguments to pass to the search method.\n\n        Returns:\n            List of `Document` objects selected by maximal marginal relevance.\n        \"\"\"\n        if not self.texts:\n            return []\n\n        # Get initial candidates using BM25 similarity search\n        candidates = self.similarity_search_with_score(query, k=fetch_k, **kwargs)\n\n        if not candidates:\n            return []\n\n        if len(candidates) <= k:\n            return [doc for doc, _ in candidates]\n\n        # Normalize BM25 scores to [0, 1] for proper MMR calculation\n        scores = [score for _, score in candidates]\n        min_score = min(scores) if scores else 0\n        max_score = max(scores) if scores else 1\n        score_range = max_score - min_score if max_score > min_score else 1\n\n        normalized_candidates = [(doc, (score - min_score) / score_range) for doc, score in candidates]\n\n        # MMR implementation following standard algorithm\n        selected = []\n        remaining = list(range(len(normalized_candidates)))\n\n        # Select documents iteratively using MMR formula\n        while len(selected) < k and remaining:\n            best_mmr_score = float(\"-inf\")\n            best_idx = -1\n            best_remaining_idx = -1\n\n            for i, doc_idx in enumerate(remaining):\n                candidate_doc, relevance_score = normalized_candidates[doc_idx]\n\n                # Calculate maximum similarity to already selected documents\n                max_similarity = 0.0\n                if selected:\n                    max_similarity = max(\n                        self._calculate_similarity(candidate_doc, normalized_candidates[sel_idx][0])\n                        for sel_idx in selected\n                    )\n\n                # Standard MMR formula: λ * Sim(q, d) - (1-λ) * max(Sim(d, s)) for s in selected\n                mmr_score = lambda_mult * relevance_score - (1 - lambda_mult) * max_similarity\n\n                if mmr_score > best_mmr_score:\n                    best_mmr_score = mmr_score\n                    best_idx = doc_idx\n                    best_remaining_idx = i\n\n            if best_idx != -1:\n                selected.append(best_idx)\n                remaining.pop(best_remaining_idx)\n\n        return [normalized_candidates[idx][0] for idx in selected]\n\n    def _calculate_similarity(self, doc1: Document, doc2: Document) -> float:\n        \"\"\"Calculate similarity between two documents using Jaccard similarity.\n\n        Args:\n            doc1: First document.\n            doc2: Second document.\n\n        Returns:\n            Similarity score between 0 and 1 (higher means more similar).\n        \"\"\"\n        tokens1 = set(self._tokenize(doc1.page_content))\n        tokens2 = set(self._tokenize(doc2.page_content))\n\n        # Calculate Jaccard similarity\n        intersection = len(tokens1 & tokens2)\n        union = len(tokens1 | tokens2)\n\n        return intersection / union if union > 0 else 0.0\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"openchatbi\"\nversion = \"0.2.2\"\ndescription = \"OpenChatBI - Natural language business intelligence powered by LLMs for intuitive data analysis and SQL generation\"\nauthors = [\n    { name = \"Yu Zhong\", email = \"zhongyu8@gmail.com\" },\n]\nlicense = { text = \"MIT\" }\nreadme = \"README.md\"\nkeywords = [\n    \"business intelligence\",\n    \"bi\",\n    \"analytics\",\n    \"llm\",\n    \"gpt\",\n    \"ai\",\n    \"machine learning\",\n    \"nlp\",\n    \"text2sql\",\n    \"agent\",\n    \"query data\",\n    \"talk to data\",\n    \"analyze data\",\n    \"data agent\",\n    \"database\",\n    \"langchain\",\n    \"langgraph\",\n    \"natural language\",\n    \"conversational ai\",\n    \"timeseries\",\n    \"forecasting\",\n    \"prediction\"\n    ]\nclassifiers = [\n    \"Development Status :: 3 - Alpha\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Science/Research\",\n    \"Intended Audience :: End Users/Desktop\",\n    \"Topic :: Software Development :: Libraries :: Python Modules\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"Topic :: Database\",\n    \"Topic :: Scientific/Engineering :: Information Analysis\",\n    \"Topic :: Office/Business\",\n    \"Programming Language :: Python :: 3\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Operating System :: OS Independent\",\n    \"License :: OSI Approved :: MIT License\",\n]\nrequires-python = \">=3.11,<4.0\"\n\ndependencies = [\n    \"requests>=2.31.0,<3.0.0\",\n    \"langgraph>=0.4.7,<1.0.0\",\n    \"langchain-openai>=0.3.18,<1.0.0\",\n    \"langchain-anthropic>=0.3.13,<1.0.0\",\n    \"langchain-community>=0.3.27,<1.0.0\",\n    \"langgraph-checkpoint-sqlite>=2.0.11\",\n    \"langchain-chroma>=0.2.5\",\n    \"langchain-mcp-adapters>=0.1.9,<0.2.0\",\n    \"langmem>=0.0.29\",\n    \"sqlalchemy>=2.0.41,<3.0.0\",\n    \"sqlalchemy-trino>=0.5.0\",\n    \"aiosqlite>=0.21.0\",\n    \"pyhive[presto]>=0.7.0\",\n    \"rank-bm25>=0.2.2,<1.0.0\",\n    \"python-levenshtein>=0.27.1\",\n    \"gradio>=5.43.1,<6.0.0\",\n    \"streamlit>=1.49.1,<2.0.0\",\n    \"RestrictedPython>=8.0,<9.0\",\n    \"docker>=7.0.0,<8.0.0\",\n    \"pandas>=2.2.0,<3.0.0\",\n    \"numpy>=2.3.0,<3.0.0\",\n    \"matplotlib>=3.10.6,<4.0.0\",\n    \"seaborn>=0.13.0,<1.0.0\",\n    \"plotly>=5.17.0,<6.0.0\",\n    \"json5>=0.10.0,<1.0.0\",\n    \"jieba>=0.42.1\", # Note: jieba is not compatible with Python 3.12+\n]\n\n[project.urls]\nHomepage = \"https://github.com/zhongyu09/openchatbi\"\nRepository = \"https://github.com/zhongyu09/openchatbi\"\nDocumentation = \"https://github.com/zhongyu09/openchatbi/tree/main\"\n\"Bug Tracker\" = \"https://github.com/zhongyu09/openchatbi/issues\"\n\n[project.optional-dependencies]\ndocs = [\n    \"sphinx>=8.2.3,<9.0.0\",\n    \"sphinx-rtd-theme>=3.0.0,<4.0.0\",\n    \"sphinx-autodoc-typehints>=2.5.0,<3.0.0\",\n    \"myst_parser\",\n    \"autodoc-pydantic\",\n]\ntest = [\n    \"pytest>=7.4.0,<9.0.0\",\n    \"pytest-mock>=3.14.0,<4.0.0\",\n    \"pytest-asyncio>=0.23.8,<1.0.0\",\n    \"pytest-sugar>=1.0.0,<2.0.0\",\n    \"pytest-cov>=6.0.0,<7.0.0\",\n    \"aioresponses>=0.7.7,<1.0.0\",\n    \"responses>=0.25.3,<1.0.0\",\n    \"langsmith[pytest]>=0.4.8,<1.0.0\",\n    \"openevals>=0.1.0,<1.0.0\",\n]\ndev = [\n    \"openchatbi[test,docs]\",\n    \"black>=24.10.0,<25.0.0\",\n    \"mypy>=1.13.0,<2.0.0\",\n    \"ruff>=0.8.0,<1.0.0\",\n    \"pre-commit>=4.0.1,<5.0.0\",\n    \"bandit>=1.8.6,<2.0.0\",\n    \"types-setuptools>=75.6.0.20241126\",\n    \"twine>=6.0.0,<7.0.0\",\n]\n\n[tool.uv]\nmanaged = true\ndev-dependencies = [\n    \"openchatbi[dev]\",\n]\n\n\n[build-system]\nrequires = [\"hatchling>=1.26.0\"]\nbuild-backend = \"hatchling.build\"\n\n[tool.hatch.build.targets.wheel]\npackages = [\"openchatbi\"]\n\n[tool.hatch.build.targets.sdist]\ninclude = [\n    \"/openchatbi\",\n    \"/tests\",\n    \"/README.md\",\n    \"/LICENSE\",\n]\n\n[tool.hatch.metadata]\nallow-direct-references = true\n\n[tool.black]\nline-length = 120\ntarget-version = [\"py311\"]\ninclude = '\\.pyi?$'\nexclude = '''\n/(\n    \\.git\n  | \\.mypy_cache\n  | \\.tox\n  | \\.venv\n  | _build\n  | buck-out\n  | build\n  | dist\n)/\n'''\nskip-string-normalization = false\nskip-magic-trailing-comma = false\npreview = false\n\n\n[tool.ruff]\nline-length = 120\ntarget-version = \"py311\"\nexclude = [\n    \".git\",\n    \".mypy_cache\",\n    \".tox\",\n    \".venv\",\n    \"_build\",\n    \"buck-out\",\n    \"build\",\n    \"dist\",\n]\n\n[tool.ruff.lint]\nselect = [\n    \"E\",   # pycodestyle errors\n    \"W\",   # pycodestyle warnings\n    \"F\",   # pyflakes\n    \"I\",   # isort\n    \"C\",   # flake8-comprehensions\n    \"B\",   # flake8-bugbear\n    \"UP\",  # pyupgrade\n]\nignore = [\n    \"E501\",  # line too long, handled by black\n    \"B008\",  # do not perform function calls in argument defaults\n    \"C901\",  # too complex\n]\n\n[tool.ruff.lint.per-file-ignores]\n\"__init__.py\" = [\"F401\"]\n\"tests/**/*\" = [\"B011\"]\n\n[tool.mypy]\npython_version = \"3.11\"\nstrict = true\nwarn_return_any = true\nwarn_unused_configs = true\ndisallow_untyped_defs = true\ndisallow_incomplete_defs = true\ncheck_untyped_defs = true\ndisallow_untyped_decorators = true\nno_implicit_optional = true\nwarn_redundant_casts = true\nwarn_unused_ignores = true\nwarn_no_return = true\nwarn_unreachable = true\nignore_missing_imports = true\nshow_error_codes = true\n\n[tool.pytest.ini_options]\nminversion = \"7.0\"\naddopts = [\n    \"--strict-markers\",\n    \"--strict-config\",\n    \"--cov=openchatbi\",\n    \"--cov-report=term-missing\",\n    \"--cov-report=html\",\n    \"--cov-report=xml\",\n]\ntestpaths = [\"tests\"]\nmarkers = [\n    \"unit: Unit tests\",\n    \"integration: Integration tests\",\n    \"slow: Slow tests that may take several seconds\",\n    \"requires_db: Tests that require database connection\",\n    \"requires_llm: Tests that require LLM service\",\n    \"asyncio: Asynchronous tests\"\n]\nfilterwarnings = [\n    \"error\",\n    \"ignore::UserWarning\",\n    \"ignore::DeprecationWarning\",\n]\n\n[tool.coverage.run]\nsource = [\"openchatbi\"]\nomit = [\n    \"*/tests/*\",\n    \"*/test_*.py\",\n    \"setup.py\",\n]\n\n[tool.coverage.report]\nexclude_lines = [\n    \"pragma: no cover\",\n    \"def __repr__\",\n    \"if self.debug:\",\n    \"if settings.DEBUG\",\n    \"raise AssertionError\",\n    \"raise NotImplementedError\",\n    \"if 0:\",\n    \"if __name__ == .__main__.:\",\n    \"class .*\\\\bProtocol\\\\):\",\n    \"@(abc\\\\.)?abstractmethod\",\n]\n"
  },
  {
    "path": "run_streamlit_ui.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nLaunch script for the Streamlit-based OpenChatBI interface.\n\nUsage:\n    python run_streamlit_ui.py\n\nThis will start the Streamlit server on http://localhost:8501\n\"\"\"\n\nimport os\nimport subprocess\nimport sys\n\n\ndef main():\n    \"\"\"Launch the Streamlit UI\"\"\"\n    # Change to the project directory\n    project_dir = os.path.dirname(os.path.abspath(__file__))\n    os.chdir(project_dir)\n\n    print(\"🚀 Starting OpenChatBI Streamlit UI...\")\n    print(\"📍 URL: http://localhost:8501\")\n    print(\"⏹️  Press Ctrl+C to stop the server\")\n    print(\"-\" * 50)\n\n    try:\n        # Run streamlit with the new UI file\n        subprocess.run(\n            [\n                sys.executable,\n                \"-m\",\n                \"streamlit\",\n                \"run\",\n                \"sample_ui/streamlit_ui.py\",\n                \"--server.port=8501\",\n                \"--server.address=localhost\",\n            ],\n            check=True,\n        )\n    except KeyboardInterrupt:\n        print(\"\\n👋 Stopping Streamlit server...\")\n    except subprocess.CalledProcessError as e:\n        print(f\"❌ Error starting Streamlit: {e}\")\n        print(\"\\n💡 Make sure Streamlit is installed:\")\n        print(\"   pip install streamlit\")\n    except FileNotFoundError:\n        print(\"❌ Python or Streamlit not found\")\n        print(\"\\n💡 Make sure Python and Streamlit are installed:\")\n        print(\"   pip install streamlit\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "run_tests.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Test runner script for OpenChatBI.\"\"\"\n\nimport argparse\nimport subprocess\nimport sys\n\n\ndef run_command(cmd, description):\n    \"\"\"Run a command and return the result.\"\"\"\n    print(f\"\\\\n{'=' * 60}\")\n    print(f\"Running: {description}\")\n    print(f\"Command: {' '.join(cmd)}\")\n    print(f\"{'=' * 60}\")\n\n    result = subprocess.run(cmd, capture_output=True, text=True)\n\n    if result.stdout:\n        print(\"STDOUT:\")\n        print(result.stdout)\n\n    if result.stderr:\n        print(\"STDERR:\")\n        print(result.stderr)\n\n    if result.returncode != 0:\n        print(f\"❌ {description} failed with return code {result.returncode}\")\n        return False\n    else:\n        print(f\"✅ {description} passed\")\n        return True\n\n\ndef main():\n    \"\"\"Main test runner function.\"\"\"\n    parser = argparse.ArgumentParser(description=\"Run OpenChatBI tests\")\n    parser.add_argument(\"--unit\", action=\"store_true\", help=\"Run only unit tests\")\n    parser.add_argument(\"--integration\", action=\"store_true\", help=\"Run only integration tests\")\n    parser.add_argument(\"--coverage\", action=\"store_true\", help=\"Run with coverage report\")\n    parser.add_argument(\"--verbose\", \"-v\", action=\"store_true\", help=\"Verbose output\")\n    parser.add_argument(\"--fast\", action=\"store_true\", help=\"Skip slow tests\")\n    parser.add_argument(\"--lint\", action=\"store_true\", help=\"Run linting checks\")\n    parser.add_argument(\"--type-check\", action=\"store_true\", help=\"Run type checking\")\n    parser.add_argument(\"--all\", action=\"store_true\", help=\"Run all checks (tests, lint, type-check)\")\n    parser.add_argument(\"--file\", help=\"Run specific test file\")\n\n    args = parser.parse_args()\n\n    # Determine test command\n    base_cmd = [\"uv\", \"run\", \"pytest\"]\n\n    if args.verbose:\n        base_cmd.append(\"-v\")\n\n    if args.coverage:\n        base_cmd.extend([\"--cov=openchatbi\", \"--cov-report=html\", \"--cov-report=term-missing\"])\n\n    if args.unit:\n        base_cmd.extend([\"-m\", \"unit\"])\n    elif args.integration:\n        base_cmd.extend([\"-m\", \"integration\"])\n    elif args.fast:\n        base_cmd.extend([\"-m\", \"not slow\"])\n\n    if args.file:\n        base_cmd.append(f\"tests/{args.file}\")\n\n    success = True\n\n    # Run tests\n    if not args.lint and not args.type_check:\n        success &= run_command(base_cmd, \"Unit Tests\")\n\n    # Run linting if requested\n    if args.lint or args.all:\n        lint_commands = [\n            ([\"uv\", \"run\", \"black\", \"--check\", \".\"], \"Black formatting check\"),\n            ([\"uv\", \"run\", \"isort\", \"--check-only\", \".\"], \"Import sorting check\"),\n            ([\"uv\", \"run\", \"ruff\", \"check\", \".\"], \"Ruff linting\"),\n            ([\"uv\", \"run\", \"bandit\", \"-r\", \"openchatbi/\"], \"Security scanning\"),\n        ]\n\n        for cmd, desc in lint_commands:\n            success &= run_command(cmd, desc)\n\n    # Run type checking if requested\n    if args.type_check or args.all:\n        success &= run_command([\"uv\", \"run\", \"mypy\", \"openchatbi/\"], \"Type checking\")\n\n    # Run all tests if --all is specified\n    if args.all:\n        test_commands = [\n            ([\"uv\", \"run\", \"pytest\", \"-m\", \"unit\", \"-v\"], \"Unit Tests\"),\n            ([\"uv\", \"run\", \"pytest\", \"-m\", \"integration\", \"-v\"], \"Integration Tests\"),\n            ([\"uv\", \"run\", \"pytest\", \"--cov=openchatbi\", \"--cov-report=html\"], \"Coverage Report\"),\n        ]\n\n        for cmd, desc in test_commands:\n            success &= run_command(cmd, desc)\n\n    # Print summary\n    print(f\"\\\\n{'=' * 60}\")\n    if success:\n        print(\"🎉 All checks passed!\")\n        sys.exit(0)\n    else:\n        print(\"❌ Some checks failed!\")\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "sample_api/async_api.py",
    "content": "\"\"\"Async API for streaming chat responses from OpenChatBI.\"\"\"\n\nimport asyncio\nfrom typing import Any\nfrom collections import defaultdict\nfrom contextlib import asynccontextmanager\n\nfrom fastapi import FastAPI, HTTPException\nfrom fastapi.responses import StreamingResponse\nfrom langchain_core.messages import AIMessageChunk\nfrom pydantic import BaseModel\n\nfrom openchatbi import config\nfrom openchatbi.agent_graph import build_agent_graph_async\nfrom openchatbi.utils import get_report_download_response\n\n# Session state storage: session_id -> state\nsessions = defaultdict(dict)\n\n# Graphs keyed by provider name\ngraphs: dict[str, Any] = {}\ngraphs_lock = asyncio.Lock()\n\n\nasync def get_or_build_graph(provider: str | None):\n    \"\"\"Get (or lazily build) a graph for the requested provider.\"\"\"\n    key = provider or \"__default__\"\n    if key in graphs:\n        return graphs[key]\n    async with graphs_lock:\n        if key in graphs:\n            return graphs[key]\n        graphs[key] = await build_agent_graph_async(config.get().catalog_store, llm_provider=provider)\n        return graphs[key]\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    \"\"\"Manage application lifespan events.\"\"\"\n    # Startup: Initialize the async graph\n    graphs[\"__default__\"] = await build_agent_graph_async(config.get().catalog_store)\n    yield\n    # Shutdown: cleanup if needed\n    graphs.clear()\n\n\napp = FastAPI(lifespan=lifespan)\n\n\nclass UserRequest(BaseModel):\n    \"\"\"Request model for streaming chat.\"\"\"\n\n    input: str\n    user_id: str | None = \"default\"\n    session_id: str | None = \"default\"\n    provider: str | None = None\n\n\n@app.post(\"/chat/stream\")\nasync def chat_stream(req: UserRequest):\n    \"\"\"Stream chat responses from the agent graph.\"\"\"\n    user_id = req.user_id or \"default\"\n    session_id = req.session_id or \"default\"\n    provider = req.provider\n\n    # Create user-session ID just like in UI\n    user_session_id = f\"{user_id}-{session_id}\"\n\n    stream_input = {\"messages\": [(\"user\", req.input)]}\n    config = {\"configurable\": {\"thread_id\": user_session_id, \"user_id\": user_id}}\n\n    try:\n        graph = await get_or_build_graph(provider)\n    except ValueError as e:\n        raise HTTPException(status_code=400, detail=str(e)) from e\n\n    async def event_generator():\n        \"\"\"Generate streaming events from the graph.\"\"\"\n        async for _namespace, event_type, event_value in graph.astream(\n            stream_input, config=config, stream_mode=[\"updates\", \"messages\"], subgraphs=True\n        ):\n            text = \"\"\n            if event_type == \"messages\":\n                message_chunk = event_value[0]\n                if isinstance(message_chunk, AIMessageChunk):\n                    text = message_chunk.content\n            elif event_value.get(\"llm_node\") and event_value[\"llm_node\"].get(\"final_answer\"):\n                text = event_value[\"llm_node\"][\"final_answer\"]\n            if text:\n                yield text\n\n    return StreamingResponse(event_generator(), media_type=\"text/plain\")\n\n\n@app.get(\"/user/{user_id}/memories\")\nasync def get_user_memories(user_id: str):\n    \"\"\"Get all memories for a specific user.\"\"\"\n    try:\n        # Import required modules for memory access\n        import json\n\n        from openchatbi.tool.memory import get_async_memory_store\n\n        # Get the async memory store\n        memory_store = await get_async_memory_store()\n\n        memories = []\n        namespace = (\"memories\", user_id)\n\n        try:\n            # Search for all memories for this user\n            search_results = memory_store.search(namespace)\n\n            for item in search_results:\n                # Parse the memory data\n                try:\n                    content = json.loads(item.value.decode(\"utf-8\")) if isinstance(item.value, bytes) else item.value\n                except (json.JSONDecodeError, AttributeError):\n                    content = str(item.value)\n\n                memory_data = {\n                    \"key\": item.key,\n                    \"content\": content,\n                    \"namespace\": str(namespace),\n                    \"created_at\": getattr(item, \"created_at\", \"Unknown\"),\n                    \"updated_at\": getattr(item, \"updated_at\", \"Unknown\"),\n                }\n                memories.append(memory_data)\n\n            return {\"user_id\": user_id, \"total_memories\": len(memories), \"memories\": memories}\n\n        except Exception as e:\n            raise HTTPException(status_code=500, detail=f\"Error retrieving memories: {str(e)}\") from e\n\n    except Exception as e:\n        raise HTTPException(status_code=500, detail=f\"Failed to access memory store: {str(e)}\") from e\n\n\n@app.get(\"/api/download/report/{filename}\")\nasync def download_report(filename: str):\n    \"\"\"Download a saved report file.\"\"\"\n    return get_report_download_response(filename)\n\n\nif __name__ == \"__main__\":\n    import uvicorn\n\n    uvicorn.run(app, host=\"0.0.0.0\", port=8000)\n"
  },
  {
    "path": "sample_ui/async_graph_manager.py",
    "content": "\"\"\"Common AsyncGraphManager for UIs.\"\"\"\n\nfrom typing import Any\n\nfrom langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver\n\nfrom openchatbi import config\nfrom openchatbi.agent_graph import build_agent_graph_async\nfrom openchatbi.tool.memory import cleanup_async_memory_store, get_async_memory_store, setup_async_memory_store\nfrom openchatbi.utils import log\n\n\nclass AsyncGraphManager:\n    \"\"\"Manages the async graph and checkpointer lifecycle\"\"\"\n\n    def __init__(self):\n        self.checkpointer = None\n        self.graph = None  # Default graph (backwards compatible)\n        self.graphs: dict[str, Any] = {}\n        self._context_manager = None\n        self._memory_store = None\n        self._initialized = False\n\n    async def initialize(self):\n        \"\"\"Initialize the graph and checkpointer\"\"\"\n        if self._initialized:\n            return\n\n        try:\n            # Setup async memory store\n            await setup_async_memory_store()\n\n            # Initialize checkpointer\n            self._context_manager = AsyncSqliteSaver.from_conn_string(\"checkpoints.db\")\n            self.checkpointer = await self._context_manager.__aenter__()\n\n            # Cache store for graph builds\n            self._memory_store = await get_async_memory_store()\n\n            self._initialized = True\n\n            # Build default graph for backwards compatibility\n            self.graph = await self.get_graph()\n\n            log(\"Graph initialized successfully\")\n\n        except Exception as e:\n            self._initialized = False\n            log(f\"Failed to initialize graph: {e}\")\n            raise\n\n    async def get_graph(self, llm_provider: str | None = None):\n        \"\"\"Get or build a graph for the requested LLM provider.\"\"\"\n        if not self._initialized:\n            await self.initialize()\n\n        key = llm_provider or \"__default__\"\n        if key in self.graphs:\n            return self.graphs[key]\n\n        graph = await build_agent_graph_async(\n            config.get().catalog_store,\n            checkpointer=self.checkpointer,\n            memory_store=self._memory_store,\n            memory_tools=None,  # Let graph builder create provider-appropriate tools\n            llm_provider=llm_provider,\n        )\n        self.graphs[key] = graph\n        return graph\n\n    async def cleanup(self):\n        \"\"\"Cleanup resources\"\"\"\n        if self.checkpointer is not None and self._context_manager is not None:\n            try:\n                await self._context_manager.__aexit__(None, None, None)\n                await cleanup_async_memory_store()\n                log(\"Graph cleaned up successfully\")\n            except Exception as e:\n                log(f\"Error during cleanup: {e}\")\n            finally:\n                self.checkpointer = None\n                self.graph = None\n                self.graphs = {}\n                self._context_manager = None\n                self._memory_store = None\n                self._initialized = False\n"
  },
  {
    "path": "sample_ui/memory_ui.py",
    "content": "\"\"\"Memory listing UI for OpenChatBI using FastAPI and Gradio.\"\"\"\n\nimport json\nfrom typing import Any\n\nimport gradio as gr\nimport uvicorn\nfrom fastapi import FastAPI\n\nfrom sample_ui.style import custom_css\n\n\ndef get_thread_memory_store() -> Any:\n    \"\"\"Create a thread-safe memory store connection.\"\"\"\n    try:\n        import pysqlite3 as sqlite3\n    except ImportError:\n        import sqlite3\n    from langgraph.store.sqlite import SqliteStore\n\n    from openchatbi import config\n\n    conn = sqlite3.connect(\"memory.db\", check_same_thread=False)\n    conn.isolation_level = None  # Use autocommit mode to avoid transaction conflicts\n    store = SqliteStore(conn, index={\"dims\": 1536, \"embed\": config.get().embedding_model, \"fields\": [\"text\"]})\n    try:\n        store.setup()\n    except Exception:\n        pass  # Store might already be set up\n    return store, conn\n\n\ndef list_all_memories() -> list[dict[str, Any]]:\n    \"\"\"\n    Retrieve all memories from the memory store.\n\n    Returns:\n        List of memory items with their metadata\n    \"\"\"\n    try:\n        memory_store, conn = get_thread_memory_store()\n        memories = []\n\n        try:\n            # Use search with partial namespace to find all memory items\n            items = memory_store.search((\"memories\",), limit=1000)\n            for item in items:\n                memory_data = {\n                    \"namespace\": item.namespace,\n                    \"key\": item.key,\n                    \"value\": item.value,\n                    \"created_at\": getattr(item, \"created_at\", \"Unknown\"),\n                    \"updated_at\": getattr(item, \"updated_at\", \"Unknown\"),\n                }\n                memories.append(memory_data)\n        except Exception as e:\n            return [{\"error\": f\"Failed to retrieve memories: {str(e)}\"}]\n        finally:\n            conn.close()\n\n        return memories\n\n    except Exception as e:\n        return [{\"error\": f\"Failed to access memory store: {str(e)}\"}]\n\n\ndef format_memories_for_display(memories: list[dict[str, Any]]) -> str:\n    \"\"\"\n    Format memories for display in the Gradio interface.\n\n    Args:\n        memories: List of memory items\n\n    Returns:\n        Formatted string for display\n    \"\"\"\n    if not memories:\n        return \"No memories found.\"\n\n    if len(memories) == 1 and \"error\" in memories[0]:\n        return f\"Error: {memories[0]['error']}\"\n\n    formatted = []\n    for i, memory in enumerate(memories, 1):\n        if \"error\" in memory:\n            formatted.append(f\"**Error:** {memory['error']}\")\n            continue\n\n        formatted.append(f\"## Memory {i}\")\n        formatted.append(f\"**Namespace:** {memory['namespace']}\")\n        formatted.append(f\"**Key:** {memory['key']}\")\n\n        # Format the value nicely\n        value = memory[\"value\"]\n        if isinstance(value, dict):\n            try:\n                value_str = json.dumps(value, indent=2)\n                formatted.append(f\"**Content:**\\n```json\\n{value_str}\\n```\")\n            except:\n                formatted.append(f\"**Content:** {str(value)}\")\n        else:\n            formatted.append(f\"**Content:** {str(value)}\")\n\n        formatted.append(f\"**Created:** {memory['created_at']}\")\n        formatted.append(f\"**Updated:** {memory['updated_at']}\")\n        formatted.append(\"---\")\n\n    return \"\\n\".join(formatted)\n\n\ndef refresh_memories() -> list[list[str]]:\n    \"\"\"Refresh and return formatted memories.\"\"\"\n    memories = list_all_memories()\n    return format_memories_for_display(memories)\n\n\ndef delete_memory_by_key(namespace_str: str, key: str) -> str:\n    \"\"\"\n    Delete a memory by namespace and key.\n\n    Args:\n        namespace_str: String representation of namespace (e.g., \"('memories', 'user1')\")\n        key: Memory key to delete\n\n    Returns:\n        Status message\n    \"\"\"\n    try:\n        import ast\n\n        memory_store, conn = get_thread_memory_store()\n\n        try:\n            # Parse namespace string back to tuple\n            namespace = ast.literal_eval(namespace_str)\n\n            # Delete the item\n            memory_store.delete(namespace, key)\n            return f\"Successfully deleted memory: {key} from namespace {namespace}\"\n        finally:\n            conn.close()\n    except Exception as e:\n        return f\"Failed to delete memory: {str(e)}\"\n\n\n# ---------- FastAPI ----------\napp = FastAPI()\n\n# ---------- Gradio UI ----------\nwith gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:\n    gr.Markdown(\"## 🧠 Memory Store Viewer\")\n    gr.Markdown(\"View and manage long-term memories stored in the OpenChatBI system.\")\n\n    with gr.Row():\n        with gr.Column(scale=3):\n            memories_display = gr.Markdown(value=refresh_memories(), elem_id=\"memories-display\")\n\n        with gr.Column(scale=1):\n            gr.Markdown(\"### Actions\")\n            refresh_btn = gr.Button(\"🔄 Refresh Memories\", variant=\"primary\")\n\n            gr.Markdown(\"### Delete Memory\")\n            namespace_input = gr.Textbox(\n                label=\"Namespace\",\n                placeholder=\"('memories', 'user_id')\",\n                info=\"Copy the exact namespace from the memory list\",\n            )\n            key_input = gr.Textbox(\n                label=\"Key\", placeholder=\"memory_key\", info=\"Copy the exact key from the memory list\"\n            )\n            delete_btn = gr.Button(\"🗑️ Delete Memory\", variant=\"stop\")\n            delete_status = gr.Textbox(label=\"Status\", interactive=False)\n\n    # Event handlers\n    refresh_btn.click(fn=refresh_memories, outputs=[memories_display])\n\n    delete_btn.click(fn=delete_memory_by_key, inputs=[namespace_input, key_input], outputs=[delete_status]).then(\n        fn=refresh_memories, outputs=[memories_display]\n    )\n\n# ---------- Application Startup ----------\n# Mount Gradio app to FastAPI\napp = gr.mount_gradio_app(app, demo, path=\"/memory\")\n\nif __name__ == \"__main__\":\n    uvicorn.run(app, host=\"0.0.0.0\", port=8001)\n"
  },
  {
    "path": "sample_ui/plotly_utils.py",
    "content": "\"\"\"Plotly utilities for generating charts from visualization DSL.\"\"\"\n\nfrom io import StringIO\nfrom typing import Any\n\nimport pandas as pd\nimport plotly.express as px\nimport plotly.graph_objects as go\n\n\ndef create_plotly_chart(data_csv: str, visualization_dsl: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a plotly chart from CSV data and visualization DSL.\n\n    Args:\n        data_csv: CSV string containing the data\n        visualization_dsl: Dictionary containing chart configuration\n\n    Returns:\n        Plotly Figure object\n    \"\"\"\n    if not data_csv or not visualization_dsl:\n        return create_empty_chart(\"No data available\")\n\n    if \"error\" in visualization_dsl:\n        return create_empty_chart(f\"Visualization error: {visualization_dsl['error']}\")\n\n    try:\n        # Parse CSV data\n        df = pd.read_csv(StringIO(data_csv))\n\n        if df.empty:\n            return create_empty_chart(\"No data to visualize\")\n\n        chart_type = visualization_dsl.get(\"chart_type\", \"table\")\n        config = visualization_dsl.get(\"config\", {})\n        layout = visualization_dsl.get(\"layout\", {})\n\n        # Create chart based on type\n        if chart_type == \"line\":\n            return create_line_chart(df, config, layout)\n        elif chart_type == \"bar\":\n            return create_bar_chart(df, config, layout)\n        elif chart_type == \"pie\":\n            return create_pie_chart(df, config, layout)\n        elif chart_type == \"scatter\":\n            return create_scatter_chart(df, config, layout)\n        elif chart_type == \"histogram\":\n            return create_histogram_chart(df, config, layout)\n        elif chart_type == \"box\":\n            return create_box_chart(df, config, layout)\n        elif chart_type == \"table\":\n            return create_table_chart(df, config, layout)\n        else:\n            return create_empty_chart(f\"Unsupported chart type: {chart_type}\")\n\n    except Exception as e:\n        return create_empty_chart(f\"Chart generation error: {str(e)}\")\n\n\ndef create_line_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a line chart.\"\"\"\n    x_col = config.get(\"x\")\n    y_col = config.get(\"y\")\n    color_col = config.get(\"color\")\n\n    if not x_col or x_col not in df.columns:\n        return create_empty_chart(\"Missing required x column for line chart\")\n\n    # Handle multiple y columns case\n    if isinstance(y_col, list):\n        # Multiple metrics - need to melt the data\n        if not all(col in df.columns for col in y_col):\n            return create_empty_chart(\"Some y columns missing from data\")\n\n        # Melt the dataframe to long format for multiple series\n        melted_df = df.melt(id_vars=[x_col], value_vars=y_col, var_name=\"metric\", value_name=\"value\")\n        fig = px.line(melted_df, x=x_col, y=\"value\", color=\"metric\")\n\n    else:\n        # Single y column\n        if not y_col or y_col not in df.columns:\n            return create_empty_chart(\"Missing required y column for line chart\")\n\n        # Check if color column exists and is valid\n        if color_col and color_col in df.columns:\n            fig = px.line(df, x=x_col, y=y_col, color=color_col)\n        else:\n            fig = px.line(df, x=x_col, y=y_col)\n\n    fig.update_layout(**layout)\n    return fig\n\n\ndef create_bar_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a bar chart.\"\"\"\n    x_col = config.get(\"x\")\n    y_col = config.get(\"y\")\n\n    if not x_col or x_col not in df.columns:\n        return create_empty_chart(\"Missing required x column for bar chart\")\n\n    # Handle multiple y columns case\n    if isinstance(y_col, list):\n        # Multiple metrics - need to melt the data\n        if not all(col in df.columns for col in y_col):\n            return create_empty_chart(\"Some y columns missing from data\")\n\n        # Melt the dataframe to long format for multiple series\n        melted_df = df.melt(id_vars=[x_col], value_vars=y_col, var_name=\"metric\", value_name=\"value\")\n        fig = px.bar(melted_df, x=x_col, y=\"value\", color=\"metric\")\n\n    else:\n        # Single y column\n        if not y_col or y_col not in df.columns:\n            return create_empty_chart(\"Missing required y column for bar chart\")\n\n        fig = px.bar(df, x=x_col, y=y_col)\n\n    fig.update_layout(**layout)\n    return fig\n\n\ndef create_pie_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a pie chart.\"\"\"\n    labels_col = config.get(\"labels\")\n    values_col = config.get(\"values\")\n\n    if not labels_col or not values_col or labels_col not in df.columns or values_col not in df.columns:\n        return create_empty_chart(\"Missing required columns for pie chart\")\n\n    fig = px.pie(df, names=labels_col, values=values_col)\n    fig.update_layout(**layout)\n    return fig\n\n\ndef create_scatter_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a scatter plot.\"\"\"\n    x_col = config.get(\"x\")\n    y_col = config.get(\"y\")\n\n    if not x_col or not y_col or x_col not in df.columns or y_col not in df.columns:\n        return create_empty_chart(\"Missing required columns for scatter plot\")\n\n    fig = px.scatter(df, x=x_col, y=y_col)\n    fig.update_layout(**layout)\n    return fig\n\n\ndef create_histogram_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a histogram.\"\"\"\n    x_col = config.get(\"x\")\n    nbins = config.get(\"nbins\", 20)\n\n    if not x_col or x_col not in df.columns:\n        return create_empty_chart(\"Missing required column for histogram\")\n\n    fig = px.histogram(df, x=x_col, nbins=nbins)\n    fig.update_layout(**layout)\n    return fig\n\n\ndef create_box_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a box plot.\"\"\"\n    y_col = config.get(\"y\")\n    x_col = config.get(\"x\")\n\n    if not y_col or y_col not in df.columns:\n        return create_empty_chart(\"Missing required column for box plot\")\n\n    if x_col and x_col in df.columns:\n        fig = px.box(df, x=x_col, y=y_col)\n    else:\n        fig = px.box(df, y=y_col)\n\n    fig.update_layout(**layout)\n    return fig\n\n\ndef create_table_chart(df: pd.DataFrame, config: dict[str, Any], layout: dict[str, Any]) -> go.Figure:\n    \"\"\"Create a table display.\"\"\"\n    columns = config.get(\"columns\", list(df.columns))\n\n    # Limit to first 100 rows for display\n    display_df = df.head(100)\n\n    fig = go.Figure(\n        data=[\n            go.Table(\n                header=dict(values=columns, fill_color=\"lightblue\", align=\"left\"),\n                cells=dict(\n                    values=[display_df[col] for col in columns if col in display_df.columns],\n                    fill_color=\"white\",\n                    align=\"left\",\n                ),\n            )\n        ]\n    )\n\n    fig.update_layout(**layout)\n    return fig\n\n\ndef create_empty_chart(message: str) -> go.Figure:\n    \"\"\"Create an empty chart with a message.\"\"\"\n    fig = go.Figure()\n    fig.add_annotation(\n        text=message,\n        xref=\"paper\",\n        yref=\"paper\",\n        x=0.5,\n        y=0.5,\n        xanchor=\"center\",\n        yanchor=\"middle\",\n        showarrow=False,\n        font=dict(size=16),\n    )\n    fig.update_layout(\n        title=\"Chart Generation Issue\",\n        xaxis=dict(showgrid=False, showticklabels=False, zeroline=False),\n        yaxis=dict(showgrid=False, showticklabels=False, zeroline=False),\n    )\n    return fig\n\n\ndef visualization_dsl_to_gradio_plot(data_csv: str, visualization_dsl: dict[str, Any]) -> tuple[go.Figure, str]:\n    \"\"\"Convert visualization DSL to Gradio-compatible plotly figure.\n\n    Args:\n        data_csv: CSV string containing the data\n        visualization_dsl: Dictionary containing chart configuration\n\n    Returns:\n        Tuple of (plotly figure, description string)\n    \"\"\"\n    fig = create_plotly_chart(data_csv, visualization_dsl)\n\n    if visualization_dsl:\n        chart_type = visualization_dsl.get(\"chart_type\", \"unknown\")\n        layout_title = visualization_dsl.get(\"layout\", {}).get(\"title\", f\"{chart_type.title()} Chart\")\n        description = f\"Generated {chart_type} visualization: {layout_title}\"\n    else:\n        description = \"Data table view\"\n\n    return fig, description\n\n\ndef create_inline_chart_markdown(data_csv: str, visualization_dsl: dict[str, Any]) -> str:\n    \"\"\"Create a simplified markdown representation of the chart for inline display.\n\n    This creates a text-based summary with a clickable link to show the interactive chart.\n    \"\"\"\n    if not data_csv or not visualization_dsl:\n        return \"📊 *No visualization data available*\"\n\n    if \"error\" in visualization_dsl:\n        return f\"⚠️ *Visualization error: {visualization_dsl['error']}*\"\n\n    try:\n        from io import StringIO\n\n        import pandas as pd\n\n        df = pd.read_csv(StringIO(data_csv))\n        chart_type = visualization_dsl.get(\"chart_type\", \"table\")\n        layout = visualization_dsl.get(\"layout\", {})\n        title = layout.get(\"title\", f\"{chart_type.title()} Chart\")\n\n        # Create a text summary with key data points and view instruction\n        summary_lines = [\n            f\"📊 **{title}**\",\n            \"\",\n            f\"*Chart Type: {chart_type.title()}* | *Data Points: {len(df)} rows, {len(df.columns)} columns*\",\n            \"\",\n            \"✨ **Interactive chart will appear automatically in the chart panel →**\",\n            \"\",\n        ]\n\n        # Add a small data sample\n        if len(df) > 0:\n            summary_lines.append(\"**Sample Data:**\")\n            summary_lines.append(\"```\")\n            # Show first few rows in a clean format\n            sample_df = df.head(3)\n            summary_lines.append(sample_df.to_string(index=False))\n            if len(df) > 3:\n                summary_lines.append(f\"... and {len(df) - 3} more rows\")\n            summary_lines.append(\"```\\n\")\n\n        return \"\\n\".join(summary_lines)\n\n    except Exception as e:\n        return f\"⚠️ *Chart generation error: {str(e)}*\"\n"
  },
  {
    "path": "sample_ui/simple_ui.py",
    "content": "\"\"\"Simple web UI for OpenChatBI using FastAPI and Gradio.\"\"\"\n\nfrom collections import defaultdict\n\nimport gradio as gr\nimport uvicorn\nfrom fastapi import FastAPI\nfrom langgraph.checkpoint.sqlite import SqliteSaver\nfrom langgraph.types import Command\n\nfrom openchatbi import config\nfrom openchatbi.agent_graph import build_agent_graph_sync\nfrom openchatbi.tool.memory import get_sync_memory_store\nfrom openchatbi.utils import get_report_download_response, log\nfrom sample_ui.style import custom_css\n\n# Session state storage: session_id -> state\nsession_interrupt = defaultdict(bool)\n\n# Use SqliteSaver for persistence\nsqlite_checkpointer_cm = SqliteSaver.from_conn_string(\"checkpoints.db\")\nsqlite_checkpointer = sqlite_checkpointer_cm.__enter__()\ngraph = build_agent_graph_sync(\n    config.get().catalog_store, checkpointer=sqlite_checkpointer, memory_store=get_sync_memory_store()\n)\n\n# ---------- FastAPI ----------\napp = FastAPI()\n\n\n# ---------- Gradio UI ----------\ndef chat_fn(message: str, history: list[tuple[str, str]], user_id: str = \"default\", session_id: str = \"default\") -> str:\n    \"\"\"Chat function for Gradio interface.\"\"\"\n    user_session_id = f\"{user_id}-{session_id}\"\n    config = {\"configurable\": {\"thread_id\": user_session_id, \"user_id\": user_id}}\n\n    if session_interrupt[user_session_id]:\n        inputs = Command(resume=message)\n    else:\n        inputs = {\"messages\": [{\"role\": \"user\", \"content\": message}]}\n\n    # Use synchronous call\n    result = graph.invoke(inputs, config=config)\n    state = graph.get_state(config)\n    if state.interrupts:\n        log(f\"state.interrupts: {state.interrupts}\")\n        output_content = state.interrupts[0].value.get(\"text\")\n        session_interrupt[user_session_id] = True\n    else:\n        session_interrupt[user_session_id] = False\n        output_content = result[\"messages\"][-1].content\n\n    return output_content\n\n\n# Create Gradio interface with custom CSS and theme\nwith gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:\n    gr.Markdown(\"## 💬 OpenChatBI Agent Chatbot\")\n\n    with gr.Row():\n        with gr.Column(scale=4):\n            chatbot = gr.Chatbot(elem_id=\"chatbot\", label=\"\", bubble_full_width=False, height=600)\n            msg = gr.Textbox(placeholder=\"Type a message and press Enter\", label=\"Input\", elem_id=\"msg\")\n        with gr.Column(scale=1):\n            user_box = gr.Textbox(value=\"default\", label=\"User ID\", interactive=True)\n            session_box = gr.Textbox(value=\"default\", label=\"Session ID\", interactive=True)\n            gr.Markdown(\n                \"\"\"\n            **Instructions**\n            - Type a message and press Enter to send\n            - User ID is used for memory isolation between users\n            - Session ID can be used to differentiate between conversations\n            \"\"\",\n                elem_id=\"description\",\n            )\n\n    def respond(\n        message: str, chat_history: list[tuple[str, str]], user_id: str, session_id: str\n    ) -> tuple[str, list[tuple[str, str]]]:\n        \"\"\"Handle response in Gradio chat interface.\"\"\"\n        response = chat_fn(message, chat_history, user_id, session_id)\n        chat_history.append((message, response))\n        return \"\", chat_history\n\n    msg.submit(respond, [msg, chatbot, user_box, session_box], [msg, chatbot])\n\n\n# ---------- API Endpoints ----------\n@app.get(\"/api/download/report/{filename}\")\ndef download_report(filename: str):\n    \"\"\"Download a saved report file.\"\"\"\n    return get_report_download_response(filename)\n\n\n# ---------- Application Startup ----------\n# Mount Gradio app to FastAPI\napp = gr.mount_gradio_app(app, demo, path=\"/ui\")\n\nif __name__ == \"__main__\":\n    try:\n        uvicorn.run(app, host=\"0.0.0.0\", port=8000)\n    finally:\n        # Cleanup checkpointer\n        sqlite_checkpointer_cm.__exit__(None, None, None)\n"
  },
  {
    "path": "sample_ui/streaming_ui.py",
    "content": "\"\"\"Gradio-based Streaming UI for OpenChatBI with real-time chat interface.\"\"\"\n\nimport asyncio\nimport sys\nfrom collections import defaultdict\nfrom contextlib import asynccontextmanager\n\nimport gradio as gr\n\ntry:\n    import pysqlite3 as sqlite3\nexcept ImportError:  # pragma: no cover\n    import sqlite3\nfrom fastapi import FastAPI\nfrom langchain_core.messages import AIMessage\n\nsys.modules[\"sqlite3\"] = sqlite3\n\nfrom langgraph.types import Command\n\nfrom openchatbi.utils import get_report_download_response, get_text_from_message_chunk, log\nfrom sample_ui.async_graph_manager import AsyncGraphManager\nfrom sample_ui.plotly_utils import create_inline_chart_markdown, visualization_dsl_to_gradio_plot\nfrom sample_ui.style import custom_css\n\n# Session state storage: user_session_id -> state\nsession_interrupt = defaultdict(bool)\n\n# Global event loop for async operations (similar to Streamlit approach)\nglobal_event_loop = None\n\n\n# Global graph manager (similar to Streamlit approach)\ngraph_manager = AsyncGraphManager()\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    \"\"\"Async context manager for FastAPI lifespan\"\"\"\n    # Startup\n    await graph_manager.initialize()\n    yield\n    # Shutdown\n    await graph_manager.cleanup()\n\n\n# ---------- FastAPI ----------\napp = FastAPI(lifespan=lifespan)\n\n\n# ---------- Gradio UI functions ----------\n\n\ndef get_or_create_event_loop():\n    \"\"\"Get or create an independent event loop\"\"\"\n    global global_event_loop\n\n    if global_event_loop is None or global_event_loop.is_closed():\n        global_event_loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(global_event_loop)\n\n    return global_event_loop\n\n\nasync def _async_respond_helper(message, chat_history, user_id, session_id):\n    \"\"\"\n    Helper async function that contains the actual async logic.\n    This will be run in an independent event loop.\n    Collects all responses and returns them as a list.\n    \"\"\"\n    responses = []  # Collect all yield values\n\n    user_session_id = f\"{user_id}-{session_id}\"\n    full_response = \"\"\n    plot_figure = None\n    chart_panel_update = gr.update()\n\n    if session_interrupt[user_session_id]:\n        stream_input = Command(resume=message)\n    else:\n        stream_input = {\"messages\": [{\"role\": \"user\", \"content\": message}]}\n\n    config = {\"configurable\": {\"thread_id\": user_session_id, \"user_id\": user_id}}\n\n    # Ensure graph is available\n    if not graph_manager._initialized:\n        try:\n            await graph_manager.initialize()\n        except Exception as e:\n            log(f\"Failed to initialize graph: {e}\")\n            chat_history[-1] = (chat_history[-1][0], f\"Error: Failed to initialize system - {str(e)}\")\n            responses.append((\"\", chat_history, plot_figure, chart_panel_update))\n            return responses\n\n    data_csv = None\n    # Asynchronously iterate through LangGraph stream\n    async for _namespace, event_type, event_value in graph_manager.graph.astream(\n        stream_input, config=config, stream_mode=[\"updates\", \"messages\"], subgraphs=True, debug=True\n    ):\n        token = \"\"\n        if event_type == \"messages\":\n            chunk = event_value[0]\n            metadata = event_value[1]\n            # Keep llm node messages only to avoid duplicates\n            if metadata[\"langgraph_node\"] != \"llm_node\" or not metadata.get(\"streaming_tokens\", False):\n                continue\n            token = get_text_from_message_chunk(chunk)\n        else:\n            # Process intermediate graph node updates\n            if event_value.get(\"llm_node\"):\n                message_obj = event_value[\"llm_node\"].get(\"messages\")[0]\n                if message_obj and isinstance(message_obj, AIMessage) and message_obj.tool_calls:\n                    token = f\"\\nUse tool: {', '.join(tool['name'] for tool in message_obj.tool_calls)}\\n\"\n                else:\n                    token = \"\\n\"\n            elif event_value.get(\"information_extraction\"):\n                message_obj = event_value[\"information_extraction\"].get(\"messages\")[0]\n                if message_obj.tool_calls:\n                    token = f\"Use tool: {message_obj.tool_calls[0]['name']}\\n\"\n                else:\n                    token = f\"Rewrite question: {event_value['information_extraction'].get('rewrite_question')}\\n\"\n            elif event_value.get(\"table_selection\"):\n                token = f\"Selected tables: {event_value['table_selection'].get('tables')}\\n\"\n            elif event_value.get(\"generate_sql\"):\n                token = f\"SQL: \\n ```sql \\n{event_value['generate_sql'].get('sql')}\\n```\\n\"\n            elif event_value.get(\"execute_sql\"):\n                token = \"Running SQL...\\n\"\n                data_csv = event_value[\"execute_sql\"].get(\"data\")\n            elif event_value.get(\"regenerate_sql\"):\n                token = f\"SQL: \\n ```sql \\n{event_value['regenerate_sql'].get('sql')}\\n```\\n\"\n            elif event_value.get(\"generate_visualization\"):\n                visualization_dsl = event_value[\"generate_visualization\"].get(\"visualization_dsl\")\n                # Check for visualization data in the final state and embed in response\n                if visualization_dsl and \"error\" not in visualization_dsl and data_csv:\n                    try:\n                        plot_figure, plot_description = visualization_dsl_to_gradio_plot(data_csv, visualization_dsl)\n                        # Add markdown representation to the chat\n                        chart_markdown = create_inline_chart_markdown(data_csv, visualization_dsl)\n                        full_response += f\"\\n\\n{chart_markdown}\"\n                        chat_history[-1] = (chat_history[-1][0], full_response)\n                        # Auto-show chart panel when plot is generated\n                        chart_panel_update = gr.update(visible=True)\n                        responses.append((\"\", chat_history, plot_figure, chart_panel_update))\n                    except Exception as e:\n                        log(f\"Visualization generation error: {str(e)}\")\n                        full_response += f\"\\n\\n⚠️ Visualization error: {str(e)}\"\n                        chat_history[-1] = (chat_history[-1][0], full_response)\n                        responses.append((\"\", chat_history, plot_figure, chart_panel_update))\n\n        # Update chat history with new tokens and collect response\n        if token:\n            full_response += token\n            chat_history[-1] = (chat_history[-1][0], full_response)\n            responses.append((\"\", chat_history, plot_figure, chart_panel_update))\n\n    # Get final state and check for visualization data\n    state = await graph_manager.graph.aget_state(config)\n    final_state_values = state.values\n\n    if state.interrupts:\n        log(f\"state.interrupts: {state.interrupts}\")\n        output_content = state.interrupts[0].value.get(\"text\")\n        if \"buttons\" in state.interrupts[0].value:\n            output_content += str(state.interrupts[0].value.get(\"buttons\"))\n        full_response += output_content\n        chat_history[-1] = (chat_history[-1][0], full_response)\n        session_interrupt[user_session_id] = True\n        responses.append((\"\", chat_history, plot_figure, chart_panel_update))\n    else:\n        session_interrupt[user_session_id] = False\n\n    return responses\n\n\ndef respond(message, chat_history, user_id, session_id=\"default\"):\n    \"\"\"\n    Synchronous callback for Gradio Chatbot with streaming updates.\n\n    This function processes user input and streams responses from the LangGraph agent.\n    Returns: message_input, chat_history, plot_figure, chart_panel_visibility\n    \"\"\"\n    # Add a placeholder in chat history\n    chat_history.append((message, \"\"))\n    plot_figure = None\n    chart_panel_update = gr.update()\n    yield \"\", chat_history, plot_figure, chart_panel_update  # Stream updates to UI\n\n    # Get or create independent event loop\n    loop = get_or_create_event_loop()\n\n    # Run the async helper in the independent event loop\n    try:\n        responses = loop.run_until_complete(_async_respond_helper(message, chat_history, user_id, session_id))\n\n        # Yield all collected responses\n        for response in responses:\n            yield response\n\n    except Exception as e:\n        log(f\"Error in respond: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        chat_history[-1] = (chat_history[-1][0], f\"Error: {str(e)}\")\n        yield \"\", chat_history, plot_figure, chart_panel_update\n\n\n# ---------- Memory Management Functions ----------\ndef list_user_memories(user_id: str) -> str:\n    \"\"\"List all memories for a specific user.\"\"\"\n    try:\n        import json\n\n        try:\n            import pysqlite3 as sqlite3\n        except ImportError:\n            import sqlite3\n        from langgraph.store.sqlite import SqliteStore\n\n        from openchatbi import config\n\n        # Create a new connection in this thread to avoid SQLite threading issues\n        conn = sqlite3.connect(\"memory.db\", check_same_thread=False)\n        conn.isolation_level = None  # Use autocommit mode to avoid transaction conflicts\n        thread_memory_store = SqliteStore(\n            conn, index={\"dims\": 1536, \"embed\": config.get().embedding_model, \"fields\": [\"text\"]}\n        )\n        try:\n            thread_memory_store.setup()\n        except Exception:\n            pass  # Store might already be set up\n\n        memories = []\n        namespace = (\"memories\", user_id)\n\n        try:\n            # Use search with namespace to find all items for this user\n            items = thread_memory_store.search(namespace, limit=1000)\n            for item in items:\n                memory_data = {\n                    \"key\": item.key,\n                    \"value\": item.value,\n                    \"created_at\": getattr(item, \"created_at\", \"Unknown\"),\n                    \"updated_at\": getattr(item, \"updated_at\", \"Unknown\"),\n                }\n                memories.append(memory_data)\n        except Exception as e:\n            return f\"No memories found for user {user_id} or error: {str(e)}\"\n        finally:\n            conn.close()\n\n        if not memories:\n            return f\"No memories found for user {user_id}\"\n\n        formatted = [f\"## Memories for User: {user_id}\\n\"]\n        for i, memory in enumerate(memories, 1):\n            formatted.append(f\"### Memory {i}\")\n            formatted.append(f\"**Key:** {memory['key']}\")\n\n            value = memory[\"value\"]\n            if isinstance(value, dict):\n                try:\n                    value_str = json.dumps(value, indent=2)\n                    formatted.append(f\"**Content:**\\n```json\\n{value_str}\\n```\")\n                except ValueError:\n                    formatted.append(f\"**Content:** {str(value)}\")\n            else:\n                formatted.append(f\"**Content:** {str(value)}\")\n\n            formatted.append(f\"**Created:** {memory['created_at']}\")\n            formatted.append(f\"**Updated:** {memory['updated_at']}\")\n            formatted.append(\"---\")\n\n        return \"\\n\".join(formatted)\n\n    except Exception as e:\n        return f\"Error accessing memories: {str(e)}\"\n\n\n# ---------- Gradio UI Blocks ----------\n\n# Create Gradio interface with custom CSS and theme\nwith gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:\n    gr.Markdown(\"## 💬 OpenChatBI Agent Chatbot with Streaming & On-Demand Visualization\")\n\n    with gr.Tabs():\n        with gr.TabItem(\"💬 Chat\"):\n            with gr.Row():\n                with gr.Column(scale=4):\n                    chatbot = gr.Chatbot(\n                        elem_id=\"chatbot\",\n                        label=\"Chat\",\n                        bubble_full_width=False,\n                        height=500,\n                        show_label=False,\n                        sanitize_html=False,\n                        render_markdown=True,\n                    )\n                    msg = gr.Textbox(placeholder=\"Type a message and press Enter\", label=\"Input\", elem_id=\"msg\")\n\n                with gr.Column(scale=2, visible=False) as chart_panel:\n                    with gr.Row():\n                        with gr.Column(scale=3):\n                            gr.Markdown(\"### 📊 Interactive Chart\")\n                        with gr.Column(scale=1):\n                            hide_chart_btn = gr.Button(\"✖️ Hide\", elem_id=\"hide-chart-btn\", size=\"sm\")\n                    plot = gr.Plot(label=\"\", visible=True, show_label=False)\n\n                with gr.Column(scale=1):\n                    user_box = gr.Textbox(value=\"default\", label=\"User ID\", interactive=True)\n                    session_box = gr.Textbox(value=\"default\", label=\"Session ID\", interactive=True)\n                    show_chart_btn = gr.Button(\"📊 Show Chart Panel\", variant=\"secondary\")\n                    gr.Markdown(\n                        \"\"\"\n                    **Instructions**  \n                    - Type a data question and press Enter\n                    - Supports streaming output (real-time display)\n                    - Click chart links in chat to view interactive charts\n                    - Use 'Show Chart Panel' to make panel visible\n                    - Session ID can be used to differentiate between conversations\n                    \"\"\",\n                        elem_id=\"description\",\n                    )\n\n            def show_chart_panel():\n                \"\"\"Show the chart panel.\"\"\"\n                return gr.update(visible=True)\n\n            def hide_chart_panel():\n                \"\"\"Hide the chart panel.\"\"\"\n                return gr.update(visible=False)\n\n            # Register async submit handler for message input with plot output\n            msg.submit(respond, [msg, chatbot, user_box, session_box], [msg, chatbot, plot, chart_panel])\n            show_chart_btn.click(show_chart_panel, outputs=[chart_panel])\n            hide_chart_btn.click(hide_chart_panel, outputs=[chart_panel])\n\n        with gr.TabItem(\"🧠 Memory Store\"):\n            gr.Markdown(\"### Long-term Memory Viewer\")\n            gr.Markdown(\"View memories stored for each user in the system.\")\n\n            with gr.Row():\n                with gr.Column(scale=3):\n                    memory_display = gr.Markdown(\n                        value=\"Enter a User ID and click 'Load Memories' to view stored memories.\",\n                        elem_id=\"memory-display\",\n                    )\n\n                with gr.Column(scale=1):\n                    memory_user_input = gr.Textbox(label=\"User ID\", placeholder=\"default\", value=\"default\")\n                    load_memories_btn = gr.Button(\"🔍 Load Memories\", variant=\"primary\")\n\n            # Event handler for loading memories\n            load_memories_btn.click(fn=list_user_memories, inputs=[memory_user_input], outputs=[memory_display])\n\n\n# ---------- API Endpoints ----------\n@app.get(\"/api/download/report/{filename}\")\nasync def download_report(filename: str):\n    \"\"\"Download a saved report file.\"\"\"\n    return get_report_download_response(filename)\n\n\n# ---------- Application Startup ----------\n# Mount Gradio app to FastAPI\napp = gr.mount_gradio_app(app, demo, path=\"/ui\")\n\nif __name__ == \"__main__\":\n    import uvicorn\n\n    uvicorn.run(app, host=\"0.0.0.0\", port=8000)\n"
  },
  {
    "path": "sample_ui/streamlit_ui.py",
    "content": "\"\"\"Streamlit-based Streaming UI for OpenChatBI with collapsible thinking sections.\"\"\"\n\nimport asyncio\nimport sys\nimport traceback\nimport uuid\nfrom pathlib import Path\n\nimport plotly.graph_objects as go\n\ntry:\n    import pysqlite3 as sqlite3\nexcept ImportError:  # pragma: no cover\n    import sqlite3\nimport streamlit as st\nfrom langchain_core.messages import AIMessage\n\nsys.modules[\"sqlite3\"] = sqlite3\n\nfrom langgraph.types import Command\n\nfrom openchatbi import config as openchatbi_config\nfrom openchatbi.llm.llm import list_llm_providers\nfrom openchatbi.utils import get_text_from_message_chunk, log\nfrom sample_ui.plotly_utils import visualization_dsl_to_gradio_plot\nfrom sample_ui.async_graph_manager import AsyncGraphManager\n\n# Configuration\nst.set_page_config(page_title=\"OpenChatBI - Streamlit Interface\", page_icon=\"💬\", layout=\"wide\")\n\n# Initialize session state\nif \"messages\" not in st.session_state:\n    st.session_state.messages = []\nif \"graph_manager\" not in st.session_state:\n    st.session_state.graph_manager = AsyncGraphManager()\nif \"session_interrupts\" not in st.session_state:\n    st.session_state.session_interrupts = {}\nif \"event_loop\" not in st.session_state:\n    st.session_state.event_loop = None\n\n\nasync def process_user_message_stream(\n    message: str, user_id: str, session_id: str, llm_provider: str | None, thinking_container, response_container\n):\n    \"\"\"\n    Process user message through the OpenChatBI graph with real-time updates\n    Updates the thinking_container and response_container as processing happens\n    \"\"\"\n    thinking_steps = []\n    final_response = \"\"\n    plot_figure = None\n\n    # Initialize graph if needed\n    if not st.session_state.graph_manager._initialized:\n        await st.session_state.graph_manager.initialize()\n    graph = await st.session_state.graph_manager.get_graph(llm_provider)\n\n    user_session_id = f\"{user_id}-{session_id}\"\n\n    # Check for interrupts\n    if st.session_state.session_interrupts.get(user_session_id, False):\n        stream_input = Command(resume=message)\n    else:\n        stream_input = {\"messages\": [{\"role\": \"user\", \"content\": message}]}\n\n    config = {\"configurable\": {\"thread_id\": user_session_id, \"user_id\": user_id}}\n\n    data_csv = None\n\n    # Use empty container for real-time updates\n    thinking_placeholder = thinking_container.empty()\n\n    # Build content chronologically - all events in time order\n    base_content = \"🔄 **Processing...**\\n\\n\"\n    chronological_content = \"\"  # All content in time order\n\n    def update_display():\n        full_content = base_content + chronological_content\n        thinking_placeholder.markdown(full_content)\n\n    # Initial display\n    update_display()\n\n    # Stream through the graph\n    async for _namespace, event_type, event_value in graph.astream(\n        stream_input, config=config, stream_mode=[\"updates\", \"messages\"], subgraphs=True, debug=True\n    ):\n        if event_type == \"messages\":\n            chunk = event_value[0]\n            metadata = event_value[1]\n            # Keep llm node messages only to avoid duplicates\n            if metadata[\"langgraph_node\"] != \"llm_node\" or not metadata.get(\"streaming_tokens\", False):\n                continue\n            token = get_text_from_message_chunk(chunk)\n            if token:\n                final_response += token\n\n                # Add to thinking content during processing\n                if len(final_response) == len(token):\n                    chronological_content += \"\\n**🤖 AI Response:** \"\n                chronological_content += token\n                update_display()\n\n        else:\n            # Process tool calls and intermediate steps\n            step_description = \"\"\n            if event_value.get(\"llm_node\"):\n                message_obj = event_value[\"llm_node\"].get(\"messages\")[0]\n                if message_obj and isinstance(message_obj, AIMessage) and message_obj.tool_calls:\n                    step_description = f\"🛠️ Using tools: {', '.join(tool['name'] for tool in message_obj.tool_calls)}\"\n\n            elif event_value.get(\"information_extraction\"):\n                message_obj = event_value[\"information_extraction\"].get(\"messages\")[0]\n                if message_obj and message_obj.tool_calls:\n                    step_description = f\"🛠️ Using tool: {message_obj.tool_calls[0]['name']}\"\n                else:\n                    rewrite_q = event_value[\"information_extraction\"].get(\"rewrite_question\")\n                    if rewrite_q:\n                        step_description = f\"📝 Rewriting question: {rewrite_q}\"\n\n            elif event_value.get(\"table_selection\"):\n                tables = event_value[\"table_selection\"].get(\"tables\")\n                if tables:\n                    step_description = f\"🗂️ Selected tables: {tables}\"\n\n            elif event_value.get(\"generate_sql\"):\n                sql = event_value[\"generate_sql\"].get(\"sql\")\n                if sql:\n                    step_description = f\"💾 Generated SQL:\\n```sql\\n{sql}\\n```\"\n\n            elif event_value.get(\"execute_sql\"):\n                step_description = \"⚡ Executing SQL query...\"\n                data_csv = event_value[\"execute_sql\"].get(\"data\")\n\n            elif event_value.get(\"regenerate_sql\"):\n                sql = event_value[\"regenerate_sql\"].get(\"sql\")\n                if sql:\n                    step_description = f\"🔄 Regenerated SQL:\\n```sql\\n{sql}\\n```\"\n\n            elif event_value.get(\"generate_visualization\"):\n                visualization_dsl = event_value[\"generate_visualization\"].get(\"visualization_dsl\")\n                if visualization_dsl and \"error\" not in visualization_dsl and data_csv:\n                    try:\n                        plot_figure, plot_description = visualization_dsl_to_gradio_plot(data_csv, visualization_dsl)\n                        step_description = f\"📊 Generated visualization: {plot_description}\"\n                    except Exception as e:\n                        step_description = f\"⚠️ Visualization error: {str(e)}\"\n\n            if step_description:\n                thinking_steps.append(step_description)\n\n                # Append new step to chronological content in time order\n                step_number = len(thinking_steps)\n                # Ensure proper spacing before new step\n                if chronological_content and not chronological_content.endswith(\"\\n\\n\"):\n                    chronological_content += \"\\n\\n\"\n                chronological_content += f\"**Step {step_number}:** {step_description}\\n\\n\"\n\n                update_display()\n\n    # Check for interrupts in final state\n    state = await graph.aget_state(config)\n    if state.interrupts:\n        log(f\"State interrupts: {state.interrupts}\")\n        output_content = state.interrupts[0].value.get(\"text\", \"\")\n        if \"buttons\" in state.interrupts[0].value:\n            output_content += str(state.interrupts[0].value.get(\"buttons\"))\n        final_response += output_content\n\n        # Append interrupt content to chronological content\n        chronological_content += output_content\n        update_display()\n\n        st.session_state.session_interrupts[user_session_id] = True\n    else:\n        st.session_state.session_interrupts[user_session_id] = False\n\n    # Final update - add completion message to chronological content\n    # Add some spacing if the last content didn't end with newlines\n    if not chronological_content.endswith(\"\\n\\n\"):\n        chronological_content += \"\\n\\n\"\n    chronological_content += \"✅ **Analysis complete!**\"\n    update_display()\n\n    # Extract final answer (last part without tool calls) and display outside thinking\n    if final_response:\n        # Find the last occurrence of tool usage to separate final answer\n        lines = final_response.split(\"\\n\")\n        final_answer_lines = []\n        collecting_final = False\n\n        for line in reversed(lines):\n            if \"Use tool:\" in line or \"Using tools:\" in line or \"Using tool:\" in line:\n                break\n            final_answer_lines.append(line)\n            collecting_final = True\n\n        if collecting_final and final_answer_lines:\n            # Reverse back to correct order\n            final_answer_lines.reverse()\n            final_answer_text = \"\\n\".join(final_answer_lines).strip()\n\n            if final_answer_text:\n                with response_container:\n                    processed_final_answer_text = process_download_links(final_answer_text)\n                    render_content_with_downloads(processed_final_answer_text)\n\n    # Final update to response container - only show plot if available (text response is in thinking container)\n    with response_container:\n        if plot_figure:\n            st.plotly_chart(plot_figure, use_container_width=True, key=str(uuid.uuid4()))\n\n    # Extract final answer for separate storage\n    final_answer_text = \"\"\n    if final_response:\n        lines = final_response.split(\"\\n\")\n        final_answer_lines = []\n        collecting_final = False\n\n        for line in reversed(lines):\n            if \"Use tool:\" in line or \"Using tools:\" in line or \"Using tool:\" in line:\n                break\n            final_answer_lines.append(line)\n            collecting_final = True\n\n        if collecting_final and final_answer_lines:\n            final_answer_lines.reverse()\n            final_answer_text = \"\\n\".join(final_answer_lines).strip()\n\n    return final_response, plot_figure, thinking_steps, chronological_content, final_answer_text\n\n\ndef get_available_reports() -> list[str]:\n    \"\"\"Get list of available report files for download.\"\"\"\n    try:\n        # Import config here to avoid circular imports\n        from openchatbi import config\n\n        report_dir = Path(config.get().report_directory)\n        if not report_dir.exists():\n            return []\n\n        # Get all files in the report directory\n        report_files = []\n        for file_path in report_dir.iterdir():\n            if file_path.is_file():\n                report_files.append(file_path.name)\n\n        return sorted(report_files)\n    except Exception as e:\n        st.error(f\"Error accessing reports: {str(e)}\")\n        return []\n\n\ndef get_report_file_content(filename: str) -> tuple[bytes | None, str | None]:\n    \"\"\"Get report file content for download.\n\n    Returns:\n        tuple: (file_content_bytes, mime_type) or (None, None) if error\n    \"\"\"\n    try:\n        # Import config here to avoid circular imports\n        from openchatbi import config\n\n        report_dir = Path(config.get().report_directory)\n        file_path = report_dir / filename\n\n        # Security check - ensure file is within report directory\n        if not file_path.exists() or not file_path.is_file():\n            st.error(f\"Report file not found: {filename}\")\n            return None, None\n\n        try:\n            file_path.resolve().relative_to(report_dir.resolve())\n        except ValueError:\n            st.error(\"Access denied to file\")\n            return None, None\n\n        # Determine MIME type\n        mime_type_map = {\n            \".md\": \"text/markdown\",\n            \".csv\": \"text/csv\",\n            \".txt\": \"text/plain\",\n            \".json\": \"application/json\",\n            \".html\": \"text/html\",\n            \".xml\": \"application/xml\",\n        }\n\n        file_extension = file_path.suffix.lower()\n        mime_type = mime_type_map.get(file_extension, \"application/octet-stream\")\n\n        # Read file content\n        with open(file_path, \"rb\") as f:\n            content = f.read()\n\n        return content, mime_type\n\n    except Exception as e:\n        st.error(f\"Error reading report file: {str(e)}\")\n        return None, None\n\n\ndef process_download_links(content: str) -> str:\n    \"\"\"Process download links in content and replace them with Streamlit-compatible ones.\n\n    Args:\n        content: Message content that may contain download links\n\n    Returns:\n        str: Content with download links replaced\n    \"\"\"\n    import re\n\n    if not content:\n        return content\n\n    # Pattern to match both full URLs and path-only download links\n    # Matches: http://localhost:8501/api/download/report/filename.ext or /api/download/report/filename.ext\n    download_pattern = r\"(?:https?://[^/\\s]+)?/api/download/report/([^)\\s\\]<>]+)\"\n\n    def replace_link(match):\n        filename = match.group(1)\n        # Return a placeholder that we'll replace with actual download button\n        return f\"[DOWNLOAD_LINK:{filename}]\"\n\n    processed_content = re.sub(download_pattern, replace_link, content)\n\n    # Debug log to see if processing worked\n    if processed_content != content:\n        st.write(f\"🔍 Debug: Processed download links - found {content.count('/api/download/report/')} links\")\n\n    return processed_content\n\n\ndef render_content_with_downloads(content: str) -> None:\n    \"\"\"Render content and replace download placeholders with actual download buttons.\"\"\"\n    import re\n\n    # Split content by download placeholders\n    download_pattern = r\"\\[DOWNLOAD_LINK:([^)]+)\\]\"\n    parts = re.split(download_pattern, content)\n\n    for i, part in enumerate(parts):\n        if i % 2 == 0:\n            # Regular content\n            if part.strip():\n                st.markdown(part)\n        else:\n            # Download link filename\n            filename = part\n            file_content, mime_type = get_report_file_content(filename)\n\n            if file_content is not None:\n                st.download_button(\n                    label=f\"📥 Download {filename}\",\n                    data=file_content,\n                    file_name=filename,\n                    mime=mime_type,\n                    key=f\"inline_download_{filename}_{hash(content)}\",\n                )\n            else:\n                st.error(f\"❌ Could not load report: {filename}\")\n\n\ndef display_message_with_thinking(\n    role: str, content: str, thinking_steps: list[str] = None, plot_figure: go.Figure = None\n):\n    \"\"\"Display a message with collapsible thinking section\"\"\"\n    with st.chat_message(role):\n        if thinking_steps and role == \"assistant\":\n            # Create thinking section with all content inside\n            with st.expander(\"💭 AI Thinking Process\", expanded=False):\n                for i, step in enumerate(thinking_steps, 1):\n                    st.markdown(f\"**Step {i}:** {step}\")\n\n                if content:\n                    st.markdown(\"**🤖 AI Response:**\")\n                    render_content_with_downloads(content)\n\n                st.success(\"✅ Analysis complete\")\n\n        # For non-assistant messages, display content normally\n        elif content and role != \"assistant\":\n            render_content_with_downloads(content)\n\n        # Display plot if available (outside thinking container)\n        if plot_figure:\n            st.plotly_chart(plot_figure, use_container_width=True, key=str(uuid.uuid4()))\n\n\n# Main UI\nst.title(\"💬 OpenChatBI - Streamlit UI\")\nst.markdown(\"*AI-powered Business Intelligence Chat with Thinking*\")\n\n# Sidebar for configuration\nwith st.sidebar:\n    st.header(\"⚙️ Configuration\")\n    user_id = st.text_input(\"User ID\", value=\"default\", help=\"Unique identifier for the user session\")\n    session_id = st.text_input(\"Session ID\", value=\"default\", help=\"Session identifier for conversation continuity\")\n\n    # Optional multi-provider support\n    llm_provider = None\n    provider_options = list_llm_providers()\n    if provider_options:\n        try:\n            default_provider = getattr(openchatbi_config.get(), \"llm_provider\", None)\n        except Exception:\n            default_provider = None\n        default_index = provider_options.index(default_provider) if default_provider in provider_options else 0\n        llm_provider = st.selectbox(\n            \"LLM Provider\",\n            options=provider_options,\n            index=default_index,\n            help=\"Select which configured LLM provider to use for this session\",\n        )\n\n    st.markdown(\"---\")\n    st.markdown(\n        \"\"\"\n    **💡 How to use:**\n    - Type your business questions\n    - Watch the AI thinking process in collapsible sections\n    - View generated charts and analyses\n    - Use different session IDs for separate conversations\n    \"\"\"\n    )\n\n    if st.button(\"🗑️ Clear Chat History\"):\n        st.session_state.messages = []\n        st.rerun()\n\n    st.markdown(\"---\")\n    st.markdown(\"### 📁 Report Downloads\")\n\n    # Get available reports\n    available_reports = get_available_reports()\n\n    if available_reports:\n        selected_report = st.selectbox(\n            \"Select a report to download:\", options=[\"\"] + available_reports, help=\"Choose a report file to download\"\n        )\n\n        if selected_report and st.button(\"📥 Download Report\"):\n            file_content, mime_type = get_report_file_content(selected_report)\n            if file_content is not None:\n                st.download_button(\n                    label=f\"💾 Save {selected_report}\",\n                    data=file_content,\n                    file_name=selected_report,\n                    mime=mime_type,\n                    key=f\"download_{selected_report}\",\n                )\n                st.success(f\"✅ {selected_report} is ready for download!\")\n    else:\n        st.info(\"No reports available for download.\")\n\n# Display chat history\nfor msg in st.session_state.messages:\n    if msg[\"type\"] == \"chronological_message\":\n        # Display chronological content in expander - all collapsed after completion\n        with st.chat_message(msg[\"role\"]):\n            with st.expander(\"💭 AI Thinking Process\", expanded=False):\n                st.markdown(msg[\"chronological_content\"])\n\n            # Extract and display final answer text outside thinking\n            if msg.get(\"final_answer\"):\n                render_content_with_downloads(msg[\"final_answer\"])\n\n            # Display plot if available (outside thinking container)\n            if msg.get(\"plot_figure\"):\n                st.plotly_chart(msg[\"plot_figure\"], use_container_width=True, key=str(uuid.uuid4()))\n\n    elif msg[\"type\"] == \"thinking_message\":\n        display_message_with_thinking(\n            msg[\"role\"], msg[\"content\"], msg.get(\"thinking_steps\", []), msg.get(\"plot_figure\")\n        )\n    else:\n        with st.chat_message(msg[\"role\"]):\n            if msg[\"type\"] == \"text\":\n                render_content_with_downloads(msg[\"content\"])\n            elif msg[\"type\"] == \"plot\" and msg.get(\"plot_figure\"):\n                st.plotly_chart(msg[\"plot_figure\"], use_container_width=True, key=str(uuid.uuid4()))\n\n# Chat input\nif prompt := st.chat_input(\"Ask me anything about your data...\"):\n    # Add user message\n    st.session_state.messages.append({\"role\": \"user\", \"type\": \"text\", \"content\": prompt})\n\n    # Display user message immediately\n    with st.chat_message(\"user\"):\n        st.markdown(prompt)\n\n    # Process assistant response with real-time streaming\n    with st.chat_message(\"assistant\"):\n        # Create thinking and response containers\n        thinking_expander = st.expander(\"💭 AI Thinking Process...\", expanded=True)\n        thinking_container = thinking_expander.container()\n        response_container = st.container()\n\n        # Process the message asynchronously with real-time updates\n        try:\n            # Reuse the same event loop to avoid binding issues\n            if st.session_state.event_loop is None or st.session_state.event_loop.is_closed():\n                st.session_state.event_loop = asyncio.new_event_loop()\n                asyncio.set_event_loop(st.session_state.event_loop)\n\n            loop = st.session_state.event_loop\n            final_response, plot_figure, thinking_steps, full_chronological_content, final_answer = (\n                loop.run_until_complete(\n                    process_user_message_stream(\n                        prompt, user_id, session_id, llm_provider, thinking_container, response_container\n                    )\n                )\n            )\n\n            # No need to create another expander - content is already shown in real-time\n            # Process download links in the content before storing\n            processed_chronological_content = process_download_links(full_chronological_content)\n            processed_final_answer = process_download_links(final_answer) if final_answer else final_answer\n\n            # Store the complete message with the processed content\n            st.session_state.messages.append(\n                {\n                    \"role\": \"assistant\",\n                    \"type\": \"chronological_message\",\n                    \"chronological_content\": processed_chronological_content,\n                    \"final_answer\": processed_final_answer,\n                    \"plot_figure\": plot_figure,\n                }\n            )\n\n            # Trigger rerun to collapse the thinking section\n            st.rerun()\n\n        except Exception as e:\n            traceback.print_exc()\n            st.error(f\"❌ Error processing request: {str(e)}\")\n            error_content = f\"❌ Error: {str(e)}\"\n            processed_error_content = process_download_links(error_content)\n            st.session_state.messages.append({\"role\": \"assistant\", \"type\": \"text\", \"content\": processed_error_content})\n\n\n# Cleanup on session end\ndef cleanup_session():\n    \"\"\"Cleanup resources when session ends\"\"\"\n    if \"graph_manager\" in st.session_state:\n        try:\n            # Use the same event loop for cleanup\n            if st.session_state.event_loop and not st.session_state.event_loop.is_closed():\n                loop = st.session_state.event_loop\n                loop.run_until_complete(st.session_state.graph_manager.cleanup())\n                loop.close()\n                st.session_state.event_loop = None\n        except Exception as e:\n            log(f\"Error during session cleanup: {e}\")\n\n\n# Register cleanup (this is a simplified approach - in production you might want more robust cleanup)\nimport atexit\n\natexit.register(cleanup_session)\n"
  },
  {
    "path": "sample_ui/style.py",
    "content": "# Custom CSS for styling the chat interface\ncustom_css = \"\"\"\n#chatbot {\n    height: 600px !important;\n    font-family: \"Inter\", \"Helvetica Neue\", sans-serif;\n}\n#chatbot .wrap.svelte-1cl84sx {\n    background: #f5f7fa;\n    border-radius: 12px;\n    padding: 8px;\n}\n.message.user {\n    background-color: #d1e9ff !important;\n    border-radius: 12px 12px 0 12px;\n    margin: 4px 0;\n    padding: 10px 14px;\n    font-size: 15px;\n}\n.message.bot {\n    background-color: #ffffff !important;\n    border-radius: 12px 12px 12px 0;\n    margin: 4px 0;\n    padding: 10px 14px;\n    font-size: 15px;\n    box-shadow: 0px 1px 3px rgba(0,0,0,0.08);\n}\n#msg {\n    font-family: \"Inter\", \"Helvetica Neue\", sans-serif;\n    font-size: 15px;\n}\n#description {\n    font-family: \"Inter\", \"Helvetica Neue\", sans-serif;\n    font-size: 14px;\n    color: #374151;\n    line-height: 1.6;\n}\n\"\"\"\n"
  },
  {
    "path": "tests/README.md",
    "content": "# OpenChatBI Test Suite\n\nThis directory contains comprehensive unit tests for the OpenChatBI project. The test suite is built using pytest and follows modern Python testing best practices.\n\n## Test Structure\n\n```\ntests/\n├── __init__.py                          # Test package initialization\n├── conftest.py                          # Shared fixtures and configuration\n├── README.md                            # This file\n│\n├── Core Module Tests\n├── test_config_loader.py                # Configuration loading tests\n├── test_graph_state.py                  # State management tests\n├── test_utils.py                        # Utility function tests\n│\n├── Catalog System Tests\n├── test_catalog_store.py                # Catalog store interface tests\n├── test_catalog_loader.py               # Database catalog loading tests\n│\n├── Text2SQL Pipeline Tests\n├── test_text2sql_extraction.py          # Information extraction tests\n├── test_text2sql_generate_sql.py        # SQL generation tests\n├── test_text2sql_schema_linking.py      # Schema linking tests\n├── test_text2sql_visualization.py       # Data visualization tests\n│\n├── Tool Tests\n├── test_tools_ask_human.py              # Human interaction tool tests\n├── test_tools_run_python_code.py        # Python code execution tests\n├── test_tools_search_knowledge.py       # Knowledge search tests\n│\n├── Additional Module Tests\n├── test_memory.py                       # Memory management tests\n├── test_plotly_utils.py                 # Plotly utilities tests\n├── test_incomplete_tool_calls.py        # Incomplete tool call handling tests\n│\n└── Context Management Tests\n    └── context_management/              # Context management test suite (see context_management/README.md)\n```\n\n## Running Tests\n\n### Prerequisites\n\nEnsure you have the development dependencies installed:\n\n```bash\n# Using uv (recommended)\nuv sync --group dev\n\n# Or using pip\npip install -e \".[dev]\"\n```\n\n### Basic Test Execution\n\n```bash\n# Run all tests\nuv run pytest\n\n# Run tests with verbose output\nuv run pytest -v\n\n# Run specific test file\nuv run pytest tests/test_config_loader.py\n\n# Run specific test class\nuv run pytest tests/test_config_loader.py::TestConfigLoader\n\n# Run specific test method\nuv run pytest tests/test_config_loader.py::TestConfigLoader::test_load_config_from_file\n```\n\n### Test Coverage\n\n```bash\n# Run tests with coverage report\nuv run pytest --cov=openchatbi --cov-report=html --cov-report=term-missing\n\n# View HTML coverage report\nopen htmlcov/index.html\n```\n\n### Test Categories\n\n```bash\n# Run only fast unit tests (exclude slow integration tests)\nuv run pytest -m \"not slow\"\n\n# Run tests for specific components\nuv run pytest tests/test_catalog* -k \"catalog\"\nuv run pytest tests/test_text2sql* -k \"text2sql\"\nuv run pytest tests/test_tools* -k \"tools\"\n\n# Run context management tests\nuv run pytest tests/context_management/\n\n# Run memory and utility tests\nuv run pytest tests/test_memory.py tests/test_plotly_utils.py\n\n# Run incomplete tool call tests\nuv run pytest tests/test_incomplete_tool_calls.py\n```\n\n## Test Configuration\n\n### Environment Variables\n\nThe test suite uses several environment variables that can be set to customize test behavior:\n\n- `OPENCHATBI_TEST_MODE=true` - Enables test mode\n- `OPENCHATBI_CONFIG_PATH` - Path to test configuration file\n- `PYTEST_TIMEOUT=300` - Test timeout in seconds\n\n### Fixtures\n\nThe `conftest.py` file provides shared fixtures used across tests:\n\n#### Core Fixtures\n- `test_config` - Test configuration dictionary\n- `temp_dir` - Temporary directory for test files\n- `mock_llm` - Mocked language model for testing\n- `sample_agent_state` - Sample AgentState for testing\n\n#### Catalog Fixtures\n- `mock_catalog_store` - Mocked catalog store with sample data\n- `mock_database_engine` - Mocked database engine\n- `sample_table_info` - Sample table metadata\n\n#### Database Fixtures\n- `mock_presto_connection` - Mocked Presto database connection\n- `mock_token_service` - Mocked authentication token service\n\n## Writing Tests\n\n### Test Naming Conventions\n\nFollow these naming conventions for consistency:\n\n```python\n# Test files\ntest_<module_name>.py\n\n# Test classes\nclass TestModuleName:\n\n# Test methods\ndef test_specific_functionality(self):\ndef test_error_condition_handling(self):\ndef test_edge_case_scenario(self):\n```\n\n### Test Categories\n\nUse pytest marks to categorize tests:\n\n```python\nimport pytest\n\n@pytest.mark.unit\ndef test_basic_functionality():\n    \"\"\"Unit test for basic functionality.\"\"\"\n    pass\n\n@pytest.mark.integration\ndef test_database_integration():\n    \"\"\"Integration test with database.\"\"\"\n    pass\n\n@pytest.mark.slow\ndef test_performance_benchmark():\n    \"\"\"Slow performance test.\"\"\"\n    pass\n\n@pytest.mark.parametrize(\"input,expected\", [\n    (\"test1\", \"result1\"),\n    (\"test2\", \"result2\"),\n])\ndef test_multiple_scenarios(input, expected):\n    \"\"\"Test multiple input/output scenarios.\"\"\"\n    pass\n```\n\n### Mocking Best Practices\n\nUse proper mocking for external dependencies:\n\n```python\nfrom unittest.mock import Mock, patch, MagicMock\n\n# Mock external services\n@patch('openchatbi.module.external_service')\ndef test_with_external_service(mock_service):\n    mock_service.return_value = \"expected_result\"\n    # Test implementation\n\n# Mock LLM responses\ndef test_llm_integration(mock_llm):\n    mock_llm.invoke.return_value = AIMessage(content=\"Mock response\")\n    # Test implementation\n```\n\n### Async Test Support\n\nFor testing async functionality:\n\n```python\nimport pytest\nimport asyncio\n\n@pytest.mark.asyncio\nasync def test_async_functionality():\n    \"\"\"Test asynchronous operations.\"\"\"\n    result = await some_async_function()\n    assert result is not None\n```\n\n## Test Data\n\n### Sample Data Files\n\nTest data is managed through fixtures and temporary files:\n\n```python\ndef test_with_sample_data(temp_dir):\n    \"\"\"Test using temporary sample data.\"\"\"\n    # Create test data file\n    data_file = temp_dir / \"test_data.csv\"\n    data_file.write_text(\"col1,col2\\\\nval1,val2\")\n    \n    # Test with the data\n    assert data_file.exists()\n```\n\n### Mock Responses\n\nCommon mock responses are defined in fixtures:\n\n```python\n# SQL generation mock response\nmock_llm.invoke.return_value = AIMessage(\n    content=\"SELECT COUNT(*) FROM test_table;\"\n)\n\n# Catalog search mock response\nmock_catalog.search_tables.return_value = [\n    {\"table_name\": \"users\", \"description\": \"User data\"}\n]\n```\n\n## Continuous Integration\n\n### GitHub Actions\n\nTests run automatically on:\n- Pull requests\n- Pushes to main branch\n- Scheduled runs (daily)\n\n### Test Matrix\n\nTests run against multiple configurations:\n- Python versions: 3.11+\n- Dependencies: Minimum and latest versions\n\n## Debugging Tests\n\n### Common Issues\n\n1. **Import Errors**\n   ```bash\n   # Ensure package is installed in development mode\n   pip install -e .\n   ```\n\n2. **Missing Dependencies**\n   ```bash\n   # Install test dependencies\n   pip install -e \".[test]\"\n   ```\n\n3. **Configuration Issues**\n   ```bash\n   # Set test environment variables\n   export OPENCHATBI_TEST_MODE=true\n   ```\n\n### Debug Output\n\nEnable debug output for failing tests:\n\n```bash\n# Run with debug output\nuv run pytest -v -s --tb=long\n\n# Run with pdb on failures\nuv run pytest --pdb\n\n# Run with coverage debug\nuv run pytest --cov-report=term-missing -v\n```\n\n## Performance Testing\n\n### Benchmarks\n\nPerformance tests are marked with `@pytest.mark.slow`:\n\n```bash\n# Run performance tests\nuv run pytest -m slow\n\n# Skip performance tests\nuv run pytest -m \"not slow\"\n```\n\n### Memory Profiling\n\nFor memory usage testing:\n\n```bash\n# Install memory profiler\npip install memory-profiler\n\n# Run with memory profiling\nuv run pytest --profile-mem\n```\n\n## Contributing\n\n### Adding New Tests\n\n1. Create test file following naming conventions\n2. Import required fixtures from `conftest.py`\n3. Write comprehensive test cases covering:\n   - Happy path scenarios\n   - Error conditions\n   - Edge cases\n   - Performance considerations\n\n4. Use appropriate mocking for external dependencies\n5. Add docstrings explaining test purpose\n6. Run tests locally before submitting PR\n\n### Test Review Guidelines\n\nWhen reviewing test PRs:\n- Ensure adequate test coverage\n- Verify mock usage is appropriate\n- Check for test independence\n- Validate error case handling\n- Confirm performance test categorization\n\n## Resources\n\n- [Pytest Documentation](https://docs.pytest.org/)\n- [Python unittest.mock](https://docs.python.org/3/library/unittest.mock.html)\n- [Coverage.py Documentation](https://coverage.readthedocs.io/)\n- [pytest-asyncio](https://pytest-asyncio.readthedocs.io/)"
  },
  {
    "path": "tests/__init__.py",
    "content": "\"\"\"Test package for OpenChatBI.\"\"\"\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "\"\"\"Pytest configuration and shared fixtures.\"\"\"\n\nimport tempfile\nfrom collections.abc import Generator\nfrom pathlib import Path\nfrom typing import Any\nfrom unittest.mock import Mock\n\nimport pytest\nfrom langchain_core.language_models import FakeListChatModel\nfrom langchain_core.messages import AIMessage, HumanMessage\nfrom sqlalchemy import create_engine\n\nfrom openchatbi.catalog.store.file_system import FileSystemCatalogStore\nfrom openchatbi.config_loader import ConfigLoader\nfrom openchatbi.graph_state import AgentState\n\n\n@pytest.fixture(scope=\"session\")\ndef test_config() -> dict[str, Any]:\n    \"\"\"Test configuration fixture.\"\"\"\n    return {\n        \"organization\": \"TestOrg\",\n        \"dialect\": \"presto\",\n        \"bi_config_file\": \"test_bi.yaml\",\n        \"catalog_store\": {\"store_type\": \"file_system\", \"data_path\": \"./test_data\"},\n        \"default_llm\": {\n            \"class\": \"langchain_core.language_models.FakeListChatModel\",\n            \"params\": {\"responses\": [\"Test response\"]},\n        },\n    }\n\n\n@pytest.fixture\ndef temp_dir() -> Generator[Path, None, None]:\n    \"\"\"Temporary directory fixture.\"\"\"\n    with tempfile.TemporaryDirectory() as temp_dir:\n        yield Path(temp_dir)\n\n\n@pytest.fixture\ndef mock_llm() -> FakeListChatModel:\n    \"\"\"Mock LLM fixture for testing.\"\"\"\n    return FakeListChatModel(\n        responses=[\"SELECT COUNT(*) FROM test_table;\", \"This is a test SQL query.\", \"Test analysis result.\"]\n    )\n\n\n@pytest.fixture\ndef sample_agent_state() -> AgentState:\n    \"\"\"Sample agent state for testing.\"\"\"\n    return AgentState(\n        messages=[HumanMessage(content=\"Test query\")],\n        sql=\"SELECT * FROM test_table;\",\n        agent_next_node=\"sql_generation\",\n        final_answer=\"Test data results\",\n    )\n\n\n@pytest.fixture\ndef mock_catalog_store(temp_dir: Path) -> FileSystemCatalogStore:\n    \"\"\"Mock catalog store fixture.\"\"\"\n    # Create test data files\n    test_data_dir = temp_dir / \"test_data\"\n    test_data_dir.mkdir(exist_ok=True)\n\n    # Create sample table_columns.csv\n    tables_info_file = test_data_dir / \"table_info.yaml\"\n    tables_info_file.write_text(\n        \"\"\"test:\n  test_table:\n    type: fact\n    description: A test table for unit tests\n  user_data:\n    type: fact\n    description: User information table\"\"\"\n    )\n\n    # Create sample table_columns.csv\n    tables_file = test_data_dir / \"table_columns.csv\"\n    tables_file.write_text(\n        \"\"\"db_name,table_name,column_name\ntest,test_table,id\ntest,test_table,name\ntest,user_data,user_id\"\"\"\n    )\n\n    # Create sample table_spec_columns.csv\n    columns_file = test_data_dir / \"table_spec_columns.csv\"\n    columns_file.write_text(\n        \"\"\"db_name,table_name,column_name,type,display_name,description\ntest,test_table,id,bigint,Id,Primary key\ntest,test_table,name,varchar,Name,User name\ntest,user_data,user_id,bigint,User Id,User identifier\"\"\"\n    )\n\n    # Create sample common_columns.csv\n    common_columns_file = test_data_dir / \"common_columns.csv\"\n    common_columns_file.write_text(\n        \"\"\"column_name,type,display_name,description\nstatus,varchar,Status,Record status\ncreated_at,timestamp,Created At,Creation timestamp\nupdated_at,timestamp,Updated At,Last update timestamp\"\"\"\n    )\n\n    # Mock data warehouse config\n    data_warehouse_config = {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"}\n\n    return FileSystemCatalogStore(data_path=str(test_data_dir), data_warehouse_config=data_warehouse_config)\n\n\n@pytest.fixture\ndef mock_database_engine():\n    \"\"\"Mock database engine fixture.\"\"\"\n    engine = create_engine(\"sqlite:///:memory:\")\n\n    # Create test tables\n    with engine.connect() as conn:\n        conn.execute(\"CREATE TABLE test_table (id INTEGER, name TEXT)\")\n        conn.execute(\"INSERT INTO test_table VALUES (1, 'Test User')\")\n        conn.commit()\n\n    return engine\n\n\n@pytest.fixture\ndef sample_table_info() -> dict[str, Any]:\n    \"\"\"Sample table information fixture.\"\"\"\n    return {\n        \"test_table\": {\n            \"columns\": [\n                {\"name\": \"id\", \"type\": \"bigint\", \"description\": \"Primary key\"},\n                {\"name\": \"name\", \"type\": \"varchar\", \"description\": \"User name\"},\n            ],\n            \"description\": \"A test table for unit tests\",\n            \"sql_rule\": \"Always filter by active status\",\n        }\n    }\n\n\n@pytest.fixture\ndef sample_messages() -> list:\n    \"\"\"Sample message history fixture.\"\"\"\n    return [\n        HumanMessage(content=\"What's the user count?\"),\n        AIMessage(content=\"I'll help you get the user count from the database.\"),\n        HumanMessage(content=\"Show me the SQL query\"),\n    ]\n\n\n@pytest.fixture(autouse=True)\ndef reset_config_loader():\n    \"\"\"Reset ConfigLoader singleton state before each test.\"\"\"\n    # Reset the singleton instance to ensure clean state\n    ConfigLoader._instance = None\n    ConfigLoader._config = None\n    yield\n    # Clean up after test\n    ConfigLoader._instance = None\n    ConfigLoader._config = None\n\n\n@pytest.fixture\ndef mock_config():\n    \"\"\"Provide a mock configuration for tests that need it.\"\"\"\n    from unittest.mock import MagicMock\n\n    config_dict = {\n        \"organization\": \"Test Company\",\n        \"dialect\": \"presto\",\n        \"default_llm\": MagicMock(),\n        \"embedding_model\": MagicMock(),\n        \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n        \"report_directory\": \"./data\",\n        \"python_executor\": \"local\",\n        \"visualization_mode\": \"rule\",\n        \"context_config\": {},\n    }\n\n    loader = ConfigLoader()\n    loader.set(config_dict)\n    return loader.get()\n\n\n@pytest.fixture(autouse=True)\ndef setup_test_env(monkeypatch, temp_dir):\n    \"\"\"Setup test environment variables.\"\"\"\n    monkeypatch.setenv(\"OPENCHATBI_CONFIG_PATH\", str(temp_dir / \"config.yaml\"))\n    monkeypatch.setenv(\"OPENCHATBI_TEST_MODE\", \"true\")\n\n\nclass MockTokenService:\n    \"\"\"Mock token service for testing.\"\"\"\n\n    def __init__(self):\n        self.token = \"mock_token_12345\"\n\n    def get_token(self) -> str:\n        return self.token\n\n\n@pytest.fixture\ndef mock_token_service() -> MockTokenService:\n    \"\"\"Mock token service fixture.\"\"\"\n    return MockTokenService()\n\n\n@pytest.fixture\ndef sample_sql_examples() -> list:\n    \"\"\"Sample SQL examples fixture.\"\"\"\n    return [\n        {\"question\": \"How many users are there?\", \"sql\": \"SELECT COUNT(*) FROM users;\", \"tables\": [\"users\"]},\n        {\n            \"question\": \"What's the average age?\",\n            \"sql\": \"SELECT AVG(age) FROM users WHERE age IS NOT NULL;\",\n            \"tables\": [\"users\"],\n        },\n    ]\n\n\n@pytest.fixture\ndef mock_presto_connection():\n    \"\"\"Mock Presto connection fixture.\"\"\"\n    mock_conn = Mock()\n    mock_cursor = Mock()\n\n    # Setup cursor behavior\n    mock_cursor.fetchall.return_value = [(\"table1\", \"Test table 1\"), (\"table2\", \"Test table 2\")]\n    mock_cursor.description = [(\"table_name\",), (\"description\",)]\n\n    mock_conn.cursor.return_value = mock_cursor\n    mock_conn.execute.return_value = mock_cursor\n\n    return mock_conn\n"
  },
  {
    "path": "tests/context_management/README.md",
    "content": "# Context Management Test Suite\n\nThis directory contains comprehensive tests for the context management functionality in OpenChatBI.\n\n## Test Structure\n\n### 📁 Test Files\n\n- **`test_context_manager.py`** - Unit tests for the `ContextManager` class\n- **`test_context_config.py`** - Tests for context configuration management\n- **`test_agent_graph_integration.py`** - Integration tests for agent graph with context management\n- **`test_edge_cases.py`** - Edge case handling\n- **`test_state_operations.py`** - Tests for state operations and message processing\n- **`conftest.py`** - Shared pytest fixtures and configuration\n- **`test_runner.py`** - Custom test runner script\n\n## 🧪 Test Categories\n\n### Unit Tests (`test_context_manager.py`)\n\nTests core functionality of the `ContextManager` class:\n\n- ✅ Token estimation and message token calculation\n- ✅ Tool output trimming (generic, SQL, Python code)\n- ✅ Conversation summarization\n- ✅ Context management with sliding window\n- ✅ Tool wrapper functionality\n- ✅ Configuration-based behavior\n\n**Key test cases:**\n- `test_trim_sql_output()` - Tests intelligent SQL result trimming\n- `test_conversation_summary_success()` - Tests LLM-based summarization\n- `test_manage_context_with_summarization()` - Tests full context management flow\n\n### Configuration Tests (`test_context_config.py`)\n\nTests configuration management and validation:\n\n- ✅ Default configuration values\n- ✅ Custom configuration creation\n- ✅ Configuration updates\n- ✅ Edge cases (zero/negative values)\n- ✅ Different configuration presets\n\n**Key test cases:**\n- `test_update_context_config_multiple_values()` - Tests configuration updates\n- `test_production_optimized_config()` - Tests realistic production settings\n\n### Integration Tests (`test_agent_graph_integration.py`)\n\nTests integration with the agent graph system:\n\n- ✅ Agent router with context management\n- ✅ Graph building with/without context management\n- ✅ Tool wrapping in graph context\n- ✅ Full conversation flow testing\n- ✅ System message preservation\n\n**Key test cases:**\n- `test_agent_router_with_context_manager()` - Tests router integration\n- `test_full_context_management_flow()` - Tests end-to-end functionality\n\n### State Operations Tests (`test_state_operations.py`)\n\nTests state manipulation and message processing operations:\n\n- ✅ Message trimming and truncation logic\n- ✅ State updates and modifications\n- ✅ Message type handling and conversion\n- ✅ Context state preservation during operations\n- ✅ Error handling in state operations\n\n**Key test cases:**\n- `test_trim_messages_by_token_count()` - Tests message trimming logic\n- `test_state_message_processing()` - Tests state message operations\n- `test_context_state_updates()` - Tests context state modifications\n\n### Edge Cases (`test_edge_cases.py`)\n\nTests system behavior under stress and edge conditions:\n\n- ✅ Unicode and encoding edge cases\n- ✅ Malformed input handling\n\n**Key test cases:**\n- `test_sql_output_edge_cases()` - SQL edge cases\n- `test_extremely_nested_or_complex_structures()` - Complex data structures\n\n## 🚀 Running Tests\n\n### Using the Test Runner\n\n```bash\n# Run all tests\npython tests/context_management/test_runner.py\n\n# Run only unit tests\npython tests/context_management/test_runner.py --type unit\n\n# Run with coverage reporting\npython tests/context_management/test_runner.py --coverage\n```\n\n### Using Pytest Directly\n\n```bash\n# Run all context management tests\npytest tests/context_management/\n\n# Run specific test file\npytest tests/context_management/test_context_manager.py\n\n# Run with verbose output\npytest tests/context_management/ -v\n\n# Run with coverage\npytest tests/context_management/ --cov=openchatbi.context_manager --cov-report=html\n```\n\n## 📊 Test Markers\n\nTests are organized using pytest markers:\n\n- `@pytest.mark.integration` - Integration tests\n- `@pytest.mark.slow` - Slow-running tests (can be excluded)\n\n## 🎯 Test Coverage Areas\n\n### Core Functionality\n- [x] Token estimation and management\n- [x] Message processing and trimming\n- [x] Conversation summarization\n- [x] Context compression strategies\n\n### Tool Output Management\n- [x] SQL output trimming with structure preservation\n- [x] Python code output handling\n- [x] Error message preservation\n- [x] Generic output trimming\n\n### Configuration Management\n- [x] Default and custom configurations\n- [x] Configuration validation\n- [x] Runtime configuration updates\n- [x] Edge case configurations\n\n### Integration Points\n- [x] Agent router integration\n- [x] Graph building integration\n- [x] Tool wrapper integration\n- [x] LLM service integration\n\n### Edge Cases\n- [x] Unicode and encoding issues\n- [x] Malformed input handling\n\n## 🧩 Fixtures\n\n### Common Fixtures (in `conftest.py`)\n\n- `mock_llm` - Mock language model for testing\n- `standard_config` - Standard test configuration\n- `minimal_config` - Minimal configuration for edge testing\n- `sample_conversation` - Sample conversation data\n- `large_sql_output` - Large SQL output for trimming tests\n- `error_output` - Sample error output for preservation tests\n\n## 🔧 Extending Tests\n\n### Adding New Test Cases\n\n1. Choose the appropriate test file based on the functionality\n2. Use existing fixtures from `conftest.py`\n3. Follow the naming convention: `test_feature_description()`\n4. Add appropriate markers for categorization\n\n### Adding New Fixtures\n\nAdd shared fixtures to `conftest.py` if they'll be used across multiple test files.\n\n## 🐛 Debugging Tests\n\n### Common Issues\n\n1. **Mock LLM failures**: Ensure proper mocking of LLM responses\n2. **Configuration conflicts**: Use isolated config instances\n3. **Memory leaks in large tests**: Force garbage collection with `gc.collect()`\n\n### Debugging Tools\n\n```bash\n# Run with debugging output\npytest tests/context_management/ -v -s\n\n# Run single test with full traceback\npytest tests/context_management/test_name.py::test_function -v --tb=long\n\n# Profile test performance\npytest tests/context_management/ --profile\n```\n\n## 📋 Test Results\n\nExpected test results:\n- **Total tests**: ~100+ test cases across 6 test files\n- **Coverage target**: >95% for context management modules\n- **State operations tests**: All message processing should work correctly\n- **Edge cases**: All should handle gracefully without crashes"
  },
  {
    "path": "tests/context_management/__init__.py",
    "content": "\"\"\"Context management test package.\"\"\"\n\n# Test package initialization\n"
  },
  {
    "path": "tests/context_management/conftest.py",
    "content": "\"\"\"Pytest configuration and fixtures for context management tests.\"\"\"\n\nfrom unittest.mock import Mock\n\nimport pytest\nfrom langchain_core.messages import AIMessage\n\nfrom openchatbi.context_config import ContextConfig\n\n\n@pytest.fixture\ndef mock_llm():\n    \"\"\"Mock LLM for testing across all test modules.\"\"\"\n    llm = Mock()\n    llm.bind_tools = Mock(return_value=llm)\n    return llm\n\n\n@pytest.fixture\ndef mock_llm_with_summary_response():\n    \"\"\"Mock LLM that returns a summary response.\"\"\"\n    llm = Mock()\n    llm.bind_tools = Mock(return_value=llm)\n    return llm\n\n\n@pytest.fixture\ndef standard_config():\n    \"\"\"Standard test configuration.\"\"\"\n    return ContextConfig(\n        enabled=True,\n        summary_trigger_tokens=12000,\n        keep_recent_messages=10,\n        max_tool_output_length=2000,\n        max_sql_result_rows=20,\n        max_code_output_lines=50,\n        enable_summarization=True,\n        enable_conversation_summary=True,\n        preserve_tool_errors=True,\n    )\n\n\n@pytest.fixture\ndef minimal_config():\n    \"\"\"Minimal test configuration.\"\"\"\n    return ContextConfig(\n        enabled=True,\n        summary_trigger_tokens=800,\n        keep_recent_messages=3,\n        max_tool_output_length=200,\n        max_sql_result_rows=5,\n        max_code_output_lines=10,\n    )\n\n\n@pytest.fixture\ndef disabled_config():\n    \"\"\"Configuration with context management disabled.\"\"\"\n    return ContextConfig(\n        enabled=False, summary_trigger_tokens=12000, keep_recent_messages=10, max_tool_output_length=2000\n    )\n\n\n@pytest.fixture\ndef sample_conversation():\n    \"\"\"Sample conversation for testing.\"\"\"\n    from langchain_core.messages import HumanMessage, ToolMessage\n\n    return [\n        HumanMessage(content=\"Can you analyze our sales data?\"),\n        AIMessage(content=\"I'll help you analyze the sales data. Let me query the database.\"),\n        ToolMessage(content=\"Query executed successfully. Found 1000 records.\", tool_call_id=\"query_1\"),\n        HumanMessage(content=\"What are the top trends?\"),\n        AIMessage(content=\"Based on the data, I can see several key trends...\"),\n        HumanMessage(content=\"Can you create a visualization?\"),\n        AIMessage(\n            content=\"I'll create a chart for you.\",\n            tool_calls=[{\"name\": \"create_chart\", \"args\": {\"type\": \"bar\"}, \"id\": \"chart_1\"}],\n        ),\n        ToolMessage(content=\"Chart created successfully.\", tool_call_id=\"chart_1\"),\n    ]\n\n\n@pytest.fixture\ndef large_sql_output():\n    \"\"\"Large SQL output for testing trimming.\"\"\"\n    csv_data = \"id,name,value,date\\n\"\n    csv_data += \"\\n\".join([f\"{i},Customer_{i},{i*100},2023-01-{i%30+1:02d}\" for i in range(100)])\n\n    return f\"\"\"SQL Query:\n```sql\nSELECT id, name, value, date\nFROM customers\nORDER BY value DESC\nLIMIT 100;\n```\n\nQuery Results (CSV format):\n```csv\n{csv_data}\n```\n\nVisualization Created: bar chart has been automatically generated and will be displayed in the UI.\"\"\"\n\n\n@pytest.fixture\ndef large_python_output():\n    \"\"\"Large Python code execution output.\"\"\"\n    output_lines = []\n    output_lines.append(\"Processing data...\")\n    for i in range(100):\n        output_lines.append(f\"Step {i}: Processing record {i} - Status: OK\")\n    output_lines.append(\"Processing complete!\")\n\n    return \"\\n\".join(output_lines)\n\n\n@pytest.fixture\ndef error_output():\n    \"\"\"Sample error output.\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"/app/analysis.py\", line 42, in analyze_sales\n    df = pd.read_csv('nonexistent_file.csv')\n  File \"/usr/local/lib/python3.9/site-packages/pandas/io/parsers/readers.py\", line 912, in read_csv\n    return _read(filepath_or_buffer, kwds)\nFileNotFoundError: [Errno 2] No such file or directory: 'nonexistent_file.csv'\n\nError: Could not load the sales data file. Please check that the file exists and is accessible.\"\"\"\n\n\n# Pytest configuration\ndef pytest_configure(config):\n    \"\"\"Configure pytest with custom markers.\"\"\"\n    config.addinivalue_line(\"markers\", \"slow: marks tests as slow (deselect with '-m \\\"not slow\\\"')\")\n    config.addinivalue_line(\"markers\", \"integration: marks tests as integration tests\")\n\n\ndef pytest_collection_modifyitems(config, items):\n    \"\"\"Modify test items to add markers.\"\"\"\n    for item in items:\n        # Mark integration tests\n        if \"integration\" in item.nodeid.lower():\n            item.add_marker(pytest.mark.integration)\n\n        # Mark slow tests based on certain patterns\n        if any(pattern in item.nodeid.lower() for pattern in [\"large\", \"stress\", \"concurrent\"]):\n            item.add_marker(pytest.mark.slow)\n"
  },
  {
    "path": "tests/context_management/test_agent_graph_integration.py",
    "content": "\"\"\"Integration tests for agent graph with context management.\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\nfrom langchain_core.messages import AIMessage, HumanMessage, ToolMessage\nfrom langchain_core.tools import StructuredTool\n\nfrom openchatbi.agent_graph import _build_graph_core, agent_llm_call, build_agent_graph_async, build_agent_graph_sync\nfrom openchatbi.context_config import ContextConfig\nfrom openchatbi.context_manager import ContextManager\nfrom openchatbi.graph_state import AgentState\n\n\nclass TestAgentGraphIntegration:\n    \"\"\"Integration tests for agent graph with context management.\"\"\"\n\n    @pytest.fixture\n    def mock_catalog(self):\n        \"\"\"Mock catalog store for testing.\"\"\"\n        catalog = Mock()\n        catalog.get_schema = Mock(return_value={\"tables\": []})\n        return catalog\n\n    @pytest.fixture\n    def mock_llm(self):\n        \"\"\"Mock LLM for testing.\"\"\"\n        llm = Mock()\n        llm.bind_tools = Mock(return_value=llm)\n        return llm\n\n    @pytest.fixture\n    def mock_tools(self):\n        \"\"\"Mock tools for testing.\"\"\"\n\n        def mock_tool_func(query: str) -> str:\n            return \"Mock tool result\"\n\n        tool = StructuredTool.from_function(func=mock_tool_func, name=\"mock_tool\", description=\"Mock tool for testing\")\n        return [tool]\n\n    @pytest.fixture\n    def test_config(self):\n        \"\"\"Test configuration for context management.\"\"\"\n        return ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=800,\n            keep_recent_messages=3,\n            max_tool_output_length=100,\n        )\n\n    def test_agent_llm_node_with_context_manager(self, mock_llm, mock_tools, test_config):\n        \"\"\"Test agent llm_node with context manager integration.\"\"\"\n        context_manager = ContextManager(llm=mock_llm, config=test_config)\n\n        # Mock LLM response\n        mock_response = AIMessage(content=\"Test response\", tool_calls=[])\n        with patch(\"openchatbi.agent_graph.call_llm_chat_model_with_retry\", return_value=mock_response):\n            llm_node_func = agent_llm_call(mock_llm, mock_tools, context_manager)\n\n            # Create test state with long messages to trigger context management\n            long_messages = [\n                HumanMessage(content=\"A\" * 500),  # Long message\n                AIMessage(content=\"B\" * 500),  # Long message\n                ToolMessage(content=\"C\" * 200, tool_call_id=\"123\"),  # Long tool output\n                HumanMessage(content=\"Recent question\"),\n            ]\n\n            state = AgentState(messages=long_messages)\n            result = llm_node_func(state)\n\n            # Should have processed the state\n            assert \"messages\" in result\n            assert isinstance(result[\"messages\"][0], AIMessage)\n\n    def test_agent_llm_node_without_context_manager(self, mock_llm, mock_tools):\n        \"\"\"Test agent llm_node without context manager.\"\"\"\n        mock_response = AIMessage(content=\"Test response\", tool_calls=[])\n        with patch(\"openchatbi.agent_graph.call_llm_chat_model_with_retry\", return_value=mock_response):\n            llm_node_func = agent_llm_call(mock_llm, mock_tools, context_manager=None)\n\n            state = AgentState(messages=[HumanMessage(content=\"Test\")])\n            result = llm_node_func(state)\n\n            assert \"messages\" in result\n            assert isinstance(result[\"messages\"][0], AIMessage)\n\n    def test_build_graph_core_with_context_management(self, mock_catalog, mock_llm):\n        \"\"\"Test core graph building with context management enabled.\"\"\"\n\n        def create_mock_tool(name):\n            def mock_func(input_str: str) -> str:\n                return f\"Mock {name} result\"\n\n            return StructuredTool.from_function(func=mock_func, name=name, description=f\"Mock {name} tool\")\n\n        # Mock all the tool imports directly\n        with (\n            patch(\"openchatbi.agent_graph.search_knowledge\", create_mock_tool(\"search_knowledge\")),\n            patch(\"openchatbi.agent_graph.show_schema\", create_mock_tool(\"show_schema\")),\n            patch(\"openchatbi.agent_graph.run_python_code\", create_mock_tool(\"run_python_code\")),\n            patch(\"openchatbi.agent_graph.save_report\", create_mock_tool(\"save_report\")),\n            patch(\"openchatbi.agent_graph.timeseries_forecast\", create_mock_tool(\"timeseries_forecast\")),\n            patch(\"openchatbi.agent_graph.get_sql_tools\") as mock_get_sql_tools,\n            patch(\"openchatbi.agent_graph.build_sql_graph\") as mock_sql_graph,\n            patch(\"openchatbi.agent_graph.get_memory_tools\") as mock_memory_tools,\n            patch(\"openchatbi.agent_graph.create_mcp_tools_sync\") as mock_mcp_tools,\n            patch(\"openchatbi.agent_graph.get_llm\", return_value=mock_llm),\n        ):\n\n            # Setup function-based mocks\n            mock_get_sql_tools.return_value = create_mock_tool(\"call_sql_graph_tool\")\n            mock_sql_graph.return_value = Mock()\n            mock_memory_tools.return_value = (\n                create_mock_tool(\"manage_memory_tool\"),\n                create_mock_tool(\"search_memory_tool\"),\n            )\n            mock_mcp_tools.return_value = []\n\n            graph = _build_graph_core(\n                catalog=mock_catalog,\n                sync_mode=True,\n                checkpointer=None,\n                memory_store=None,\n                memory_tools=None,\n                mcp_tools=[],\n                enable_context_management=True,\n            )\n\n            # Should create a compiled graph\n            assert graph is not None\n            # Verify that SQL graph was initialized\n            mock_sql_graph.assert_called_once()\n\n    def test_build_graph_core_without_context_management(self, mock_catalog, mock_llm):\n        \"\"\"Test core graph building with context management disabled.\"\"\"\n\n        def create_mock_tool(name):\n            def mock_func(input_str: str) -> str:\n                return f\"Mock {name} result\"\n\n            return StructuredTool.from_function(func=mock_func, name=name, description=f\"Mock {name} tool\")\n\n        # Mock all the tool imports directly - same pattern as with context management\n        with (\n            patch(\"openchatbi.agent_graph.search_knowledge\", create_mock_tool(\"search_knowledge\")),\n            patch(\"openchatbi.agent_graph.show_schema\", create_mock_tool(\"show_schema\")),\n            patch(\"openchatbi.agent_graph.run_python_code\", create_mock_tool(\"run_python_code\")),\n            patch(\"openchatbi.agent_graph.save_report\", create_mock_tool(\"save_report\")),\n            patch(\"openchatbi.agent_graph.timeseries_forecast\", create_mock_tool(\"timeseries_forecast\")),\n            patch(\"openchatbi.agent_graph.get_sql_tools\") as mock_get_sql_tools,\n            patch(\"openchatbi.agent_graph.build_sql_graph\") as mock_sql_graph,\n            patch(\"openchatbi.agent_graph.get_memory_tools\") as mock_memory_tools,\n            patch(\"openchatbi.agent_graph.create_mcp_tools_sync\") as mock_mcp_tools,\n            patch(\"openchatbi.agent_graph.get_llm\", return_value=mock_llm),\n        ):\n\n            # Setup function-based mocks\n            mock_get_sql_tools.return_value = create_mock_tool(\"call_sql_graph_tool\")\n            mock_sql_graph.return_value = Mock()\n            mock_memory_tools.return_value = (\n                create_mock_tool(\"manage_memory_tool\"),\n                create_mock_tool(\"search_memory_tool\"),\n            )\n            mock_mcp_tools.return_value = []\n\n            graph = _build_graph_core(\n                catalog=mock_catalog,\n                sync_mode=True,\n                checkpointer=None,\n                memory_store=None,\n                memory_tools=None,\n                mcp_tools=[],\n                enable_context_management=False,\n            )\n\n            # Should still create a compiled graph\n            assert graph is not None\n\n    def test_build_agent_graph_sync_with_context_management(self, mock_catalog):\n        \"\"\"Test sync graph building with context management.\"\"\"\n        with (\n            patch(\"openchatbi.agent_graph.create_mcp_tools_sync\") as mock_mcp_tools,\n            patch(\"openchatbi.agent_graph._build_graph_core\") as mock_build_core,\n        ):\n\n            mock_build_core.return_value = Mock()\n            mock_mcp_tools.return_value = []\n\n            graph = build_agent_graph_sync(catalog=mock_catalog, enable_context_management=True)\n\n            # Verify _build_graph_core was called with correct parameters\n            mock_build_core.assert_called_once()\n            call_args = mock_build_core.call_args\n            assert call_args[1][\"enable_context_management\"] is True\n\n            # Should return the graph\n            assert graph is not None\n\n    @pytest.mark.asyncio\n    async def test_build_agent_graph_async_with_context_management(self, mock_catalog):\n        \"\"\"Test async graph building with context management.\"\"\"\n        with (\n            patch(\"openchatbi.agent_graph.get_mcp_tools_async\") as mock_mcp_tools,\n            patch(\"openchatbi.agent_graph._build_graph_core\") as mock_build_core,\n        ):\n\n            mock_build_core.return_value = Mock()\n            # Mock async function\n            mock_mcp_tools.return_value = []\n\n            graph = await build_agent_graph_async(catalog=mock_catalog, enable_context_management=True)\n\n            # Verify _build_graph_core was called with correct parameters\n            mock_build_core.assert_called_once()\n            call_args = mock_build_core.call_args\n            assert call_args[1][\"enable_context_management\"] is True\n\n            # Should return the graph\n            assert graph is not None\n\n    @patch(\"openchatbi.agent_graph.call_llm_chat_model_with_retry\")\n    def test_full_context_management_flow(self, mock_llm_call, mock_catalog):\n        \"\"\"Test full context management flow in agent graph.\"\"\"\n        # Mock LLM responses\n        mock_llm_call.side_effect = [\n            AIMessage(content=\"Response 1\"),\n            AIMessage(content=\"Summary of conversation\"),  # For summarization\n            AIMessage(content=\"Final response\"),\n        ]\n\n        context_manager = ContextManager(\n            llm=Mock(),\n            config=ContextConfig(\n                enabled=True,\n                summary_trigger_tokens=80,\n                keep_recent_messages=2,\n            ),\n        )\n\n        # Create many messages to trigger context management\n        messages = []\n        for i in range(10):\n            messages.extend(\n                [\n                    HumanMessage(content=f\"Question {i}\" * 10),  # Make messages longer\n                    AIMessage(content=f\"Response {i}\" * 10),\n                    ToolMessage(content=f\"Tool result {i}\" * 20, tool_call_id=f\"tool_{i}\"),\n                ]\n            )\n\n        # Test context management\n        original_count = len(messages)\n        context_manager.manage_context_messages(messages)\n        managed_messages = messages\n\n        # Should have fewer messages than input\n        assert len(managed_messages) < original_count\n\n        # Should preserve recent messages\n        assert any(\"Question 9\" in str(msg.content) for msg in managed_messages if hasattr(msg, \"content\"))\n\n\nclass TestContextManagementEdgeCases:\n    \"\"\"Test edge cases for context management in agent graph.\"\"\"\n\n    def test_empty_message_handling(self):\n        \"\"\"Test handling of empty messages.\"\"\"\n        config = ContextConfig(enabled=True)\n        context_manager = ContextManager(llm=Mock(), config=config)\n\n        messages = []\n        context_manager.manage_context_messages(messages)\n        result = messages\n        assert result == []\n\n    def test_state_message_type_validation(self):\n        \"\"\"Test that only valid state message types are maintained during context management.\"\"\"\n        config = ContextConfig(enabled=True)\n        context_manager = ContextManager(llm=Mock(), config=config)\n\n        # State should only contain valid message types (no SystemMessage)\n        messages = [\n            HumanMessage(content=\"A\" * 100),  # Long message\n            AIMessage(content=\"B\" * 100),  # Long message\n            HumanMessage(content=\"Recent question\"),\n        ]\n\n        with patch(\n            \"openchatbi.context_manager.call_llm_chat_model_with_retry\", return_value=AIMessage(content=\"Summary\")\n        ):\n            context_manager.manage_context_messages(messages)\n            result = messages\n\n        # Should only contain valid state message types\n        valid_types = {HumanMessage, AIMessage, ToolMessage}\n        assert all(type(msg) in valid_types for msg in result), \"Should only contain valid state message types\"\n\n    def test_context_management_with_tool_calls(self):\n        \"\"\"Test context management when AI messages have tool calls.\"\"\"\n        config = ContextConfig(enabled=True)\n        context_manager = ContextManager(llm=Mock(), config=config)\n\n        ai_message_with_tools = AIMessage(\n            content=\"I'll help you with that.\",\n            tool_calls=[{\"name\": \"search_tool\", \"args\": {\"query\": \"test\"}, \"id\": \"call_123\"}],\n        )\n\n        messages = [ai_message_with_tools, HumanMessage(content=\"Follow up\")]\n\n        context_manager.manage_context_messages(messages)\n        result = messages\n\n        # AI message with tool calls should be preserved\n        ai_msgs = [msg for msg in result if isinstance(msg, AIMessage)]\n        assert len(ai_msgs) > 0\n        assert any(hasattr(msg, \"tool_calls\") and msg.tool_calls for msg in ai_msgs)\n\n    @patch(\"openchatbi.context_manager.call_llm_chat_model_with_retry\")\n    def test_summarization_failure_fallback(self, mock_llm_call):\n        \"\"\"Test fallback behavior when summarization fails.\"\"\"\n        # Mock LLM failure\n        mock_llm_call.side_effect = Exception(\"LLM unavailable\")\n\n        config = ContextConfig(enabled=True)\n        context_manager = ContextManager(llm=Mock(), config=config)\n\n        # Create messages that would trigger summarization (no SystemMessage in state)\n        messages = [\n            HumanMessage(content=\"A\" * 100),  # Long messages to trigger\n            AIMessage(content=\"B\" * 100),\n            HumanMessage(content=\"C\" * 100),\n            AIMessage(content=\"D\" * 100),\n            HumanMessage(content=\"Recent\"),\n        ]\n\n        context_manager.manage_context_messages(messages)\n        result = messages\n\n        # Should fallback to sliding window\n        assert len(result) <= len(messages)\n        # Should preserve recent messages and only contain valid state message types\n        assert any(\"Recent\" in str(msg.content) for msg in result if hasattr(msg, \"content\"))\n        valid_types = {HumanMessage, AIMessage, ToolMessage}\n        assert all(type(msg) in valid_types for msg in result), \"Should only contain valid state message types\"\n"
  },
  {
    "path": "tests/context_management/test_context_config.py",
    "content": "\"\"\"Unit tests for context configuration.\"\"\"\n\nfrom openchatbi.context_config import ContextConfig, get_context_config, update_context_config\n\n\nclass TestContextConfig:\n    \"\"\"Test cases for ContextConfig class.\"\"\"\n\n    def test_default_config_values(self):\n        \"\"\"Test that default configuration has expected values.\"\"\"\n        config = ContextConfig()\n\n        # Test default values\n        assert config.enabled is True\n        assert config.summary_trigger_tokens == 12000\n        assert config.keep_recent_messages == 20\n        assert config.max_tool_output_length == 2000\n        assert config.max_sql_result_rows == 50\n        assert config.max_code_output_lines == 50\n\n        # Test boolean flags\n        assert config.enable_summarization is True\n        assert config.enable_conversation_summary is True\n        assert config.preserve_tool_errors is True\n        assert config.preserve_recent_sql is True\n\n    def test_custom_config_values(self):\n        \"\"\"Test creating config with custom values.\"\"\"\n        config = ContextConfig(\n            enabled=False,\n            summary_trigger_tokens=8000,\n            keep_recent_messages=5,\n            max_tool_output_length=1000,\n            enable_summarization=False,\n        )\n\n        assert config.enabled is False\n        assert config.summary_trigger_tokens == 8000\n        assert config.keep_recent_messages == 5\n        assert config.max_tool_output_length == 1000\n        assert config.enable_summarization is False\n\n        # Other values should use defaults\n        assert config.max_sql_result_rows == 50\n        assert config.preserve_tool_errors is True\n\n    def test_config_validation_logic(self):\n        \"\"\"Test logical relationships in configuration.\"\"\"\n        config = ContextConfig()\n\n        # Keep recent messages should be reasonable\n        assert config.keep_recent_messages > 0\n        assert config.keep_recent_messages < 100  # Sanity check\n\n        # Output limits should be positive\n        assert config.max_tool_output_length > 0\n        assert config.max_sql_result_rows > 0\n        assert config.max_code_output_lines > 0\n\n        # Token limits should be reasonable\n        assert config.summary_trigger_tokens > 0\n\n    def test_get_context_config(self):\n        \"\"\"Test getting context configuration.\"\"\"\n        config = get_context_config()\n        assert isinstance(config, ContextConfig)\n\n    def test_update_context_config_single_value(self):\n        \"\"\"Test updating a single configuration value.\"\"\"\n        original_trigger_tokens = get_context_config().summary_trigger_tokens\n\n        updated_config = update_context_config(summary_trigger_tokens=15000)\n\n        assert updated_config.summary_trigger_tokens == 15000\n        # Other values should remain unchanged\n        assert updated_config.keep_recent_messages == get_context_config().keep_recent_messages\n\n    def test_update_context_config_multiple_values(self):\n        \"\"\"Test updating multiple configuration values.\"\"\"\n        updated_config = update_context_config(\n            summary_trigger_tokens=20000,\n            keep_recent_messages=15,\n            enable_summarization=False,\n            max_tool_output_length=3000,\n        )\n\n        assert updated_config.summary_trigger_tokens == 20000\n        assert updated_config.keep_recent_messages == 15\n        assert updated_config.enable_summarization is False\n        assert updated_config.max_tool_output_length == 3000\n\n    def test_update_context_config_invalid_attribute(self):\n        \"\"\"Test updating config with invalid attribute name.\"\"\"\n        # Should not raise error, just ignore invalid attributes\n        config = update_context_config(invalid_attribute=123)\n        assert not hasattr(config, \"invalid_attribute\")\n\n    def test_update_context_config_returns_copy(self):\n        \"\"\"Test that update_context_config returns a modified copy.\"\"\"\n        original_config = get_context_config()\n        updated_config = update_context_config(summary_trigger_tokens=30000)\n\n        # Original should be unchanged (if it's designed that way)\n        # Updated should have new values\n        assert updated_config.summary_trigger_tokens == 30000\n\n\nclass TestContextConfigPresets:\n    \"\"\"Test different configuration presets for common scenarios.\"\"\"\n\n    def test_minimal_context_config(self):\n        \"\"\"Test configuration for minimal context management.\"\"\"\n        config = ContextConfig(\n            enabled=True,\n            enable_summarization=False,\n            enable_conversation_summary=False,\n            max_tool_output_length=500,\n        )\n\n        assert config.enabled is True\n        assert config.enable_summarization is False\n        assert config.enable_conversation_summary is False\n\n    def test_aggressive_compression_config(self):\n        \"\"\"Test configuration for aggressive context compression.\"\"\"\n        config = ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=6000,\n            keep_recent_messages=5,\n            max_tool_output_length=1000,\n            max_sql_result_rows=10,\n            max_code_output_lines=20,\n            enable_summarization=True,\n        )\n\n        assert config.summary_trigger_tokens == 6000\n        assert config.keep_recent_messages == 5\n        assert config.max_tool_output_length == 1000\n        assert config.max_sql_result_rows == 10\n        assert config.max_code_output_lines == 20\n\n    def test_development_debug_config(self):\n        \"\"\"Test configuration suitable for development/debugging.\"\"\"\n        config = ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=40000,\n            keep_recent_messages=20,\n            max_tool_output_length=10000,  # Don't trim much\n            preserve_tool_errors=True,  # Always preserve errors\n        )\n\n        assert config.preserve_tool_errors is True\n\n    def test_production_optimized_config(self):\n        \"\"\"Test configuration optimized for production use.\"\"\"\n        config = ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=10000,\n            keep_recent_messages=8,\n            max_tool_output_length=1500,\n            max_sql_result_rows=15,\n            enable_summarization=True,\n            preserve_tool_errors=True,\n        )\n\n        assert config.summary_trigger_tokens == 10000\n        assert config.enable_summarization is True\n        assert config.preserve_tool_errors is True\n\n\nclass TestContextConfigEdgeCases:\n    \"\"\"Test edge cases and boundary conditions for context configuration.\"\"\"\n\n    def test_zero_values(self):\n        \"\"\"Test configuration with zero values.\"\"\"\n        config = ContextConfig(\n            summary_trigger_tokens=0,\n            keep_recent_messages=0,\n            max_tool_output_length=0,\n        )\n\n        # Should accept zero values (might cause issues in practice)\n        assert config.keep_recent_messages == 0\n        assert config.summary_trigger_tokens == 0\n        assert config.max_tool_output_length == 0\n\n    def test_very_large_values(self):\n        \"\"\"Test configuration with very large values.\"\"\"\n        config = ContextConfig(\n            summary_trigger_tokens=900000,\n            keep_recent_messages=1000,\n            max_tool_output_length=100000,\n        )\n\n        assert config.keep_recent_messages == 1000\n\n    def test_inconsistent_token_limits(self):\n        \"\"\"Test configuration where summary trigger > max tokens.\"\"\"\n\n        # Should accept but might cause logical issues\n\n    def test_all_features_disabled(self):\n        \"\"\"Test configuration with all features disabled.\"\"\"\n        config = ContextConfig(\n            enabled=False,\n            enable_summarization=False,\n            enable_conversation_summary=False,\n        )\n\n        assert config.enabled is False\n        assert config.enable_summarization is False\n        assert config.enable_conversation_summary is False\n\n    def test_config_serialization(self):\n        \"\"\"Test that config can be converted to/from dict (if needed).\"\"\"\n        config = ContextConfig(summary_trigger_tokens=15000, enable_summarization=True)\n\n        # Test converting to dict-like representation\n        config_dict = {\n            \"summary_trigger_tokens\": config.summary_trigger_tokens,\n            \"enable_summarization\": config.enable_summarization,\n            \"enabled\": config.enabled,\n        }\n\n        assert config_dict[\"summary_trigger_tokens\"] == 15000\n        assert config_dict[\"enable_summarization\"] is True\n        assert config_dict[\"enabled\"] is True\n\n    def test_config_immutability_simulation(self):\n        \"\"\"Test that config behaves consistently across operations.\"\"\"\n        config1 = ContextConfig(summary_trigger_tokens=10000)\n        config2 = ContextConfig(summary_trigger_tokens=10000)\n\n        # Same values should be equal\n        assert config1.summary_trigger_tokens == config2.summary_trigger_tokens\n        assert config1.enabled == config2.enabled\n\n    def test_realistic_configuration_scenarios(self):\n        \"\"\"Test realistic configuration scenarios.\"\"\"\n\n        # Small dataset scenario\n        small_config = ContextConfig()\n\n        # Large dataset scenario\n        large_config = ContextConfig()\n\n        # Interactive analysis scenario\n        interactive_config = ContextConfig(\n            keep_recent_messages=50,  # Keep more context for back-and-forth\n            preserve_tool_errors=True,\n            max_tool_output_length=5000,  # Don't trim too aggressively\n        )\n\n        assert interactive_config.keep_recent_messages > small_config.keep_recent_messages\n        assert interactive_config.preserve_tool_errors is True\n"
  },
  {
    "path": "tests/context_management/test_context_manager.py",
    "content": "\"\"\"Unit tests for ContextManager class.\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\nfrom langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage\n\nfrom openchatbi.context_config import ContextConfig\nfrom openchatbi.context_manager import ContextManager\n\n\nclass TestContextManager:\n    \"\"\"Test cases for ContextManager class.\"\"\"\n\n    @pytest.fixture\n    def mock_llm(self):\n        \"\"\"Mock LLM for testing.\"\"\"\n        llm = Mock()\n        # Mock response for summarization\n        llm_response = AIMessage(content=\"This is a test summary of the conversation.\")\n        with patch(\"openchatbi.context_manager.call_llm_chat_model_with_retry\", return_value=llm_response):\n            yield llm\n\n    @pytest.fixture\n    def default_config(self):\n        \"\"\"Default context configuration for testing.\"\"\"\n        return ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=900,\n            keep_recent_messages=3,\n            max_tool_output_length=200,\n            max_sql_result_rows=5,\n            max_code_output_lines=10,\n            enable_conversation_summary=True,\n            enable_summarization=True,\n        )\n\n    @pytest.fixture\n    def context_manager(self, mock_llm, default_config):\n        \"\"\"Context manager instance for testing.\"\"\"\n        return ContextManager(llm=mock_llm, config=default_config)\n\n    def test_token_estimation(self, context_manager):\n        \"\"\"Test token estimation functionality.\"\"\"\n        # Test basic token estimation\n        short_text = \"Hello world\"\n        assert context_manager.estimate_tokens(short_text) == len(short_text) // 4\n\n        # Test longer text\n        long_text = \"This is a longer text that should have more tokens estimated.\"\n        assert context_manager.estimate_tokens(long_text) > context_manager.estimate_tokens(short_text)\n\n    def test_message_token_estimation(self, context_manager):\n        \"\"\"Test token estimation for messages.\"\"\"\n        messages = [\n            HumanMessage(content=\"Hello\"),\n            AIMessage(content=\"Hi there!\"),\n            ToolMessage(content=\"Tool result\", tool_call_id=\"123\"),\n        ]\n\n        total_tokens = context_manager.estimate_message_tokens(messages)\n        assert total_tokens > 0\n        # Should include content tokens plus metadata overhead\n        assert total_tokens > sum(len(str(msg.content)) // 4 for msg in messages)\n\n    def test_trim_short_tool_output(self, context_manager):\n        \"\"\"Test trimming tool output that's already short enough.\"\"\"\n        short_output = \"This is a short output.\"\n        result = context_manager.trim_tool_output(short_output)\n        assert result == short_output\n\n    def test_trim_long_generic_output(self, context_manager):\n        \"\"\"Test trimming long generic tool output.\"\"\"\n        long_output = \"A\" * 500  # Much longer than max_tool_output_length (200)\n        result = context_manager.trim_tool_output(long_output)\n\n        assert len(result) < len(long_output)\n        assert \"... [Output truncated] ...\" in result\n        assert result.startswith(\"A\")\n        assert result.endswith(\"A\")\n\n    def test_trim_sql_output(self, context_manager):\n        \"\"\"Test trimming SQL output with structured data.\"\"\"\n        sql_output = \"\"\"SQL Query:\n```sql\nSELECT * FROM users WHERE age > 18;\n```\n\nQuery Results (CSV format):\n```csv\nid,name,age,email\n1,John,25,john@example.com\n2,Jane,30,jane@example.com\n3,Bob,22,bob@example.com\n4,Alice,28,alice@example.com\n5,Charlie,35,charlie@example.com\n6,Diana,27,diana@example.com\n7,Eve,31,eve@example.com\n```\n\nVisualization Created: bar chart has been automatically generated.\"\"\"\n\n        result = context_manager.trim_tool_output(sql_output)\n\n        # Should preserve SQL query\n        assert \"SELECT * FROM users WHERE age > 18;\" in result\n        # Should preserve visualization info\n        assert \"Visualization Created:\" in result\n        # Should trim CSV data but keep structure\n        assert \"```csv\" in result\n        assert \"rows omitted\" in result\n\n    def test_trim_code_output(self, context_manager):\n        \"\"\"Test trimming Python code execution output.\"\"\"\n        # Test long output without errors\n        long_code_output = \"\\n\".join([f\"Line {i}: Some output here\" for i in range(50)])\n        result = context_manager.trim_tool_output(long_code_output)\n\n        assert len(result.split(\"\\n\")) < 50\n        assert \"... [Output truncated] ...\" in result\n\n    def test_preserve_error_output(self, context_manager):\n        \"\"\"Test that error outputs are preserved when configured.\"\"\"\n        error_output = \"\"\"Traceback (most recent call last):\n  File \"test.py\", line 1, in <module>\n    print(undefined_variable)\nNameError: name 'undefined_variable' is not defined\"\"\"\n\n        # With preserve_tool_errors=True (default in test config)\n        result = context_manager.trim_tool_output(error_output)\n        assert result == error_output  # Should be preserved in full\n\n        # Test with preserve_tool_errors=False\n        context_manager.config.preserve_tool_errors = False\n        result = context_manager.trim_tool_output(error_output)\n        # Should still preserve because it's an error, but could be trimmed based on length\n\n    # Tool output trimming disable test removed - trimming is always enabled now\n\n    def test_conversation_summary_disabled(self, context_manager):\n        \"\"\"Test conversation summary when disabled.\"\"\"\n        context_manager.config.enable_conversation_summary = False\n        messages = [HumanMessage(content=\"Hello\"), AIMessage(content=\"Hi\")]\n\n        summary = context_manager.summarize_conversation(messages)\n        assert summary == \"\"\n\n    @patch(\"openchatbi.context_manager.call_llm_chat_model_with_retry\")\n    def test_conversation_summary_success(self, mock_llm_call, context_manager):\n        \"\"\"Test successful conversation summarization.\"\"\"\n        # Mock successful LLM response\n        mock_response = AIMessage(content=\"Summary: User asked about data analysis.\")\n        mock_llm_call.return_value = mock_response\n\n        messages = [\n            HumanMessage(content=\"Can you analyze our sales data?\"),\n            AIMessage(content=\"I'll help you analyze the sales data.\"),\n            ToolMessage(content=\"Query results: 100 records\", tool_call_id=\"123\"),\n            HumanMessage(content=\"What are the trends?\"),\n            AIMessage(content=\"The trends show increasing sales.\"),\n            HumanMessage(content=\"Recent question\"),  # This should be excluded from summary\n        ]\n\n        summary = context_manager.summarize_conversation(messages)\n        assert summary.startswith(\"[Conversation Summary]:\")\n        assert \"Summary: User asked about data analysis.\" in summary\n        mock_llm_call.assert_called_once()\n\n    @patch(\"openchatbi.context_manager.call_llm_chat_model_with_retry\")\n    def test_conversation_summary_failure(self, mock_llm_call, context_manager):\n        \"\"\"Test conversation summary when LLM call fails.\"\"\"\n        # Mock LLM failure\n        mock_llm_call.side_effect = Exception(\"LLM service unavailable\")\n\n        # Need more messages than keep_recent_messages (3) to trigger summarization\n        messages = [\n            HumanMessage(content=\"First message\"),\n            AIMessage(content=\"First response\"),\n            HumanMessage(content=\"Second message\"),\n            AIMessage(content=\"Second response\"),\n            HumanMessage(content=\"Third message\"),\n            AIMessage(content=\"Third response\"),\n        ]\n        summary = context_manager.summarize_conversation(messages)\n        assert summary == \"[Summary generation failed]\"\n\n    def test_manage_context_disabled(self, context_manager):\n        \"\"\"Test context management when disabled.\"\"\"\n        context_manager.config.enabled = False\n        messages = [HumanMessage(content=\"Test\")]\n\n        context_manager.manage_context_messages(messages)\n        result = messages\n        assert result == messages  # Should return unchanged\n\n    def test_manage_context_empty_messages(self, context_manager):\n        \"\"\"Test context management with empty message list.\"\"\"\n        messages = []\n        context_manager.manage_context_messages(messages)\n        result = messages\n        assert result == []\n\n    def test_manage_context_tool_message_trimming(self, context_manager):\n        \"\"\"Test that tool messages are trimmed during context management.\"\"\"\n        long_content = \"A\" * 500\n        # Add enough messages to trigger context management, with ToolMessage in historical part\n        # keep_recent_messages=3, so we need more than 3 messages after the ToolMessage\n        messages = [\n            HumanMessage(content=\"This is a long question that helps reach the token threshold \" * 10),\n            ToolMessage(content=long_content, tool_call_id=\"123\"),  # This should be in historical part\n            AIMessage(content=\"This is a long response that helps reach the token threshold \" * 10),\n            HumanMessage(content=\"Another long question to increase token count \" * 10),\n            AIMessage(content=\"Response \" * 20),\n            HumanMessage(content=\"Final question\"),\n            AIMessage(content=\"Final response\"),\n        ]\n\n        context_manager.manage_context_messages(messages)\n        result = messages\n\n        # Find the tool message in results\n        tool_msg = next(msg for msg in result if isinstance(msg, ToolMessage))\n        assert len(str(tool_msg.content)) < len(long_content)\n        assert \"... [Output truncated] ...\" in str(tool_msg.content)\n\n    @patch(\"openchatbi.context_manager.call_llm_chat_model_with_retry\")\n    def test_manage_context_with_summarization(self, mock_llm_call, context_manager):\n        \"\"\"Test context management triggering summarization.\"\"\"\n        # Mock successful summarization\n        mock_response = AIMessage(content=\"Conversation summary here.\")\n        mock_llm_call.return_value = mock_response\n\n        # Create many messages to trigger summarization\n        messages = []\n        for i in range(10):\n            messages.extend(\n                [\n                    HumanMessage(content=f\"Question {i}\"),\n                    AIMessage(content=f\"Response {i}\" * 100),  # Long responses to increase token count\n                ]\n            )\n\n        original_length = len(messages)\n        context_manager.manage_context_messages(messages)\n        result = messages\n\n        # Should have fewer messages than input\n        assert len(result) < original_length\n        # Should contain summary message\n        assert any(\"[Conversation Summary]:\" in str(msg.content) for msg in result if hasattr(msg, \"content\"))\n        # Verify LLM was called for summarization\n        mock_llm_call.assert_called_once()\n\n    # Tool wrapper tests removed - we now handle context at state level\n\n    def test_format_messages_for_summary(self, context_manager):\n        \"\"\"Test message formatting for summary generation.\"\"\"\n        messages = [\n            HumanMessage(content=\"User question\"),\n            AIMessage(content=\"AI response\"),\n            ToolMessage(content=\"Tool result with some data\", tool_call_id=\"123\"),\n            SystemMessage(content=\"System message\"),  # Should be excluded\n        ]\n\n        formatted = context_manager._format_messages_for_summary(messages)\n\n        assert \"<user> User question </user>\" in formatted\n        assert \"<assistant>\" in formatted and \"AI response\" in formatted\n        assert \"tool_result\" in formatted\n        assert \"System message\" not in formatted  # System messages excluded\n\n    def test_format_long_ai_message_for_summary(self, context_manager):\n        \"\"\"Test that long AI messages are truncated in summary formatting.\"\"\"\n        long_content = \"A\" * 1000\n        messages = [AIMessage(content=long_content)]\n\n        formatted = context_manager._format_messages_for_summary(messages)\n\n        assert len(formatted) < len(f\"Assistant: {long_content}\")\n        assert \"... [truncated]\" in formatted\n\n\n# Tool wrapping tests removed - we now handle context at state level instead of wrapping tools\n\n\n# Pytest fixtures and test data\n@pytest.fixture\ndef sample_sql_output():\n    \"\"\"Sample SQL output for testing.\"\"\"\n    return \"\"\"SQL Query:\n```sql\nSELECT customer_id, SUM(amount) as total\nFROM orders\nWHERE order_date >= '2023-01-01'\nGROUP BY customer_id\nORDER BY total DESC;\n```\n\nQuery Results (CSV format):\n```csv\ncustomer_id,total\n1001,15420.50\n1002,12350.75\n1003,11200.00\n1004,9875.25\n1005,8650.00\n1006,7500.50\n1007,6200.75\n1008,5800.00\n1009,4950.25\n1010,4200.00\n```\n\nVisualization Created: bar chart has been automatically generated and will be displayed in the UI.\"\"\"\n\n\n@pytest.fixture\ndef sample_error_output():\n    \"\"\"Sample error output for testing.\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"/app/code.py\", line 15, in analyze_data\n    result = df.groupby('nonexistent_column').sum()\n  File \"/usr/local/lib/python3.9/site-packages/pandas/core/groupby/groupby.py\", line 1647, in sum\n    return self._cython_transform(\"sum\", numeric_only=numeric_only, **kwargs)\nKeyError: 'nonexistent_column'\n\nError: Column 'nonexistent_column' not found in DataFrame. Available columns: ['customer_id', 'order_date', 'amount', 'product_id']\"\"\"\n"
  },
  {
    "path": "tests/context_management/test_edge_cases.py",
    "content": "\"\"\"Edge cases for context management.\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\nfrom langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n\nfrom openchatbi.context_config import ContextConfig\nfrom openchatbi.context_manager import ContextManager\n\n\nclass TestContextManagementEdgeCases:\n    \"\"\"Edge cases and boundary conditions for context management.\"\"\"\n\n    @pytest.fixture\n    def edge_case_config(self):\n        \"\"\"Configuration for edge case testing.\"\"\"\n        return ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=800,\n            keep_recent_messages=2,\n            max_tool_output_length=50,\n        )\n\n    @pytest.fixture\n    def context_manager(self, edge_case_config):\n        \"\"\"Context manager for edge case testing.\"\"\"\n        return ContextManager(llm=Mock(), config=edge_case_config)\n\n    def test_empty_and_none_inputs(self, context_manager):\n        \"\"\"Test handling of empty and None inputs.\"\"\"\n        # Empty list\n        messages = []\n        context_manager.manage_context_messages(messages)\n        assert messages == []\n\n        # List with None elements (should be filtered out gracefully)\n        messages = [HumanMessage(content=\"Test\"), None, AIMessage(content=\"Response\")]\n        # Filter out None values before passing to context manager\n        filtered_messages = [msg for msg in messages if msg is not None]\n        context_manager.manage_context_messages(filtered_messages)\n        result = filtered_messages\n        assert len(result) == 2\n\n    def test_malformed_messages(self, context_manager):\n        \"\"\"Test handling of malformed messages.\"\"\"\n        # Message with None content\n        try:\n            malformed_msg = HumanMessage(content=None)\n            messages = [malformed_msg]\n            context_manager.manage_context_messages(messages)\n            result = messages\n            # Should handle gracefully\n            assert isinstance(result, list)\n        except Exception as e:\n            # If it raises an exception, it should be a reasonable one\n            assert \"content\" in str(e).lower()\n\n    def test_extremely_long_single_message(self, context_manager):\n        \"\"\"Test handling of extremely long single messages.\"\"\"\n        # Create a message longer than the entire context limit\n        very_long_content = \"A\" * 100000  # Much longer than context limit\n        long_message = HumanMessage(content=very_long_content)\n\n        messages = [long_message]\n        context_manager.manage_context_messages(messages)\n        result = messages\n\n        # Should still return the message (context management doesn't trim individual message content)\n        assert len(result) == 1\n        assert isinstance(result[0], HumanMessage)\n\n    def test_tool_message_without_tool_call_id(self, context_manager):\n        \"\"\"Test handling of tool messages without proper tool_call_id.\"\"\"\n        try:\n            # This might raise an error depending on LangChain's validation\n            tool_msg = ToolMessage(content=\"Result\", tool_call_id=\"\")\n            messages = [tool_msg]\n            context_manager.manage_context_messages(messages)\n            result = messages\n            assert isinstance(result, list)\n        except Exception:\n            # If LangChain validates and raises, that's acceptable\n            pass\n\n    def test_circular_references_in_content(self, context_manager):\n        \"\"\"Test handling of complex content that might cause issues.\"\"\"\n        # Content with special characters and formatting\n        special_content = (\n            \"\"\"\n        Content with:\n        - Unicode: 🚀 中文 العربية\n        - Code blocks: ```python\\nprint(\"hello\")\\n```\n        - JSON: {\"key\": \"value\", \"nested\": {\"array\": [1,2,3]}}\n        - HTML: <div class=\"test\">content</div>\n        - URLs: https://example.com/path?param=value\n        - Very long line: \"\"\"\n            + \"X\" * 1000\n        )\n\n        message = HumanMessage(content=special_content)\n        messages = [message]\n        context_manager.manage_context_messages(messages)\n        result = messages\n\n        assert len(result) == 1\n        assert isinstance(result[0], HumanMessage)\n\n    def test_zero_configuration_values(self):\n        \"\"\"Test behavior with zero configuration values.\"\"\"\n        zero_config = ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=0,\n            keep_recent_messages=0,\n            max_tool_output_length=0,\n        )\n\n        context_manager = ContextManager(llm=Mock(), config=zero_config)\n        messages = [HumanMessage(content=\"Test\")]\n\n        # Should handle zero values gracefully\n        context_manager.manage_context_messages(messages)\n        result = messages\n        assert isinstance(result, list)\n\n    def test_negative_configuration_values(self):\n        \"\"\"Test behavior with negative configuration values.\"\"\"\n        negative_config = ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=-50,\n            keep_recent_messages=-5,\n            max_tool_output_length=-10,\n        )\n\n        context_manager = ContextManager(llm=Mock(), config=negative_config)\n        messages = [HumanMessage(content=\"Test\")]\n\n        # Should handle negative values gracefully (might treat as disabled)\n        context_manager.manage_context_messages(messages)\n        result = messages\n        assert isinstance(result, list)\n\n    def test_unicode_and_encoding_edge_cases(self, context_manager):\n        \"\"\"Test handling of various Unicode and encoding scenarios.\"\"\"\n        unicode_messages = [\n            HumanMessage(content=\"English text\"),\n            HumanMessage(content=\"中文内容测试\"),\n            HumanMessage(content=\"العربية\"),\n            HumanMessage(content=\"Русский текст\"),\n            HumanMessage(content=\"🚀🎉💡🔥\"),  # Emojis\n            HumanMessage(content=\"Mixed: Hello 世界 🌍\"),\n            ToolMessage(content=\"Unicode tool result: café naïve résumé\", tool_call_id=\"unicode_1\"),\n        ]\n\n        context_manager.manage_context_messages(unicode_messages)\n        result = unicode_messages\n\n        # Should handle all Unicode content\n        assert len(result) > 0\n        assert all(isinstance(msg, (HumanMessage, AIMessage, ToolMessage)) for msg in result)\n\n    def test_extremely_nested_or_complex_structures(self, context_manager):\n        \"\"\"Test handling of complex nested data structures in tool outputs.\"\"\"\n        # Simulate deeply nested JSON output\n        nested_data = {\"level1\": {\"level2\": {\"level3\": {\"data\": [\"item\"] * 1000}}}}\n        complex_output = str(nested_data) * 100  # Make it very large\n\n        # Create messages so the tool message is in historical part (not recent)\n        # keep_recent_messages=2, so add more than 2 messages after the tool message\n        messages = [\n            ToolMessage(content=complex_output, tool_call_id=\"complex_1\"),  # Historical part\n            HumanMessage(content=\"Question 1\"),\n            AIMessage(content=\"Response 1\"),\n            HumanMessage(content=\"Recent question\"),  # Recent part starts here\n        ]\n        context_manager.manage_context_messages(messages)\n        result = messages\n\n        # Should trim the complex output since it's in historical part\n        tool_msg = next(msg for msg in result if isinstance(msg, ToolMessage))\n        assert len(str(tool_msg.content)) < len(complex_output)\n\n    def test_sql_output_edge_cases(self, context_manager):\n        \"\"\"Test SQL output trimming with edge cases.\"\"\"\n        # SQL with no results\n        empty_sql_output = \"\"\"SQL Query:\n```sql\nSELECT * FROM users WHERE id = -1;\n```\n\nQuery Results (CSV format):\n```csv\nid,name\n```\"\"\"\n\n        # SQL with single row\n        single_row_sql = \"\"\"SQL Query:\n```sql\nSELECT COUNT(*) as total FROM users;\n```\n\nQuery Results (CSV format):\n```csv\ntotal\n42\n```\"\"\"\n\n        # Malformed SQL output\n        malformed_sql = \"\"\"Something that looks like SQL but isn't:\n```sql\nINVALID QUERY HERE\n```\nRandom text after\"\"\"\n\n        test_cases = [empty_sql_output, single_row_sql, malformed_sql]\n\n        for sql_output in test_cases:\n            tool_msg = ToolMessage(content=sql_output, tool_call_id=\"sql_test\")\n            messages = [tool_msg]\n            context_manager.manage_context_messages(messages)\n            result = messages\n\n            # Should handle all cases gracefully\n            assert len(result) == 1\n            assert isinstance(result[0], ToolMessage)\n\n    def test_conversation_state_consistency(self, context_manager):\n        \"\"\"Test that conversation state remains consistent through management.\"\"\"\n        # Create a conversation with specific patterns (no SystemMessage in state)\n        messages = [\n            HumanMessage(content=\"Question 1\"),\n            AIMessage(content=\"Response 1\"),\n            ToolMessage(content=\"Tool result 1\", tool_call_id=\"tool_1\"),\n            HumanMessage(content=\"Question 2\"),\n            AIMessage(\n                content=\"Response 2 with tool calls\",\n                tool_calls=[{\"name\": \"test_tool\", \"args\": {\"param\": \"value\"}, \"id\": \"call_1\"}],\n            ),\n            ToolMessage(content=\"Tool result 2\", tool_call_id=\"call_1\"),\n            HumanMessage(content=\"Final question\"),\n        ]\n\n        with patch(\n            \"openchatbi.context_manager.call_llm_chat_model_with_retry\", return_value=AIMessage(content=\"Summary\")\n        ):\n            context_manager.manage_context_messages(messages)\n            result = messages\n\n        # Should maintain message type consistency (only valid state message types)\n        message_types = [type(msg) for msg in result]\n        valid_types = {HumanMessage, AIMessage, ToolMessage}\n        assert all(\n            msg_type in valid_types for msg_type in message_types\n        ), \"Should only contain valid state message types\"\n\n        # Should not have orphaned tool messages without corresponding AI messages\n        for i, msg in enumerate(result):\n            if isinstance(msg, ToolMessage):\n                # There should be an AI message with tool calls before this\n                previous_ai_msgs = [m for m in result[:i] if isinstance(m, AIMessage)]\n                assert len(previous_ai_msgs) > 0, \"Tool message should have corresponding AI message\"\n"
  },
  {
    "path": "tests/context_management/test_runner.py",
    "content": "\"\"\"Test runner script for context management tests.\"\"\"\n\nimport argparse\nimport subprocess\nimport sys\nfrom pathlib import Path\n\n\ndef run_tests(test_type=\"all\", verbose=False, coverage=False):\n    \"\"\"Run context management tests.\n\n    Args:\n        test_type: Type of tests to run ('all', 'unit', 'integration', 'edge_cases')\n        verbose: Enable verbose output\n        coverage: Enable coverage reporting\n    \"\"\"\n    # Base pytest command\n    cmd = [\"python\", \"-m\", \"pytest\"]\n\n    # Test directory\n    test_dir = Path(__file__).parent\n\n    # Add specific test files based on type\n    if test_type == \"all\":\n        cmd.append(str(test_dir))\n    elif test_type == \"unit\":\n        cmd.extend([str(test_dir / \"test_context_manager.py\"), str(test_dir / \"test_context_config.py\")])\n    elif test_type == \"integration\":\n        cmd.append(str(test_dir / \"test_agent_graph_integration.py\"))\n    elif test_type == \"edge_cases\":\n        cmd.extend([str(test_dir / \"test_edge_cases.py\"), str(test_dir / \"test_state_operations.py\")])\n    else:\n        print(f\"Unknown test type: {test_type}\")\n        return False\n\n    # Add verbose flag\n    if verbose:\n        cmd.append(\"-v\")\n\n    # Add coverage\n    if coverage:\n        cmd.extend(\n            [\n                \"--cov=openchatbi.context_manager\",\n                \"--cov=openchatbi.context_config\",\n                \"--cov-report=html\",\n                \"--cov-report=term-missing\",\n            ]\n        )\n\n    # Add other useful flags\n    cmd.extend(\n        [\n            \"--tb=short\",  # Shorter traceback format\n            \"-x\",  # Stop on first failure\n            \"--strict-markers\",  # Strict marker checking\n        ]\n    )\n\n    print(f\"Running command: {' '.join(cmd)}\")\n    print(\"-\" * 50)\n\n    # Run the tests\n    try:\n        result = subprocess.run(cmd, check=False)\n        return result.returncode == 0\n    except KeyboardInterrupt:\n        print(\"\\nTests interrupted by user\")\n        return False\n\n\ndef main():\n    \"\"\"Main function for test runner.\"\"\"\n    parser = argparse.ArgumentParser(description=\"Run context management tests\")\n\n    parser.add_argument(\n        \"--type\",\n        \"-t\",\n        choices=[\"all\", \"unit\", \"integration\", \"edge_cases\"],\n        default=\"all\",\n        help=\"Type of tests to run (default: all)\",\n    )\n\n    parser.add_argument(\"--verbose\", \"-v\", action=\"store_true\", help=\"Enable verbose output\")\n\n    parser.add_argument(\"--coverage\", \"-c\", action=\"store_true\", help=\"Enable coverage reporting\")\n\n    args = parser.parse_args()\n\n    success = run_tests(test_type=args.type, verbose=args.verbose, coverage=args.coverage)\n\n    if success:\n        print(\"\\n✅ All tests passed!\")\n        sys.exit(0)\n    else:\n        print(\"\\n❌ Some tests failed!\")\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/context_management/test_state_operations.py",
    "content": "\"\"\"Tests for message-based context management operations.\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\nfrom langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n\nfrom openchatbi.context_config import ContextConfig\nfrom openchatbi.context_manager import ContextManager\n\n\nclass TestMessageBasedContextManagement:\n    \"\"\"Test message-based context management with direct modification.\"\"\"\n\n    @pytest.fixture\n    def test_config(self):\n        \"\"\"Configuration for testing message operations.\"\"\"\n        return ContextConfig(\n            enabled=True,\n            summary_trigger_tokens=300,  # Lower threshold to trigger management\n            keep_recent_messages=3,\n            max_tool_output_length=200,\n            preserve_tool_errors=True,\n            preserve_recent_sql=True,\n        )\n\n    @pytest.fixture\n    def context_manager(self, test_config):\n        \"\"\"Context manager for testing.\"\"\"\n        mock_llm = Mock()\n        return ContextManager(llm=mock_llm, config=test_config)\n\n    def test_no_operations_when_disabled(self, context_manager):\n        \"\"\"Test that no operations are performed when context management is disabled.\"\"\"\n        context_manager.config.enabled = False\n\n        messages = [HumanMessage(content=\"Test\", id=\"test_1\")]\n        original_messages = messages.copy()\n\n        context_manager.manage_context_messages(messages)\n        assert messages == original_messages  # Should be unchanged\n\n    def test_no_operations_when_under_limit(self, context_manager):\n        \"\"\"Test that no operations are performed when context is under token limit.\"\"\"\n        # Short messages that won't trigger context management\n        messages = [HumanMessage(content=\"Hi\", id=\"human_1\"), AIMessage(content=\"Hello\", id=\"ai_1\")]\n        original_messages = messages.copy()\n\n        context_manager.manage_context_messages(messages)\n        assert messages == original_messages  # Should be unchanged\n\n    def test_historical_tool_compression(self, context_manager):\n        \"\"\"Test compression of historical tool messages.\"\"\"\n        # Disable conversation summarization to test only tool compression\n        context_manager.config.enable_conversation_summary = False\n        context_manager.config.enable_summarization = False\n\n        # Create messages with large historical tool outputs\n        messages = [\n            HumanMessage(content=\"Query data\", id=\"human_1\"),\n            AIMessage(content=\"Running query\", id=\"ai_1\"),\n            # Large historical tool message (should be compressed)\n            ToolMessage(content=\"A\" * 1000, tool_call_id=\"query_1\", id=\"tool_1_historical\"),  # Large content\n            HumanMessage(content=\"More analysis\", id=\"human_2\"),\n            AIMessage(content=\"Analyzing\", id=\"ai_2\"),\n            # Another large historical tool message\n            ToolMessage(content=\"B\" * 800, tool_call_id=\"query_2\", id=\"tool_2_historical\"),  # Large content\n            # Recent messages (should be preserved)\n            HumanMessage(content=\"Recent question\", id=\"human_recent\"),\n            AIMessage(content=\"Recent response\", id=\"ai_recent\"),\n            ToolMessage(content=\"Recent result\", tool_call_id=\"recent_1\", id=\"tool_recent\"),\n        ]\n\n        original_count = len(messages)\n        context_manager.manage_context_messages(messages)\n\n        # Should have same number of messages but some content should be compressed\n        assert len(messages) == original_count\n\n        # Check that historical tool messages are compressed\n        historical_tool_msgs = [\n            msg\n            for msg in messages\n            if isinstance(msg, ToolMessage) and msg.id in [\"tool_1_historical\", \"tool_2_historical\"]\n        ]\n        for msg in historical_tool_msgs:\n            assert len(str(msg.content)) < 1000, \"Historical tool messages should be compressed\"\n\n    def test_error_message_preservation(self, context_manager):\n        \"\"\"Test that error messages are preserved even if they're historical.\"\"\"\n        error_content = \"\"\"Traceback (most recent call last):\n  File \"test.py\", line 1, in <module>\n    raise ValueError(\"Test error\")\nValueError: Test error\"\"\"\n\n        messages = [\n            HumanMessage(content=\"Run code\", id=\"human_1\"),\n            AIMessage(content=\"Executing\", id=\"ai_1\"),\n            # Historical error message (should be preserved)\n            ToolMessage(content=error_content, tool_call_id=\"code_1\", id=\"error_tool_historical\"),\n            # Recent messages\n            HumanMessage(content=\"What happened?\", id=\"human_recent\"),\n            AIMessage(content=\"There was an error\", id=\"ai_recent\"),\n        ]\n\n        original_error_content = messages[2].content\n        context_manager.manage_context_messages(messages)\n\n        # Error message should be preserved\n        error_msg = next(msg for msg in messages if msg.id == \"error_tool_historical\")\n        assert error_msg.content == original_error_content, \"Error messages should be preserved\"\n\n    def test_sql_content_preservation(self, context_manager):\n        \"\"\"Test that SQL content is preserved when configured.\"\"\"\n        sql_content = \"\"\"SQL Query:\n```sql\nSELECT * FROM users WHERE active = 1;\n```\n\nQuery Results (CSV format):\n```csv\nid,name,email\n1,John,john@example.com\n2,Jane,jane@example.com\n```\"\"\"\n\n        messages = [\n            HumanMessage(content=\"Get user data\", id=\"human_1\"),\n            AIMessage(content=\"Querying users\", id=\"ai_1\"),\n            # Historical SQL result (should be preserved if preserve_recent_sql=True)\n            ToolMessage(content=sql_content, tool_call_id=\"sql_1\", id=\"sql_tool_historical\"),\n            # Recent messages\n            HumanMessage(content=\"Analyze results\", id=\"human_recent\"),\n            AIMessage(content=\"Analyzing\", id=\"ai_recent\"),\n        ]\n\n        # Test with SQL preservation enabled\n        context_manager.config.preserve_recent_sql = True\n        original_sql_content = messages[2].content\n        context_manager.manage_context_messages(messages)\n\n        # SQL should be preserved when preserve_recent_sql=True\n        sql_msg = next(msg for msg in messages if msg.id == \"sql_tool_historical\")\n        assert sql_msg.content == original_sql_content, \"SQL content should be preserved when configured\"\n\n    @patch(\"openchatbi.context_manager.call_llm_chat_model_with_retry\")\n    def test_conversation_summarization(self, mock_llm_call, context_manager):\n        \"\"\"Test conversation summarization with message modification.\"\"\"\n        # Mock LLM response for summarization\n        mock_llm_call.return_value = AIMessage(content=\"Summary of the conversation\")\n\n        # Create a long conversation that will trigger summarization\n        messages = []\n\n        # Add many historical messages\n        for i in range(20):\n            messages.extend(\n                [\n                    HumanMessage(content=f\"Question {i}\" * 10, id=f\"human_{i}\"),\n                    AIMessage(content=f\"Response {i}\" * 10, id=f\"ai_{i}\"),\n                ]\n            )\n\n        # Add recent messages\n        messages.extend(\n            [\n                HumanMessage(content=\"Recent question\", id=\"human_recent\"),\n                AIMessage(content=\"Recent response\", id=\"ai_recent\"),\n                ToolMessage(content=\"Recent result\", tool_call_id=\"recent_1\", id=\"tool_recent\"),\n            ]\n        )\n\n        original_count = len(messages)\n        context_manager.manage_context_messages(messages)\n\n        # Should have fewer messages due to summarization\n        assert len(messages) < original_count\n\n        # Should have a summary message\n        summary_msgs = [msg for msg in messages if isinstance(msg, AIMessage) and \"Summary\" in str(msg.content)]\n        assert len(summary_msgs) > 0, \"Should create a summary message\"\n\n    def test_content_type_detection(self, context_manager):\n        \"\"\"Test content type detection methods.\"\"\"\n        # Test error content detection\n        error_contents = [\n            \"Error: Something went wrong\",\n            \"Traceback (most recent call last):\\n  File test.py\",\n            \"ValueError: Invalid input\",\n            \"Connection failed with status 500\",\n        ]\n\n        for content in error_contents:\n            assert context_manager._is_error_content(content), f\"Should detect error in: {content[:50]}\"\n\n        # Test SQL content detection\n        sql_contents = [\n            \"```sql\\nSELECT * FROM users;\\n```\",\n            \"Query results: 100 rows returned\",\n            \"SQL Query:\\nSELECT id FROM table\",\n        ]\n\n        for content in sql_contents:\n            assert context_manager._is_sql_content(content), f\"Should detect SQL in: {content[:50]}\"\n\n        # Test data query result detection\n        data_contents = [\n            \"```csv\\nid,name\\n1,test\\n```\",\n            \"Query Results (CSV format):\",\n            \"Found 500 records in the database\",\n        ]\n\n        for content in data_contents:\n            assert context_manager._is_data_query_result(content), f\"Should detect data result in: {content[:50]}\"\n\n    def test_should_compress_logic(self, context_manager):\n        \"\"\"Test the logic for determining whether to compress historical tool messages.\"\"\"\n        # Short content should not be compressed\n        short_msg = ToolMessage(content=\"Short\", tool_call_id=\"test\", id=\"short\")\n        assert not context_manager._should_compress_historical_tool_message(short_msg, \"Short\")\n\n        # Long non-error content should be compressed\n        long_content = \"A\" * 1000\n        long_msg = ToolMessage(content=long_content, tool_call_id=\"test\", id=\"long\")\n        assert context_manager._should_compress_historical_tool_message(long_msg, long_content)\n\n        # Long error content should not be compressed (if preserve_tool_errors=True)\n        error_content = \"Error: \" + \"A\" * 1000\n        error_msg = ToolMessage(content=error_content, tool_call_id=\"test\", id=\"error\")\n        context_manager.config.preserve_tool_errors = True\n        assert not context_manager._should_compress_historical_tool_message(error_msg, error_content)\n\n        # But should be compressed if preserve_tool_errors=False\n        context_manager.config.preserve_tool_errors = False\n        assert context_manager._should_compress_historical_tool_message(error_msg, error_content)\n\n    def test_recent_messages_always_preserved(self, context_manager):\n        \"\"\"Test that recent messages are always preserved regardless of content.\"\"\"\n        # Create messages where recent ones are large but should still be preserved\n        messages = []\n\n        # Historical messages\n        for i in range(10):\n            messages.extend(\n                [\n                    HumanMessage(content=f\"Historical {i}\", id=f\"hist_human_{i}\"),\n                    ToolMessage(content=\"A\" * 500, tool_call_id=f\"hist_{i}\", id=f\"hist_tool_{i}\"),\n                ]\n            )\n\n        # Recent messages (including large tool output)\n        messages.extend(\n            [\n                HumanMessage(content=\"Recent question\", id=\"recent_human\"),\n                AIMessage(content=\"Recent response\", id=\"recent_ai\"),\n                ToolMessage(content=\"B\" * 1000, tool_call_id=\"recent\", id=\"recent_tool\"),  # Large but recent\n            ]\n        )\n\n        original_count = len(messages)\n        context_manager.manage_context_messages(messages)\n\n        # Recent messages should be preserved (even if content gets compressed due to summarization)\n        recent_ids = [\"recent_human\", \"recent_ai\", \"recent_tool\"]\n        remaining_recent = [msg for msg in messages if hasattr(msg, \"id\") and msg.id in recent_ids]\n\n        # All recent message IDs should still be present (even if summarization occurred)\n        assert len(remaining_recent) >= 2, \"Most recent messages should be preserved\"\n\n    def test_message_order_preservation(self, context_manager):\n        \"\"\"Test that message ordering is preserved during context management.\"\"\"\n        # Disable conversation summarization to test only tool compression\n        context_manager.config.enable_conversation_summary = False\n        context_manager.config.enable_summarization = False\n\n        # Create messages with specific order\n        messages = [\n            HumanMessage(content=\"Question 1\", id=\"human_1\"),\n            AIMessage(content=\"Response 1\", id=\"ai_1\"),\n            ToolMessage(content=\"A\" * 1000, tool_call_id=\"tool_1\", id=\"tool_1\"),  # Will be compressed\n            HumanMessage(content=\"Question 2\", id=\"human_2\"),\n            AIMessage(content=\"Response 2\", id=\"ai_2\"),\n            ToolMessage(content=\"B\" * 1000, tool_call_id=\"tool_2\", id=\"tool_2\"),  # Will be compressed\n            HumanMessage(content=\"Recent question\", id=\"human_recent\"),  # Recent, should not be compressed\n            AIMessage(content=\"Recent response\", id=\"ai_recent\"),  # Recent\n            ToolMessage(\n                content=\"C\" * 1000, tool_call_id=\"tool_recent\", id=\"tool_recent\"\n            ),  # Recent, should not be compressed\n        ]\n\n        original_order = [msg.id for msg in messages if hasattr(msg, \"id\")]\n        context_manager.manage_context_messages(messages)\n\n        # Extract the IDs in the new order\n        result_order = [msg.id for msg in messages if hasattr(msg, \"id\")]\n\n        # The order should be preserved\n        assert result_order == original_order, \"Message order should be preserved\"\n\n        # Verify that historical tool messages were actually compressed\n        historical_tools = [msg for msg in messages if isinstance(msg, ToolMessage) and msg.id in [\"tool_1\", \"tool_2\"]]\n        for msg in historical_tools:\n            assert len(str(msg.content)) < 1000, \"Historical tool messages should be compressed\"\n"
  },
  {
    "path": "tests/test_catalog_loader.py",
    "content": "\"\"\"Tests for catalog loader functionality.\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\n\nfrom openchatbi.catalog.catalog_loader import DataCatalogLoader, load_catalog_from_data_warehouse\n\n\nclass TestDataCatalogLoader:\n    \"\"\"Test DataCatalogLoader functionality.\"\"\"\n\n    @pytest.fixture\n    def mock_engine(self):\n        \"\"\"Mock SQLAlchemy engine.\"\"\"\n        engine = Mock()\n        return engine\n\n    def test_catalog_loader_initialization(self, mock_engine):\n        \"\"\"Test DataCatalogLoader initialization.\"\"\"\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\") as mock_inspect:\n            mock_inspect.return_value = Mock()\n\n            loader = DataCatalogLoader(engine=mock_engine, include_tables=[\"table1\", \"table2\"])\n\n            assert loader.engine == mock_engine\n            assert loader.include_tables == [\"table1\", \"table2\"]\n\n    def test_catalog_loader_without_include_tables(self, mock_engine):\n        \"\"\"Test DataCatalogLoader without include tables.\"\"\"\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\") as mock_inspect:\n            mock_inspect.return_value = Mock()\n\n            loader = DataCatalogLoader(engine=mock_engine, include_tables=None)\n            assert loader.include_tables is None\n\n    def test_get_tables_and_columns(self, mock_engine):\n        \"\"\"Test getting tables and columns metadata.\"\"\"\n        # Mock inspector\n        mock_inspector = Mock()\n        mock_inspector.get_table_names.return_value = [\"table1\", \"table2\"]\n        mock_inspector.get_columns.return_value = [\n            {\"name\": \"col1\", \"type\": \"VARCHAR(50)\", \"comment\": \"Test column\", \"default\": None, \"primary_key\": False}\n        ]\n\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\", return_value=mock_inspector):\n            loader = DataCatalogLoader(engine=mock_engine, include_tables=[\"table1\"])\n            result = loader.get_tables_and_columns()\n\n            assert \"table1\" in result\n            assert len(result[\"table1\"]) == 1\n            assert result[\"table1\"][0][\"column_name\"] == \"col1\"\n\n    def test_get_table_indexes(self, mock_engine):\n        \"\"\"Test getting table indexes.\"\"\"\n        mock_inspector = Mock()\n        mock_inspector.get_indexes.return_value = [{\"name\": \"idx_test\", \"column_names\": [\"col1\"]}]\n\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\", return_value=mock_inspector):\n            loader = DataCatalogLoader(engine=mock_engine)\n            result = loader.get_table_indexes(\"table1\")\n\n            assert len(result) == 1\n            assert result[0][\"name\"] == \"idx_test\"\n\n    def test_get_foreign_keys(self, mock_engine):\n        \"\"\"Test getting foreign keys.\"\"\"\n        mock_inspector = Mock()\n        mock_inspector.get_foreign_keys.return_value = [\n            {\"name\": \"fk_test\", \"constrained_columns\": [\"col1\"], \"referred_table\": \"ref_table\"}\n        ]\n\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\", return_value=mock_inspector):\n            loader = DataCatalogLoader(engine=mock_engine)\n            result = loader.get_foreign_keys(\"table1\")\n\n            assert len(result) == 1\n            assert result[0][\"name\"] == \"fk_test\"\n\n    def test_save_to_catalog_store_success(self, mock_engine):\n        \"\"\"Test saving to catalog store successfully.\"\"\"\n        mock_catalog_store = Mock()\n        mock_catalog_store.save_table_information.return_value = True\n        mock_catalog_store.save_table_sql_examples.return_value = True\n        mock_catalog_store.save_table_selection_examples.return_value = True\n\n        mock_inspector = Mock()\n        mock_inspector.get_table_names.return_value = [\"table1\"]\n        mock_inspector.get_columns.return_value = [\n            {\"name\": \"col1\", \"type\": \"VARCHAR(50)\", \"comment\": \"Test column\", \"default\": None, \"primary_key\": False}\n        ]\n        mock_inspector.get_table_comment.return_value = {\"text\": \"Test table\"}\n\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\", return_value=mock_inspector):\n            loader = DataCatalogLoader(engine=mock_engine, include_tables=[\"table1\"])\n            result = loader.save_to_catalog_store(mock_catalog_store, \"test_db\")\n\n            assert result == True\n            mock_catalog_store.save_table_information.assert_called()\n            mock_catalog_store.save_table_sql_examples.assert_called()\n            mock_catalog_store.save_table_selection_examples.assert_called()\n\n    def test_save_to_catalog_store_failure(self, mock_engine):\n        \"\"\"Test handling catalog store save failures.\"\"\"\n        mock_catalog_store = Mock()\n        mock_catalog_store.save_table_information.return_value = False\n\n        mock_inspector = Mock()\n        mock_inspector.get_table_names.return_value = [\"table1\"]\n        mock_inspector.get_columns.return_value = []\n\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\", return_value=mock_inspector):\n            loader = DataCatalogLoader(engine=mock_engine)\n            result = loader.save_to_catalog_store(mock_catalog_store)\n\n            assert result == False\n\n    def test_load_catalog_from_data_warehouse(self):\n        \"\"\"Test main entry point for catalog loading.\"\"\"\n        mock_catalog_store = Mock()\n        mock_catalog_store.get_data_warehouse_config.return_value = {\n            \"uri\": \"test://user@host/db\",\n            \"include_tables\": [\"table1\"],\n            \"database_name\": \"test_db\",\n        }\n        mock_catalog_store.get_sql_engine.return_value = Mock()\n\n        with patch(\"openchatbi.catalog.catalog_loader.DataCatalogLoader\") as mock_loader_class:\n            mock_loader = Mock()\n            mock_loader.save_to_catalog_store.return_value = True\n            mock_loader_class.return_value = mock_loader\n\n            result = load_catalog_from_data_warehouse(mock_catalog_store)\n\n            assert result == True\n            mock_loader.save_to_catalog_store.assert_called_once()\n\n    def test_error_handling_in_get_tables_and_columns(self, mock_engine):\n        \"\"\"Test error handling in get_tables_and_columns method.\"\"\"\n        mock_inspector = Mock()\n        mock_inspector.get_table_names.side_effect = Exception(\"Database error\")\n\n        with patch(\"openchatbi.catalog.catalog_loader.inspect\", return_value=mock_inspector):\n            loader = DataCatalogLoader(engine=mock_engine)\n            result = loader.get_tables_and_columns()\n\n            assert result == {}\n"
  },
  {
    "path": "tests/test_catalog_store.py",
    "content": "\"\"\"Tests for catalog store functionality.\"\"\"\n\nimport pytest\n\nfrom openchatbi.catalog.catalog_store import CatalogStore\nfrom openchatbi.catalog.store.file_system import FileSystemCatalogStore\n\n\nclass TestCatalogStore:\n    \"\"\"Test base CatalogStore functionality.\"\"\"\n\n    def test_catalog_store_is_abstract(self):\n        \"\"\"Test that CatalogStore cannot be instantiated directly.\"\"\"\n        with pytest.raises(TypeError):\n            CatalogStore()\n\n    def test_catalog_store_interface_methods(self):\n        \"\"\"Test that CatalogStore defines required interface methods.\"\"\"\n        # Check that abstract methods exist\n        assert hasattr(CatalogStore, \"get_table_list\")\n        assert hasattr(CatalogStore, \"get_column_list\")\n        assert hasattr(CatalogStore, \"get_table_information\")\n        assert hasattr(CatalogStore, \"get_data_warehouse_config\")\n        assert hasattr(CatalogStore, \"get_sql_engine\")\n        assert hasattr(CatalogStore, \"save_table_information\")\n\n\nclass TestFileSystemCatalogStore:\n    \"\"\"Test FileSystemCatalogStore functionality.\"\"\"\n\n    def test_filesystem_store_initialization(self, temp_dir):\n        \"\"\"Test FileSystemCatalogStore initialization.\"\"\"\n        data_warehouse_config = {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"}\n        data_path = str(temp_dir)\n        store = FileSystemCatalogStore(data_path=data_path, data_warehouse_config=data_warehouse_config)\n\n        assert store.data_path == data_path\n        assert isinstance(store, CatalogStore)\n\n    def test_get_tables_from_csv(self, mock_catalog_store):\n        \"\"\"Test getting tables from CSV file.\"\"\"\n        tables = mock_catalog_store.get_table_list()\n\n        assert isinstance(tables, list)\n        assert len(tables) >= 1\n\n    def test_get_columns_from_csv(self, mock_catalog_store):\n        \"\"\"Test getting columns from CSV file.\"\"\"\n        columns = mock_catalog_store.get_column_list(\"test_table\", \"test\")\n\n        assert isinstance(columns, list)\n        if columns:\n            column = columns[0]\n            assert \"column_name\" in column or \"name\" in column\n            assert \"data_type\" in column or \"type\" in column\n\n    def test_get_table_info(self, mock_catalog_store):\n        \"\"\"Test getting table information.\"\"\"\n        table_info = mock_catalog_store.get_table_information(\"test.test_table\")\n\n        assert isinstance(table_info, dict)\n\n    def test_get_tables_file_not_found(self, temp_dir):\n        \"\"\"Test handling when tables file doesn't exist.\"\"\"\n        empty_dir = temp_dir / \"empty\"\n        empty_dir.mkdir()\n\n        data_warehouse_config = {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"}\n        store = FileSystemCatalogStore(data_path=str(empty_dir), data_warehouse_config=data_warehouse_config)\n\n        # Should handle missing file gracefully\n        tables = store.get_table_list()\n        assert isinstance(tables, list)\n\n    def test_get_columns_file_not_found(self, temp_dir):\n        \"\"\"Test handling when columns file doesn't exist.\"\"\"\n        empty_dir = temp_dir / \"empty\"\n        empty_dir.mkdir()\n\n        data_warehouse_config = {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"}\n        store = FileSystemCatalogStore(data_path=str(empty_dir), data_warehouse_config=data_warehouse_config)\n\n        # Should handle missing file gracefully\n        columns = store.get_column_list(\"nonexistent_table\")\n        assert isinstance(columns, list)\n\n    def test_get_tables_malformed_csv(self, temp_dir):\n        \"\"\"Test handling malformed CSV files.\"\"\"\n        # Create malformed CSV\n        malformed_csv = temp_dir / \"table_columns.csv\"\n        malformed_csv.write_text(\"invalid,csv,format\\\\nno,proper\\\\nheaders\")\n\n        data_warehouse_config = {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"}\n        store = FileSystemCatalogStore(data_path=str(temp_dir), data_warehouse_config=data_warehouse_config)\n\n        # Should handle malformed CSV gracefully\n        tables = store.get_table_list()\n        assert isinstance(tables, list)\n\n    def test_get_tables_pandas_error(self, temp_dir):\n        \"\"\"Test handling pandas errors.\"\"\"\n        data_warehouse_config = {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"}\n        store = FileSystemCatalogStore(data_path=str(temp_dir), data_warehouse_config=data_warehouse_config)\n\n        # Should handle pandas errors gracefully\n        tables = store.get_table_list()\n        assert isinstance(tables, list)\n\n    def test_get_table_schema(self, mock_catalog_store):\n        \"\"\"Test getting complete table schema.\"\"\"\n        # Use get_table_information instead of get_table_schema\n        schema = mock_catalog_store.get_table_information(\"test.test_table\")\n\n        assert isinstance(schema, dict)\n\n    def test_search_tables(self, mock_catalog_store):\n        \"\"\"Test searching for tables by keyword.\"\"\"\n        # This method might not exist in current implementation\n        # but it's a common catalog feature\n        if hasattr(mock_catalog_store, \"search_tables\"):\n            results = mock_catalog_store.search_tables(\"test\")\n            assert isinstance(results, list)\n\n    def test_get_all_table_names(self, mock_catalog_store):\n        \"\"\"Test getting all table names.\"\"\"\n        tables = mock_catalog_store.get_table_list()\n        # get_table_list() returns list of strings (table names), not dictionaries\n        assert isinstance(tables, list)\n        # Verify all items are strings\n        for table_name in tables:\n            assert isinstance(table_name, str)\n\n    def test_case_insensitive_table_lookup(self, mock_catalog_store):\n        \"\"\"Test case-insensitive table lookups.\"\"\"\n        # Test with different cases\n        test_cases = [\"test_table\", \"TEST_TABLE\", \"Test_Table\"]\n\n        for table_name in test_cases:\n            columns = mock_catalog_store.get_column_list(table_name)\n            assert isinstance(columns, list)\n\n    def test_data_path_validation(self):\n        \"\"\"Test data path validation.\"\"\"\n        data_warehouse_config = {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"}\n        # Test with None path\n        with pytest.raises((ValueError, TypeError)):\n            FileSystemCatalogStore(data_path=None, data_warehouse_config=data_warehouse_config)\n\n        # Test with empty string\n        with pytest.raises((ValueError, FileNotFoundError)):\n            FileSystemCatalogStore(data_path=\"\", data_warehouse_config=data_warehouse_config)\n\n    def test_concurrent_access(self, mock_catalog_store):\n        \"\"\"Test concurrent access to catalog store.\"\"\"\n        import threading\n        import time\n\n        results = []\n        errors = []\n\n        def worker():\n            try:\n                tables = mock_catalog_store.get_table_list()\n                results.append(len(tables))\n                time.sleep(0.01)\n                columns = mock_catalog_store.get_column_list(\"test_table\", \"test\")\n                results.append(len(columns))\n            except Exception as e:\n                errors.append(e)\n\n        # Create multiple threads\n        threads = []\n        for _ in range(5):\n            thread = threading.Thread(target=worker)\n            threads.append(thread)\n\n        # Start all threads\n        for thread in threads:\n            thread.start()\n\n        # Wait for completion\n        for thread in threads:\n            thread.join()\n\n        # Should not have errors from concurrent access\n        assert len(errors) == 0\n        assert len(results) > 0\n"
  },
  {
    "path": "tests/test_config_loader.py",
    "content": "\"\"\"Tests for configuration loading functionality.\"\"\"\n\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\nimport yaml\n\nfrom openchatbi.config_loader import Config, ConfigLoader\n\n\nclass TestConfigLoader:\n    \"\"\"Test configuration loading functionality.\"\"\"\n\n    def test_config_initialization(self):\n        \"\"\"Test Config model initialization.\"\"\"\n        from unittest.mock import MagicMock\n\n        mock_llm = MagicMock()\n        mock_embedding = MagicMock()\n        config = Config(organization=\"TestOrg\", dialect=\"presto\", default_llm=mock_llm, embedding_model=mock_embedding)\n\n        assert config.organization == \"TestOrg\"\n        assert config.dialect == \"presto\"\n        assert config.default_llm == mock_llm\n        assert config.embedding_model == mock_embedding\n\n    def test_config_from_dict(self):\n        \"\"\"Test creating Config from dictionary.\"\"\"\n        from unittest.mock import MagicMock\n\n        mock_llm = MagicMock()\n        mock_embedding = MagicMock()\n        config_dict = {\n            \"organization\": \"TestOrg\",\n            \"dialect\": \"mysql\",\n            \"default_llm\": mock_llm,\n            \"embedding_model\": mock_embedding,\n        }\n\n        config = Config.from_dict(config_dict)\n        assert config.organization == \"TestOrg\"\n        assert config.dialect == \"mysql\"\n        assert config.default_llm == mock_llm\n        assert config.embedding_model == mock_embedding\n\n    def test_config_loader_initialization(self):\n        \"\"\"Test ConfigLoader initialization.\"\"\"\n        loader = ConfigLoader()\n        # Initially, config should be None until loaded\n        # Don't assert _config state since it depends on previous tests\n\n    def test_load_config_from_file(self, temp_dir):\n        \"\"\"Test loading configuration from YAML file.\"\"\"\n        config_data = {\n            \"organization\": \"TestOrg\",\n            \"dialect\": \"presto\",\n            \"default_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-4\"}},\n            \"embedding_model\": {\n                \"class\": \"langchain_openai.OpenAIEmbeddings\",\n                \"params\": {\"model\": \"text-embedding-ada-002\"},\n            },\n            \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n        }\n\n        config_file = temp_dir / \"test_config.yaml\"\n        with open(config_file, \"w\") as f:\n            yaml.dump(config_data, f)\n\n        loader = ConfigLoader()\n\n        with (\n            patch(\"langchain_openai.ChatOpenAI\") as mock_openai,\n            patch(\"langchain_openai.OpenAIEmbeddings\") as mock_embeddings,\n        ):\n            # Create a proper mock that satisfies BaseChatModel interface\n            from langchain_core.language_models import BaseChatModel\n\n            mock_llm_instance = MagicMock(spec=BaseChatModel)\n            mock_embedding_instance = MagicMock()\n            mock_openai.return_value = mock_llm_instance\n            mock_embeddings.return_value = mock_embedding_instance\n\n            loader.load(str(config_file))\n\n        config = loader.get()\n        assert config.organization == \"TestOrg\"\n        assert config.dialect == \"presto\"\n        assert config.default_llm == mock_llm_instance\n        assert config.embedding_model == mock_embedding_instance\n\n    def test_load_config_missing_file(self):\n        \"\"\"Test handling of missing configuration file.\"\"\"\n        loader = ConfigLoader()\n\n        # Reset the config to ensure clean state\n        loader._config = None\n\n        # The loader now logs and returns instead of raising FileNotFoundError\n        loader.load(\"/nonexistent/path.yaml\")\n\n        # Verify that the config was not loaded (remains None)\n        with pytest.raises(ValueError, match=\"Configuration has not been loaded\"):\n            loader.get()\n\n    def test_load_config_invalid_yaml(self, temp_dir):\n        \"\"\"Test handling of invalid YAML syntax.\"\"\"\n        config_file = temp_dir / \"invalid_config.yaml\"\n        config_file.write_text(\"invalid: yaml: content: [\")\n\n        loader = ConfigLoader()\n\n        with pytest.raises(ValueError, match=\"Invalid YAML in configuration file\"):\n            loader.load(str(config_file))\n\n    def test_load_config_with_bi_config_file(self, temp_dir):\n        \"\"\"Test loading configuration with BI config file.\"\"\"\n        bi_config_data = {\"metrics\": [\"revenue\", \"users\"], \"dimensions\": [\"date\", \"region\"]}\n\n        bi_config_file = temp_dir / \"bi_config.yaml\"\n        with open(bi_config_file, \"w\") as f:\n            yaml.dump(bi_config_data, f)\n\n        config_data = {\n            \"organization\": \"TestOrg\",\n            \"bi_config_file\": str(bi_config_file),\n            \"default_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-4\"}},\n            \"embedding_model\": {\n                \"class\": \"langchain_openai.OpenAIEmbeddings\",\n                \"params\": {\"model\": \"text-embedding-ada-002\"},\n            },\n            \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n        }\n\n        config_file = temp_dir / \"test_config.yaml\"\n        with open(config_file, \"w\") as f:\n            yaml.dump(config_data, f)\n\n        loader = ConfigLoader()\n\n        with (\n            patch(\"langchain_openai.ChatOpenAI\") as mock_openai,\n            patch(\"langchain_openai.OpenAIEmbeddings\") as mock_embeddings,\n        ):\n            mock_llm_instance = MagicMock()\n            mock_embedding_instance = MagicMock()\n            mock_openai.return_value = mock_llm_instance\n            mock_embeddings.return_value = mock_embedding_instance\n\n            loader.load(str(config_file))\n\n        config = loader.get()\n        assert config.bi_config[\"metrics\"] == [\"revenue\", \"users\"]\n        assert config.bi_config[\"dimensions\"] == [\"date\", \"region\"]\n\n    def test_load_config_with_catalog_store(self, temp_dir):\n        \"\"\"Test loading configuration with catalog store.\"\"\"\n        config_data = {\n            \"organization\": \"TestOrg\",\n            \"catalog_store\": {\"store_type\": \"file_system\", \"data_path\": str(temp_dir / \"catalog_data\")},\n            \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n            \"default_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-4\"}},\n            \"embedding_model\": {\n                \"class\": \"langchain_openai.OpenAIEmbeddings\",\n                \"params\": {\"model\": \"text-embedding-ada-002\"},\n            },\n        }\n\n        config_file = temp_dir / \"test_config.yaml\"\n        with open(config_file, \"w\") as f:\n            yaml.dump(config_data, f)\n\n        loader = ConfigLoader()\n\n        with (\n            patch(\"langchain_openai.ChatOpenAI\") as mock_openai,\n            patch(\"langchain_openai.OpenAIEmbeddings\") as mock_embeddings,\n        ):\n            mock_llm_instance = MagicMock()\n            mock_embedding_instance = MagicMock()\n            mock_openai.return_value = mock_llm_instance\n            mock_embeddings.return_value = mock_embedding_instance\n\n            loader.load(str(config_file))\n\n        config = loader.get()\n        # Just verify that a catalog store was created\n        assert config.catalog_store is not None\n        assert hasattr(config.catalog_store, \"get_table_list\")\n\n    def test_load_config_with_llm_configs(self, temp_dir):\n        \"\"\"Test loading configuration with LLM configs.\"\"\"\n        config_data = {\n            \"organization\": \"TestOrg\",\n            \"default_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-4\", \"temperature\": 0.1}},\n            \"embedding_model\": {\n                \"class\": \"langchain_openai.OpenAIEmbeddings\",\n                \"params\": {\"model\": \"text-embedding-ada-002\"},\n            },\n            \"text2sql_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-3.5-turbo\"}},\n            \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n        }\n\n        config_file = temp_dir / \"test_config.yaml\"\n        with open(config_file, \"w\") as f:\n            yaml.dump(config_data, f)\n\n        loader = ConfigLoader()\n\n        with (\n            patch(\"langchain_openai.ChatOpenAI\") as mock_openai,\n            patch(\"langchain_openai.OpenAIEmbeddings\") as mock_embeddings,\n        ):\n            # Create proper mocks that satisfy BaseChatModel interface\n            from langchain_core.language_models import BaseChatModel\n\n            mock_instance1 = MagicMock(spec=BaseChatModel)\n            mock_instance2 = MagicMock(spec=BaseChatModel)\n            mock_embedding_instance = MagicMock()\n            mock_openai.side_effect = [mock_instance1, mock_instance2]\n            mock_embeddings.return_value = mock_embedding_instance\n\n            loader.load(str(config_file))\n\n        config = loader.get()\n        assert config.default_llm == mock_instance1\n        assert config.embedding_model == mock_embedding_instance\n        assert config.text2sql_llm == mock_instance2\n\n    def test_load_config_with_llm_providers_selected_by_default_llm(self, temp_dir):\n        \"\"\"Test loading configuration using llm_providers with default_llm provider selector.\"\"\"\n        config_data = {\n            \"organization\": \"TestOrg\",\n            \"dialect\": \"presto\",\n            \"default_llm\": \"openai\",\n            \"llm_providers\": {\n                \"openai\": {\n                    \"default_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-4\"}},\n                    \"embedding_model\": {\n                        \"class\": \"langchain_openai.OpenAIEmbeddings\",\n                        \"params\": {\"model\": \"text-embedding-ada-002\"},\n                    },\n                },\n                \"anthropic\": {\n                    \"default_llm\": {\"class\": \"langchain_anthropic.ChatAnthropic\", \"params\": {\"model\": \"claude\"}},\n                },\n            },\n            \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n        }\n\n        config_file = temp_dir / \"test_config.yaml\"\n        with open(config_file, \"w\") as f:\n            yaml.dump(config_data, f)\n\n        loader = ConfigLoader()\n\n        with (\n            patch(\"langchain_openai.ChatOpenAI\") as mock_openai,\n            patch(\"langchain_openai.OpenAIEmbeddings\") as mock_embeddings,\n            patch(\"langchain_anthropic.ChatAnthropic\") as mock_anthropic,\n        ):\n            from langchain_core.language_models import BaseChatModel\n\n            mock_openai_instance = MagicMock(spec=BaseChatModel)\n            mock_anthropic_instance = MagicMock(spec=BaseChatModel)\n            mock_embedding_instance = MagicMock()\n            mock_openai.return_value = mock_openai_instance\n            mock_anthropic.return_value = mock_anthropic_instance\n            mock_embeddings.return_value = mock_embedding_instance\n\n            loader.load(str(config_file))\n\n        config = loader.get()\n        assert config.llm_provider == \"openai\"\n        assert config.default_llm == mock_openai_instance\n        assert config.embedding_model == mock_embedding_instance\n        assert set(config.llm_providers.keys()) == {\"openai\", \"anthropic\"}\n\n    def test_set_config(self):\n        \"\"\"Test setting configuration from dictionary.\"\"\"\n        config_dict = {\n            \"organization\": \"SetOrg\",\n            \"dialect\": \"postgresql\",\n            \"default_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-4\"}},\n            \"embedding_model\": {\n                \"class\": \"langchain_openai.OpenAIEmbeddings\",\n                \"params\": {\"model\": \"text-embedding-ada-002\"},\n            },\n            \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n        }\n\n        loader = ConfigLoader()\n\n        with (\n            patch(\"langchain_openai.ChatOpenAI\") as mock_openai,\n            patch(\"langchain_openai.OpenAIEmbeddings\") as mock_embeddings,\n        ):\n            mock_llm_instance = MagicMock()\n            mock_embedding_instance = MagicMock()\n            mock_openai.return_value = mock_llm_instance\n            mock_embeddings.return_value = mock_embedding_instance\n\n            loader.set(config_dict)\n\n        config = loader.get()\n        assert config.organization == \"SetOrg\"\n        assert config.dialect == \"postgresql\"\n\n    def test_get_config_not_loaded(self):\n        \"\"\"Test getting configuration when not loaded.\"\"\"\n        loader = ConfigLoader()\n        loader._config = None\n\n        with pytest.raises(ValueError, match=\"Configuration has not been loaded\"):\n            loader.get()\n\n    def test_load_bi_config_missing_file(self, temp_dir):\n        \"\"\"Test loading missing BI config file.\"\"\"\n        nonexistent_file = temp_dir / \"nonexistent_bi.yaml\"\n\n        loader = ConfigLoader()\n\n        # Should not raise exception, just return empty dict\n        result = loader.load_bi_config(str(nonexistent_file))\n        assert result == {}\n\n    def test_catalog_store_missing_store_type(self, temp_dir):\n        \"\"\"Test catalog store configuration without store_type.\"\"\"\n        config_data = {\n            \"organization\": \"TestOrg\",\n            \"catalog_store\": {\n                \"data_path\": \"/test/path\"\n                # Missing store_type\n            },\n            \"default_llm\": {\"class\": \"langchain_openai.ChatOpenAI\", \"params\": {\"model\": \"gpt-4\"}},\n            \"embedding_model\": {\n                \"class\": \"langchain_openai.OpenAIEmbeddings\",\n                \"params\": {\"model\": \"text-embedding-ada-002\"},\n            },\n            \"data_warehouse_config\": {\"uri\": \"sqlite:///:memory:\", \"include_tables\": None, \"database_name\": \"test_db\"},\n        }\n\n        config_file = temp_dir / \"test_config.yaml\"\n        with open(config_file, \"w\") as f:\n            yaml.dump(config_data, f)\n\n        loader = ConfigLoader()\n\n        with pytest.raises(ValueError, match=\"catalog_store must have a store_type field\"):\n            loader.load(str(config_file))\n"
  },
  {
    "path": "tests/test_graph_state.py",
    "content": "\"\"\"Tests for graph state management.\"\"\"\n\nfrom langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n\nfrom openchatbi.graph_state import AgentState, InputState, OutputState\n\n\nclass TestAgentState:\n    \"\"\"Test AgentState functionality.\"\"\"\n\n    def test_agent_state_with_data(self):\n        \"\"\"Test creating AgentState with initial data.\"\"\"\n        messages = [HumanMessage(content=\"Test message\")]\n        sql = \"SELECT * FROM test_table;\"\n        agent_next_node = \"sql_generation\"\n        final_answer = \"Here is your data\"\n\n        state = AgentState(messages=messages, sql=sql, agent_next_node=agent_next_node, final_answer=final_answer)\n\n        assert state[\"messages\"] == messages\n        assert state[\"sql\"] == sql\n        assert state[\"agent_next_node\"] == agent_next_node\n        assert state[\"final_answer\"] == final_answer\n\n    def test_agent_state_message_types(self):\n        \"\"\"Test AgentState with different message types.\"\"\"\n        messages = [\n            HumanMessage(content=\"User question\"),\n            AIMessage(content=\"AI response\"),\n            ToolMessage(content=\"Tool result\", tool_call_id=\"test_id\"),\n        ]\n\n        state = AgentState(messages=messages)\n\n        assert len(state[\"messages\"]) == 3\n        assert isinstance(state[\"messages\"][0], HumanMessage)\n        assert isinstance(state[\"messages\"][1], AIMessage)\n        assert isinstance(state[\"messages\"][2], ToolMessage)\n\n    def test_agent_state_immutability(self):\n        \"\"\"Test that AgentState behaves correctly with updates.\"\"\"\n        original_state = AgentState(\n            messages=[HumanMessage(content=\"Original\")],\n            sql=\"SELECT 1;\",\n            agent_next_node=\"original_node\",\n            final_answer=\"Original answer\",\n        )\n\n        # Create updated state\n        new_messages = original_state[\"messages\"] + [AIMessage(content=\"Response\")]\n        updated_state = AgentState(\n            messages=new_messages, sql=\"SELECT 2;\", agent_next_node=\"updated_node\", final_answer=\"Updated answer\"\n        )\n\n        # Original state should remain unchanged\n        assert len(original_state[\"messages\"]) == 1\n        assert original_state[\"sql\"] == \"SELECT 1;\"\n        assert original_state[\"agent_next_node\"] == \"original_node\"\n        assert original_state[\"final_answer\"] == \"Original answer\"\n\n        # Updated state should have new values\n        assert len(updated_state[\"messages\"]) == 2\n        assert updated_state[\"sql\"] == \"SELECT 2;\"\n        assert updated_state[\"agent_next_node\"] == \"updated_node\"\n        assert updated_state[\"final_answer\"] == \"Updated answer\"\n\n\nclass TestInputState:\n    \"\"\"Test InputState functionality.\"\"\"\n\n    def test_input_state_creation(self):\n        \"\"\"Test creating InputState.\"\"\"\n        messages = [HumanMessage(content=\"Input message\")]\n\n        state = InputState(messages=messages)\n\n        assert state[\"messages\"] == messages\n\n    def test_input_state_empty_messages(self):\n        \"\"\"Test InputState with empty messages.\"\"\"\n        state = InputState(messages=[])\n\n        assert state[\"messages\"] == []\n\n\nclass TestOutputState:\n    \"\"\"Test OutputState functionality.\"\"\"\n\n    def test_output_state_creation(self):\n        \"\"\"Test creating OutputState.\"\"\"\n        messages = [AIMessage(content=\"Output message\")]\n\n        state = OutputState(messages=messages)\n\n        assert state[\"messages\"] == messages\n\n    def test_output_state_with_multiple_messages(self):\n        \"\"\"Test OutputState with conversation history.\"\"\"\n        messages = [\n            HumanMessage(content=\"Question\"),\n            AIMessage(content=\"Answer\"),\n            HumanMessage(content=\"Follow-up\"),\n            AIMessage(content=\"Final response\"),\n        ]\n\n        state = OutputState(messages=messages)\n\n        assert len(state[\"messages\"]) == 4\n        assert state[\"messages\"] == messages\n\n\nclass TestStateIntegration:\n    \"\"\"Test integration between different state types.\"\"\"\n\n    def test_input_to_agent_state_conversion(self):\n        \"\"\"Test converting InputState to AgentState.\"\"\"\n        input_messages = [HumanMessage(content=\"User input\")]\n        input_state = InputState(messages=input_messages)\n\n        # Simulate conversion to AgentState\n        agent_state = AgentState(messages=input_state[\"messages\"], sql=\"\", agent_next_node=\"\", final_answer=\"\")\n\n        assert agent_state[\"messages\"] == input_messages\n        assert agent_state[\"sql\"] == \"\"\n\n    def test_agent_to_output_state_conversion(self):\n        \"\"\"Test converting AgentState to OutputState.\"\"\"\n        agent_messages = [HumanMessage(content=\"Question\"), AIMessage(content=\"Generated response\")]\n\n        agent_state = AgentState(\n            messages=agent_messages,\n            sql=\"SELECT * FROM test_table;\",\n            agent_next_node=\"output\",\n            final_answer=\"Generated response\",\n        )\n\n        # Simulate conversion to OutputState\n        output_state = OutputState(messages=agent_state[\"messages\"])\n\n        assert output_state[\"messages\"] == agent_messages\n\n    def test_state_serialization_compatibility(self):\n        \"\"\"Test that states can be serialized and deserialized.\"\"\"\n        original_state = AgentState(\n            messages=[HumanMessage(content=\"Test\"), AIMessage(content=\"Response\")],\n            sql=\"SELECT COUNT(*) FROM table1;\",\n            agent_next_node=\"final\",\n            final_answer=\"Count results\",\n        )\n\n        # Convert to dict (simulating serialization)\n        state_dict = {\n            \"messages\": original_state[\"messages\"],\n            \"sql\": original_state[\"sql\"],\n            \"agent_next_node\": original_state[\"agent_next_node\"],\n            \"final_answer\": original_state[\"final_answer\"],\n        }\n\n        # Recreate from dict (simulating deserialization)\n        recreated_state = AgentState(**state_dict)\n\n        assert recreated_state[\"messages\"] == original_state[\"messages\"]\n        assert recreated_state[\"sql\"] == original_state[\"sql\"]\n        assert recreated_state[\"agent_next_node\"] == original_state[\"agent_next_node\"]\n        assert recreated_state[\"final_answer\"] == original_state[\"final_answer\"]\n"
  },
  {
    "path": "tests/test_incomplete_tool_calls.py",
    "content": "\"\"\"Tests for incomplete tool call recovery functionality.\"\"\"\n\nfrom unittest.mock import Mock\n\nfrom langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n\nfrom openchatbi.agent_graph import agent_llm_call\nfrom openchatbi.graph_state import AgentState\nfrom openchatbi.utils import recover_incomplete_tool_calls\n\n\nclass TestIncompleteToolCallRecovery:\n    \"\"\"Test cases for recover_incomplete_tool_calls function.\"\"\"\n\n    def test_no_messages(self):\n        \"\"\"Test recovery with empty message list.\"\"\"\n        state = AgentState(messages=[])\n        result = recover_incomplete_tool_calls(state)\n        assert result == []\n\n    def test_no_tool_calls(self):\n        \"\"\"Test recovery with messages but no tool calls.\"\"\"\n        messages = [HumanMessage(content=\"Hello\"), AIMessage(content=\"Hi there!\")]\n        state = AgentState(messages=messages)\n        result = recover_incomplete_tool_calls(state)\n        assert result == []\n\n    def test_complete_tool_calls(self):\n        \"\"\"Test recovery when all tool calls have responses.\"\"\"\n        messages = [\n            HumanMessage(content=\"Search for data\"),\n            AIMessage(\n                content=\"I'll search for that data.\",\n                tool_calls=[{\"name\": \"search\", \"args\": {\"query\": \"data\"}, \"id\": \"call_1\"}],\n            ),\n            ToolMessage(content=\"Search completed\", tool_call_id=\"call_1\"),\n        ]\n        state = AgentState(messages=messages)\n        result = recover_incomplete_tool_calls(state)\n        assert result == []\n\n    def test_incomplete_single_tool_call(self):\n        \"\"\"Test recovery when there's one incomplete tool call.\"\"\"\n        messages = [\n            HumanMessage(content=\"Search for data\"),\n            AIMessage(\n                content=\"I'll search for that data.\",\n                tool_calls=[{\"name\": \"search\", \"args\": {\"query\": \"data\"}, \"id\": \"call_1\"}],\n            ),\n        ]\n        state = AgentState(messages=messages)\n        result = recover_incomplete_tool_calls(state)\n\n        assert isinstance(result, list)\n        assert len(result) == 1  # Just the recovery message\n\n        failure_msg = result[0]\n        assert failure_msg.tool_call_id == \"call_1\"\n        assert \"interrupted\" in failure_msg.content.lower()\n\n    def test_incomplete_multiple_tool_calls(self):\n        \"\"\"Test recovery when there are multiple incomplete tool calls.\"\"\"\n        messages = [\n            HumanMessage(content=\"Search and analyze\"),\n            AIMessage(\n                content=\"I'll search and analyze.\",\n                tool_calls=[\n                    {\"name\": \"search\", \"args\": {\"query\": \"data\"}, \"id\": \"call_1\"},\n                    {\"name\": \"analyze\", \"args\": {\"data\": \"result\"}, \"id\": \"call_2\"},\n                ],\n            ),\n        ]\n        state = AgentState(messages=messages)\n        result = recover_incomplete_tool_calls(state)\n\n        assert isinstance(result, list)\n        assert len(result) == 2  # Just the recovery messages\n\n        # Check that both tool calls get failure messages\n        recovery_messages = result\n        tool_call_ids = {msg.tool_call_id for msg in recovery_messages}\n        assert tool_call_ids == {\"call_1\", \"call_2\"}\n\n        for msg in recovery_messages:\n            assert isinstance(msg, ToolMessage)\n            assert \"interrupted\" in msg.content.lower()\n\n    def test_partial_incomplete_tool_calls(self):\n        \"\"\"Test recovery when some tool calls are complete, others are not.\"\"\"\n        messages = [\n            HumanMessage(content=\"Search and analyze\"),\n            AIMessage(\n                content=\"I'll search and analyze.\",\n                tool_calls=[\n                    {\"name\": \"search\", \"args\": {\"query\": \"data\"}, \"id\": \"call_1\"},\n                    {\"name\": \"analyze\", \"args\": {\"data\": \"result\"}, \"id\": \"call_2\"},\n                ],\n            ),\n            ToolMessage(content=\"Search completed\", tool_call_id=\"call_1\"),\n            # Missing ToolMessage for call_2\n        ]\n        state = AgentState(messages=messages)\n        result = recover_incomplete_tool_calls(state)\n\n        assert isinstance(result, list)\n        assert len(result) == 3  # RemoveMessage + recovery message + re-added message\n\n        # Should have: RemoveMessage, ToolMessage(recovery for call_2), ToolMessage(original for call_1)\n        operations = result\n        assert \"RemoveMessage\" in str(type(operations[0]))  # Remove the existing ToolMessage\n        assert isinstance(operations[1], ToolMessage)  # Recovery message for call_2\n        assert isinstance(operations[2], ToolMessage)  # Re-added original message for call_1\n\n        # The recovery message should be for call_2\n        recovery_msg = operations[1]\n        assert recovery_msg.tool_call_id == \"call_2\"\n        assert \"interrupted\" in recovery_msg.content.lower()\n\n        # The re-added message should be the original for call_1\n        original_msg = operations[2]\n        assert original_msg.tool_call_id == \"call_1\"\n        assert original_msg.content == \"Search completed\"\n\n    def test_multiple_ai_messages_with_tool_calls(self):\n        \"\"\"Test recovery considers only the last AIMessage with tool calls.\"\"\"\n        messages = [\n            HumanMessage(content=\"First task\"),\n            AIMessage(content=\"Doing first task.\", tool_calls=[{\"name\": \"task1\", \"args\": {}, \"id\": \"old_call\"}]),\n            ToolMessage(content=\"Task 1 done\", tool_call_id=\"old_call\"),\n            HumanMessage(content=\"Second task\"),\n            AIMessage(content=\"Doing second task.\", tool_calls=[{\"name\": \"task2\", \"args\": {}, \"id\": \"new_call\"}]),\n            # Missing ToolMessage for new_call\n        ]\n        state = AgentState(messages=messages)\n        result = recover_incomplete_tool_calls(state)\n\n        assert isinstance(result, list)\n        assert len(result) == 1  # Just the recovery message\n\n        # The recovery message should be for new_call only\n        recovery_msg = result[0]\n        assert recovery_msg.tool_call_id == \"new_call\"\n        assert \"interrupted\" in recovery_msg.content.lower()\n\n    def test_llm_node_integration_with_recovery(self):\n        \"\"\"Test that the llm_node handles recovery correctly and continues processing.\"\"\"\n        # Create a mock llm_node function for testing\n        mock_llm = Mock()\n        mock_tools = []\n        llm_node_func = agent_llm_call(mock_llm, mock_tools)\n\n        # State with incomplete tool calls\n        messages = [\n            HumanMessage(content=\"Search for data\"),\n            AIMessage(\n                content=\"I'll search for that data.\",\n                tool_calls=[{\"name\": \"search\", \"args\": {\"query\": \"data\"}, \"id\": \"call_1\"}],\n            ),\n        ]\n        state = AgentState(messages=messages)\n\n        # Call the llm node - it should detect incomplete tool calls and return recovery\n        result = llm_node_func(state)\n\n        # Should return message operations and continue to llm_node\n        assert \"messages\" in result\n        assert \"agent_next_node\" in result\n        assert result[\"agent_next_node\"] == \"llm_node\"\n\n        # Should have recovery ToolMessage operation for the incomplete call\n        operations = result[\"messages\"]\n        assert len(operations) == 1  # Only recovery message needed\n        assert isinstance(operations[0], ToolMessage)\n        assert operations[0].tool_call_id == \"call_1\"\n"
  },
  {
    "path": "tests/test_memory.py",
    "content": "\"\"\"Tests for memory tool functionality.\"\"\"\n\nfrom pathlib import Path\nfrom unittest.mock import AsyncMock, Mock, patch\n\nimport pytest\nfrom langchain_core.language_models import FakeListChatModel\nfrom langchain_openai import ChatOpenAI\n\n# Check if pysqlite3 is available, if not skip these tests\npysqlite3 = pytest.importorskip(\"pysqlite3\", reason=\"pysqlite3 not available\")\n\nfrom openchatbi.tool.memory import (\n    StructuredToolWithRequired,\n    UserProfile,\n    cleanup_async_memory_store,\n    fix_schema_for_openai,\n    get_async_memory_store,\n    get_async_memory_tools,\n    get_memory_manager,\n    get_memory_tools,\n    get_sync_memory_store,\n    setup_async_memory_store,\n)\n\n\nclass TestUserProfile:\n    \"\"\"Test UserProfile model functionality.\"\"\"\n\n    def test_user_profile_basic_initialization(self):\n        \"\"\"Test basic UserProfile model creation.\"\"\"\n        profile = UserProfile(name=\"John Doe\", language=\"English\", timezone=\"UTC\", jargon=\"Technical\")\n\n        assert profile.name == \"John Doe\"\n        assert profile.language == \"English\"\n        assert profile.timezone == \"UTC\"\n        assert profile.jargon == \"Technical\"\n\n    def test_user_profile_optional_fields(self):\n        \"\"\"Test UserProfile with optional fields.\"\"\"\n        profile = UserProfile()\n\n        assert profile.name is None\n        assert profile.language is None\n        assert profile.timezone is None\n        assert profile.jargon is None\n\n    def test_user_profile_partial_initialization(self):\n        \"\"\"Test UserProfile with partial field initialization.\"\"\"\n        profile = UserProfile(name=\"Jane Smith\", language=\"Spanish\")\n\n        assert profile.name == \"Jane Smith\"\n        assert profile.language == \"Spanish\"\n        assert profile.timezone is None\n        assert profile.jargon is None\n\n    def test_user_profile_serialization(self):\n        \"\"\"Test UserProfile model serialization.\"\"\"\n        profile = UserProfile(name=\"Test User\", timezone=\"EST\")\n\n        data = profile.model_dump()\n        assert data[\"name\"] == \"Test User\"\n        assert data[\"timezone\"] == \"EST\"\n        assert data[\"language\"] is None\n        assert data[\"jargon\"] is None\n\n\nclass TestMemoryStoreManagement:\n    \"\"\"Test memory store management functions.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def setup_test_env(self, tmp_path: Path):\n        \"\"\"Setup test environment with temporary database.\"\"\"\n        self.temp_db_path = tmp_path / \"test_memory.db\"\n        # Clean up any global state\n        import openchatbi.tool.memory as memory_module\n\n        memory_module.sync_memory_store = None\n        memory_module.async_memory_store = None\n        memory_module.async_store_context_manager = None\n\n    @patch(\"openchatbi.tool.memory.sqlite3.connect\")\n    @patch(\"openchatbi.tool.memory.config.get\")\n    def test_get_sync_memory_store(self, mock_config, mock_connect):\n        \"\"\"Test sync memory store creation.\"\"\"\n        mock_config.return_value.embedding_model = Mock()\n        mock_conn = Mock()\n        mock_connect.return_value = mock_conn\n\n        # Mock SqliteStore\n        with patch(\"openchatbi.tool.memory.SqliteStore\") as mock_store_class:\n            mock_store = Mock()\n            mock_store_class.return_value = mock_store\n\n            store = get_sync_memory_store()\n\n            assert store == mock_store\n            mock_store_class.assert_called_once()\n            mock_store.setup.assert_called_once()\n\n    @pytest.mark.asyncio\n    @patch(\"openchatbi.tool.memory.AsyncSqliteStore.from_conn_string\")\n    @patch(\"openchatbi.tool.memory.config.get\")\n    async def test_get_async_memory_store(self, mock_config, mock_from_conn_string):\n        \"\"\"Test async memory store creation.\"\"\"\n        mock_config.return_value.embedding_model = Mock()\n\n        # Mock the async context manager\n        mock_context_manager = AsyncMock()\n        mock_store = Mock()\n        mock_context_manager.__aenter__.return_value = mock_store\n        mock_from_conn_string.return_value = mock_context_manager\n\n        store = await get_async_memory_store()\n\n        assert store == mock_store\n        mock_from_conn_string.assert_called_once()\n        mock_context_manager.__aenter__.assert_called_once()\n\n    @pytest.mark.asyncio\n    @patch(\"openchatbi.tool.memory.async_memory_store\", new=Mock())\n    @patch(\"openchatbi.tool.memory.async_store_context_manager\")\n    async def test_cleanup_async_memory_store(self, mock_context_manager):\n        \"\"\"Test async memory store cleanup.\"\"\"\n        mock_context_manager.__aexit__ = AsyncMock()\n\n        await cleanup_async_memory_store()\n\n        mock_context_manager.__aexit__.assert_called_once_with(None, None, None)\n\n    @pytest.mark.asyncio\n    @patch(\"openchatbi.tool.memory.get_async_memory_store\")\n    async def test_setup_async_memory_store(self, mock_get_store):\n        \"\"\"Test async memory store setup.\"\"\"\n        mock_store = Mock()\n        mock_get_store.return_value = mock_store\n\n        result = await setup_async_memory_store()\n\n        mock_get_store.assert_called_once()\n        assert result is None\n\n\nclass TestMemoryTools:\n    \"\"\"Test memory tools creation and management.\"\"\"\n\n    @patch(\"openchatbi.tool.memory.create_manage_memory_tool\")\n    @patch(\"openchatbi.tool.memory.create_search_memory_tool\")\n    @patch(\"openchatbi.tool.memory.get_sync_memory_store\")\n    def test_get_memory_tools_sync_mode(self, mock_get_store, mock_search_tool, mock_manage_tool):\n        \"\"\"Test getting memory tools in sync mode.\"\"\"\n        mock_llm = FakeListChatModel(responses=[\"test\"])\n        mock_store = Mock()\n        mock_get_store.return_value = mock_store\n\n        mock_manage = Mock()\n        mock_search = Mock()\n        mock_manage_tool.return_value = mock_manage\n        mock_search_tool.return_value = mock_search\n\n        memory_tools = get_memory_tools(mock_llm, sync_mode=True)\n        manage_tool, search_tool = memory_tools[0], memory_tools[1]\n\n        assert manage_tool == mock_manage\n        assert search_tool == mock_search\n        mock_manage_tool.assert_called_once_with(namespace=(\"memories\", \"{user_id}\"), store=mock_store)\n        mock_search_tool.assert_called_once_with(namespace=(\"memories\", \"{user_id}\"), store=mock_store)\n\n    @patch(\"openchatbi.tool.memory.create_manage_memory_tool\")\n    @patch(\"openchatbi.tool.memory.create_search_memory_tool\")\n    @patch(\"openchatbi.tool.memory.config.get\")\n    def test_get_memory_tools_with_openai_llm(self, mock_config, mock_search_tool, mock_manage_tool):\n        \"\"\"Test getting memory tools with OpenAI LLM (requires structured tool wrapper).\"\"\"\n        mock_llm = Mock(spec=ChatOpenAI)\n        mock_config.return_value.embedding_model = Mock()\n\n        mock_manage = Mock()\n        mock_search = Mock()\n        mock_manage_tool.return_value = mock_manage\n        mock_search_tool.return_value = mock_search\n\n        with patch(\"openchatbi.tool.memory.StructuredToolWithRequired\") as mock_wrapper:\n            mock_wrapped_manage = Mock()\n            mock_wrapped_search = Mock()\n            mock_wrapper.side_effect = [mock_wrapped_manage, mock_wrapped_search]\n\n            memory_tools = get_memory_tools(mock_llm, sync_mode=True)\n            manage_tool, search_tool = memory_tools[0], memory_tools[1]\n\n            assert manage_tool == mock_wrapped_manage\n            assert search_tool == mock_wrapped_search\n            assert mock_wrapper.call_count == 2\n\n    @pytest.mark.asyncio\n    @patch(\"openchatbi.tool.memory.get_async_memory_store\")\n    @patch(\"openchatbi.tool.memory.get_memory_tools\")\n    @patch(\"openchatbi.tool.memory.config.get\")\n    async def test_get_async_memory_tools(self, mock_config, mock_get_tools, mock_get_store):\n        \"\"\"Test getting async memory tools.\"\"\"\n        mock_llm = FakeListChatModel(responses=[\"test\"])\n        mock_store = Mock()\n        mock_get_store.return_value = mock_store\n        mock_config.return_value.embedding_model = Mock()\n\n        mock_manage = Mock()\n        mock_search = Mock()\n        mock_get_tools.return_value = (mock_manage, mock_search)\n\n        manage_tool, search_tool = await get_async_memory_tools(mock_llm)\n\n        assert manage_tool == mock_manage\n        assert search_tool == mock_search\n        mock_get_store.assert_called_once()\n        mock_get_tools.assert_called_once_with(mock_llm, sync_mode=False, store=mock_store)\n\n\nclass TestMemoryManager:\n    \"\"\"Test memory manager functionality.\"\"\"\n\n    @patch(\"openchatbi.tool.memory.create_memory_store_manager\")\n    @patch(\"openchatbi.tool.memory.config.get\")\n    def test_get_memory_manager(self, mock_config, mock_create_manager):\n        \"\"\"Test memory manager creation.\"\"\"\n        mock_llm = Mock()\n        mock_config.return_value.default_llm = mock_llm\n        mock_manager = Mock()\n        mock_create_manager.return_value = mock_manager\n\n        manager = get_memory_manager()\n\n        assert manager == mock_manager\n        mock_create_manager.assert_called_once_with(\n            mock_llm,\n            schemas=[UserProfile],\n            instructions=\"Extract user profile information\",\n            enable_inserts=False,\n        )\n\n    @patch(\"openchatbi.tool.memory.memory_manager\", new=Mock())\n    @patch(\"openchatbi.tool.memory.create_memory_store_manager\")\n    @patch(\"openchatbi.tool.memory.config.get\")\n    def test_get_memory_manager_singleton(self, mock_config, mock_create_manager):\n        \"\"\"Test memory manager singleton behavior.\"\"\"\n        # Reset the global variable for this test\n        import openchatbi.tool.memory as memory_module\n\n        existing_manager = Mock()\n        memory_module.memory_manager = existing_manager\n\n        manager = get_memory_manager()\n\n        # Should return existing manager without creating new one\n        assert manager == existing_manager\n        mock_create_manager.assert_not_called()\n\n\nclass TestSchemaFixer:\n    \"\"\"Test schema fixing functionality for OpenAI compatibility.\"\"\"\n\n    def test_fix_schema_for_openai_basic(self):\n        \"\"\"Test basic schema fixing.\"\"\"\n        schema = {\"properties\": {\"field1\": {\"type\": \"string\"}, \"field2\": {\"type\": \"number\"}}}\n\n        fix_schema_for_openai(schema)\n\n        assert schema[\"required\"] == [\"field1\", \"field2\"]\n\n    def test_fix_schema_for_openai_nested_object(self):\n        \"\"\"Test schema fixing with nested objects.\"\"\"\n        schema = {\n            \"properties\": {\n                \"nested\": {\"type\": \"object\", \"additionalProperties\": True, \"properties\": {\"inner\": {\"type\": \"string\"}}}\n            }\n        }\n\n        fix_schema_for_openai(schema)\n\n        assert schema[\"required\"] == [\"nested\"]\n        assert schema[\"properties\"][\"nested\"][\"additionalProperties\"] is False\n\n    def test_fix_schema_for_openai_with_arrays(self):\n        \"\"\"Test schema fixing with array properties.\"\"\"\n        schema = {\"properties\": {\"items\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"additionalProperties\": True}}}}\n\n        fix_schema_for_openai(schema)\n\n        assert schema[\"required\"] == [\"items\"]\n        assert schema[\"properties\"][\"items\"][\"items\"][\"additionalProperties\"] is False\n\n\nclass TestStructuredToolWithRequired:\n    \"\"\"Test StructuredToolWithRequired wrapper functionality.\"\"\"\n\n    def test_structured_tool_with_required_initialization(self):\n        \"\"\"Test StructuredToolWithRequired initialization.\"\"\"\n        mock_original_tool = Mock()\n        mock_original_tool.name = \"test_tool\"\n        mock_original_tool.description = \"Test description\"\n        mock_original_tool.args_schema = Mock()\n        mock_original_tool.func = Mock()\n        mock_original_tool.coroutine = None\n\n        with patch(\"openchatbi.tool.memory.StructuredTool.__init__\", return_value=None) as mock_init:\n            wrapper = StructuredToolWithRequired(mock_original_tool)\n\n            # Verify the __init__ was called with correct parameters\n            mock_init.assert_called_once()\n            call_args = mock_init.call_args\n            assert call_args.kwargs[\"name\"] == \"test_tool\"\n            assert call_args.kwargs[\"description\"] == \"Test description\"\n\n    def test_tool_call_schema_property(self):\n        \"\"\"Test tool_call_schema cached property.\"\"\"\n        mock_original_tool = Mock()\n        mock_original_tool.name = \"test_tool\"\n        mock_original_tool.description = \"Test description\"\n        mock_original_tool.args_schema = Mock()\n        mock_original_tool.func = Mock()\n        mock_original_tool.coroutine = None\n\n        with patch(\"openchatbi.tool.memory.StructuredTool.__init__\", return_value=None):\n            wrapper = StructuredToolWithRequired(mock_original_tool)\n\n            # Mock the parent's tool_call_schema\n            mock_tcs = Mock()\n            mock_tcs.model_config = {}\n\n            with patch(\"openchatbi.tool.memory.StructuredTool.tool_call_schema\", new_callable=lambda: mock_tcs):\n                result = wrapper.tool_call_schema\n\n                assert result == mock_tcs\n                assert \"json_schema_extra\" in mock_tcs.model_config\n"
  },
  {
    "path": "tests/test_plotly_utils.py",
    "content": "\"\"\"Tests for plotly utilities in the UI.\"\"\"\n\nimport plotly.graph_objects as go\nimport pytest\n\nfrom sample_ui.plotly_utils import (\n    create_empty_chart,\n    create_plotly_chart,\n    visualization_dsl_to_gradio_plot,\n)\n\n\n@pytest.fixture\ndef sample_csv_data():\n    \"\"\"Sample CSV data for testing.\"\"\"\n    return \"\"\"product,sales,region,month\nWidget A,10000,North,Jan\nWidget B,15000,South,Jan  \nWidget C,8000,East,Jan\nWidget A,12000,North,Feb\nWidget B,18000,South,Feb\nWidget C,9000,East,Feb\"\"\"\n\n\n@pytest.fixture\ndef sample_line_dsl():\n    \"\"\"Sample DSL for line chart.\"\"\"\n    return {\n        \"chart_type\": \"line\",\n        \"data_columns\": [\"month\", \"sales\"],\n        \"config\": {\"x\": \"month\", \"y\": \"sales\", \"mode\": \"lines+markers\"},\n        \"layout\": {\"title\": \"Sales Over Time\", \"xaxis_title\": \"Month\", \"yaxis_title\": \"Sales\"},\n    }\n\n\n@pytest.fixture\ndef sample_bar_dsl():\n    \"\"\"Sample DSL for bar chart.\"\"\"\n    return {\n        \"chart_type\": \"bar\",\n        \"data_columns\": [\"region\", \"sales\"],\n        \"config\": {\"x\": \"region\", \"y\": \"sales\"},\n        \"layout\": {\"title\": \"Sales by Region\", \"xaxis_title\": \"Region\", \"yaxis_title\": \"Sales\"},\n    }\n\n\n@pytest.fixture\ndef sample_pie_dsl():\n    \"\"\"Sample DSL for pie chart.\"\"\"\n    return {\n        \"chart_type\": \"pie\",\n        \"data_columns\": [\"product\", \"sales\"],\n        \"config\": {\"labels\": \"product\", \"values\": \"sales\"},\n        \"layout\": {\"title\": \"Sales Distribution by Product\"},\n    }\n\n\nclass TestPlotlyChartCreation:\n    \"\"\"Tests for individual chart creation functions.\"\"\"\n\n    def test_create_line_chart_success(self, sample_csv_data, sample_line_dsl):\n        \"\"\"Test successful line chart creation.\"\"\"\n        fig = create_plotly_chart(sample_csv_data, sample_line_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n        assert fig.layout.title.text == \"Sales Over Time\"\n\n    def test_create_line_chart_with_color(self, sample_csv_data):\n        \"\"\"Test line chart creation with color parameter for multiple series.\"\"\"\n        multi_series_dsl = {\n            \"chart_type\": \"line\",\n            \"data_columns\": [\"month\", \"sales\", \"product\"],\n            \"config\": {\"x\": \"month\", \"y\": \"sales\", \"color\": \"product\"},\n            \"layout\": {\"title\": \"Sales Over Time by Product\", \"xaxis_title\": \"Month\", \"yaxis_title\": \"Sales\"},\n        }\n\n        fig = create_plotly_chart(sample_csv_data, multi_series_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0  # Should have multiple traces for different products\n        assert fig.layout.title.text == \"Sales Over Time by Product\"\n\n    def test_create_line_chart_with_multiple_y_columns(self):\n        \"\"\"Test line chart creation with multiple y columns.\"\"\"\n        multi_metric_data = \"\"\"date,revenue,profit,users\n2023-01-01,50000,15000,1000\n2023-02-01,55000,18000,1100\n2023-03-01,60000,20000,1200\"\"\"\n\n        multi_y_dsl = {\n            \"chart_type\": \"line\",\n            \"data_columns\": [\"date\", \"revenue\", \"profit\"],\n            \"config\": {\"x\": \"date\", \"y\": [\"revenue\", \"profit\"]},\n            \"layout\": {\"title\": \"Multiple Metrics Over Time\", \"xaxis_title\": \"Date\", \"yaxis_title\": \"Value\"},\n        }\n\n        fig = create_plotly_chart(multi_metric_data, multi_y_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0  # Should have multiple traces for different metrics\n        assert fig.layout.title.text == \"Multiple Metrics Over Time\"\n\n    def test_create_bar_chart_success(self, sample_csv_data, sample_bar_dsl):\n        \"\"\"Test successful bar chart creation.\"\"\"\n        fig = create_plotly_chart(sample_csv_data, sample_bar_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n        assert fig.layout.title.text == \"Sales by Region\"\n\n    def test_create_pie_chart_success(self, sample_csv_data, sample_pie_dsl):\n        \"\"\"Test successful pie chart creation.\"\"\"\n        fig = create_plotly_chart(sample_csv_data, sample_pie_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n        assert fig.layout.title.text == \"Sales Distribution by Product\"\n\n    def test_create_scatter_chart(self, sample_csv_data):\n        \"\"\"Test scatter chart creation.\"\"\"\n        scatter_dsl = {\n            \"chart_type\": \"scatter\",\n            \"data_columns\": [\"sales\", \"region\"],\n            \"config\": {\"x\": \"sales\", \"y\": \"region\", \"mode\": \"markers\"},\n            \"layout\": {\"title\": \"Sales Scatter Plot\"},\n        }\n\n        fig = create_plotly_chart(sample_csv_data, scatter_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n\n    def test_create_histogram_chart(self, sample_csv_data):\n        \"\"\"Test histogram chart creation.\"\"\"\n        histogram_dsl = {\n            \"chart_type\": \"histogram\",\n            \"data_columns\": [\"sales\"],\n            \"config\": {\"x\": \"sales\", \"nbins\": 10},\n            \"layout\": {\"title\": \"Sales Distribution\"},\n        }\n\n        fig = create_plotly_chart(sample_csv_data, histogram_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n\n    def test_create_box_chart(self, sample_csv_data):\n        \"\"\"Test box chart creation.\"\"\"\n        box_dsl = {\n            \"chart_type\": \"box\",\n            \"data_columns\": [\"sales\", \"region\"],\n            \"config\": {\"y\": \"sales\", \"x\": \"region\"},\n            \"layout\": {\"title\": \"Sales Distribution by Region\"},\n        }\n\n        fig = create_plotly_chart(sample_csv_data, box_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n\n    def test_create_table_chart(self, sample_csv_data):\n        \"\"\"Test table chart creation.\"\"\"\n        table_dsl = {\n            \"chart_type\": \"table\",\n            \"data_columns\": [\"product\", \"sales\", \"region\", \"month\"],\n            \"config\": {\"columns\": [\"product\", \"sales\", \"region\", \"month\"]},\n            \"layout\": {\"title\": \"Data Table\"},\n        }\n\n        fig = create_plotly_chart(sample_csv_data, table_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n        assert fig.data[0].type == \"table\"\n\n\nclass TestErrorHandling:\n    \"\"\"Tests for error handling in chart creation.\"\"\"\n\n    def test_empty_data(self):\n        \"\"\"Test handling of empty data.\"\"\"\n        fig = create_plotly_chart(\"\", {})\n\n        assert isinstance(fig, go.Figure)\n        # Should create an empty chart with error message\n\n    def test_invalid_csv_data(self, sample_bar_dsl):\n        \"\"\"Test handling of invalid CSV data.\"\"\"\n        invalid_csv = \"invalid,csv\\ndata\"\n\n        fig = create_plotly_chart(invalid_csv, sample_bar_dsl)\n\n        assert isinstance(fig, go.Figure)\n        # Should create an empty chart with error message\n\n    def test_missing_columns(self, sample_csv_data):\n        \"\"\"Test handling of missing columns in DSL.\"\"\"\n        invalid_dsl = {\n            \"chart_type\": \"line\",\n            \"data_columns\": [\"nonexistent_col\"],\n            \"config\": {\"x\": \"nonexistent_col\", \"y\": \"another_missing_col\"},\n            \"layout\": {\"title\": \"Invalid Chart\"},\n        }\n\n        fig = create_plotly_chart(sample_csv_data, invalid_dsl)\n\n        assert isinstance(fig, go.Figure)\n        # Should create an empty chart with error message\n\n    def test_unsupported_chart_type(self, sample_csv_data):\n        \"\"\"Test handling of unsupported chart types.\"\"\"\n        invalid_dsl = {\"chart_type\": \"unsupported_type\", \"data_columns\": [\"sales\"], \"config\": {}, \"layout\": {}}\n\n        fig = create_plotly_chart(sample_csv_data, invalid_dsl)\n\n        assert isinstance(fig, go.Figure)\n        # Should create an empty chart with error message\n\n    def test_visualization_dsl_error(self):\n        \"\"\"Test handling of DSL with error field.\"\"\"\n        error_dsl = {\"error\": \"Failed to generate visualization\"}\n\n        fig = create_plotly_chart(\"some,data\\n1,2\", error_dsl)\n\n        assert isinstance(fig, go.Figure)\n        # Should create an empty chart with error message\n\n\nclass TestVisualizationDslToGradioPlot:\n    \"\"\"Tests for the main interface function.\"\"\"\n\n    def test_successful_conversion(self, sample_csv_data, sample_line_dsl):\n        \"\"\"Test successful DSL to Gradio plot conversion.\"\"\"\n        fig, description = visualization_dsl_to_gradio_plot(sample_csv_data, sample_line_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert isinstance(description, str)\n        assert \"line\" in description.lower()\n        assert \"Sales Over Time\" in description\n\n    def test_empty_dsl(self, sample_csv_data):\n        \"\"\"Test conversion with empty DSL.\"\"\"\n        fig, description = visualization_dsl_to_gradio_plot(sample_csv_data, {})\n\n        assert isinstance(fig, go.Figure)\n        assert isinstance(description, str)\n        assert \"table\" in description.lower()\n\n    def test_no_data(self, sample_line_dsl):\n        \"\"\"Test conversion with no data.\"\"\"\n        fig, description = visualization_dsl_to_gradio_plot(\"\", sample_line_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert isinstance(description, str)\n\n\nclass TestCreateEmptyChart:\n    \"\"\"Tests for empty chart creation.\"\"\"\n\n    def test_create_empty_chart(self):\n        \"\"\"Test empty chart creation with message.\"\"\"\n        message = \"Test error message\"\n        fig = create_empty_chart(message)\n\n        assert isinstance(fig, go.Figure)\n        assert fig.layout.title.text == \"Chart Generation Issue\"\n        # Check if annotation contains the message\n        assert len(fig.layout.annotations) > 0\n        assert fig.layout.annotations[0].text == message\n\n\n@pytest.fixture\ndef sample_time_series_data():\n    \"\"\"Sample time series data for testing.\"\"\"\n    return \"\"\"date,revenue,users\n2023-01-01,50000,1000\n2023-02-01,55000,1100\n2023-03-01,60000,1200\n2023-04-01,52000,1050\n2023-05-01,58000,1150\"\"\"\n\n\nclass TestIntegrationScenarios:\n    \"\"\"Integration tests for complete visualization scenarios.\"\"\"\n\n    def test_sales_dashboard_scenario(self, sample_csv_data):\n        \"\"\"Test a complete sales dashboard scenario.\"\"\"\n        # Test multiple chart types with the same data\n        chart_configs = [\n            {\"chart_type\": \"bar\", \"config\": {\"x\": \"product\", \"y\": \"sales\"}, \"layout\": {\"title\": \"Sales by Product\"}},\n            {\n                \"chart_type\": \"pie\",\n                \"config\": {\"labels\": \"region\", \"values\": \"sales\"},\n                \"layout\": {\"title\": \"Sales by Region\"},\n            },\n        ]\n\n        for config in chart_configs:\n            config[\"data_columns\"] = list(config[\"config\"].values())\n            fig = create_plotly_chart(sample_csv_data, config)\n            assert isinstance(fig, go.Figure)\n            assert len(fig.data) > 0\n\n    def test_time_series_scenario(self, sample_time_series_data):\n        \"\"\"Test time series visualization scenario.\"\"\"\n        line_dsl = {\n            \"chart_type\": \"line\",\n            \"data_columns\": [\"date\", \"revenue\"],\n            \"config\": {\"x\": \"date\", \"y\": \"revenue\", \"mode\": \"lines+markers\"},\n            \"layout\": {\"title\": \"Revenue Trend Over Time\", \"xaxis_title\": \"Date\", \"yaxis_title\": \"Revenue\"},\n        }\n\n        fig, description = visualization_dsl_to_gradio_plot(sample_time_series_data, line_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert \"line\" in description.lower()\n        assert \"Revenue Trend Over Time\" in description\n\n    def test_multiple_metrics_scenario(self, sample_time_series_data):\n        \"\"\"Test scenario with multiple metrics.\"\"\"\n        scatter_dsl = {\n            \"chart_type\": \"scatter\",\n            \"data_columns\": [\"revenue\", \"users\"],\n            \"config\": {\"x\": \"revenue\", \"y\": \"users\", \"mode\": \"markers\"},\n            \"layout\": {\"title\": \"Revenue vs Users Correlation\", \"xaxis_title\": \"Revenue\", \"yaxis_title\": \"Users\"},\n        }\n\n        fig = create_plotly_chart(sample_time_series_data, scatter_dsl)\n\n        assert isinstance(fig, go.Figure)\n        assert len(fig.data) > 0\n        assert fig.layout.title.text == \"Revenue vs Users Correlation\"\n"
  },
  {
    "path": "tests/test_simple_store.py",
    "content": "\"\"\"Unit tests for SimpleStore.\"\"\"\n\nimport pytest\n\nfrom openchatbi.utils import SimpleStore\n\n\nclass TestSimpleStore:\n    \"\"\"Test suite for SimpleStore class.\"\"\"\n\n    @pytest.fixture\n    def sample_texts(self):\n        \"\"\"Sample texts for testing.\"\"\"\n        return [\n            \"Python is a programming language\",\n            \"Machine learning is a subset of AI\",\n            \"Deep learning uses neural networks\",\n            \"Natural language processing works with text\",\n        ]\n\n    @pytest.fixture\n    def sample_metadatas(self):\n        \"\"\"Sample metadata for testing.\"\"\"\n        return [\n            {\"category\": \"programming\"},\n            {\"category\": \"ai\"},\n            {\"category\": \"ai\"},\n            {\"category\": \"nlp\"},\n        ]\n\n    @pytest.fixture\n    def simple_store(self, sample_texts):\n        \"\"\"Create a SimpleStore instance for testing.\"\"\"\n        return SimpleStore(sample_texts)\n\n    def test_initialization_basic(self, sample_texts):\n        \"\"\"Test basic initialization.\"\"\"\n        store = SimpleStore(sample_texts)\n\n        assert len(store.texts) == len(sample_texts)\n        assert store.texts == sample_texts\n        assert len(store.documents) == len(sample_texts)\n        assert store.bm25 is not None\n\n    def test_initialization_with_metadata_and_ids(self, sample_texts, sample_metadatas):\n        \"\"\"Test initialization with metadata and custom IDs.\"\"\"\n        ids = [\"id1\", \"id2\", \"id3\", \"id4\"]\n        store = SimpleStore(sample_texts, sample_metadatas, ids)\n\n        assert store.texts == sample_texts\n        assert store.metadatas == sample_metadatas\n        assert store.ids == ids\n        # Check documents are created correctly\n        for doc, text, meta, doc_id in zip(store.documents, sample_texts, sample_metadatas, ids):\n            assert doc.page_content == text\n            assert doc.metadata == meta\n            assert doc.id == doc_id\n\n    def test_similarity_search(self, simple_store):\n        \"\"\"Test similarity search functionality.\"\"\"\n        query = \"programming\"\n        results = simple_store.similarity_search(query, k=2)\n\n        assert len(results) == 2\n        assert \"Python\" in results[0].page_content\n\n        # Test k parameter\n        results = simple_store.similarity_search(query, k=10)\n        assert len(results) == 4  # Should return all documents\n\n    def test_similarity_search_with_score(self, simple_store):\n        \"\"\"Test similarity search with scores.\"\"\"\n        query = \"programming\"\n        results = simple_store.similarity_search_with_score(query, k=2)\n\n        assert len(results) == 2\n        for doc, score in results:\n            assert hasattr(doc, \"page_content\")\n            assert isinstance(score, (int, float))\n            assert score >= 0\n\n        # Scores should be in descending order\n        scores = [score for _, score in results]\n        assert scores == sorted(scores, reverse=True)\n\n    def test_empty_store(self):\n        \"\"\"Test empty store operations.\"\"\"\n        store = SimpleStore([])\n\n        assert store.bm25 is None\n        assert store.similarity_search(\"test\", k=5) == []\n        assert store.similarity_search_with_score(\"test\", k=5) == []\n\n    def test_add_texts(self, simple_store):\n        \"\"\"Test adding texts with and without metadata.\"\"\"\n        initial_count = len(simple_store.texts)\n        new_texts = [\"Data science is important\", \"Statistics is fundamental\"]\n        new_metadatas = [{\"type\": \"test\"}, {\"type\": \"example\"}]\n\n        # Add with metadata and custom IDs\n        custom_ids = [\"custom_1\", \"custom_2\"]\n        returned_ids = simple_store.add_texts(new_texts, metadatas=new_metadatas, ids=custom_ids)\n\n        assert returned_ids == custom_ids\n        assert len(simple_store.texts) == initial_count + len(new_texts)\n        assert all(text in simple_store.texts for text in new_texts)\n\n        # Check metadata was added correctly\n        added_docs = [doc for doc in simple_store.documents if doc.id in custom_ids]\n        assert len(added_docs) == 2\n        assert added_docs[0].metadata == {\"type\": \"test\"}\n\n        # Verify BM25 index is updated\n        results = simple_store.similarity_search(\"data science\", k=1)\n        assert \"data\" in results[0].page_content.lower() or \"science\" in results[0].page_content.lower()\n\n    def test_delete(self):\n        \"\"\"Test deleting documents.\"\"\"\n        texts = [\"Text A\", \"Text B\", \"Text C\", \"Text D\"]\n        ids = [\"id1\", \"id2\", \"id3\", \"id4\"]\n        store = SimpleStore(texts, ids=ids)\n\n        # Delete specific IDs\n        result = store.delete([\"id2\", \"id3\"])\n        assert result is True\n        assert len(store.texts) == 2\n        assert store.texts == [\"Text A\", \"Text D\"]\n        assert store.ids == [\"id1\", \"id4\"]\n\n        # Delete non-existent IDs\n        result = store.delete([\"nonexistent\"])\n        assert result is False\n\n        # Delete with None\n        result = store.delete(None)\n        assert result is False\n\n        # Delete all remaining documents\n        result = store.delete([\"id1\", \"id4\"])\n        assert result is True\n        assert len(store.texts) == 0\n        assert store.bm25 is None\n\n    def test_get_by_ids(self, sample_texts):\n        \"\"\"Test retrieving documents by IDs.\"\"\"\n        ids = [\"id1\", \"id2\", \"id3\", \"id4\"]\n        store = SimpleStore(sample_texts, ids=ids)\n\n        # Get existing IDs\n        docs = store.get_by_ids([\"id1\", \"id3\"])\n        assert len(docs) == 2\n        assert docs[0].id == \"id1\"\n        assert docs[0].page_content == sample_texts[0]\n\n        # Get non-existent IDs\n        docs = store.get_by_ids([\"nonexistent\"])\n        assert len(docs) == 0\n\n        # Mixed existent and non-existent\n        docs = store.get_by_ids([\"id1\", \"nonexistent\", \"id3\"])\n        assert len(docs) == 2\n\n    def test_from_texts(self, sample_texts, sample_metadatas):\n        \"\"\"Test creating store using from_texts class method.\"\"\"\n        ids = [\"id1\", \"id2\", \"id3\", \"id4\"]\n        store = SimpleStore.from_texts(sample_texts, embedding=None, metadatas=sample_metadatas, ids=ids)\n\n        assert isinstance(store, SimpleStore)\n        assert store.texts == sample_texts\n        assert store.metadatas == sample_metadatas\n        assert store.ids == ids\n\n    def test_as_retriever(self, simple_store):\n        \"\"\"Test creating a retriever from the store.\"\"\"\n        retriever = simple_store.as_retriever(search_kwargs={\"k\": 2})\n\n        results = retriever.invoke(\"programming\")\n        assert len(results) <= 2\n        assert all(hasattr(doc, \"page_content\") for doc in results)\n\n    def test_chinese_and_mixed_language(self):\n        \"\"\"Test search with Chinese and mixed language texts.\"\"\"\n        from openchatbi.text_segmenter import _jieba_available\n\n        mixed_texts = [\n            \"Python programming language\",\n            \"机器学习很重要\",\n            \"Deep learning neural networks\",\n            \"数据科学分析\",\n        ]\n        store = SimpleStore(mixed_texts)\n\n        # Search in English\n        en_results = store.similarity_search(\"programming\", k=1)\n        assert \"Python\" in en_results[0].page_content\n\n        # Search in Chinese - result depends on jieba availability\n        cn_results = store.similarity_search(\"机器学习\", k=2)\n        assert len(cn_results) > 0\n\n        # If jieba is available, expect better Chinese matching\n        if _jieba_available:\n            assert \"机器学习\" in cn_results[0].page_content\n        else:\n            # Without jieba, just verify results are returned\n            # (Chinese text may not be perfectly tokenized)\n            assert any(\"机器学习\" in doc.page_content for doc in cn_results) or any(\n                \"数据科学\" in doc.page_content for doc in cn_results\n            )\n\n    def test_max_marginal_relevance_search(self, simple_store):\n        \"\"\"Test max_marginal_relevance_search method.\"\"\"\n        query = \"programming language\"\n\n        # Test basic MMR search\n        results = simple_store.max_marginal_relevance_search(query, k=2, fetch_k=4, lambda_mult=0.5)\n        assert len(results) == 2\n        assert all(hasattr(doc, \"page_content\") for doc in results)\n\n        # Test relevance-focused search (lambda_mult = 1.0)\n        results_relevant = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=1.0)\n        assert len(results_relevant) == 3\n\n        # Test diversity-focused search (lambda_mult = 0.0)\n        results_diverse = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=0.0)\n        assert len(results_diverse) == 3\n\n        # Verify different lambda values produce different results\n        # (unless there are ties in scoring)\n        assert len(results_relevant) == len(results_diverse)\n\n        # Test with k >= fetch_k\n        results = simple_store.max_marginal_relevance_search(query, k=5, fetch_k=3, lambda_mult=0.5)\n        assert len(results) == 3  # Should return fetch_k documents\n\n        # Test empty query\n        results = simple_store.max_marginal_relevance_search(\"\", k=2)\n        assert len(results) <= 2\n\n        # Test empty store\n        empty_store = SimpleStore([])\n        results = empty_store.max_marginal_relevance_search(query, k=2)\n        assert results == []\n\n    def test_calculate_similarity(self, simple_store):\n        \"\"\"Test _calculate_similarity method.\"\"\"\n        # Get two documents\n        doc1 = simple_store.documents[0]  # \"Python is a programming language\"\n        doc2 = simple_store.documents[1]  # \"Machine learning is a subset of AI\"\n        doc3 = simple_store.documents[0]  # Same as doc1\n\n        # Test similarity between different documents\n        similarity_diff = simple_store._calculate_similarity(doc1, doc2)\n        assert 0.0 <= similarity_diff <= 1.0\n\n        # Test similarity between identical documents\n        similarity_same = simple_store._calculate_similarity(doc1, doc3)\n        assert similarity_same == 1.0\n\n        # Test with empty documents\n        from langchain_core.documents import Document\n\n        empty_doc1 = Document(page_content=\"\", metadata={})\n        empty_doc2 = Document(page_content=\"\", metadata={})\n        similarity_empty = simple_store._calculate_similarity(empty_doc1, empty_doc2)\n        assert similarity_empty == 0.0  # Empty sets have 0 Jaccard similarity\n"
  },
  {
    "path": "tests/test_text2sql_extraction.py",
    "content": "\"\"\"Tests for text2sql information extraction functionality.\"\"\"\n\nimport json\nfrom datetime import date\nfrom unittest.mock import Mock, patch\n\nfrom langchain_core.messages import AIMessage, HumanMessage\n\nfrom openchatbi.graph_state import SQLGraphState\nfrom openchatbi.text2sql.extraction import (\n    generate_extraction_prompt,\n    information_extraction,\n    information_extraction_conditional_edges,\n    parse_extracted_info_json,\n)\n\n\nclass TestText2SQLExtraction:\n    \"\"\"Test text2sql information extraction functionality.\"\"\"\n\n    def test_generate_extraction_prompt(self):\n        \"\"\"Test extraction prompt generation.\"\"\"\n        prompt = generate_extraction_prompt()\n\n        # Should replace time placeholder with today's date\n        today_str = date.today().strftime(\"%Y-%m-%d\")\n        assert today_str in prompt\n\n        # Should contain basic knowledge\n        assert \"[basic_knowledge_glossary]\" not in prompt\n        assert \"[time_field_placeholder]\" not in prompt\n\n    def test_parse_extracted_info_json_valid(self):\n        \"\"\"Test parsing valid JSON from LLM response.\"\"\"\n        json_response = {\n            \"keywords\": [\"revenue\", \"sales\"],\n            \"dimensions\": [\"date\", \"region\"],\n            \"metrics\": [\"total_revenue\"],\n            \"filters\": [],\n        }\n\n        # Mock LLM response with JSON\n        llm_content = f\"```json\\n{json.dumps(json_response)}\\n```\"\n\n        with patch(\"openchatbi.text2sql.extraction.get_text_from_content\", return_value=llm_content):\n            with patch(\"openchatbi.text2sql.extraction.extract_json_from_answer\", return_value=json_response):\n                result = parse_extracted_info_json(llm_content)\n\n        assert result == json_response\n        assert \"keywords\" in result\n        assert \"dimensions\" in result\n\n    def test_parse_extracted_info_json_invalid(self):\n        \"\"\"Test parsing invalid JSON returns empty dict.\"\"\"\n        invalid_content = \"Not valid JSON content\"\n\n        with patch(\"openchatbi.text2sql.extraction.get_text_from_content\", return_value=invalid_content):\n            with patch(\"openchatbi.text2sql.extraction.extract_json_from_answer\", side_effect=Exception(\"Parse error\")):\n                result = parse_extracted_info_json(invalid_content)\n\n        assert result == {}\n\n    def test_information_extraction_function_creation(self):\n        \"\"\"Test creating information extraction function.\"\"\"\n        mock_llm = Mock()\n\n        extraction_func = information_extraction(mock_llm)\n\n        # Should return a callable function\n        assert callable(extraction_func)\n\n    def test_information_extraction_successful(self):\n        \"\"\"Test successful information extraction.\"\"\"\n        mock_llm = Mock()\n\n        # Mock LLM response\n        extracted_info = {\n            \"rewrite_question\": \"What is the total revenue by region?\",\n            \"keywords\": [\"revenue\", \"total\"],\n            \"dimensions\": [\"region\"],\n            \"metrics\": [\"revenue\"],\n            \"filters\": [],\n        }\n\n        mock_response = AIMessage(content=json.dumps(extracted_info))\n\n        with patch(\"openchatbi.text2sql.extraction.call_llm_chat_model_with_retry\", return_value=mock_response):\n            with patch(\"openchatbi.text2sql.extraction.parse_extracted_info_json\", return_value=extracted_info):\n                extraction_func = information_extraction(mock_llm)\n\n                state = SQLGraphState(\n                    messages=[HumanMessage(content=\"Show me revenue by region\")], question=\"Show me revenue by region\"\n                )\n\n                result = extraction_func(state)\n\n        assert \"info_entities\" in result\n        assert result[\"rewrite_question\"] == \"What is the total revenue by region?\"\n\n    def test_information_extraction_empty_response(self):\n        \"\"\"Test handling empty extraction response.\"\"\"\n        mock_llm = Mock()\n\n        mock_response = AIMessage(content=\"\")\n\n        with patch(\"openchatbi.text2sql.extraction.call_llm_chat_model_with_retry\", return_value=mock_response):\n            with patch(\"openchatbi.text2sql.extraction.parse_extracted_info_json\", return_value={}):\n                extraction_func = information_extraction(mock_llm)\n\n                state = SQLGraphState(messages=[HumanMessage(content=\"Test question\")], question=\"Test question\")\n\n                result = extraction_func(state)\n\n        # Should handle empty response gracefully\n        assert \"info_entities\" in result\n        assert result[\"info_entities\"] == {}\n\n    def test_information_extraction_conditional_edges_success(self):\n        \"\"\"Test conditional edges with successful extraction.\"\"\"\n        state = SQLGraphState(\n            messages=[HumanMessage(content=\"Test question\")],\n            question=\"Test question\",\n            rewrite_question=\"What is the total revenue by region?\",\n            info_entities={\"keywords\": [\"revenue\"], \"dimensions\": [\"date\"]},\n        )\n\n        result = information_extraction_conditional_edges(state)\n\n        # Should proceed to next when rewrite_question exists\n        assert result == \"next\"\n\n    def test_information_extraction_conditional_edges_failure(self):\n        \"\"\"Test conditional edges with failed extraction.\"\"\"\n        state = SQLGraphState(\n            messages=[HumanMessage(content=\"Test question\")], question=\"Test question\", info_entities={}\n        )\n\n        result = information_extraction_conditional_edges(state)\n\n        # Should end when no info extracted\n        assert result == \"end\"\n\n    def test_information_extraction_conditional_edges_missing(self):\n        \"\"\"Test conditional edges with missing info_entities.\"\"\"\n        state = SQLGraphState(messages=[HumanMessage(content=\"Test question\")], question=\"Test question\")\n\n        result = information_extraction_conditional_edges(state)\n\n        # Should end when info_entities not present\n        assert result == \"end\"\n\n    def test_information_extraction_with_retry_on_failure(self):\n        \"\"\"Test information extraction with retry mechanism.\"\"\"\n        mock_llm = Mock()\n\n        # First call fails, second succeeds\n        extracted_info = {\n            \"rewrite_question\": \"Test question\",\n            \"keywords\": [\"test\"],\n            \"dimensions\": [],\n            \"metrics\": [],\n            \"filters\": [],\n        }\n\n        mock_response = AIMessage(content=json.dumps(extracted_info))\n\n        with patch(\"openchatbi.text2sql.extraction.call_llm_chat_model_with_retry\", return_value=mock_response):\n            with patch(\"openchatbi.text2sql.extraction.parse_extracted_info_json\", return_value=extracted_info):\n                extraction_func = information_extraction(mock_llm)\n\n                state = SQLGraphState(messages=[HumanMessage(content=\"Test question\")], question=\"Test question\")\n\n                result = extraction_func(state)\n\n        assert \"info_entities\" in result\n        assert result[\"info_entities\"][\"keywords\"] == [\"test\"]\n\n    def test_information_extraction_time_period_detection(self):\n        \"\"\"Test time period detection in queries.\"\"\"\n        mock_llm = Mock()\n\n        extracted_info = {\n            \"rewrite_question\": \"Show data for the last 7 days\",\n            \"keywords\": [\"data\"],\n            \"dimensions\": [\"date\"],\n            \"metrics\": [],\n            \"filters\": [],\n            \"start_time\": \"2024-01-01\",\n        }\n\n        mock_response = AIMessage(content=json.dumps(extracted_info))\n\n        with patch(\"openchatbi.text2sql.extraction.call_llm_chat_model_with_retry\", return_value=mock_response):\n            with patch(\"openchatbi.text2sql.extraction.parse_extracted_info_json\", return_value=extracted_info):\n                extraction_func = information_extraction(mock_llm)\n\n                state = SQLGraphState(\n                    messages=[HumanMessage(content=\"Test question\")], question=\"Show data for last 7 days\"\n                )\n\n                result = extraction_func(state)\n\n        assert \"info_entities\" in result\n        assert \"start_time\" in result[\"info_entities\"]\n\n    def test_information_extraction_error_handling(self):\n        \"\"\"Test error handling in information extraction.\"\"\"\n        mock_llm = Mock()\n\n        # Mock call to raise exception\n        with patch(\"openchatbi.text2sql.extraction.call_llm_chat_model_with_retry\", side_effect=Exception(\"LLM error\")):\n            extraction_func = information_extraction(mock_llm)\n\n            state = SQLGraphState(messages=[HumanMessage(content=\"Test question\")], question=\"Test question\")\n\n            # Should raise exception as the function doesn't have try-catch\n            try:\n                result = extraction_func(state)\n                # Should not reach here\n                assert False, \"Expected exception to be raised\"\n            except Exception as e:\n                assert \"LLM error\" in str(e)\n"
  },
  {
    "path": "tests/test_text2sql_generate_sql.py",
    "content": "\"\"\"Tests for text2sql SQL generation functionality.\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\nfrom langchain_core.messages import AIMessage\n\nfrom openchatbi.graph_state import SQLGraphState\nfrom openchatbi.text2sql.generate_sql import create_sql_nodes, should_execute_sql, should_retry_sql\n\n\nclass TestText2SQLGenerateSQL:\n    \"\"\"Test text2sql SQL generation functionality.\"\"\"\n\n    @pytest.fixture\n    def mock_llm(self):\n        \"\"\"Mock LLM for testing.\"\"\"\n        llm = Mock()\n        llm.invoke.return_value = AIMessage(content=\"SELECT * FROM users\")\n        return llm\n\n    @pytest.fixture\n    def mock_catalog(self):\n        \"\"\"Mock catalog store for testing.\"\"\"\n        catalog = Mock()\n        catalog.get_table_information.return_value = {\n            \"description\": \"User data table\",\n            \"sql_rule\": \"\",\n            \"derived_metric\": \"\",\n        }\n        catalog.get_column_list.return_value = [\n            {\n                \"column_name\": \"user_id\",\n                \"type\": \"bigint\",\n                \"display_name\": \"User ID\",\n                \"description\": \"Unique user identifier\",\n                \"alias\": \"\",\n            }\n        ]\n        # Mock SQL engine with proper context manager\n        mock_engine = Mock()\n        mock_connection = Mock()\n        mock_result = Mock()\n        mock_result.fetchall.return_value = [(\"1\", \"John\"), (\"2\", \"Jane\")]\n        mock_result.keys.return_value = [\"id\", \"name\"]\n        mock_connection.execute.return_value = mock_result\n\n        # Create a proper context manager mock using MagicMock\n        from unittest.mock import MagicMock\n\n        mock_context_manager = MagicMock()\n        mock_context_manager.__enter__.return_value = mock_connection\n        mock_context_manager.__exit__.return_value = None\n        mock_engine.connect.return_value = mock_context_manager\n\n        catalog.get_sql_engine.return_value = mock_engine\n\n        return catalog\n\n    def test_create_sql_nodes(self, mock_llm, mock_catalog):\n        \"\"\"Test creating SQL processing nodes.\"\"\"\n        generate_node, execute_node, regenerate_node, visualization_node = create_sql_nodes(\n            mock_llm, mock_catalog, \"presto\"\n        )\n\n        assert callable(generate_node)\n        assert callable(execute_node)\n        assert callable(regenerate_node)\n        assert callable(visualization_node)\n\n    def test_generate_sql_node_success(self, mock_llm, mock_catalog):\n        \"\"\"Test successful SQL generation.\"\"\"\n        generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(\n            messages=[],\n            question=\"Show all users\",\n            rewrite_question=\"Show all users\",\n            tables=[{\"table\": \"users\", \"columns\": []}],\n        )\n\n        with patch(\"openchatbi.text2sql.generate_sql.sql_example_retriever\") as mock_retriever:\n            mock_retriever.invoke.return_value = []\n\n            result = generate_node(state)\n\n        assert \"sql\" in result\n        assert result[\"sql\"] == \"SELECT * FROM users\"\n\n    def test_generate_sql_node_missing_rewrite_question(self, mock_llm, mock_catalog):\n        \"\"\"Test SQL generation with missing rewrite question.\"\"\"\n        generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(\n            messages=[],\n            question=\"Show all users\",\n            # Missing rewrite_question\n        )\n\n        result = generate_node(state)\n        assert result == {}\n\n    def test_generate_sql_node_missing_tables(self, mock_llm, mock_catalog):\n        \"\"\"Test SQL generation with missing tables.\"\"\"\n        generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(\n            messages=[], question=\"Show all users\", rewrite_question=\"Show all users\", tables=[]  # Empty tables\n        )\n\n        result = generate_node(state)\n        assert result == {}\n\n    def test_execute_sql_node_success(self, mock_llm, mock_catalog):\n        \"\"\"Test successful SQL execution.\"\"\"\n        _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(messages=[], sql=\"SELECT * FROM users\")\n\n        result = execute_node(state)\n\n        assert \"sql_execution_result\" in result\n        from openchatbi.constants import SQL_SUCCESS\n\n        assert result[\"sql_execution_result\"] == SQL_SUCCESS\n        assert \"data\" in result\n\n    def test_execute_sql_node_empty_sql(self, mock_llm, mock_catalog):\n        \"\"\"Test SQL execution with empty SQL.\"\"\"\n        _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(messages=[], sql=\"\")  # Empty SQL\n\n        result = execute_node(state)\n\n        assert \"sql_execution_result\" in result\n        from openchatbi.constants import SQL_NA\n\n        assert result[\"sql_execution_result\"] == SQL_NA\n\n    def test_execute_sql_node_syntax_error(self, mock_llm, mock_catalog):\n        \"\"\"Test SQL execution with syntax error.\"\"\"\n        _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        # Mock SQL execution to raise syntax error\n        mock_engine = mock_catalog.get_sql_engine.return_value\n        mock_connection = mock_engine.connect.return_value.__enter__.return_value\n        from sqlalchemy.exc import ProgrammingError\n\n        mock_connection.execute.side_effect = ProgrammingError(\"\", \"\", \"Syntax error\")\n\n        state = SQLGraphState(messages=[], sql=\"SELECT * FRON users\")  # Intentional syntax error\n\n        result = execute_node(state)\n\n        assert \"sql_execution_result\" in result\n        from openchatbi.constants import SQL_SYNTAX_ERROR\n\n        assert result[\"sql_execution_result\"] == SQL_SYNTAX_ERROR\n        assert \"previous_sql_errors\" in result\n\n    def test_regenerate_sql_node_success(self, mock_llm, mock_catalog):\n        \"\"\"Test successful SQL regeneration.\"\"\"\n        _, _, regenerate_node, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(\n            messages=[],\n            question=\"Show all users\",\n            rewrite_question=\"Show all users\",\n            tables=[{\"table\": \"users\", \"columns\": []}],\n            previous_sql_errors=[\n                {\"sql\": \"SELECT * FRON users\", \"error\": \"Syntax error: FRON\", \"error_type\": \"SQL syntax error\"}\n            ],\n            sql_retry_count=1,\n        )\n\n        with patch(\"openchatbi.text2sql.generate_sql.sql_example_retriever\") as mock_retriever:\n            mock_retriever.invoke.return_value = []\n\n            result = regenerate_node(state)\n\n        assert \"sql\" in result\n        assert \"sql_retry_count\" in result\n        assert result[\"sql_retry_count\"] == 2\n\n    def test_should_retry_sql_success(self):\n        \"\"\"Test retry decision with successful execution.\"\"\"\n        # Import the constant from the module\n        from openchatbi.constants import SQL_SUCCESS\n\n        state = SQLGraphState(sql_execution_result=SQL_SUCCESS, sql_retry_count=1)\n\n        result = should_retry_sql(state)\n        assert result == \"end\"\n\n    def test_should_retry_sql_timeout(self):\n        \"\"\"Test retry decision with timeout.\"\"\"\n        # Import the constant from the module\n        from openchatbi.constants import SQL_EXECUTE_TIMEOUT\n\n        state = SQLGraphState(sql_execution_result=SQL_EXECUTE_TIMEOUT, sql_retry_count=1)\n\n        result = should_retry_sql(state)\n        assert result == \"end\"\n\n    def test_should_retry_sql_retry_needed(self):\n        \"\"\"Test retry decision when retry is needed.\"\"\"\n        state = SQLGraphState(sql_execution_result=\"SYNTAX_ERROR\", sql_retry_count=1)\n\n        result = should_retry_sql(state)\n        assert result == \"regenerate_sql\"\n\n    def test_should_retry_sql_max_retries_reached(self):\n        \"\"\"Test retry decision when max retries reached.\"\"\"\n        state = SQLGraphState(sql_execution_result=\"SYNTAX_ERROR\", sql_retry_count=3)\n\n        result = should_retry_sql(state)\n        assert result == \"end\"\n\n    def test_should_execute_sql_with_sql(self):\n        \"\"\"Test execute decision with SQL present.\"\"\"\n        state = SQLGraphState(sql=\"SELECT * FROM users\")\n\n        result = should_execute_sql(state)\n        assert result == \"execute_sql\"\n\n    def test_should_execute_sql_without_sql(self):\n        \"\"\"Test execute decision without SQL.\"\"\"\n        state = SQLGraphState(sql=\"\")\n\n        result = should_execute_sql(state)\n        assert result == \"end\"\n\n    def test_sql_generation_with_examples(self, mock_llm, mock_catalog):\n        \"\"\"Test SQL generation with relevant examples.\"\"\"\n        generate_node, _, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(\n            messages=[],\n            question=\"Show user count\",\n            rewrite_question=\"Show user count\",\n            tables=[{\"table\": \"users\", \"columns\": []}],\n        )\n\n        # Mock example retrieval\n        mock_document = Mock()\n        mock_document.page_content = \"How many users are there?\"\n\n        with patch(\"openchatbi.text2sql.generate_sql.sql_example_retriever\") as mock_retriever:\n            mock_retriever.invoke.return_value = [mock_document]\n\n            with patch(\n                \"openchatbi.text2sql.generate_sql.sql_example_dicts\",\n                {\"How many users are there?\": (\"SELECT COUNT(*) FROM users\", [\"users\"])},\n            ):\n                result = generate_node(state)\n\n        assert \"sql\" in result\n\n    def test_sql_error_handling_database_error(self, mock_llm, mock_catalog):\n        \"\"\"Test handling of database connection errors.\"\"\"\n        _, execute_node, _, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        # Mock database connection error\n        mock_engine = mock_catalog.get_sql_engine.return_value\n        mock_connection = mock_engine.connect.return_value.__enter__.return_value\n        from sqlalchemy.exc import OperationalError\n\n        mock_connection.execute.side_effect = OperationalError(\"\", \"\", \"Connection failed\")\n\n        state = SQLGraphState(messages=[], sql=\"SELECT * FROM users\")\n\n        result = execute_node(state)\n\n        assert \"sql_execution_result\" in result\n        from openchatbi.constants import SQL_EXECUTE_TIMEOUT\n\n        assert result[\"sql_execution_result\"] == SQL_EXECUTE_TIMEOUT\n\n    def test_regenerate_sql_empty_response(self, mock_llm, mock_catalog):\n        \"\"\"Test regeneration with empty LLM response.\"\"\"\n        mock_llm.invoke.return_value = AIMessage(content=\"\")\n\n        _, _, regenerate_node, _ = create_sql_nodes(mock_llm, mock_catalog, \"presto\")\n\n        state = SQLGraphState(\n            messages=[],\n            question=\"Show all users\",\n            rewrite_question=\"Show all users\",\n            tables=[{\"table\": \"users\", \"columns\": []}],\n            previous_sql_errors=[],\n            sql_retry_count=1,\n        )\n\n        with patch(\"openchatbi.text2sql.generate_sql.sql_example_retriever\") as mock_retriever:\n            mock_retriever.invoke.return_value = []\n\n            result = regenerate_node(state)\n\n        assert \"sql\" in result\n        assert result[\"sql\"] == \"\"\n        from openchatbi.constants import SQL_NA\n\n        assert result[\"sql_execution_result\"] == SQL_NA\n"
  },
  {
    "path": "tests/test_text2sql_schema_linking.py",
    "content": "\"\"\"Tests for text2sql schema linking functionality.\"\"\"\n\nfrom unittest.mock import Mock, patch\n\nimport pytest\nfrom langchain_core.messages import AIMessage\n\nfrom openchatbi.graph_state import SQLGraphState\nfrom openchatbi.text2sql.schema_linking import schema_linking\n\n\nclass TestText2SQLSchemaLinking:\n    \"\"\"Test text2sql schema linking functionality.\"\"\"\n\n    @pytest.fixture\n    def mock_llm(self):\n        \"\"\"Mock LLM for testing.\"\"\"\n        llm = Mock()\n        llm.invoke.return_value = AIMessage(content='{\"tables\": [{\"table\": \"users\", \"reason\": \"Contains user data\"}]}')\n        return llm\n\n    @pytest.fixture\n    def mock_catalog(self):\n        \"\"\"Mock catalog store for testing.\"\"\"\n        catalog = Mock()\n        catalog.get_table_information.return_value = {\n            \"description\": \"User data table\",\n            \"selection_rule\": \"Use for user-related queries\",\n        }\n        return catalog\n\n    def test_select_table_function_creation(self, mock_llm, mock_catalog):\n        \"\"\"Test creating table selection function.\"\"\"\n        select_func = schema_linking(mock_llm, mock_catalog)\n\n        assert callable(select_func)\n\n    def test_select_table_success(self, mock_llm, mock_catalog):\n        \"\"\"Test successful table selection.\"\"\"\n        with patch(\"openchatbi.text2sql.schema_linking.get_relevant_columns\") as mock_get_columns:\n            mock_get_columns.return_value = [\"user_id\", \"name\", \"email\"]\n\n            with patch(\n                \"openchatbi.text2sql.schema_linking.column_tables_mapping\",\n                {\"user_id\": [\"users\", \"profiles\"], \"name\": [\"users\"], \"email\": [\"users\", \"contacts\"]},\n            ):\n                with patch(\n                    \"openchatbi.text2sql.schema_linking.col_dict\",\n                    {\n                        \"user_id\": {\n                            \"column_name\": \"user_id\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"User ID\",\n                            \"description\": \"Unique user identifier\",\n                        },\n                        \"name\": {\n                            \"column_name\": \"name\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"Name\",\n                            \"description\": \"User full name\",\n                        },\n                        \"email\": {\n                            \"column_name\": \"email\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"Email\",\n                            \"description\": \"User email address\",\n                        },\n                    },\n                ):\n                    with patch(\"openchatbi.text2sql.schema_linking.table_selection_retriever\") as mock_retriever:\n                        mock_retriever.invoke.return_value = []\n\n                        with patch(\"openchatbi.text2sql.schema_linking.table_selection_example_dict\", {}):\n                            with patch(\"openchatbi.text2sql.schema_linking.extract_json_from_answer\") as mock_extract:\n                                mock_extract.return_value = {\n                                    \"tables\": [{\"table\": \"users\", \"reason\": \"Contains user data\"}]\n                                }\n\n                                select_func = schema_linking(mock_llm, mock_catalog)\n\n                                state = SQLGraphState(\n                                    messages=[],\n                                    question=\"Show user information\",\n                                    rewrite_question=\"Show user information\",\n                                    info_entities={\n                                        \"keywords\": [\"user\", \"information\"],\n                                        \"dimensions\": [\"name\", \"email\"],\n                                        \"metrics\": [],\n                                    },\n                                )\n\n                                result = select_func(state)\n\n        assert \"tables\" in result\n        assert len(result[\"tables\"]) == 1\n        assert result[\"tables\"][0][\"table\"] == \"users\"\n\n    def test_select_table_missing_rewrite_question(self, mock_llm, mock_catalog):\n        \"\"\"Test table selection with missing rewrite question.\"\"\"\n        select_func = schema_linking(mock_llm, mock_catalog)\n\n        state = SQLGraphState(\n            messages=[],\n            question=\"Show user information\",\n            # Missing rewrite_question\n        )\n\n        result = select_func(state)\n        assert result == {}\n\n    def test_select_table_with_examples(self, mock_llm, mock_catalog):\n        \"\"\"Test table selection with similar examples.\"\"\"\n        with patch(\"openchatbi.text2sql.schema_linking.get_relevant_columns\") as mock_get_columns:\n            mock_get_columns.return_value = [\"user_id\", \"revenue\"]\n\n            with patch(\n                \"openchatbi.text2sql.schema_linking.column_tables_mapping\", {\"user_id\": [\"users\"], \"revenue\": [\"sales\"]}\n            ):\n                with patch(\n                    \"openchatbi.text2sql.schema_linking.col_dict\",\n                    {\n                        \"user_id\": {\n                            \"column_name\": \"user_id\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"User ID\",\n                            \"description\": \"Unique user identifier\",\n                        },\n                        \"revenue\": {\n                            \"column_name\": \"revenue\",\n                            \"category\": \"metric\",\n                            \"display_name\": \"Revenue\",\n                            \"description\": \"Total revenue amount\",\n                        },\n                    },\n                ):\n                    # Mock similar examples\n                    mock_document = Mock()\n                    mock_document.page_content = \"What is user revenue?\"\n\n                    with patch(\"openchatbi.text2sql.schema_linking.table_selection_retriever\") as mock_retriever:\n                        mock_retriever.invoke.return_value = [mock_document]\n\n                        with patch(\n                            \"openchatbi.text2sql.schema_linking.table_selection_example_dict\",\n                            {\"What is user revenue?\": [\"users\", \"sales\"]},\n                        ):\n                            with patch(\"openchatbi.text2sql.schema_linking.extract_json_from_answer\") as mock_extract:\n                                mock_extract.return_value = {\"tables\": [{\"table\": \"users\"}, {\"table\": \"sales\"}]}\n\n                                select_func = schema_linking(mock_llm, mock_catalog)\n\n                                state = SQLGraphState(\n                                    messages=[],\n                                    question=\"Show user revenue\",\n                                    rewrite_question=\"Show user revenue\",\n                                    info_entities={\n                                        \"keywords\": [\"user\", \"revenue\"],\n                                        \"dimensions\": [\"user_id\"],\n                                        \"metrics\": [\"revenue\"],\n                                    },\n                                )\n\n                                result = select_func(state)\n\n        assert \"tables\" in result\n        assert len(result[\"tables\"]) == 2\n\n    def test_select_table_invalid_table_selection(self, mock_llm, mock_catalog):\n        \"\"\"Test handling of invalid table selection.\"\"\"\n        with patch(\"openchatbi.text2sql.schema_linking.get_relevant_columns\") as mock_get_columns:\n            mock_get_columns.return_value = [\"user_id\"]\n\n            with patch(\"openchatbi.text2sql.schema_linking.column_tables_mapping\", {\"user_id\": [\"users\"]}):\n                with patch(\n                    \"openchatbi.text2sql.schema_linking.col_dict\",\n                    {\n                        \"user_id\": {\n                            \"column_name\": \"user_id\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"User ID\",\n                            \"description\": \"Unique user identifier\",\n                        }\n                    },\n                ):\n                    with patch(\"openchatbi.text2sql.schema_linking.table_selection_retriever\") as mock_retriever:\n                        mock_retriever.invoke.return_value = []\n\n                        with patch(\"openchatbi.text2sql.schema_linking.table_selection_example_dict\", {}):\n                            with patch(\"openchatbi.text2sql.schema_linking.extract_json_from_answer\") as mock_extract:\n                                # Return invalid table not in candidate list\n                                mock_extract.return_value = {\"tables\": [{\"table\": \"invalid_table\"}]}\n\n                                select_func = schema_linking(mock_llm, mock_catalog)\n\n                                state = SQLGraphState(\n                                    messages=[],\n                                    question=\"Show user info\",\n                                    rewrite_question=\"Show user info\",\n                                    info_entities={\"keywords\": [\"user\"], \"dimensions\": [\"user_id\"], \"metrics\": []},\n                                )\n\n                                result = select_func(state)\n\n        # Should return empty dict when invalid table selected\n        assert result == {}\n\n    def test_select_table_retry_mechanism(self, mock_llm, mock_catalog):\n        \"\"\"Test retry mechanism for table selection.\"\"\"\n        with patch(\"openchatbi.text2sql.schema_linking.get_relevant_columns\") as mock_get_columns:\n            mock_get_columns.return_value = [\"user_id\"]\n\n            with patch(\"openchatbi.text2sql.schema_linking.column_tables_mapping\", {\"user_id\": [\"users\"]}):\n                with patch(\n                    \"openchatbi.text2sql.schema_linking.col_dict\",\n                    {\n                        \"user_id\": {\n                            \"column_name\": \"user_id\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"User ID\",\n                            \"description\": \"Unique user identifier\",\n                        }\n                    },\n                ):\n                    with patch(\"openchatbi.text2sql.schema_linking.table_selection_retriever\") as mock_retriever:\n                        mock_retriever.invoke.return_value = []\n\n                        with patch(\"openchatbi.text2sql.schema_linking.table_selection_example_dict\", {}):\n                            with patch(\"openchatbi.text2sql.schema_linking.extract_json_from_answer\") as mock_extract:\n                                # First returns invalid, then valid\n                                mock_extract.side_effect = [\n                                    {\"tables\": [{\"table\": \"invalid_table\"}]},\n                                    {\"tables\": [{\"table\": \"users\"}]},\n                                ]\n\n                                select_func = schema_linking(mock_llm, mock_catalog)\n\n                                state = SQLGraphState(\n                                    messages=[],\n                                    question=\"Show user info\",\n                                    rewrite_question=\"Show user info\",\n                                    info_entities={\"keywords\": [\"user\"], \"dimensions\": [\"user_id\"], \"metrics\": []},\n                                )\n\n                                result = select_func(state)\n\n        assert \"tables\" in result\n        assert result[\"tables\"][0][\"table\"] == \"users\"\n\n    def test_select_table_with_time_filter(self, mock_llm, mock_catalog):\n        \"\"\"Test table selection with time filtering.\"\"\"\n        # Mock table with start_time\n        mock_catalog.get_table_information.return_value = {\n            \"description\": \"User data table\",\n            \"selection_rule\": \"Use for user queries\",\n            \"start_time\": \"2024-01-01\",\n        }\n\n        with patch(\"openchatbi.text2sql.schema_linking.get_relevant_columns\") as mock_get_columns:\n            mock_get_columns.return_value = [\"user_id\"]\n\n            with patch(\"openchatbi.text2sql.schema_linking.column_tables_mapping\", {\"user_id\": [\"users\"]}):\n                with patch(\n                    \"openchatbi.text2sql.schema_linking.col_dict\",\n                    {\n                        \"user_id\": {\n                            \"column_name\": \"user_id\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"User ID\",\n                            \"description\": \"Unique user identifier\",\n                        }\n                    },\n                ):\n                    with patch(\"openchatbi.text2sql.schema_linking.table_selection_retriever\") as mock_retriever:\n                        mock_retriever.invoke.return_value = []\n\n                        with patch(\"openchatbi.text2sql.schema_linking.table_selection_example_dict\", {}):\n                            with patch(\"openchatbi.text2sql.schema_linking.extract_json_from_answer\") as mock_extract:\n                                mock_extract.return_value = {\"tables\": [{\"table\": \"users\"}]}\n\n                                select_func = schema_linking(mock_llm, mock_catalog)\n\n                                state = SQLGraphState(\n                                    messages=[],\n                                    question=\"Show recent user info\",\n                                    rewrite_question=\"Show recent user info\",\n                                    info_entities={\n                                        \"keywords\": [\"user\"],\n                                        \"dimensions\": [\"user_id\"],\n                                        \"metrics\": [],\n                                        \"start_time\": \"2024-06-01\",  # Later than table start_time\n                                    },\n                                )\n\n                                result = select_func(state)\n\n        assert \"tables\" in result\n        assert result[\"tables\"][0][\"table\"] == \"users\"\n\n    def test_select_table_llm_error_handling(self, mock_llm, mock_catalog):\n        \"\"\"Test handling of LLM errors during table selection.\"\"\"\n        mock_llm.invoke.side_effect = Exception(\"LLM service error\")\n\n        with patch(\"openchatbi.text2sql.schema_linking.get_relevant_columns\") as mock_get_columns:\n            mock_get_columns.return_value = [\"user_id\"]\n\n            with patch(\"openchatbi.text2sql.schema_linking.column_tables_mapping\", {\"user_id\": [\"users\"]}):\n                with patch(\n                    \"openchatbi.text2sql.schema_linking.col_dict\",\n                    {\n                        \"user_id\": {\n                            \"column_name\": \"user_id\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"User ID\",\n                            \"description\": \"Unique user identifier\",\n                        }\n                    },\n                ):\n                    with patch(\"openchatbi.text2sql.schema_linking.table_selection_retriever\") as mock_retriever:\n                        mock_retriever.invoke.return_value = []\n\n                        select_func = schema_linking(mock_llm, mock_catalog)\n\n                        state = SQLGraphState(\n                            messages=[],\n                            question=\"Show user info\",\n                            rewrite_question=\"Show user info\",\n                            info_entities={\"keywords\": [\"user\"], \"dimensions\": [\"user_id\"], \"metrics\": []},\n                        )\n\n                        result = select_func(state)\n\n        # Should handle error gracefully and return empty dict\n        assert result == {}\n\n    def test_select_table_max_retries_exceeded(self, mock_llm, mock_catalog):\n        \"\"\"Test behavior when max retries are exceeded.\"\"\"\n        with patch(\"openchatbi.text2sql.schema_linking.get_relevant_columns\") as mock_get_columns:\n            mock_get_columns.return_value = [\"user_id\"]\n\n            with patch(\"openchatbi.text2sql.schema_linking.column_tables_mapping\", {\"user_id\": [\"users\"]}):\n                with patch(\n                    \"openchatbi.text2sql.schema_linking.col_dict\",\n                    {\n                        \"user_id\": {\n                            \"column_name\": \"user_id\",\n                            \"category\": \"dimension\",\n                            \"display_name\": \"User ID\",\n                            \"description\": \"Unique user identifier\",\n                        }\n                    },\n                ):\n                    with patch(\"openchatbi.text2sql.schema_linking.table_selection_retriever\") as mock_retriever:\n                        mock_retriever.invoke.return_value = []\n\n                        with patch(\"openchatbi.text2sql.schema_linking.table_selection_example_dict\", {}):\n                            with patch(\"openchatbi.text2sql.schema_linking.extract_json_from_answer\") as mock_extract:\n                                # Always return invalid table\n                                mock_extract.return_value = {\"tables\": [{\"table\": \"invalid_table\"}]}\n\n                                select_func = schema_linking(mock_llm, mock_catalog)\n\n                                state = SQLGraphState(\n                                    messages=[],\n                                    question=\"Show user info\",\n                                    rewrite_question=\"Show user info\",\n                                    info_entities={\"keywords\": [\"user\"], \"dimensions\": [\"user_id\"], \"metrics\": []},\n                                )\n\n                                result = select_func(state)\n\n        # Should return empty dict after max retries\n        assert result == {}\n"
  },
  {
    "path": "tests/test_text2sql_visualization.py",
    "content": "\"\"\"Tests for text2sql visualization functionality.\"\"\"\n\nimport pytest\n\nfrom openchatbi.text2sql.visualization import ChartType, VisualizationConfig, VisualizationDSL, VisualizationService\n\n\nclass TestVisualizationService:\n    \"\"\"Tests for the VisualizationService class.\"\"\"\n\n    def test_generate_visualization_dsl_basic(self):\n        \"\"\"Test basic DSL generation with schema info.\"\"\"\n        schema_info = {\n            \"columns\": [\"name\", \"age\", \"salary\", \"department\"],\n            \"row_count\": 4,\n            \"numeric_columns\": [\"age\", \"salary\"],\n            \"categorical_columns\": [\"name\", \"department\"],\n            \"datetime_columns\": [],\n            \"unique_counts\": {\"name\": 4, \"department\": 2},\n        }\n\n        service = VisualizationService()\n        question = \"Compare salary by department\"\n        dsl = service.generate_visualization_dsl(question, schema_info)\n\n        assert dsl.chart_type == \"bar\"\n        # Should use first categorical column which is \"name\"\n        assert \"name\" in dsl.data_columns\n        assert \"age\" in dsl.data_columns and \"salary\" in dsl.data_columns  # Both numeric columns should be included\n\n    def test_get_chart_type_by_rule_with_datetime(self):\n        \"\"\"Test chart type recommendation with datetime columns.\"\"\"\n        schema_info = {\n            \"columns\": [\"date\", \"sales\", \"region\"],\n            \"numeric_columns\": [\"sales\"],\n            \"categorical_columns\": [\"region\"],\n            \"datetime_columns\": [\"date\"],\n            \"row_count\": 3,\n        }\n\n        service = VisualizationService()\n        question = \"Show sales trend over time\"\n        chart_type = service._get_chart_type_by_rule(question, schema_info)\n\n        assert chart_type == ChartType.LINE\n\n    def test_generate_visualization_dsl_error_handling(self):\n        \"\"\"Test DSL generation with error in schema info.\"\"\"\n        schema_info = {\"error\": \"Failed to analyze data schema\"}\n\n        service = VisualizationService()\n        question = \"Show data\"\n        dsl = service.generate_visualization_dsl(question, schema_info)\n\n        assert dsl.chart_type == \"table\"\n        assert \"error\" in dsl.config\n\n    def test_get_chart_type_by_rule_line_chart(self):\n        \"\"\"Test recommendation for line chart based on question keywords.\"\"\"\n        question = \"Show me the sales trend over time\"\n        schema = {\n            \"numeric_columns\": [\"sales\"],\n            \"categorical_columns\": [\"region\"],\n            \"datetime_columns\": [\"date\"],\n            \"row_count\": 10,\n        }\n\n        service = VisualizationService()\n        chart_type = service._get_chart_type_by_rule(question, schema)\n\n        assert chart_type == ChartType.LINE\n\n    def test_get_chart_type_by_rule_pie_chart(self):\n        \"\"\"Test recommendation for pie chart based on question keywords.\"\"\"\n        question = \"What is the percentage breakdown by department?\"\n        schema = {\n            \"numeric_columns\": [\"count\"],\n            \"categorical_columns\": [\"department\"],\n            \"datetime_columns\": [],\n            \"row_count\": 5,\n            \"unique_counts\": {\"department\": 4},\n        }\n\n        service = VisualizationService()\n        chart_type = service._get_chart_type_by_rule(question, schema)\n\n        assert chart_type == ChartType.PIE\n\n    def test_get_chart_type_by_rule_bar_chart(self):\n        \"\"\"Test recommendation for bar chart based on question keywords.\"\"\"\n        question = \"Compare sales by region\"\n        schema = {\n            \"numeric_columns\": [\"sales\"],\n            \"categorical_columns\": [\"region\"],\n            \"datetime_columns\": [],\n            \"row_count\": 10,\n            \"unique_counts\": {\"region\": 4},\n        }\n\n        service = VisualizationService()\n        chart_type = service._get_chart_type_by_rule(question, schema)\n\n        assert chart_type == ChartType.BAR\n\n    def test_get_chart_type_by_rule_scatter_plot(self):\n        \"\"\"Test recommendation for scatter plot based on data characteristics.\"\"\"\n        question = \"Show relationship between age and salary\"\n        schema = {\n            \"numeric_columns\": [\"age\", \"salary\"],\n            \"categorical_columns\": [\"name\"],\n            \"datetime_columns\": [],\n            \"row_count\": 10,\n        }\n\n        service = VisualizationService()\n        chart_type = service._get_chart_type_by_rule(question, schema)\n\n        assert chart_type == ChartType.SCATTER\n\n    def test_get_chart_type_by_rule_histogram(self):\n        \"\"\"Test recommendation for histogram based on keywords.\"\"\"\n        question = \"What is the distribution of ages?\"\n        schema = {\"numeric_columns\": [\"age\"], \"categorical_columns\": [], \"datetime_columns\": [], \"row_count\": 100}\n\n        service = VisualizationService()\n        chart_type = service._get_chart_type_by_rule(question, schema)\n\n        assert chart_type == ChartType.HISTOGRAM\n\n    def test_get_chart_type_by_rule_data_based_priority(self):\n        \"\"\"Test that data characteristics take priority over row count.\"\"\"\n        question = \"Show all records\"\n        schema = {\n            \"numeric_columns\": [\"value\"],\n            \"categorical_columns\": [\"category\"],\n            \"datetime_columns\": [],\n            \"row_count\": 15,\n            \"unique_counts\": {\"category\": 5},  # Small number of categories\n        }\n\n        service = VisualizationService()\n        chart_type = service._get_chart_type_by_rule(question, schema)\n\n        # Should choose PIE because of categorical + numeric columns, not TABLE due to row count\n        assert chart_type == ChartType.PIE\n\n    def test_generate_visualization_dsl_line_chart(self):\n        \"\"\"Test DSL generation for line chart.\"\"\"\n        question = \"Show sales trend over time\"\n        schema_info = {\n            \"columns\": [\"date\", \"sales\"],\n            \"numeric_columns\": [\"sales\"],\n            \"categorical_columns\": [],\n            \"datetime_columns\": [\"date\"],\n            \"row_count\": 3,\n        }\n\n        service = VisualizationService()\n        dsl = service.generate_visualization_dsl(question, schema_info)\n\n        assert dsl.chart_type == \"line\"\n        assert \"date\" in dsl.data_columns\n        assert \"sales\" in dsl.data_columns\n        assert dsl.config[\"x\"] == \"date\"\n        assert dsl.config[\"y\"] == \"sales\"\n\n    def test_generate_visualization_dsl_bar_chart(self):\n        \"\"\"Test DSL generation for bar chart.\"\"\"\n        question = \"Compare sales by region\"\n        schema_info = {\n            \"columns\": [\"region\", \"sales\"],\n            \"numeric_columns\": [\"sales\"],\n            \"categorical_columns\": [\"region\"],\n            \"datetime_columns\": [],\n            \"row_count\": 4,\n            \"unique_counts\": {\"region\": 4},\n        }\n\n        service = VisualizationService()\n        dsl = service.generate_visualization_dsl(question, schema_info)\n\n        assert dsl.chart_type == \"bar\"\n        assert \"region\" in dsl.data_columns\n        assert \"sales\" in dsl.data_columns\n        assert dsl.config[\"x\"] == \"region\"\n        assert dsl.config[\"y\"] == \"sales\"\n\n    def test_generate_visualization_dsl_pie_chart(self):\n        \"\"\"Test DSL generation for pie chart.\"\"\"\n        question = \"Show percentage breakdown by department\"\n        schema_info = {\n            \"columns\": [\"department\", \"count\"],\n            \"numeric_columns\": [\"count\"],\n            \"categorical_columns\": [\"department\"],\n            \"datetime_columns\": [],\n            \"row_count\": 4,\n            \"unique_counts\": {\"department\": 4},\n        }\n\n        service = VisualizationService()\n        dsl = service.generate_visualization_dsl(question, schema_info, ChartType.PIE)\n\n        assert dsl.chart_type == \"pie\"\n        assert dsl.config[\"labels\"] == \"department\"\n        assert dsl.config[\"values\"] == \"count\"\n\n    def test_generate_visualization_dsl_empty_data(self):\n        \"\"\"Test DSL generation with empty data.\"\"\"\n        question = \"Show data\"\n        schema_info = {\n            \"columns\": [],\n            \"numeric_columns\": [],\n            \"categorical_columns\": [],\n            \"datetime_columns\": [],\n            \"row_count\": 0,\n        }\n\n        service = VisualizationService()\n        dsl = service.generate_visualization_dsl(question, schema_info)\n\n        assert dsl.chart_type == \"table\"\n        assert \"columns\" in dsl.config\n\n    def test_visualization_config_dataclass(self):\n        \"\"\"Test VisualizationConfig dataclass.\"\"\"\n        config = VisualizationConfig(\n            chart_type=ChartType.BAR, x_column=\"category\", y_column=\"value\", title=\"Test Chart\"\n        )\n\n        assert config.chart_type == ChartType.BAR\n        assert config.x_column == \"category\"\n        assert config.y_column == \"value\"\n        assert config.title == \"Test Chart\"\n        assert config.show_legend is True  # default value\n\n    def test_visualization_dsl_to_dict(self):\n        \"\"\"Test VisualizationDSL to_dict method.\"\"\"\n        dsl = VisualizationDSL(\n            chart_type=\"bar\",\n            data_columns=[\"x\", \"y\"],\n            config={\"x\": \"category\", \"y\": \"value\"},\n            layout={\"title\": \"Test Chart\"},\n        )\n\n        result = dsl.to_dict()\n\n        assert result[\"chart_type\"] == \"bar\"\n        assert result[\"data_columns\"] == [\"x\", \"y\"]\n        assert result[\"config\"][\"x\"] == \"category\"\n        assert result[\"layout\"][\"title\"] == \"Test Chart\"\n\n\nclass TestChartType:\n    \"\"\"Tests for ChartType enum.\"\"\"\n\n    def test_chart_type_values(self):\n        \"\"\"Test ChartType enum values.\"\"\"\n        assert ChartType.LINE.value == \"line\"\n        assert ChartType.BAR.value == \"bar\"\n        assert ChartType.PIE.value == \"pie\"\n        assert ChartType.SCATTER.value == \"scatter\"\n        assert ChartType.HISTOGRAM.value == \"histogram\"\n        assert ChartType.BOX.value == \"box\"\n        assert ChartType.HEATMAP.value == \"heatmap\"\n        assert ChartType.TABLE.value == \"table\"\n\n\n@pytest.fixture\ndef sample_csv_data():\n    \"\"\"Fixture providing sample CSV data for testing.\"\"\"\n    return \"\"\"product,sales,region,quarter\nWidget A,10000,North,Q1\nWidget B,15000,South,Q1\nWidget C,8000,East,Q1\nWidget A,12000,North,Q2\nWidget B,18000,South,Q2\nWidget C,9000,East,Q2\"\"\"\n\n\n@pytest.fixture\ndef sample_time_series_data():\n    \"\"\"Fixture providing sample time series data for testing.\"\"\"\n    return \"\"\"date,revenue,users\n2023-01-01,50000,1000\n2023-02-01,55000,1100\n2023-03-01,60000,1200\n2023-04-01,52000,1050\n2023-05-01,58000,1150\"\"\"\n\n\nclass TestVisualizationIntegration:\n    \"\"\"Integration tests for visualization functionality.\"\"\"\n\n    def test_complete_workflow_line_chart(self, sample_time_series_data):\n        \"\"\"Test complete workflow for generating line chart.\"\"\"\n        question = \"Show revenue trend over time\"\n\n        # Mock schema info for time series data\n        schema_info = {\n            \"columns\": [\"date\", \"revenue\", \"users\"],\n            \"numeric_columns\": [\"revenue\", \"users\"],\n            \"categorical_columns\": [],\n            \"datetime_columns\": [\"date\"],\n            \"row_count\": 5,\n        }\n\n        service = VisualizationService()\n        # Recommend chart type\n        chart_type = service._get_chart_type_by_rule(question, schema_info)\n\n        # Generate DSL\n        dsl = service.generate_visualization_dsl(question, schema_info, chart_type)\n\n        assert chart_type == ChartType.LINE\n        assert dsl.chart_type == \"line\"\n        assert \"date\" in dsl.data_columns\n        assert \"revenue\" in dsl.data_columns\n\n    def test_complete_workflow_bar_chart(self, sample_csv_data):\n        \"\"\"Test complete workflow for generating bar chart.\"\"\"\n        question = \"Compare sales by product\"\n\n        # Mock schema info for sample CSV data\n        schema_info = {\n            \"columns\": [\"product\", \"sales\", \"region\", \"quarter\"],\n            \"numeric_columns\": [\"sales\"],\n            \"categorical_columns\": [\"product\", \"region\", \"quarter\"],\n            \"datetime_columns\": [],\n            \"row_count\": 6,\n            \"unique_counts\": {\"product\": 3, \"region\": 3, \"quarter\": 2},\n        }\n\n        service = VisualizationService()\n        # Generate DSL directly (will analyze schema internally)\n        dsl = service.generate_visualization_dsl(question, schema_info)\n\n        assert dsl.chart_type in [\"bar\", \"line\"]  # Could be either based on heuristics\n        assert len(dsl.data_columns) >= 2\n        assert dsl.layout.get(\"title\") is not None\n"
  },
  {
    "path": "tests/test_tools_ask_human.py",
    "content": "\"\"\"Tests for ask_human tool functionality.\"\"\"\n\nimport pytest\nfrom pydantic import ValidationError\n\nfrom openchatbi.tool.ask_human import AskHuman\n\n\nclass TestAskHuman:\n    \"\"\"Test AskHuman model functionality.\"\"\"\n\n    def test_ask_human_basic_initialization(self):\n        \"\"\"Test basic AskHuman model creation.\"\"\"\n        question = \"What time period should I analyze?\"\n        options = [\"Last 7 days\", \"Last 30 days\", \"Last year\"]\n\n        ask_human = AskHuman(question=question, options=options)\n\n        assert ask_human.question == question\n        assert ask_human.options == options\n\n    def test_ask_human_empty_options(self):\n        \"\"\"Test AskHuman with empty options list.\"\"\"\n        ask_human = AskHuman(question=\"Simple question?\", options=[])\n\n        assert ask_human.question == \"Simple question?\"\n        assert ask_human.options == []\n\n    def test_ask_human_validation_error(self):\n        \"\"\"Test AskHuman model validation.\"\"\"\n        with pytest.raises(ValidationError):\n            AskHuman()  # Missing required fields\n\n        with pytest.raises(ValidationError):\n            AskHuman(question=\"Test\")  # Missing options field\n\n    def test_ask_human_serialization(self):\n        \"\"\"Test AskHuman model serialization.\"\"\"\n        ask_human = AskHuman(question=\"Which analysis method?\", options=[\"Statistical\", \"Machine Learning\"])\n\n        data = ask_human.model_dump()\n        assert data[\"question\"] == \"Which analysis method?\"\n        assert data[\"options\"] == [\"Statistical\", \"Machine Learning\"]\n"
  },
  {
    "path": "tests/test_tools_run_python_code.py",
    "content": "\"\"\"Tests for run_python_code tool functionality.\"\"\"\n\nfrom unittest.mock import patch\n\nfrom openchatbi.tool.run_python_code import run_python_code\n\n\nclass TestRunPythonCode:\n    \"\"\"Test run_python_code tool functionality.\"\"\"\n\n    def test_run_python_code_basic(self):\n        \"\"\"Test basic Python code execution.\"\"\"\n        reasoning = \"Testing basic print functionality\"\n        code = \"print('Hello, World!')\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Hello, World!\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Hello, World!\" in result\n\n    def test_run_python_code_with_variables(self):\n        \"\"\"Test Python code execution with variables.\"\"\"\n        reasoning = \"Testing variable operations\"\n        code = \"\"\"\nx = 10\ny = 20\nresult = x + y\nprint(f\"Result: {result}\")\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Result: 30\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Result: 30\" in result\n\n    def test_run_python_code_data_analysis(self):\n        \"\"\"Test Python code for data analysis operations.\"\"\"\n        reasoning = \"Performing data analysis calculations\"\n        code = \"\"\"\nimport math\ndata = [1, 2, 3, 4, 5]\nmean = sum(data) / len(data)\nprint(f\"Mean: {mean}\")\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Mean: 3.0\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Mean: 3.0\" in result\n\n    def test_run_python_code_matplotlib_plot(self):\n        \"\"\"Test Python code for creating plots.\"\"\"\n        reasoning = \"Creating a matplotlib visualization\"\n        code = \"\"\"\nimport matplotlib.pyplot as plt\nx = [1, 2, 3, 4, 5]\ny = [2, 4, 6, 8, 10]\nplt.plot(x, y)\nplt.title('Sample Plot')\nprint('Plot created successfully')\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Plot created successfully\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Plot created successfully\" in result\n\n    def test_run_python_code_syntax_error(self):\n        \"\"\"Test Python code execution with syntax errors.\"\"\"\n        reasoning = \"Testing error handling for syntax errors\"\n        code = \"print('Hello World'\"  # Missing closing parenthesis\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (False, \"SyntaxError: unexpected EOF while parsing\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Error:\" in result\n            assert \"SyntaxError\" in result\n\n    def test_run_python_code_runtime_error(self):\n        \"\"\"Test Python code execution with runtime errors.\"\"\"\n        reasoning = \"Testing error handling for runtime errors\"\n        code = \"\"\"\nx = 10\ny = 0\nresult = x / y\nprint(result)\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (False, \"ZeroDivisionError: division by zero\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Error:\" in result\n            assert \"ZeroDivisionError\" in result\n\n    def test_run_python_code_import_error(self):\n        \"\"\"Test Python code execution with import errors.\"\"\"\n        reasoning = \"Testing error handling for import errors\"\n        code = \"\"\"\nimport nonexistent_module\nprint('This should not print')\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (False, \"ModuleNotFoundError: No module named 'nonexistent_module'\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Error:\" in result\n            assert \"ModuleNotFoundError\" in result\n\n    def test_run_python_code_multiline_output(self):\n        \"\"\"Test Python code with multiple print statements.\"\"\"\n        reasoning = \"Testing multiple output lines\"\n        code = \"\"\"\nfor i in range(3):\n    print(f\"Line {i + 1}\")\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Line 1\\nLine 2\\nLine 3\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Line 1\" in result\n            assert \"Line 2\" in result\n            assert \"Line 3\" in result\n\n    def test_run_python_code_with_sql_data(self):\n        \"\"\"Test Python code working with SQL-like data.\"\"\"\n        reasoning = \"Processing SQL query results\"\n        code = \"\"\"\ndata = [\n    {'name': 'Alice', 'age': 30, 'salary': 50000},\n    {'name': 'Bob', 'age': 25, 'salary': 45000},\n    {'name': 'Charlie', 'age': 35, 'salary': 55000}\n]\ntotal_salary = sum(row['salary'] for row in data)\nprint(f\"Total salary: ${total_salary}\")\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Total salary: $150000\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Total salary: $150000\" in result\n\n    def test_run_python_code_empty_code(self):\n        \"\"\"Test Python code execution with empty code.\"\"\"\n        reasoning = \"Testing empty code handling\"\n        code = \"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert result == \"\"\n\n    def test_run_python_code_whitespace_only(self):\n        \"\"\"Test Python code execution with whitespace only.\"\"\"\n        reasoning = \"Testing whitespace-only code\"\n        code = \"   \\n  \\t  \\n   \"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert result == \"\"\n\n    def test_run_python_code_with_comments(self):\n        \"\"\"Test Python code execution with comments.\"\"\"\n        reasoning = \"Testing code with comments\"\n        code = \"\"\"\n# This is a comment\nx = 5  # Another comment\nprint(f\"Value: {x}\")  # Final comment\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Value: 5\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Value: 5\" in result\n\n    def test_run_python_code_security_restrictions(self):\n        \"\"\"Test Python code with potentially restricted operations.\"\"\"\n        reasoning = \"Testing security restrictions\"\n        code = \"\"\"\n# Attempting file operations\ntry:\n    with open('/etc/passwd', 'r') as f:\n        content = f.read()\n        print(\"File read successfully\")\nexcept Exception as e:\n    print(f\"Security restriction: {e}\")\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (\n                False,\n                \"PermissionError: [Errno 13] Permission denied: '/etc/passwd'\",\n            )\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Error:\" in result\n\n    def test_run_python_code_timeout_handling(self):\n        \"\"\"Test Python code execution timeout scenarios.\"\"\"\n        reasoning = \"Testing timeout handling\"\n        code = \"\"\"\nimport time\ntime.sleep(10)  # Long running operation\nprint(\"This might timeout\")\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (False, \"TimeoutError: Code execution timed out\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Error:\" in result\n            assert \"TimeoutError\" in result\n\n    def test_run_python_code_memory_limit(self):\n        \"\"\"Test Python code execution with memory limitations.\"\"\"\n        reasoning = \"Testing memory limit handling\"\n        code = \"\"\"\n# Creating a large list that might exceed memory limits\nlarge_list = [0] * (10**8)\nprint(f\"Created list with {len(large_list)} elements\")\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (False, \"MemoryError: Unable to allocate memory\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Error:\" in result\n            assert \"MemoryError\" in result\n\n    def test_run_python_code_return_values(self):\n        \"\"\"Test that return values are not captured (only prints).\"\"\"\n        reasoning = \"Testing return value handling\"\n        code = \"\"\"\ndef calculate():\n    return 42\n\nresult = calculate()\nprint(f\"Function returned: {result}\")\n# The return value itself should not be captured\ncalculate()\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (True, \"Function returned: 42\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Function returned: 42\" in result\n            # Should not contain the raw return value\n            assert result.strip() == \"Function returned: 42\"\n\n    def test_run_python_code_exception_details(self):\n        \"\"\"Test detailed exception information.\"\"\"\n        reasoning = \"Testing detailed exception handling\"\n        code = \"\"\"\ndef faulty_function():\n    raise ValueError(\"This is a custom error message\")\n\nfaulty_function()\n\"\"\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor:\n            mock_instance = mock_executor.return_value\n            mock_instance.run_code.return_value = (False, \"ValueError: This is a custom error message\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            assert isinstance(result, str)\n            assert \"Error:\" in result\n            assert \"ValueError\" in result\n            assert \"custom error message\" in result\n\n    def test_run_python_code_executor_selection(self):\n        \"\"\"Test that LocalExecutor is properly instantiated and used.\"\"\"\n        reasoning = \"Testing executor instantiation\"\n        code = \"print('Executor test')\"\n\n        with patch(\"openchatbi.tool.run_python_code.LocalExecutor\") as mock_executor_class:\n            mock_instance = mock_executor_class.return_value\n            mock_instance.run_code.return_value = (True, \"Executor test\\n\")\n\n            result = run_python_code.run({\"reasoning\": reasoning, \"code\": code})\n\n            # Verify LocalExecutor was instantiated\n            mock_executor_class.assert_called_once()\n            # Verify run_code was called with the correct code\n            mock_instance.run_code.assert_called_once_with(code)\n\n            assert isinstance(result, str)\n            assert \"Executor test\" in result\n"
  },
  {
    "path": "tests/test_tools_search_knowledge.py",
    "content": "\"\"\"Tests for search_knowledge tool functionality.\"\"\"\n\nfrom unittest.mock import patch\n\nimport pytest\n\nfrom openchatbi.tool.search_knowledge import search_knowledge, show_schema\n\n\nclass TestSearchKnowledge:\n    \"\"\"Test search_knowledge tool functionality.\"\"\"\n\n    def test_search_knowledge_basic(self):\n        \"\"\"Test basic knowledge search functionality.\"\"\"\n        reasoning = \"Looking for user information\"\n        query_list = [\"user\", \"information\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_id: User identifier\\nuser_name: User name\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n            assert \"User identifier\" in result[\"columns\"]\n            mock_search.assert_called_once_with(query_list, False)\n\n    def test_search_knowledge_table_matching(self):\n        \"\"\"Test knowledge search with table matching.\"\"\"\n        reasoning = \"Finding table relationships\"\n        query_list = [\"user\", \"metrics\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_id: Unique identifier\\nmetrics_value: Metric value\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": True,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n            mock_search.assert_called_once_with(query_list, True)\n\n    def test_search_knowledge_empty_query(self):\n        \"\"\"Test knowledge search with empty query.\"\"\"\n        reasoning = \"Testing empty search\"\n        query_list = []\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n            mock_search.assert_called_once_with(query_list, False)\n\n    def test_search_knowledge_no_matches(self):\n        \"\"\"Test knowledge search with no matches.\"\"\"\n        reasoning = \"Testing no matches\"\n        query_list = [\"nonexistent\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n            assert result[\"columns\"] == \"# Relevant Columns and Description:\\n\"\n\n    def test_search_knowledge_multiple_matches(self):\n        \"\"\"Test knowledge search with multiple matches.\"\"\"\n        reasoning = \"Finding multiple matches\"\n        query_list = [\"user\", \"data\", \"profile\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_id: User ID\\nuser_name: Name\\nprofile_data: Profile\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n            assert \"user_id\" in result[\"columns\"]\n            assert \"profile_data\" in result[\"columns\"]\n\n    def test_search_knowledge_with_synonyms(self):\n        \"\"\"Test knowledge search with synonym matching.\"\"\"\n        reasoning = \"Testing synonym search\"\n        query_list = [\"customer\", \"client\", \"user\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"customer_id: Customer identifier\\nclient_name: Client name\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n\n    def test_search_knowledge_case_insensitive(self):\n        \"\"\"Test case insensitive knowledge search.\"\"\"\n        reasoning = \"Testing case sensitivity\"\n        query_list = [\"USER\", \"Data\", \"PROFILE\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_data: User information\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n\n    def test_search_knowledge_partial_matches(self):\n        \"\"\"Test knowledge search with partial matches.\"\"\"\n        reasoning = \"Testing partial matching\"\n        query_list = [\"usr\", \"prof\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_profile: User profile data\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n\n    def test_search_knowledge_error_handling(self):\n        \"\"\"Test knowledge search error handling.\"\"\"\n        reasoning = \"Testing error handling\"\n        query_list = [\"test\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.side_effect = Exception(\"Search error\")\n\n            # Should handle exceptions gracefully\n            with pytest.raises(Exception):\n                search_knowledge.run(\n                    {\n                        \"reasoning\": reasoning,\n                        \"query_list\": query_list,\n                        \"knowledge_bases\": knowledge_bases,\n                        \"with_table_list\": False,\n                    }\n                )\n\n    def test_show_schema_basic(self):\n        \"\"\"Test basic schema display functionality.\"\"\"\n        reasoning = \"Showing basic schema\"\n        tables = [\"user_data\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.list_table_from_catalog\") as mock_list:\n            mock_list.return_value = [\"Table: user_data\\n# Description: User information\\n# Columns:\\nuser_id: User ID\"]\n\n            result = show_schema.run({\"reasoning\": reasoning, \"tables\": tables})\n\n            assert isinstance(result, list)\n            assert len(result) == 1\n            assert \"user_data\" in result[0]\n            mock_list.assert_called_once_with(tables)\n\n    def test_show_schema_detailed_info(self):\n        \"\"\"Test detailed schema information.\"\"\"\n        reasoning = \"Showing detailed schema\"\n        tables = [\"user_data\", \"metrics\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.list_table_from_catalog\") as mock_list:\n            mock_list.return_value = [\n                \"Table: user_data\\n# Columns: user_id, name, email\",\n                \"Table: metrics\\n# Columns: metric_id, value, timestamp\",\n            ]\n\n            result = show_schema.run({\"reasoning\": reasoning, \"tables\": tables})\n\n            assert isinstance(result, list)\n            assert len(result) == 2\n            assert any(\"user_data\" in schema for schema in result)\n            assert any(\"metrics\" in schema for schema in result)\n\n    def test_show_schema_nonexistent_table(self):\n        \"\"\"Test schema display for nonexistent table.\"\"\"\n        reasoning = \"Testing nonexistent table\"\n        tables = [\"nonexistent_table\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.list_table_from_catalog\") as mock_list:\n            mock_list.return_value = []\n\n            result = show_schema.run({\"reasoning\": reasoning, \"tables\": tables})\n\n            assert isinstance(result, list)\n            assert len(result) == 0\n\n    def test_show_schema_table_error(self):\n        \"\"\"Test schema display error handling.\"\"\"\n        reasoning = \"Testing schema errors\"\n        tables = [\"error_table\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.list_table_from_catalog\") as mock_list:\n            mock_list.side_effect = Exception(\"Table access error\")\n\n            with pytest.raises(Exception):\n                show_schema.run({\"reasoning\": reasoning, \"tables\": tables})\n\n    def test_show_schema_complex_table(self):\n        \"\"\"Test schema display for complex table structure.\"\"\"\n        reasoning = \"Showing complex schema\"\n        tables = [\"complex_table\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.list_table_from_catalog\") as mock_list:\n            mock_list.return_value = [\n                \"Table: complex_table\\n# Description: Complex data structure\\n# Columns:\\nid: Primary key\\ndata: JSON data\\ncreated_at: Timestamp\"\n            ]\n\n            result = show_schema.run({\"reasoning\": reasoning, \"tables\": tables})\n\n            assert isinstance(result, list)\n            assert \"complex_table\" in result[0]\n            assert \"Primary key\" in result[0]\n\n    def test_search_knowledge_with_metrics(self):\n        \"\"\"Test knowledge search focusing on metrics.\"\"\"\n        reasoning = \"Finding metrics columns\"\n        query_list = [\"revenue\", \"clicks\", \"impressions\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"revenue: Revenue amount\\nclicks: Click count\\nimpressions: Impression count\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"revenue\" in result[\"columns\"]\n            assert \"clicks\" in result[\"columns\"]\n\n    def test_search_knowledge_contextual_search(self):\n        \"\"\"Test contextual knowledge search.\"\"\"\n        reasoning = \"Contextual search for user behavior\"\n        query_list = [\"user\", \"behavior\", \"tracking\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_behavior: User activity tracking\\ntracking_id: Tracking identifier\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"behavior\" in result[\"columns\"]\n\n    def test_search_knowledge_with_aggregations(self):\n        \"\"\"Test knowledge search for aggregation columns.\"\"\"\n        reasoning = \"Finding aggregation metrics\"\n        query_list = [\"sum\", \"count\", \"average\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"total_count: Count aggregation\\naverage_value: Average calculation\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n\n    def test_show_schema_with_examples(self):\n        \"\"\"Test schema display with usage examples.\"\"\"\n        reasoning = \"Showing schema with examples\"\n        tables = [\"example_table\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.list_table_from_catalog\") as mock_list:\n            mock_list.return_value = [\n                \"Table: example_table\\n# Description: Example usage\\n## Derived metrics:\\nSELECT COUNT(*) FROM example_table\"\n            ]\n\n            result = show_schema.run({\"reasoning\": reasoning, \"tables\": tables})\n\n            assert isinstance(result, list)\n            assert \"example_table\" in result[0]\n            assert \"Derived metrics\" in result[0]\n\n    def test_search_knowledge_performance(self):\n        \"\"\"Test knowledge search performance characteristics.\"\"\"\n        reasoning = \"Testing search performance\"\n        query_list = [\"performance\", \"speed\", \"optimization\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"performance_metric: Performance measurement\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            # Just ensure it completes without performance issues\n\n    def test_search_knowledge_special_characters(self):\n        \"\"\"Test knowledge search with special characters.\"\"\"\n        reasoning = \"Testing special character handling\"\n        query_list = [\"user@domain\", \"data-point\", \"metric_value\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_email: User email address\\ndata_point: Data measurement\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n\n    def test_search_knowledge_unicode_support(self):\n        \"\"\"Test knowledge search with unicode characters.\"\"\"\n        reasoning = \"Testing unicode support\"\n        query_list = [\"utilización\", \"données\", \"用户\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"user_data: International user data\"\n\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n\n    def test_knowledge_integration_with_state(self):\n        \"\"\"Test knowledge search integration with agent state.\"\"\"\n        reasoning = \"Testing state integration\"\n        query_list = [\"state\", \"integration\"]\n        knowledge_bases = [\"columns\"]\n\n        with patch(\"openchatbi.tool.search_knowledge.search_column_from_catalog\") as mock_search:\n            mock_search.return_value = \"state_data: Application state information\"\n\n            # Test that the tool can be called in the context of agent state\n            result = search_knowledge.run(\n                {\n                    \"reasoning\": reasoning,\n                    \"query_list\": query_list,\n                    \"knowledge_bases\": knowledge_bases,\n                    \"with_table_list\": False,\n                }\n            )\n\n            assert isinstance(result, dict)\n            assert \"columns\" in result\n"
  },
  {
    "path": "tests/test_utils.py",
    "content": "\"\"\"Tests for utility functions.\"\"\"\n\nimport io\nfrom unittest.mock import patch\n\nimport pytest\n\nfrom openchatbi.utils import log\n\n\nclass TestUtilityFunctions:\n    \"\"\"Test utility functions.\"\"\"\n\n    def test_log_function_basic(self):\n        \"\"\"Test basic logging functionality.\"\"\"\n        # Capture stdout\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(\"Test message\")\n\n        output = captured_output.getvalue()\n        assert \"Test message\" in output\n\n    def test_log_function_multiple_messages(self):\n        \"\"\"Test logging with multiple messages.\"\"\"\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(\"First message\")\n            log(\"Second message\")\n\n        output = captured_output.getvalue()\n        assert \"First message\" in output\n        assert \"Second message\" in output\n\n    def test_log_function_empty_message(self):\n        \"\"\"Test logging with empty message.\"\"\"\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(\"\")\n\n        output = captured_output.getvalue()\n        # Should handle empty messages gracefully\n        assert output is not None\n\n    def test_log_function_none_message(self):\n        \"\"\"Test logging with None message.\"\"\"\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(None)\n\n        output = captured_output.getvalue()\n        # Should handle None messages gracefully\n        assert \"None\" in output or output == \"\"\n\n    def test_log_function_complex_objects(self):\n        \"\"\"Test logging with complex objects.\"\"\"\n        test_dict = {\"key\": \"value\", \"number\": 42}\n        test_list = [1, 2, 3, \"string\"]\n\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(test_dict)\n            log(test_list)\n\n        output = captured_output.getvalue()\n        assert \"key\" in output or str(test_dict) in output\n        assert \"string\" in output or str(test_list) in output\n\n    def test_log_function_with_exception(self):\n        \"\"\"Test logging exception information.\"\"\"\n        try:\n            raise ValueError(\"Test exception\")\n        except ValueError as e:\n            captured_output = io.StringIO()\n\n            with patch(\"sys.stderr\", captured_output):\n                log(f\"Exception occurred: {e}\")\n\n            output = captured_output.getvalue()\n            assert \"Exception occurred\" in output\n            assert \"Test exception\" in output\n\n    @patch(\"sys.stderr\")\n    def test_log_function_stderr_error(self, mock_stderr):\n        \"\"\"Test logging when stderr has issues.\"\"\"\n        mock_stderr.write.side_effect = OSError(\"stderr error\")\n\n        # Current implementation raises exception when stderr fails - this is expected\n        with pytest.raises(OSError, match=\"stderr error\"):\n            log(\"Test message\")\n\n    def test_log_function_unicode_handling(self):\n        \"\"\"Test logging with unicode characters.\"\"\"\n        unicode_message = \"Test with émojis: 🚀 and spéciál characters: ñáéíóú\"\n\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(unicode_message)\n\n        output = captured_output.getvalue()\n        # Should handle unicode characters properly\n        assert len(output) > 0\n\n    def test_log_function_large_message(self):\n        \"\"\"Test logging with very large messages.\"\"\"\n        large_message = \"A\" * 10000  # 10KB message\n\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(large_message)\n\n        output = captured_output.getvalue()\n        assert len(output) > 0\n        assert \"A\" in output\n\n    def test_log_function_newline_handling(self):\n        \"\"\"Test logging with messages containing newlines.\"\"\"\n        multiline_message = \"Line 1\\\\nLine 2\\\\nLine 3\"\n\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(multiline_message)\n\n        output = captured_output.getvalue()\n        assert \"Line 1\" in output\n        assert \"Line 2\" in output\n        assert \"Line 3\" in output\n\n    def test_log_function_timestamp_format(self):\n        \"\"\"Test that log includes timestamp information.\"\"\"\n        captured_output = io.StringIO()\n\n        with patch(\"sys.stderr\", captured_output):\n            log(\"Timestamp test\")\n\n        output = captured_output.getvalue()\n        # Check if output contains timestamp-like format (basic check)\n        # The actual implementation might vary\n        assert len(output) > len(\"Timestamp test\")\n\n    def test_log_function_concurrent_calls(self):\n        \"\"\"Test logging with concurrent-like calls.\"\"\"\n        import threading\n        import time\n\n        captured_output = io.StringIO()\n\n        def log_worker(message):\n            log(f\"Worker: {message}\")\n            time.sleep(0.01)  # Small delay\n\n        # Patch stderr for all threads\n        with patch(\"sys.stderr\", captured_output):\n            # Create multiple threads (simulating concurrency)\n            threads = []\n            for i in range(5):\n                thread = threading.Thread(target=log_worker, args=(f\"message_{i}\",))\n                threads.append(thread)\n\n            # Start all threads\n            for thread in threads:\n                thread.start()\n\n            # Wait for all threads\n            for thread in threads:\n                thread.join()\n\n        output = captured_output.getvalue()\n        # Should handle concurrent access gracefully\n        assert len(output) > 0\n"
  },
  {
    "path": "timeseries_forecasting/Dockerfile",
    "content": "FROM python:3.10-slim\n\n# Install only essential build tools\nRUN apt-get update && apt-get install -y --no-install-recommends \\\n    curl \\\n    wget \\\n    build-essential \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\n# Install Python dependencies for time series forecasting\nRUN pip3 --no-cache-dir install \\\n    fastapi==0.120.4 \\\n    uvicorn==0.38.0 \\\n    transformers==4.40.1 \\\n    torch==2.9.0 \\\n    numpy==2.2.6 \\\n    pandas==2.3.3 \\\n    pydantic==2.12.3\n\n# Set working directory\nWORKDIR /home/model-server\n\n# Copy the model\nCOPY ../hf_model /home/model-server/hf_model\n\n# Copy application files\nCOPY app.py model_handler.py /home/model-server/\n\n# Set environment variables\nENV PYTHONPATH=/home/model-server\nENV PYTHONUNBUFFERED=1\n\n# Expose port\nEXPOSE 8765\n\n# Define entrypoint and default command\nCMD [\"uvicorn\", \"app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8765\"]"
  },
  {
    "path": "timeseries_forecasting/README.md",
    "content": "# Transformer Time Series Forecasting Service\n\nA Docker-based time series forecasting service using Transformer based models for accurate time series prediction. This service provides a FastAPI-based REST API for easy integration with various applications.\n\n## Features\n\n- **Transformer Model Integration**: Uses state-of-the-art Transformer models for time series forecasting\n- **FastAPI Backend**: Modern, fast web framework with automatic API documentation\n- **Docker Support**: Fully containerized service for easy deployment\n- **Flexible Input**: Supports both simple numeric arrays and structured data with timestamps\n- **Multiple Forecast Horizons**: Configure prediction length from 1 to 200 time steps\n- **GPU Support**: Automatic GPU detection and utilization when available\n\n## Prerequisites\n\n- Docker installed and running\n- Transformer model files (compatible with Hugging Face transformers library)\n\n## Quick Start\n\n### 1. Download Transformer Model\n\nDownload a pre-trained model from Hugging Face and place it in the `hf_model` directory. For example, use the recommended `timer-base-84m` model from https://huggingface.co/thuml/timer-base-84m:\n\n> **Note**: The `timer-base-84m` model requires at least 96 time points in the input data. When integrating with OpenChatBI, add this restriction to your `extra_tool_use_rule` in bi.yaml:\n> ```\n> - timeseries_forecast tool requires at least 96 time points in input data. If no enough input data, set input_len to 96 to pad with zeros.\n> ```\n\n```bash\n\n\n### 2. Build and Run\n\n```bash\ncd timeseries_forecasting\nchmod +x build_and_run.sh\n./build_and_run.sh\n```\n\nThe service will be available at:\n- **Predictions**: `http://localhost:8765/predict`\n- **Health Check**: `http://localhost:8765/health`\n- **API Documentation**: `http://localhost:8765/docs`\n- **Model Info**: `http://localhost:8765/model/info`\n\n### 2. Make a Prediction\n\n```bash\ncurl -X POST http://localhost:8765/predict \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"input\": [100, 102, 98, 105, 107, 103, 99, 101, 104, 106],\n    \"input_len\": 100,\n    \"forecast_window\": 5,\n    \"frequency\": \"H\"\n  }'\n```\n\n### 4. Test the Service\n\nRun the comprehensive test suite:\n\n```bash\npython test_forecasting.py --url http://localhost:8765\n```\n\n## API Reference\n\n### Prediction Endpoint\n\n**POST** `/predict`\n\n#### Request Format\n\n```json\n{\n  \"input\": [...],              // Time series data (required)\n  \"forecast_window\": 24,       // Number of future points to predict (default: 24, max: 200)\n  \"frequency\": \"H\",           // Frequency: \"H\" (hourly), \"D\" (daily), etc. (default: \"H\")\n  \"input_len\": null,          // Limit input length, if provided, will use it to truncate input or pad zero (optional)\n  \"target_column\": \"value\"    // Column name for structured data (default: \"value\")\n}\n```\n\n#### Input Data Formats\n\n**Simple Numeric Array:**\n```json\n{\n  \"input\": [100, 102, 98, 105, 107, 103, 99, 101], \n  \"input_len\": 100,\n  \"forecast_window\": 12\n}\n```\n\n**Structured Data with Timestamps:**\n```json\n{\n  \"input\": [\n    {\"timestamp\": \"2024-01-01T00:00:00\", \"value\": 100},\n    {\"timestamp\": \"2024-01-01T01:00:00\", \"value\": 102},\n    {\"timestamp\": \"2024-01-01T02:00:00\", \"value\": 98}\n  ],\n  \"input_len\": 100,\n  \"forecast_window\": 24,\n  \"target_column\": \"value\"\n}\n```\n\n#### Response Format\n\n```json\n{\n  \"predictions\": [101.5, 103.2, 99.8, ...],\n  \"forecast_window\": 24,\n  \"frequency\": \"H\",\n  \"status\": \"success\"\n}\n```\n\n## Configuration\n\n### Environment Variables\n\n- `PYTHONPATH`: Python path for modules (default: /home/model-server)\n- `PYTHONUNBUFFERED`: Disable Python output buffering (default: 1)\n\n### Docker Run Options\n\n```bash\n# Basic run\ndocker run -p 8765:8765 timeseries-forecasting\n\n# With volume mount for models\ndocker run -p 8765:8765 \\\n  -v /path/to/model:/app/hf_model \\\n  timeseries-forecasting\n\n# With custom environment variables\ndocker run -p 8765:8765 \\\n  -e PYTHONPATH=/home/model-server \\\n  timeseries-forecasting\n```\n\n## Testing\n\n### Service Tests\n\nRun the test script to validate the service:\n\n```bash\n# Make test script executable\nchmod +x test_forecasting.py\n\n# Install test dependencies\npip install requests numpy\n\n# Run tests\npython test_forecasting.py --url http://localhost:8765\n```\n\n## Model Information\n\n- **Recommended Models**: https://huggingface.co/thuml/timer-base-84m\n- **Model Type**: Transformer-based Causal Language Model for Time Series\n- **Framework**: Hugging Face Transformers\n- **Architecture**: AutoModelForCausalLM\n- **Device Support**: Automatic GPU/CPU detection\n- **Capabilities**: Univariate time series forecasting with automatic normalization\n\n## Troubleshooting\n\n### Common Issues\n\n1. **Service Not Starting**\n   - Check if port 8765 is available: `lsof -i :8765`\n   - Verify Docker has sufficient memory allocated (minimum 4GB recommended)\n   - Check logs: `docker logs time-series-forecasting-service`\n\n2. **Model Loading Errors**\n   - Ensure model files are properly copied during build\n   - Check available disk space (models can be several GB)\n   - Verify Hugging Face transformers library compatibility\n\n3. **Prediction Errors**\n   - Validate input data format\n   - Check forecast horizon is reasonable\n   - Ensure input data has sufficient length\n\n### Debug Mode\n\nEnable debug logging:\n\n```bash\ndocker run -p 8765:8765 \\\n  -e PYTHONPATH=/home/model-server \\\n  -e LOGGING_LEVEL=DEBUG \\\n  timeseries-forecasting\n```\n\n## Performance\n\n- **Cold Start**: ~10 seconds (model loading)\n- **Inference Time**: ~100-300ms per request (varies by input size and model)\n- **Memory Usage**: ~2-4GB (depending on input size and model)\n- **Concurrent Requests**: Supported (configure workers)\n\n## Limitations\n\n- Maximum forecast window: 200 time points\n- Univariate forecasting (single time series)\n- Requires minimum input data for reliable predictions, timer-base-84m needs at least 96 time points\n- Model-specific context length limitations may apply\n"
  },
  {
    "path": "timeseries_forecasting/app.py",
    "content": "\"\"\"app.py: FastAPI application for Transformer time series forecasting.\"\"\"\n\nimport logging\nimport time\nfrom typing import Any\n\nimport uvicorn\nfrom fastapi import FastAPI, HTTPException\nfrom fastapi.responses import JSONResponse\nfrom pydantic import BaseModel, Field\nfrom starlette.requests import Request\n\nfrom model_handler import TransformerModelHandler, get_model_handler\n\n# Configure logging\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n\n# Create FastAPI app\napp = FastAPI(\n    title=\"Transformer Time Series Forecasting API\",\n    description=\"A REST API for time series forecasting using Transformer model\",\n    version=\"1.0.0\",\n)\n\n\n# Request models\nclass ForecastRequest(BaseModel):\n    \"\"\"Request model for forecasting.\"\"\"\n\n    input: list[float | int | dict[str, Any]] = Field(\n        ...,\n        description=\"Time series data as list of numbers or structured data\",\n        example=[100, 102, 98, 105, 107, 103, 99, 101],\n    )\n    forecast_window: int = Field(default=24, ge=1, le=200, description=\"Number of future points to predict\")\n    input_len: int | None = Field(default=None, description=\"Optional input length limit\")\n    frequency: str = Field(default=\"hourly\", description=\"Frequency of the time series (hourly, daily, etc.)\")\n    target_column: str = Field(default=\"value\", description=\"Column name for structured data\")\n\n\nclass ForecastResponse(BaseModel):\n    \"\"\"Response model for forecasting.\"\"\"\n\n    predictions: list[float] = Field(description=\"Forecasted values\")\n    forecast_window: int = Field(description=\"Number of predictions\")\n    frequency: str = Field(description=\"Time series frequency\")\n    status: str = Field(description=\"Response status\")\n\n\nclass ErrorResponse(BaseModel):\n    \"\"\"Error response model.\"\"\"\n\n    error: str = Field(description=\"Error message\")\n    status: str = Field(description=\"Response status\")\n\n\n# Global variables\nmodel_handler: TransformerModelHandler | None = None\nstartup_time: float | None = None\n\n\n@app.on_event(\"startup\")\nasync def startup_event():\n    \"\"\"Initialize model on startup.\"\"\"\n    global model_handler, startup_time\n    startup_time = time.time()\n    logger.info(\"Starting Transformer Forecasting API...\")\n\n    try:\n        # Initialize model handler\n        model_handler = get_model_handler()\n        model_success = model_handler.initialize()\n\n        if model_success:\n            logger.info(\"Model initialized successfully\")\n        else:\n            logger.error(\"Failed to initialize model\")\n\n    except Exception as e:\n        logger.error(f\"Startup failed: {str(e)}\")\n\n\n@app.get(\"/health\")\nasync def health_check():\n    \"\"\"Health check endpoint.\"\"\"\n    uptime = time.time() - startup_time if startup_time else 0\n\n    return {\n        \"status\": \"healthy\",\n        \"model_initialized\": model_handler.initialized if model_handler else False,\n        \"uptime_seconds\": round(uptime, 2),\n    }\n\n\n@app.get(\"/ping\")\nasync def ping():\n    \"\"\"Simple ping endpoint.\"\"\"\n    return {\"status\": \"ok\"}\n\n\n@app.post(\n    \"/predict\",\n    response_model=ForecastResponse | ErrorResponse,\n    responses={\n        400: {\"model\": ErrorResponse, \"description\": \"Bad Request\"},\n        422: {\"model\": ErrorResponse, \"description\": \"Validation Error\"},\n        500: {\"model\": ErrorResponse, \"description\": \"Internal Error\"},\n    },\n)\nasync def predict(request: ForecastRequest):\n    \"\"\"\n    Main forecasting endpoint.\n\n    Args:\n        request: Forecast request containing time series data and parameters\n\n    Returns:\n        Forecast response with predictions or error\n    \"\"\"\n    try:\n        logger.info(f\"Received prediction request: {len(request.input)} data points, horizon={request.forecast_window}\")\n\n        # Check if model is initialized\n        if not model_handler or not model_handler.initialized:\n            raise HTTPException(status_code=500, detail=\"Model not initialized\")\n\n        # Validate input\n        if len(request.input) == 0:\n            raise HTTPException(status_code=400, detail=\"Input data cannot be empty\")\n\n        # Make prediction\n        result = model_handler.predict(\n            time_series_data=request.input,\n            forecast_window=request.forecast_window,\n            input_len=request.input_len,\n            frequency=request.frequency,\n            target_column=request.target_column,\n        )\n\n        # Check if prediction was successful\n        if result.get(\"status\") == \"error\":\n            raise HTTPException(status_code=result.get(\"code\", 500), detail=result.get(\"error\", \"Prediction failed\"))\n\n        logger.info(f\"Prediction successful: {len(result['predictions'])} predictions generated\")\n\n        return ForecastResponse(**result)\n\n    except HTTPException as e:\n        return JSONResponse(status_code=e.status_code, content=ErrorResponse(error=str(e), status=\"error\").model_dump())\n    except Exception as e:\n        logger.error(f\"Prediction error: {str(e)}\")\n        return JSONResponse(status_code=500, content=ErrorResponse(error=str(e), status=\"error\").model_dump())\n\n\n@app.get(\"/model/info\")\nasync def model_info():\n    \"\"\"Get model information.\"\"\"\n    if not model_handler or not model_handler.initialized:\n        return {\"error\": \"Model not initialized\", \"status\": \"error\"}\n\n    return {\n        \"model_path\": model_handler.model_path,\n        \"device\": str(model_handler.device),\n        \"initialized\": model_handler.initialized,\n        \"config\": str(model_handler.config) if model_handler.config else None,\n    }\n\n\n@app.get(\"/\")\nasync def root():\n    \"\"\"Root endpoint with API information.\"\"\"\n    return {\n        \"name\": \"Transformer Time Series Forecasting API\",\n        \"version\": \"1.0.0\",\n        \"description\": \"REST API for time series forecasting using Transformer model\",\n        \"endpoints\": {\n            \"predict\": \"/predict\",\n            \"health\": \"/health\",\n            \"ping\": \"/ping\",\n            \"model_info\": \"/model/info\",\n            \"docs\": \"/docs\",\n        },\n    }\n\n\n# Error handlers\n@app.exception_handler(HTTPException)\nasync def http_exception_handler(request: Request, exc: HTTPException):\n    \"\"\"Handle HTTP exceptions.\"\"\"\n    return JSONResponse(\n        status_code=exc.status_code,\n        content={\"status\": \"error\", \"message\": exc.detail},\n    )\n\n\n@app.exception_handler(Exception)\nasync def general_exception_handler(request: Request, exc: Exception):\n    \"\"\"Handle general exceptions.\"\"\"\n    logger.error(f\"Unhandled exception: {str(exc)}\")\n    return JSONResponse(\n        status_code=500,\n        content={\"status\": \"error\", \"message\": \"Internal server error\"},\n    )\n\n\nif __name__ == \"__main__\":\n    # For development\n    uvicorn.run(\"app:app\", host=\"0.0.0.0\", port=8765, reload=True, log_level=\"info\")\n"
  },
  {
    "path": "timeseries_forecasting/build_and_run.sh",
    "content": "#!/bin/bash\n\n# Build and run script for time series forecasting service\nset -e\n\necho \"=== Building Timeseries Forecasting Docker Container ===\"\n\n# Check if the hf_model model directory exists\nMODEL_DIR=\"../hf_model\"\nif [ ! -d \"$MODEL_DIR\" ]; then\n    echo \"Error: Hugging face model directory not found at $MODEL_DIR\"\n    echo \"Please ensure the model is downloaded and available at this location\"\n    exit 1\nfi\n\necho \"✓ Found Hugging face model at: $MODEL_DIR\"\n\nrm -rf hf_model\ncp -r $MODEL_DIR .\n\n# Build the Docker image\necho \"Building Docker image...\"\ndocker build -t timeseries-forecasting .\n\nif [ $? -eq 0 ]; then\n    echo \"✓ Docker image built successfully\"\nelse\n    echo \"✗ Failed to build Docker image\"\n    exit 1\nfi\n\n# Check if container is already running\nCONTAINER_NAME=\"time-series-forecasting-service\"\nif [ \"$(docker ps -q -f name=$CONTAINER_NAME)\" ]; then\n    echo \"Stopping existing container...\"\n    docker stop $CONTAINER_NAME\n    docker rm $CONTAINER_NAME\nfi\n\necho \"=== Starting Timeseries Forecasting Service ===\"\n\n# Run the container\ndocker run -d \\\n    --name $CONTAINER_NAME \\\n    -p 8765:8765 \\\n    timeseries-forecasting\n\nif [ $? -eq 0 ]; then\n    echo \"✓ Container started successfully\"\n    echo \"\"\n    echo \"Service endpoints:\"\n    echo \"  - Predictions: http://localhost:8765/predict\"\n    echo \"  - Health Check: http://localhost:8765/health\"\n    echo \"  - API Docs: http://localhost:8765/docs\"\n    echo \"\"\n    echo \"Container logs:\"\n    echo \"  docker logs -f $CONTAINER_NAME\"\n    echo \"\"\n    echo \"To test the service:\"\n    echo \"  python test_forecasting.py\"\n    echo \"\"\n    echo \"To stop the service:\"\n    echo \"  docker stop $CONTAINER_NAME\"\nelse\n    echo \"✗ Failed to start container\"\n    exit 1\nfi\n\n# Wait a moment and check if container is still running\nsleep 5\nif [ \"$(docker ps -q -f name=$CONTAINER_NAME)\" ]; then\n    echo \"✓ Service is running\"\n\n    # Show few logs\n    echo \"\"\n    echo \"=== Initial Service Logs ===\"\n    docker logs \"$CONTAINER_NAME\" | head -n 50\nelse\n    echo \"✗ Service failed to start\"\n    echo \"Checking logs...\"\n    docker logs $CONTAINER_NAME\n    exit 1\nfi"
  },
  {
    "path": "timeseries_forecasting/model_handler.py",
    "content": "\"\"\"model_handler.py: Transformer based model handler for time series forecasting.\"\"\"\n\nimport logging\nfrom typing import Any\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\n# Configure logging\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n\n\nclass TransformerModelHandler:\n    \"\"\"\n    Transformer based Model handler for time series forecasting.\n    \"\"\"\n\n    def __init__(self, model_path: str = \"hf_model\"):\n        \"\"\"Initialize the model handler.\"\"\"\n        logger.info(\"Initializing Transformer Model Handler\")\n        self.model_path = model_path\n        self.model = None\n        self.config = None\n        self.device = None\n        self.initialized = False\n\n    def initialize(self) -> bool:\n        \"\"\"\n        Initialize model.\n\n        Returns:\n            bool: True if initialization successful\n        \"\"\"\n        try:\n            logger.info(\"Starting model initialization\")\n\n            # Set device\n            self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n            logger.info(f\"Using device: {self.device}\")\n\n            logger.info(f\"Loading model from: {self.model_path}\")\n\n            # Load model configuration\n            self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)\n\n            # Load the pretrained model\n            self.model = AutoModelForCausalLM.from_pretrained(\n                self.model_path, config=self.config, trust_remote_code=True\n            )\n\n            # Move model to device\n            self.model.to(self.device)\n            self.model.eval()\n\n            self.initialized = True\n            logger.info(\"Transformer model loaded successfully\")\n            logger.info(f\"Model config: {self.config}\")\n\n            return True\n\n        except Exception as e:\n            logger.error(f\"Failed to initialize model: {str(e)}\")\n            self.initialized = False\n            return False\n\n    def preprocess(\n        self,\n        time_series_data: list,\n        forecast_window: int = 24,\n        input_len: int | None = None,\n        frequency: str = \"hourly\",\n        target_column: str = \"value\",\n    ) -> tuple[torch.Tensor, dict[str, Any]]:\n        \"\"\"\n        Transform raw input into model input data.\n\n        Args:\n            time_series_data: Input time series data\n            forecast_window: Number of future points to predict\n            input_len: Optional input length limit\n            frequency: Frequency of the time series\n            target_column: Column name for structured data\n\n        Returns:\n            Tuple of (processed_tensor, metadata)\n        \"\"\"\n        try:\n            logger.info(f\"Input data length: {len(time_series_data) if isinstance(time_series_data, list) else 'N/A'}\")\n            logger.info(f\"Forecast window: {forecast_window}\")\n\n            # Convert input to numpy array\n            if isinstance(time_series_data, list):\n                if len(time_series_data) > 0 and isinstance(time_series_data[0], dict):\n                    # Handle structured data (with timestamps)\n                    df = pd.DataFrame(time_series_data)\n                    if target_column in df.columns:\n                        values = df[target_column].values\n                    else:\n                        # Use the first numeric column\n                        numeric_cols = df.select_dtypes(include=[np.number]).columns\n                        if len(numeric_cols) > 0:\n                            values = df[numeric_cols[0]].values\n                        else:\n                            values = np.array([float(x) for x in time_series_data])\n                else:\n                    # Handle simple numeric list\n                    values = np.array([float(x) for x in time_series_data])\n            else:\n                values = np.array(time_series_data)\n\n            # Handle input length constraint\n            if input_len is not None:\n                if input_len > len(values):\n                    # Pad with zeros if input is shorter than required\n                    values = np.pad(values, (input_len - len(values), 0), mode=\"constant\", constant_values=0)\n                elif input_len < len(values):\n                    # Take the last input_len values\n                    values = values[-input_len:]\n\n            # Normalize the data (simple z-score normalization)\n            mean_val = np.mean(values)\n            std_val = np.std(values)\n            if std_val > 0:\n                normalized_values = (values - mean_val) / std_val\n            else:\n                normalized_values = values - mean_val\n\n            # Convert to tensor\n            tensor = torch.tensor(normalized_values, dtype=torch.float32).unsqueeze(0)\n            tensor = tensor.to(self.device)\n\n            # Store metadata for post-processing\n            metadata = {\n                \"mean\": mean_val,\n                \"std\": std_val,\n                \"forecast_window\": forecast_window,\n                \"frequency\": frequency,\n                \"original_length\": len(values),\n            }\n\n            logger.info(f\"Preprocessed tensor shape: {tensor.shape}\")\n\n            return tensor, metadata\n\n        except Exception as e:\n            logger.error(f\"Preprocessing failed: {str(e)}\")\n            raise e\n\n    def inference(self, input_tensor: torch.Tensor, metadata: dict[str, Any]) -> torch.Tensor:\n        \"\"\"\n        Run inference on the model.\n\n        Args:\n            input_tensor: Preprocessed input tensor\n            metadata: Preprocessing metadata\n\n        Returns:\n            Model output tensor\n        \"\"\"\n        try:\n            if not self.initialized:\n                raise RuntimeError(\"Model not initialized\")\n\n            with torch.no_grad():\n                forecast_window = metadata.get(\"forecast_window\", 24)\n\n                # Use generate method\n                output = self.model.generate(input_tensor, max_new_tokens=forecast_window)\n\n                logger.info(f\"Model output shape: {output.shape}\")\n                return output\n\n        except ValueError as e:\n            logger.error(f\"Inference failed due to ValueError: {str(e)}\")\n            raise e\n        except Exception as e:\n            logger.error(f\"Inference failed: {str(e)}\")\n            raise e\n\n    def postprocess(self, output_tensor: torch.Tensor, metadata: dict[str, Any]) -> list[float]:\n        \"\"\"\n        Transform model output to final prediction format.\n\n        Args:\n            output_tensor: Raw model output\n            metadata: Preprocessing metadata\n\n        Returns:\n            Final predictions as list\n        \"\"\"\n        try:\n            # Extract predictions from tensor\n            if output_tensor.dim() > 1:\n                predictions = output_tensor[0].cpu().numpy()\n            else:\n                predictions = output_tensor.cpu().numpy()\n\n            # Denormalize the predictions\n            mean_val = metadata.get(\"mean\", 0)\n            std_val = metadata.get(\"std\", 1)\n\n            if std_val > 0:\n                denormalized_predictions = predictions * std_val + mean_val\n            else:\n                denormalized_predictions = predictions + mean_val\n\n            # Convert to list and ensure it's the right length\n            forecast_window = metadata.get(\"forecast_window\", 24)\n            result = denormalized_predictions[:forecast_window].tolist()\n\n            logger.info(f\"Final predictions length: {len(result)}\")\n\n            return result\n\n        except Exception as e:\n            logger.error(f\"Postprocessing failed: {str(e)}\")\n            raise e\n\n    def predict(\n        self,\n        time_series_data: list,\n        forecast_window: int = 24,\n        input_len: int | None = None,\n        frequency: str = \"hourly\",\n        target_column: str = \"value\",\n    ) -> dict[str, Any]:\n        \"\"\"\n        Main prediction method.\n\n        Args:\n            time_series_data: Input time series data\n            forecast_window: Number of future points to predict\n            input_len: Optional input length limit\n            frequency: Frequency of the time series\n            target_column: Column name for structured data\n\n        Returns:\n            Dictionary containing predictions and metadata\n        \"\"\"\n        try:\n            # Ensure model is initialized\n            if not self.initialized:\n                if not self.initialize():\n                    raise RuntimeError(\"Failed to initialize model\")\n\n            # Preprocess input\n            input_tensor, metadata = self.preprocess(\n                time_series_data, forecast_window, input_len, frequency, target_column\n            )\n\n            # Run inference\n            output_tensor = self.inference(input_tensor, metadata)\n\n            # Postprocess output\n            predictions = self.postprocess(output_tensor, metadata)\n\n            # Format result\n            result = {\n                \"predictions\": predictions,\n                \"forecast_window\": metadata.get(\"forecast_window\", 24),\n                \"frequency\": metadata.get(\"frequency\", \"hourly\"),\n                \"status\": \"success\",\n            }\n\n            return result\n\n        except ValueError as e:\n            logger.error(f\"Prediction failed due to ValueError: {str(e)}\")\n            return {\"error\": str(e), \"code\": 400, \"status\": \"error\"}\n        except Exception as e:\n            logger.error(f\"Prediction failed: {str(e)}\")\n            return {\"error\": str(e), \"status\": \"error\"}\n\n\n# Global model handler instance\n_model_handler = None\n\n\ndef get_model_handler() -> TransformerModelHandler:\n    \"\"\"Get or create global model handler instance.\"\"\"\n    global _model_handler\n    if _model_handler is None:\n        _model_handler = TransformerModelHandler()\n    return _model_handler\n"
  },
  {
    "path": "timeseries_forecasting/test_forecasting.py",
    "content": "#!/usr/bin/env python3\n\"\"\"test_forecasting.py: Test script for Timer forecasting service.\"\"\"\n\nimport time\nfrom datetime import datetime, timedelta\n\nimport numpy as np\nimport requests\nfrom requests.exceptions import RequestException\n\n\nclass TimeseriesForecastingTester:\n    \"\"\"Test class for Timer forecasting service.\"\"\"\n\n    def __init__(self, base_url=\"http://localhost:8765\"):\n        \"\"\"Initialize the tester.\"\"\"\n        self.base_url = base_url\n        self.predictions_endpoint = f\"{base_url}/predict\"\n        self.health_endpoint = f\"{base_url}/health\"\n\n    def generate_sample_data(self, length=100, frequency=\"H\"):\n        \"\"\"Generate sample time series data for testing.\"\"\"\n        # Generate synthetic time series with trend and seasonality\n        t = np.arange(length)\n\n        # Add trend\n        trend = 0.1 * t\n\n        # Add seasonality (daily pattern for hourly data)\n        if frequency == \"H\":\n            seasonality = 5 * np.sin(2 * np.pi * t / 24)\n        else:\n            seasonality = 3 * np.sin(2 * np.pi * t / 7)  # Weekly pattern for daily data\n\n        # Add noise\n        noise = np.random.normal(0, 1, length)\n\n        # Combine components\n        values = 100 + trend + seasonality + noise\n\n        return values.tolist()\n\n    def test_basic_forecasting(self):\n        \"\"\"Test basic time series forecasting.\"\"\"\n        print(\"\\n=== Testing Basic Forecasting ===\")\n\n        # Generate sample data\n        sample_data = self.generate_sample_data(length=168, frequency=\"H\")  # 1 week of hourly data\n\n        # Prepare request payload\n        payload = {\n            \"input\": sample_data,\n            \"forecast_window\": 24,  # Forecast next 24 hours\n            \"frequency\": \"H\",\n            \"input_len\": 168,  # Use last week of data\n        }\n\n        # Send request\n        try:\n            response = requests.post(\n                self.predictions_endpoint, json=payload, headers={\"Content-Type\": \"application/json\"}, timeout=30\n            )\n\n            if response.status_code == 200:\n                result = response.json()\n                print(\"✓ Basic forecasting successful\")\n                print(f\"  - Input length: {len(sample_data)}\")\n                print(f\"  - Forecast Window: {payload['forecast_window']}\")\n                print(f\"  - Predictions length: {len(result.get('predictions', []))}\")\n                print(f\"  - Sample predictions: {result.get('predictions', [])[:5]}\")\n                return True\n            else:\n                print(f\"✗ Request failed with status: {response.status_code}\")\n                print(f\"  Response: {response.text}\")\n                return False\n\n        except requests.exceptions.RequestException as e:\n            print(f\"✗ Request failed: {str(e)}\")\n            return False\n\n    def test_structured_data(self):\n        \"\"\"Test forecasting with structured data (timestamps + values).\"\"\"\n        print(\"\\n=== Testing Structured Data Forecasting ===\")\n\n        # Generate structured data with timestamps\n        start_time = datetime.now() - timedelta(days=7)\n        structured_data = []\n\n        for i in range(168):  # 1 week of hourly data\n            timestamp = start_time + timedelta(hours=i)\n            value = 100 + 0.1 * i + 5 * np.sin(2 * np.pi * i / 24) + np.random.normal(0, 1)\n\n            structured_data.append({\"timestamp\": timestamp.isoformat(), \"value\": value})\n\n        # Prepare request payload\n        payload = {\n            \"input\": structured_data,\n            \"forecast_window\": 48,  # Forecast next 48 hours\n            \"frequency\": \"H\",\n            \"target_column\": \"value\",\n        }\n\n        # Send request\n        try:\n            response = requests.post(\n                self.predictions_endpoint, json=payload, headers={\"Content-Type\": \"application/json\"}, timeout=30\n            )\n\n            if response.status_code == 200:\n                result = response.json()\n                print(\"✓ Structured data forecasting successful\")\n                print(f\"  - Input records: {len(structured_data)}\")\n                print(f\"  - Forecast Window: {payload['forecast_window']}\")\n                print(f\"  - Predictions length: {len(result.get('predictions', []))}\")\n                return True\n            else:\n                print(f\"✗ Request failed with status: {response.status_code}\")\n                print(f\"  Response: {response.text}\")\n                return False\n\n        except requests.exceptions.RequestException as e:\n            print(f\"✗ Request failed: {str(e)}\")\n            return False\n\n    def test_different_windows(self):\n        \"\"\"Test forecasting with different forecast windows.\"\"\"\n        print(\"\\n=== Testing Different Forecast Horizons ===\")\n\n        sample_data = self.generate_sample_data(length=100)\n        windows = [1, 12, 24, 48, 72]\n\n        for window in windows:\n            payload = {\"input\": sample_data, \"forecast_window\": window, \"frequency\": \"H\"}\n\n            try:\n                response = requests.post(\n                    self.predictions_endpoint, json=payload, headers={\"Content-Type\": \"application/json\"}, timeout=30\n                )\n\n                if response.status_code == 200:\n                    result = response.json()\n                    predictions_len = len(result.get(\"predictions\", []))\n                    print(f\"✓ Window {window}: {predictions_len} predictions\")\n                else:\n                    print(f\"✗ Window {window}: Failed with status {response.status_code}\")\n                    return False\n\n            except requests.exceptions.RequestException as e:\n                print(f\"✗ Window {window}: Request failed - {str(e)}\")\n                return False\n        return True\n\n    def test_error_handling(self):\n        \"\"\"Test error handling with invalid inputs.\"\"\"\n        print(\"\\n=== Testing Error Handling ===\")\n\n        # Test empty input\n        try:\n            response = requests.post(\n                self.predictions_endpoint, json={\"input\": []}, headers={\"Content-Type\": \"application/json\"}, timeout=10\n            )\n            print(f\"Empty input: Status {response.status_code}\")\n            if response.status_code != 400:\n                print(\"✗ Empty input: Expected 400 status code\")\n                return False\n        except RequestException:\n            print(\"Empty input: exception occurred not expected\")\n            return False\n\n        # Test invalid JSON\n        try:\n            response = requests.post(\n                self.predictions_endpoint, data=\"invalid json\", headers={\"Content-Type\": \"application/json\"}, timeout=10\n            )\n            print(f\"Invalid JSON: Status {response.status_code}\")\n            if response.status_code != 422:\n                print(\"✗ Empty input: Expected 422 status code\")\n                return False\n        except RequestException:\n            print(\"Empty input: exception occurred not expected\")\n            return False\n        return True\n\n    def test_health_check(self):\n        \"\"\"Test service health check.\"\"\"\n        print(\"\\n=== Testing Service Health ===\")\n\n        try:\n            # Test health endpoint\n            response = requests.get(self.health_endpoint, timeout=5)\n\n            if response.status_code == 200:\n                result = response.json()\n                print(\"✓ Service health check passed\")\n                print(f\"  - Model initialized: {result.get('model_initialized', 'Unknown')}\")\n                print(f\"  - Uptime: {result.get('uptime_seconds', 'Unknown')} seconds\")\n                return True\n            else:\n                print(f\"✗ Health check failed: {response.status_code}\")\n                return False\n\n        except RequestException as e:\n            print(f\"Health check failed: {str(e)}\")\n            return False\n\n    def run_all_tests(self):\n        \"\"\"Run all tests.\"\"\"\n        print(\"=\" * 50)\n        print(\"TIMER FORECASTING SERVICE TESTS\")\n        print(\"=\" * 50)\n\n        # Wait for service to be ready\n        print(\"Waiting for service to be ready...\")\n        for _i in range(30):  # Wait up to 30 seconds\n            try:\n                response = requests.get(self.health_endpoint, timeout=2)\n                if response.status_code == 200:\n                    result = response.json()\n                    if result.get(\"model_initialized\", False):\n                        print(\"✓ Service is ready\")\n                        break\n            except RequestException:\n                pass\n            time.sleep(1)\n        else:\n            print(\"✗ Service not ready after 30 seconds\")\n            return False\n\n        # Run tests\n        tests = [\n            self.test_health_check,\n            self.test_basic_forecasting,\n            self.test_structured_data,\n            self.test_different_windows,\n            self.test_error_handling,\n        ]\n\n        passed = 0\n        total = len(tests)\n\n        for test in tests:\n            try:\n                if test():\n                    passed += 1\n            except Exception as e:\n                print(f\"✗ Test {test.__name__} failed with exception: {str(e)}\")\n\n        print(\"\\n\" + \"=\" * 50)\n        print(f\"TESTS COMPLETED: {passed}/{total} passed\")\n        print(\"=\" * 50)\n\n        return passed == total\n\n\ndef main():\n    \"\"\"Main function.\"\"\"\n    import argparse\n\n    parser = argparse.ArgumentParser(description=\"Test Timer forecasting service\")\n    parser.add_argument(\n        \"--url\", default=\"http://localhost:8765\", help=\"Base URL of the service (default: http://localhost:8080)\"\n    )\n\n    args = parser.parse_args()\n\n    tester = TimeseriesForecastingTester(args.url)\n    success = tester.run_all_tests()\n\n    exit(0 if success else 1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]