Repository: myscale/ChatData
Branch: main
Commit: fcfc236dc5d4
Files: 52
Total size: 156.1 KB
Directory structure:
gitextract_zniflvil/
├── .github/
│ └── workflows/
│ └── sync-with-huggingface.yml
├── .gitignore
├── LICENSE
├── README.md
├── app/
│ ├── app.py
│ ├── backend/
│ │ ├── __init__.py
│ │ ├── callbacks/
│ │ │ ├── __init__.py
│ │ │ ├── arxiv_callbacks.py
│ │ │ ├── llm_thought_with_table.py
│ │ │ ├── self_query_callbacks.py
│ │ │ └── vector_sql_callbacks.py
│ │ ├── chains/
│ │ │ ├── __init__.py
│ │ │ ├── retrieval_qa_with_sources.py
│ │ │ └── stuff_documents.py
│ │ ├── chat_bot/
│ │ │ ├── __init__.py
│ │ │ ├── chat.py
│ │ │ ├── json_decoder.py
│ │ │ ├── message_converter.py
│ │ │ ├── private_knowledge_base.py
│ │ │ ├── session_manager.py
│ │ │ └── tools.py
│ │ ├── constants/
│ │ │ ├── __init__.py
│ │ │ ├── myscale_tables.py
│ │ │ ├── prompts.py
│ │ │ ├── streamlit_keys.py
│ │ │ └── variables.py
│ │ ├── construct/
│ │ │ ├── __init__.py
│ │ │ ├── build_agents.py
│ │ │ ├── build_all.py
│ │ │ ├── build_chains.py
│ │ │ ├── build_chat_bot.py
│ │ │ ├── build_retriever_tool.py
│ │ │ └── build_retrievers.py
│ │ ├── retrievers/
│ │ │ ├── __init__.py
│ │ │ ├── self_query.py
│ │ │ ├── vector_sql_output_parser.py
│ │ │ └── vector_sql_query.py
│ │ ├── types/
│ │ │ ├── __init__.py
│ │ │ ├── chains_and_retrievers.py
│ │ │ ├── global_config.py
│ │ │ └── table_config.py
│ │ └── vector_store/
│ │ ├── __init__.py
│ │ └── myscale_without_metadata.py
│ ├── logger.py
│ ├── requirements.txt
│ └── ui/
│ ├── __init__.py
│ ├── chat_page.py
│ ├── home.py
│ ├── retrievers.py
│ └── utils.py
└── docs/
├── self-query.md
└── vector-sql.md
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/sync-with-huggingface.yml
================================================
name: Sync with Hugging Face
on:
push:
branches:
- main
paths:
- .github/workflows/sync-with-huggingface.yml
- app/**
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Sync with Hugging Face
uses: nateraw/huggingface-sync-action@v0.0.5
with:
# The github repo you are syncing from. Required.
github_repo_id: 'myscale/ChatData'
# The Hugging Face repo id you want to sync to. (ex. 'username/reponame')
# A repo with this name will be created if it doesn't exist. Required.
huggingface_repo_id: 'myscale/ChatData'
# Hugging Face token with write access. Required.
# Here, we provide a token that we called `HF_TOKEN` when we added the secret to our GitHub repo.
hf_token: ${{ secrets.HF_TOKEN }}
# The type of repo you are syncing to: model, dataset, or space.
# Defaults to space.
repo_type: 'space'
# If true and the Hugging Face repo doesn't already exist, it will be created
# as a private repo.
#
# Note: this param has no effect if the repo already exists.
private: false
# If repo type is space, specify a space_sdk. One of: streamlit, gradio, or static
#
# This option is especially important if the repo has not been created yet.
# It won't really be used if the repo already exists.
space_sdk: 'streamlit'
# If provided, subdirectory will determine which directory of the repo will be synced.
# By default, this action syncs the entire GitHub repo.
#
# An example using this option can be seen here:
# https://github.com/huggingface/fuego/blob/830ed98/.github/workflows/sync-with-huggingface.yml
subdirectory: app
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
.idea/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# dataset files
data/
.streamlit/
#*.ipynb
.DS_Store
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 MyScale
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# ChatData 🔍 📖
***We are constantly improving LangChain's self-query retriever. Some of the features are not merged yet.***
[](https://discord.gg/D2qpkqc4Jq)
[](https://twitter.com/myscaledb)
Yet another chat-with-documents app, but supporting query over millions of files with [MyScale](https://myscale.com) and [LangChain](https://github.com/hwchase17/langchain/).
## Introduction 📖
### Overview
ChatData is a robust chat-with-documents application designed to extract information and provide answers by querying the MyScale free knowledge base or your uploaded documents.
Powered by the Retrieval Augmented Generation (RAG) framework, ChatData leverages millions of Wikipedia pages and arXiv papers as its external knowledge base, with MyScale managing all data hosting tasks. Simply input your questions in natural language, and ChatData takes care of generating SQL, querying the data, and presenting the results.
Enhancing your chat experience, ChatData introduces three key features. Let's delve into each of them in detail.
#### Feature 1: Retriever Type
MyScale works closely with LangChain, providing the easiest interface to build complex queries with LLM.
**Self-querying retriever:** MyScale augmented LangChain's Self Querying Retriever, where the LLM can use more data types, for instance timestamps and array of strings, to build filters for the query.
**VectorSQL:** SQL is powerful and can be used to construct complex search queries. Vector Structured Query Language (Vector SQL) is designed to teach LLMs how to query SQL vector databases. Besides the general data types and functions, vectorSQL contains extra functions like DISTANCE(column, query_vector)and NeuralArray(entity), with which we can extend the standard SQL for vector search.
#### Feature 2: Session Management
To enhance your experience and seamlessly continue interactions with existing sessions, ChatData has introduced the Session Management feature. You can easily customize your session ID and modify your prompt to guide ChatData in addressing your queries. With just a few clicks, you can enjoy smooth and personalized session interactions.
#### Feature 3: Building Your Own Knowledge Base
In addition to tapping into ChatData's external knowledge base powered by MyScale for answers, you also have the option to upload your own files and establish a personalized knowledge base. We've implemented the Unstructured API for this purpose, ensuring that only processed texts from your documents are stored, prioritizing your data privacy.
In conclusion, with ChatData, you can effortlessly navigate through vast amounts of data, effortlessly accessing precisely what you need. Whether you're a researcher, a student, or a knowledge enthusiast, ChatData empowers you to explore academic papers and research documents like never before. Unlock the true potential of information retrieval with ChatData and discover a world of knowledge at your fingertips.
➡️ Dive in and experience ChatData on [Hugging Face](https://huggingface.co/spaces/myscale/ChatData)🤗

### Data schema
Database credentials:
```toml
MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com"
MYSCALE_PORT = 443
MYSCALE_USER = "chatdata"
MYSCALE_PASSWORD = "myscale_rocks"
```
#### *[NEW]* Table `wiki.Wikipedia`
ChatData also provides you access to Wikipedia, a large knowledge base that contains about 36 million paragraphs under 5 million wiki pages. The knowledge base is a snapshot on 2022-12.
You can query from this table with the public account [here](#data-schema).
```sql
CREATE TABLE wiki.Wikipedia (
-- Record ID
`id` String,
-- Page title to this paragraph
`title` String,
-- Paragraph text
`text` String,
-- Page URL
`url` String,
-- Wiki page ID
`wiki_id` UInt64,
-- View statistics
`views` Float32,
-- Paragraph ID
`paragraph_id` UInt64,
-- Language ID
`langs` UInt32,
-- Feature vector to this paragraph
`emb` Array(Float32),
-- Vector Index
VECTOR INDEX emb_idx emb TYPE MSTG('metric_type=Cosine'),
CONSTRAINT emb_len CHECK length(emb) = 768)
ENGINE = ReplacingMergeTree ORDER BY id SETTINGS index_granularity = 8192
```
#### Table `default.ChatArXiv`
ChatData brings millions of papers into your knowledge base. We imported 2.2 million papers with metadata info, which contains:
1. `id`: paper's arxiv id
2. `abstract`: paper's abstracts used as ranking criterion (with InstructXL)
3. `vector`: column that contains the vector array in `Array(Float32)`
4. `metadata`: LangChain VectorStore Compatible Columns
1. `metadata.authors`: paper's authors in *list of strings*
2. `metadata.abstract`: paper's abstracts used as ranking criterion (with InstructXL)
3. `metadata.titles`: papers's titles
4. `metadata.categories`: paper's categories in *list of strings* like ["cs.CV"]
5. `metadata.pubdate`: paper's date of publication in *ISO 8601 formated strings*
6. `metadata.primary_category`: paper's primary category in *strings* defined by arXiv
7. `metadata.comment`: some additional comment to the paper
*Columns below are native columns in MyScale and can only be used as SQLDatabase*
5. `authors`: paper's authors in *list of strings*
6. `titles`: papers's titles
7. `categories`: paper's categories in *list of strings* like ["cs.CV"]
8. `pubdate`: paper's date of publication in *Date32 data type* (faster)
9. `primary_category`: paper's primary category in *strings* defined by arXiv
10. `comment`: some additional comment to the paper
And for overall table schema, please refer to [table creation section in docs/self-query.md](docs/self-query.md#table-creation).
If you want to use this database with `langchain.chains.sql_database.base.SQLDatabaseChain` or `langchain.retrievers.SQLDatabaseRetriever`, please follow guides on [data preparation section](docs/vector-sql.md#prepare-the-database) and [chain creation section](docs/vector-sql.md#create-the-sqldatabasechain) in docs/vector-sql.md
### Where can I get those arXiv data?
- [From parquet files on S3](docs/self-query.md#insert-data)
- Or Directly use MyScale database as service... for **FREE** ✨
```python
import clickhouse_connect
client = clickhouse_connect.get_client(
host='msc-950b9f1f.us-east-1.aws.myscale.com',
port=443,
username='chatdata',
password='myscale_rocks'
)
```
## Monthly Updates 🔥 (November-2023)
- 🚀 Upload your documents and chat with your own knowledge bases with MyScale!
- 💬 Chat with RAG-enabled agents on both ArXiv and Wikipedia knowledge base!
- 📖 Wikipedia is available as knowledge base!! Feel FREE 💰 to ask with 36 million of paragraphs under 5 million titles! 💫
- 🤖 LLMs are now capable of writing **Vector SQL** - a extended SQL with vector search! Vector SQL allows you to **access MyScale faster and stronger**! This will **be added to LangChain** soon! ([PR 7454](https://github.com/hwchase17/langchain/pull/7454))
- 🌏 Customized Retrieval QA Chain that gives you **more information** on each PDF and **answer question in your native language**!
- 🔧 Our contribution to LangChain that helps self-query retrievers [**filter with more types and functions**](https://python.langchain.com/docs/modules/data_connection/retrievers/how_to/self_query/myscale_self_query)
- 🌟 **We just opened a FREE pod hosting data for ArXiv paper.** Anyone can try their own SQL with vector search!!! Feel the power when SQL meets vector search! See how to access the pod [here](#data-service).
- 📚 We collected about **2 million papers on arxiv**! We are collecting more and we need your advice!
- More coming...
## How to build your own app from scratch 🧱
### Quickstart
1. Enter directory `app/`
```bash
cd app/
```
2. Create an virtual environment
```bash
python3 -m venv venv
source venv/bin/activate
```
3. Install dependencies
```bash
python3 -m pip install -r requirements.txt
```
4. Run the app!
```python
# fill you OpenAI key in .streamlit/secrets.toml
cp .streamlit/secrets.example.toml .streamlit/secrets.toml
# start the app
python3 -m streamlit run app.py
```
### With LangChain SQLDatabaseRetrievers
[*Read the full article*](https://myscale.com/blog/teach-your-llm-vector-sql/)
- [Why Vector SQL?](https://myscale.com/blog/teach-your-llm-vector-sql/#automate-the-whole-process-with-sql-and-vector-search)
- [How did LangChain and MyScale convert natural language to structured filters?](https://myscale.com/docs/en/advanced-applications/chatdata/#selfqueryretriever)
- [How to make chain execution more responsive in LangChain?](https://myscale.com/docs/en/advanced-applications/chatdata/#add-callbacks)
### With LangChain Self-Query Retrievers
[*Read the full article*](https://myscale.com/docs/en/advanced-applications/chatdata/)
- [How this app is built?](https://docs.myscale.com/en/advanced-applications/chatdata)
- [What is the overview pipeline?](https://docs.myscale.com/en/advanced-applications/chatdata/#design-the-query-pipeline)
- [How did LangChain and MyScale convert natural language to structured filters?](https://docs.myscale.com/en/advanced-applications/chatdata/#selfqueryretriever)
- [How to make chain execution more responsive in LangChain?](https://docs.myscale.com/en/advanced-applications/chatdata/#add-callbacks)
## Community 🌍
- Welcome to join our #ChatData channel in [Discord](https://discord.gg/jGCq2yZH) to discuss anything about ChatData.
- Feel free to filing an issue or opening a PR against this repository.
## Special Thanks 👏 (Ordered Alphabetically)
- [arXiv API](https://info.arxiv.org/help/api/index.html) for its open access interoperability to pre-printed papers.
- [InstructorXL](https://huggingface.co/hkunlp/instructor-xl) for its promptable embeddings that improves retrieve performance.
- [LangChain🦜️🔗](https://github.com/hwchase17/langchain/) for its easy-to-use and composable API designs and prompts.
- [OpenChatPaper](https://github.com/liuyixin-louis/OpenChatPaper) for prompt design reference.
- [The Alexandria Index](https://alex.macrocosm.so/download) for providing arXiv data index to the public.
================================================
FILE: app/app.py
================================================
import os
import time
import streamlit as st
from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \
DATA_INITIALIZE_STARTED
from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \
TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config
from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools
from backend.types.global_config import GlobalConfig
from logger import logger
from ui.chat_page import chat_page
from ui.home import render_home
from ui.retrievers import render_retrievers
# warnings.filterwarnings("ignore", category=UserWarning)
def prepare_environment():
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ["LANGCHAIN_TRACING_V2"] = "false"
# os.environ["LANGCHAIN_API_KEY"] = ""
os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE']
os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']
os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID']
os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN']
update_global_config(GlobalConfig(
openai_api_base=st.secrets['OPENAI_API_BASE'],
openai_api_key=st.secrets['OPENAI_API_KEY'],
auth0_client_id=st.secrets['AUTH0_CLIENT_ID'],
auth0_domain=st.secrets['AUTH0_DOMAIN'],
myscale_user=st.secrets['MYSCALE_USER'],
myscale_password=st.secrets['MYSCALE_PASSWORD'],
myscale_host=st.secrets['MYSCALE_HOST'],
myscale_port=st.secrets['MYSCALE_PORT'],
query_model="gpt-3.5-turbo-0125",
chat_model="gpt-3.5-turbo-0125",
untrusted_api=st.secrets['UNSTRUCTURED_API'],
myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True),
))
# when refresh browser, all session keys will be cleaned.
def initialize_session_state():
if DATA_INITIALIZE_STATUS not in st.session_state:
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED
logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}")
if JUMP_QUERY_ASK not in st.session_state:
st.session_state[JUMP_QUERY_ASK] = False
logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}")
def initialize_chat_data():
if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED:
start_time = time.time()
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED
st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models()
st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers()
st.session_state[RETRIEVER_TOOLS] = update_retriever_tools()
# mark data initialization finished.
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED
end_time = time.time()
logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, "
f"session state keys: {list(st.session_state.keys())}")
st.set_page_config(
page_title="ChatData",
page_icon="https://myscale.com/favicon.ico",
initial_sidebar_state="expanded",
layout="wide",
)
prepare_environment()
initialize_session_state()
initialize_chat_data()
if USER_NAME in st.session_state:
chat_page()
else:
if st.session_state[JUMP_QUERY_ASK]:
render_retrievers()
else:
render_home()
================================================
FILE: app/backend/__init__.py
================================================
================================================
FILE: app/backend/callbacks/__init__.py
================================================
================================================
FILE: app/backend/callbacks/arxiv_callbacks.py
================================================
import json
import textwrap
from typing import Dict, Any, List
from langchain.callbacks.streamlit.streamlit_callback_handler import (
LLMThought,
StreamlitCallbackHandler,
)
class LLMThoughtWithKnowledgeBase(LLMThought):
def on_tool_end(
self,
output: str,
color=None,
observation_prefix=None,
llm_prefix=None,
**kwargs: Any,
) -> None:
try:
self._container.markdown(
"\n\n".join(
["### Retrieved Documents:"]
+ [
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
for i, r in enumerate(json.loads(output))
]
)
)
except Exception as e:
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThoughtWithKnowledgeBase(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
================================================
FILE: app/backend/callbacks/llm_thought_with_table.py
================================================
from typing import Any, Dict, List
import streamlit as st
from langchain_core.outputs import LLMResult
from streamlit.external.langchain import StreamlitCallbackHandler
class ChatDataSelfQueryCallBack(StreamlitCallbackHandler):
def __init__(self):
super().__init__(st.container())
self._current_thought = None
self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...")
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
pass
def on_chain_end(self, outputs, **kwargs) -> None:
if len(kwargs['tags']) == 0:
self.progress_bar.progress(value=0.75, text="Searching in DB...")
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
st.markdown("### Generate filter by LLM \n"
"> Here we get `query_constructor` results \n\n")
self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
for item in response.generations:
st.markdown(f"{item[0].text}")
pass
================================================
FILE: app/backend/callbacks/self_query_callbacks.py
================================================
from typing import Dict, Any, List
import streamlit as st
from langchain.callbacks.streamlit.streamlit_callback_handler import (
StreamlitCallbackHandler,
)
from langchain.schema.output import LLMResult
class CustomSelfQueryRetrieverCallBackHandler(StreamlitCallbackHandler):
def __init__(self):
super().__init__(st.container())
self._current_thought = None
self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery...")
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
pass
def on_chain_end(self, outputs, **kwargs) -> None:
if len(kwargs['tags']) == 0:
self.progress_bar.progress(value=0.75, text="Searching in DB...")
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
st.markdown("### Generate filter by LLM \n"
"> Here we get `query_constructor` results \n\n")
self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
for item in response.generations:
st.markdown(f"{item[0].text}")
pass
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
super().__init__(st.container())
self.progress_bar = st.progress(value=0.2, text="Executing ChatData SelfQuery Chain...")
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if len(kwargs['tags']) != 0:
self.progress_bar.progress(value=0.5, text="We got filter info from LLM...")
st.markdown("### Generate filter by LLM \n"
"> Here we get `query_constructor` results \n\n")
for item in response.generations:
st.markdown(f"{item[0].text}")
pass
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = ".".join(serialized["id"])
if cid.endswith(".CustomStuffDocumentChain"):
self.progress_bar.progress(value=0.7, text="Asking LLM with related documents...")
================================================
FILE: app/backend/callbacks/vector_sql_callbacks.py
================================================
import streamlit as st
from langchain.callbacks.streamlit.streamlit_callback_handler import (
StreamlitCallbackHandler,
)
from langchain.schema.output import LLMResult
from sql_formatter.core import format_sql
class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.2
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_llm_end(
self,
response: LLMResult,
*args,
**kwargs,
):
text = response.generations[0][0].text
if text.replace(" ", "").upper().startswith("SELECT"):
st.markdown("### Generated Vector Search SQL Statement \n"
"> This sql statement is generated by LLM \n\n")
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
self.prog_value += self.prog_interval
self.progress_bar.progress(
value=self.prog_value, text="Searching in DB...")
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = ".".join(serialized["id"])
self.prog_value += self.prog_interval
self.progress_bar.progress(
value=self.prog_value, text=f"Running Chain `{cid}`..."
)
def on_chain_end(self, outputs, **kwargs) -> None:
pass
class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler):
def __init__(self, table: str) -> None:
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.1
self.table = table
================================================
FILE: app/backend/chains/__init__.py
================================================
================================================
FILE: app/backend/chains/retrieval_qa_with_sources.py
================================================
import inspect
from typing import Dict, Any, Optional, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.docstore.document import Document
from logger import logger
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
"""QA with source chain for Chat ArXiv app with references
This chain will automatically assign reference number to the article,
Then parse it back to titles or anything else.
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
logger.info(f"\033[91m\033[1m{self._chain_type}\033[0m")
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs: List[Document] = self._get_docs(inputs, run_manager=_run_manager)
else:
docs: List[Document] = self._get_docs(inputs) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
# parse source with ref_id
sources = []
ref_cnt = 1
for d in docs:
ref_id = d.metadata['ref_id']
if f"Doc #{ref_id}" in answer:
answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
if f"#{ref_id}" in answer:
title = d.metadata['title'].replace('\n', '')
d.metadata['ref_id'] = ref_cnt
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
sources.append(d)
ref_cnt += 1
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
raise NotImplementedError
@property
def _chain_type(self) -> str:
return "custom_retrieval_qa_with_sources_chain"
================================================
FILE: app/backend/chains/stuff_documents.py
================================================
from typing import Any, List, Tuple
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.docstore.document import Document
from langchain.schema.prompt_template import format_document
class CustomStuffDocumentChain(StuffDocumentsChain):
"""Combine arxiv documents with PDF reference number"""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = []
for doc_id, doc in enumerate(docs):
# add temp reference number in metadata
doc.metadata.update({'ref_id': doc_id})
doc.page_content = doc.page_content.replace('\n', ' ')
doc_strings.append(format_document(doc, self.document_prompt))
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings)
return inputs
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
output = self.llm_chain.predict(callbacks=callbacks, **inputs)
return output, {}
@property
def _chain_type(self) -> str:
return "custom_stuff_document_chain"
================================================
FILE: app/backend/chat_bot/__init__.py
================================================
================================================
FILE: app/backend/chat_bot/chat.py
================================================
import time
from os import environ
from time import sleep
import streamlit as st
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \
CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \
EL_BUILD_KB_WITH_FILES, \
EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES
from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS
from backend.construct.build_agents import build_agents
from backend.chat_bot.session_manager import SessionManager
from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
from logger import logger
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
TOOL_NAMES = {
"langchain_retriever_tool": "Self-querying retriever",
"vecsql_retriever_tool": "Vector SQL",
}
def on_chat_submit():
with st.session_state.next_round.container():
with st.chat_message("user"):
st.write(st.session_state.chat_input)
with st.chat_message("assistant"):
container = st.container()
st_callback = ChatDataAgentCallBackHandler(
container, collapse_completed_thoughts=False
)
ret = st.session_state.agent(
{"input": st.session_state.chat_input}, callbacks=[st_callback]
)
logger.info(f"ret:{ret}")
def clear_history():
if "agent" in st.session_state:
st.session_state.agent.memory.clear()
def back_to_main():
if USER_INFO in st.session_state:
del st.session_state[USER_INFO]
if USER_NAME in st.session_state:
del st.session_state[USER_NAME]
if JUMP_QUERY_ASK in st.session_state:
del st.session_state[JUMP_QUERY_ASK]
if EL_SESSION_SELECTOR in st.session_state:
del st.session_state[EL_SESSION_SELECTOR]
if CHAT_CURRENT_USER_SESSIONS in st.session_state:
del st.session_state[CHAT_CURRENT_USER_SESSIONS]
def refresh_sessions():
chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER]
current_user_name = st.session_state[USER_NAME]
current_user_sessions = chat_session_manager.list_sessions(current_user_name)
if not isinstance(current_user_sessions, dict) or not current_user_sessions:
# generate a default session for current user.
chat_session_manager.add_session(
user_id=current_user_name,
session_id=f"{current_user_name}?default",
system_prompt=DEFAULT_SYSTEM_PROMPT,
)
st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name)
current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS]
else:
st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions
# load current user files.
st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(
current_user_name
)
# load current user private knowledge bases.
st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \
st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name)
logger.info(f"current user name: {current_user_name}, "
f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, "
f"user private files: {st.session_state[USER_PRIVATE_FILES]}")
st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = {
# public retrieval tools
**st.session_state[RETRIEVER_TOOLS],
# private retrieval tools
**st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name),
}
# print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}")
print(f"current_user_sessions is {current_user_sessions}")
st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0]
# process for session add and delete.
def on_session_change_submit():
if "session_manager" in st.session_state and "session_editor" in st.session_state:
try:
for elem in st.session_state.session_editor["added_rows"]:
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
if elem["session_id"] != "" and "?" not in elem["session_id"]:
st.session_state.session_manager.add_session(
user_id=st.session_state.user_name,
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
system_prompt=elem["system_prompt"],
)
else:
st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="❌")
raise KeyError(
"`session_id` shouldn't be neither empty nor contain char `?`."
)
else:
st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="❌")
raise KeyError(
"You should fill both `session_id` and `system_prompt` to add a column!"
)
for elem in st.session_state.session_editor["deleted_rows"]:
user_name = st.session_state[USER_NAME]
session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id']
user_with_session_id = f"{user_name}?{session_id}"
st.session_state.session_manager.remove_session(session_id=user_with_session_id)
st.toast(f"session `{user_with_session_id}` removed.", icon="✅")
refresh_sessions()
except Exception as e:
sleep(2)
st.error(f"{type(e)}: {str(e)}")
finally:
st.session_state.session_editor["added_rows"] = []
st.session_state.session_editor["deleted_rows"] = []
refresh_agent()
def create_private_knowledge_base_as_tool():
current_user_name = st.session_state[USER_NAME]
if (
EL_PERSONAL_KB_NAME in st.session_state
and EL_PERSONAL_KB_DESCRIPTION in st.session_state
and EL_BUILD_KB_WITH_FILES in st.session_state
and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0
and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0
and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0
):
st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base(
user_id=current_user_name,
tool_name=st.session_state[EL_PERSONAL_KB_NAME],
tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION],
files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]],
)
refresh_sessions()
else:
st.session_state[EL_UPLOAD_FILES_STATUS].error(
"You should fill all fields to build up a tool!"
)
sleep(2)
def remove_private_knowledge_bases():
if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]:
private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]
private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove]
# remove these private knowledge bases.
st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases(
user_id=st.session_state[USER_NAME],
private_knowledge_bases=private_knowledge_base_names
)
refresh_sessions()
else:
st.session_state[EL_UPLOAD_FILES_STATUS].error(
"You should specify at least one private knowledge base to delete!"
)
time.sleep(2)
def refresh_agent():
with st.spinner("Initializing session..."):
user_name = st.session_state[USER_NAME]
session_id = st.session_state[EL_SESSION_SELECTOR]['session_id']
user_with_session_id = f"{user_name}?{session_id}"
if EL_SELECTED_KBS in st.session_state:
selected_knowledge_bases = st.session_state[EL_SELECTED_KBS]
else:
selected_knowledge_bases = ["Wikipedia + Vector SQL"]
logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}")
if EL_SESSION_SELECTOR in st.session_state:
system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"]
else:
system_prompt = DEFAULT_SYSTEM_PROMPT
st.session_state["agent"] = build_agents(
session_id=user_with_session_id,
tool_names=selected_knowledge_bases,
system_prompt=system_prompt
)
def add_file():
user_name = st.session_state[USER_NAME]
if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0:
st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="⚠️")
sleep(2)
return
try:
st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...")
st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file(
user_id=user_name,
files=st.session_state[EL_UPLOAD_FILES]
)
refresh_sessions()
except ValueError as e:
st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e))
sleep(2)
def clear_files():
st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME])
refresh_sessions()
================================================
FILE: app/backend/chat_bot/json_decoder.py
================================================
import json
import datetime
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime.datetime):
return datetime.datetime.isoformat(obj)
return json.JSONEncoder.default(self, obj)
class CustomJSONDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(
self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, source):
for k, v in source.items():
if isinstance(v, str):
try:
source[k] = datetime.datetime.fromisoformat(str(v))
except:
pass
return source
================================================
FILE: app/backend/chat_bot/message_converter.py
================================================
import hashlib
import json
import time
from typing import Any
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage, FunctionMessage
from langchain.schema.messages import ToolMessage
from sqlalchemy.orm import declarative_base
from backend.chat_bot.tools import create_message_history_table
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
elif _type == "AIMessageChunk":
message["data"]["type"] = "ai"
return AIMessage(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
"""The default message converter for SQLChatMessageHistory."""
def __init__(self, table_name: str):
super().__init__(table_name)
self.model_class = create_message_history_table(table_name, declarative_base())
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
time_stamp = time.time()
msg_id = hashlib.sha256(
f"{session_id}_{message}_{time_stamp}".encode('utf-8')).hexdigest()
user_id, _ = session_id.split("?")
return self.model_class(
id=time_stamp,
msg_id=msg_id,
user_id=user_id,
session_id=session_id,
type=message.type,
addtionals=json.dumps(message.additional_kwargs),
message=json.dumps({
"type": message.type,
"additional_kwargs": {"timestamp": time_stamp},
"data": message.dict()})
)
def from_sql_model(self, sql_message: Any) -> BaseMessage:
msg_dump = json.loads(sql_message.message)
msg = _message_from_dict(msg_dump)
msg.additional_kwargs = msg_dump["additional_kwargs"]
return msg
def get_sql_model_class(self) -> Any:
return self.model_class
================================================
FILE: app/backend/chat_bot/private_knowledge_base.py
================================================
import hashlib
from datetime import datetime
from typing import List, Optional
import pandas as pd
from clickhouse_connect import get_client
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
from streamlit.runtime.uploaded_file_manager import UploadedFile
from backend.chat_bot.tools import parse_files, extract_embedding
from backend.construct.build_retriever_tool import create_retriever_tool
from logger import logger
class ChatBotKnowledgeTable:
def __init__(self, host, port, username, password,
embedding: Embeddings, parser_api_key: str, db="chat",
kb_table="private_kb", tool_table="private_tool") -> None:
super().__init__()
personal_files_schema_ = f"""
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
entity_id String,
file_name String,
text String,
user_id String,
created_by DateTime,
vector Array(Float32),
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
) ENGINE = ReplacingMergeTree ORDER BY entity_id
"""
# `tool_name` represent private knowledge database name.
private_knowledge_base_schema_ = f"""
CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
tool_id String,
tool_name String,
file_names Array(String),
user_id String,
created_by DateTime,
tool_description String
) ENGINE = ReplacingMergeTree ORDER BY tool_id
"""
self.personal_files_table = kb_table
self.private_knowledge_base_table = tool_table
config = MyScaleSettings(
host=host,
port=port,
username=username,
password=password,
database=db,
table=kb_table,
)
self.client = get_client(
host=config.host,
port=config.port,
username=config.username,
password=config.password,
)
self.client.command("SET allow_experimental_object_type=1")
self.client.command(personal_files_schema_)
self.client.command(private_knowledge_base_schema_)
self.parser_api_key = parser_api_key
self.vector_store = MyScaleWithoutJSON(
embedding=embedding,
config=config,
must_have_cols=["file_name", "text", "created_by"],
)
# List all files with given `user_id`
def list_files(self, user_id: str):
query = f"""
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
FROM {self.vector_store.config.database}.{self.personal_files_table}
WHERE user_id = '{user_id}' GROUP BY file_name
"""
return [r for r in self.vector_store.client.query(query).named_results()]
# Parse and embedding files
def add_by_file(self, user_id, files: List[UploadedFile]):
data = parse_files(self.parser_api_key, user_id, files)
data = extract_embedding(self.vector_store.embeddings, data)
self.vector_store.client.insert_df(
table=self.personal_files_table,
df=pd.DataFrame(data),
database=self.vector_store.config.database,
)
# Remove all files and private_knowledge_bases with given `user_id`
def clear(self, user_id: str):
self.vector_store.client.command(
f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} "
f"WHERE user_id='{user_id}'"
)
query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
WHERE user_id = '{user_id}'"""
self.vector_store.client.command(query)
def create_private_knowledge_base(
self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None
):
self.vector_store.client.insert_df(
self.private_knowledge_base_table,
pd.DataFrame(
[
{
"tool_id": hashlib.sha256(
(user_id + tool_name).encode("utf-8")
).hexdigest(),
"tool_name": tool_name, # tool_name represent user's private knowledge base.
"file_names": files,
"user_id": user_id,
"created_by": datetime.now(),
"tool_description": tool_description,
}
]
),
database=self.vector_store.config.database,
)
# Show all private knowledge bases with given `user_id`
def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None):
extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else ""
query = f"""
SELECT tool_name, tool_description, length(file_names)
FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
WHERE user_id = '{user_id}' {extended_where}
"""
return [r for r in self.vector_store.client.query(query).named_results()]
def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]):
unique_list = list(set(private_knowledge_bases))
unique_list = ",".join([f"'{t}'" for t in unique_list])
query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]"""
self.vector_store.client.command(query)
def as_retrieval_tools(self, user_id, tool_name=None):
logger.info(f"")
private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name)
retrievers = {}
for private_kb in private_knowledge_bases:
file_names_sql = f"""
SELECT arrayJoin(file_names) FROM (
SELECT file_names
FROM chat.private_tool
WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}'
)
"""
logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}")
res = self.client.query(file_names_sql)
file_names = []
for line in res.result_rows:
file_names.append(line[0])
file_names = ', '.join(f"'{item}'" for item in file_names)
logger.info(f"user_id is {user_id}, file_names is {file_names}")
retrievers[private_kb["tool_name"]] = create_retriever_tool(
self.vector_store.as_retriever(
search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"},
),
tool_name=private_kb["tool_name"],
description=private_kb["tool_description"],
)
return retrievers
================================================
FILE: app/backend/chat_bot/session_manager.py
================================================
import json
from backend.chat_bot.tools import create_session_table, create_message_history_table
from backend.constants.variables import GLOBAL_CONFIG
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
from sqlalchemy import orm, create_engine
from logger import logger
def get_sessions(engine, model_class, user_id):
with orm.sessionmaker(engine)() as session:
result = (
session.query(model_class)
.where(
model_class.session_id == user_id
)
.order_by(model_class.create_by.desc())
)
return json.loads(result)
class SessionManager:
def __init__(
self,
session_state,
host,
port,
username,
password,
db='chat',
session_table='sessions',
msg_table='chat_memory'
) -> None:
if GLOBAL_CONFIG.myscale_enable_https == False:
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http'
else:
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
self.engine = create_engine(conn_str, echo=False)
self.session_model_class = create_session_table(
session_table, declarative_base())
self.session_model_class.metadata.create_all(self.engine)
self.msg_model_class = create_message_history_table(msg_table, declarative_base())
self.msg_model_class.metadata.create_all(self.engine)
self.session_orm = orm.sessionmaker(self.engine)
self.session_state = session_state
def list_sessions(self, user_id: str):
with self.session_orm() as session:
result = (
session.query(self.session_model_class)
.where(
self.session_model_class.user_id == user_id
)
.order_by(self.session_model_class.create_by.desc())
)
sessions = []
for r in result:
sessions.append({
"session_id": r.session_id.split("?")[-1],
"system_prompt": r.system_prompt,
})
return sessions
# Update sys_prompt with given session_id
def modify_system_prompt(self, session_id, sys_prompt):
with self.session_orm() as session:
obj = session.query(self.session_model_class).where(
self.session_model_class.session_id == session_id).first()
if obj:
obj.system_prompt = sys_prompt
session.commit()
else:
logger.warning(f"Session {session_id} not found")
# Add a session(session_id, sys_prompt)
def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs):
with self.session_orm() as session:
elem = self.session_model_class(
user_id=user_id, session_id=session_id, system_prompt=system_prompt,
create_by=datetime.now(), additionals=json.dumps(kwargs)
)
session.add(elem)
session.commit()
# Remove a session and related chat history.
def remove_session(self, session_id: str):
with self.session_orm() as session:
# remove session
session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete()
# remove related chat history.
session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete()
================================================
FILE: app/backend/chat_bot/tools.py
================================================
import hashlib
from datetime import datetime
from multiprocessing.pool import ThreadPool
from typing import List
import requests
from clickhouse_sqlalchemy import types, engines
from langchain.schema.embeddings import Embeddings
from sqlalchemy import Column, Text
from streamlit.runtime.uploaded_file_manager import UploadedFile
def parse_files(api_key, user_id, files: List[UploadedFile]):
def parse_file(file: UploadedFile):
headers = {
"accept": "application/json",
"unstructured-api-key": api_key,
}
data = {"strategy": "auto", "ocr_languages": ["eng"]}
file_hash = hashlib.sha256(file.read()).hexdigest()
file_data = {"files": (file.name, file.getvalue(), file.type)}
response = requests.post(
url="https://api.unstructured.io/general/v0/general",
headers=headers,
data=data,
files=file_data
)
json_response = response.json()
if response.status_code != 200:
raise ValueError(str(json_response))
texts = [
{
"text": t["text"],
"file_name": t["metadata"]["filename"],
"entity_id": hashlib.sha256(
(file_hash + t["text"]).encode()
).hexdigest(),
"user_id": user_id,
"created_by": datetime.now(),
}
for t in json_response
if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
]
return texts
with ThreadPool(8) as p:
rows = []
for r in p.imap_unordered(parse_file, files):
rows.extend(r)
return rows
def extract_embedding(embeddings: Embeddings, texts):
if len(texts) > 0:
embeddings = embeddings.embed_documents(
[t["text"] for _, t in enumerate(texts)])
for i, _ in enumerate(texts):
texts[i]["vector"] = embeddings[i]
return texts
raise ValueError("No texts extracted!")
def create_message_history_table(table_name: str, base_class):
class Message(base_class):
__tablename__ = table_name
id = Column(types.Float64)
session_id = Column(Text)
user_id = Column(Text)
msg_id = Column(Text, primary_key=True)
type = Column(Text)
# should be additions, formal developer mistake spell it.
addtionals = Column(Text)
message = Column(Text)
__table_args__ = (
engines.MergeTree(
partition_by='session_id',
order_by=('id', 'msg_id')
),
{'comment': 'Store Chat History'}
)
return Message
def create_session_table(table_name: str, DynamicBase):
class Session(DynamicBase):
__tablename__ = table_name
user_id = Column(Text)
session_id = Column(Text, primary_key=True)
system_prompt = Column(Text)
# represent create time.
create_by = Column(types.DateTime)
# should be additions, formal developer mistake spell it.
additionals = Column(Text)
__table_args__ = (
engines.MergeTree(order_by=session_id),
{'comment': 'Store Session and Prompts'}
)
return Session
================================================
FILE: app/backend/constants/__init__.py
================================================
================================================
FILE: app/backend/constants/myscale_tables.py
================================================
from typing import Dict, List
import streamlit as st
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings
from langchain.prompts import PromptTemplate
from backend.types.table_config import TableConfig
def hint_arxiv():
st.markdown("Here we provide some query samples.")
st.markdown("- If you want to search papers with filters")
st.markdown("1. ```What is a Bayesian network? Please use articles published later than Feb 2018 and with more "
"than 2 categories and whose title like `computer` and must have `cs.CV` in its category. ```")
st.markdown("2. ```What is a Bayesian network? Please use articles published later than Feb 2018```")
st.markdown("- If you want to ask questions based on arxiv papers stored in MyScaleDB")
st.markdown("1. ```Did Geoffrey Hinton wrote paper about Capsule Neural Networks?```")
st.markdown("2. ```Introduce some applications of GANs published around 2019.```")
st.markdown("3. ```请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些```")
def hint_sql_arxiv():
st.markdown('''```sql
CREATE TABLE default.ChatArXiv (
`abstract` String,
`id` String,
`vector` Array(Float32),
`metadata` Object('JSON'),
`pubdate` DateTime,
`title` String,
`categories` Array(String),
`authors` Array(String),
`comment` String,
`primary_category` String,
VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
CONSTRAINT vec_len CHECK length(vector) = 768)
ENGINE = ReplacingMergeTree ORDER BY id
```''')
def hint_wiki():
st.markdown("Here we provide some query samples.")
st.markdown("1. ```Which company did Elon Musk found?```")
st.markdown("2. ```What is Iron Gwazi?```")
st.markdown("3. ```苹果的发源地是哪里?```")
st.markdown("4. ```What is a Ring in mathematics?```")
st.markdown("5. ```The producer of Rick and Morty.```")
st.markdown("6. ```How low is the temperature on Pluto?```")
def hint_sql_wiki():
st.markdown('''```sql
CREATE TABLE wiki.Wikipedia (
`id` String,
`title` String,
`text` String,
`url` String,
`wiki_id` UInt64,
`views` Float32,
`paragraph_id` UInt64,
`langs` UInt32,
`emb` Array(Float32),
VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
CONSTRAINT emb_len CHECK length(emb) = 768)
ENGINE = ReplacingMergeTree ORDER BY id
```''')
MYSCALE_TABLES: Dict[str, TableConfig] = {
'Wikipedia': TableConfig(
database="wiki",
table="Wikipedia",
table_contents="Snapshort from Wikipedia for 2022. All in English.",
hint=hint_wiki,
hint_sql=hint_sql_wiki,
# doc_prompt 对 qa source chain 有用
doc_prompt=PromptTemplate(
input_variables=["page_content", "url", "title", "ref_id", "views"],
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"
),
metadata_col_attributes=[
AttributeInfo(name="title", description="title of the wikipedia page", type="string"),
AttributeInfo(name="text", description="paragraph from this wiki page", type="string"),
AttributeInfo(name="views", description="number of views", type="float")
],
must_have_col_names=['id', 'title', 'url', 'text', 'views'],
vector_col_name="emb",
text_col_name="text",
metadata_col_name="metadata",
emb_model=lambda: SentenceTransformerEmbeddings(
model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
),
tool_desc=("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages")
),
'ArXiv Papers': TableConfig(
database="default",
table="ChatArXiv",
table_contents="Snapshort from Wikipedia for 2022. All in English.",
hint=hint_arxiv,
hint_sql=hint_sql_arxiv,
doc_prompt=PromptTemplate(
input_variables=["page_content", "id", "title", "ref_id", "authors", "pubdate", "categories"],
template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\t"
"Date of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"
),
metadata_col_attributes=[
AttributeInfo(name="pubdate", description="The year the paper is published", type="timestamp"),
AttributeInfo(name="authors", description="List of author names", type="list[string]"),
AttributeInfo(name="title", description="Title of the paper", type="string"),
AttributeInfo(name="categories", description="arxiv categories to this paper", type="list[string]"),
AttributeInfo(name="length(categories)", description="length of arxiv categories to this paper", type="int")
],
must_have_col_names=['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
vector_col_name="vector",
text_col_name="abstract",
metadata_col_name="metadata",
emb_model=lambda: HuggingFaceInstructEmbeddings(
model_name='hkunlp/instructor-xl',
embed_instruction="Represent the question for retrieving supporting scientific papers: "
),
tool_desc=(
"search_among_scientific_papers",
"Searches among scientific papers from ArXiv and returns research papers"
)
)
}
ALL_TABLE_NAME: List[str] = [config.table for config in MYSCALE_TABLES.values()]
================================================
FILE: app/backend/constants/prompts.py
================================================
from langchain.prompts import ChatPromptTemplate, \
SystemMessagePromptTemplate, HumanMessagePromptTemplate
DEFAULT_SYSTEM_PROMPT = (
"Do your best to answer the questions. "
"Feel free to use any tools available to look up "
"relevant information. Please keep all details in query "
"when calling search functions."
)
COMBINE_PROMPT_TEMPLATE = (
"You are a helpful document assistant. "
"Your task is to provide information and answer any questions related to documents given below. "
"You should use the sections, title and abstract of the selected documents as your source of information "
"and try to provide concise and accurate answers to any questions asked by the user. "
"If you are unable to find relevant information in the given sections, "
"you will need to let the user know that the source does not contain relevant information but still try to "
"provide an answer based on your general knowledge. You must refer to the corresponding section name and page "
"that you refer to when answering. "
"The following is the related information about the document that will help you answer users' questions, "
"you MUST answer it using question's language:\n\n {summaries} "
"Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
)
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
string_messages=[(SystemMessagePromptTemplate, COMBINE_PROMPT_TEMPLATE),
(HumanMessagePromptTemplate, '{question}')])
MYSCALE_PROMPT = """
You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.
Pay attention to the data type when using functions. Always use `AND` to connect conditions in `WHERE` and never use comma.
Make sure you never write an isolated `WHERE` keyword and never use undesired condition to conrtain the query.
Use the following format:
======== table info ========
Question: "Question here"
SQLQuery: "SQL Query to run"
Here are some examples:
======== table info ========
CREATE TABLE "ChatPaper" (
abstract String,
id String,
vector Array(Float32),
) ENGINE = ReplicatedReplacingMergeTree()
ORDER BY id
PRIMARY KEY id
Question: What is Feartue Pyramid Network?
SQLQuery: SELECT ChatPaper.abstract, ChatPaper.id FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
======== table info ========
CREATE TABLE "ChatPaper" (
abstract String,
id String,
vector Array(Float32),
categories Array(String),
pubdate DateTime,
title String,
authors Array(String),
primary_category String
) ENGINE = ReplicatedReplacingMergeTree()
ORDER BY id
PRIMARY KEY id
Question: What is PaperRank? What is the contribution of those works? Use paper with more than 2 categories.
SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
======== table info ========
CREATE TABLE "ChatArXiv" (
primary_category String
categories Array(String),
pubdate DateTime,
abstract String,
title String,
paper_id String,
vector Array(Float32),
authors Array(String),
) ENGINE = MergeTree()
ORDER BY paper_id
PRIMARY KEY paper_id
Question: Did Geoffrey Hinton wrote about Capsule Neural Networks? Please use articles published later than 2021.
SQLQuery: SELECT ChatArXiv.title, ChatArXiv.paper_id, ChatArXiv.authors FROM ChatArXiv WHERE has(authors, 'Geoffrey Hinton') AND pubdate > parseDateTimeBestEffort('2021-01-01') ORDER BY DISTANCE(vector, NeuralArray(Capsule Neural Networks)) LIMIT {top_k}
======== table info ========
CREATE TABLE "PaperDatabase" (
abstract String,
categories Array(String),
vector Array(Float32),
pubdate DateTime,
id String,
comments String,
title String,
authors Array(String),
primary_category String
) ENGINE = MergeTree()
ORDER BY id
PRIMARY KEY id
Question: Find papers whose abstract has Mutual Information in it.
SQLQuery: SELECT PaperDatabase.title, PaperDatabase.id FROM PaperDatabase WHERE abstract ILIKE '%Mutual Information%' ORDER BY DISTANCE(vector, NeuralArray(Mutual Information)) LIMIT {top_k}
Let's begin:
======== table info ========
{table_info}
Question: {input}
SQLQuery: """
================================================
FILE: app/backend/constants/streamlit_keys.py
================================================
DATA_INITIALIZE_NOT_STATED = "data_initialize_not_started"
DATA_INITIALIZE_STARTED = "data_initialize_started"
DATA_INITIALIZE_COMPLETED = "data_initialize_completed"
CHAT_SESSION = "sel_sess"
CHAT_KNOWLEDGE_TABLE = "private_kb"
CHAT_SESSION_MANAGER = "session_manager"
CHAT_CURRENT_USER_SESSIONS = "current_sessions"
EL_SESSION_SELECTOR = "el_session_selector"
# all personal knowledge bases under a specific user.
USER_PERSONAL_KNOWLEDGE_BASES = "user_tools"
# all personal files under a specific user.
USER_PRIVATE_FILES = "user_files"
# public and personal knowledge bases.
AVAILABLE_RETRIEVAL_TOOLS = "tools_with_users"
EL_PERSONAL_KB_NEEDS_REMOVE = "el_personal_kb_needs_remove"
# files needs upload
EL_UPLOAD_FILES = "el_upload_files"
EL_UPLOAD_FILES_STATUS = "el_upload_files_status"
# use these files to build private knowledge base
EL_BUILD_KB_WITH_FILES = "el_build_kb_with_files"
# build a personal kb, given name.
EL_PERSONAL_KB_NAME = "el_personal_kb_name"
# build a personal kb, given description.
EL_PERSONAL_KB_DESCRIPTION = "el_personal_kb_description"
# knowledge bases selected by user.
EL_SELECTED_KBS = "el_selected_kbs"
================================================
FILE: app/backend/constants/variables.py
================================================
from backend.types.global_config import GlobalConfig
# ***** str variables ***** #
EMBEDDING_MODEL_PREFIX = "embedding_model"
CHAINS_RETRIEVERS_MAPPING = "sel_map_obj"
LANGCHAIN_RETRIEVER = "langchain_retriever"
VECTOR_SQL_RETRIEVER = "vecsql_retriever"
TABLE_EMBEDDINGS_MAPPING = "embeddings"
RETRIEVER_TOOLS = "tools"
DATA_INITIALIZE_STATUS = "data_initialized"
UI_INITIALIZED = "ui_initialized"
JUMP_QUERY_ASK = "jump_query_ask"
USER_NAME = "user_name"
USER_INFO = "user_info"
DIVIDER_HTML = """
"""
DIVIDER_THIN_HTML = """
"""
class RetrieverButtons:
vector_sql_query_from_db = "vector_sql_query_from_db"
vector_sql_query_with_llm = "vector_sql_query_with_llm"
self_query_from_db = "self_query_from_db"
self_query_with_llm = "self_query_with_llm"
GLOBAL_CONFIG = GlobalConfig()
def update_global_config(new_config: GlobalConfig):
global GLOBAL_CONFIG
GLOBAL_CONFIG.openai_api_base = new_config.openai_api_base
GLOBAL_CONFIG.openai_api_key = new_config.openai_api_key
GLOBAL_CONFIG.auth0_client_id = new_config.auth0_client_id
GLOBAL_CONFIG.auth0_domain = new_config.auth0_domain
GLOBAL_CONFIG.myscale_user = new_config.myscale_user
GLOBAL_CONFIG.myscale_password = new_config.myscale_password
GLOBAL_CONFIG.myscale_host = new_config.myscale_host
GLOBAL_CONFIG.myscale_port = new_config.myscale_port
GLOBAL_CONFIG.query_model = new_config.query_model
GLOBAL_CONFIG.chat_model = new_config.chat_model
GLOBAL_CONFIG.untrusted_api = new_config.untrusted_api
GLOBAL_CONFIG.myscale_enable_https = new_config.myscale_enable_https
================================================
FILE: app/backend/construct/__init__.py
================================================
================================================
FILE: app/backend/construct/build_agents.py
================================================
import os
from typing import Sequence, List
import streamlit as st
from langchain.agents import AgentExecutor
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS
from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS
from logger import logger
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import MessagesPlaceholder
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.schema.messages import SystemMessage
from langchain.memory import SQLChatMessageHistory
def create_agent_executor(
agent_name: str,
session_id: str,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
system_prompt: str,
**kwargs
) -> AgentExecutor:
agent_name = agent_name.replace(" ", "_")
conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}'
chat_memory = SQLChatMessageHistory(
session_id,
connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https',
custom_message_converter=DefaultClickhouseMessageConverter(agent_name))
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=SystemMessage(content=system_prompt),
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
return AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=True,
return_intermediate_steps=True,
**kwargs
)
def build_agents(
session_id: str,
tool_names: List[str],
model: str = "gpt-3.5-turbo-0125",
temperature: float = 0.6,
system_prompt: str = DEFAULT_SYSTEM_PROMPT
):
chat_llm = ChatOpenAI(
model_name=model,
temperature=temperature,
base_url=GLOBAL_CONFIG.openai_api_base,
api_key=GLOBAL_CONFIG.openai_api_key,
streaming=True
)
tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS))
selected_tools = [tools[k] for k in tool_names]
logger.info(f"create agent, use tools: {selected_tools}")
agent = create_agent_executor(
agent_name="chat_memory",
session_id=session_id,
llm=chat_llm,
tools=selected_tools,
system_prompt=system_prompt
)
return agent
================================================
FILE: app/backend/construct/build_all.py
================================================
from logger import logger
from typing import Dict, Any, Union
import streamlit as st
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING
from backend.construct.build_chains import build_retrieval_qa_with_sources_chain
from backend.construct.build_retriever_tool import create_retriever_tool
from backend.construct.build_retrievers import build_self_query_retriever, build_vector_sql_db_chain_retriever
from backend.types.chains_and_retrievers import ChainsAndRetrievers, MetadataColumn
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, \
SentenceTransformerEmbeddings
@st.cache_resource
def load_embedding_model_for_table(table_name: str) -> \
Union[SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings]:
with st.spinner(f"Loading embedding models for [{table_name}] ..."):
embeddings = MYSCALE_TABLES[table_name].emb_model()
return embeddings
@st.cache_resource
def load_embedding_models() -> Dict[str, Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]]:
embedding_models = {}
for table in MYSCALE_TABLES:
embedding_models[table] = load_embedding_model_for_table(table)
return embedding_models
@st.cache_resource
def update_retriever_tools():
retrievers_tools = {}
for table in MYSCALE_TABLES:
logger.info(f"Updating retriever tools [, ] for table {table}")
retrievers_tools.update(
{
f"{table} + Self Querying": create_retriever_tool(
st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["retriever"],
*MYSCALE_TABLES[table].tool_desc
),
f"{table} + Vector SQL": create_retriever_tool(
st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["sql_retriever"],
*MYSCALE_TABLES[table].tool_desc
),
})
return retrievers_tools
@st.cache_resource
def build_chains_retriever_for_table(table_name: str) -> ChainsAndRetrievers:
metadata_col_attributes = MYSCALE_TABLES[table_name].metadata_col_attributes
self_query_retriever = build_self_query_retriever(table_name)
self_query_chain = build_retrieval_qa_with_sources_chain(
table_name=table_name,
retriever=self_query_retriever,
chain_name="Self Query Retriever"
)
vector_sql_retriever = build_vector_sql_db_chain_retriever(table_name)
vector_sql_chain = build_retrieval_qa_with_sources_chain(
table_name=table_name,
retriever=vector_sql_retriever,
chain_name="Vector SQL DB Retriever"
)
metadata_columns = [
MetadataColumn(
name=attribute.name,
desc=attribute.description,
type=attribute.type
)
for attribute in metadata_col_attributes
]
return ChainsAndRetrievers(
metadata_columns=metadata_columns,
# for self query
retriever=self_query_retriever,
chain=self_query_chain,
# for vector sql
sql_retriever=vector_sql_retriever,
sql_chain=vector_sql_chain
)
@st.cache_resource
def build_chains_and_retrievers() -> Dict[str, Dict[str, Any]]:
chains_and_retrievers = {}
for table in MYSCALE_TABLES:
logger.info(f"Building chains, retrievers for table {table}")
chains_and_retrievers[table] = build_chains_retriever_for_table(table).to_dict()
return chains_and_retrievers
================================================
FILE: app/backend/construct/build_chains.py
================================================
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.schema import BaseRetriever
import streamlit as st
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
from backend.chains.stuff_documents import CustomStuffDocumentChain
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.prompts import COMBINE_PROMPT
from backend.constants.variables import GLOBAL_CONFIG
def build_retrieval_qa_with_sources_chain(
table_name: str,
retriever: BaseRetriever,
chain_name: str = ""
) -> CustomRetrievalQAWithSourcesChain:
with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'):
# Assign ref_id for documents
custom_stuff_document_chain = CustomStuffDocumentChain(
llm_chain=LLMChain(
prompt=COMBINE_PROMPT,
llm=ChatOpenAI(
model_name=GLOBAL_CONFIG.chat_model,
openai_api_key=GLOBAL_CONFIG.openai_api_key,
temperature=0.6
),
),
document_prompt=MYSCALE_TABLES[table_name].doc_prompt,
document_variable_name="summaries",
)
chain = CustomRetrievalQAWithSourcesChain(
retriever=retriever,
combine_documents_chain=custom_stuff_document_chain,
return_source_documents=True,
max_tokens_limit=12000,
)
return chain
================================================
FILE: app/backend/construct/build_chat_bot.py
================================================
from backend.chat_bot.private_knowledge_base import ChatBotKnowledgeTable
from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION, CHAT_SESSION_MANAGER
import streamlit as st
from backend.constants.variables import GLOBAL_CONFIG, TABLE_EMBEDDINGS_MAPPING
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
from backend.chat_bot.session_manager import SessionManager
def build_chat_knowledge_table():
if CHAT_KNOWLEDGE_TABLE not in st.session_state:
st.session_state[CHAT_KNOWLEDGE_TABLE] = ChatBotKnowledgeTable(
host=GLOBAL_CONFIG.myscale_host,
port=GLOBAL_CONFIG.myscale_port,
username=GLOBAL_CONFIG.myscale_user,
password=GLOBAL_CONFIG.myscale_password,
# embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["Wikipedia"],
embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["ArXiv Papers"],
parser_api_key=GLOBAL_CONFIG.untrusted_api,
)
def initialize_session_manager():
if CHAT_SESSION not in st.session_state:
st.session_state[CHAT_SESSION] = {
"session_id": "default",
"system_prompt": DEFAULT_SYSTEM_PROMPT,
}
if CHAT_SESSION_MANAGER not in st.session_state:
st.session_state[CHAT_SESSION_MANAGER] = SessionManager(
st.session_state,
host=GLOBAL_CONFIG.myscale_host,
port=GLOBAL_CONFIG.myscale_port,
username=GLOBAL_CONFIG.myscale_user,
password=GLOBAL_CONFIG.myscale_password,
)
================================================
FILE: app/backend/construct/build_retriever_tool.py
================================================
import json
from typing import List
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema import BaseRetriever, Document
from langchain.tools import Tool
from backend.chat_bot.json_decoder import CustomJSONEncoder
class RetrieverInput(BaseModel):
query: str = Field(description="query to look up in retriever")
def create_retriever_tool(
retriever: BaseRetriever,
tool_name: str,
description: str
) -> Tool:
"""Create a tool to do retrieval of documents.
Args:
retriever: The retriever to use for the retrieval
tool_name: The name for the tool. This will be passed to the language model,
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
Returns:
Tool class to pass to an agent
"""
def wrap(func):
def wrapped_retrieve(*args, **kwargs):
docs: List[Document] = func(*args, **kwargs)
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
return wrapped_retrieve
return Tool(
name=tool_name,
description=description,
func=wrap(retriever.get_relevant_documents),
coroutine=retriever.aget_relevant_documents,
args_schema=RetrieverInput,
)
================================================
FILE: app/backend/construct/build_retrievers.py
================================================
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.utilities.sql_database import SQLDatabase
from langchain.vectorstores import MyScaleSettings
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
from sqlalchemy import create_engine, MetaData
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.prompts import MYSCALE_PROMPT
from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG
from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser
from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson
from logger import logger
@st.cache_resource
def build_self_query_retriever(table_name: str) -> SelfQueryRetriever:
with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."):
myscale_connection = {
"host": GLOBAL_CONFIG.myscale_host,
"port": GLOBAL_CONFIG.myscale_port,
"username": GLOBAL_CONFIG.myscale_user,
"password": GLOBAL_CONFIG.myscale_password,
}
myscale_settings = MyScaleSettings(
**myscale_connection,
database=MYSCALE_TABLES[table_name].database,
table=MYSCALE_TABLES[table_name].table,
column_map={
"id": "id",
"text": MYSCALE_TABLES[table_name].text_col_name,
"vector": MYSCALE_TABLES[table_name].vector_col_name,
# TODO refine MyScaleDB metadata in langchain.
"metadata": MYSCALE_TABLES[table_name].metadata_col_name
}
)
myscale_vector_store = MyScaleWithoutMetadataJson(
embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
config=myscale_settings,
must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names
)
with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."):
retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm(
llm=ChatOpenAI(
model_name=GLOBAL_CONFIG.query_model,
base_url=GLOBAL_CONFIG.openai_api_base,
api_key=GLOBAL_CONFIG.openai_api_key,
temperature=0
),
vectorstore=myscale_vector_store,
document_contents=MYSCALE_TABLES[table_name].table_contents,
metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes,
use_original_query=False,
structured_query_translator=MyScaleTranslator()
)
return retriever
@st.cache_resource
def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever:
"""Get a group of relative docs from MyScaleDB"""
with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'):
if GLOBAL_CONFIG.myscale_enable_https == False:
engine = create_engine(
f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
f'/{MYSCALE_TABLES[table_name].database}?protocol=http'
)
else:
engine = create_engine(
f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
f'/{MYSCALE_TABLES[table_name].database}?protocol=https'
)
metadata = MetaData(bind=engine)
logger.info(f"{table_name} metadata is : {metadata}")
prompt = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=MYSCALE_PROMPT,
)
# Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column.
output_parser = VectorSQLRetrieveOutputParser.from_embeddings(
model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
# rewrite columns needs be searched.
must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names
)
# `db_chain` will generate a SQL
vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm(
llm=ChatOpenAI(
model_name=GLOBAL_CONFIG.query_model,
base_url=GLOBAL_CONFIG.openai_api_base,
api_key=GLOBAL_CONFIG.openai_api_key,
temperature=0
),
prompt=prompt,
top_k=10,
return_direct=True,
db=SQLDatabase(
engine,
None,
metadata,
include_tables=[MYSCALE_TABLES[table_name].table],
max_string_length=1024
),
sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type.
native_format=True
)
# `retriever` can search a group of documents with `db_chain`
vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever(
sql_db_chain=vector_sql_db_chain,
page_content_key=MYSCALE_TABLES[table_name].text_col_name
)
return vector_sql_db_chain_retriever
================================================
FILE: app/backend/retrievers/__init__.py
================================================
================================================
FILE: app/backend/retrievers/self_query.py
================================================
from typing import List
import pandas as pd
import streamlit as st
from langchain.retrievers import SelfQueryRetriever
from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler
from ui.utils import display
from logger import logger
def process_self_query(selected_table, query_type):
place_holder = st.empty()
logger.info(
f"button-1: {RetrieverButtons.self_query_from_db}, "
f"button-2: {RetrieverButtons.self_query_with_llm}, "
f"content: {st.session_state.query_self}"
)
with place_holder.expander('🪵 Chat Log', expanded=True):
try:
if query_type == RetrieverButtons.self_query_from_db:
callback = CustomSelfQueryRetrieverCallBackHandler()
retriever: SelfQueryRetriever = \
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"]
config: RunnableConfig = {"callbacks": [callback]}
relevant_docs = retriever.invoke(
input=st.session_state.query_self,
config=config
)
callback.progress_bar.progress(
value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!✅")
st.markdown(f"### Self Query Results from `{selected_table}` \n"
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
display(
dataframe=pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
),
columns_=MYSCALE_TABLES[selected_table].must_have_col_names
)
elif query_type == RetrieverButtons.self_query_with_llm:
# callback = CustomSelfQueryRetrieverCallBackHandler()
callback = ChatDataSelfAskCallBackHandler()
chain: CustomRetrievalQAWithSourcesChain = \
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"]
chain_results = chain(st.session_state.query_self, callbacks=[callback])
callback.progress_bar.progress(
value=1.0,
text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!✅"
)
documents_reference: List[Document] = chain_results["source_documents"]
st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n"
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
display(
pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
)
)
st.markdown(
f"### Answer from LLM \n"
f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n"
)
st.write(chain_results['answer'])
st.markdown(
f"### References from `{selected_table}`\n"
f"> Here shows that which documents used by LLM \n\n"
)
if len(chain_results['sources']) == 0:
st.write("No documents is used by LLM.")
else:
display(
dataframe=pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
),
columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
index='ref_id'
)
st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
except Exception as e:
st.write('Oops 😵 Something bad happened...')
raise e
================================================
FILE: app/backend/retrievers/vector_sql_output_parser.py
================================================
from typing import Dict, Any, List
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
class VectorSQLRetrieveOutputParser(VectorSQLOutputParser):
"""Based on VectorSQLOutputParser
It also modify the SQL to get all columns
"""
must_have_columns: List[str]
@property
def _type(self) -> str:
return "vector_sql_retrieve_custom"
def parse(self, text: str) -> Dict[str, Any]:
text = text.strip()
start = text.upper().find("SELECT")
if start >= 0:
end = text.upper().find("FROM")
text = text.replace(
text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
return super().parse(text)
================================================
FILE: app/backend/retrievers/vector_sql_query.py
================================================
from typing import List
import pandas as pd
import streamlit as st
from langchain.schema import Document
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
from backend.callbacks.vector_sql_callbacks import VectorSQLSearchDBCallBackHandler, VectorSQLSearchLLMCallBackHandler
from ui.utils import display
from logger import logger
def process_sql_query(selected_table: str, query_type: str):
place_holder = st.empty()
logger.info(
f"button-1: {st.session_state[RetrieverButtons.vector_sql_query_from_db]}, "
f"button-2: {st.session_state[RetrieverButtons.vector_sql_query_with_llm]}, "
f"table: {selected_table}, "
f"content: {st.session_state.query_sql}"
)
with place_holder.expander('🪵 Query Log', expanded=True):
try:
if query_type == RetrieverButtons.vector_sql_query_from_db:
callback = VectorSQLSearchDBCallBackHandler()
vector_sql_retriever: VectorSQLDatabaseChainRetriever = \
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_retriever"]
relevant_docs: List[Document] = vector_sql_retriever.get_relevant_documents(
query=st.session_state.query_sql,
callbacks=[callback]
)
callback.progress_bar.progress(
value=1.0,
text="[Question -> LLM -> SQL Statement -> MyScaleDB -> Results] Done! ✅"
)
st.markdown(f"### Vector Search Results from `{selected_table}` \n"
f"> Here we get documents from MyScaleDB with given sql statement \n\n")
display(
pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
)
)
elif query_type == RetrieverButtons.vector_sql_query_with_llm:
callback = VectorSQLSearchLLMCallBackHandler(table=selected_table)
vector_sql_chain: CustomRetrievalQAWithSourcesChain = \
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_chain"]
chain_results = vector_sql_chain(
inputs=st.session_state.query_sql,
callbacks=[callback]
)
callback.progress_bar.progress(
value=1.0,
text="[Question -> LLM -> SQL Statement -> MyScaleDB -> "
"(Question,Results) -> LLM -> Results] Done! ✅"
)
documents_reference: List[Document] = chain_results["source_documents"]
st.markdown(f"### Vector Search Results from `{selected_table}` \n"
f"> Here we get documents from MyScaleDB with given sql statement \n\n")
display(
pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
)
)
st.markdown(
f"### Answer from LLM \n"
f"> The response of the LLM when given the vector search results. \n\n"
)
st.write(chain_results['answer'])
st.markdown(
f"### References from `{selected_table}`\n"
f"> Here shows that which documents used by LLM \n\n"
)
if len(chain_results['sources']) == 0:
st.write("No documents is used by LLM.")
else:
display(
dataframe=pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
),
columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
index='ref_id'
)
else:
raise NotImplementedError(f"Unsupported query type: {query_type}")
st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
except Exception as e:
st.write('Oops 😵 Something bad happened...')
raise e
================================================
FILE: app/backend/types/__init__.py
================================================
================================================
FILE: app/backend/types/chains_and_retrievers.py
================================================
from typing import Dict
from dataclasses import dataclass
from typing import List, Any
from langchain.retrievers import SelfQueryRetriever
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
@dataclass
class MetadataColumn:
name: str
desc: str
type: str
@dataclass
class ChainsAndRetrievers:
metadata_columns: List[MetadataColumn]
retriever: SelfQueryRetriever
chain: CustomRetrievalQAWithSourcesChain
sql_retriever: VectorSQLDatabaseChainRetriever
sql_chain: CustomRetrievalQAWithSourcesChain
def to_dict(self) -> Dict[str, Any]:
return {
"metadata_columns": self.metadata_columns,
"retriever": self.retriever,
"chain": self.chain,
"sql_retriever": self.sql_retriever,
"sql_chain": self.sql_chain
}
================================================
FILE: app/backend/types/global_config.py
================================================
from dataclasses import dataclass
from typing import Optional
@dataclass
class GlobalConfig:
openai_api_base: Optional[str] = ""
openai_api_key: Optional[str] = ""
auth0_client_id: Optional[str] = ""
auth0_domain: Optional[str] = ""
myscale_user: Optional[str] = ""
myscale_password: Optional[str] = ""
myscale_host: Optional[str] = ""
myscale_port: Optional[int] = 443
query_model: Optional[str] = ""
chat_model: Optional[str] = ""
untrusted_api: Optional[str] = ""
myscale_enable_https: Optional[bool] = True
================================================
FILE: app/backend/types/table_config.py
================================================
from typing import Callable
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.prompts import PromptTemplate
from dataclasses import dataclass
from typing import List
@dataclass
class TableConfig:
database: str
table: str
table_contents: str
# column names
must_have_col_names: List[str]
vector_col_name: str
text_col_name: str
metadata_col_name: str
# hint for UI
hint: Callable
hint_sql: Callable
# for langchain
doc_prompt: PromptTemplate
metadata_col_attributes: List[AttributeInfo]
emb_model: Callable
tool_desc: tuple
================================================
FILE: app/backend/vector_store/__init__.py
================================================
================================================
FILE: app/backend/vector_store/myscale_without_metadata.py
================================================
from typing import Any, Optional, List
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.myscale import MyScale, MyScaleSettings
from logger import logger
class MyScaleWithoutMetadataJson(MyScale):
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [],
**kwargs: Any) -> None:
try:
super().__init__(embedding, config, **kwargs)
except Exception as e:
logger.error(e)
self.must_have_cols: List[str] = must_have_cols
def _build_qstr(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"PREWHERE {where_str}"
else:
where_str = ""
q_str = f"""
SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
AS dist {self.dist_order}
LIMIT {topk}
"""
return q_str
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None,
**kwargs: Any) -> List[Document]:
q_str = self._build_qstr(embedding, k, where_str)
try:
return [
Document(
page_content=r[self.config.column_map["text"]],
metadata={k: r[k] for k in self.must_have_cols},
)
for r in self.client.query(q_str).named_results()
]
except Exception as e:
logger.error(
f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
================================================
FILE: app/logger.py
================================================
import logging
def setup_logger():
logger_ = logging.getLogger('chat-data')
logger_.setLevel(logging.INFO)
if not logger_.handlers:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(filename)s - %(funcName)s - %(levelname)s - %(message)s - [Thread ID: %(thread)d]'
)
console_handler.setFormatter(formatter)
logger_.addHandler(console_handler)
return logger_
logger = setup_logger()
================================================
FILE: app/requirements.txt
================================================
langchain==0.2.1
langchain-community==0.2.1
langchain-core==0.2.1
langchain-experimental==0.0.59
langchain-openai==0.1.7
sentence-transformers==2.2.2
InstructorEmbedding
pandas
streamlit==1.36.0
streamlit-extras
streamlit-auth0-component
altair==4.2.2
clickhouse-connect
openai==1.35.3
lark
tiktoken
sql-formatter
sqlalchemy==1.4.48
clickhouse-sqlalchemy
================================================
FILE: app/ui/__init__.py
================================================
================================================
FILE: app/ui/chat_page.py
================================================
import datetime
import json
import pandas as pd
import streamlit as st
from langchain_core.messages import HumanMessage, FunctionMessage
from streamlit.delta_generator import DeltaGenerator
from backend.chat_bot.json_decoder import CustomJSONDecoder
from backend.constants.streamlit_keys import CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, \
EL_UPLOAD_FILES_STATUS, USER_PRIVATE_FILES, EL_BUILD_KB_WITH_FILES, \
EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
CHAT_KNOWLEDGE_TABLE, EL_UPLOAD_FILES, EL_SELECTED_KBS
from backend.constants.variables import DIVIDER_HTML, USER_NAME, RETRIEVER_TOOLS
from backend.construct.build_chat_bot import build_chat_knowledge_table, initialize_session_manager
from backend.chat_bot.chat import refresh_sessions, on_session_change_submit, refresh_agent, \
create_private_knowledge_base_as_tool, \
remove_private_knowledge_bases, add_file, clear_files, clear_history, back_to_main, on_chat_submit
def render_session_manager():
with st.expander("🤖 Session Management"):
if CHAT_CURRENT_USER_SESSIONS not in st.session_state:
refresh_sessions()
st.markdown("Here you can update `session_id` and `system_prompt`")
st.markdown("- Click empty row to add a new item")
st.markdown("- If needs to delete an item, just click it and press `DEL` key")
st.markdown("- Don't forget to submit your change.")
st.data_editor(
data=st.session_state[CHAT_CURRENT_USER_SESSIONS],
num_rows="dynamic",
key="session_editor",
use_container_width=True,
)
st.button("⏫ Submit", on_click=on_session_change_submit, type="primary")
def render_session_selection():
with st.expander("✅ Session Selection", expanded=True):
st.selectbox(
"Choose a `session` to chat",
options=st.session_state[CHAT_CURRENT_USER_SESSIONS],
index=None,
key=EL_SESSION_SELECTOR,
format_func=lambda x: x["session_id"],
on_change=refresh_agent,
)
def render_files_manager():
with st.expander("📃 **Upload your personal files**", expanded=False):
st.markdown("- Files will be parsed by [Unstructured API](https://unstructured.io/api-key).")
st.markdown("- All files will be converted into vectors and stored in [MyScaleDB](https://myscale.com/).")
st.file_uploader(label="⏫ **Upload files**", key=EL_UPLOAD_FILES, accept_multiple_files=True)
# st.markdown("### Uploaded Files")
st.dataframe(
data=st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(st.session_state[USER_NAME]),
use_container_width=True,
)
st.session_state[EL_UPLOAD_FILES_STATUS] = st.empty()
col_1, col_2 = st.columns(2)
with col_1:
st.button(label="Upload files", on_click=add_file)
with col_2:
st.button(label="Clear all files and tools", on_click=clear_files)
def _render_create_personal_knowledge_bases(div: DeltaGenerator):
with div:
st.markdown("- If you haven't upload your personal files, please upload them first.")
st.markdown("- Select some **files** to build your `personal knowledge base`.")
st.markdown("- Once the your `personal knowledge base` is built, "
"it will answer your questions using information from your personal **files**.")
st.multiselect(
label="⚡️Select some files to build a **personal knowledge base**",
options=st.session_state[USER_PRIVATE_FILES],
placeholder="You should upload some files first",
key=EL_BUILD_KB_WITH_FILES,
format_func=lambda x: x["file_name"],
)
st.text_input(
label="⚡️Personal knowledge base name",
value="get_relevant_documents",
key=EL_PERSONAL_KB_NAME
)
st.text_input(
label="⚡️Personal knowledge base description",
value="Searches from some personal files.",
key=EL_PERSONAL_KB_DESCRIPTION,
)
st.button(
label="Build 🔧",
on_click=create_private_knowledge_base_as_tool
)
def _render_remove_personal_knowledge_bases(div: DeltaGenerator):
with div:
st.markdown("> Here is all your personal knowledge bases.")
if USER_PERSONAL_KNOWLEDGE_BASES in st.session_state and len(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) > 0:
st.dataframe(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES])
else:
st.warning("You don't have any personal knowledge bases, please create a new one.")
st.multiselect(
label="Choose a personal knowledge base to delete",
placeholder="Choose a personal knowledge base to delete",
options=st.session_state[USER_PERSONAL_KNOWLEDGE_BASES],
format_func=lambda x: x["tool_name"],
key=EL_PERSONAL_KB_NEEDS_REMOVE,
)
st.button("Delete", on_click=remove_private_knowledge_bases, type="primary")
def render_personal_tools_build():
with st.expander("🔨 **Build your personal knowledge base**", expanded=True):
create_new_kb, kb_manager = st.tabs(["Create personal knowledge base", "Personal knowledge base management"])
_render_create_personal_knowledge_bases(create_new_kb)
_render_remove_personal_knowledge_bases(kb_manager)
def render_knowledge_base_selector():
with st.expander("🙋 **Select some knowledge bases to query**", expanded=True):
st.markdown("- Knowledge bases come in two types: `public` and `private`.")
st.markdown("- All users can access our `public` knowledge bases.")
st.markdown("- Only you can access your `personal` knowledge bases.")
options = st.session_state[RETRIEVER_TOOLS].keys()
if AVAILABLE_RETRIEVAL_TOOLS in st.session_state:
options = st.session_state[AVAILABLE_RETRIEVAL_TOOLS]
st.multiselect(
label="Select some knowledge base tool",
placeholder="Please select some knowledge bases to query",
options=options,
default=["Wikipedia + Self Querying"],
key=EL_SELECTED_KBS,
on_change=refresh_agent,
)
def chat_page():
# initialize resources
build_chat_knowledge_table()
initialize_session_manager()
# render sidebar
with st.sidebar:
left, middle, right = st.columns([1, 1, 2])
with left:
st.button(label="↩️ Log Out", help="log out and back to main page", on_click=back_to_main)
with right:
st.markdown(f"👤 `{st.session_state[USER_NAME]}`")
st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
render_session_manager()
render_session_selection()
render_files_manager()
render_personal_tools_build()
render_knowledge_base_selector()
# render chat history
if "agent" not in st.session_state:
refresh_agent()
for msg in st.session_state.agent.memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message(name="from knowledge base", avatar="📚"):
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write("Retrieved from knowledge base:")
try:
st.dataframe(
pd.DataFrame.from_records(
json.loads(msg.content, cls=CustomJSONDecoder)
),
use_container_width=True,
)
except Exception as e:
st.warning(e)
st.write(msg.content)
else:
if len(msg.content) > 0:
with st.chat_message(speaker):
# print(type(msg), msg.dict())
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write(f"{msg.content}")
st.session_state["next_round"] = st.empty()
from streamlit import _bottom
with _bottom:
col1, col2 = st.columns([1, 16])
with col1:
st.button("🗑️", help="Clean chat history", on_click=clear_history, type="secondary")
with col2:
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
================================================
FILE: app/ui/home.py
================================================
import base64
from streamlit_extras.add_vertical_space import add_vertical_space
from streamlit_extras.card import card
from streamlit_extras.colored_header import colored_header
from streamlit_extras.mention import mention
from streamlit_extras.tags import tagger_component
from logger import logger
import os
import streamlit as st
from auth0_component import login_button
from backend.constants.variables import JUMP_QUERY_ASK, USER_INFO, USER_NAME, DIVIDER_HTML, DIVIDER_THIN_HTML
from streamlit_extras.let_it_rain import rain
def render_home():
render_home_header()
# st.divider()
# st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
add_vertical_space(5)
render_home_content()
# st.divider()
st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
render_home_footer()
def render_home_header():
logger.info("render home header")
st.header("ChatData - Your Intelligent Assistant")
st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
st.markdown("> [ChatData](https://github.com/myscale/ChatData) \
is developed by [MyScale](https://myscale.com/), \
it's an integration of [LangChain](https://www.langchain.com/) \
and [MyScaleDB](https://github.com/myscale/myscaledb)")
tagger_component(
"Keywords:",
["MyScaleDB", "LangChain", "VectorSearch", "ChatBot", "GPT", "arxiv", "wikipedia", "Personal Knowledge Base 📚"],
color_name=["darkslateblue", "green", "orange", "darkslategrey", "red", "crimson", "darkcyan", "darkgrey"],
)
text, col1, col2, col3, _ = st.columns([1, 1, 1, 1, 4])
with text:
st.markdown("Related:")
with col1.container():
mention(
label="streamlit",
icon="streamlit",
url="https://streamlit.io/",
write=True
)
with col2.container():
mention(
label="langchain",
icon="🦜🔗",
url="https://www.langchain.com/",
write=True
)
with col3.container():
mention(
label="streamlit-extras",
icon="🪢",
url="https://github.com/arnaudmiribel/streamlit-extras",
write=True
)
def _render_self_query_chain_content():
col1, col2 = st.columns([1, 1], gap='large')
with col1.container():
st.image(image='./assets/home_page_background_1.png',
caption=None,
width=None,
use_column_width=True,
clamp=False,
channels="RGB",
output_format="PNG")
with col2.container():
st.header("VectorSearch & SelfQuery with Sources")
st.info("In this sample, you will learn how **LangChain** integrates with **MyScaleDB**.")
st.markdown("""This example demonstrates two methods for integrating MyScale into LangChain: [Vector SQL](https://api.python.langchain.com/en/latest/sql/langchain_experimental.sql.vector_sql.VectorSQLDatabaseChain.html) and [Self-querying retriever](https://python.langchain.com/v0.2/docs/integrations/retrievers/self_query/myscale_self_query/). For each method, you can choose one of the following options:
1. `Retrieve from MyScaleDB ➡️` - The LLM (GPT) converts user queries into SQL statements with vector search, executes these searches in MyScaleDB, and retrieves relevant content.
2. `Retrieve and answer with LLM ➡️` - After retrieving relevant content from MyScaleDB, the user query along with the retrieved content is sent to the LLM (GPT), which then provides a comprehensive answer.""")
add_vertical_space(3)
_, middle, _ = st.columns([2, 1, 2], gap='small')
with middle.container():
st.session_state[JUMP_QUERY_ASK] = st.button("Try sample", use_container_width=False, type="secondary")
def _render_chat_bot_content():
col1, col2 = st.columns(2, gap='large')
with col1.container():
st.image(image='./assets/home_page_background_2.png',
caption=None,
width=None,
use_column_width=True,
clamp=False,
channels="RGB",
output_format="PNG")
with col2.container():
st.header("Chat Bot")
st.info("Now you can try our chatbot, this chatbot is built with MyScale and LangChain.")
st.markdown("- You need to go to [https://myscale-chatdata.hf.space/](https://myscale-chatdata.hf.space/) "
"to log in successfully, otherwise the auth service will not work.")
st.markdown("- You can upload your own PDF files and build your own knowledge base. \
(This is just a sample application. Please do not upload important or confidential files.)")
st.markdown("- A default session will be assigned as your initial chat session. \
You can create and switch to other sessions to jump between different chat conversations.")
add_vertical_space(1)
_, middle, _ = st.columns([1, 2, 1], gap='small')
with middle.container():
if USER_NAME not in st.session_state:
login_button(clientId=os.environ["AUTH0_CLIENT_ID"],
domain=os.environ["AUTH0_DOMAIN"],
key="auth0")
# if user_info:
# user_name = user_info.get("nickname", "default") + "_" + user_info.get("email", "null")
# st.session_state[USER_NAME] = user_name
# print(user_info)
def render_home_content():
logger.info("render home content")
_render_self_query_chain_content()
add_vertical_space(3)
_render_chat_bot_content()
def render_home_footer():
logger.info("render home footer")
st.write(
"Please follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!"
)
st.write(
"For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
st.write("Our [privacy policy](https://myscale.com/privacy/), [terms of service](https://myscale.com/terms/)")
# st.write(
# "Recommended to use the standalone version of Chat-Data, "
# "available [here](https://myscale-chatdata.hf.space/)."
# )
if st.session_state.auth0 is not None:
st.session_state[USER_INFO] = dict(st.session_state.auth0)
if 'email' in st.session_state[USER_INFO]:
email = st.session_state[USER_INFO]["email"]
else:
email = f"{st.session_state[USER_INFO]['nickname']}@{st.session_state[USER_INFO]['sub']}"
st.session_state["user_name"] = email
del st.session_state.auth0
st.rerun()
if st.session_state.jump_query_ask:
st.rerun()
================================================
FILE: app/ui/retrievers.py
================================================
import streamlit as st
from streamlit_extras.add_vertical_space import add_vertical_space
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons
from backend.retrievers.self_query import process_self_query
from backend.retrievers.vector_sql_query import process_sql_query
from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO
def back_to_main():
if USER_INFO in st.session_state:
del st.session_state[USER_INFO]
if USER_NAME in st.session_state:
del st.session_state[USER_NAME]
if JUMP_QUERY_ASK in st.session_state:
del st.session_state[JUMP_QUERY_ASK]
def _render_table_selector() -> str:
col1, col2 = st.columns(2)
with col1:
selected_table = st.selectbox(
label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.',
options=MYSCALE_TABLES.keys(),
)
MYSCALE_TABLES[selected_table].hint()
with col2:
add_vertical_space(1)
st.info(f"Here is your selected public knowledge base schema in MyScaleDB",
icon='📚')
MYSCALE_TABLES[selected_table].hint_sql()
return selected_table
def render_retrievers():
st.button("⬅️ Back", key="back_sql", on_click=back_to_main)
st.subheader('Please choose a public knowledge base to search.')
selected_table = _render_table_selector()
tab_sql, tab_self_query = st.tabs(
tabs=['Vector SQL', 'Self-querying Retriever']
)
with tab_sql:
render_tab_sql(selected_table)
with tab_self_query:
render_tab_self_query(selected_table)
def render_tab_sql(selected_table: str):
st.warning(
"When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
"the metadata we provide. This table allows filters to be established on the following metadata fields:",
icon="⚠️")
st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
cols = st.columns([8, 3, 3, 2])
cols[0].text_input("Input your question:", key='query_sql')
with cols[1].container():
add_vertical_space(2)
st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db)
with cols[2].container():
add_vertical_space(2)
st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm)
if st.session_state[RetrieverButtons.vector_sql_query_from_db]:
process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db)
if st.session_state[RetrieverButtons.vector_sql_query_with_llm]:
process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm)
def render_tab_self_query(selected_table):
st.warning(
"When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
"the metadata we provide. This table allows filters to be established on the following metadata fields:",
icon="⚠️")
st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
cols = st.columns([8, 3, 3, 2])
cols[0].text_input("Input your question:", key='query_self')
with cols[1].container():
add_vertical_space(2)
st.button("Retrieve from MyScaleDB ➡️", key='search_self')
with cols[2].container():
add_vertical_space(2)
st.button("Retrieve and answer with LLM ➡️", key='ask_self')
if st.session_state.search_self:
process_self_query(selected_table, RetrieverButtons.self_query_from_db)
if st.session_state.ask_self:
process_self_query(selected_table, RetrieverButtons.self_query_with_llm)
================================================
FILE: app/ui/utils.py
================================================
import streamlit as st
def display(dataframe, columns_=None, index=None):
if len(dataframe) > 0:
if index:
dataframe.set_index(index)
if columns_:
st.dataframe(dataframe[columns_])
else:
st.dataframe(dataframe)
else:
st.write(
"Sorry 😵 we didn't find any articles related to your query.\n\n"
"Maybe the LLM is too naughty that does not follow our instruction... \n\n"
"Please try again and use verbs that may match the datatype.",
unsafe_allow_html=True
)
================================================
FILE: docs/self-query.md
================================================
# HOW-TO: Build a ChatPDF App over millions of documents with LangChain and MyScale in 30 Minutes
Chatting with GPT about a single academic paper is relatively straightforward by providing the document as the language model context. Chatting with millions of research papers is also simple... as long as you choose the right vector database.
Large language models (LLM) are powerful NLP tools. One of the most significant benefits of LLMs, such as ChatGPT, is that you can use them to build tools that allow you to interact (or chat) with documents, such as PDF copies of research or academic papers, based on their topics and content.
Many implementations of chat-with-document apps already exist, like [ChatPaper](https://github.com/kaixindelele/ChatPaper), [OpenChatPaper](https://github.com/liuyixin-louis/OpenChatPaper), and [DocsMind](https://github.com/3Alan/DocsMind). But, many of these implementations seem complicated, with simplistic search utilities only featuring elementary keyword searches filtering on basic metadata such as year and subject.
Therefore, developing a ChatPDF-like app to interact with millions of academic/research papers makes sense. You can chat with the data in your natural language, combining both semantic and structural attributes, asking questions such as "What is a neural network?" and providing additional qualifiers like "Please use articles published by Geoffrey Hinton after 2018."
The primary purpose of this article is to help you build your own ChatPDF app that allows you to interact (chat) with millions of academic/research papers using LangChain and MyScale.
> This app should take about 30 minutes to create.
But before we start, let’s look at the following diagrammatic workflow of the whole process:
> Even though we describe how to develop this LLM-based chat app, we have a sample app on [GitHub](https://github.com/myscale/ChatData), including access to a [read-only vector database](../app/.streamlit/secrets.example.toml), further simplifying the app-creation process.
## Prepare the Data
As described in this image, the first step is to prepare the data.
> We recommend you to use our open database for this app. The credentials are in the example configuration: `$PROJECT_DIR/app/.streamlit/secrets.toml`. Or you can follow the instruction below to create your own database. It takes about 20 minutes to create the database.
We have sourced our data: a usable list of abstracts and arXiv IDs from the Alexandria Index through the [Macrocosm website](https://alex.macrocosm.so/). Using this data and interrogating the arXiv Open API, we can significantly enhance the query experience by retrieving a much richer set of metadata, including year, subject, release date, category, and author.
> We have prepared the data using the [arXiv Open API](https://info.arxiv.org/help/api/index.html) to simplify the app-creation process.
Now that we have the data, let's dive into the next steps:
## Create the Table
The first step is to create a table schema. Sign into [myscale.com](http://myscale.com/) and create a free [cluster](http://console.myscale.com/).
After creating the cluster, the next step is to create a table that is compatible with our MyScale VectorStore.
> There will be a workspace sidebar on your left. This is the place where you can play with pure SQL.
Use the following script to create a MyScale VectorStore table:
```sql
CREATE TABLE default.langchain (
`abstract` String,
`id` String,
`vector` Array(Float32),
`metadata` Object('JSON'),
CONSTRAINT vec_len CHECK length(vector) = 768)
ENGINE = ReplacingMergeTree ORDER BY id
```
You can also create a table initializing a MyScale VectorStore in LangChain with Python, as the following code sample describes:
```python
from langchain.vectorstores import MyScale, MyScaleSettings
config = MyScaleSetting(host="", port=8443, ...)
doc_search = MyScale(embedding_function, config)
```
Both these methods do the same job.
Great. Let’s move onto the next step.
## Insert the Data
> We have appended additional metadata, like the publication date and authors, to each ArXiv entry.
Our data is hosted on [Amazon S3](https://docs.aws.amazon.com/AmazonS3/latest/userguide/Welcome.html), supported by [Clickhouse table functions](https://clickhouse.com/docs/en/sql-reference/table-functions/s3).
The compressed `jsonl` file is available [here](https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/full.json.zst). You can also import data via partitioned dataset (113 parts) from our AWS S3 bucket (the URL will look like `https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/data.part*.0.jsonl.zst`) to your MyScale Cloud with [S3 table function](https://clickhouse.com/docs/en/sql-reference/table-functions/s3).
> You can also upload the data onto Google Cloud Platform and use the same SQL insert query to import this data.
To insert data into this table, you still have the same options as when creating a VectorStore table: MyScale SQL workspace or LangChain.
### MyScale SQL Workspace
Paste the following SQL statement into the MyScale SQL workspace and click **Run**:
```sql
INSERT INTO langchain
SELECT
*
FROM
s3(
'https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/data.*.jsonl.zst',
'JSONEachRow',
'abstract String, id String, vector Array(Float32), metadata Object(''JSON'')',
'zstd'
)
```
Then you need to build the vector index with this SQL:
```sql
ALTER TABLE langchain ADD VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine')
```
### LangChain
The second option is to insert the data into the table using LangChain for better control over the data insertion process.
Add the following code snippet to your app’s code:
```python
# ! unzstd data-*.jsonl.zst
import json
from langchain.docstore.document import Document
def str2doc(_str):
j = json.loads(_str)
return Document(page_content=j['abstract'], metadata=j['metadata'])
with open('func_call_data.jsonl') as f:
docs = [str2doc(l) for l in f.readlines()]
```
## Design the Query Pipeline
Most LLM-based applications need an automated pipeline for querying and returning an answer to the query.
> Chat-with-LLM apps must generally retrieve reference documents before querying their models (LLMs).
Let's look at the step-by-step workflow describing how the app answers the user's queries, as illustrated in the following diagram:
1. **Ask for the user's input/questions.**
This input must be as concise as possible. In most cases, it should be at most several sentences.
2. **Construct a DB query from the user's input.**
The query is simple for vector databases. All you need to do is to extract the relevant embedding from the vector database. However, for enhanced accuracy, it is advisable to filter your query.
For instance, let's assume the user only wants the latest papers rather than all the papers in the returned embedding, but the returned embedding includes all the research papers. By way of solving this challenge, you can add metadata filters to the query to filter out the correct information.
3. **Parse the retrieved documents from VectorStore.**
The data returned from the vector store is not in a native format that the LLM understands. You must parse it and insert it into your prompt templates. Sometimes you need to add more metadata to these templates, like the date created, authors, or document categories. This metadata will help LLM improve the quality of its answer.
4. **Ask the LLM.**
This process is straightforward as long as you are familiar with the LLM's API and have properly designed prompts.
5. **Fetch the answer**
Returning the answer is straightforward for simple applications. But, if the question is complex, additional effort is required to provide more information to the user; for example, adding the LLM's source data. Additionally, adding reference numbers to the prompt can help you find the source and reduce your prompt's size by avoiding repeating content, such as the document's title.
In practice, LangChain has a good framework to work with. We used the following functions to build this pipeline:
* `RetrievalQAWithSourcesChain`
* `SelfQueryRetriever`
### `SelfQueryRetriever`
This function defines the interaction between the VectorStore and your app. Let’s dive deeper into how a self-query retriever works, as illustrated in the following diagram:
LangChain’s `SelfQueryRetriever` defines a universal filter for every VectorStore, including several `comparators` for comparing values and `operators`, combining these conditions to form a filter. The LLM will generate a filter rule based on these `comparators` and `operators`. All VectorStore providers will implement a `FilterTranslator` to translate the given universal filter to the correct arguments that call the VectorStore.
LangChain's universal solution provides a complete package for new operators, comparators, and vector store providers. However, you are limited to the pre-defined elements inside it.
In the prompt filter context, MyScale includes more powerful and flexible filters. We have added more data types, like lists and timestamps, and more functions, like string pattern matching and `CONTAIN` comparators for lists, offering more options for data storage and query design.
> We contributed to LangChain's Self-Query retrievers to make them more powerful, resulting in self-query retrievers that provide more freedom to the LLM when designing the query.
Look at [what else MyScale can do with metadata filters](https://myscale.com/blog/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints).
Here is the code for it, written using LangChain:
```python
from langchain.vectorstores import MyScale
from langchain.embeddings import HuggingFaceInstructEmbeddings
# Assuming you data is ready on MyScale Cloud
embeddings = HuggingFaceInstructEmbeddings()
doc_search = MyScale(embeddings)
# Define metadata fields and their types
# Descriptions are important. That's where LLM know how to use that metadata.
metadata_field_info=[
AttributeInfo(
name="pubdate",
description="The year the paper is published",
type="timestamp",
),
AttributeInfo(
name="authors",
description="List of author names",
type="list[string]",
),
AttributeInfo(
name="title",
description="Title of the paper",
type="string",
),
AttributeInfo(
name="categories",
description="arxiv categories to this paper",
type="list[string]"
),
AttributeInfo(
name="length(categories)",
description="length of arxiv categories to this paper",
type="int"
),
]
# Now build a retriever with LLM, a vector store and your metadata info
retriever = SelfQueryRetriever.from_llm(
OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
doc_search, "Scientific papers indexes with abstracts", metadata_field_info,
use_original_query=True)
```
### `RetrievalQAWithSourcesChain`
This function constructs the prompts containing the documents.
Data should be formatted into LLM readable strings, like JSON or Markdown that contain the document’s info in it.
As highlighted above, once the document data has been retrieved from the vector store, it must be formatted into LLM-readable strings, like JSON or Markdown.
LangChain uses the following chains to build these LLM-readable strings:
* `MapReduceDocumentsChain`
* `StuffDocumentsChain`
`MapReduceDocumentChain` gathers all the documents the vector store returns and normalizes them into a standard format. It maps the documents to a prompt template and concatenates them together. `StuffDocumentChain` works on those formatted documents, inserting them as context with task descriptions as prefixes and examples as suffixes.
Add the following code snippet to your app’s code so that your app will format the vector store data into LLM-readable documents.
```python
chain = RetrievalQAWithSourcesChain.from_llm(
llm=OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.
retriever=retriever,
return_source_documents=True,)
```
## Run the Chain
With these components, we can now search and answer the user's questions with a scalable vector store.
Try it yourself!
```python
ret = st.session_state.chain(st.session_state.query, callbacks=[callback])
# You can find the answer from LLM in the field `answer`
st.markdown(f"### Answer from LLM\n{ret['answer']}\n### References")
# and source documents in `sources` and `source_documents`
docs = ret['source_documents']
```
Not responsive?
### Add Callbacks
The chain works just fine, but you might have a complaint: It needs to be faster!
Yes, the chain will be slow as it will construct a filtered vector query (one LLM call), retrieve data from VectorStore and ask LLM (another LLM call). Consequently, the total execution time will be about 10~20 seconds.
Don't worry; LangChain has your back. It includes [Callbacks](https://python.langchain.com/en/latest/modules/callbacks/getting_started.html?highlight=Callbacks) that you can use to increase your app's responsiveness. In our example, we added several callback functions to update a progress bar:
```python
class ChatArXivAskCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
# You will have a progress bar when this callback is initialized
self.progress_bar = st.progress(value=0.0, text='Searching DB...')
self.status_bar = st.empty()
self.prog_value = 0.0
# You can use chain names to control the progress
self.prog_map = {
'langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain': 0.2,
'langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain': 0.4,
'langchain.chains.combine_documents.stuff.StuffDocumentsChain': 0.8
}
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_text(self, text: str, **kwargs) -> None:
pass
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
# the name is in list, so you can join them in strings.
cid = '.'.join(serialized['id'])
if cid != 'langchain.chains.llm.LLMChain':
self.progress_bar.progress(value=self.prog_map[cid], text=f'Running Chain `{cid}`...')
self.prog_value = self.prog_map[cid]
else:
self.prog_value += 0.1
self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
def on_chain_end(self, outputs, **kwargs) -> None:
pass
```
Now your app will have a pretty progress bar just like ours.
## In Conclusion
This is how an LLM app should be built with LangChain!
Today we provided a brief overview of how to build a simple LLM app that chats with the MyScale VectorStore, and also explained how to use chains in the query pipeline.
We hope this article helps you when you design your LLM-based app architecture from the ground up.
You can also ask for help on our [Discord server](https://discord.gg/D2qpkqc4Jq). We are happy to help, whether on vector databases, LLM apps, or other fantastic stuff. You are also welcomed to use our open database to build your own apps! We believe you can more awesome apps with this self-query retriever with MyScale! Happy Coding!
See you in the following article!
================================================
FILE: docs/vector-sql.md
================================================
# Teach Your LLM to Search Using Vector SQL and Answer With Facts From Database
A vector database that supports Structured Query Language can store more than vectors. Common data types like timestamps and arrays can be accessed and filtered within the database, which improves the accuracy and efficiency of vector search queries. Accurate results from the database can teach LLMs to speak with facts, which reduces hallucination and enhance the quality and credibility of answers from LLM.
## What is Hallucination?
Large Language Models are advanced AI systems that can answer a wide range of questions. Although they provide informative responses on topics they know, they are not always accurate on unfamiliar topics. This phenomenon is known as **hallucination**.
Before we look at an example of an LLM hallucination, let's consider a definition of the term "hallucination" as described by [Wikipedia.com](https://en.wikipedia.org/wiki/Hallucination):
> "A hallucination is a perception in the absence of an external stimulus that has the qualities of a real perception."
Moreover:
> "Hallucinations are vivid, substantial, and are perceived to be located in external objective space."
In other words, a hallucination is an error in (or a false) perception of something real or concrete. For example, a Large Language Model was asked what LLM hallucinations are, with the answer being:
](../assets/hallucination.png)
Therefore, the question begs, how do we improve on (or fix) this result? The concise answer is to add facts to your question, such as providing the LLM definition before or after you ask the question.
For instance:
> An LLM is a Large Language Model, an artificial neural network that models how humans talk and write. Please tell me, what is LLM hallucination?
The public domain answer to this question, provided by ChatGPT, is:

**Note:** The reason for the first sentence, "Apologies for the confusion in my earlier response," is that we asked ChatGPT our first question, what LLM hallucinations are, before giving it our second prompt: "An LLM..."
These additions have improved the quality of the answer. At least it no longer thinks an LLM hallucination is a "Late-Life Migraine Accompaniment!" 😆
## External Knowledge Reduces Hallucinations
At this juncture, it is absolutely crucial to note that an LLM is not infallible nor the ultimate authority on all knowledge. LLMs are trained on large amounts of data and learn patterns in language, but they may not always have access to the most up-to-date information or have a comprehensive understanding of complex topics.
What now? How do you increase the chance of reducing LLM hallucinations?
The solution to this problem is to include supporting documents to the query (or prompt) to guide the LLM toward a more accurate and informed response. Like humans, it needs to learn from these documents to answer your question accurately and correctly.
Helpful documents can come from many sources, including a search engine like Google or Bing and a digital library like Arxiv, among others, providing an interface to search for relevant passages. Using a database is also a good choice, providing a more flexible and private query interface.
Knowledge retrieved from sources must be relevant to the question/prompt. There are several ways to retrieve relevant documents, including:
* **Keyword-based:** Searching for keywords in plain text, suitable for an exact match on terms.
* **Vector search-based:** Searching for records closer to embeddings, helpful in searching for appropriate paraphrases or general documents.
Nowadays, vector searches are popular since they can solve paraphrase problems and calculate paragraph meanings. Vector search is not a one-size-fits-all solution; it should be paired with specific filters to maintain its performance, especially when searching massive volumes of records. For example, should you only want to retrieve knowledge about physics (as a subject), you must filter out all information about any other subjects. Thus, the LLM will not be confused by knowledge from other disciplines.
## Automate the Whole Process with SQL... and Vector Search
The LLM should also learn to query data from its data sources before answering the questions, automating the whole process. Actually, LLMs are already capable of writing SQL queries and following instructions.

SQL is powerful and can be used to construct complex search queries. It supports many different data types and functions. And it allows us to write a vector search in SQL with `ORDER BY` and `LIMIT`, treating the similarity score between embeddings as a column `distance`. Pretty straightforward, isn't it?
> See the next section, [What Vector SQL Looks Like](#what-vector-sql-looks-like), for more information on structuring a vector SQL query.
There are significant benefits to using vector SQL to build complex search queries, including:
* Increased flexibility for data type and function support
* Improved efficiency because SQL is highly optimized and executed inside the database
* Is human-readable and easy to learn as it is an extension of standard SQL
* Is LLM-friendly
**Note:** Many SQL examples and tutorials are available on the Internet. LLMs are familiar with standard SQL as well as some of its dialects.
Apart from MyScale, many SQL database solutions like Clickhouse and PostgreSQL are adding vector search to their existing functionality, allowing users to use vector SQL and LLMs to answer questions on complex topics. Similarly, an increasing number of application developers are starting to integrate vector searches with SQL into their applications.
## What Vector SQL Looks Like
Vector Structured Query Language (Vector SQL) is designed to teach LLMs how to query vector SQL databases and contains the following extra functions:
* `DISTANCE(column, query_vector)`: This function compares the distance between the column of vectors and the query vector either exactly or approximately.
* `NeuralArray(entity)`: This function converts an entity (for example, an image or a piece of text) into an embedding.
With these two functions, we can extend the standard SQL for vector search. For example, if you want to search for 10 relevant records to word `flower`, you can use the following SQL statement:
```sql
SELECT * FROM table
ORDER BY DISTANCE(vector, NeuralArray(flower))
LIMIT 10
```
The `DISTANCE` function comprises the following:
* The inner function, `NeuralArray(flower)`, converts the word `flower` into an embedding.
* This embedding is then serialized and injected into the `DISTANCE` function.
Vector SQL is an extended version of SQL that needs further translation based on the vector database used. For instance, many implementations have different names for the `DISTANCE` function. It is called `distance` in MyScale, and `L2Distance` or `CosineDistance` in Clickhouse. Additionally, based on the database, this function name will be translated differently.
## How to teach an LLM to write Vector SQL
Now that we understand the basic principles of vector SQL and its unique functions, let's use an LLM to help us to write a vector SQL query.
### 1. Teach an LLM What Standard Vector SQL is
First, we need to teach our LLM what standard vector SQL is. We aim to ensure that the LLM will do the following three things spontaneously when writing a vector SQL query:
* Extract the keywords from our question/prompt. It could be an object, a concept, or a topic.
* Decide which column to use to perform the similarity search. It should always choose a vector column for similarity.
* Translate the rest of our question's constraints into valid SQL.
### 2. Design the LLM Prompt
Having determined exactly what information the LLM requires to construct a vector SQL query, we can design the prompt as follows:
```python
# Here is an example of a vector SQL prompt
_prompt = f"""
You are a {dialect} expert. Given an input question, first, create a syntactically correct MyScale query to run, look at the query results, and return the answer to the input question.
The {dialect} query has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by this relevance.
When the query asks for {top_k} closest row, you must use this distance function to calculate the distance to the entity's array on the vector column and order by the distance to retrieve the relevant rows.
*NOTICE*: `DISTANCE(column, array)` only accepts an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user-defined function called `NeuralArray(entity)` to retrieve the entity's array.
"""
```
This prompt should do its job. But the more examples you add, the better it will be, like using the following vector SQL-to-text pair as a prompt:
**The SQL table create statement:**
```sql
------ table schema ------
CREATE TABLE "ChatPaper" (
abstract String,
id String,
vector Array(Float32),
categories Array(String),
pubdate DateTime,
title String,
authors Array(String),
primary_category String
) ENGINE = ReplicatedReplacingMergeTree()
ORDER BY id
PRIMARY KEY id
```
**The question and answer:**
```text
Question: What is PaperRank? What is the contribution of these works? Use papers with more than 2 categories.
SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
```
The more relevant examples you add to your prompt, the more the LLM's process of building the correct vector SQL query will improve.
Lastly, here are several extra tips to help you when designing your prompt:
* Cover all possible functions that might appear in any questions asked.
* Avoid monotonic questions.
* Alter the table schema, like adding/removing /modifying names and data types.
* Align the prompt's format.
## A Real-World Example: Using MyScale
Let's now build [**a real-world example**](https://huggingface.co/spaces/myscale/ChatData), set out in the following steps:

### Prepare the Database
We have prepared a playground for you with more than 2 million papers ready to query. You can access this data by adding the following Python code to your app.
```python
from sqlalchemy import create_engine
MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com"
MYSCALE_PORT = 443
MYSCALE_USER = "chatdata"
MYSCALE_PASSWORD = "myscale_rocks"
engine = create_engine(f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https')
```
If you like, you can skip the following steps, where we create the table and insert its data using the MyScale console, and jump to where we play with vector SQL and [create the `SQLDatabaseChain`](#create-sqldatabasechain) to query the database.
**Create the database table:**
```sql
CREATE TABLE default.ChatArXiv (
`abstract` String,
`id` String,
`vector` Array(Float32),
`metadata` Object('JSON'),
`pubdate` DateTime,
`title` String,
`categories` Array(String),
`authors` Array(String),
`comment` String,
`primary_category` String,
CONSTRAINT vec_len CHECK length(vector) = 768)
ENGINE = ReplacingMergeTree ORDER BY id SETTINGS index_granularity = 8192
```
**Insert the data:**
```sql
INSERT INTO ChatArXiv
SELECT
abstract, id, vector, metadata,
parseDateTimeBestEffort(JSONExtractString(toJSONString(metadata), 'pubdate')) AS pubdate,
JSONExtractString(toJSONString(metadata), 'title') AS title,
arrayMap(x->trim(BOTH '"' FROM x), JSONExtractArrayRaw(toJSONString(metadata), 'categories')) AS categories,
arrayMap(x->trim(BOTH '"' FROM x), JSONExtractArrayRaw(toJSONString(metadata), 'authors')) AS authors,
JSONExtractString(toJSONString(metadata), 'comment') AS comment,
JSONExtractString(toJSONString(metadata), 'primary_category') AS primary_category
FROM
s3(
'https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/data.part*.zst',
'JSONEachRow',
'abstract String, id String, vector Array(Float32), metadata Object(''JSON'')',
'zstd'
);
ALTER TABLE ChatArXiv ADD VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine');
```
### Create the `SQLDatabaseChain`
This LangChain feature is currently under [MyScale tech preview](https://github.com/myscale/langchain/tree/preview). You can install it by executing the following installation script:
```bash
python3 -m venv .venv
source .venv/bin/activate
# This is a technical preview of langchain from MyScale
pip3 install langchain@git+https://github.com/myscale/langchain.git@preview
```
Once you have installed this feature, the next step is to use it to query the database, as the following Python code demonstrates:
```python
from sqlalchemy import create_engine
MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com"
MYSCALE_PORT = 443
MYSCALE_USER = "chatdata"
MYSCALE_PASSWORD = "myscale_rocks"
# create connection to database
engine = create_engine(f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https')
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.callbacks import StdOutCallbackHandler
from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser
from langchain.sql_database import SQLDatabase
from langchain.chains.sql_database.prompt import MYSCALE_PROMPT
from langchain.llms import OpenAI
from langchain.chains.sql_database.base import SQLDatabaseChain
# this parser converts `NeuralArray()` into embeddings
output_parser = VectorSQLRetrieveAllOutputParser(
model=HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-xl')
)
# use the prompt above
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_prompt,
)
# bind the metadata to SqlAlchemy engine
metadata = MetaData(bind=engine)
# create SQLDatabaseChain
query_chain = SQLDatabaseChain.from_llm(
# GPT-3.5 generates valid SQL better
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
# use the predefined prompt, change it to your own prompt
prompt=MYSCALE_PROMPT,
# returns top 10 relevant documents
top_k=10,
# use result directly from DB
return_direct=True,
# use our database for retreival
db=SQLDatabase(engine, None, metadata),
# convert `NeuralArray()` into embeddings
sql_cmd_parser=output_parser)
# launch the chain!! And trace all chain calls in standard output
query_chain.run("Introduce some papers that uses Generative Adversarial Networks published around 2019.",
callbacks=[StdOutCallbackHandler()])
```
### Ask with `RetrievalQAwithSourcesChain`
You can also use this SQLDatabaseChain as a Retriever. You can plugin it in to some retrieval QA chains just like other retievers in LangChain.
```python
from langchain.retrievers import SQLDatabaseChainRetriever
from langchain.chains.qa_with_sources.map_reduce_prompt import combine_prompt_template
OPENAI_API_KEY = "sk-***"
# define how you serialize those structured data from database
document_with_metadata_prompt = PromptTemplate(
input_variables=["page_content", "id", "title", "authors", "pubdate", "categories"],
template="Content:\n\tTitle: {title}\n\tAbstract: {page_content}\n\t" +
"Authors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"
)
# define the prompt you use to ask the LLM
COMBINE_PROMPT = PromptTemplate(
template=combine_prompt_template, input_variables=["summaries", "question"])
# define a retriever with a SQLDatabaseChain
retriever = SQLDatabaseChainRetriever(
sql_db_chain=query_chain, page_content_key="abstract")
# finally, the ask chain to organize all of these
ask_chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(model_name='gpt-3.5-turbo-16k',
openai_api_key=OPENAI_API_KEY, temperature=0.6),
retriever=retriever,
chain_type='stuff',
chain_type_kwargs={
'prompt': COMBINE_PROMPT,
'document_prompt': document_with_metadata_prompt,
}, return_source_documents=True)
# Run the chain! and get the result from LLM
ask_chain("Introduce some papers that uses Generative Adversarial Networks published around 2019.",
callbacks=[StdOutCallbackHandler()])
```
We also provide a live demo on [**huggingface**](https://huggingface.co/spaces/myscale/ChatData) and the code is available on [**GitHub**](https://github.com/myscale/ChatData)! We used [**a customized Retrieval QA chain**](https://github.com/myscale/ChatData/blob/main/chains/arxiv_chains.py) to maximize the performance our search and ask pipeline with LangChain!
## In Conclusion
In reality, most LLMs hallucinate. The most practical way to reduce its appearance is to add extra facts (external knowledge) to your question. External knowledge is crucial to improving the performance of LLM systems, allowing for the efficient and accurate retrieval of answers. Every word counts, and you don't want to waste your money on unused information that is retrieved by inaccurate queries.
How?
Enter Vector SQL, allowing you to execute finely-grained vector searches to target and retrieve the required information.
Vector SQL is powerful and easy to learn for humans and machines. You can use many data types and functions to create complex queries. LLMs also like vector SQL, as its training dataset includes many references.
Lastly, it is possible to translate Vector SQL into many vector databases using different embedding models. We believe that is the future of vector databases.
Are interested in what we are doing? Join us on [discord](https://discord.gg/D2qpkqc4Jq) today!