Repository: myscale/ChatData Branch: main Commit: fcfc236dc5d4 Files: 52 Total size: 156.1 KB Directory structure: gitextract_zniflvil/ ├── .github/ │ └── workflows/ │ └── sync-with-huggingface.yml ├── .gitignore ├── LICENSE ├── README.md ├── app/ │ ├── app.py │ ├── backend/ │ │ ├── __init__.py │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── arxiv_callbacks.py │ │ │ ├── llm_thought_with_table.py │ │ │ ├── self_query_callbacks.py │ │ │ └── vector_sql_callbacks.py │ │ ├── chains/ │ │ │ ├── __init__.py │ │ │ ├── retrieval_qa_with_sources.py │ │ │ └── stuff_documents.py │ │ ├── chat_bot/ │ │ │ ├── __init__.py │ │ │ ├── chat.py │ │ │ ├── json_decoder.py │ │ │ ├── message_converter.py │ │ │ ├── private_knowledge_base.py │ │ │ ├── session_manager.py │ │ │ └── tools.py │ │ ├── constants/ │ │ │ ├── __init__.py │ │ │ ├── myscale_tables.py │ │ │ ├── prompts.py │ │ │ ├── streamlit_keys.py │ │ │ └── variables.py │ │ ├── construct/ │ │ │ ├── __init__.py │ │ │ ├── build_agents.py │ │ │ ├── build_all.py │ │ │ ├── build_chains.py │ │ │ ├── build_chat_bot.py │ │ │ ├── build_retriever_tool.py │ │ │ └── build_retrievers.py │ │ ├── retrievers/ │ │ │ ├── __init__.py │ │ │ ├── self_query.py │ │ │ ├── vector_sql_output_parser.py │ │ │ └── vector_sql_query.py │ │ ├── types/ │ │ │ ├── __init__.py │ │ │ ├── chains_and_retrievers.py │ │ │ ├── global_config.py │ │ │ └── table_config.py │ │ └── vector_store/ │ │ ├── __init__.py │ │ └── myscale_without_metadata.py │ ├── logger.py │ ├── requirements.txt │ └── ui/ │ ├── __init__.py │ ├── chat_page.py │ ├── home.py │ ├── retrievers.py │ └── utils.py └── docs/ ├── self-query.md └── vector-sql.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/sync-with-huggingface.yml ================================================ name: Sync with Hugging Face on: push: branches: - main paths: - .github/workflows/sync-with-huggingface.yml - app/** jobs: build: runs-on: ubuntu-latest steps: - name: Sync with Hugging Face uses: nateraw/huggingface-sync-action@v0.0.5 with: # The github repo you are syncing from. Required. github_repo_id: 'myscale/ChatData' # The Hugging Face repo id you want to sync to. (ex. 'username/reponame') # A repo with this name will be created if it doesn't exist. Required. huggingface_repo_id: 'myscale/ChatData' # Hugging Face token with write access. Required. # Here, we provide a token that we called `HF_TOKEN` when we added the secret to our GitHub repo. hf_token: ${{ secrets.HF_TOKEN }} # The type of repo you are syncing to: model, dataset, or space. # Defaults to space. repo_type: 'space' # If true and the Hugging Face repo doesn't already exist, it will be created # as a private repo. # # Note: this param has no effect if the repo already exists. private: false # If repo type is space, specify a space_sdk. One of: streamlit, gradio, or static # # This option is especially important if the repo has not been created yet. # It won't really be used if the repo already exists. space_sdk: 'streamlit' # If provided, subdirectory will determine which directory of the repo will be synced. # By default, this action syncs the entire GitHub repo. # # An example using this option can be seen here: # https://github.com/huggingface/fuego/blob/830ed98/.github/workflows/sync-with-huggingface.yml subdirectory: app ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ .idea/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # dataset files data/ .streamlit/ #*.ipynb .DS_Store ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 MyScale Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # ChatData 🔍 📖 ***We are constantly improving LangChain's self-query retriever. Some of the features are not merged yet.*** [![](https://dcbadge.vercel.app/api/server/D2qpkqc4Jq?compact=true&style=flat)](https://discord.gg/D2qpkqc4Jq) [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/myscaledb.svg?style=social&label=Follow%20%40MyScaleDB)](https://twitter.com/myscaledb)
Yet another chat-with-documents app, but supporting query over millions of files with [MyScale](https://myscale.com) and [LangChain](https://github.com/hwchase17/langchain/). ## Introduction 📖 ### Overview ChatData is a robust chat-with-documents application designed to extract information and provide answers by querying the MyScale free knowledge base or your uploaded documents. Powered by the Retrieval Augmented Generation (RAG) framework, ChatData leverages millions of Wikipedia pages and arXiv papers as its external knowledge base, with MyScale managing all data hosting tasks. Simply input your questions in natural language, and ChatData takes care of generating SQL, querying the data, and presenting the results. Enhancing your chat experience, ChatData introduces three key features. Let's delve into each of them in detail. #### Feature 1: Retriever Type MyScale works closely with LangChain, providing the easiest interface to build complex queries with LLM. **Self-querying retriever:** MyScale augmented LangChain's Self Querying Retriever, where the LLM can use more data types, for instance timestamps and array of strings, to build filters for the query. **VectorSQL:** SQL is powerful and can be used to construct complex search queries. Vector Structured Query Language (Vector SQL) is designed to teach LLMs how to query SQL vector databases. Besides the general data types and functions, vectorSQL contains extra functions like DISTANCE(column, query_vector)and NeuralArray(entity), with which we can extend the standard SQL for vector search. #### Feature 2: Session Management To enhance your experience and seamlessly continue interactions with existing sessions, ChatData has introduced the Session Management feature. You can easily customize your session ID and modify your prompt to guide ChatData in addressing your queries. With just a few clicks, you can enjoy smooth and personalized session interactions. #### Feature 3: Building Your Own Knowledge Base In addition to tapping into ChatData's external knowledge base powered by MyScale for answers, you also have the option to upload your own files and establish a personalized knowledge base. We've implemented the Unstructured API for this purpose, ensuring that only processed texts from your documents are stored, prioritizing your data privacy. In conclusion, with ChatData, you can effortlessly navigate through vast amounts of data, effortlessly accessing precisely what you need. Whether you're a researcher, a student, or a knowledge enthusiast, ChatData empowers you to explore academic papers and research documents like never before. Unlock the true potential of information retrieval with ChatData and discover a world of knowledge at your fingertips. ➡️ Dive in and experience ChatData on [Hugging Face](https://huggingface.co/spaces/myscale/ChatData)🤗 ![ChatData Homepage](assets/home.png) ### Data schema Database credentials: ```toml MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" MYSCALE_PORT = 443 MYSCALE_USER = "chatdata" MYSCALE_PASSWORD = "myscale_rocks" ``` #### *[NEW]* Table `wiki.Wikipedia` ChatData also provides you access to Wikipedia, a large knowledge base that contains about 36 million paragraphs under 5 million wiki pages. The knowledge base is a snapshot on 2022-12. You can query from this table with the public account [here](#data-schema). ```sql CREATE TABLE wiki.Wikipedia ( -- Record ID `id` String, -- Page title to this paragraph `title` String, -- Paragraph text `text` String, -- Page URL `url` String, -- Wiki page ID `wiki_id` UInt64, -- View statistics `views` Float32, -- Paragraph ID `paragraph_id` UInt64, -- Language ID `langs` UInt32, -- Feature vector to this paragraph `emb` Array(Float32), -- Vector Index VECTOR INDEX emb_idx emb TYPE MSTG('metric_type=Cosine'), CONSTRAINT emb_len CHECK length(emb) = 768) ENGINE = ReplacingMergeTree ORDER BY id SETTINGS index_granularity = 8192 ``` #### Table `default.ChatArXiv` ChatData brings millions of papers into your knowledge base. We imported 2.2 million papers with metadata info, which contains: 1. `id`: paper's arxiv id 2. `abstract`: paper's abstracts used as ranking criterion (with InstructXL) 3. `vector`: column that contains the vector array in `Array(Float32)` 4. `metadata`: LangChain VectorStore Compatible Columns 1. `metadata.authors`: paper's authors in *list of strings* 2. `metadata.abstract`: paper's abstracts used as ranking criterion (with InstructXL) 3. `metadata.titles`: papers's titles 4. `metadata.categories`: paper's categories in *list of strings* like ["cs.CV"] 5. `metadata.pubdate`: paper's date of publication in *ISO 8601 formated strings* 6. `metadata.primary_category`: paper's primary category in *strings* defined by arXiv 7. `metadata.comment`: some additional comment to the paper *Columns below are native columns in MyScale and can only be used as SQLDatabase* 5. `authors`: paper's authors in *list of strings* 6. `titles`: papers's titles 7. `categories`: paper's categories in *list of strings* like ["cs.CV"] 8. `pubdate`: paper's date of publication in *Date32 data type* (faster) 9. `primary_category`: paper's primary category in *strings* defined by arXiv 10. `comment`: some additional comment to the paper And for overall table schema, please refer to [table creation section in docs/self-query.md](docs/self-query.md#table-creation). If you want to use this database with `langchain.chains.sql_database.base.SQLDatabaseChain` or `langchain.retrievers.SQLDatabaseRetriever`, please follow guides on [data preparation section](docs/vector-sql.md#prepare-the-database) and [chain creation section](docs/vector-sql.md#create-the-sqldatabasechain) in docs/vector-sql.md ### Where can I get those arXiv data? - [From parquet files on S3](docs/self-query.md#insert-data) - Or Directly use MyScale database as service... for **FREE** ✨ ```python import clickhouse_connect client = clickhouse_connect.get_client( host='msc-950b9f1f.us-east-1.aws.myscale.com', port=443, username='chatdata', password='myscale_rocks' ) ``` ## Monthly Updates 🔥 (November-2023) - 🚀 Upload your documents and chat with your own knowledge bases with MyScale! - 💬 Chat with RAG-enabled agents on both ArXiv and Wikipedia knowledge base! - 📖 Wikipedia is available as knowledge base!! Feel FREE 💰 to ask with 36 million of paragraphs under 5 million titles! 💫 - 🤖 LLMs are now capable of writing **Vector SQL** - a extended SQL with vector search! Vector SQL allows you to **access MyScale faster and stronger**! This will **be added to LangChain** soon! ([PR 7454](https://github.com/hwchase17/langchain/pull/7454)) - 🌏 Customized Retrieval QA Chain that gives you **more information** on each PDF and **answer question in your native language**! - 🔧 Our contribution to LangChain that helps self-query retrievers [**filter with more types and functions**](https://python.langchain.com/docs/modules/data_connection/retrievers/how_to/self_query/myscale_self_query) - 🌟 **We just opened a FREE pod hosting data for ArXiv paper.** Anyone can try their own SQL with vector search!!! Feel the power when SQL meets vector search! See how to access the pod [here](#data-service). - 📚 We collected about **2 million papers on arxiv**! We are collecting more and we need your advice! - More coming... ## How to build your own app from scratch 🧱 ### Quickstart 1. Enter directory `app/` ```bash cd app/ ``` 2. Create an virtual environment ```bash python3 -m venv venv source venv/bin/activate ``` 3. Install dependencies ```bash python3 -m pip install -r requirements.txt ``` 4. Run the app! ```python # fill you OpenAI key in .streamlit/secrets.toml cp .streamlit/secrets.example.toml .streamlit/secrets.toml # start the app python3 -m streamlit run app.py ``` ### With LangChain SQLDatabaseRetrievers [*Read the full article*](https://myscale.com/blog/teach-your-llm-vector-sql/) - [Why Vector SQL?](https://myscale.com/blog/teach-your-llm-vector-sql/#automate-the-whole-process-with-sql-and-vector-search) - [How did LangChain and MyScale convert natural language to structured filters?](https://myscale.com/docs/en/advanced-applications/chatdata/#selfqueryretriever) - [How to make chain execution more responsive in LangChain?](https://myscale.com/docs/en/advanced-applications/chatdata/#add-callbacks) ### With LangChain Self-Query Retrievers [*Read the full article*](https://myscale.com/docs/en/advanced-applications/chatdata/) - [How this app is built?](https://docs.myscale.com/en/advanced-applications/chatdata) - [What is the overview pipeline?](https://docs.myscale.com/en/advanced-applications/chatdata/#design-the-query-pipeline) - [How did LangChain and MyScale convert natural language to structured filters?](https://docs.myscale.com/en/advanced-applications/chatdata/#selfqueryretriever) - [How to make chain execution more responsive in LangChain?](https://docs.myscale.com/en/advanced-applications/chatdata/#add-callbacks) ## Community 🌍 - Welcome to join our #ChatData channel in [Discord](https://discord.gg/jGCq2yZH) to discuss anything about ChatData. - Feel free to filing an issue or opening a PR against this repository. ## Special Thanks 👏 (Ordered Alphabetically) - [arXiv API](https://info.arxiv.org/help/api/index.html) for its open access interoperability to pre-printed papers. - [InstructorXL](https://huggingface.co/hkunlp/instructor-xl) for its promptable embeddings that improves retrieve performance. - [LangChain🦜️🔗](https://github.com/hwchase17/langchain/) for its easy-to-use and composable API designs and prompts. - [OpenChatPaper](https://github.com/liuyixin-louis/OpenChatPaper) for prompt design reference. - [The Alexandria Index](https://alex.macrocosm.so/download) for providing arXiv data index to the public. ================================================ FILE: app/app.py ================================================ import os import time import streamlit as st from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \ DATA_INITIALIZE_STARTED from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \ TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools from backend.types.global_config import GlobalConfig from logger import logger from ui.chat_page import chat_page from ui.home import render_home from ui.retrievers import render_retrievers # warnings.filterwarnings("ignore", category=UserWarning) def prepare_environment(): os.environ['TOKENIZERS_PARALLELISM'] = 'true' os.environ["LANGCHAIN_TRACING_V2"] = "false" # os.environ["LANGCHAIN_API_KEY"] = "" os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE'] os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY'] os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID'] os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN'] update_global_config(GlobalConfig( openai_api_base=st.secrets['OPENAI_API_BASE'], openai_api_key=st.secrets['OPENAI_API_KEY'], auth0_client_id=st.secrets['AUTH0_CLIENT_ID'], auth0_domain=st.secrets['AUTH0_DOMAIN'], myscale_user=st.secrets['MYSCALE_USER'], myscale_password=st.secrets['MYSCALE_PASSWORD'], myscale_host=st.secrets['MYSCALE_HOST'], myscale_port=st.secrets['MYSCALE_PORT'], query_model="gpt-3.5-turbo-0125", chat_model="gpt-3.5-turbo-0125", untrusted_api=st.secrets['UNSTRUCTURED_API'], myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True), )) # when refresh browser, all session keys will be cleaned. def initialize_session_state(): if DATA_INITIALIZE_STATUS not in st.session_state: st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}") if JUMP_QUERY_ASK not in st.session_state: st.session_state[JUMP_QUERY_ASK] = False logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}") def initialize_chat_data(): if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED: start_time = time.time() st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models() st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers() st.session_state[RETRIEVER_TOOLS] = update_retriever_tools() # mark data initialization finished. st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED end_time = time.time() logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, " f"session state keys: {list(st.session_state.keys())}") st.set_page_config( page_title="ChatData", page_icon="https://myscale.com/favicon.ico", initial_sidebar_state="expanded", layout="wide", ) prepare_environment() initialize_session_state() initialize_chat_data() if USER_NAME in st.session_state: chat_page() else: if st.session_state[JUMP_QUERY_ASK]: render_retrievers() else: render_home() ================================================ FILE: app/backend/__init__.py ================================================ ================================================ FILE: app/backend/callbacks/__init__.py ================================================ ================================================ FILE: app/backend/callbacks/arxiv_callbacks.py ================================================ import json import textwrap from typing import Dict, Any, List from langchain.callbacks.streamlit.streamlit_callback_handler import ( LLMThought, StreamlitCallbackHandler, ) class LLMThoughtWithKnowledgeBase(LLMThought): def on_tool_end( self, output: str, color=None, observation_prefix=None, llm_prefix=None, **kwargs: Any, ) -> None: try: self._container.markdown( "\n\n".join( ["### Retrieved Documents:"] + [ f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" for i, r in enumerate(json.loads(output)) ] ) ) except Exception as e: super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs) class ChatDataAgentCallBackHandler(StreamlitCallbackHandler): def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: if self._current_thought is None: self._current_thought = LLMThoughtWithKnowledgeBase( parent_container=self._parent_container, expanded=self._expand_new_thoughts, collapse_on_complete=self._collapse_completed_thoughts, labeler=self._thought_labeler, ) self._current_thought.on_llm_start(serialized, prompts) ================================================ FILE: app/backend/callbacks/llm_thought_with_table.py ================================================ from typing import Any, Dict, List import streamlit as st from langchain_core.outputs import LLMResult from streamlit.external.langchain import StreamlitCallbackHandler class ChatDataSelfQueryCallBack(StreamlitCallbackHandler): def __init__(self): super().__init__(st.container()) self._current_thought = None self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...") def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: self.progress_bar.progress(value=0.35, text="Communicate with LLM...") pass def on_chain_end(self, outputs, **kwargs) -> None: if len(kwargs['tags']) == 0: self.progress_bar.progress(value=0.75, text="Searching in DB...") def on_chain_start(self, serialized, inputs, **kwargs) -> None: pass def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: st.markdown("### Generate filter by LLM \n" "> Here we get `query_constructor` results \n\n") self.progress_bar.progress(value=0.5, text="Generate filter by LLM...") for item in response.generations: st.markdown(f"{item[0].text}") pass ================================================ FILE: app/backend/callbacks/self_query_callbacks.py ================================================ from typing import Dict, Any, List import streamlit as st from langchain.callbacks.streamlit.streamlit_callback_handler import ( StreamlitCallbackHandler, ) from langchain.schema.output import LLMResult class CustomSelfQueryRetrieverCallBackHandler(StreamlitCallbackHandler): def __init__(self): super().__init__(st.container()) self._current_thought = None self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery...") def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: self.progress_bar.progress(value=0.35, text="Communicate with LLM...") pass def on_chain_end(self, outputs, **kwargs) -> None: if len(kwargs['tags']) == 0: self.progress_bar.progress(value=0.75, text="Searching in DB...") pass def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: st.markdown("### Generate filter by LLM \n" "> Here we get `query_constructor` results \n\n") self.progress_bar.progress(value=0.5, text="Generate filter by LLM...") for item in response.generations: st.markdown(f"{item[0].text}") pass class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler): def __init__(self) -> None: super().__init__(st.container()) self.progress_bar = st.progress(value=0.2, text="Executing ChatData SelfQuery Chain...") def on_llm_start(self, serialized, prompts, **kwargs) -> None: pass def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: if len(kwargs['tags']) != 0: self.progress_bar.progress(value=0.5, text="We got filter info from LLM...") st.markdown("### Generate filter by LLM \n" "> Here we get `query_constructor` results \n\n") for item in response.generations: st.markdown(f"{item[0].text}") pass def on_chain_start(self, serialized, inputs, **kwargs) -> None: cid = ".".join(serialized["id"]) if cid.endswith(".CustomStuffDocumentChain"): self.progress_bar.progress(value=0.7, text="Asking LLM with related documents...") ================================================ FILE: app/backend/callbacks/vector_sql_callbacks.py ================================================ import streamlit as st from langchain.callbacks.streamlit.streamlit_callback_handler import ( StreamlitCallbackHandler, ) from langchain.schema.output import LLMResult from sql_formatter.core import format_sql class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler): def __init__(self) -> None: self.progress_bar = st.progress(value=0.0, text="Writing SQL...") self.status_bar = st.empty() self.prog_value = 0 self.prog_interval = 0.2 def on_llm_start(self, serialized, prompts, **kwargs) -> None: pass def on_llm_end( self, response: LLMResult, *args, **kwargs, ): text = response.generations[0][0].text if text.replace(" ", "").upper().startswith("SELECT"): st.markdown("### Generated Vector Search SQL Statement \n" "> This sql statement is generated by LLM \n\n") st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") self.prog_value += self.prog_interval self.progress_bar.progress( value=self.prog_value, text="Searching in DB...") def on_chain_start(self, serialized, inputs, **kwargs) -> None: cid = ".".join(serialized["id"]) self.prog_value += self.prog_interval self.progress_bar.progress( value=self.prog_value, text=f"Running Chain `{cid}`..." ) def on_chain_end(self, outputs, **kwargs) -> None: pass class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler): def __init__(self, table: str) -> None: self.progress_bar = st.progress(value=0.0, text="Writing SQL...") self.status_bar = st.empty() self.prog_value = 0 self.prog_interval = 0.1 self.table = table ================================================ FILE: app/backend/chains/__init__.py ================================================ ================================================ FILE: app/backend/chains/retrieval_qa_with_sources.py ================================================ import inspect from typing import Dict, Any, Optional, List from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.docstore.document import Document from logger import logger class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain): """QA with source chain for Chat ArXiv app with references This chain will automatically assign reference number to the article, Then parse it back to titles or anything else. """ def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: logger.info(f"\033[91m\033[1m{self._chain_type}\033[0m") _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) if accepts_run_manager: docs: List[Document] = self._get_docs(inputs, run_manager=_run_manager) else: docs: List[Document] = self._get_docs(inputs) # type: ignore[call-arg] answer = self.combine_documents_chain.run( input_documents=docs, callbacks=_run_manager.get_child(), **inputs ) # parse source with ref_id sources = [] ref_cnt = 1 for d in docs: ref_id = d.metadata['ref_id'] if f"Doc #{ref_id}" in answer: answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}") if f"#{ref_id}" in answer: title = d.metadata['title'].replace('\n', '') d.metadata['ref_id'] = ref_cnt answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]") sources.append(d) ref_cnt += 1 result: Dict[str, Any] = { self.answer_key: answer, self.sources_answer_key: sources, } if self.return_source_documents: result["source_documents"] = docs return result async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: raise NotImplementedError @property def _chain_type(self) -> str: return "custom_retrieval_qa_with_sources_chain" ================================================ FILE: app/backend/chains/stuff_documents.py ================================================ from typing import Any, List, Tuple from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.docstore.document import Document from langchain.schema.prompt_template import format_document class CustomStuffDocumentChain(StuffDocumentsChain): """Combine arxiv documents with PDF reference number""" def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: """Construct inputs from kwargs and docs. Format and the join all the documents together into one input with name `self.document_variable_name`. The pluck any additional variables from **kwargs. Args: docs: List of documents to format and then join into single input **kwargs: additional inputs to chain, will pluck any other required arguments from here. Returns: dictionary of inputs to LLMChain """ # Format each document according to the prompt doc_strings = [] for doc_id, doc in enumerate(docs): # add temp reference number in metadata doc.metadata.update({'ref_id': doc_id}) doc.page_content = doc.page_content.replace('\n', ' ') doc_strings.append(format_document(doc, self.document_prompt)) # Join the documents together to put them in the prompt. inputs = { k: v for k, v in kwargs.items() if k in self.llm_chain.prompt.input_variables } inputs[self.document_variable_name] = self.document_separator.join( doc_strings) return inputs def combine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Stuff all documents into one prompt and pass to LLM. Args: docs: List of documents to join together into one variable callbacks: Optional callbacks to pass along **kwargs: additional parameters to use to get inputs to LLMChain. Returns: The first element returned is the single string output. The second element returned is a dictionary of other keys to return. """ inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. output = self.llm_chain.predict(callbacks=callbacks, **inputs) return output, {} @property def _chain_type(self) -> str: return "custom_stuff_document_chain" ================================================ FILE: app/backend/chat_bot/__init__.py ================================================ ================================================ FILE: app/backend/chat_bot/chat.py ================================================ import time from os import environ from time import sleep import streamlit as st from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \ CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \ EL_BUILD_KB_WITH_FILES, \ EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \ USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \ EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS from backend.construct.build_agents import build_agents from backend.chat_bot.session_manager import SessionManager from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler from logger import logger environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] TOOL_NAMES = { "langchain_retriever_tool": "Self-querying retriever", "vecsql_retriever_tool": "Vector SQL", } def on_chat_submit(): with st.session_state.next_round.container(): with st.chat_message("user"): st.write(st.session_state.chat_input) with st.chat_message("assistant"): container = st.container() st_callback = ChatDataAgentCallBackHandler( container, collapse_completed_thoughts=False ) ret = st.session_state.agent( {"input": st.session_state.chat_input}, callbacks=[st_callback] ) logger.info(f"ret:{ret}") def clear_history(): if "agent" in st.session_state: st.session_state.agent.memory.clear() def back_to_main(): if USER_INFO in st.session_state: del st.session_state[USER_INFO] if USER_NAME in st.session_state: del st.session_state[USER_NAME] if JUMP_QUERY_ASK in st.session_state: del st.session_state[JUMP_QUERY_ASK] if EL_SESSION_SELECTOR in st.session_state: del st.session_state[EL_SESSION_SELECTOR] if CHAT_CURRENT_USER_SESSIONS in st.session_state: del st.session_state[CHAT_CURRENT_USER_SESSIONS] def refresh_sessions(): chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER] current_user_name = st.session_state[USER_NAME] current_user_sessions = chat_session_manager.list_sessions(current_user_name) if not isinstance(current_user_sessions, dict) or not current_user_sessions: # generate a default session for current user. chat_session_manager.add_session( user_id=current_user_name, session_id=f"{current_user_name}?default", system_prompt=DEFAULT_SYSTEM_PROMPT, ) st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name) current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS] else: st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions # load current user files. st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files( current_user_name ) # load current user private knowledge bases. st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \ st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name) logger.info(f"current user name: {current_user_name}, " f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, " f"user private files: {st.session_state[USER_PRIVATE_FILES]}") st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = { # public retrieval tools **st.session_state[RETRIEVER_TOOLS], # private retrieval tools **st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name), } # print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}") print(f"current_user_sessions is {current_user_sessions}") st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0] # process for session add and delete. def on_session_change_submit(): if "session_manager" in st.session_state and "session_editor" in st.session_state: try: for elem in st.session_state.session_editor["added_rows"]: if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: if elem["session_id"] != "" and "?" not in elem["session_id"]: st.session_state.session_manager.add_session( user_id=st.session_state.user_name, session_id=f"{st.session_state.user_name}?{elem['session_id']}", system_prompt=elem["system_prompt"], ) else: st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="❌") raise KeyError( "`session_id` shouldn't be neither empty nor contain char `?`." ) else: st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="❌") raise KeyError( "You should fill both `session_id` and `system_prompt` to add a column!" ) for elem in st.session_state.session_editor["deleted_rows"]: user_name = st.session_state[USER_NAME] session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id'] user_with_session_id = f"{user_name}?{session_id}" st.session_state.session_manager.remove_session(session_id=user_with_session_id) st.toast(f"session `{user_with_session_id}` removed.", icon="✅") refresh_sessions() except Exception as e: sleep(2) st.error(f"{type(e)}: {str(e)}") finally: st.session_state.session_editor["added_rows"] = [] st.session_state.session_editor["deleted_rows"] = [] refresh_agent() def create_private_knowledge_base_as_tool(): current_user_name = st.session_state[USER_NAME] if ( EL_PERSONAL_KB_NAME in st.session_state and EL_PERSONAL_KB_DESCRIPTION in st.session_state and EL_BUILD_KB_WITH_FILES in st.session_state and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0 and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0 and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0 ): st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base( user_id=current_user_name, tool_name=st.session_state[EL_PERSONAL_KB_NAME], tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION], files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]], ) refresh_sessions() else: st.session_state[EL_UPLOAD_FILES_STATUS].error( "You should fill all fields to build up a tool!" ) sleep(2) def remove_private_knowledge_bases(): if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]: private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE] private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove] # remove these private knowledge bases. st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases( user_id=st.session_state[USER_NAME], private_knowledge_bases=private_knowledge_base_names ) refresh_sessions() else: st.session_state[EL_UPLOAD_FILES_STATUS].error( "You should specify at least one private knowledge base to delete!" ) time.sleep(2) def refresh_agent(): with st.spinner("Initializing session..."): user_name = st.session_state[USER_NAME] session_id = st.session_state[EL_SESSION_SELECTOR]['session_id'] user_with_session_id = f"{user_name}?{session_id}" if EL_SELECTED_KBS in st.session_state: selected_knowledge_bases = st.session_state[EL_SELECTED_KBS] else: selected_knowledge_bases = ["Wikipedia + Vector SQL"] logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}") if EL_SESSION_SELECTOR in st.session_state: system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"] else: system_prompt = DEFAULT_SYSTEM_PROMPT st.session_state["agent"] = build_agents( session_id=user_with_session_id, tool_names=selected_knowledge_bases, system_prompt=system_prompt ) def add_file(): user_name = st.session_state[USER_NAME] if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0: st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="⚠️") sleep(2) return try: st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...") st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file( user_id=user_name, files=st.session_state[EL_UPLOAD_FILES] ) refresh_sessions() except ValueError as e: st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e)) sleep(2) def clear_files(): st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME]) refresh_sessions() ================================================ FILE: app/backend/chat_bot/json_decoder.py ================================================ import json import datetime class CustomJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, datetime.datetime): return datetime.datetime.isoformat(obj) return json.JSONEncoder.default(self, obj) class CustomJSONDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): json.JSONDecoder.__init__( self, object_hook=self.object_hook, *args, **kwargs) def object_hook(self, source): for k, v in source.items(): if isinstance(v, str): try: source[k] = datetime.datetime.fromisoformat(str(v)) except: pass return source ================================================ FILE: app/backend/chat_bot/message_converter.py ================================================ import hashlib import json import time from typing import Any from langchain.memory.chat_message_histories.sql import DefaultMessageConverter from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage, FunctionMessage from langchain.schema.messages import ToolMessage from sqlalchemy.orm import declarative_base from backend.chat_bot.tools import create_message_history_table def _message_from_dict(message: dict) -> BaseMessage: _type = message["type"] if _type == "human": return HumanMessage(**message["data"]) elif _type == "ai": return AIMessage(**message["data"]) elif _type == "system": return SystemMessage(**message["data"]) elif _type == "chat": return ChatMessage(**message["data"]) elif _type == "function": return FunctionMessage(**message["data"]) elif _type == "tool": return ToolMessage(**message["data"]) elif _type == "AIMessageChunk": message["data"]["type"] = "ai" return AIMessage(**message["data"]) else: raise ValueError(f"Got unexpected message type: {_type}") class DefaultClickhouseMessageConverter(DefaultMessageConverter): """The default message converter for SQLChatMessageHistory.""" def __init__(self, table_name: str): super().__init__(table_name) self.model_class = create_message_history_table(table_name, declarative_base()) def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: time_stamp = time.time() msg_id = hashlib.sha256( f"{session_id}_{message}_{time_stamp}".encode('utf-8')).hexdigest() user_id, _ = session_id.split("?") return self.model_class( id=time_stamp, msg_id=msg_id, user_id=user_id, session_id=session_id, type=message.type, addtionals=json.dumps(message.additional_kwargs), message=json.dumps({ "type": message.type, "additional_kwargs": {"timestamp": time_stamp}, "data": message.dict()}) ) def from_sql_model(self, sql_message: Any) -> BaseMessage: msg_dump = json.loads(sql_message.message) msg = _message_from_dict(msg_dump) msg.additional_kwargs = msg_dump["additional_kwargs"] return msg def get_sql_model_class(self) -> Any: return self.model_class ================================================ FILE: app/backend/chat_bot/private_knowledge_base.py ================================================ import hashlib from datetime import datetime from typing import List, Optional import pandas as pd from clickhouse_connect import get_client from langchain.schema.embeddings import Embeddings from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings from streamlit.runtime.uploaded_file_manager import UploadedFile from backend.chat_bot.tools import parse_files, extract_embedding from backend.construct.build_retriever_tool import create_retriever_tool from logger import logger class ChatBotKnowledgeTable: def __init__(self, host, port, username, password, embedding: Embeddings, parser_api_key: str, db="chat", kb_table="private_kb", tool_table="private_tool") -> None: super().__init__() personal_files_schema_ = f""" CREATE TABLE IF NOT EXISTS {db}.{kb_table}( entity_id String, file_name String, text String, user_id String, created_by DateTime, vector Array(Float32), CONSTRAINT cons_vec_len CHECK length(vector) = 768, VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') ) ENGINE = ReplacingMergeTree ORDER BY entity_id """ # `tool_name` represent private knowledge database name. private_knowledge_base_schema_ = f""" CREATE TABLE IF NOT EXISTS {db}.{tool_table}( tool_id String, tool_name String, file_names Array(String), user_id String, created_by DateTime, tool_description String ) ENGINE = ReplacingMergeTree ORDER BY tool_id """ self.personal_files_table = kb_table self.private_knowledge_base_table = tool_table config = MyScaleSettings( host=host, port=port, username=username, password=password, database=db, table=kb_table, ) self.client = get_client( host=config.host, port=config.port, username=config.username, password=config.password, ) self.client.command("SET allow_experimental_object_type=1") self.client.command(personal_files_schema_) self.client.command(private_knowledge_base_schema_) self.parser_api_key = parser_api_key self.vector_store = MyScaleWithoutJSON( embedding=embedding, config=config, must_have_cols=["file_name", "text", "created_by"], ) # List all files with given `user_id` def list_files(self, user_id: str): query = f""" SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph, arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars FROM {self.vector_store.config.database}.{self.personal_files_table} WHERE user_id = '{user_id}' GROUP BY file_name """ return [r for r in self.vector_store.client.query(query).named_results()] # Parse and embedding files def add_by_file(self, user_id, files: List[UploadedFile]): data = parse_files(self.parser_api_key, user_id, files) data = extract_embedding(self.vector_store.embeddings, data) self.vector_store.client.insert_df( table=self.personal_files_table, df=pd.DataFrame(data), database=self.vector_store.config.database, ) # Remove all files and private_knowledge_bases with given `user_id` def clear(self, user_id: str): self.vector_store.client.command( f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} " f"WHERE user_id='{user_id}'" ) query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} WHERE user_id = '{user_id}'""" self.vector_store.client.command(query) def create_private_knowledge_base( self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None ): self.vector_store.client.insert_df( self.private_knowledge_base_table, pd.DataFrame( [ { "tool_id": hashlib.sha256( (user_id + tool_name).encode("utf-8") ).hexdigest(), "tool_name": tool_name, # tool_name represent user's private knowledge base. "file_names": files, "user_id": user_id, "created_by": datetime.now(), "tool_description": tool_description, } ] ), database=self.vector_store.config.database, ) # Show all private knowledge bases with given `user_id` def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None): extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else "" query = f""" SELECT tool_name, tool_description, length(file_names) FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} WHERE user_id = '{user_id}' {extended_where} """ return [r for r in self.vector_store.client.query(query).named_results()] def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]): unique_list = list(set(private_knowledge_bases)) unique_list = ",".join([f"'{t}'" for t in unique_list]) query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table} WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]""" self.vector_store.client.command(query) def as_retrieval_tools(self, user_id, tool_name=None): logger.info(f"") private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name) retrievers = {} for private_kb in private_knowledge_bases: file_names_sql = f""" SELECT arrayJoin(file_names) FROM ( SELECT file_names FROM chat.private_tool WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}' ) """ logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}") res = self.client.query(file_names_sql) file_names = [] for line in res.result_rows: file_names.append(line[0]) file_names = ', '.join(f"'{item}'" for item in file_names) logger.info(f"user_id is {user_id}, file_names is {file_names}") retrievers[private_kb["tool_name"]] = create_retriever_tool( self.vector_store.as_retriever( search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"}, ), tool_name=private_kb["tool_name"], description=private_kb["tool_description"], ) return retrievers ================================================ FILE: app/backend/chat_bot/session_manager.py ================================================ import json from backend.chat_bot.tools import create_session_table, create_message_history_table from backend.constants.variables import GLOBAL_CONFIG try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base from datetime import datetime from sqlalchemy import orm, create_engine from logger import logger def get_sessions(engine, model_class, user_id): with orm.sessionmaker(engine)() as session: result = ( session.query(model_class) .where( model_class.session_id == user_id ) .order_by(model_class.create_by.desc()) ) return json.loads(result) class SessionManager: def __init__( self, session_state, host, port, username, password, db='chat', session_table='sessions', msg_table='chat_memory' ) -> None: if GLOBAL_CONFIG.myscale_enable_https == False: conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http' else: conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https' self.engine = create_engine(conn_str, echo=False) self.session_model_class = create_session_table( session_table, declarative_base()) self.session_model_class.metadata.create_all(self.engine) self.msg_model_class = create_message_history_table(msg_table, declarative_base()) self.msg_model_class.metadata.create_all(self.engine) self.session_orm = orm.sessionmaker(self.engine) self.session_state = session_state def list_sessions(self, user_id: str): with self.session_orm() as session: result = ( session.query(self.session_model_class) .where( self.session_model_class.user_id == user_id ) .order_by(self.session_model_class.create_by.desc()) ) sessions = [] for r in result: sessions.append({ "session_id": r.session_id.split("?")[-1], "system_prompt": r.system_prompt, }) return sessions # Update sys_prompt with given session_id def modify_system_prompt(self, session_id, sys_prompt): with self.session_orm() as session: obj = session.query(self.session_model_class).where( self.session_model_class.session_id == session_id).first() if obj: obj.system_prompt = sys_prompt session.commit() else: logger.warning(f"Session {session_id} not found") # Add a session(session_id, sys_prompt) def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs): with self.session_orm() as session: elem = self.session_model_class( user_id=user_id, session_id=session_id, system_prompt=system_prompt, create_by=datetime.now(), additionals=json.dumps(kwargs) ) session.add(elem) session.commit() # Remove a session and related chat history. def remove_session(self, session_id: str): with self.session_orm() as session: # remove session session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete() # remove related chat history. session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete() ================================================ FILE: app/backend/chat_bot/tools.py ================================================ import hashlib from datetime import datetime from multiprocessing.pool import ThreadPool from typing import List import requests from clickhouse_sqlalchemy import types, engines from langchain.schema.embeddings import Embeddings from sqlalchemy import Column, Text from streamlit.runtime.uploaded_file_manager import UploadedFile def parse_files(api_key, user_id, files: List[UploadedFile]): def parse_file(file: UploadedFile): headers = { "accept": "application/json", "unstructured-api-key": api_key, } data = {"strategy": "auto", "ocr_languages": ["eng"]} file_hash = hashlib.sha256(file.read()).hexdigest() file_data = {"files": (file.name, file.getvalue(), file.type)} response = requests.post( url="https://api.unstructured.io/general/v0/general", headers=headers, data=data, files=file_data ) json_response = response.json() if response.status_code != 200: raise ValueError(str(json_response)) texts = [ { "text": t["text"], "file_name": t["metadata"]["filename"], "entity_id": hashlib.sha256( (file_hash + t["text"]).encode() ).hexdigest(), "user_id": user_id, "created_by": datetime.now(), } for t in json_response if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 ] return texts with ThreadPool(8) as p: rows = [] for r in p.imap_unordered(parse_file, files): rows.extend(r) return rows def extract_embedding(embeddings: Embeddings, texts): if len(texts) > 0: embeddings = embeddings.embed_documents( [t["text"] for _, t in enumerate(texts)]) for i, _ in enumerate(texts): texts[i]["vector"] = embeddings[i] return texts raise ValueError("No texts extracted!") def create_message_history_table(table_name: str, base_class): class Message(base_class): __tablename__ = table_name id = Column(types.Float64) session_id = Column(Text) user_id = Column(Text) msg_id = Column(Text, primary_key=True) type = Column(Text) # should be additions, formal developer mistake spell it. addtionals = Column(Text) message = Column(Text) __table_args__ = ( engines.MergeTree( partition_by='session_id', order_by=('id', 'msg_id') ), {'comment': 'Store Chat History'} ) return Message def create_session_table(table_name: str, DynamicBase): class Session(DynamicBase): __tablename__ = table_name user_id = Column(Text) session_id = Column(Text, primary_key=True) system_prompt = Column(Text) # represent create time. create_by = Column(types.DateTime) # should be additions, formal developer mistake spell it. additionals = Column(Text) __table_args__ = ( engines.MergeTree(order_by=session_id), {'comment': 'Store Session and Prompts'} ) return Session ================================================ FILE: app/backend/constants/__init__.py ================================================ ================================================ FILE: app/backend/constants/myscale_tables.py ================================================ from typing import Dict, List import streamlit as st from langchain.chains.query_constructor.schema import AttributeInfo from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings from langchain.prompts import PromptTemplate from backend.types.table_config import TableConfig def hint_arxiv(): st.markdown("Here we provide some query samples.") st.markdown("- If you want to search papers with filters") st.markdown("1. ```What is a Bayesian network? Please use articles published later than Feb 2018 and with more " "than 2 categories and whose title like `computer` and must have `cs.CV` in its category. ```") st.markdown("2. ```What is a Bayesian network? Please use articles published later than Feb 2018```") st.markdown("- If you want to ask questions based on arxiv papers stored in MyScaleDB") st.markdown("1. ```Did Geoffrey Hinton wrote paper about Capsule Neural Networks?```") st.markdown("2. ```Introduce some applications of GANs published around 2019.```") st.markdown("3. ```请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些```") def hint_sql_arxiv(): st.markdown('''```sql CREATE TABLE default.ChatArXiv ( `abstract` String, `id` String, `vector` Array(Float32), `metadata` Object('JSON'), `pubdate` DateTime, `title` String, `categories` Array(String), `authors` Array(String), `comment` String, `primary_category` String, VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), CONSTRAINT vec_len CHECK length(vector) = 768) ENGINE = ReplacingMergeTree ORDER BY id ```''') def hint_wiki(): st.markdown("Here we provide some query samples.") st.markdown("1. ```Which company did Elon Musk found?```") st.markdown("2. ```What is Iron Gwazi?```") st.markdown("3. ```苹果的发源地是哪里?```") st.markdown("4. ```What is a Ring in mathematics?```") st.markdown("5. ```The producer of Rick and Morty.```") st.markdown("6. ```How low is the temperature on Pluto?```") def hint_sql_wiki(): st.markdown('''```sql CREATE TABLE wiki.Wikipedia ( `id` String, `title` String, `text` String, `url` String, `wiki_id` UInt64, `views` Float32, `paragraph_id` UInt64, `langs` UInt32, `emb` Array(Float32), VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'), CONSTRAINT emb_len CHECK length(emb) = 768) ENGINE = ReplacingMergeTree ORDER BY id ```''') MYSCALE_TABLES: Dict[str, TableConfig] = { 'Wikipedia': TableConfig( database="wiki", table="Wikipedia", table_contents="Snapshort from Wikipedia for 2022. All in English.", hint=hint_wiki, hint_sql=hint_sql_wiki, # doc_prompt 对 qa source chain 有用 doc_prompt=PromptTemplate( input_variables=["page_content", "url", "title", "ref_id", "views"], template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}" ), metadata_col_attributes=[ AttributeInfo(name="title", description="title of the wikipedia page", type="string"), AttributeInfo(name="text", description="paragraph from this wiki page", type="string"), AttributeInfo(name="views", description="number of views", type="float") ], must_have_col_names=['id', 'title', 'url', 'text', 'views'], vector_col_name="emb", text_col_name="text", metadata_col_name="metadata", emb_model=lambda: SentenceTransformerEmbeddings( model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2' ), tool_desc=("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages") ), 'ArXiv Papers': TableConfig( database="default", table="ChatArXiv", table_contents="Snapshort from Wikipedia for 2022. All in English.", hint=hint_arxiv, hint_sql=hint_sql_arxiv, doc_prompt=PromptTemplate( input_variables=["page_content", "id", "title", "ref_id", "authors", "pubdate", "categories"], template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\t" "Date of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}" ), metadata_col_attributes=[ AttributeInfo(name="pubdate", description="The year the paper is published", type="timestamp"), AttributeInfo(name="authors", description="List of author names", type="list[string]"), AttributeInfo(name="title", description="Title of the paper", type="string"), AttributeInfo(name="categories", description="arxiv categories to this paper", type="list[string]"), AttributeInfo(name="length(categories)", description="length of arxiv categories to this paper", type="int") ], must_have_col_names=['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'], vector_col_name="vector", text_col_name="abstract", metadata_col_name="metadata", emb_model=lambda: HuggingFaceInstructEmbeddings( model_name='hkunlp/instructor-xl', embed_instruction="Represent the question for retrieving supporting scientific papers: " ), tool_desc=( "search_among_scientific_papers", "Searches among scientific papers from ArXiv and returns research papers" ) ) } ALL_TABLE_NAME: List[str] = [config.table for config in MYSCALE_TABLES.values()] ================================================ FILE: app/backend/constants/prompts.py ================================================ from langchain.prompts import ChatPromptTemplate, \ SystemMessagePromptTemplate, HumanMessagePromptTemplate DEFAULT_SYSTEM_PROMPT = ( "Do your best to answer the questions. " "Feel free to use any tools available to look up " "relevant information. Please keep all details in query " "when calling search functions." ) COMBINE_PROMPT_TEMPLATE = ( "You are a helpful document assistant. " "Your task is to provide information and answer any questions related to documents given below. " "You should use the sections, title and abstract of the selected documents as your source of information " "and try to provide concise and accurate answers to any questions asked by the user. " "If you are unable to find relevant information in the given sections, " "you will need to let the user know that the source does not contain relevant information but still try to " "provide an answer based on your general knowledge. You must refer to the corresponding section name and page " "that you refer to when answering. " "The following is the related information about the document that will help you answer users' questions, " "you MUST answer it using question's language:\n\n {summaries} " "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n" ) COMBINE_PROMPT = ChatPromptTemplate.from_strings( string_messages=[(SystemMessagePromptTemplate, COMBINE_PROMPT_TEMPLATE), (HumanMessagePromptTemplate, '{question}')]) MYSCALE_PROMPT = """ You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question. MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance. When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows. *NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function. Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema. Pay attention to the data type when using functions. Always use `AND` to connect conditions in `WHERE` and never use comma. Make sure you never write an isolated `WHERE` keyword and never use undesired condition to conrtain the query. Use the following format: ======== table info ======== Question: "Question here" SQLQuery: "SQL Query to run" Here are some examples: ======== table info ======== CREATE TABLE "ChatPaper" ( abstract String, id String, vector Array(Float32), ) ENGINE = ReplicatedReplacingMergeTree() ORDER BY id PRIMARY KEY id Question: What is Feartue Pyramid Network? SQLQuery: SELECT ChatPaper.abstract, ChatPaper.id FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k} ======== table info ======== CREATE TABLE "ChatPaper" ( abstract String, id String, vector Array(Float32), categories Array(String), pubdate DateTime, title String, authors Array(String), primary_category String ) ENGINE = ReplicatedReplacingMergeTree() ORDER BY id PRIMARY KEY id Question: What is PaperRank? What is the contribution of those works? Use paper with more than 2 categories. SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k} ======== table info ======== CREATE TABLE "ChatArXiv" ( primary_category String categories Array(String), pubdate DateTime, abstract String, title String, paper_id String, vector Array(Float32), authors Array(String), ) ENGINE = MergeTree() ORDER BY paper_id PRIMARY KEY paper_id Question: Did Geoffrey Hinton wrote about Capsule Neural Networks? Please use articles published later than 2021. SQLQuery: SELECT ChatArXiv.title, ChatArXiv.paper_id, ChatArXiv.authors FROM ChatArXiv WHERE has(authors, 'Geoffrey Hinton') AND pubdate > parseDateTimeBestEffort('2021-01-01') ORDER BY DISTANCE(vector, NeuralArray(Capsule Neural Networks)) LIMIT {top_k} ======== table info ======== CREATE TABLE "PaperDatabase" ( abstract String, categories Array(String), vector Array(Float32), pubdate DateTime, id String, comments String, title String, authors Array(String), primary_category String ) ENGINE = MergeTree() ORDER BY id PRIMARY KEY id Question: Find papers whose abstract has Mutual Information in it. SQLQuery: SELECT PaperDatabase.title, PaperDatabase.id FROM PaperDatabase WHERE abstract ILIKE '%Mutual Information%' ORDER BY DISTANCE(vector, NeuralArray(Mutual Information)) LIMIT {top_k} Let's begin: ======== table info ======== {table_info} Question: {input} SQLQuery: """ ================================================ FILE: app/backend/constants/streamlit_keys.py ================================================ DATA_INITIALIZE_NOT_STATED = "data_initialize_not_started" DATA_INITIALIZE_STARTED = "data_initialize_started" DATA_INITIALIZE_COMPLETED = "data_initialize_completed" CHAT_SESSION = "sel_sess" CHAT_KNOWLEDGE_TABLE = "private_kb" CHAT_SESSION_MANAGER = "session_manager" CHAT_CURRENT_USER_SESSIONS = "current_sessions" EL_SESSION_SELECTOR = "el_session_selector" # all personal knowledge bases under a specific user. USER_PERSONAL_KNOWLEDGE_BASES = "user_tools" # all personal files under a specific user. USER_PRIVATE_FILES = "user_files" # public and personal knowledge bases. AVAILABLE_RETRIEVAL_TOOLS = "tools_with_users" EL_PERSONAL_KB_NEEDS_REMOVE = "el_personal_kb_needs_remove" # files needs upload EL_UPLOAD_FILES = "el_upload_files" EL_UPLOAD_FILES_STATUS = "el_upload_files_status" # use these files to build private knowledge base EL_BUILD_KB_WITH_FILES = "el_build_kb_with_files" # build a personal kb, given name. EL_PERSONAL_KB_NAME = "el_personal_kb_name" # build a personal kb, given description. EL_PERSONAL_KB_DESCRIPTION = "el_personal_kb_description" # knowledge bases selected by user. EL_SELECTED_KBS = "el_selected_kbs" ================================================ FILE: app/backend/constants/variables.py ================================================ from backend.types.global_config import GlobalConfig # ***** str variables ***** # EMBEDDING_MODEL_PREFIX = "embedding_model" CHAINS_RETRIEVERS_MAPPING = "sel_map_obj" LANGCHAIN_RETRIEVER = "langchain_retriever" VECTOR_SQL_RETRIEVER = "vecsql_retriever" TABLE_EMBEDDINGS_MAPPING = "embeddings" RETRIEVER_TOOLS = "tools" DATA_INITIALIZE_STATUS = "data_initialized" UI_INITIALIZED = "ui_initialized" JUMP_QUERY_ASK = "jump_query_ask" USER_NAME = "user_name" USER_INFO = "user_info" DIVIDER_HTML = """
""" DIVIDER_THIN_HTML = """
""" class RetrieverButtons: vector_sql_query_from_db = "vector_sql_query_from_db" vector_sql_query_with_llm = "vector_sql_query_with_llm" self_query_from_db = "self_query_from_db" self_query_with_llm = "self_query_with_llm" GLOBAL_CONFIG = GlobalConfig() def update_global_config(new_config: GlobalConfig): global GLOBAL_CONFIG GLOBAL_CONFIG.openai_api_base = new_config.openai_api_base GLOBAL_CONFIG.openai_api_key = new_config.openai_api_key GLOBAL_CONFIG.auth0_client_id = new_config.auth0_client_id GLOBAL_CONFIG.auth0_domain = new_config.auth0_domain GLOBAL_CONFIG.myscale_user = new_config.myscale_user GLOBAL_CONFIG.myscale_password = new_config.myscale_password GLOBAL_CONFIG.myscale_host = new_config.myscale_host GLOBAL_CONFIG.myscale_port = new_config.myscale_port GLOBAL_CONFIG.query_model = new_config.query_model GLOBAL_CONFIG.chat_model = new_config.chat_model GLOBAL_CONFIG.untrusted_api = new_config.untrusted_api GLOBAL_CONFIG.myscale_enable_https = new_config.myscale_enable_https ================================================ FILE: app/backend/construct/__init__.py ================================================ ================================================ FILE: app/backend/construct/build_agents.py ================================================ import os from typing import Sequence, List import streamlit as st from langchain.agents import AgentExecutor from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS from logger import logger try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base from langchain.chat_models import ChatOpenAI from langchain.prompts.chat import MessagesPlaceholder from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.schema.messages import SystemMessage from langchain.memory import SQLChatMessageHistory def create_agent_executor( agent_name: str, session_id: str, llm: BaseLanguageModel, tools: Sequence[BaseTool], system_prompt: str, **kwargs ) -> AgentExecutor: agent_name = agent_name.replace(" ", "_") conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}' chat_memory = SQLChatMessageHistory( session_id, connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https', custom_message_converter=DefaultClickhouseMessageConverter(agent_name)) memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory) prompt = OpenAIFunctionsAgent.create_prompt( system_message=SystemMessage(content=system_prompt), extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], ) agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) return AgentExecutor( agent=agent, tools=tools, memory=memory, verbose=True, return_intermediate_steps=True, **kwargs ) def build_agents( session_id: str, tool_names: List[str], model: str = "gpt-3.5-turbo-0125", temperature: float = 0.6, system_prompt: str = DEFAULT_SYSTEM_PROMPT ): chat_llm = ChatOpenAI( model_name=model, temperature=temperature, base_url=GLOBAL_CONFIG.openai_api_base, api_key=GLOBAL_CONFIG.openai_api_key, streaming=True ) tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS)) selected_tools = [tools[k] for k in tool_names] logger.info(f"create agent, use tools: {selected_tools}") agent = create_agent_executor( agent_name="chat_memory", session_id=session_id, llm=chat_llm, tools=selected_tools, system_prompt=system_prompt ) return agent ================================================ FILE: app/backend/construct/build_all.py ================================================ from logger import logger from typing import Dict, Any, Union import streamlit as st from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING from backend.construct.build_chains import build_retrieval_qa_with_sources_chain from backend.construct.build_retriever_tool import create_retriever_tool from backend.construct.build_retrievers import build_self_query_retriever, build_vector_sql_db_chain_retriever from backend.types.chains_and_retrievers import ChainsAndRetrievers, MetadataColumn from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, \ SentenceTransformerEmbeddings @st.cache_resource def load_embedding_model_for_table(table_name: str) -> \ Union[SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings]: with st.spinner(f"Loading embedding models for [{table_name}] ..."): embeddings = MYSCALE_TABLES[table_name].emb_model() return embeddings @st.cache_resource def load_embedding_models() -> Dict[str, Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]]: embedding_models = {} for table in MYSCALE_TABLES: embedding_models[table] = load_embedding_model_for_table(table) return embedding_models @st.cache_resource def update_retriever_tools(): retrievers_tools = {} for table in MYSCALE_TABLES: logger.info(f"Updating retriever tools [, ] for table {table}") retrievers_tools.update( { f"{table} + Self Querying": create_retriever_tool( st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["retriever"], *MYSCALE_TABLES[table].tool_desc ), f"{table} + Vector SQL": create_retriever_tool( st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["sql_retriever"], *MYSCALE_TABLES[table].tool_desc ), }) return retrievers_tools @st.cache_resource def build_chains_retriever_for_table(table_name: str) -> ChainsAndRetrievers: metadata_col_attributes = MYSCALE_TABLES[table_name].metadata_col_attributes self_query_retriever = build_self_query_retriever(table_name) self_query_chain = build_retrieval_qa_with_sources_chain( table_name=table_name, retriever=self_query_retriever, chain_name="Self Query Retriever" ) vector_sql_retriever = build_vector_sql_db_chain_retriever(table_name) vector_sql_chain = build_retrieval_qa_with_sources_chain( table_name=table_name, retriever=vector_sql_retriever, chain_name="Vector SQL DB Retriever" ) metadata_columns = [ MetadataColumn( name=attribute.name, desc=attribute.description, type=attribute.type ) for attribute in metadata_col_attributes ] return ChainsAndRetrievers( metadata_columns=metadata_columns, # for self query retriever=self_query_retriever, chain=self_query_chain, # for vector sql sql_retriever=vector_sql_retriever, sql_chain=vector_sql_chain ) @st.cache_resource def build_chains_and_retrievers() -> Dict[str, Dict[str, Any]]: chains_and_retrievers = {} for table in MYSCALE_TABLES: logger.info(f"Building chains, retrievers for table {table}") chains_and_retrievers[table] = build_chains_retriever_for_table(table).to_dict() return chains_and_retrievers ================================================ FILE: app/backend/construct/build_chains.py ================================================ from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate from langchain.schema import BaseRetriever import streamlit as st from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain from backend.chains.stuff_documents import CustomStuffDocumentChain from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.prompts import COMBINE_PROMPT from backend.constants.variables import GLOBAL_CONFIG def build_retrieval_qa_with_sources_chain( table_name: str, retriever: BaseRetriever, chain_name: str = "" ) -> CustomRetrievalQAWithSourcesChain: with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'): # Assign ref_id for documents custom_stuff_document_chain = CustomStuffDocumentChain( llm_chain=LLMChain( prompt=COMBINE_PROMPT, llm=ChatOpenAI( model_name=GLOBAL_CONFIG.chat_model, openai_api_key=GLOBAL_CONFIG.openai_api_key, temperature=0.6 ), ), document_prompt=MYSCALE_TABLES[table_name].doc_prompt, document_variable_name="summaries", ) chain = CustomRetrievalQAWithSourcesChain( retriever=retriever, combine_documents_chain=custom_stuff_document_chain, return_source_documents=True, max_tokens_limit=12000, ) return chain ================================================ FILE: app/backend/construct/build_chat_bot.py ================================================ from backend.chat_bot.private_knowledge_base import ChatBotKnowledgeTable from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION, CHAT_SESSION_MANAGER import streamlit as st from backend.constants.variables import GLOBAL_CONFIG, TABLE_EMBEDDINGS_MAPPING from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT from backend.chat_bot.session_manager import SessionManager def build_chat_knowledge_table(): if CHAT_KNOWLEDGE_TABLE not in st.session_state: st.session_state[CHAT_KNOWLEDGE_TABLE] = ChatBotKnowledgeTable( host=GLOBAL_CONFIG.myscale_host, port=GLOBAL_CONFIG.myscale_port, username=GLOBAL_CONFIG.myscale_user, password=GLOBAL_CONFIG.myscale_password, # embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["Wikipedia"], embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["ArXiv Papers"], parser_api_key=GLOBAL_CONFIG.untrusted_api, ) def initialize_session_manager(): if CHAT_SESSION not in st.session_state: st.session_state[CHAT_SESSION] = { "session_id": "default", "system_prompt": DEFAULT_SYSTEM_PROMPT, } if CHAT_SESSION_MANAGER not in st.session_state: st.session_state[CHAT_SESSION_MANAGER] = SessionManager( st.session_state, host=GLOBAL_CONFIG.myscale_host, port=GLOBAL_CONFIG.myscale_port, username=GLOBAL_CONFIG.myscale_user, password=GLOBAL_CONFIG.myscale_password, ) ================================================ FILE: app/backend/construct/build_retriever_tool.py ================================================ import json from typing import List from langchain.pydantic_v1 import BaseModel, Field from langchain.schema import BaseRetriever, Document from langchain.tools import Tool from backend.chat_bot.json_decoder import CustomJSONEncoder class RetrieverInput(BaseModel): query: str = Field(description="query to look up in retriever") def create_retriever_tool( retriever: BaseRetriever, tool_name: str, description: str ) -> Tool: """Create a tool to do retrieval of documents. Args: retriever: The retriever to use for the retrieval tool_name: The name for the tool. This will be passed to the language model, so should be unique and somewhat descriptive. description: The description for the tool. This will be passed to the language model, so should be descriptive. Returns: Tool class to pass to an agent """ def wrap(func): def wrapped_retrieve(*args, **kwargs): docs: List[Document] = func(*args, **kwargs) return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder) return wrapped_retrieve return Tool( name=tool_name, description=description, func=wrap(retriever.get_relevant_documents), coroutine=retriever.aget_relevant_documents, args_schema=RetrieverInput, ) ================================================ FILE: app/backend/construct/build_retrievers.py ================================================ import streamlit as st from langchain.chat_models import ChatOpenAI from langchain.prompts.prompt import PromptTemplate from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain.retrievers.self_query.myscale import MyScaleTranslator from langchain.utilities.sql_database import SQLDatabase from langchain.vectorstores import MyScaleSettings from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain from sqlalchemy import create_engine, MetaData from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.prompts import MYSCALE_PROMPT from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson from logger import logger @st.cache_resource def build_self_query_retriever(table_name: str) -> SelfQueryRetriever: with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."): myscale_connection = { "host": GLOBAL_CONFIG.myscale_host, "port": GLOBAL_CONFIG.myscale_port, "username": GLOBAL_CONFIG.myscale_user, "password": GLOBAL_CONFIG.myscale_password, } myscale_settings = MyScaleSettings( **myscale_connection, database=MYSCALE_TABLES[table_name].database, table=MYSCALE_TABLES[table_name].table, column_map={ "id": "id", "text": MYSCALE_TABLES[table_name].text_col_name, "vector": MYSCALE_TABLES[table_name].vector_col_name, # TODO refine MyScaleDB metadata in langchain. "metadata": MYSCALE_TABLES[table_name].metadata_col_name } ) myscale_vector_store = MyScaleWithoutMetadataJson( embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], config=myscale_settings, must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names ) with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."): retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm( llm=ChatOpenAI( model_name=GLOBAL_CONFIG.query_model, base_url=GLOBAL_CONFIG.openai_api_base, api_key=GLOBAL_CONFIG.openai_api_key, temperature=0 ), vectorstore=myscale_vector_store, document_contents=MYSCALE_TABLES[table_name].table_contents, metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes, use_original_query=False, structured_query_translator=MyScaleTranslator() ) return retriever @st.cache_resource def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever: """Get a group of relative docs from MyScaleDB""" with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'): if GLOBAL_CONFIG.myscale_enable_https == False: engine = create_engine( f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' f'/{MYSCALE_TABLES[table_name].database}?protocol=http' ) else: engine = create_engine( f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' f'/{MYSCALE_TABLES[table_name].database}?protocol=https' ) metadata = MetaData(bind=engine) logger.info(f"{table_name} metadata is : {metadata}") prompt = PromptTemplate( input_variables=["input", "table_info", "top_k"], template=MYSCALE_PROMPT, ) # Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column. output_parser = VectorSQLRetrieveOutputParser.from_embeddings( model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], # rewrite columns needs be searched. must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names ) # `db_chain` will generate a SQL vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm( llm=ChatOpenAI( model_name=GLOBAL_CONFIG.query_model, base_url=GLOBAL_CONFIG.openai_api_base, api_key=GLOBAL_CONFIG.openai_api_key, temperature=0 ), prompt=prompt, top_k=10, return_direct=True, db=SQLDatabase( engine, None, metadata, include_tables=[MYSCALE_TABLES[table_name].table], max_string_length=1024 ), sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type. native_format=True ) # `retriever` can search a group of documents with `db_chain` vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever( sql_db_chain=vector_sql_db_chain, page_content_key=MYSCALE_TABLES[table_name].text_col_name ) return vector_sql_db_chain_retriever ================================================ FILE: app/backend/retrievers/__init__.py ================================================ ================================================ FILE: app/backend/retrievers/self_query.py ================================================ from typing import List import pandas as pd import streamlit as st from langchain.retrievers import SelfQueryRetriever from langchain_core.documents import Document from langchain_core.runnables import RunnableConfig from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler from ui.utils import display from logger import logger def process_self_query(selected_table, query_type): place_holder = st.empty() logger.info( f"button-1: {RetrieverButtons.self_query_from_db}, " f"button-2: {RetrieverButtons.self_query_with_llm}, " f"content: {st.session_state.query_self}" ) with place_holder.expander('🪵 Chat Log', expanded=True): try: if query_type == RetrieverButtons.self_query_from_db: callback = CustomSelfQueryRetrieverCallBackHandler() retriever: SelfQueryRetriever = \ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"] config: RunnableConfig = {"callbacks": [callback]} relevant_docs = retriever.invoke( input=st.session_state.query_self, config=config ) callback.progress_bar.progress( value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!✅") st.markdown(f"### Self Query Results from `{selected_table}` \n" f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n") display( dataframe=pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs] ), columns_=MYSCALE_TABLES[selected_table].must_have_col_names ) elif query_type == RetrieverButtons.self_query_with_llm: # callback = CustomSelfQueryRetrieverCallBackHandler() callback = ChatDataSelfAskCallBackHandler() chain: CustomRetrievalQAWithSourcesChain = \ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"] chain_results = chain(st.session_state.query_self, callbacks=[callback]) callback.progress_bar.progress( value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!✅" ) documents_reference: List[Document] = chain_results["source_documents"] st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n" f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n") display( pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in documents_reference] ) ) st.markdown( f"### Answer from LLM \n" f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n" ) st.write(chain_results['answer']) st.markdown( f"### References from `{selected_table}`\n" f"> Here shows that which documents used by LLM \n\n" ) if len(chain_results['sources']) == 0: st.write("No documents is used by LLM.") else: display( dataframe=pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']] ), columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names, index='ref_id' ) st.markdown(DIVIDER_HTML, unsafe_allow_html=True) except Exception as e: st.write('Oops 😵 Something bad happened...') raise e ================================================ FILE: app/backend/retrievers/vector_sql_output_parser.py ================================================ from typing import Dict, Any, List from langchain_experimental.sql.vector_sql import VectorSQLOutputParser class VectorSQLRetrieveOutputParser(VectorSQLOutputParser): """Based on VectorSQLOutputParser It also modify the SQL to get all columns """ must_have_columns: List[str] @property def _type(self) -> str: return "vector_sql_retrieve_custom" def parse(self, text: str) -> Dict[str, Any]: text = text.strip() start = text.upper().find("SELECT") if start >= 0: end = text.upper().find("FROM") text = text.replace( text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns)) return super().parse(text) ================================================ FILE: app/backend/retrievers/vector_sql_query.py ================================================ from typing import List import pandas as pd import streamlit as st from langchain.schema import Document from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons from backend.callbacks.vector_sql_callbacks import VectorSQLSearchDBCallBackHandler, VectorSQLSearchLLMCallBackHandler from ui.utils import display from logger import logger def process_sql_query(selected_table: str, query_type: str): place_holder = st.empty() logger.info( f"button-1: {st.session_state[RetrieverButtons.vector_sql_query_from_db]}, " f"button-2: {st.session_state[RetrieverButtons.vector_sql_query_with_llm]}, " f"table: {selected_table}, " f"content: {st.session_state.query_sql}" ) with place_holder.expander('🪵 Query Log', expanded=True): try: if query_type == RetrieverButtons.vector_sql_query_from_db: callback = VectorSQLSearchDBCallBackHandler() vector_sql_retriever: VectorSQLDatabaseChainRetriever = \ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_retriever"] relevant_docs: List[Document] = vector_sql_retriever.get_relevant_documents( query=st.session_state.query_sql, callbacks=[callback] ) callback.progress_bar.progress( value=1.0, text="[Question -> LLM -> SQL Statement -> MyScaleDB -> Results] Done! ✅" ) st.markdown(f"### Vector Search Results from `{selected_table}` \n" f"> Here we get documents from MyScaleDB with given sql statement \n\n") display( pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs] ) ) elif query_type == RetrieverButtons.vector_sql_query_with_llm: callback = VectorSQLSearchLLMCallBackHandler(table=selected_table) vector_sql_chain: CustomRetrievalQAWithSourcesChain = \ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_chain"] chain_results = vector_sql_chain( inputs=st.session_state.query_sql, callbacks=[callback] ) callback.progress_bar.progress( value=1.0, text="[Question -> LLM -> SQL Statement -> MyScaleDB -> " "(Question,Results) -> LLM -> Results] Done! ✅" ) documents_reference: List[Document] = chain_results["source_documents"] st.markdown(f"### Vector Search Results from `{selected_table}` \n" f"> Here we get documents from MyScaleDB with given sql statement \n\n") display( pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in documents_reference] ) ) st.markdown( f"### Answer from LLM \n" f"> The response of the LLM when given the vector search results. \n\n" ) st.write(chain_results['answer']) st.markdown( f"### References from `{selected_table}`\n" f"> Here shows that which documents used by LLM \n\n" ) if len(chain_results['sources']) == 0: st.write("No documents is used by LLM.") else: display( dataframe=pd.DataFrame( [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']] ), columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names, index='ref_id' ) else: raise NotImplementedError(f"Unsupported query type: {query_type}") st.markdown(DIVIDER_HTML, unsafe_allow_html=True) except Exception as e: st.write('Oops 😵 Something bad happened...') raise e ================================================ FILE: app/backend/types/__init__.py ================================================ ================================================ FILE: app/backend/types/chains_and_retrievers.py ================================================ from typing import Dict from dataclasses import dataclass from typing import List, Any from langchain.retrievers import SelfQueryRetriever from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain @dataclass class MetadataColumn: name: str desc: str type: str @dataclass class ChainsAndRetrievers: metadata_columns: List[MetadataColumn] retriever: SelfQueryRetriever chain: CustomRetrievalQAWithSourcesChain sql_retriever: VectorSQLDatabaseChainRetriever sql_chain: CustomRetrievalQAWithSourcesChain def to_dict(self) -> Dict[str, Any]: return { "metadata_columns": self.metadata_columns, "retriever": self.retriever, "chain": self.chain, "sql_retriever": self.sql_retriever, "sql_chain": self.sql_chain } ================================================ FILE: app/backend/types/global_config.py ================================================ from dataclasses import dataclass from typing import Optional @dataclass class GlobalConfig: openai_api_base: Optional[str] = "" openai_api_key: Optional[str] = "" auth0_client_id: Optional[str] = "" auth0_domain: Optional[str] = "" myscale_user: Optional[str] = "" myscale_password: Optional[str] = "" myscale_host: Optional[str] = "" myscale_port: Optional[int] = 443 query_model: Optional[str] = "" chat_model: Optional[str] = "" untrusted_api: Optional[str] = "" myscale_enable_https: Optional[bool] = True ================================================ FILE: app/backend/types/table_config.py ================================================ from typing import Callable from langchain.chains.query_constructor.schema import AttributeInfo from langchain.prompts import PromptTemplate from dataclasses import dataclass from typing import List @dataclass class TableConfig: database: str table: str table_contents: str # column names must_have_col_names: List[str] vector_col_name: str text_col_name: str metadata_col_name: str # hint for UI hint: Callable hint_sql: Callable # for langchain doc_prompt: PromptTemplate metadata_col_attributes: List[AttributeInfo] emb_model: Callable tool_desc: tuple ================================================ FILE: app/backend/vector_store/__init__.py ================================================ ================================================ FILE: app/backend/vector_store/myscale_without_metadata.py ================================================ from typing import Any, Optional, List from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores.myscale import MyScale, MyScaleSettings from logger import logger class MyScaleWithoutMetadataJson(MyScale): def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None: try: super().__init__(embedding, config, **kwargs) except Exception as e: logger.error(e) self.must_have_cols: List[str] = must_have_cols def _build_qstr( self, q_emb: List[float], topk: int, where_str: Optional[str] = None ) -> str: q_emb_str = ",".join(map(str, q_emb)) if where_str: where_str = f"PREWHERE {where_str}" else: where_str = "" q_str = f""" SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)} FROM {self.config.database}.{self.config.table} {where_str} ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}]) AS dist {self.dist_order} LIMIT {topk} """ return q_str def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]: q_str = self._build_qstr(embedding, k, where_str) try: return [ Document( page_content=r[self.config.column_map["text"]], metadata={k: r[k] for k in self.must_have_cols}, ) for r in self.client.query(q_str).named_results() ] except Exception as e: logger.error( f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return [] ================================================ FILE: app/logger.py ================================================ import logging def setup_logger(): logger_ = logging.getLogger('chat-data') logger_.setLevel(logging.INFO) if not logger_.handlers: console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(filename)s - %(funcName)s - %(levelname)s - %(message)s - [Thread ID: %(thread)d]' ) console_handler.setFormatter(formatter) logger_.addHandler(console_handler) return logger_ logger = setup_logger() ================================================ FILE: app/requirements.txt ================================================ langchain==0.2.1 langchain-community==0.2.1 langchain-core==0.2.1 langchain-experimental==0.0.59 langchain-openai==0.1.7 sentence-transformers==2.2.2 InstructorEmbedding pandas streamlit==1.36.0 streamlit-extras streamlit-auth0-component altair==4.2.2 clickhouse-connect openai==1.35.3 lark tiktoken sql-formatter sqlalchemy==1.4.48 clickhouse-sqlalchemy ================================================ FILE: app/ui/__init__.py ================================================ ================================================ FILE: app/ui/chat_page.py ================================================ import datetime import json import pandas as pd import streamlit as st from langchain_core.messages import HumanMessage, FunctionMessage from streamlit.delta_generator import DeltaGenerator from backend.chat_bot.json_decoder import CustomJSONDecoder from backend.constants.streamlit_keys import CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, \ EL_UPLOAD_FILES_STATUS, USER_PRIVATE_FILES, EL_BUILD_KB_WITH_FILES, \ EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \ USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \ CHAT_KNOWLEDGE_TABLE, EL_UPLOAD_FILES, EL_SELECTED_KBS from backend.constants.variables import DIVIDER_HTML, USER_NAME, RETRIEVER_TOOLS from backend.construct.build_chat_bot import build_chat_knowledge_table, initialize_session_manager from backend.chat_bot.chat import refresh_sessions, on_session_change_submit, refresh_agent, \ create_private_knowledge_base_as_tool, \ remove_private_knowledge_bases, add_file, clear_files, clear_history, back_to_main, on_chat_submit def render_session_manager(): with st.expander("🤖 Session Management"): if CHAT_CURRENT_USER_SESSIONS not in st.session_state: refresh_sessions() st.markdown("Here you can update `session_id` and `system_prompt`") st.markdown("- Click empty row to add a new item") st.markdown("- If needs to delete an item, just click it and press `DEL` key") st.markdown("- Don't forget to submit your change.") st.data_editor( data=st.session_state[CHAT_CURRENT_USER_SESSIONS], num_rows="dynamic", key="session_editor", use_container_width=True, ) st.button("⏫ Submit", on_click=on_session_change_submit, type="primary") def render_session_selection(): with st.expander("✅ Session Selection", expanded=True): st.selectbox( "Choose a `session` to chat", options=st.session_state[CHAT_CURRENT_USER_SESSIONS], index=None, key=EL_SESSION_SELECTOR, format_func=lambda x: x["session_id"], on_change=refresh_agent, ) def render_files_manager(): with st.expander("📃 **Upload your personal files**", expanded=False): st.markdown("- Files will be parsed by [Unstructured API](https://unstructured.io/api-key).") st.markdown("- All files will be converted into vectors and stored in [MyScaleDB](https://myscale.com/).") st.file_uploader(label="⏫ **Upload files**", key=EL_UPLOAD_FILES, accept_multiple_files=True) # st.markdown("### Uploaded Files") st.dataframe( data=st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(st.session_state[USER_NAME]), use_container_width=True, ) st.session_state[EL_UPLOAD_FILES_STATUS] = st.empty() col_1, col_2 = st.columns(2) with col_1: st.button(label="Upload files", on_click=add_file) with col_2: st.button(label="Clear all files and tools", on_click=clear_files) def _render_create_personal_knowledge_bases(div: DeltaGenerator): with div: st.markdown("- If you haven't upload your personal files, please upload them first.") st.markdown("- Select some **files** to build your `personal knowledge base`.") st.markdown("- Once the your `personal knowledge base` is built, " "it will answer your questions using information from your personal **files**.") st.multiselect( label="⚡️Select some files to build a **personal knowledge base**", options=st.session_state[USER_PRIVATE_FILES], placeholder="You should upload some files first", key=EL_BUILD_KB_WITH_FILES, format_func=lambda x: x["file_name"], ) st.text_input( label="⚡️Personal knowledge base name", value="get_relevant_documents", key=EL_PERSONAL_KB_NAME ) st.text_input( label="⚡️Personal knowledge base description", value="Searches from some personal files.", key=EL_PERSONAL_KB_DESCRIPTION, ) st.button( label="Build 🔧", on_click=create_private_knowledge_base_as_tool ) def _render_remove_personal_knowledge_bases(div: DeltaGenerator): with div: st.markdown("> Here is all your personal knowledge bases.") if USER_PERSONAL_KNOWLEDGE_BASES in st.session_state and len(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) > 0: st.dataframe(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) else: st.warning("You don't have any personal knowledge bases, please create a new one.") st.multiselect( label="Choose a personal knowledge base to delete", placeholder="Choose a personal knowledge base to delete", options=st.session_state[USER_PERSONAL_KNOWLEDGE_BASES], format_func=lambda x: x["tool_name"], key=EL_PERSONAL_KB_NEEDS_REMOVE, ) st.button("Delete", on_click=remove_private_knowledge_bases, type="primary") def render_personal_tools_build(): with st.expander("🔨 **Build your personal knowledge base**", expanded=True): create_new_kb, kb_manager = st.tabs(["Create personal knowledge base", "Personal knowledge base management"]) _render_create_personal_knowledge_bases(create_new_kb) _render_remove_personal_knowledge_bases(kb_manager) def render_knowledge_base_selector(): with st.expander("🙋 **Select some knowledge bases to query**", expanded=True): st.markdown("- Knowledge bases come in two types: `public` and `private`.") st.markdown("- All users can access our `public` knowledge bases.") st.markdown("- Only you can access your `personal` knowledge bases.") options = st.session_state[RETRIEVER_TOOLS].keys() if AVAILABLE_RETRIEVAL_TOOLS in st.session_state: options = st.session_state[AVAILABLE_RETRIEVAL_TOOLS] st.multiselect( label="Select some knowledge base tool", placeholder="Please select some knowledge bases to query", options=options, default=["Wikipedia + Self Querying"], key=EL_SELECTED_KBS, on_change=refresh_agent, ) def chat_page(): # initialize resources build_chat_knowledge_table() initialize_session_manager() # render sidebar with st.sidebar: left, middle, right = st.columns([1, 1, 2]) with left: st.button(label="↩️ Log Out", help="log out and back to main page", on_click=back_to_main) with right: st.markdown(f"👤 `{st.session_state[USER_NAME]}`") st.markdown(DIVIDER_HTML, unsafe_allow_html=True) render_session_manager() render_session_selection() render_files_manager() render_personal_tools_build() render_knowledge_base_selector() # render chat history if "agent" not in st.session_state: refresh_agent() for msg in st.session_state.agent.memory.chat_memory.messages: speaker = "user" if isinstance(msg, HumanMessage) else "assistant" if isinstance(msg, FunctionMessage): with st.chat_message(name="from knowledge base", avatar="📚"): st.write( f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" ) st.write("Retrieved from knowledge base:") try: st.dataframe( pd.DataFrame.from_records( json.loads(msg.content, cls=CustomJSONDecoder) ), use_container_width=True, ) except Exception as e: st.warning(e) st.write(msg.content) else: if len(msg.content) > 0: with st.chat_message(speaker): # print(type(msg), msg.dict()) st.write( f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" ) st.write(f"{msg.content}") st.session_state["next_round"] = st.empty() from streamlit import _bottom with _bottom: col1, col2 = st.columns([1, 16]) with col1: st.button("🗑️", help="Clean chat history", on_click=clear_history, type="secondary") with col2: st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input") ================================================ FILE: app/ui/home.py ================================================ import base64 from streamlit_extras.add_vertical_space import add_vertical_space from streamlit_extras.card import card from streamlit_extras.colored_header import colored_header from streamlit_extras.mention import mention from streamlit_extras.tags import tagger_component from logger import logger import os import streamlit as st from auth0_component import login_button from backend.constants.variables import JUMP_QUERY_ASK, USER_INFO, USER_NAME, DIVIDER_HTML, DIVIDER_THIN_HTML from streamlit_extras.let_it_rain import rain def render_home(): render_home_header() # st.divider() # st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True) add_vertical_space(5) render_home_content() # st.divider() st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True) render_home_footer() def render_home_header(): logger.info("render home header") st.header("ChatData - Your Intelligent Assistant") st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True) st.markdown("> [ChatData](https://github.com/myscale/ChatData) \ is developed by [MyScale](https://myscale.com/), \ it's an integration of [LangChain](https://www.langchain.com/) \ and [MyScaleDB](https://github.com/myscale/myscaledb)") tagger_component( "Keywords:", ["MyScaleDB", "LangChain", "VectorSearch", "ChatBot", "GPT", "arxiv", "wikipedia", "Personal Knowledge Base 📚"], color_name=["darkslateblue", "green", "orange", "darkslategrey", "red", "crimson", "darkcyan", "darkgrey"], ) text, col1, col2, col3, _ = st.columns([1, 1, 1, 1, 4]) with text: st.markdown("Related:") with col1.container(): mention( label="streamlit", icon="streamlit", url="https://streamlit.io/", write=True ) with col2.container(): mention( label="langchain", icon="🦜🔗", url="https://www.langchain.com/", write=True ) with col3.container(): mention( label="streamlit-extras", icon="🪢", url="https://github.com/arnaudmiribel/streamlit-extras", write=True ) def _render_self_query_chain_content(): col1, col2 = st.columns([1, 1], gap='large') with col1.container(): st.image(image='./assets/home_page_background_1.png', caption=None, width=None, use_column_width=True, clamp=False, channels="RGB", output_format="PNG") with col2.container(): st.header("VectorSearch & SelfQuery with Sources") st.info("In this sample, you will learn how **LangChain** integrates with **MyScaleDB**.") st.markdown("""This example demonstrates two methods for integrating MyScale into LangChain: [Vector SQL](https://api.python.langchain.com/en/latest/sql/langchain_experimental.sql.vector_sql.VectorSQLDatabaseChain.html) and [Self-querying retriever](https://python.langchain.com/v0.2/docs/integrations/retrievers/self_query/myscale_self_query/). For each method, you can choose one of the following options: 1. `Retrieve from MyScaleDB ➡️` - The LLM (GPT) converts user queries into SQL statements with vector search, executes these searches in MyScaleDB, and retrieves relevant content. 2. `Retrieve and answer with LLM ➡️` - After retrieving relevant content from MyScaleDB, the user query along with the retrieved content is sent to the LLM (GPT), which then provides a comprehensive answer.""") add_vertical_space(3) _, middle, _ = st.columns([2, 1, 2], gap='small') with middle.container(): st.session_state[JUMP_QUERY_ASK] = st.button("Try sample", use_container_width=False, type="secondary") def _render_chat_bot_content(): col1, col2 = st.columns(2, gap='large') with col1.container(): st.image(image='./assets/home_page_background_2.png', caption=None, width=None, use_column_width=True, clamp=False, channels="RGB", output_format="PNG") with col2.container(): st.header("Chat Bot") st.info("Now you can try our chatbot, this chatbot is built with MyScale and LangChain.") st.markdown("- You need to go to [https://myscale-chatdata.hf.space/](https://myscale-chatdata.hf.space/) " "to log in successfully, otherwise the auth service will not work.") st.markdown("- You can upload your own PDF files and build your own knowledge base. \ (This is just a sample application. Please do not upload important or confidential files.)") st.markdown("- A default session will be assigned as your initial chat session. \ You can create and switch to other sessions to jump between different chat conversations.") add_vertical_space(1) _, middle, _ = st.columns([1, 2, 1], gap='small') with middle.container(): if USER_NAME not in st.session_state: login_button(clientId=os.environ["AUTH0_CLIENT_ID"], domain=os.environ["AUTH0_DOMAIN"], key="auth0") # if user_info: # user_name = user_info.get("nickname", "default") + "_" + user_info.get("email", "null") # st.session_state[USER_NAME] = user_name # print(user_info) def render_home_content(): logger.info("render home content") _render_self_query_chain_content() add_vertical_space(3) _render_chat_bot_content() def render_home_footer(): logger.info("render home footer") st.write( "Please follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!" ) st.write( "For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!") st.write("Our [privacy policy](https://myscale.com/privacy/), [terms of service](https://myscale.com/terms/)") # st.write( # "Recommended to use the standalone version of Chat-Data, " # "available [here](https://myscale-chatdata.hf.space/)." # ) if st.session_state.auth0 is not None: st.session_state[USER_INFO] = dict(st.session_state.auth0) if 'email' in st.session_state[USER_INFO]: email = st.session_state[USER_INFO]["email"] else: email = f"{st.session_state[USER_INFO]['nickname']}@{st.session_state[USER_INFO]['sub']}" st.session_state["user_name"] = email del st.session_state.auth0 st.rerun() if st.session_state.jump_query_ask: st.rerun() ================================================ FILE: app/ui/retrievers.py ================================================ import streamlit as st from streamlit_extras.add_vertical_space import add_vertical_space from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons from backend.retrievers.self_query import process_self_query from backend.retrievers.vector_sql_query import process_sql_query from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO def back_to_main(): if USER_INFO in st.session_state: del st.session_state[USER_INFO] if USER_NAME in st.session_state: del st.session_state[USER_NAME] if JUMP_QUERY_ASK in st.session_state: del st.session_state[JUMP_QUERY_ASK] def _render_table_selector() -> str: col1, col2 = st.columns(2) with col1: selected_table = st.selectbox( label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.', options=MYSCALE_TABLES.keys(), ) MYSCALE_TABLES[selected_table].hint() with col2: add_vertical_space(1) st.info(f"Here is your selected public knowledge base schema in MyScaleDB", icon='📚') MYSCALE_TABLES[selected_table].hint_sql() return selected_table def render_retrievers(): st.button("⬅️ Back", key="back_sql", on_click=back_to_main) st.subheader('Please choose a public knowledge base to search.') selected_table = _render_table_selector() tab_sql, tab_self_query = st.tabs( tabs=['Vector SQL', 'Self-querying Retriever'] ) with tab_sql: render_tab_sql(selected_table) with tab_self_query: render_tab_self_query(selected_table) def render_tab_sql(selected_table: str): st.warning( "When you input a query with filtering conditions, you need to ensure that your filters are applied only to " "the metadata we provide. This table allows filters to be established on the following metadata fields:", icon="⚠️") st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"]) cols = st.columns([8, 3, 3, 2]) cols[0].text_input("Input your question:", key='query_sql') with cols[1].container(): add_vertical_space(2) st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db) with cols[2].container(): add_vertical_space(2) st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm) if st.session_state[RetrieverButtons.vector_sql_query_from_db]: process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db) if st.session_state[RetrieverButtons.vector_sql_query_with_llm]: process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm) def render_tab_self_query(selected_table): st.warning( "When you input a query with filtering conditions, you need to ensure that your filters are applied only to " "the metadata we provide. This table allows filters to be established on the following metadata fields:", icon="⚠️") st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"]) cols = st.columns([8, 3, 3, 2]) cols[0].text_input("Input your question:", key='query_self') with cols[1].container(): add_vertical_space(2) st.button("Retrieve from MyScaleDB ➡️", key='search_self') with cols[2].container(): add_vertical_space(2) st.button("Retrieve and answer with LLM ➡️", key='ask_self') if st.session_state.search_self: process_self_query(selected_table, RetrieverButtons.self_query_from_db) if st.session_state.ask_self: process_self_query(selected_table, RetrieverButtons.self_query_with_llm) ================================================ FILE: app/ui/utils.py ================================================ import streamlit as st def display(dataframe, columns_=None, index=None): if len(dataframe) > 0: if index: dataframe.set_index(index) if columns_: st.dataframe(dataframe[columns_]) else: st.dataframe(dataframe) else: st.write( "Sorry 😵 we didn't find any articles related to your query.\n\n" "Maybe the LLM is too naughty that does not follow our instruction... \n\n" "Please try again and use verbs that may match the datatype.", unsafe_allow_html=True ) ================================================ FILE: docs/self-query.md ================================================ # HOW-TO: Build a ChatPDF App over millions of documents with LangChain and MyScale in 30 Minutes Chatting with GPT about a single academic paper is relatively straightforward by providing the document as the language model context. Chatting with millions of research papers is also simple... as long as you choose the right vector database.
Large language models (LLM) are powerful NLP tools. One of the most significant benefits of LLMs, such as ChatGPT, is that you can use them to build tools that allow you to interact (or chat) with documents, such as PDF copies of research or academic papers, based on their topics and content. Many implementations of chat-with-document apps already exist, like [ChatPaper](https://github.com/kaixindelele/ChatPaper), [OpenChatPaper](https://github.com/liuyixin-louis/OpenChatPaper), and [DocsMind](https://github.com/3Alan/DocsMind). But, many of these implementations seem complicated, with simplistic search utilities only featuring elementary keyword searches filtering on basic metadata such as year and subject. Therefore, developing a ChatPDF-like app to interact with millions of academic/research papers makes sense. You can chat with the data in your natural language, combining both semantic and structural attributes, asking questions such as "What is a neural network?" and providing additional qualifiers like "Please use articles published by Geoffrey Hinton after 2018." The primary purpose of this article is to help you build your own ChatPDF app that allows you to interact (chat) with millions of academic/research papers using LangChain and MyScale. > This app should take about 30 minutes to create. But before we start, let’s look at the following diagrammatic workflow of the whole process:
> Even though we describe how to develop this LLM-based chat app, we have a sample app on [GitHub](https://github.com/myscale/ChatData), including access to a [read-only vector database](../app/.streamlit/secrets.example.toml), further simplifying the app-creation process. ## Prepare the Data As described in this image, the first step is to prepare the data. > We recommend you to use our open database for this app. The credentials are in the example configuration: `$PROJECT_DIR/app/.streamlit/secrets.toml`. Or you can follow the instruction below to create your own database. It takes about 20 minutes to create the database. We have sourced our data: a usable list of abstracts and arXiv IDs from the Alexandria Index through the [Macrocosm website](https://alex.macrocosm.so/). Using this data and interrogating the arXiv Open API, we can significantly enhance the query experience by retrieving a much richer set of metadata, including year, subject, release date, category, and author. > We have prepared the data using the [arXiv Open API](https://info.arxiv.org/help/api/index.html) to simplify the app-creation process. Now that we have the data, let's dive into the next steps: ## Create the Table The first step is to create a table schema. Sign into [myscale.com](http://myscale.com/) and create a free [cluster](http://console.myscale.com/). After creating the cluster, the next step is to create a table that is compatible with our MyScale VectorStore. > There will be a workspace sidebar on your left. This is the place where you can play with pure SQL. Use the following script to create a MyScale VectorStore table: ```sql CREATE TABLE default.langchain ( `abstract` String, `id` String, `vector` Array(Float32), `metadata` Object('JSON'), CONSTRAINT vec_len CHECK length(vector) = 768) ENGINE = ReplacingMergeTree ORDER BY id ``` You can also create a table initializing a MyScale VectorStore in LangChain with Python, as the following code sample describes: ```python from langchain.vectorstores import MyScale, MyScaleSettings config = MyScaleSetting(host="", port=8443, ...) doc_search = MyScale(embedding_function, config) ``` Both these methods do the same job. Great. Let’s move onto the next step. ## Insert the Data > We have appended additional metadata, like the publication date and authors, to each ArXiv entry. Our data is hosted on [Amazon S3](https://docs.aws.amazon.com/AmazonS3/latest/userguide/Welcome.html), supported by [Clickhouse table functions](https://clickhouse.com/docs/en/sql-reference/table-functions/s3). The compressed `jsonl` file is available [here](https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/full.json.zst). You can also import data via partitioned dataset (113 parts) from our AWS S3 bucket (the URL will look like `https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/data.part*.0.jsonl.zst`) to your MyScale Cloud with [S3 table function](https://clickhouse.com/docs/en/sql-reference/table-functions/s3). > You can also upload the data onto Google Cloud Platform and use the same SQL insert query to import this data. To insert data into this table, you still have the same options as when creating a VectorStore table: MyScale SQL workspace or LangChain. ### MyScale SQL Workspace Paste the following SQL statement into the MyScale SQL workspace and click **Run**: ```sql INSERT INTO langchain SELECT * FROM s3( 'https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/data.*.jsonl.zst', 'JSONEachRow', 'abstract String, id String, vector Array(Float32), metadata Object(''JSON'')', 'zstd' ) ``` Then you need to build the vector index with this SQL: ```sql ALTER TABLE langchain ADD VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine') ``` ### LangChain The second option is to insert the data into the table using LangChain for better control over the data insertion process. Add the following code snippet to your app’s code: ```python # ! unzstd data-*.jsonl.zst import json from langchain.docstore.document import Document def str2doc(_str): j = json.loads(_str) return Document(page_content=j['abstract'], metadata=j['metadata']) with open('func_call_data.jsonl') as f: docs = [str2doc(l) for l in f.readlines()] ``` ## Design the Query Pipeline Most LLM-based applications need an automated pipeline for querying and returning an answer to the query. > Chat-with-LLM apps must generally retrieve reference documents before querying their models (LLMs). Let's look at the step-by-step workflow describing how the app answers the user's queries, as illustrated in the following diagram:
1. **Ask for the user's input/questions.** This input must be as concise as possible. In most cases, it should be at most several sentences. 2. **Construct a DB query from the user's input.** The query is simple for vector databases. All you need to do is to extract the relevant embedding from the vector database. However, for enhanced accuracy, it is advisable to filter your query. For instance, let's assume the user only wants the latest papers rather than all the papers in the returned embedding, but the returned embedding includes all the research papers. By way of solving this challenge, you can add metadata filters to the query to filter out the correct information. 3. **Parse the retrieved documents from VectorStore.** The data returned from the vector store is not in a native format that the LLM understands. You must parse it and insert it into your prompt templates. Sometimes you need to add more metadata to these templates, like the date created, authors, or document categories. This metadata will help LLM improve the quality of its answer. 4. **Ask the LLM.** This process is straightforward as long as you are familiar with the LLM's API and have properly designed prompts. 5. **Fetch the answer** Returning the answer is straightforward for simple applications. But, if the question is complex, additional effort is required to provide more information to the user; for example, adding the LLM's source data. Additionally, adding reference numbers to the prompt can help you find the source and reduce your prompt's size by avoiding repeating content, such as the document's title. In practice, LangChain has a good framework to work with. We used the following functions to build this pipeline: * `RetrievalQAWithSourcesChain` * `SelfQueryRetriever` ### `SelfQueryRetriever` This function defines the interaction between the VectorStore and your app. Let’s dive deeper into how a self-query retriever works, as illustrated in the following diagram:
LangChain’s `SelfQueryRetriever` defines a universal filter for every VectorStore, including several `comparators` for comparing values and `operators`, combining these conditions to form a filter. The LLM will generate a filter rule based on these `comparators` and `operators`. All VectorStore providers will implement a `FilterTranslator` to translate the given universal filter to the correct arguments that call the VectorStore. LangChain's universal solution provides a complete package for new operators, comparators, and vector store providers. However, you are limited to the pre-defined elements inside it. In the prompt filter context, MyScale includes more powerful and flexible filters. We have added more data types, like lists and timestamps, and more functions, like string pattern matching and `CONTAIN` comparators for lists, offering more options for data storage and query design. > We contributed to LangChain's Self-Query retrievers to make them more powerful, resulting in self-query retrievers that provide more freedom to the LLM when designing the query. Look at [what else MyScale can do with metadata filters](https://myscale.com/blog/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints). Here is the code for it, written using LangChain: ```python from langchain.vectorstores import MyScale from langchain.embeddings import HuggingFaceInstructEmbeddings # Assuming you data is ready on MyScale Cloud embeddings = HuggingFaceInstructEmbeddings() doc_search = MyScale(embeddings) # Define metadata fields and their types # Descriptions are important. That's where LLM know how to use that metadata. metadata_field_info=[ AttributeInfo( name="pubdate", description="The year the paper is published", type="timestamp", ), AttributeInfo( name="authors", description="List of author names", type="list[string]", ), AttributeInfo( name="title", description="Title of the paper", type="string", ), AttributeInfo( name="categories", description="arxiv categories to this paper", type="list[string]" ), AttributeInfo( name="length(categories)", description="length of arxiv categories to this paper", type="int" ), ] # Now build a retriever with LLM, a vector store and your metadata info retriever = SelfQueryRetriever.from_llm( OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0), doc_search, "Scientific papers indexes with abstracts", metadata_field_info, use_original_query=True) ``` ### `RetrievalQAWithSourcesChain` This function constructs the prompts containing the documents. Data should be formatted into LLM readable strings, like JSON or Markdown that contain the document’s info in it.
As highlighted above, once the document data has been retrieved from the vector store, it must be formatted into LLM-readable strings, like JSON or Markdown. LangChain uses the following chains to build these LLM-readable strings: * `MapReduceDocumentsChain` * `StuffDocumentsChain` `MapReduceDocumentChain` gathers all the documents the vector store returns and normalizes them into a standard format. It maps the documents to a prompt template and concatenates them together. `StuffDocumentChain` works on those formatted documents, inserting them as context with task descriptions as prefixes and examples as suffixes. Add the following code snippet to your app’s code so that your app will format the vector store data into LLM-readable documents. ```python chain = RetrievalQAWithSourcesChain.from_llm( llm=OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0. retriever=retriever, return_source_documents=True,) ``` ## Run the Chain With these components, we can now search and answer the user's questions with a scalable vector store. Try it yourself! ```python ret = st.session_state.chain(st.session_state.query, callbacks=[callback]) # You can find the answer from LLM in the field `answer` st.markdown(f"### Answer from LLM\n{ret['answer']}\n### References") # and source documents in `sources` and `source_documents` docs = ret['source_documents'] ``` Not responsive? ### Add Callbacks The chain works just fine, but you might have a complaint: It needs to be faster! Yes, the chain will be slow as it will construct a filtered vector query (one LLM call), retrieve data from VectorStore and ask LLM (another LLM call). Consequently, the total execution time will be about 10~20 seconds. Don't worry; LangChain has your back. It includes [Callbacks](https://python.langchain.com/en/latest/modules/callbacks/getting_started.html?highlight=Callbacks) that you can use to increase your app's responsiveness. In our example, we added several callback functions to update a progress bar: ```python class ChatArXivAskCallBackHandler(StreamlitCallbackHandler): def __init__(self) -> None: # You will have a progress bar when this callback is initialized self.progress_bar = st.progress(value=0.0, text='Searching DB...') self.status_bar = st.empty() self.prog_value = 0.0 # You can use chain names to control the progress self.prog_map = { 'langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain': 0.2, 'langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain': 0.4, 'langchain.chains.combine_documents.stuff.StuffDocumentsChain': 0.8 } def on_llm_start(self, serialized, prompts, **kwargs) -> None: pass def on_text(self, text: str, **kwargs) -> None: pass def on_chain_start(self, serialized, inputs, **kwargs) -> None: # the name is in list, so you can join them in strings. cid = '.'.join(serialized['id']) if cid != 'langchain.chains.llm.LLMChain': self.progress_bar.progress(value=self.prog_map[cid], text=f'Running Chain `{cid}`...') self.prog_value = self.prog_map[cid] else: self.prog_value += 0.1 self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...') def on_chain_end(self, outputs, **kwargs) -> None: pass ``` Now your app will have a pretty progress bar just like ours.
## In Conclusion This is how an LLM app should be built with LangChain! Today we provided a brief overview of how to build a simple LLM app that chats with the MyScale VectorStore, and also explained how to use chains in the query pipeline. We hope this article helps you when you design your LLM-based app architecture from the ground up. You can also ask for help on our [Discord server](https://discord.gg/D2qpkqc4Jq). We are happy to help, whether on vector databases, LLM apps, or other fantastic stuff. You are also welcomed to use our open database to build your own apps! We believe you can more awesome apps with this self-query retriever with MyScale! Happy Coding! See you in the following article! ================================================ FILE: docs/vector-sql.md ================================================ # Teach Your LLM to Search Using Vector SQL and Answer With Facts From Database A vector database that supports Structured Query Language can store more than vectors. Common data types like timestamps and arrays can be accessed and filtered within the database, which improves the accuracy and efficiency of vector search queries. Accurate results from the database can teach LLMs to speak with facts, which reduces hallucination and enhance the quality and credibility of answers from LLM. ## What is Hallucination? Large Language Models are advanced AI systems that can answer a wide range of questions. Although they provide informative responses on topics they know, they are not always accurate on unfamiliar topics. This phenomenon is known as **hallucination**. Before we look at an example of an LLM hallucination, let's consider a definition of the term "hallucination" as described by [Wikipedia.com](https://en.wikipedia.org/wiki/Hallucination): > "A hallucination is a perception in the absence of an external stimulus that has the qualities of a real perception." Moreover: > "Hallucinations are vivid, substantial, and are perceived to be located in external objective space." In other words, a hallucination is an error in (or a false) perception of something real or concrete. For example, a Large Language Model was asked what LLM hallucinations are, with the answer being: ![LLM Hallucinations. source: [aruna-x](https://dev.to/aruna/how-to-minimize-llm-hallucinations-2el7)](../assets/hallucination.png) Therefore, the question begs, how do we improve on (or fix) this result? The concise answer is to add facts to your question, such as providing the LLM definition before or after you ask the question. For instance: > An LLM is a Large Language Model, an artificial neural network that models how humans talk and write. Please tell me, what is LLM hallucination? The public domain answer to this question, provided by ChatGPT, is: ![ChatGPT LLM Hallucinations Response](../assets/chatgpt-hallucination-response.png) **Note:** The reason for the first sentence, "Apologies for the confusion in my earlier response," is that we asked ChatGPT our first question, what LLM hallucinations are, before giving it our second prompt: "An LLM..." These additions have improved the quality of the answer. At least it no longer thinks an LLM hallucination is a "Late-Life Migraine Accompaniment!" 😆 ## External Knowledge Reduces Hallucinations At this juncture, it is absolutely crucial to note that an LLM is not infallible nor the ultimate authority on all knowledge. LLMs are trained on large amounts of data and learn patterns in language, but they may not always have access to the most up-to-date information or have a comprehensive understanding of complex topics. What now? How do you increase the chance of reducing LLM hallucinations? The solution to this problem is to include supporting documents to the query (or prompt) to guide the LLM toward a more accurate and informed response. Like humans, it needs to learn from these documents to answer your question accurately and correctly. Helpful documents can come from many sources, including a search engine like Google or Bing and a digital library like Arxiv, among others, providing an interface to search for relevant passages. Using a database is also a good choice, providing a more flexible and private query interface. Knowledge retrieved from sources must be relevant to the question/prompt. There are several ways to retrieve relevant documents, including: * **Keyword-based:** Searching for keywords in plain text, suitable for an exact match on terms. * **Vector search-based:** Searching for records closer to embeddings, helpful in searching for appropriate paraphrases or general documents. Nowadays, vector searches are popular since they can solve paraphrase problems and calculate paragraph meanings. Vector search is not a one-size-fits-all solution; it should be paired with specific filters to maintain its performance, especially when searching massive volumes of records. For example, should you only want to retrieve knowledge about physics (as a subject), you must filter out all information about any other subjects. Thus, the LLM will not be confused by knowledge from other disciplines. ## Automate the Whole Process with SQL... and Vector Search The LLM should also learn to query data from its data sources before answering the questions, automating the whole process. Actually, LLMs are already capable of writing SQL queries and following instructions. ![Vector Pipeline](../assets/pipeline.png) SQL is powerful and can be used to construct complex search queries. It supports many different data types and functions. And it allows us to write a vector search in SQL with `ORDER BY` and `LIMIT`, treating the similarity score between embeddings as a column `distance`. Pretty straightforward, isn't it? > See the next section, [What Vector SQL Looks Like](#what-vector-sql-looks-like), for more information on structuring a vector SQL query. There are significant benefits to using vector SQL to build complex search queries, including: * Increased flexibility for data type and function support * Improved efficiency because SQL is highly optimized and executed inside the database * Is human-readable and easy to learn as it is an extension of standard SQL * Is LLM-friendly **Note:** Many SQL examples and tutorials are available on the Internet. LLMs are familiar with standard SQL as well as some of its dialects. Apart from MyScale, many SQL database solutions like Clickhouse and PostgreSQL are adding vector search to their existing functionality, allowing users to use vector SQL and LLMs to answer questions on complex topics. Similarly, an increasing number of application developers are starting to integrate vector searches with SQL into their applications. ## What Vector SQL Looks Like Vector Structured Query Language (Vector SQL) is designed to teach LLMs how to query vector SQL databases and contains the following extra functions: * `DISTANCE(column, query_vector)`: This function compares the distance between the column of vectors and the query vector either exactly or approximately. * `NeuralArray(entity)`: This function converts an entity (for example, an image or a piece of text) into an embedding. With these two functions, we can extend the standard SQL for vector search. For example, if you want to search for 10 relevant records to word `flower`, you can use the following SQL statement: ```sql SELECT * FROM table ORDER BY DISTANCE(vector, NeuralArray(flower)) LIMIT 10 ``` The `DISTANCE` function comprises the following: * The inner function, `NeuralArray(flower)`, converts the word `flower` into an embedding. * This embedding is then serialized and injected into the `DISTANCE` function. Vector SQL is an extended version of SQL that needs further translation based on the vector database used. For instance, many implementations have different names for the `DISTANCE` function. It is called `distance` in MyScale, and `L2Distance` or `CosineDistance` in Clickhouse. Additionally, based on the database, this function name will be translated differently. ## How to teach an LLM to write Vector SQL Now that we understand the basic principles of vector SQL and its unique functions, let's use an LLM to help us to write a vector SQL query. ### 1. Teach an LLM What Standard Vector SQL is First, we need to teach our LLM what standard vector SQL is. We aim to ensure that the LLM will do the following three things spontaneously when writing a vector SQL query: * Extract the keywords from our question/prompt. It could be an object, a concept, or a topic. * Decide which column to use to perform the similarity search. It should always choose a vector column for similarity. * Translate the rest of our question's constraints into valid SQL. ### 2. Design the LLM Prompt Having determined exactly what information the LLM requires to construct a vector SQL query, we can design the prompt as follows: ```python # Here is an example of a vector SQL prompt _prompt = f""" You are a {dialect} expert. Given an input question, first, create a syntactically correct MyScale query to run, look at the query results, and return the answer to the input question. The {dialect} query has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by this relevance. When the query asks for {top_k} closest row, you must use this distance function to calculate the distance to the entity's array on the vector column and order by the distance to retrieve the relevant rows. *NOTICE*: `DISTANCE(column, array)` only accepts an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user-defined function called `NeuralArray(entity)` to retrieve the entity's array. """ ``` This prompt should do its job. But the more examples you add, the better it will be, like using the following vector SQL-to-text pair as a prompt: **The SQL table create statement:** ```sql ------ table schema ------ CREATE TABLE "ChatPaper" ( abstract String, id String, vector Array(Float32), categories Array(String), pubdate DateTime, title String, authors Array(String), primary_category String ) ENGINE = ReplicatedReplacingMergeTree() ORDER BY id PRIMARY KEY id ``` **The question and answer:** ```text Question: What is PaperRank? What is the contribution of these works? Use papers with more than 2 categories. SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k} ``` The more relevant examples you add to your prompt, the more the LLM's process of building the correct vector SQL query will improve. Lastly, here are several extra tips to help you when designing your prompt: * Cover all possible functions that might appear in any questions asked. * Avoid monotonic questions. * Alter the table schema, like adding/removing /modifying names and data types. * Align the prompt's format. ## A Real-World Example: Using MyScale Let's now build [**a real-world example**](https://huggingface.co/spaces/myscale/ChatData), set out in the following steps: ![A Real-World Example: Using MyScale](../assets/myscale-example.png) ### Prepare the Database We have prepared a playground for you with more than 2 million papers ready to query. You can access this data by adding the following Python code to your app. ```python from sqlalchemy import create_engine MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" MYSCALE_PORT = 443 MYSCALE_USER = "chatdata" MYSCALE_PASSWORD = "myscale_rocks" engine = create_engine(f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https') ``` If you like, you can skip the following steps, where we create the table and insert its data using the MyScale console, and jump to where we play with vector SQL and [create the `SQLDatabaseChain`](#create-sqldatabasechain) to query the database. **Create the database table:** ```sql CREATE TABLE default.ChatArXiv ( `abstract` String, `id` String, `vector` Array(Float32), `metadata` Object('JSON'), `pubdate` DateTime, `title` String, `categories` Array(String), `authors` Array(String), `comment` String, `primary_category` String, CONSTRAINT vec_len CHECK length(vector) = 768) ENGINE = ReplacingMergeTree ORDER BY id SETTINGS index_granularity = 8192 ``` **Insert the data:** ```sql INSERT INTO ChatArXiv SELECT abstract, id, vector, metadata, parseDateTimeBestEffort(JSONExtractString(toJSONString(metadata), 'pubdate')) AS pubdate, JSONExtractString(toJSONString(metadata), 'title') AS title, arrayMap(x->trim(BOTH '"' FROM x), JSONExtractArrayRaw(toJSONString(metadata), 'categories')) AS categories, arrayMap(x->trim(BOTH '"' FROM x), JSONExtractArrayRaw(toJSONString(metadata), 'authors')) AS authors, JSONExtractString(toJSONString(metadata), 'comment') AS comment, JSONExtractString(toJSONString(metadata), 'primary_category') AS primary_category FROM s3( 'https://myscale-demo.s3.ap-southeast-1.amazonaws.com/chat_arxiv/data.part*.zst', 'JSONEachRow', 'abstract String, id String, vector Array(Float32), metadata Object(''JSON'')', 'zstd' ); ALTER TABLE ChatArXiv ADD VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine'); ``` ### Create the `SQLDatabaseChain` This LangChain feature is currently under [MyScale tech preview](https://github.com/myscale/langchain/tree/preview). You can install it by executing the following installation script: ```bash python3 -m venv .venv source .venv/bin/activate # This is a technical preview of langchain from MyScale pip3 install langchain@git+https://github.com/myscale/langchain.git@preview ``` Once you have installed this feature, the next step is to use it to query the database, as the following Python code demonstrates: ```python from sqlalchemy import create_engine MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" MYSCALE_PORT = 443 MYSCALE_USER = "chatdata" MYSCALE_PASSWORD = "myscale_rocks" # create connection to database engine = create_engine(f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https') from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.callbacks import StdOutCallbackHandler from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser from langchain.sql_database import SQLDatabase from langchain.chains.sql_database.prompt import MYSCALE_PROMPT from langchain.llms import OpenAI from langchain.chains.sql_database.base import SQLDatabaseChain # this parser converts `NeuralArray()` into embeddings output_parser = VectorSQLRetrieveAllOutputParser( model=HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-xl') ) # use the prompt above PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], template=_prompt, ) # bind the metadata to SqlAlchemy engine metadata = MetaData(bind=engine) # create SQLDatabaseChain query_chain = SQLDatabaseChain.from_llm( # GPT-3.5 generates valid SQL better llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0), # use the predefined prompt, change it to your own prompt prompt=MYSCALE_PROMPT, # returns top 10 relevant documents top_k=10, # use result directly from DB return_direct=True, # use our database for retreival db=SQLDatabase(engine, None, metadata), # convert `NeuralArray()` into embeddings sql_cmd_parser=output_parser) # launch the chain!! And trace all chain calls in standard output query_chain.run("Introduce some papers that uses Generative Adversarial Networks published around 2019.", callbacks=[StdOutCallbackHandler()]) ``` ### Ask with `RetrievalQAwithSourcesChain` You can also use this SQLDatabaseChain as a Retriever. You can plugin it in to some retrieval QA chains just like other retievers in LangChain. ```python from langchain.retrievers import SQLDatabaseChainRetriever from langchain.chains.qa_with_sources.map_reduce_prompt import combine_prompt_template OPENAI_API_KEY = "sk-***" # define how you serialize those structured data from database document_with_metadata_prompt = PromptTemplate( input_variables=["page_content", "id", "title", "authors", "pubdate", "categories"], template="Content:\n\tTitle: {title}\n\tAbstract: {page_content}\n\t" + "Authors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}" ) # define the prompt you use to ask the LLM COMBINE_PROMPT = PromptTemplate( template=combine_prompt_template, input_variables=["summaries", "question"]) # define a retriever with a SQLDatabaseChain retriever = SQLDatabaseChainRetriever( sql_db_chain=query_chain, page_content_key="abstract") # finally, the ask chain to organize all of these ask_chain = RetrievalQAWithSourcesChain.from_chain_type( ChatOpenAI(model_name='gpt-3.5-turbo-16k', openai_api_key=OPENAI_API_KEY, temperature=0.6), retriever=retriever, chain_type='stuff', chain_type_kwargs={ 'prompt': COMBINE_PROMPT, 'document_prompt': document_with_metadata_prompt, }, return_source_documents=True) # Run the chain! and get the result from LLM ask_chain("Introduce some papers that uses Generative Adversarial Networks published around 2019.", callbacks=[StdOutCallbackHandler()]) ``` We also provide a live demo on [**huggingface**](https://huggingface.co/spaces/myscale/ChatData) and the code is available on [**GitHub**](https://github.com/myscale/ChatData)! We used [**a customized Retrieval QA chain**](https://github.com/myscale/ChatData/blob/main/chains/arxiv_chains.py) to maximize the performance our search and ask pipeline with LangChain! ## In Conclusion In reality, most LLMs hallucinate. The most practical way to reduce its appearance is to add extra facts (external knowledge) to your question. External knowledge is crucial to improving the performance of LLM systems, allowing for the efficient and accurate retrieval of answers. Every word counts, and you don't want to waste your money on unused information that is retrieved by inaccurate queries. How? Enter Vector SQL, allowing you to execute finely-grained vector searches to target and retrieve the required information. Vector SQL is powerful and easy to learn for humans and machines. You can use many data types and functions to create complex queries. LLMs also like vector SQL, as its training dataset includes many references. Lastly, it is possible to translate Vector SQL into many vector databases using different embedding models. We believe that is the future of vector databases. Are interested in what we are doing? Join us on [discord](https://discord.gg/D2qpkqc4Jq) today!