[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: \"[BUG]\"\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nReport following things\n1. Input topic name\n2. All output files generated for this topic as a zip file.\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Environment:**\n - OS: [e.g. iOS, Windows]\n - Browser [e.g. chrome, safari] if the bug report is UI problem\n"
  },
  {
    "path": ".github/workflows/format-check.yml",
    "content": "name: Check Python formatting with Black\n\non:\n  pull_request:\n    branches:\n      - main\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - uses: actions/setup-python@v2\n      - uses: psf/black@stable\n        with:\n          black_args: \"knowledge_storm --check\"\n"
  },
  {
    "path": ".github/workflows/python-package.yml",
    "content": "name: Build and upload Python package\n\non:\n  workflow_dispatch:  # Allows manual triggering of the workflow\n\njobs:\n  build:\n\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@master\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v3\n        with:\n          python-version: \"3.11\"\n      - name: Compare versions in setup.py and knowledge_storm/__init__.py\n        run: |\n          VERSION_SETUP=$(grep -oP '(?<=version=\\\").*(?=\\\")' setup.py)\n          VERSION_INIT=$(grep -oP '(?<=__version__ = \\\").*(?=\\\")' knowledge_storm/__init__.py)\n          echo \"Version in setup.py: $VERSION_SETUP\"\n          echo \"Version in __init__.py: $VERSION_INIT\"\n          if [ \"$VERSION_SETUP\" != \"$VERSION_INIT\" ]; then\n            echo \"Error: Version mismatch between setup.py ($VERSION_SETUP) and knowledge_storm/__init__.py ($VERSION_INIT)\"\n            exit 1\n          fi\n        shell: bash\n      - name: Install dependencies\n        run: python3 -m pip install setuptools wheel twine\n      - name: Install dependencies\n        run: |\n          python3 -m pip install --upgrade pip setuptools wheel\n          if [ -f requirements.txt ]; then pip install -r requirements.txt; fi\n      - name: Build a binary wheel\n        run: python3 setup.py sdist bdist_wheel\n      - name: Publish package to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          user: __token__\n          password: ${{ secrets.PYPI_API_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# mac\n.DS_Store\n\n# Other\n.vscode\n*.tsv\n*.pt\ngpt*.txt\n*.env\nlocal/\nlocal_*\nbuild/\n*.egg-info/\n.idea\n.venv\n\n# Project-specific\nsecrets.toml\n*.log\n*/assertion.log\n*results/\n.venv/"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/psf/black\n    rev: 24.8.0\n    hooks:\n      - id: black\n        name: Format Python code with black\n        entry: black\n        args: [\"knowledge_storm/\"]\n        language: python\n        pass_filenames: true"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing\n\nThank you for your interest in contributing to STORM! \n\nContributions aren't just about code. Currently (last edit: 7/22/2024), we are accepting the following forms of contribution:\n- Pull requests for additional language model support to `knowledge_storm/lm.py`.\n- Pull requests for additional retrieval model/search engine support to `knowledge_storm/rm.py`.\n- Pull requests for new features to `frontend/demo_light` to assist other developers.\n- Identification and reporting of issues or bugs.\n- Helping each other by responding to issues.\n\nPlease note that we are not accepting code refactoring PRs at this time to avoid conflicts with our team's efforts.\n\n## Development\nThis section contains technical instructions & hints for contributors.\n\n### Setting up\n1. Fork this repository and clone your forked repository.\n2. Install the required packages:\n    ```\n    conda create -n storm python=3.11\n    conda activate storm\n    pip install -r requirements.txt\n    ```\n3. If you want to contribute to `frontend/demo_light`, follow its [Setup guide](https://github.com/stanford-oval/storm/tree/main/frontend/demo_light#setup) to install additional packages.\n\n### PR suggestions\n\nFollowing the suggested format can lead to a faster review process.\n\n**Title:**\n\n[New LM/New RM/Demo Enhancement] xxx\n\n**Description:**\n- For new language model support, (1) describe how to use the new LM class, (2) create an example script following the style of existing example scripts under `examples/`, (3) attach an input-output example of the example script.\n- For new retrieval model/search engine support, (1) describe how to use the new RM class and (2) attach input-output examples of the RM class.\n- For demo light enhancements, (1) describe what's new and (2) attach screenshots to demonstrate the UI change.\n- Please clearly describe the required API keys and provide instructions on how to get them (if applicable). This project manages API key with `secrets.toml`.\n\n**Code Format:**\n\nWe adopt [`black`](https://github.com/psf/black) for arranging and formatting Python code. To streamline the contribution process, we set up a [pre-commit hook](https://pre-commit.com/) to format the code under `knowledge_storm/` before committing. To install the pre-commit hook, run:\n```\npip install pre-commit\npre-commit install\n```\nThe hook will automatically format the code before each commit.\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Stanford Open Virtual Assistant Lab\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "MANIFEST.in",
    "content": "include requirements.txt\ninclude LICENSE\ninclude README.md"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n  <img src=\"assets/logo.svg\" style=\"width: 25%; height: auto;\">\n</p>\n\n# STORM: Synthesis of Topic Outlines through Retrieval and Multi-perspective Question Asking\n\n<p align=\"center\">\n| <a href=\"http://storm.genie.stanford.edu\"><b>Research preview</b></a> | <a href=\"https://arxiv.org/abs/2402.14207\"><b>STORM Paper</b></a>| <a href=\"https://www.arxiv.org/abs/2408.15232\"><b>Co-STORM Paper</b></a>  | <a href=\"https://storm-project.stanford.edu/\"><b>Website</b></a> |\n</p>\n**Latest News** 🔥\n\n- [2025/01] We add [litellm](https://github.com/BerriAI/litellm) integration for language models and embedding models in `knowledge-storm` v1.1.0.\n\n- [2024/09] Co-STORM codebase is now released and integrated into `knowledge-storm` python package v1.0.0. Run `pip install knowledge-storm --upgrade` to check it out.\n\n- [2024/09] We introduce collaborative STORM (Co-STORM) to support human-AI collaborative knowledge curation! [Co-STORM Paper](https://www.arxiv.org/abs/2408.15232) has been accepted to EMNLP 2024 main conference.\n\n- [2024/07] You can now install our package with `pip install knowledge-storm`!\n- [2024/07] We add `VectorRM` to support grounding on user-provided documents, complementing existing support of search engines (`YouRM`, `BingSearch`). (check out [#58](https://github.com/stanford-oval/storm/pull/58))\n- [2024/07] We release demo light for developers a minimal user interface built with streamlit framework in Python, handy for local development and demo hosting (checkout [#54](https://github.com/stanford-oval/storm/pull/54))\n- [2024/06] We will present STORM at NAACL 2024! Find us at Poster Session 2 on June 17 or check our [presentation material](assets/storm_naacl2024_slides.pdf). \n- [2024/05] We add Bing Search support in [rm.py](knowledge_storm/rm.py). Test STORM with `GPT-4o` - we now configure the article generation part in our demo using `GPT-4o` model.\n- [2024/04] We release refactored version of STORM codebase! We define [interface](knowledge_storm/interface.py) for STORM pipeline and reimplement STORM-wiki (check out [`src/storm_wiki`](knowledge_storm/storm_wiki)) to demonstrate how to instantiate the pipeline. We provide API to support customization of different language models and retrieval/search integration.\n\n[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n\n## Overview [(Try STORM now!)](https://storm.genie.stanford.edu/)\n\n<p align=\"center\">\n  <img src=\"assets/overview.svg\" style=\"width: 90%; height: auto;\">\n</p>\nSTORM is a LLM system that writes Wikipedia-like articles from scratch based on Internet search. Co-STORM further enhanced its feature by enabling human to collaborative LLM system to support more aligned and preferred information seeking and knowledge curation.\n\nWhile the system cannot produce publication-ready articles that often require a significant number of edits, experienced Wikipedia editors have found it helpful in their pre-writing stage.\n\n**More than 70,000 people have tried our [live research preview](https://storm.genie.stanford.edu/). Try it out to see how STORM can help your knowledge exploration journey and please provide feedback to help us improve the system 🙏!**\n\n\n\n## How STORM & Co-STORM works\n\n### STORM\n\nSTORM breaks down generating long articles with citations into two steps:\n\n1. **Pre-writing stage**: The system conducts Internet-based research to collect references and generates an outline.\n2. **Writing stage**: The system uses the outline and references to generate the full-length article with citations.\n<p align=\"center\">\n  <img src=\"assets/two_stages.jpg\" style=\"width: 60%; height: auto;\">\n</p>\n\nSTORM identifies the core of automating the research process as automatically coming up with good questions to ask. Directly prompting the language model to ask questions does not work well. To improve the depth and breadth of the questions, STORM adopts two strategies:\n1. **Perspective-Guided Question Asking**: Given the input topic, STORM discovers different perspectives by surveying existing articles from similar topics and uses them to control the question-asking process.\n2. **Simulated Conversation**: STORM simulates a conversation between a Wikipedia writer and a topic expert grounded in Internet sources to enable the language model to update its understanding of the topic and ask follow-up questions.\n\n### CO-STORM\n\nCo-STORM proposes **a collaborative discourse protocol** which implements a turn management policy to support smooth collaboration among \n\n- **Co-STORM LLM experts**: This type of agent generates answers grounded on external knowledge sources and/or raises follow-up questions based on the discourse history.\n- **Moderator**: This agent generates thought-provoking questions inspired by information discovered by the retriever but not directly used in previous turns. Question generation can also be grounded!\n- **Human user**: The human user will take the initiative to either (1) observe the discourse to gain deeper understanding of the topic, or (2) actively engage in the conversation by injecting utterances to steer the discussion focus.\n\n<p align=\"center\">\n  <img src=\"assets/co-storm-workflow.jpg\" style=\"width: 60%; height: auto;\">\n</p>\n\nCo-STORM also maintains a dynamic updated **mind map**, which organize collected information into a hierarchical concept structure, aiming to **build a shared conceptual space between the human user and the system**. The mind map has been proven to help reduce the mental load when the discourse goes long and in-depth. \n\nBoth STORM and Co-STORM are implemented in a highly modular way using [dspy](https://github.com/stanfordnlp/dspy).\n\n## Installation\n\n\nTo install the knowledge storm library, use `pip install knowledge-storm`. \n\nYou could also install the source code which allows you to modify the behavior of STORM engine directly.\n1. Clone the git repository.\n    ```shell\n    git clone https://github.com/stanford-oval/storm.git\n    cd storm\n    ```\n   \n2. Install the required packages.\n   ```shell\n   conda create -n storm python=3.11\n   conda activate storm\n   pip install -r requirements.txt\n   ```\n   \n\n## API\n\nCurrently, our package support:\n\n- Language model components: All language models supported by litellm as listed [here](https://docs.litellm.ai/docs/providers)\n- Embedding model components: All embedding models supported by litellm as listed [here](https://docs.litellm.ai/docs/embedding/supported_embedding)\n- retrieval module components: `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as \n\n:star2: **PRs for integrating more search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!**\n\nBoth STORM and Co-STORM are working in the information curation layer, you need to set up the information retrieval module and language model module to create their `Runner` classes respectively.\n\n### STORM\n\nThe STORM knowledge curation engine is defined as a simple Python `STORMWikiRunner` class. Here is an example of using You.com search engine and OpenAI models.\n\n```python\nimport os\nfrom knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs\nfrom knowledge_storm.lm import LitellmModel\nfrom knowledge_storm.rm import YouRM\n\nlm_configs = STORMWikiLMConfigs()\nopenai_kwargs = {\n    'api_key': os.getenv(\"OPENAI_API_KEY\"),\n    'temperature': 1.0,\n    'top_p': 0.9,\n}\n# STORM is a LM system so different components can be powered by different models to reach a good balance between cost and quality.\n# For a good practice, choose a cheaper/faster model for `conv_simulator_lm` which is used to split queries, synthesize answers in the conversation.\n# Choose a more powerful model for `article_gen_lm` to generate verifiable text with citations.\ngpt_35 = LitellmModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)\ngpt_4 = LitellmModel(model='gpt-4o', max_tokens=3000, **openai_kwargs)\nlm_configs.set_conv_simulator_lm(gpt_35)\nlm_configs.set_question_asker_lm(gpt_35)\nlm_configs.set_outline_gen_lm(gpt_4)\nlm_configs.set_article_gen_lm(gpt_4)\nlm_configs.set_article_polish_lm(gpt_4)\n# Check out the STORMWikiRunnerArguments class for more configurations.\nengine_args = STORMWikiRunnerArguments(...)\nrm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)\nrunner = STORMWikiRunner(engine_args, lm_configs, rm)\n```\n\nThe `STORMWikiRunner` instance can be evoked with the simple `run` method:\n```python\ntopic = input('Topic: ')\nrunner.run(\n    topic=topic,\n    do_research=True,\n    do_generate_outline=True,\n    do_generate_article=True,\n    do_polish_article=True,\n)\nrunner.post_run()\nrunner.summary()\n```\n- `do_research`: if True, simulate conversations with difference perspectives to collect information about the topic; otherwise, load the results.\n- `do_generate_outline`: if True, generate an outline for the topic; otherwise, load the results.\n- `do_generate_article`: if True, generate an article for the topic based on the outline and the collected information; otherwise, load the results.\n- `do_polish_article`: if True, polish the article by adding a summarization section and (optionally) removing duplicate content; otherwise, load the results.\n\n### Co-STORM\n\nThe Co-STORM knowledge curation engine is defined as a simple Python `CoStormRunner` class. Here is an example of using Bing search engine and OpenAI models.\n\n```python\nfrom knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner\nfrom knowledge_storm.lm import LitellmModel\nfrom knowledge_storm.logging_wrapper import LoggingWrapper\nfrom knowledge_storm.rm import BingSearch\n\n# Co-STORM adopts the same multi LM system paradigm as STORM \nlm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs()\nopenai_kwargs = {\n    \"api_key\": os.getenv(\"OPENAI_API_KEY\"),\n    \"api_provider\": \"openai\",\n    \"temperature\": 1.0,\n    \"top_p\": 0.9,\n    \"api_base\": None,\n} \nquestion_answering_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)\ndiscourse_manage_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)\nutterance_polishing_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs)\nwarmstart_outline_gen_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)\nquestion_asking_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs)\nknowledge_base_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)\n\nlm_config.set_question_answering_lm(question_answering_lm)\nlm_config.set_discourse_manage_lm(discourse_manage_lm)\nlm_config.set_utterance_polishing_lm(utterance_polishing_lm)\nlm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm)\nlm_config.set_question_asking_lm(question_asking_lm)\nlm_config.set_knowledge_base_lm(knowledge_base_lm)\n\n# Check out the Co-STORM's RunnerArguments class for more configurations.\ntopic = input('Topic: ')\nrunner_argument = RunnerArgument(topic=topic, ...)\nlogging_wrapper = LoggingWrapper(lm_config)\nbing_rm = BingSearch(bing_search_api_key=os.environ.get(\"BING_SEARCH_API_KEY\"),\n                     k=runner_argument.retrieve_top_k)\ncostorm_runner = CoStormRunner(lm_config=lm_config,\n                               runner_argument=runner_argument,\n                               logging_wrapper=logging_wrapper,\n                               rm=bing_rm)\n```\n\nThe `CoStormRunner` instance can be evoked with the `warmstart()` and `step(...)` methods.\n\n```python\n# Warm start the system to build shared conceptual space between Co-STORM and users\ncostorm_runner.warm_start()\n\n# Step through the collaborative discourse \n# Run either of the code snippets below in any order, as many times as you'd like\n# To observe the conversation:\nconv_turn = costorm_runner.step()\n# To inject your utterance to actively steer the conversation:\ncostorm_runner.step(user_utterance=\"YOUR UTTERANCE HERE\")\n\n# Generate report based on the collaborative discourse\ncostorm_runner.knowledge_base.reorganize()\narticle = costorm_runner.generate_report()\nprint(article)\n```\n\n\n\n## Quick Start with Example Scripts\n\nWe provide scripts in our [examples folder](examples) as a quick start to run STORM and Co-STORM with different configurations.\n\nWe suggest using `secrets.toml` to set up the API keys. Create a file `secrets.toml` under the root directory and add the following content:\n\n```shell\n# ============ language model configurations ============ \n# Set up OpenAI API key.\nOPENAI_API_KEY=\"your_openai_api_key\"\n# If you are using the API service provided by OpenAI, include the following line:\nOPENAI_API_TYPE=\"openai\"\n# If you are using the API service provided by Microsoft Azure, include the following lines:\nOPENAI_API_TYPE=\"azure\"\nAZURE_API_BASE=\"your_azure_api_base_url\"\nAZURE_API_VERSION=\"your_azure_api_version\"\n# ============ retriever configurations ============ \nBING_SEARCH_API_KEY=\"your_bing_search_api_key\" # if using bing search\n# ============ encoder configurations ============ \nENCODER_API_TYPE=\"openai\" # if using openai encoder\n```\n\n### STORM examples\n\n**To run STORM with `gpt` family models with default configurations:**\n\nRun the following command.\n```bash\npython examples/storm_examples/run_storm_wiki_gpt.py \\\n    --output-dir $OUTPUT_DIR \\\n    --retriever bing \\\n    --do-research \\\n    --do-generate-outline \\\n    --do-generate-article \\\n    --do-polish-article\n```\n\n**To run STORM using your favorite language models or grounding on your own corpus:** Check out [examples/storm_examples/README.md](examples/storm_examples/README.md).\n\n### Co-STORM examples\n\nTo run Co-STORM with `gpt` family models with default configurations,\n\n1. Add `BING_SEARCH_API_KEY=\"xxx\"` and `ENCODER_API_TYPE=\"xxx\"` to `secrets.toml`\n2. Run the following command\n\n```bash\npython examples/costorm_examples/run_costorm_gpt.py \\\n    --output-dir $OUTPUT_DIR \\\n    --retriever bing\n```\n\n\n## Customization of the Pipeline\n\n### STORM\n\nIf you have installed the source code, you can customize STORM based on your own use case. STORM engine consists of 4 modules:\n\n1. Knowledge Curation Module: Collects a broad coverage of information about the given topic.\n2. Outline Generation Module: Organizes the collected information by generating a hierarchical outline for the curated knowledge.\n3. Article Generation Module: Populates the generated outline with the collected information.\n4. Article Polishing Module: Refines and enhances the written article for better presentation.\n\nThe interface for each module is defined in `knowledge_storm/interface.py`, while their implementations are instantiated in `knowledge_storm/storm_wiki/modules/*`. These modules can be customized according to your specific requirements (e.g., generating sections in bullet point format instead of full paragraphs).\n\n### Co-STORM\n\nIf you have installed the source code, you can customize Co-STORM based on your own use case\n\n1. Co-STORM introduces multiple LLM agent types (i.e. Co-STORM experts and Moderator). LLM agent interface is defined in `knowledge_storm/interface.py` , while its implementation is instantiated in `knowledge_storm/collaborative_storm/modules/co_storm_agents.py`. Different LLM agent policies can be customized.\n2. Co-STORM introduces a collaborative discourse protocol, with its core function centered on turn policy management. We provide an example implementation of turn policy management through `DiscourseManager` in `knowledge_storm/collaborative_storm/engine.py`. It can be customized and further improved.\n\n## Datasets\nTo facilitate the study of automatic knowledge curation and complex information seeking, our project releases the following datasets:\n\n### FreshWiki\nThe FreshWiki Dataset is a collection of 100 high-quality Wikipedia articles focusing on the most-edited pages from February 2022 to September 2023. See Section 2.1 in [STORM paper](https://arxiv.org/abs/2402.14207) for more details.\n\nYou can download the dataset from [huggingface](https://huggingface.co/datasets/EchoShao8899/FreshWiki) directly. To ease the data contamination issue, we archive the [source code](https://github.com/stanford-oval/storm/tree/NAACL-2024-code-backup/FreshWiki) for the data construction pipeline that can be repeated at future dates.\n\n### WildSeek\nTo study users’ interests in complex information seeking tasks in the wild, we utilized data collected from the web research preview to create the WildSeek dataset. We downsampled the data to ensure the diversity of the topics and the quality of the data. Each data point is a pair comprising a topic and the user’s goal for conducting deep search on the topic.  For more details, please refer to Section 2.2 and Appendix A of [Co-STORM paper](https://www.arxiv.org/abs/2408.15232).\n\nThe WildSeek dataset is available [here](https://huggingface.co/datasets/YuchengJiang/WildSeek).\n\n## Replicate STORM & Co-STORM paper result\n\nFor STORM paper experiments, please switch to the branch `NAACL-2024-code-backup` [here](https://github.com/stanford-oval/storm/tree/NAACL-2024-code-backup).\n\nFor Co-STORM paper experiments, please switch to the branch `EMNLP-2024-code-backup` (placeholder for now, will be updated soon).\n\n## Roadmap & Contributions\nOur team is actively working on:\n1. Human-in-the-Loop Functionalities: Supporting user participation in the knowledge curation process.\n2. Information Abstraction: Developing abstractions for curated information to support presentation formats beyond the Wikipedia-style report.\n\nIf you have any questions or suggestions, please feel free to open an issue or pull request. We welcome contributions to improve the system and the codebase!\n\nContact person: [Yijia Shao](mailto:shaoyj@stanford.edu) and [Yucheng Jiang](mailto:yuchengj@stanford.edu)\n\n## Acknowledgement\nWe would like to thank Wikipedia for its excellent open-source content. The FreshWiki dataset is sourced from Wikipedia, licensed under the Creative Commons Attribution-ShareAlike (CC BY-SA) license.\n\nWe are very grateful to [Michelle Lam](https://michelle123lam.github.io/) for designing the logo for this project and [Dekun Ma](https://dekun.me) for leading the UI development.\n\nThanks to Vercel for their support of [open-source software](https://storm.genie.stanford.edu)\n\n## Citation\nPlease cite our paper if you use this code or part of it in your work:\n```bibtex\n@inproceedings{jiang-etal-2024-unknown,\n    title = \"Into the Unknown Unknowns: Engaged Human Learning through Participation in Language Model Agent Conversations\",\n    author = \"Jiang, Yucheng  and\n      Shao, Yijia  and\n      Ma, Dekun  and\n      Semnani, Sina  and\n      Lam, Monica\",\n    editor = \"Al-Onaizan, Yaser  and\n      Bansal, Mohit  and\n      Chen, Yun-Nung\",\n    booktitle = \"Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing\",\n    month = nov,\n    year = \"2024\",\n    address = \"Miami, Florida, USA\",\n    publisher = \"Association for Computational Linguistics\",\n    url = \"https://aclanthology.org/2024.emnlp-main.554/\",\n    doi = \"10.18653/v1/2024.emnlp-main.554\",\n    pages = \"9917--9955\",\n}\n\n@inproceedings{shao-etal-2024-assisting,\n    title = \"Assisting in Writing {W}ikipedia-like Articles From Scratch with Large Language Models\",\n    author = \"Shao, Yijia  and\n      Jiang, Yucheng  and\n      Kanell, Theodore  and\n      Xu, Peter  and\n      Khattab, Omar  and\n      Lam, Monica\",\n    editor = \"Duh, Kevin  and\n      Gomez, Helena  and\n      Bethard, Steven\",\n    booktitle = \"Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)\",\n    month = jun,\n    year = \"2024\",\n    address = \"Mexico City, Mexico\",\n    publisher = \"Association for Computational Linguistics\",\n    url = \"https://aclanthology.org/2024.naacl-long.347/\",\n    doi = \"10.18653/v1/2024.naacl-long.347\",\n    pages = \"6252--6278\",\n}\n```\n"
  },
  {
    "path": "examples/costorm_examples/run_costorm_gpt.py",
    "content": "\"\"\"\nCo-STORM pipeline powered by GPT-4o/4o-mini and Bing search engine.\nYou need to set up the following environment variables to run this script:\n    - OPENAI_API_KEY: OpenAI API key\n    - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')\n    - AZURE_API_BASE: Azure API base URL if using Azure API\n    - AZURE_API_VERSION: Azure API version if using Azure API\n    - BING_SEARCH_API_KEY: Biang search API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\n\nOutput will be structured as below\nargs.output_dir/\n    log.json           # Log of information-seeking conversation\n    report.txt         # Final article generated\n\"\"\"\n\nimport os\nimport json\nfrom argparse import ArgumentParser\nfrom knowledge_storm.collaborative_storm.engine import (\n    CollaborativeStormLMConfigs,\n    RunnerArgument,\n    CoStormRunner,\n)\nfrom knowledge_storm.collaborative_storm.modules.callback import (\n    LocalConsolePrintCallBackHandler,\n)\nfrom knowledge_storm.lm import OpenAIModel, AzureOpenAIModel\nfrom knowledge_storm.logging_wrapper import LoggingWrapper\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs()\n    openai_kwargs = (\n        {\n            \"api_key\": os.getenv(\"OPENAI_API_KEY\"),\n            \"api_provider\": \"openai\",\n            \"temperature\": 1.0,\n            \"top_p\": 0.9,\n            \"api_base\": None,\n        }\n        if os.getenv(\"OPENAI_API_TYPE\") == \"openai\"\n        else {\n            \"api_key\": os.getenv(\"AZURE_API_KEY\"),\n            \"temperature\": 1.0,\n            \"top_p\": 0.9,\n            \"api_base\": os.getenv(\"AZURE_API_BASE\"),\n            \"api_version\": os.getenv(\"AZURE_API_VERSION\"),\n        }\n    )\n\n    ModelClass = (\n        OpenAIModel if os.getenv(\"OPENAI_API_TYPE\") == \"openai\" else AzureOpenAIModel\n    )\n    # If you are using Azure service, make sure the model name matches your own deployed model name.\n    # The default name here is only used for demonstration and may not match your case.\n    gpt_4o_mini_model_name = \"gpt-4o-mini\"\n    gpt_4o_model_name = \"gpt-4o\"\n    if os.getenv(\"OPENAI_API_TYPE\") == \"azure\":\n        openai_kwargs[\"api_base\"] = os.getenv(\"AZURE_API_BASE\")\n        openai_kwargs[\"api_version\"] = os.getenv(\"AZURE_API_VERSION\")\n\n    # STORM is a LM system so different components can be powered by different models.\n    # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm\n    # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models\n    # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm\n    # which is responsible for generating sections with citations.\n    question_answering_lm = ModelClass(\n        model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs\n    )\n    discourse_manage_lm = ModelClass(\n        model=gpt_4o_model_name, max_tokens=500, **openai_kwargs\n    )\n    utterance_polishing_lm = ModelClass(\n        model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs\n    )\n    warmstart_outline_gen_lm = ModelClass(\n        model=gpt_4o_model_name, max_tokens=500, **openai_kwargs\n    )\n    question_asking_lm = ModelClass(\n        model=gpt_4o_model_name, max_tokens=300, **openai_kwargs\n    )\n    knowledge_base_lm = ModelClass(\n        model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs\n    )\n\n    lm_config.set_question_answering_lm(question_answering_lm)\n    lm_config.set_discourse_manage_lm(discourse_manage_lm)\n    lm_config.set_utterance_polishing_lm(utterance_polishing_lm)\n    lm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm)\n    lm_config.set_question_asking_lm(question_asking_lm)\n    lm_config.set_knowledge_base_lm(knowledge_base_lm)\n\n    topic = input(\"Topic: \")\n    runner_argument = RunnerArgument(\n        topic=topic,\n        retrieve_top_k=args.retrieve_top_k,\n        max_search_queries=args.max_search_queries,\n        total_conv_turn=args.total_conv_turn,\n        max_search_thread=args.max_search_thread,\n        max_search_queries_per_turn=args.max_search_queries_per_turn,\n        warmstart_max_num_experts=args.warmstart_max_num_experts,\n        warmstart_max_turn_per_experts=args.warmstart_max_turn_per_experts,\n        warmstart_max_thread=args.warmstart_max_thread,\n        max_thread_num=args.max_thread_num,\n        max_num_round_table_experts=args.max_num_round_table_experts,\n        moderator_override_N_consecutive_answering_turn=args.moderator_override_N_consecutive_answering_turn,\n        node_expansion_trigger_count=args.node_expansion_trigger_count,\n    )\n    logging_wrapper = LoggingWrapper(lm_config)\n    callback_handler = (\n        LocalConsolePrintCallBackHandler() if args.enable_log_print else None\n    )\n\n    # Co-STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=runner_argument.retrieve_top_k,\n            )\n        case \"you\":\n            rm = YouRM(\n                ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=runner_argument.retrieve_top_k\n            )\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=runner_argument.retrieve_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=runner_argument.retrieve_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=runner_argument.retrieve_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"),\n                k=runner_argument.retrieve_top_k,\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", or \"searxng\"'\n            )\n\n    costorm_runner = CoStormRunner(\n        lm_config=lm_config,\n        runner_argument=runner_argument,\n        logging_wrapper=logging_wrapper,\n        rm=rm,\n        callback_handler=callback_handler,\n    )\n\n    # warm start the system\n    costorm_runner.warm_start()\n\n    # Below is an example of how users may interact with Co-STORM to seek information together\n    # In actual deployment, we suggest allowing the user to decide whether to observe the agent utterance or inject a turn\n\n    # observing Co-STORM LLM agent utterance for 5 turns\n    for _ in range(1):\n        conv_turn = costorm_runner.step()\n        print(f\"**{conv_turn.role}**: {conv_turn.utterance}\\n\")\n\n    # active engaging by injecting your utterance\n    your_utterance = input(\"Your utterance: \")\n    costorm_runner.step(user_utterance=your_utterance)\n\n    # continue observing\n    conv_turn = costorm_runner.step()\n    print(f\"**{conv_turn.role}**: {conv_turn.utterance}\\n\")\n\n    # generate report\n    costorm_runner.knowledge_base.reorganize()\n    article = costorm_runner.generate_report()\n\n    # save results\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    # Save article\n    with open(os.path.join(args.output_dir, \"report.md\"), \"w\") as f:\n        f.write(article)\n\n    # Save instance dump\n    instance_copy = costorm_runner.to_dict()\n    with open(os.path.join(args.output_dir, \"instance_dump.json\"), \"w\") as f:\n        json.dump(instance_copy, f, indent=2)\n\n    # Save logging\n    log_dump = costorm_runner.dump_logging_and_reset()\n    with open(os.path.join(args.output_dir, \"log.json\"), \"w\") as f:\n        json.dump(log_dump, f, indent=2)\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/co-storm\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"brave\", \"serper\", \"duckduckgo\", \"tavily\", \"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    # hyperparameters for co-storm\n    parser.add_argument(\n        \"--retrieve_top_k\",\n        type=int,\n        default=10,\n        help=\"Retrieve top k results for each query in retriever.\",\n    )\n    parser.add_argument(\n        \"--max_search_queries\",\n        type=int,\n        default=2,\n        help=\"Maximum number of search queries to consider for each question.\",\n    )\n    parser.add_argument(\n        \"--total_conv_turn\",\n        type=int,\n        default=20,\n        help=\"Maximum number of turns in conversation.\",\n    )\n    parser.add_argument(\n        \"--max_search_thread\",\n        type=int,\n        default=5,\n        help=\"Maximum number of parallel threads for retriever.\",\n    )\n    parser.add_argument(\n        \"--max_search_queries_per_turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of search queries to consider in each turn.\",\n    )\n    parser.add_argument(\n        \"--warmstart_max_num_experts\",\n        type=int,\n        default=3,\n        help=\"Max number of experts in perspective-guided QA during warm start.\",\n    )\n    parser.add_argument(\n        \"--warmstart_max_turn_per_experts\",\n        type=int,\n        default=2,\n        help=\"Max number of turns per perspective during warm start.\",\n    )\n    parser.add_argument(\n        \"--warmstart_max_thread\",\n        type=int,\n        default=3,\n        help=\"Max number of threads for parallel perspective-guided QA during warm start.\",\n    )\n    parser.add_argument(\n        \"--max_thread_num\",\n        type=int,\n        default=10,\n        help=(\n            \"Maximum number of threads to use. \"\n            \"Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_num_round_table_experts\",\n        type=int,\n        default=2,\n        help=\"Max number of active experts in round table discussion.\",\n    )\n    parser.add_argument(\n        \"--moderator_override_N_consecutive_answering_turn\",\n        type=int,\n        default=3,\n        help=(\n            \"Number of consecutive expert answering turns before the moderator overrides the conversation.\"\n        ),\n    )\n    parser.add_argument(\n        \"--node_expansion_trigger_count\",\n        type=int,\n        default=10,\n        help=\"Trigger node expansion for nodes that contain more than N snippets.\",\n    )\n\n    # Boolean flags\n    parser.add_argument(\n        \"--enable_log_print\",\n        action=\"store_true\",\n        help=\"If set, enable console log print.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/README.md",
    "content": "# Examples\n\nWe host a number of example scripts for various customization of STORM (e.g., use your favorite language models, use your own corpus, etc.). These examples can be starting points for your own customizations and you are welcome to contribute your own examples by submitting a pull request to this directory.\n\n## Run STORM with your own language model\n[run_storm_wiki_gpt.py](run_storm_wiki_gpt.py) provides an example of running STORM with GPT models, and [run_storm_wiki_claude.py](run_storm_wiki_claude.py) provides an example of running STORM with Claude models. Besides using close-source models, you can also run STORM with models with open weights.\n\n`run_storm_wiki_mistral.py` provides an example of running STORM with `Mistral-7B-Instruct-v0.2` using [VLLM](https://docs.vllm.ai/en/stable/) server:\n\n1. Set up a VLLM server with the `Mistral-7B-Instruct-v0.2` model running.\n2. Run the following command under the root directory of the repository:\n\n   ```\n    python examples/storm_examples/run_storm_wiki_mistral.py \\\n       --url $URL \\\n       --port $PORT \\\n       --output-dir $OUTPUT_DIR \\\n       --retriever you \\\n       --do-research \\\n       --do-generate-outline \\\n       --do-generate-article \\\n       --do-polish-article\n    ```\n   - `--url` URL of the VLLM server.\n   - `--port` Port of the VLLM server.\n\nBesides VLLM server, STORM is also compatible with [TGI](https://huggingface.co/docs/text-generation-inference/en/index) server or [Together.ai](https://www.together.ai/products#inference) endpoint. \n\n\n## Run STORM with your own corpus\n\nBy default, STORM is grounded on the Internet using the search engine, but it can also be grounded on your own corpus using `VectorRM`. [run_storm_wiki_with_gpt_with_VectorRM.py](run_storm_wiki_gpt_with_VectorRM.py) provides an example of running STORM grounding on your provided data.\n\n1. Set up API keys.\n   - Make sure you have set up the OpenAI API key.\n   - `VectorRM` use [Qdrant](https://github.com/qdrant/qdrant-client) to create a vector store. If you want to set up this vector store online on a [Qdrant cloud server](https://cloud.qdrant.io/login), you need to set up `QDRANT_API_KEY` in `secrets.toml` as well; if you want to save the vector store locally, make sure you provide a location for the vector store.\n2. Prepare your corpus. The documents should be provided as a single CSV file with the following format:\n\n   | content                | title      | url        | description                        |\n   |------------------------|------------|------------|------------------------------------|\n   | I am a document.       | Document 1 | docu-n-112 | A self-explanatory document.       |\n   | I am another document. | Document 2 | docu-l-13  | Another self-explanatory document. |\n   | ...                    | ...        | ...        | ...                                |\n\n   - `url` will be used as a unique identifier of the document in STORM engine, so ensure different documents have different urls.\n   - The contents for `title` and `description` columns are optional. If not provided, the script will use default empty values.\n   - The content column is crucial and should be provided for each document.\n\n3. Run the command under the root directory of the repository:\n   To create the vector store offline, run\n\n   ```\n   python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \\\n       --output-dir $OUTPUT_DIR \\\n       --vector-db-mode offline \\\n       --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \\\n       --csv-file-path $CSV_FILE_PATH \\ \n       --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \\\n       --do-research \\\n       --do-generate-outline \\\n       --do-generate-article \\\n       --do-polish-article\n   ```\n\n   To create the vector store online on a Qdrant server, run\n\n   ```\n   python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \\\n       --output-dir $OUTPUT_DIR \\\n       --vector-db-mode online \\\n       --online-vector-db-url $ONLINE_VECTOR_DB_URL \\\n       --csv-file-path $CSV_FILE_PATH \\\n       --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \\\n       --do-research \\\n       --do-generate-outline \\\n       --do-generate-article \\\n       --do-polish-article\n   ```\n\n4. **Quick test with Kaggle arXiv Paper Abstracts dataset**:\n   \n   - Download `arxiv_data_210930-054931.csv` from [here](https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts).\n   - Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above.\n\n     ```\n     python examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV\n     ```\n   - Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., \"The progress of multimodal models in computer vision\") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.)\n\n     ```\n     python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \\\n         --output-dir $OUTPUT_DIR \\\n         --vector-db-mode offline \\\n         --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \\\n         --csv-file-path $PATH_TO_THE_PROCESSED_CSV \\\n         --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \\\n         --do-research \\\n         --do-generate-outline \\\n         --do-generate-article \\\n         --do-polish-article\n     ```\n   - For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link).\n\n     ```\n     python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \\\n         --output-dir $OUTPUT_DIR \\\n         --vector-db-mode offline \\\n         --offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \\\n         --do-research \\\n         --do-generate-outline \\\n         --do-generate-article \\\n         --do-polish-article\n     ```"
  },
  {
    "path": "examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py",
    "content": "\"\"\"Process `arxiv_data_210930-054931.csv` \nfrom https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts\nto a csv file that is compatible with VectorRM.\n\"\"\"\n\nfrom argparse import ArgumentParser\n\nimport pandas as pd\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\n        \"--input-path\", type=str, help=\"Path to arxiv_data_210930-054931.csv.\"\n    )\n    parser.add_argument(\n        \"--output-path\",\n        type=str,\n        help=\"Path to store the csv file that is compatible with VectorRM.\",\n    )\n    args = parser.parse_args()\n\n    df = pd.read_csv(args.input_path)\n    print(f\"The original dataset has {len(df)} samples.\")\n\n    # Downsample the dataset.\n    df = df[df[\"terms\"] == \"['cs.CV']\"]\n\n    # Reformat the dataset to match the VectorRM input format.\n    df.rename(columns={\"abstracts\": \"content\", \"titles\": \"title\"}, inplace=True)\n    df[\"url\"] = [\n        \"uid_\" + str(idx) for idx in range(len(df))\n    ]  # Ensure the url is unique.\n    df[\"description\"] = \"\"\n\n    print(f\"The downsampled dataset has {len(df)} samples.\")\n    df.to_csv(args.output_path, index=False)\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_claude.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by Claude family models and You.com search engine.\nYou need to set up the following environment variables to run this script:\n    - ANTHROPIC_API_KEY: Anthropic API key\n    - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import ClaudeModel\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n    claude_kwargs = {\n        \"api_key\": os.getenv(\"ANTHROPIC_API_KEY\"),\n        \"temperature\": 1.0,\n        \"top_p\": 0.9,\n    }\n\n    # STORM is a LM system so different components can be powered by different models.\n    # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm\n    # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models\n    # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm\n    # which is responsible for generating sections with citations.\n    conv_simulator_lm = ClaudeModel(\n        model=\"claude-3-haiku-20240307\", max_tokens=500, **claude_kwargs\n    )\n    question_asker_lm = ClaudeModel(\n        model=\"claude-3-sonnet-20240229\", max_tokens=500, **claude_kwargs\n    )\n    outline_gen_lm = ClaudeModel(\n        model=\"claude-3-opus-20240229\", max_tokens=400, **claude_kwargs\n    )\n    article_gen_lm = ClaudeModel(\n        model=\"claude-3-opus-20240229\", max_tokens=700, **claude_kwargs\n    )\n    article_polish_lm = ClaudeModel(\n        model=\"claude-3-opus-20240229\", max_tokens=4000, **claude_kwargs\n    )\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"you\":\n            rm = YouRM(ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=engine_args.search_top_k)\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=engine_args.search_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=engine_args.search_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"), k=engine_args.search_top_k\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", or \"searxng\"'\n            )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    topic = input(\"Topic: \")\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/claude\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"brave\", \"serper\", \"duckduckgo\", \"tavily\", \"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_deepseek.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by DeepSeek models and You.com or Bing search engine.\nYou need to set up the following environment variables to run this script:\n    - DEEPSEEK_API_KEY: DeepSeek API key\n    - DEEPSEEK_API_BASE: DeepSeek API base URL (default is https://api.deepseek.com)\n    - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nimport re\nimport logging\nfrom argparse import ArgumentParser\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import DeepSeekModel\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef sanitize_topic(topic):\n    \"\"\"\n    Sanitize the topic name for use in file names.\n    Remove or replace characters that are not allowed in file names.\n    \"\"\"\n    # Replace spaces with underscores\n    topic = topic.replace(\" \", \"_\")\n\n    # Remove any character that isn't alphanumeric, underscore, or hyphen\n    topic = re.sub(r\"[^a-zA-Z0-9_-]\", \"\", topic)\n\n    # Ensure the topic isn't empty after sanitization\n    if not topic:\n        topic = \"unnamed_topic\"\n\n    return topic\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n\n    logger = logging.getLogger(__name__)\n\n    # Ensure DEEPSEEK_API_KEY is set\n    if not os.getenv(\"DEEPSEEK_API_KEY\"):\n        raise ValueError(\n            \"DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file.\"\n        )\n\n    deepseek_kwargs = {\n        \"api_key\": os.getenv(\"DEEPSEEK_API_KEY\"),\n        \"api_base\": os.getenv(\"DEEPSEEK_API_BASE\", \"https://api.deepseek.com\"),\n        \"temperature\": args.temperature,\n        \"top_p\": args.top_p,\n    }\n\n    # DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks\n    # Users can choose the appropriate model based on their needs\n    conv_simulator_lm = DeepSeekModel(\n        model=args.model, max_tokens=500, **deepseek_kwargs\n    )\n    question_asker_lm = DeepSeekModel(\n        model=args.model, max_tokens=500, **deepseek_kwargs\n    )\n    outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs)\n    article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs)\n    article_polish_lm = DeepSeekModel(\n        model=args.model, max_tokens=4000, **deepseek_kwargs\n    )\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"you\":\n            rm = YouRM(ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=engine_args.search_top_k)\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=engine_args.search_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=engine_args.search_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"), k=engine_args.search_top_k\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", or \"searxng\"'\n            )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    topic = input(\"Topic: \")\n    sanitized_topic = sanitize_topic(topic)\n\n    try:\n        runner.run(\n            topic=sanitized_topic,\n            do_research=args.do_research,\n            do_generate_outline=args.do_generate_outline,\n            do_generate_article=args.do_generate_article,\n            do_polish_article=args.do_polish_article,\n            remove_duplicate=args.remove_duplicate,\n        )\n        runner.post_run()\n        runner.summary()\n    except Exception as e:\n        logger.exception(f\"An error occurred: {str(e)}\")\n        raise\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/deepseek\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"brave\", \"serper\", \"duckduckgo\", \"tavily\", \"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        choices=[\"deepseek-chat\", \"deepseek-coder\"],\n        default=\"deepseek-chat\",\n        help='DeepSeek model to use. \"deepseek-chat\" for general tasks, \"deepseek-coder\" for coding tasks.',\n    )\n    parser.add_argument(\n        \"--temperature\", type=float, default=1.0, help=\"Sampling temperature to use.\"\n    )\n    parser.add_argument(\n        \"--top_p\", type=float, default=0.9, help=\"Top-p sampling parameter.\"\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_gemini.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by Google Gemini models and search engine.\nYou need to set up the following environment variables to run this script:\n    - GOOGLE_API_KEY: Google API key (Can be obtained from https://ai.google.dev/gemini-api/docs/api-key)\n    - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import GoogleModel\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n    gemini_kwargs = {\n        \"api_key\": os.getenv(\"GOOGLE_API_KEY\"),\n        \"temperature\": 1.0,\n        \"top_p\": 0.9,\n    }\n\n    # STORM is a LM system so different components can be powered by different models.\n    # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm\n    # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models\n    # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm\n    # which is responsible for generating sections with citations.\n    # To check out available Google models, see:\n    # https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python#list_models\n    conv_simulator_lm = GoogleModel(\n        model=\"models/gemini-1.5-flash\", max_tokens=500, **gemini_kwargs\n    )\n    question_asker_lm = GoogleModel(\n        model=\"models/gemini-1.5-flash\", max_tokens=500, **gemini_kwargs\n    )\n    outline_gen_lm = GoogleModel(\n        model=\"models/gemini-1.5-pro-exp-0801\", max_tokens=400, **gemini_kwargs\n    )\n    article_gen_lm = GoogleModel(\n        model=\"models/gemini-1.5-pro-exp-0801\", max_tokens=700, **gemini_kwargs\n    )\n    article_polish_lm = GoogleModel(\n        model=\"models/gemini-1.5-pro-exp-0801\", max_tokens=4000, **gemini_kwargs\n    )\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"you\":\n            rm = YouRM(ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=engine_args.search_top_k)\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=engine_args.search_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=engine_args.search_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"), k=engine_args.search_top_k\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", or \"searxng\"'\n            )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    topic = input(\"Topic: \")\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/gemini\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"brave\", \"serper\", \"duckduckgo\", \"tavily\", \"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_gpt.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by GPT-3.5/4 and You.com search engine.\nYou need to set up the following environment variables to run this script:\n    - OPENAI_API_KEY: OpenAI API key\n    - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')\n    - AZURE_API_BASE: Azure API base URL if using Azure API\n    - AZURE_API_VERSION: Azure API version if using Azure API\n    - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\n\nfrom argparse import ArgumentParser\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import OpenAIModel, AzureOpenAIModel\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n    AzureAISearch,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n    openai_kwargs = {\n        \"api_key\": os.getenv(\"OPENAI_API_KEY\"),\n        \"temperature\": 1.0,\n        \"top_p\": 0.9,\n    }\n\n    ModelClass = (\n        OpenAIModel if os.getenv(\"OPENAI_API_TYPE\") == \"openai\" else AzureOpenAIModel\n    )\n    # If you are using Azure service, make sure the model name matches your own deployed model name.\n    # The default name here is only used for demonstration and may not match your case.\n    gpt_35_model_name = (\n        \"gpt-3.5-turbo\" if os.getenv(\"OPENAI_API_TYPE\") == \"openai\" else \"gpt-35-turbo\"\n    )\n    gpt_4_model_name = \"gpt-4o\"\n    if os.getenv(\"OPENAI_API_TYPE\") == \"azure\":\n        openai_kwargs[\"api_base\"] = os.getenv(\"AZURE_API_BASE\")\n        openai_kwargs[\"api_version\"] = os.getenv(\"AZURE_API_VERSION\")\n\n    # STORM is a LM system so different components can be powered by different models.\n    # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm\n    # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models\n    # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm\n    # which is responsible for generating sections with citations.\n    conv_simulator_lm = ModelClass(\n        model=gpt_35_model_name, max_tokens=500, **openai_kwargs\n    )\n    question_asker_lm = ModelClass(\n        model=gpt_35_model_name, max_tokens=500, **openai_kwargs\n    )\n    outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs)\n    article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs)\n    article_polish_lm = ModelClass(\n        model=gpt_4_model_name, max_tokens=4000, **openai_kwargs\n    )\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"you\":\n            rm = YouRM(ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=engine_args.search_top_k)\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=engine_args.search_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=engine_args.search_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"), k=engine_args.search_top_k\n            )\n        case \"azure_ai_search\":\n            rm = AzureAISearch(\n                azure_ai_search_api_key=os.getenv(\"AZURE_AI_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", \"searxng\", or \"azure_ai_search\"'\n            )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    topic = input(\"Topic: \")\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/gpt\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\n            \"bing\",\n            \"you\",\n            \"brave\",\n            \"serper\",\n            \"duckduckgo\",\n            \"tavily\",\n            \"searxng\",\n            \"azure_ai_search\",\n        ],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py",
    "content": "\"\"\"\nThis STORM Wiki pipeline powered by GPT-3.5/4 and local retrieval model that uses Qdrant.\nYou need to set up the following environment variables to run this script:\n    - OPENAI_API_KEY: OpenAI API key\n    - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')\n    - QDRANT_API_KEY: Qdrant API key (needed ONLY if online vector store was used)\n\nYou will also need an existing Qdrant vector store either saved in a folder locally offline or in a server online.\nIf not, then you would need a CSV file with documents, and the script is going to create the vector store for you.\nThe CSV should be in the following format:\ncontent  | title  |  url  |  description\nI am a document. | Document 1 | docu-n-112 | A self-explanatory document.\nI am another document. | Document 2 | docu-l-13 | Another self-explanatory document.\n\nNotice that the URL will be a unique identifier for the document so ensure different documents have different urls.\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.rm import VectorRM\nfrom knowledge_storm.lm import OpenAIModel, AzureOpenAIModel\nfrom knowledge_storm.utils import load_api_key, QdrantVectorStoreManager\n\n\ndef main(args):\n    # Load API key from the specified toml file path\n    load_api_key(toml_file_path=\"secrets.toml\")\n\n    # Initialize the language model configurations\n    engine_lm_configs = STORMWikiLMConfigs()\n    openai_kwargs = {\n        \"api_key\": os.getenv(\"OPENAI_API_KEY\"),\n        \"temperature\": 1.0,\n        \"top_p\": 0.9,\n    }\n\n    ModelClass = (\n        OpenAIModel if os.getenv(\"OPENAI_API_TYPE\") == \"openai\" else AzureOpenAIModel\n    )\n    # If you are using Azure service, make sure the model name matches your own deployed model name.\n    # The default name here is only used for demonstration and may not match your case.\n    gpt_35_model_name = (\n        \"gpt-3.5-turbo\" if os.getenv(\"OPENAI_API_TYPE\") == \"openai\" else \"gpt-35-turbo\"\n    )\n    gpt_4_model_name = \"gpt-4o\"\n    if os.getenv(\"OPENAI_API_TYPE\") == \"azure\":\n        openai_kwargs[\"api_base\"] = os.getenv(\"AZURE_API_BASE\")\n        openai_kwargs[\"api_version\"] = os.getenv(\"AZURE_API_VERSION\")\n\n    # STORM is a LM system so different components can be powered by different models.\n    # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm\n    # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models\n    # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm\n    # which is responsible for generating sections with citations.\n    conv_simulator_lm = ModelClass(\n        model=gpt_35_model_name, max_tokens=500, **openai_kwargs\n    )\n    question_asker_lm = ModelClass(\n        model=gpt_35_model_name, max_tokens=500, **openai_kwargs\n    )\n    outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs)\n    article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs)\n    article_polish_lm = ModelClass(\n        model=gpt_4_model_name, max_tokens=4000, **openai_kwargs\n    )\n\n    engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    engine_lm_configs.set_question_asker_lm(question_asker_lm)\n    engine_lm_configs.set_outline_gen_lm(outline_gen_lm)\n    engine_lm_configs.set_article_gen_lm(article_gen_lm)\n    engine_lm_configs.set_article_polish_lm(article_polish_lm)\n\n    # Initialize the engine arguments\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # Create / update the vector store with the documents in the csv file\n    if args.csv_file_path:\n        kwargs = {\n            \"file_path\": args.csv_file_path,\n            \"content_column\": \"content\",\n            \"title_column\": \"title\",\n            \"url_column\": \"url\",\n            \"desc_column\": \"description\",\n            \"batch_size\": args.embed_batch_size,\n            \"vector_db_mode\": args.vector_db_mode,\n            \"collection_name\": args.collection_name,\n            \"embedding_model\": args.embedding_model,\n            \"device\": args.device,\n        }\n        if args.vector_db_mode == \"offline\":\n            QdrantVectorStoreManager.create_or_update_vector_store(\n                vector_store_path=args.offline_vector_db_dir, **kwargs\n            )\n        elif args.vector_db_mode == \"online\":\n            QdrantVectorStoreManager.create_or_update_vector_store(\n                url=args.online_vector_db_url,\n                api_key=os.getenv(\"QDRANT_API_KEY\"),\n                **kwargs\n            )\n\n    # Setup VectorRM to retrieve information from your own data\n    rm = VectorRM(\n        collection_name=args.collection_name,\n        embedding_model=args.embedding_model,\n        device=args.device,\n        k=engine_args.search_top_k,\n    )\n\n    # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally):\n    if args.vector_db_mode == \"offline\":\n        rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir)\n    elif args.vector_db_mode == \"online\":\n        rm.init_online_vector_db(\n            url=args.online_vector_db_url, api_key=os.getenv(\"QDRANT_API_KEY\")\n        )\n\n    # Initialize the STORM Wiki Runner\n    runner = STORMWikiRunner(engine_args, engine_lm_configs, rm)\n\n    # run the pipeline\n    topic = input(\"Topic: \")\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/gpt_retrieval\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    # provide local corpus and set up vector db\n    parser.add_argument(\n        \"--collection-name\",\n        type=str,\n        default=\"my_documents\",\n        help=\"The collection name for vector store.\",\n    )\n    parser.add_argument(\n        \"--embedding_model\",\n        type=str,\n        default=\"BAAI/bge-m3\",\n        help=\"The collection name for vector store.\",\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"mps\",\n        help=\"The device used to run the retrieval model (mps, cuda, cpu, etc).\",\n    )\n    parser.add_argument(\n        \"--vector-db-mode\",\n        type=str,\n        choices=[\"offline\", \"online\"],\n        help=\"The mode of the Qdrant vector store (offline or online).\",\n    )\n    parser.add_argument(\n        \"--offline-vector-db-dir\",\n        type=str,\n        default=\"./vector_store\",\n        help=\"If use offline mode, please provide the directory to store the vector store.\",\n    )\n    parser.add_argument(\n        \"--online-vector-db-url\",\n        type=str,\n        help=\"If use online mode, please provide the url of the Qdrant server.\",\n    )\n    parser.add_argument(\n        \"--csv-file-path\",\n        type=str,\n        default=None,\n        help=\"The path of the custom document corpus in CSV format. The CSV file should include \"\n        \"content, title, url, and description columns.\",\n    )\n    parser.add_argument(\n        \"--embed-batch-size\",\n        type=int,\n        default=64,\n        help=\"Batch size for embedding the documents in the csv file.\",\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_groq.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by llama3-70b-8192 hosted by Groq server and You.com search engine.\nYou need to set up the following environment variables to run this script:\n    - GROQ_API_KEY: You can get your Groq API Key at https://console.groq.com/keys\n    - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\nYou also need to have a VLLM server running with the Mistral-7B-Instruct-v0.2 model. Specify `--url` and `--port` accordingly.\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nimport re\nfrom argparse import ArgumentParser\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\n\n# Now import lm directly\nimport lm\nfrom lm import GroqModel\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef sanitize_topic(topic):\n    \"\"\"\n    Sanitize the topic name for use in file names.\n    Remove or replace characters that are not allowed in file names.\n    \"\"\"\n    # Replace spaces with underscores\n    topic = topic.replace(\" \", \"_\")\n\n    # Remove any character that isn't alphanumeric, underscore, or hyphen\n    topic = re.sub(r\"[^a-zA-Z0-9_-]\", \"\", topic)\n\n    # Ensure the topic isn't empty after sanitization\n    if not topic:\n        topic = \"unnamed_topic\"\n\n    return topic\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n\n    # Ensure GROQ_API_KEY is set\n    if not os.getenv(\"GROQ_API_KEY\"):\n        raise ValueError(\n            \"GROQ_API_KEY environment variable is not set. Please set it in your secrets.toml file.\"\n        )\n\n    groq_kwargs = {\n        \"api_key\": os.getenv(\"GROQ_API_KEY\"),\n        \"api_base\": \"https://api.groq.com/openai/v1\",\n        \"temperature\": args.temperature,\n        \"top_p\": args.top_p,\n    }\n\n    # Groq currently offers the \"llama3-70b-8192\" model with generous free API credits and the llama3.1 family of models as a preview for paying customers\n    conv_simulator_lm = GroqModel(\n        model=\"llama3-70b-8192\", max_tokens=500, **groq_kwargs\n    )\n    question_asker_lm = GroqModel(\n        model=\"llama3-70b-8192\", max_tokens=500, **groq_kwargs\n    )\n    outline_gen_lm = GroqModel(model=\"llama3-70b-8192\", max_tokens=400, **groq_kwargs)\n    article_gen_lm = GroqModel(model=\"llama3-70b-8192\", max_tokens=700, **groq_kwargs)\n    article_polish_lm = GroqModel(\n        model=\"llama3-70b-8192\", max_tokens=4000, **groq_kwargs\n    )\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"you\":\n            rm = YouRM(ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=engine_args.search_top_k)\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=engine_args.search_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=engine_args.search_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"), k=engine_args.search_top_k\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", or \"searxng\"'\n            )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    topic = input(\"Topic: \")\n    sanitized_topic = sanitize_topic(topic)\n\n    try:\n        runner.run(\n            topic=sanitized_topic,\n            do_research=args.do_research,\n            do_generate_outline=args.do_generate_outline,\n            do_generate_article=args.do_generate_article,\n            do_polish_article=args.do_polish_article,\n            remove_duplicate=args.remove_duplicate,\n        )\n        runner.post_run()\n        runner.summary()\n    except Exception as e:\n        logger.exception(f\"An error occurred: {str(e)}\")\n        raise\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/groq\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"brave\", \"serper\", \"duckduckgo\", \"tavily\", \"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    parser.add_argument(\n        \"--temperature\", type=float, default=1.0, help=\"Sampling temperature to use.\"\n    )\n    parser.add_argument(\n        \"--top_p\", type=float, default=0.9, help=\"Top-p sampling parameter.\"\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_mistral.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by Mistral-7B-Instruct-v0.2 hosted by VLLM server and You.com search engine.\nYou need to set up the following environment variables to run this script:\n    - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\nYou also need to have a VLLM server running with the Mistral-7B-Instruct-v0.2 model. Specify `--url` and `--port` accordingly.\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\n\nfrom dspy import Example\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import VLLMClient\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n\n    mistral_kwargs = {\n        \"model\": \"mistralai/Mistral-7B-Instruct-v0.2\",\n        \"port\": args.port,\n        \"url\": args.url,\n        \"stop\": (\n            \"\\n\\n---\",\n        ),  # dspy uses \"\\n\\n---\" to separate examples. Open models sometimes generate this.\n    }\n\n    conv_simulator_lm = VLLMClient(max_tokens=500, **mistral_kwargs)\n    question_asker_lm = VLLMClient(max_tokens=500, **mistral_kwargs)\n    outline_gen_lm = VLLMClient(max_tokens=400, **mistral_kwargs)\n    article_gen_lm = VLLMClient(max_tokens=700, **mistral_kwargs)\n    article_polish_lm = VLLMClient(max_tokens=4000, **mistral_kwargs)\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"you\":\n            rm = YouRM(ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=engine_args.search_top_k)\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=engine_args.search_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=engine_args.search_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"), k=engine_args.search_top_k\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", or \"searxng\"'\n            )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    # Open LMs are generally weaker in following output format.\n    # One way for mitigation is to add one-shot example to the prompt to exemplify the desired output format.\n    # For example, we can add the following examples to the two prompts used in StormPersonaGenerator.\n    # Note that the example should be an object of dspy.Example with fields matching the InputField\n    # and OutputField in the prompt (i.e., dspy.Signature).\n    find_related_topic_example = Example(\n        topic=\"Knowledge Curation\",\n        related_topics=\"https://en.wikipedia.org/wiki/Knowledge_management\\n\"\n        \"https://en.wikipedia.org/wiki/Information_science\\n\"\n        \"https://en.wikipedia.org/wiki/Library_science\\n\",\n    )\n    gen_persona_example = Example(\n        topic=\"Knowledge Curation\",\n        examples=\"Title: Knowledge management\\n\"\n        \"Table of Contents: History\\nResearch\\n  Dimensions\\n  Strategies\\n  Motivations\\nKM technologies\"\n        \"\\nKnowledge barriers\\nKnowledge retention\\nKnowledge audit\\nKnowledge protection\\n\"\n        \"  Knowledge protection methods\\n    Formal methods\\n    Informal methods\\n\"\n        \"  Balancing knowledge protection and knowledge sharing\\n  Knowledge protection risks\",\n        personas=\"1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\\n\"\n        \"2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\\n\"\n        \"3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\\n\"\n        \"4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\\n\"\n        \"5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.\",\n    )\n    runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [\n        find_related_topic_example\n    ]\n    runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [\n        gen_persona_example\n    ]\n\n    # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some\n    # examples may be very long (e.g., an example for writing a section based on the given information), which may\n    # confuse the model. For these cases, you can create a pseudo-example that is short and easy to understand to steer\n    # the model's output format.\n    # For example, we can add the following pseudo-examples to the prompt used in WritePageOutlineFromConv and\n    # ConvToSection.\n    write_page_outline_example = Example(\n        topic=\"Example Topic\",\n        conv=\"Wikipedia Writer: ...\\nExpert: ...\\nWikipedia Writer: ...\\nExpert: ...\",\n        old_outline=\"# Section 1\\n## Subsection 1\\n## Subsection 2\\n\"\n        \"# Section 2\\n## Subsection 1\\n## Subsection 2\\n\"\n        \"# Section 3\",\n        outline=\"# New Section 1\\n## New Subsection 1\\n## New Subsection 2\\n\"\n        \"# New Section 2\\n\"\n        \"# New Section 3\\n## New Subsection 1\\n## New Subsection 2\\n## New Subsection 3\",\n    )\n    runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [\n        write_page_outline_example\n    ]\n    write_section_example = Example(\n        info=\"[1]\\nInformation in document 1\\n[2]\\nInformation in document 2\\n[3]\\nInformation in document 3\",\n        topic=\"Example Topic\",\n        section=\"Example Section\",\n        output=\"# Example Topic\\n## Subsection 1\\n\"\n        \"This is an example sentence [1]. This is another example sentence [2][3].\\n\"\n        \"## Subsection 2\\nThis is one more example sentence [1].\",\n    )\n    runner.storm_article_generation.section_gen.write_section.demos = [\n        write_section_example\n    ]\n\n    topic = input(\"Topic: \")\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--url\", type=str, default=\"http://localhost\", help=\"URL of the VLLM server.\"\n    )\n    parser.add_argument(\n        \"--port\", type=int, default=8000, help=\"Port of the VLLM server.\"\n    )\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/mistral_7b\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"brave\", \"serper\", \"duckduckgo\", \"tavily\", \"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_ollama.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by local model hosted by Ollama server and You.com or Bing search engine.\nYou need to set up the following environment variables to run this script:\n    - YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key\nYou also need to have a Ollama server running with the llama3 model or other. Specify `--url`, `--port` and `--model` accordingly.\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nimport sys\nfrom argparse import ArgumentParser\n\nfrom dspy import Example\n\nfrom knowledge_storm.lm import OllamaClient\nfrom knowledge_storm.rm import (\n    YouRM,\n    BingSearch,\n    BraveRM,\n    SerperRM,\n    DuckDuckGoSearchRM,\n    TavilySearchRM,\n    SearXNG,\n)\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n\n    ollama_kwargs = {\n        \"model\": args.model,\n        \"port\": args.port,\n        \"url\": args.url,\n        \"stop\": (\n            \"\\n\\n---\",\n        ),  # dspy uses \"\\n\\n---\" to separate examples. Open models sometimes generate this.\n    }\n\n    conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs)\n    question_asker_lm = OllamaClient(max_tokens=500, **ollama_kwargs)\n    outline_gen_lm = OllamaClient(max_tokens=400, **ollama_kwargs)\n    article_gen_lm = OllamaClient(max_tokens=700, **ollama_kwargs)\n    article_polish_lm = OllamaClient(max_tokens=4000, **ollama_kwargs)\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    # STORM is a knowledge curation system which consumes information from the retrieval module.\n    # Currently, the information source is the Internet and we use search engine API as the retrieval module.\n    match args.retriever:\n        case \"bing\":\n            rm = BingSearch(\n                bing_search_api=os.getenv(\"BING_SEARCH_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"you\":\n            rm = YouRM(ydc_api_key=os.getenv(\"YDC_API_KEY\"), k=engine_args.search_top_k)\n        case \"brave\":\n            rm = BraveRM(\n                brave_search_api_key=os.getenv(\"BRAVE_API_KEY\"),\n                k=engine_args.search_top_k,\n            )\n        case \"duckduckgo\":\n            rm = DuckDuckGoSearchRM(\n                k=engine_args.search_top_k, safe_search=\"On\", region=\"us-en\"\n            )\n        case \"serper\":\n            rm = SerperRM(\n                serper_search_api_key=os.getenv(\"SERPER_API_KEY\"),\n                query_params={\"autocorrect\": True, \"num\": 10, \"page\": 1},\n            )\n        case \"tavily\":\n            rm = TavilySearchRM(\n                tavily_search_api_key=os.getenv(\"TAVILY_API_KEY\"),\n                k=engine_args.search_top_k,\n                include_raw_content=True,\n            )\n        case \"searxng\":\n            rm = SearXNG(\n                searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"), k=engine_args.search_top_k\n            )\n        case _:\n            raise ValueError(\n                f'Invalid retriever: {args.retriever}. Choose either \"bing\", \"you\", \"brave\", \"duckduckgo\", \"serper\", \"tavily\", or \"searxng\"'\n            )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    # Open LMs are generally weaker in following output format.\n    # One way for mitigation is to add one-shot example to the prompt to exemplify the desired output format.\n    # For example, we can add the following examples to the two prompts used in StormPersonaGenerator.\n    # Note that the example should be an object of dspy.Example with fields matching the InputField\n    # and OutputField in the prompt (i.e., dspy.Signature).\n    find_related_topic_example = Example(\n        topic=\"Knowledge Curation\",\n        related_topics=\"https://en.wikipedia.org/wiki/Knowledge_management\\n\"\n        \"https://en.wikipedia.org/wiki/Information_science\\n\"\n        \"https://en.wikipedia.org/wiki/Library_science\\n\",\n    )\n    gen_persona_example = Example(\n        topic=\"Knowledge Curation\",\n        examples=\"Title: Knowledge management\\n\"\n        \"Table of Contents: History\\nResearch\\n  Dimensions\\n  Strategies\\n  Motivations\\nKM technologies\"\n        \"\\nKnowledge barriers\\nKnowledge retention\\nKnowledge audit\\nKnowledge protection\\n\"\n        \"  Knowledge protection methods\\n    Formal methods\\n    Informal methods\\n\"\n        \"  Balancing knowledge protection and knowledge sharing\\n  Knowledge protection risks\",\n        personas=\"1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge curation. They will provide context on how knowledge curation has changed over time and its impact on modern practices.\\n\"\n        \"2. Information Science Professional: With insights from 'Information science', this editor will explore the foundational theories, definitions, and philosophy that underpin knowledge curation\\n\"\n        \"3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, including software, metadata, digital preservation.\\n\"\n        \"4. Technical expert: This editor will focus on the technical aspects of knowledge curation, such as common features of content management systems.\\n\"\n        \"5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and the transition of these practices into the digital realm.\",\n    )\n    runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [\n        find_related_topic_example\n    ]\n    runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [\n        gen_persona_example\n    ]\n\n    # A trade-off of adding one-shot example is that it will increase the input length of the prompt. Also, some\n    # examples may be very long (e.g., an example for writing a section based on the given information), which may\n    # confuse the model. For these cases, you can create a pseudo-example that is short and easy to understand to steer\n    # the model's output format.\n    # For example, we can add the following pseudo-examples to the prompt used in WritePageOutlineFromConv and\n    # ConvToSection.\n    write_page_outline_example = Example(\n        topic=\"Example Topic\",\n        conv=\"Wikipedia Writer: ...\\nExpert: ...\\nWikipedia Writer: ...\\nExpert: ...\",\n        old_outline=\"# Section 1\\n## Subsection 1\\n## Subsection 2\\n\"\n        \"# Section 2\\n## Subsection 1\\n## Subsection 2\\n\"\n        \"# Section 3\",\n        outline=\"# New Section 1\\n## New Subsection 1\\n## New Subsection 2\\n\"\n        \"# New Section 2\\n\"\n        \"# New Section 3\\n## New Subsection 1\\n## New Subsection 2\\n## New Subsection 3\",\n    )\n    runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [\n        write_page_outline_example\n    ]\n    write_section_example = Example(\n        info=\"[1]\\nInformation in document 1\\n[2]\\nInformation in document 2\\n[3]\\nInformation in document 3\",\n        topic=\"Example Topic\",\n        section=\"Example Section\",\n        output=\"# Example Topic\\n## Subsection 1\\n\"\n        \"This is an example sentence [1]. This is another example sentence [2][3].\\n\"\n        \"## Subsection 2\\nThis is one more example sentence [1].\",\n    )\n    runner.storm_article_generation.section_gen.write_section.demos = [\n        write_section_example\n    ]\n\n    topic = input(\"Topic: \")\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--url\", type=str, default=\"http://localhost\", help=\"URL of the Ollama server.\"\n    )\n    parser.add_argument(\n        \"--port\", type=int, default=11434, help=\"Port of the Ollama server.\"\n    )\n    parser.add_argument(\n        \"--model\", type=str, default=\"llama3:latest\", help=\"Model of the Ollama server.\"\n    )\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/ollama\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"brave\", \"serper\", \"duckduckgo\", \"tavily\", \"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_ollama_with_searxng.py",
    "content": "import os\nfrom argparse import ArgumentParser\n\nfrom dspy import Example\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import OllamaClient\nfrom knowledge_storm.rm import SearXNG\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n\n    ollama_kwargs = {\n        \"model\": args.model,\n        \"port\": args.port,\n        \"url\": args.url,\n        \"stop\": (\"\\n\\n---\",),\n    }\n\n    conv_simulator_lm = OllamaClient(max_tokens=500, **ollama_kwargs)\n    question_asker_lm = OllamaClient(max_tokens=500, **ollama_kwargs)\n    outline_gen_lm = OllamaClient(max_tokens=400, **ollama_kwargs)\n    article_gen_lm = OllamaClient(max_tokens=700, **ollama_kwargs)\n    article_polish_lm = OllamaClient(max_tokens=4000, **ollama_kwargs)\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n\n    rm = SearXNG(\n        searxng_api_url=args.searxng_api_url,\n        searxng_api_key=os.getenv(\"SEARXNG_API_KEY\"),\n        k=engine_args.search_top_k,\n    )\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    find_related_topic_example = Example(\n        topic=\"Knowledge Curation\",\n        related_topics=\"https://en.wikipedia.org/wiki/Knowledge_management\\n\"\n        \"https://en.wikipedia.org/wiki/Information_science\\n\"\n        \"https://en.wikipedia.org/wiki/Library_science\\n\",\n    )\n    gen_persona_example = Example(\n        topic=\"Knowledge Curation\",\n        examples=\"Title: Knowledge management\\n\"\n        \"Table of Contents: History\\nResearch\\n  Dimensions\\n  Strategies\\n  Motivations\\nKM technologies\"\n        \"\\nKnowledge barriers\\nKnowledge retention\\nKnowledge audit\\nKnowledge protection\\n\"\n        \"  Knowledge protection methods\\n    Formal methods\\n    Informal methods\\n\"\n        \"  Balancing knowledge protection and knowledge sharing\\n  Knowledge protection risks\",\n        personas=(\n            \"1. Historian of Knowledge Systems: This editor will focus on the history and evolution of knowledge \"\n            \"curation. They will provide context on how knowledge curation has changed over time and its impact on \"\n            \"modern practices.\\n\"\n            \"2. Information Science Professional: With insights from 'Information science', this editor will \"\n            \"explore the foundational theories, definitions, and philosophy that underpin knowledge curation\\n\"\n            \"3. Digital Librarian: This editor will delve into the specifics of how digital libraries operate, \"\n            \"including software, metadata, digital preservation.\\n\"\n            \"4. Technical expert: This editor will focus on the technical aspects of knowledge curation, \"\n            \"such as common features of content management systems.\\n\"\n            \"5. Museum Curator: The museum curator will contribute expertise on the curation of physical items and \"\n            \"the transition of these practices into the digital realm.\"\n        ),\n    )\n    runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.find_related_topic.demos = [\n        find_related_topic_example\n    ]\n    runner.storm_knowledge_curation_module.persona_generator.create_writer_with_persona.gen_persona.demos = [\n        gen_persona_example\n    ]\n\n    write_page_outline_example = Example(\n        topic=\"Example Topic\",\n        conv=\"Wikipedia Writer: ...\\nExpert: ...\\nWikipedia Writer: ...\\nExpert: ...\",\n        old_outline=\"# Section 1\\n## Subsection 1\\n## Subsection 2\\n\"\n        \"# Section 2\\n## Subsection 1\\n## Subsection 2\\n\"\n        \"# Section 3\",\n        outline=\"# New Section 1\\n## New Subsection 1\\n## New Subsection 2\\n\"\n        \"# New Section 2\\n\"\n        \"# New Section 3\\n## New Subsection 1\\n## New Subsection 2\\n## New Subsection 3\",\n    )\n    runner.storm_outline_generation_module.write_outline.write_page_outline.demos = [\n        write_page_outline_example\n    ]\n    write_section_example = Example(\n        info=\"[1]\\nInformation in document 1\\n[2]\\nInformation in document 2\\n[3]\\nInformation in document 3\",\n        topic=\"Example Topic\",\n        section=\"Example Section\",\n        output=\"# Example Topic\\n## Subsection 1\\n\"\n        \"This is an example sentence [1]. This is another example sentence [2][3].\\n\"\n        \"## Subsection 2\\nThis is one more example sentence [1].\",\n    )\n    runner.storm_article_generation.section_gen.write_section.demos = [\n        write_section_example\n    ]\n\n    topic = input(\"Topic: \")\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--url\", type=str, default=\"http://localhost\", help=\"URL of the Ollama server.\"\n    )\n    parser.add_argument(\n        \"--port\", type=int, default=11434, help=\"Port of the Ollama server.\"\n    )\n    parser.add_argument(\n        \"--model\", type=str, default=\"llama3:latest\", help=\"Model of the Ollama server.\"\n    )\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/ollama\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"searxng\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    parser.add_argument(\n        \"--searxng-api-url\", type=str, required=True, help=\"URL of the SearXNG API.\"\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "examples/storm_examples/run_storm_wiki_serper.py",
    "content": "\"\"\"\nSTORM Wiki pipeline powered by Claude family models and serper search engine.\nYou need to set up the following environment variables to run this script:\n    - ANTHROPIC_API_KEY: Anthropic API key\n    - SERPER_API_KEY: Serper.dev api key\n\nOutput will be structured as below\nargs.output_dir/\n    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash\n        conversation_log.json           # Log of information-seeking conversation\n        raw_search_results.json         # Raw search results from search engine\n        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge\n        storm_gen_outline.txt           # Outline refined with collected information\n        url_to_info.json                # Sources that are used in the final article\n        storm_gen_article.txt           # Final article generated\n        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)\n\"\"\"\n\nimport os\nfrom argparse import ArgumentParser\n\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import ClaudeModel\nfrom knowledge_storm.rm import SerperRM\nfrom knowledge_storm.utils import load_api_key\n\n\ndef main(args):\n    load_api_key(toml_file_path=\"secrets.toml\")\n    lm_configs = STORMWikiLMConfigs()\n    claude_kwargs = {\n        \"api_key\": os.getenv(\"ANTHROPIC_API_KEY\"),\n        \"temperature\": 1.0,\n        \"top_p\": 0.9,\n    }\n\n    # STORM is a LM system so different components can be powered by different models.\n    # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm\n    # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models\n    # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm\n    # which is responsible for generating sections with citations.\n    conv_simulator_lm = ClaudeModel(\n        model=\"claude-3-haiku-20240307\", max_tokens=500, **claude_kwargs\n    )\n    question_asker_lm = ClaudeModel(\n        model=\"claude-3-sonnet-20240229\", max_tokens=500, **claude_kwargs\n    )\n    outline_gen_lm = ClaudeModel(\n        model=\"claude-3-opus-20240229\", max_tokens=400, **claude_kwargs\n    )\n    article_gen_lm = ClaudeModel(\n        model=\"claude-3-opus-20240229\", max_tokens=700, **claude_kwargs\n    )\n    article_polish_lm = ClaudeModel(\n        model=\"claude-3-opus-20240229\", max_tokens=4000, **claude_kwargs\n    )\n\n    lm_configs.set_conv_simulator_lm(conv_simulator_lm)\n    lm_configs.set_question_asker_lm(question_asker_lm)\n    lm_configs.set_outline_gen_lm(outline_gen_lm)\n    lm_configs.set_article_gen_lm(article_gen_lm)\n    lm_configs.set_article_polish_lm(article_polish_lm)\n\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=args.output_dir,\n        max_conv_turn=args.max_conv_turn,\n        max_perspective=args.max_perspective,\n        search_top_k=args.search_top_k,\n        max_thread_num=args.max_thread_num,\n    )\n    # Documentation to generate the data is available here:\n    # https://serper.dev/playground\n    # Important to note that tbs(date range is hardcoded values).\n    # num is results per pages and is recommended to use in increments of 10(10, 20, etc).\n    # page is how many pages will be searched.\n    # h1 is where the google search will orginate from.\n    topic = input(\"topic: \")\n    data = {\"autocorrect\": True, \"num\": 10, \"page\": 1}\n    rm = SerperRM(serper_search_api_key=os.getenv(\"SERPER_API_KEY\"), query_params=data)\n\n    runner = STORMWikiRunner(engine_args, lm_configs, rm)\n\n    runner.run(\n        topic=topic,\n        do_research=args.do_research,\n        do_generate_outline=args.do_generate_outline,\n        do_generate_article=args.do_generate_article,\n        do_polish_article=args.do_polish_article,\n    )\n    runner.post_run()\n    runner.summary()\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    # global arguments\n    parser.add_argument(\n        \"--output-dir\",\n        type=str,\n        default=\"./results/serper\",\n        help=\"Directory to store the outputs.\",\n    )\n    parser.add_argument(\n        \"--max-thread-num\",\n        type=int,\n        default=3,\n        help=\"Maximum number of threads to use. The information seeking part and the article generation\"\n        \"part can speed up by using multiple threads. Consider reducing it if keep getting \"\n        '\"Exceed rate limit\" error when calling LM API.',\n    )\n    parser.add_argument(\n        \"--retriever\",\n        type=str,\n        choices=[\"bing\", \"you\", \"serper\"],\n        help=\"The search engine API to use for retrieving information.\",\n    )\n    # stage of the pipeline\n    parser.add_argument(\n        \"--do-research\",\n        action=\"store_true\",\n        help=\"If True, simulate conversation to research the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-outline\",\n        action=\"store_true\",\n        help=\"If True, generate an outline for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-generate-article\",\n        action=\"store_true\",\n        help=\"If True, generate an article for the topic; otherwise, load the results.\",\n    )\n    parser.add_argument(\n        \"--do-polish-article\",\n        action=\"store_true\",\n        help=\"If True, polish the article by adding a summarization section and (optionally) removing \"\n        \"duplicate content.\",\n    )\n    # hyperparameters for the pre-writing stage\n    parser.add_argument(\n        \"--max-conv-turn\",\n        type=int,\n        default=3,\n        help=\"Maximum number of questions in conversational question asking.\",\n    )\n    parser.add_argument(\n        \"--max-perspective\",\n        type=int,\n        default=3,\n        help=\"Maximum number of perspectives to consider in perspective-guided question asking.\",\n    )\n    parser.add_argument(\n        \"--search-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k search results to consider for each search query.\",\n    )\n    # hyperparameters for the writing stage\n    parser.add_argument(\n        \"--retrieve-top-k\",\n        type=int,\n        default=3,\n        help=\"Top k collected references for each section title.\",\n    )\n    parser.add_argument(\n        \"--remove-duplicate\",\n        action=\"store_true\",\n        help=\"If True, remove duplicate content from the article.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "frontend/demo_light/.streamlit/config.toml",
    "content": "[client]\nshowErrorDetails = false\ntoolbarMode = \"minimal\"\n\n[theme]\nprimaryColor = \"#F63366\"\nbackgroundColor = \"#FFFFFF\"\nsecondaryBackgroundColor = \"#F0F2F6\"\ntextColor = \"#262730\"\nfont = \"sans serif\""
  },
  {
    "path": "frontend/demo_light/README.md",
    "content": "# STORM Minimal User Interface\n\nThis is a minimal user interface for `STORMWikiRunner` which includes the following features:\n1. Allowing user to create a new article through the \"Create New Article\" page.\n2. Showing the intermediate steps of STORMWikiRunner in real-time when creating an article.\n3. Displaying the written article and references side by side.\n4. Allowing user to view previously created articles through the \"My Articles\" page.\n\n<p align=\"center\">\n  <img src=\"assets/create_article.jpg\" style=\"width: 70%; height: auto;\">\n</p>\n\n<p align=\"center\">\n  <img src=\"assets/article_display.jpg\" style=\"width: 70%; height: auto;\">\n</p>\n\n## Setup\n1. Make sure you have installed `knowledge-storm` or set up the source code correctly.\n2. Install additional packages required by the user interface:\n    ```bash\n    pip install -r requirements.txt\n    ```\n2. Make sure you set up the API keys following the instructions in the main README file. Create a copy of `secrets.toml` and place it under `.streamlit/`.\n3. Run the following command to start the user interface:\n    ```bash\n    streamlit run storm.py\n    ```\n   The user interface will create a `DEMO_WORKING_DIR` directory in the current directory to store the outputs.\n\n## Customization\n\nYou can customize the `STORMWikiRunner` powering the user interface according to [the guidelines](https://github.com/stanford-oval/storm?tab=readme-ov-file#customize-storm) in the main README file.\n\nThe `STORMWikiRunner` is initialized in `set_storm_runner()` in [demo_util.py](demo_util.py). You can change `STORMWikiRunnerArguments`, `STORMWikiLMConfigs`, or use a different retrieval model according to your need.\n"
  },
  {
    "path": "frontend/demo_light/demo_util.py",
    "content": "import base64\nimport datetime\nimport json\nimport os\nimport re\nfrom typing import Optional\n\nimport markdown\nimport pytz\nimport streamlit as st\n\n# If you install the source code instead of the `knowledge-storm` package,\n# Uncomment the following lines:\n# import sys\n# sys.path.append('../../')\nfrom knowledge_storm import (\n    STORMWikiRunnerArguments,\n    STORMWikiRunner,\n    STORMWikiLMConfigs,\n)\nfrom knowledge_storm.lm import OpenAIModel\nfrom knowledge_storm.rm import YouRM\nfrom knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler\nfrom knowledge_storm.utils import truncate_filename\nfrom stoc import stoc\n\n\nclass DemoFileIOHelper:\n    @staticmethod\n    def read_structure_to_dict(articles_root_path):\n        \"\"\"\n        Reads the directory structure of articles stored in the given root path and\n        returns a nested dictionary. The outer dictionary has article names as keys,\n        and each value is another dictionary mapping file names to their absolute paths.\n\n        Args:\n            articles_root_path (str): The root directory path containing article subdirectories.\n\n        Returns:\n            dict: A dictionary where each key is an article name, and each value is a dictionary\n                of file names and their absolute paths within that article's directory.\n        \"\"\"\n        articles_dict = {}\n        for topic_name in os.listdir(articles_root_path):\n            topic_path = os.path.join(articles_root_path, topic_name)\n            if os.path.isdir(topic_path):\n                # Initialize or update the dictionary for the topic\n                articles_dict[topic_name] = {}\n                # Iterate over all files within a topic directory\n                for file_name in os.listdir(topic_path):\n                    file_path = os.path.join(topic_path, file_name)\n                    articles_dict[topic_name][file_name] = os.path.abspath(file_path)\n        return articles_dict\n\n    @staticmethod\n    def read_txt_file(file_path):\n        \"\"\"\n        Reads the contents of a text file and returns it as a string.\n\n        Args:\n            file_path (str): The path to the text file to be read.\n\n        Returns:\n            str: The content of the file as a single string.\n        \"\"\"\n        with open(file_path) as f:\n            return f.read()\n\n    @staticmethod\n    def read_json_file(file_path):\n        \"\"\"\n        Reads a JSON file and returns its content as a Python dictionary or list,\n        depending on the JSON structure.\n\n        Args:\n            file_path (str): The path to the JSON file to be read.\n\n        Returns:\n            dict or list: The content of the JSON file. The type depends on the\n                        structure of the JSON file (object or array at the root).\n        \"\"\"\n        with open(file_path) as f:\n            return json.load(f)\n\n    @staticmethod\n    def read_image_as_base64(image_path):\n        \"\"\"\n        Reads an image file and returns its content encoded as a base64 string,\n        suitable for embedding in HTML or transferring over networks where binary\n        data cannot be easily sent.\n\n        Args:\n            image_path (str): The path to the image file to be encoded.\n\n        Returns:\n            str: The base64 encoded string of the image, prefixed with the necessary\n                data URI scheme for images.\n        \"\"\"\n        with open(image_path, \"rb\") as f:\n            data = f.read()\n            encoded = base64.b64encode(data)\n        data = \"data:image/png;base64,\" + encoded.decode(\"utf-8\")\n        return data\n\n    @staticmethod\n    def set_file_modification_time(file_path, modification_time_string):\n        \"\"\"\n        Sets the modification time of a file based on a given time string in the California time zone.\n\n        Args:\n            file_path (str): The path to the file.\n            modification_time_string (str): The desired modification time in 'YYYY-MM-DD HH:MM:SS' format.\n        \"\"\"\n        california_tz = pytz.timezone(\"America/Los_Angeles\")\n        modification_time = datetime.datetime.strptime(\n            modification_time_string, \"%Y-%m-%d %H:%M:%S\"\n        )\n        modification_time = california_tz.localize(modification_time)\n        modification_time_utc = modification_time.astimezone(datetime.timezone.utc)\n        modification_timestamp = modification_time_utc.timestamp()\n        os.utime(file_path, (modification_timestamp, modification_timestamp))\n\n    @staticmethod\n    def get_latest_modification_time(path):\n        \"\"\"\n        Returns the latest modification time of all files in a directory in the California time zone as a string.\n\n        Args:\n            directory_path (str): The path to the directory.\n\n        Returns:\n            str: The latest file's modification time in 'YYYY-MM-DD HH:MM:SS' format.\n        \"\"\"\n        california_tz = pytz.timezone(\"America/Los_Angeles\")\n        latest_mod_time = None\n\n        file_paths = []\n        if os.path.isdir(path):\n            for root, dirs, files in os.walk(path):\n                for file in files:\n                    file_paths.append(os.path.join(root, file))\n        else:\n            file_paths = [path]\n\n        for file_path in file_paths:\n            modification_timestamp = os.path.getmtime(file_path)\n            modification_time_utc = datetime.datetime.utcfromtimestamp(\n                modification_timestamp\n            )\n            modification_time_utc = modification_time_utc.replace(\n                tzinfo=datetime.timezone.utc\n            )\n            modification_time_california = modification_time_utc.astimezone(\n                california_tz\n            )\n\n            if (\n                latest_mod_time is None\n                or modification_time_california > latest_mod_time\n            ):\n                latest_mod_time = modification_time_california\n\n        if latest_mod_time is not None:\n            return latest_mod_time.strftime(\"%Y-%m-%d %H:%M:%S\")\n        else:\n            return datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n\n    @staticmethod\n    def assemble_article_data(article_file_path_dict):\n        \"\"\"\n        Constructs a dictionary containing the content and metadata of an article\n        based on the available files in the article's directory. This includes the\n        main article text, citations from a JSON file, and a conversation log if\n        available. The function prioritizes a polished version of the article if\n        both a raw and polished version exist.\n\n        Args:\n            article_file_paths (dict): A dictionary where keys are file names relevant\n                                    to the article (e.g., the article text, citations\n                                    in JSON format, conversation logs) and values\n                                    are their corresponding file paths.\n\n        Returns:\n            dict or None: A dictionary containing the parsed content of the article,\n                        citations, and conversation log if available. Returns None\n                        if neither the raw nor polished article text exists in the\n                        provided file paths.\n        \"\"\"\n        if (\n            \"storm_gen_article.txt\" in article_file_path_dict\n            or \"storm_gen_article_polished.txt\" in article_file_path_dict\n        ):\n            full_article_name = (\n                \"storm_gen_article_polished.txt\"\n                if \"storm_gen_article_polished.txt\" in article_file_path_dict\n                else \"storm_gen_article.txt\"\n            )\n            article_data = {\n                \"article\": DemoTextProcessingHelper.parse(\n                    DemoFileIOHelper.read_txt_file(\n                        article_file_path_dict[full_article_name]\n                    )\n                )\n            }\n            if \"url_to_info.json\" in article_file_path_dict:\n                article_data[\"citations\"] = _construct_citation_dict_from_search_result(\n                    DemoFileIOHelper.read_json_file(\n                        article_file_path_dict[\"url_to_info.json\"]\n                    )\n                )\n            if \"conversation_log.json\" in article_file_path_dict:\n                article_data[\"conversation_log\"] = DemoFileIOHelper.read_json_file(\n                    article_file_path_dict[\"conversation_log.json\"]\n                )\n            return article_data\n        return None\n\n\nclass DemoTextProcessingHelper:\n    @staticmethod\n    def remove_citations(sent):\n        return (\n            re.sub(r\"\\[\\d+\", \"\", re.sub(r\" \\[\\d+\", \"\", sent))\n            .replace(\" |\", \"\")\n            .replace(\"]\", \"\")\n        )\n\n    @staticmethod\n    def parse_conversation_history(json_data):\n        \"\"\"\n        Given conversation log data, return list of parsed data of following format\n        (persona_name, persona_description, list of dialogue turn)\n        \"\"\"\n        parsed_data = []\n        for persona_conversation_data in json_data:\n            if \": \" in persona_conversation_data[\"perspective\"]:\n                name, description = persona_conversation_data[\"perspective\"].split(\n                    \": \", 1\n                )\n            elif \"- \" in persona_conversation_data[\"perspective\"]:\n                name, description = persona_conversation_data[\"perspective\"].split(\n                    \"- \", 1\n                )\n            else:\n                name, description = \"\", persona_conversation_data[\"perspective\"]\n            cur_conversation = []\n            for dialogue_turn in persona_conversation_data[\"dlg_turns\"]:\n                cur_conversation.append(\n                    {\"role\": \"user\", \"content\": dialogue_turn[\"user_utterance\"]}\n                )\n                cur_conversation.append(\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": DemoTextProcessingHelper.remove_citations(\n                            dialogue_turn[\"agent_utterance\"]\n                        ),\n                    }\n                )\n            parsed_data.append((name, description, cur_conversation))\n        return parsed_data\n\n    @staticmethod\n    def parse(text):\n        regex = re.compile(r']:\\s+\"(.*?)\"\\s+http')\n        text = regex.sub(\"]: http\", text)\n        return text\n\n    @staticmethod\n    def add_markdown_indentation(input_string):\n        lines = input_string.split(\"\\n\")\n        processed_lines = [\"\"]\n        for line in lines:\n            num_hashes = 0\n            for char in line:\n                if char == \"#\":\n                    num_hashes += 1\n                else:\n                    break\n            num_hashes -= 1\n            num_spaces = 4 * num_hashes\n            new_line = \" \" * num_spaces + line\n            processed_lines.append(new_line)\n        return \"\\n\".join(processed_lines)\n\n    @staticmethod\n    def get_current_time_string():\n        \"\"\"\n        Returns the current time in the California time zone as a string.\n\n        Returns:\n            str: The current California time in 'YYYY-MM-DD HH:MM:SS' format.\n        \"\"\"\n        california_tz = pytz.timezone(\"America/Los_Angeles\")\n        utc_now = datetime.datetime.now(datetime.timezone.utc)\n        california_now = utc_now.astimezone(california_tz)\n        return california_now.strftime(\"%Y-%m-%d %H:%M:%S\")\n\n    @staticmethod\n    def compare_time_strings(\n        time_string1, time_string2, time_format=\"%Y-%m-%d %H:%M:%S\"\n    ):\n        \"\"\"\n        Compares two time strings to determine if they represent the same point in time.\n\n        Args:\n            time_string1 (str): The first time string to compare.\n            time_string2 (str): The second time string to compare.\n            time_format (str): The format of the time strings, defaults to '%Y-%m-%d %H:%M:%S'.\n\n        Returns:\n            bool: True if the time strings represent the same time, False otherwise.\n        \"\"\"\n        # Parse the time strings into datetime objects\n        time1 = datetime.datetime.strptime(time_string1, time_format)\n        time2 = datetime.datetime.strptime(time_string2, time_format)\n\n        # Compare the datetime objects\n        return time1 == time2\n\n    @staticmethod\n    def add_inline_citation_link(article_text, citation_dict):\n        # Regular expression to find citations like [i]\n        pattern = r\"\\[(\\d+)\\]\"\n\n        # Function to replace each citation with its Markdown link\n        def replace_with_link(match):\n            i = match.group(1)\n            url = citation_dict.get(int(i), {}).get(\"url\", \"#\")\n            return f\"[[{i}]]({url})\"\n\n        # Replace all citations in the text with Markdown links\n        return re.sub(pattern, replace_with_link, article_text)\n\n    @staticmethod\n    def generate_html_toc(md_text):\n        toc = []\n        for line in md_text.splitlines():\n            if line.startswith(\"#\"):\n                level = line.count(\"#\")\n                title = line.strip(\"# \").strip()\n                anchor = title.lower().replace(\" \", \"-\").replace(\".\", \"\")\n                toc.append(\n                    f\"<li style='margin-left: {20 * (level - 1)}px;'><a href='#{anchor}'>{title}</a></li>\"\n                )\n        return \"<ul>\" + \"\".join(toc) + \"</ul>\"\n\n    @staticmethod\n    def construct_bibliography_from_url_to_info(url_to_info):\n        bibliography_list = []\n        sorted_url_to_unified_index = dict(\n            sorted(\n                url_to_info[\"url_to_unified_index\"].items(), key=lambda item: item[1]\n            )\n        )\n        for url, index in sorted_url_to_unified_index.items():\n            title = url_to_info[\"url_to_info\"][url][\"title\"]\n            bibliography_list.append(f\"[{index}]: [{title}]({url})\")\n        bibliography_string = \"\\n\\n\".join(bibliography_list)\n        return f\"# References\\n\\n{bibliography_string}\"\n\n\nclass DemoUIHelper:\n    def st_markdown_adjust_size(content, font_size=20):\n        st.markdown(\n            f\"\"\"\n        <span style='font-size: {font_size}px;'>{content}</span>\n        \"\"\",\n            unsafe_allow_html=True,\n        )\n\n    @staticmethod\n    def get_article_card_UI_style(boarder_color=\"#9AD8E1\"):\n        return {\n            \"card\": {\n                \"width\": \"100%\",\n                \"height\": \"116px\",\n                \"max-width\": \"640px\",\n                \"background-color\": \"#FFFFF\",\n                \"border\": \"1px solid #CCC\",\n                \"padding\": \"20px\",\n                \"border-radius\": \"5px\",\n                \"border-left\": f\"0.5rem solid {boarder_color}\",\n                \"box-shadow\": \"0 0.15rem 1.75rem 0 rgba(58, 59, 69, 0.15)\",\n                \"margin\": \"0px\",\n            },\n            \"title\": {\n                \"white-space\": \"nowrap\",\n                \"overflow\": \"hidden\",\n                \"text-overflow\": \"ellipsis\",\n                \"font-size\": \"17px\",\n                \"color\": \"rgb(49, 51, 63)\",\n                \"text-align\": \"left\",\n                \"width\": \"95%\",\n                \"font-weight\": \"normal\",\n            },\n            \"text\": {\n                \"white-space\": \"nowrap\",\n                \"overflow\": \"hidden\",\n                \"text-overflow\": \"ellipsis\",\n                \"font-size\": \"25px\",\n                \"color\": \"rgb(49, 51, 63)\",\n                \"text-align\": \"left\",\n                \"width\": \"95%\",\n            },\n            \"filter\": {\"background-color\": \"rgba(0, 0, 0, 0)\"},\n        }\n\n    @staticmethod\n    def customize_toast_css_style():\n        # Note padding is top right bottom left\n        st.markdown(\n            \"\"\"\n            <style>\n\n                div[data-testid=stToast] {\n                    padding: 20px 10px 40px 10px;\n                    background-color: #FF0000;   /* red */\n                    width: 40%;\n                }\n\n                [data-testid=toastContainer] [data-testid=stMarkdownContainer] > p {\n                    font-size: 25px;\n                    font-style: normal;\n                    font-weight: 400;\n                    color: #FFFFFF;   /* white */\n                    line-height: 1.5; /* Adjust this value as needed */\n                }\n            </style>\n            \"\"\",\n            unsafe_allow_html=True,\n        )\n\n    @staticmethod\n    def article_markdown_to_html(article_title, article_content):\n        return f\"\"\"\n        <html>\n            <head>\n                <meta charset=\"utf-8\">\n                <title>{article_title}</title>\n                <style>\n                    .title {{\n                        text-align: center;\n                    }}\n                </style>\n            </head>\n            <body>\n                <div class=\"title\">\n                    <h1>{article_title.replace('_', ' ')}</h1>\n                </div>\n                <h2>Table of Contents</h2>\n                {DemoTextProcessingHelper.generate_html_toc(article_content)}\n                {markdown.markdown(article_content)}\n            </body>\n        </html>\n        \"\"\"\n\n\ndef _construct_citation_dict_from_search_result(search_results):\n    if search_results is None:\n        return None\n    citation_dict = {}\n    for url, index in search_results[\"url_to_unified_index\"].items():\n        citation_dict[index] = {\n            \"url\": url,\n            \"title\": search_results[\"url_to_info\"][url][\"title\"],\n            \"snippets\": search_results[\"url_to_info\"][url][\"snippets\"],\n        }\n    return citation_dict\n\n\ndef _display_main_article_text(article_text, citation_dict, table_content_sidebar):\n    # Post-process the generated article for better display.\n    if \"Write the lead section:\" in article_text:\n        article_text = article_text[\n            article_text.find(\"Write the lead section:\")\n            + len(\"Write the lead section:\") :\n        ]\n    if article_text[0] == \"#\":\n        article_text = \"\\n\".join(article_text.split(\"\\n\")[1:])\n    article_text = DemoTextProcessingHelper.add_inline_citation_link(\n        article_text, citation_dict\n    )\n    # '$' needs to be changed to '\\$' to avoid being interpreted as LaTeX in st.markdown()\n    article_text = article_text.replace(\"$\", \"\\\\$\")\n    stoc.from_markdown(article_text, table_content_sidebar)\n\n\ndef _display_references(citation_dict):\n    if citation_dict:\n        reference_list = [f\"reference [{i}]\" for i in range(1, len(citation_dict) + 1)]\n        selected_key = st.selectbox(\"Select a reference\", reference_list)\n        citation_val = citation_dict[reference_list.index(selected_key) + 1]\n        citation_val[\"title\"] = citation_val[\"title\"].replace(\"$\", \"\\\\$\")\n        st.markdown(f\"**Title:** {citation_val['title']}\")\n        st.markdown(f\"**Url:** {citation_val['url']}\")\n        snippets = \"\\n\\n\".join(citation_val[\"snippets\"]).replace(\"$\", \"\\\\$\")\n        st.markdown(f\"**Highlights:**\\n\\n {snippets}\")\n    else:\n        st.markdown(\"**No references available**\")\n\n\ndef _display_persona_conversations(conversation_log):\n    \"\"\"\n    Display persona conversation in dialogue UI\n    \"\"\"\n    # get personas list as (persona_name, persona_description, dialogue turns list) tuple\n    parsed_conversation_history = DemoTextProcessingHelper.parse_conversation_history(\n        conversation_log\n    )\n    # construct tabs for each persona conversation\n    persona_tabs = st.tabs([name for (name, _, _) in parsed_conversation_history])\n    for idx, persona_tab in enumerate(persona_tabs):\n        with persona_tab:\n            # show persona description\n            st.info(parsed_conversation_history[idx][1])\n            # show user / agent utterance in dialogue UI\n            for message in parsed_conversation_history[idx][2]:\n                message[\"content\"] = message[\"content\"].replace(\"$\", \"\\\\$\")\n                with st.chat_message(message[\"role\"]):\n                    if message[\"role\"] == \"user\":\n                        st.markdown(f\"**{message['content']}**\")\n                    else:\n                        st.markdown(message[\"content\"])\n\n\ndef _display_main_article(\n    selected_article_file_path_dict, show_reference=True, show_conversation=True\n):\n    article_data = DemoFileIOHelper.assemble_article_data(\n        selected_article_file_path_dict\n    )\n\n    with st.container(height=1000, border=True):\n        table_content_sidebar = st.sidebar.expander(\n            \"**Table of contents**\", expanded=True\n        )\n        _display_main_article_text(\n            article_text=article_data.get(\"article\", \"\"),\n            citation_dict=article_data.get(\"citations\", {}),\n            table_content_sidebar=table_content_sidebar,\n        )\n\n    # display reference panel\n    if show_reference and \"citations\" in article_data:\n        with st.sidebar.expander(\"**References**\", expanded=True):\n            with st.container(height=800, border=False):\n                _display_references(citation_dict=article_data.get(\"citations\", {}))\n\n    # display conversation history\n    if show_conversation and \"conversation_log\" in article_data:\n        with st.expander(\n            \"**STORM** is powered by a knowledge agent that proactively research a given topic by asking good questions coming from different perspectives.\\n\\n\"\n            \":sunglasses: Click here to view the agent's brain**STORM**ing process!\"\n        ):\n            _display_persona_conversations(\n                conversation_log=article_data.get(\"conversation_log\", {})\n            )\n\n\ndef get_demo_dir():\n    return os.path.dirname(os.path.abspath(__file__))\n\n\ndef clear_other_page_session_state(page_index: Optional[int]):\n    if page_index is None:\n        keys_to_delete = [key for key in st.session_state if key.startswith(\"page\")]\n    else:\n        keys_to_delete = [\n            key\n            for key in st.session_state\n            if key.startswith(\"page\") and f\"page{page_index}\" not in key\n        ]\n    for key in set(keys_to_delete):\n        del st.session_state[key]\n\n\ndef set_storm_runner():\n    current_working_dir = os.path.join(get_demo_dir(), \"DEMO_WORKING_DIR\")\n    if not os.path.exists(current_working_dir):\n        os.makedirs(current_working_dir)\n\n    # configure STORM runner\n    llm_configs = STORMWikiLMConfigs()\n    llm_configs.init_openai_model(\n        openai_api_key=st.secrets[\"OPENAI_API_KEY\"], openai_type=\"openai\"\n    )\n    llm_configs.set_question_asker_lm(\n        OpenAIModel(\n            model=\"gpt-4-1106-preview\",\n            api_key=st.secrets[\"OPENAI_API_KEY\"],\n            api_provider=\"openai\",\n            max_tokens=500,\n            temperature=1.0,\n            top_p=0.9,\n        )\n    )\n    engine_args = STORMWikiRunnerArguments(\n        output_dir=current_working_dir,\n        max_conv_turn=3,\n        max_perspective=3,\n        search_top_k=3,\n        retrieve_top_k=5,\n    )\n\n    rm = YouRM(ydc_api_key=st.secrets[\"YDC_API_KEY\"], k=engine_args.search_top_k)\n\n    runner = STORMWikiRunner(engine_args, llm_configs, rm)\n    st.session_state[\"runner\"] = runner\n\n\ndef display_article_page(\n    selected_article_name,\n    selected_article_file_path_dict,\n    show_title=True,\n    show_main_article=True,\n):\n    if show_title:\n        st.markdown(\n            f\"<h2 style='text-align: center;'>{selected_article_name.replace('_', ' ')}</h2>\",\n            unsafe_allow_html=True,\n        )\n\n    if show_main_article:\n        _display_main_article(selected_article_file_path_dict)\n\n\nclass StreamlitCallbackHandler(BaseCallbackHandler):\n    def __init__(self, status_container):\n        self.status_container = status_container\n\n    def on_identify_perspective_start(self, **kwargs):\n        self.status_container.info(\n            \"Start identifying different perspectives for researching the topic.\"\n        )\n\n    def on_identify_perspective_end(self, perspectives: list[str], **kwargs):\n        perspective_list = \"\\n- \".join(perspectives)\n        self.status_container.success(\n            f\"Finish identifying perspectives. Will now start gathering information\"\n            f\" from the following perspectives:\\n- {perspective_list}\"\n        )\n\n    def on_information_gathering_start(self, **kwargs):\n        self.status_container.info(\"Start browsing the Internet.\")\n\n    def on_dialogue_turn_end(self, dlg_turn, **kwargs):\n        urls = list(set([r.url for r in dlg_turn.search_results]))\n        for url in urls:\n            self.status_container.markdown(\n                f\"\"\"\n                    <style>\n                    .small-font {{\n                        font-size: 14px;\n                        margin: 0px;\n                        padding: 0px;\n                    }}\n                    </style>\n                    <div class=\"small-font\">Finish browsing <a href=\"{url}\" class=\"small-font\" target=\"_blank\">{url}</a>.</div>\n                    \"\"\",\n                unsafe_allow_html=True,\n            )\n\n    def on_information_gathering_end(self, **kwargs):\n        self.status_container.success(\"Finish collecting information.\")\n\n    def on_information_organization_start(self, **kwargs):\n        self.status_container.info(\n            \"Start organizing information into a hierarchical outline.\"\n        )\n\n    def on_direct_outline_generation_end(self, outline: str, **kwargs):\n        self.status_container.success(\n            f\"Finish leveraging the internal knowledge of the large language model.\"\n        )\n\n    def on_outline_refinement_end(self, outline: str, **kwargs):\n        self.status_container.success(f\"Finish leveraging the collected information.\")\n"
  },
  {
    "path": "frontend/demo_light/pages_util/CreateNewArticle.py",
    "content": "import os\nimport time\n\nimport demo_util\nimport streamlit as st\nfrom demo_util import (\n    DemoFileIOHelper,\n    DemoTextProcessingHelper,\n    DemoUIHelper,\n    truncate_filename,\n)\n\n\ndef handle_not_started():\n    if st.session_state[\"page3_write_article_state\"] == \"not started\":\n        _, search_form_column, _ = st.columns([2, 5, 2])\n        with search_form_column:\n            with st.form(key=\"search_form\"):\n                # Text input for the search topic\n                DemoUIHelper.st_markdown_adjust_size(\n                    content=\"Enter the topic you want to learn in depth:\", font_size=18\n                )\n                st.session_state[\"page3_topic\"] = st.text_input(\n                    label=\"page3_topic\", label_visibility=\"collapsed\"\n                )\n                pass_appropriateness_check = True\n\n                # Submit button for the form\n                submit_button = st.form_submit_button(label=\"Research\")\n                # only start new search when button is clicked, not started, or already finished previous one\n                if submit_button and st.session_state[\"page3_write_article_state\"] in [\n                    \"not started\",\n                    \"show results\",\n                ]:\n                    if not st.session_state[\"page3_topic\"].strip():\n                        pass_appropriateness_check = False\n                        st.session_state[\"page3_warning_message\"] = (\n                            \"topic could not be empty\"\n                        )\n\n                    st.session_state[\"page3_topic_name_cleaned\"] = (\n                        st.session_state[\"page3_topic\"]\n                        .replace(\" \", \"_\")\n                        .replace(\"/\", \"_\")\n                    )\n                    st.session_state[\"page3_topic_name_truncated\"] = truncate_filename(\n                        st.session_state[\"page3_topic_name_cleaned\"]\n                    )\n                    if not pass_appropriateness_check:\n                        st.session_state[\"page3_write_article_state\"] = \"not started\"\n                        alert = st.warning(\n                            st.session_state[\"page3_warning_message\"], icon=\"⚠️\"\n                        )\n                        time.sleep(5)\n                        alert.empty()\n                    else:\n                        st.session_state[\"page3_write_article_state\"] = \"initiated\"\n\n\ndef handle_initiated():\n    if st.session_state[\"page3_write_article_state\"] == \"initiated\":\n        current_working_dir = os.path.join(demo_util.get_demo_dir(), \"DEMO_WORKING_DIR\")\n        if not os.path.exists(current_working_dir):\n            os.makedirs(current_working_dir)\n\n        if \"runner\" not in st.session_state:\n            demo_util.set_storm_runner()\n        st.session_state[\"page3_current_working_dir\"] = current_working_dir\n        st.session_state[\"page3_write_article_state\"] = \"pre_writing\"\n\n\ndef handle_pre_writing():\n    if st.session_state[\"page3_write_article_state\"] == \"pre_writing\":\n        status = st.status(\n            \"I am brain**STORM**ing now to research the topic. (This may take 2-3 minutes.)\"\n        )\n        st_callback_handler = demo_util.StreamlitCallbackHandler(status)\n        with status:\n            # STORM main gen outline\n            st.session_state[\"runner\"].run(\n                topic=st.session_state[\"page3_topic\"],\n                do_research=True,\n                do_generate_outline=True,\n                do_generate_article=False,\n                do_polish_article=False,\n                callback_handler=st_callback_handler,\n            )\n            conversation_log_path = os.path.join(\n                st.session_state[\"page3_current_working_dir\"],\n                st.session_state[\"page3_topic_name_truncated\"],\n                \"conversation_log.json\",\n            )\n            demo_util._display_persona_conversations(\n                DemoFileIOHelper.read_json_file(conversation_log_path)\n            )\n            st.session_state[\"page3_write_article_state\"] = \"final_writing\"\n            status.update(label=\"brain**STORM**ing complete!\", state=\"complete\")\n\n\ndef handle_final_writing():\n    if st.session_state[\"page3_write_article_state\"] == \"final_writing\":\n        # polish final article\n        with st.status(\n            \"Now I will connect the information I found for your reference. (This may take 4-5 minutes.)\"\n        ) as status:\n            st.info(\n                \"Now I will connect the information I found for your reference. (This may take 4-5 minutes.)\"\n            )\n            st.session_state[\"runner\"].run(\n                topic=st.session_state[\"page3_topic\"],\n                do_research=False,\n                do_generate_outline=False,\n                do_generate_article=True,\n                do_polish_article=True,\n                remove_duplicate=False,\n            )\n            # finish the session\n            st.session_state[\"runner\"].post_run()\n\n            # update status bar\n            st.session_state[\"page3_write_article_state\"] = \"prepare_to_show_result\"\n            status.update(label=\"information snythesis complete!\", state=\"complete\")\n\n\ndef handle_prepare_to_show_result():\n    if st.session_state[\"page3_write_article_state\"] == \"prepare_to_show_result\":\n        _, show_result_col, _ = st.columns([4, 3, 4])\n        with show_result_col:\n            if st.button(\"show final article\"):\n                st.session_state[\"page3_write_article_state\"] = \"completed\"\n                st.rerun()\n\n\ndef handle_completed():\n    if st.session_state[\"page3_write_article_state\"] == \"completed\":\n        # display polished article\n        current_working_dir_paths = DemoFileIOHelper.read_structure_to_dict(\n            st.session_state[\"page3_current_working_dir\"]\n        )\n        current_article_file_path_dict = current_working_dir_paths[\n            st.session_state[\"page3_topic_name_truncated\"]\n        ]\n        demo_util.display_article_page(\n            selected_article_name=st.session_state[\"page3_topic_name_cleaned\"],\n            selected_article_file_path_dict=current_article_file_path_dict,\n            show_title=True,\n            show_main_article=True,\n        )\n\n\ndef create_new_article_page():\n    demo_util.clear_other_page_session_state(page_index=3)\n\n    if \"page3_write_article_state\" not in st.session_state:\n        st.session_state[\"page3_write_article_state\"] = \"not started\"\n\n    handle_not_started()\n\n    handle_initiated()\n\n    handle_pre_writing()\n\n    handle_final_writing()\n\n    handle_prepare_to_show_result()\n\n    handle_completed()\n"
  },
  {
    "path": "frontend/demo_light/pages_util/MyArticles.py",
    "content": "import os\n\nimport demo_util\nimport streamlit as st\nfrom demo_util import DemoFileIOHelper, DemoUIHelper\nfrom streamlit_card import card\n\n\n# set page config and display title\ndef my_articles_page():\n    with st.sidebar:\n        _, return_button_col = st.columns([2, 5])\n        with return_button_col:\n            if st.button(\n                \"Select another article\",\n                disabled=\"page2_selected_my_article\" not in st.session_state,\n            ):\n                if \"page2_selected_my_article\" in st.session_state:\n                    del st.session_state[\"page2_selected_my_article\"]\n                st.rerun()\n\n    # sync my articles\n    if \"page2_user_articles_file_path_dict\" not in st.session_state:\n        local_dir = os.path.join(demo_util.get_demo_dir(), \"DEMO_WORKING_DIR\")\n        os.makedirs(local_dir, exist_ok=True)\n        st.session_state[\"page2_user_articles_file_path_dict\"] = (\n            DemoFileIOHelper.read_structure_to_dict(local_dir)\n        )\n\n    # if no feature demo selected, display all featured articles as info cards\n    def article_card_setup(column_to_add, card_title, article_name):\n        with column_to_add:\n            cleaned_article_title = article_name.replace(\"_\", \" \")\n            hasClicked = card(\n                title=\" / \".join(card_title),\n                text=article_name.replace(\"_\", \" \"),\n                image=DemoFileIOHelper.read_image_as_base64(\n                    os.path.join(demo_util.get_demo_dir(), \"assets\", \"void.jpg\")\n                ),\n                styles=DemoUIHelper.get_article_card_UI_style(boarder_color=\"#9AD8E1\"),\n            )\n            if hasClicked:\n                st.session_state[\"page2_selected_my_article\"] = article_name\n                st.rerun()\n\n    if \"page2_selected_my_article\" not in st.session_state:\n        # display article cards\n        my_article_columns = st.columns(3)\n        if len(st.session_state[\"page2_user_articles_file_path_dict\"]) > 0:\n            # get article names\n            article_names = sorted(\n                list(st.session_state[\"page2_user_articles_file_path_dict\"].keys())\n            )\n            # configure pagination\n            pagination = st.container()\n            bottom_menu = st.columns((1, 4, 1, 1, 1))[1:-1]\n            with bottom_menu[2]:\n                batch_size = st.selectbox(\"Page Size\", options=[24, 48, 72])\n            with bottom_menu[1]:\n                total_pages = (\n                    int(len(article_names) / batch_size)\n                    if int(len(article_names) / batch_size) > 0\n                    else 1\n                )\n                current_page = st.number_input(\n                    \"Page\", min_value=1, max_value=total_pages, step=1\n                )\n            with bottom_menu[0]:\n                st.markdown(f\"Page **{current_page}** of **{total_pages}** \")\n            # show article cards\n            with pagination:\n                my_article_count = 0\n                start_index = (current_page - 1) * batch_size\n                end_index = min(current_page * batch_size, len(article_names))\n                for article_name in article_names[start_index:end_index]:\n                    column_to_add = my_article_columns[my_article_count % 3]\n                    my_article_count += 1\n                    article_card_setup(\n                        column_to_add=column_to_add,\n                        card_title=[\"My Article\"],\n                        article_name=article_name,\n                    )\n        else:\n            with my_article_columns[0]:\n                hasClicked = card(\n                    title=\"Get started\",\n                    text=\"Start your first research!\",\n                    image=DemoFileIOHelper.read_image_as_base64(\n                        os.path.join(demo_util.get_demo_dir(), \"assets\", \"void.jpg\")\n                    ),\n                    styles=DemoUIHelper.get_article_card_UI_style(),\n                )\n                if hasClicked:\n                    st.session_state.selected_page = 1\n                    st.session_state[\"manual_selection_override\"] = True\n                    st.session_state[\"rerun_requested\"] = True\n                    st.rerun()\n    else:\n        selected_article_name = st.session_state[\"page2_selected_my_article\"]\n        selected_article_file_path_dict = st.session_state[\n            \"page2_user_articles_file_path_dict\"\n        ][selected_article_name]\n\n        demo_util.display_article_page(\n            selected_article_name=selected_article_name,\n            selected_article_file_path_dict=selected_article_file_path_dict,\n            show_title=True,\n            show_main_article=True,\n        )\n"
  },
  {
    "path": "frontend/demo_light/requirements.txt",
    "content": "streamlit==1.31.1\nstreamlit-card\nmarkdown\nunidecode\nextra-streamlit-components==0.1.60\nstreamlit_extras\ndeprecation==2.1.0\nst-pages==0.4.5\nstreamlit-float\nstreamlit-option-menu"
  },
  {
    "path": "frontend/demo_light/stoc.py",
    "content": "\"\"\"https://github.com/arnaudmiribel/stoc\"\"\"\n\nimport re\n\nimport streamlit as st\nimport unidecode\n\nDISABLE_LINK_CSS = \"\"\"\n<style>\na.toc {\n    color: inherit;\n    text-decoration: none; /* no underline */\n}\n</style>\"\"\"\n\n\nclass stoc:\n    def __init__(self):\n        self.toc_items = list()\n\n    def h1(self, text: str, write: bool = True):\n        if write:\n            st.write(f\"# {text}\")\n        self.toc_items.append((\"h1\", text))\n\n    def h2(self, text: str, write: bool = True):\n        if write:\n            st.write(f\"## {text}\")\n        self.toc_items.append((\"h2\", text))\n\n    def h3(self, text: str, write: bool = True):\n        if write:\n            st.write(f\"### {text}\")\n        self.toc_items.append((\"h3\", text))\n\n    def toc(self, expander):\n        st.write(DISABLE_LINK_CSS, unsafe_allow_html=True)\n        # st.sidebar.caption(\"Table of contents\")\n        if expander is None:\n            expander = st.sidebar.expander(\"**Table of contents**\", expanded=True)\n        with expander:\n            with st.container(height=600, border=False):\n                markdown_toc = \"\"\n                for title_size, title in self.toc_items:\n                    h = int(title_size.replace(\"h\", \"\"))\n                    markdown_toc += (\n                        \" \" * 2 * h\n                        + \"- \"\n                        + f'<a href=\"#{normalize(title)}\" class=\"toc\"> {title}</a> \\n'\n                    )\n                # st.sidebar.write(markdown_toc, unsafe_allow_html=True)\n                st.write(markdown_toc, unsafe_allow_html=True)\n\n    @classmethod\n    def get_toc(cls, markdown_text: str, topic=\"\"):\n        def increase_heading_depth_and_add_top_heading(markdown_text, new_top_heading):\n            lines = markdown_text.splitlines()\n            # Increase the depth of each heading by adding an extra '#'\n            increased_depth_lines = [\n                \"#\" + line if line.startswith(\"#\") else line for line in lines\n            ]\n            # Add the new top-level heading at the beginning\n            increased_depth_lines.insert(0, f\"# {new_top_heading}\")\n            # Re-join the modified lines back into a single string\n            modified_text = \"\\n\".join(increased_depth_lines)\n            return modified_text\n\n        if topic:\n            markdown_text = increase_heading_depth_and_add_top_heading(\n                markdown_text, topic\n            )\n        toc = []\n        for line in markdown_text.splitlines():\n            if line.startswith(\"#\"):\n                # Remove the '#' characters and strip leading/trailing spaces\n                heading_text = line.lstrip(\"#\").strip()\n                # Create slug (lowercase, spaces to hyphens, remove non-alphanumeric characters)\n                slug = (\n                    re.sub(r\"[^a-zA-Z0-9\\s-]\", \"\", heading_text)\n                    .lower()\n                    .replace(\" \", \"-\")\n                )\n                # Determine heading level for indentation\n                level = line.count(\"#\") - 1\n                # Add to the table of contents\n                toc.append(\"  \" * level + f\"- [{heading_text}](#{slug})\")\n        return \"\\n\".join(toc)\n\n    @classmethod\n    def from_markdown(cls, text: str, expander=None):\n        self = cls()\n        for line in text.splitlines():\n            if line.startswith(\"###\"):\n                self.h3(line[3:], write=False)\n            elif line.startswith(\"##\"):\n                self.h2(line[2:], write=False)\n            elif line.startswith(\"#\"):\n                self.h1(line[1:], write=False)\n        # customize markdown font size\n        custom_css = \"\"\"\n        <style>\n            /* Adjust the font size for headings */\n            h1 { font-size: 28px; }\n            h2 { font-size: 24px; }\n            h3 { font-size: 22px; }\n            h4 { font-size: 20px; }\n            h5 { font-size: 18px; }\n            /* Adjust the font size for normal text */\n            p { font-size: 18px; }\n        </style>\n        \"\"\"\n        st.markdown(custom_css, unsafe_allow_html=True)\n\n        st.write(text)\n        self.toc(expander=expander)\n\n\ndef normalize(s):\n    \"\"\"\n    Normalize titles as valid HTML ids for anchors\n    >>> normalize(\"it's a test to spot how Things happ3n héhé\")\n    \"it-s-a-test-to-spot-how-things-happ3n-h-h\"\n    \"\"\"\n\n    # Replace accents with \"-\"\n    s_wo_accents = unidecode.unidecode(s)\n    accents = [s for s in s if s not in s_wo_accents]\n    for accent in accents:\n        s = s.replace(accent, \"-\")\n\n    # Lowercase\n    s = s.lower()\n\n    # Keep only alphanum and remove \"-\" suffix if existing\n    normalized = (\n        \"\".join([char if char.isalnum() else \"-\" for char in s]).strip(\"-\").lower()\n    )\n\n    return normalized\n"
  },
  {
    "path": "frontend/demo_light/storm.py",
    "content": "import os\n\nscript_dir = os.path.dirname(os.path.abspath(__file__))\nwiki_root_dir = os.path.dirname(os.path.dirname(script_dir))\n\nimport demo_util\nfrom pages_util import MyArticles, CreateNewArticle\nfrom streamlit_float import *\nfrom streamlit_option_menu import option_menu\n\n\ndef main():\n    global database\n    st.set_page_config(layout=\"wide\")\n\n    if \"first_run\" not in st.session_state:\n        st.session_state[\"first_run\"] = True\n\n    # set api keys from secrets\n    if st.session_state[\"first_run\"]:\n        for key, value in st.secrets.items():\n            if type(value) == str:\n                os.environ[key] = value\n\n    # initialize session_state\n    if \"selected_article_index\" not in st.session_state:\n        st.session_state[\"selected_article_index\"] = 0\n    if \"selected_page\" not in st.session_state:\n        st.session_state[\"selected_page\"] = 0\n    if st.session_state.get(\"rerun_requested\", False):\n        st.session_state[\"rerun_requested\"] = False\n        st.rerun()\n\n    st.write(\n        \"<style>div.block-container{padding-top:2rem;}</style>\", unsafe_allow_html=True\n    )\n    menu_container = st.container()\n    with menu_container:\n        pages = [\"My Articles\", \"Create New Article\"]\n        styles = {\n            \"container\": {\"padding\": \"0.2rem 0\", \"background-color\": \"#22222200\"},\n        }\n        menu_selection = option_menu(\n            None,\n            pages,\n            icons=[\"house\", \"search\"],\n            menu_icon=\"cast\",\n            default_index=0,\n            orientation=\"horizontal\",\n            manual_select=st.session_state.selected_page,\n            styles=styles,\n            key=\"menu_selection\",\n        )\n        if st.session_state.get(\"manual_selection_override\", False):\n            menu_selection = pages[st.session_state[\"selected_page\"]]\n            st.session_state[\"manual_selection_override\"] = False\n            st.session_state[\"selected_page\"] = None\n\n        if menu_selection == \"My Articles\":\n            demo_util.clear_other_page_session_state(page_index=2)\n            MyArticles.my_articles_page()\n        elif menu_selection == \"Create New Article\":\n            demo_util.clear_other_page_session_state(page_index=3)\n            CreateNewArticle.create_new_article_page()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "knowledge_storm/__init__.py",
    "content": "from .storm_wiki import *\nfrom .collaborative_storm import *\nfrom .encoder import *\nfrom .interface import *\nfrom .lm import *\nfrom .rm import *\nfrom .utils import *\nfrom .dataclass import *\n\n__version__ = \"1.1.0\"\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/__init__.py",
    "content": "from .modules import *\nfrom .engine import *\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/engine.py",
    "content": "import dspy\nimport os\nfrom dataclasses import dataclass, field, asdict\nfrom typing import List, Union, Literal, Optional, Dict\n\nfrom .modules import collaborative_storm_utils as collaborative_storm_utils\nfrom .modules.callback import BaseCallbackHandler\nfrom .modules.co_storm_agents import (\n    SimulatedUser,\n    PureRAGAgent,\n    Moderator,\n    CoStormExpert,\n)\nfrom .modules.expert_generation import GenerateExpertModule\nfrom .modules.warmstart_hierarchical_chat import WarmStartModule\nfrom ..dataclass import ConversationTurn, KnowledgeBase\nfrom ..encoder import Encoder\nfrom ..interface import LMConfigs, Agent\nfrom ..logging_wrapper import LoggingWrapper\nfrom ..lm import LitellmModel\nfrom ..rm import BingSearch\n\n\nclass CollaborativeStormLMConfigs(LMConfigs):\n    \"\"\"Configurations for LLM used in different parts of Co-STORM.\n\n    Given that different parts in Co-STORM framework have different complexity, we use different LLM configurations\n    to achieve a balance between quality and efficiency. If no specific configuration is provided, we use the default\n    setup in the paper.\n    \"\"\"\n\n    def __init__(self):\n        self.question_answering_lm = None\n        self.discourse_manage_lm = None\n        self.utterance_polishing_lm = None\n        self.warmstart_outline_gen_lm = None\n        self.question_asking_lm = None\n        self.knowledge_base_lm = None\n\n    def init(\n        self,\n        lm_type: Literal[\"openai\", \"azure\", \"together\"],\n        temperature: Optional[float] = 1.0,\n        top_p: Optional[float] = 0.9,\n    ):\n        if lm_type and lm_type == \"openai\":\n            openai_kwargs = {\n                \"api_key\": os.getenv(\"OPENAI_API_KEY\"),\n                \"temperature\": temperature,\n                \"top_p\": top_p,\n                \"api_base\": None,\n            }\n            self.question_answering_lm = LitellmModel(\n                model=\"gpt-4o-2024-05-13\", max_tokens=1000, **openai_kwargs\n            )\n            self.discourse_manage_lm = LitellmModel(\n                model=\"gpt-4o-2024-05-13\", max_tokens=500, **openai_kwargs\n            )\n            self.utterance_polishing_lm = LitellmModel(\n                model=\"gpt-4o-2024-05-13\", max_tokens=2000, **openai_kwargs\n            )\n            self.warmstart_outline_gen_lm = LitellmModel(\n                model=\"gpt-4-1106-preview\", max_tokens=500, **openai_kwargs\n            )\n            self.question_asking_lm = LitellmModel(\n                model=\"gpt-4o-2024-05-13\", max_tokens=300, **openai_kwargs\n            )\n            self.knowledge_base_lm = LitellmModel(\n                model=\"gpt-4o-2024-05-13\", max_tokens=1000, **openai_kwargs\n            )\n        elif lm_type and lm_type == \"azure\":\n            azure_kwargs = {\n                \"api_key\": os.getenv(\"AZURE_API_KEY\"),\n                \"temperature\": temperature,\n                \"top_p\": top_p,\n                \"api_base\": os.getenv(\"AZURE_API_BASE\"),\n                \"api_version\": os.getenv(\"AZURE_API_VERSION\"),\n            }\n            self.question_answering_lm = LitellmModel(\n                model=\"azure/gpt-4o\", max_tokens=1000, **azure_kwargs, model_type=\"chat\"\n            )\n            self.discourse_manage_lm = LitellmModel(\n                model=\"azure/gpt-4o\", max_tokens=500, **azure_kwargs, model_type=\"chat\"\n            )\n            self.utterance_polishing_lm = LitellmModel(\n                model=\"azure/gpt-4o\", max_tokens=2000, **azure_kwargs, model_type=\"chat\"\n            )\n            self.warmstart_outline_gen_lm = LitellmModel(\n                model=\"azure/gpt-4o\", max_tokens=300, **azure_kwargs, model_type=\"chat\"\n            )\n            self.question_asking_lm = LitellmModel(\n                model=\"azure/gpt-4o\", max_tokens=300, **azure_kwargs, model_type=\"chat\"\n            )\n            self.knowledge_base_lm = LitellmModel(\n                model=\"azure/gpt-4o\", max_tokens=1000, **azure_kwargs, model_type=\"chat\"\n            )\n        elif lm_type and lm_type == \"together\":\n            together_kwargs = {\n                \"api_key\": os.getenv(\"TOGETHER_API_KEY\"),\n                \"temperature\": temperature,\n                \"top_p\": top_p,\n            }\n            self.question_answering_lm = LitellmModel(\n                model=\"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n                max_tokens=1000,\n                model_type=\"chat\",\n                **together_kwargs,\n            )\n            self.discourse_manage_lm = LitellmModel(\n                model=\"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n                max_tokens=500,\n                model_type=\"chat\",\n                **together_kwargs,\n            )\n            self.utterance_polishing_lm = LitellmModel(\n                model=\"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n                max_tokens=2000,\n                model_type=\"chat\",\n                **together_kwargs,\n            )\n            self.warmstart_outline_gen_lm = LitellmModel(\n                model=\"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n                max_tokens=500,\n                model_type=\"chat\",\n                **together_kwargs,\n            )\n            self.question_asking_lm = LitellmModel(\n                model=\"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n                max_tokens=300,\n                model_type=\"chat\",\n                **together_kwargs,\n            )\n            self.knowledge_base_lm = LitellmModel(\n                model=\"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n                max_tokens=1000,\n                model_type=\"chat\",\n                **together_kwargs,\n            )\n        else:\n            raise Exception(\n                \"No valid OpenAI API provider is provided. Cannot use default LLM configurations.\"\n            )\n\n    def set_question_answering_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.question_answering_lm = model\n\n    def set_discourse_manage_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.discourse_manage_lm = model\n\n    def set_utterance_polishing_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.utterance_polishing_lm = model\n\n    def set_warmstart_outline_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.warmstart_outline_gen_lm = model\n\n    def set_question_asking_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.question_asking_lm = model\n\n    def set_knowledge_base_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.knowledge_base_lm = model\n\n    def collect_and_reset_lm_usage(self):\n        lm_usage = {}\n        for attr_name in self.__dict__:\n            if \"_lm\" in attr_name and hasattr(\n                getattr(self, attr_name), \"get_usage_and_reset\"\n            ):\n                usage = getattr(self, attr_name).get_usage_and_reset()\n                if any(\n                    value[\"prompt_tokens\"] != 0 or value[\"completion_tokens\"] != 0\n                    for value in usage.values()\n                ):\n                    lm_usage[attr_name] = usage\n        return lm_usage\n\n    def to_dict(self):\n        \"\"\"\n        Converts the CollaborativeStormLMConfigs instance to a dictionary representation.\n\n        Returns:\n            dict: The dictionary representation of the CollaborativeStormLMConfigs.\n        \"\"\"\n        config_dict = {}\n        for attr_name in self.__dict__:\n            config_dict[attr_name] = getattr(self, attr_name).kwargs\n        return config_dict\n\n\n@dataclass\nclass RunnerArgument:\n    \"\"\"Arguments for controlling the STORM Wiki pipeline.\"\"\"\n\n    topic: str = field(\n        metadata={\"help\": \"Topic of discourse\"},\n    )\n    retrieve_top_k: int = field(\n        default=10,\n        metadata={\"help\": \"retrieve top k results for each query in retriever\"},\n    )\n    max_search_queries: int = field(\n        default=2,\n        metadata={\n            \"help\": \"Maximum number of search queries to consider for each question.\"\n        },\n    )\n    total_conv_turn: int = field(\n        default=20,\n        metadata={\"help\": \"Maximum number turn in conversation.\"},\n    )\n    max_search_thread: int = field(\n        default=5,\n        metadata={\"help\": \"Maximum number of parallel thread for retriever\"},\n    )\n    max_search_queries_per_turn: int = field(\n        default=3,\n        metadata={\"help\": \"Maximum number of search queries to consider in each turn.\"},\n    )\n    warmstart_max_num_experts: int = field(\n        default=3,\n        metadata={\n            \"help\": \"Max number of experts in perspective guided QA in warm start process\"\n        },\n    )\n    warmstart_max_turn_per_experts: int = field(\n        default=2,\n        metadata={\"help\": \"Max number of turns per perspective in warm start process\"},\n    )\n    warmstart_max_thread: int = field(\n        default=3,\n        metadata={\n            \"help\": \"Max number thread for parallel perspective guided QA in warm start process\"\n        },\n    )\n    max_thread_num: int = field(\n        default=10,\n        metadata={\n            \"help\": \"Maximum number of threads to use. \"\n            \"Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API.\"\n        },\n    )\n    max_num_round_table_experts: int = field(\n        default=2,\n        metadata={\"help\": \"Max number of active experts in round table discussion.\"},\n    )\n    moderator_override_N_consecutive_answering_turn: int = field(\n        default=3,\n        metadata={\n            \"help\": \"Number of consecutive experts answering turn before moderator override the conversation\"\n        },\n    )\n    node_expansion_trigger_count: int = field(\n        default=10,\n        metadata={\n            \"help\": \"Trigger node expansion for node that contain more than N snippets\"\n        },\n    )\n    disable_moderator: bool = field(\n        default=False,\n        metadata={\"help\": \"If True, disable moderator.\"},\n    )\n    disable_multi_experts: bool = field(\n        default=False,\n        metadata={\"help\": \"If True, disable moderator.\"},\n    )\n    rag_only_baseline_mode: bool = field(\n        default=False,\n        metadata={\"help\": \"If True, switch to rag online baseline mode\"},\n    )\n\n    def to_dict(self):\n        \"\"\"\n        Converts the RunnerArgument instance to a dictionary representation.\n\n        Returns:\n            dict: The dictionary representation of the RunnerArgument.\n        \"\"\"\n        return asdict(self)\n\n    @classmethod\n    def from_dict(cls, data):\n        \"\"\"\n        Constructs a RunnerArgument instance from a dictionary representation.\n\n        Args:\n            data (dict): The dictionary representation of the RunnerArgument.\n\n        Returns:\n            RunnerArgument: The constructed RunnerArgument instance.\n        \"\"\"\n        return cls(**data)\n\n\n@dataclass\nclass TurnPolicySpec:\n    \"\"\"\n    Represents the policy specifications for determining the behavior of a conversation turn.\n\n    Attributes:\n        should_reorganize_knowledge_base (bool):\n            A flag that indicates whether the knowledge base should be reorganized after the current turn.\n\n        should_update_experts_list (bool):\n            A flag that indicates whether the list of experts should be updated based on the conversation context.\n\n        should_polish_utterance (bool):\n            A flag that indicates whether the generated utterance should be polished (e.g., refined or rephrased) before it is used in the conversation.\n\n        agent (Agent):\n            The `Agent` responsible for generating utterances or responses during the conversation turn.\n            This agent interacts with the knowledge base and the conversation history to produce responses.\n    \"\"\"\n\n    should_reorganize_knowledge_base: bool = False\n    should_update_experts_list: bool = False\n    should_polish_utterance: bool = False\n    agent: Agent = None\n\n\nclass DiscourseManager:\n    def __init__(\n        self,\n        logging_wrapper: LoggingWrapper,\n        lm_config: CollaborativeStormLMConfigs,\n        runner_argument: RunnerArgument,\n        rm: dspy.Retrieve,\n        encoder: Encoder,\n        callback_handler: BaseCallbackHandler,\n    ):\n        # parameter management\n        self.lm_config = lm_config\n        self.runner_argument = runner_argument\n        self.logging_wrapper = logging_wrapper\n        self.callback_handler = callback_handler\n        self.rm = rm\n        self.encoder = encoder\n        # role management\n        self.experts: List[CoStormExpert] = []\n        self.simulated_user: SimulatedUser = SimulatedUser(\n            topic=self.runner_argument.topic,\n            role_name=\"Guest\",\n            role_description=\"\",\n            intent=None,\n            lm_config=self.lm_config,\n            runner_argument=self.runner_argument,\n            logging_wrapper=self.logging_wrapper,\n            callback_handler=self.callback_handler,\n        )\n        self.pure_rag_agent: PureRAGAgent = PureRAGAgent(\n            topic=self.runner_argument.topic,\n            role_name=\"PureRAG\",\n            role_description=\"\",\n            lm_config=self.lm_config,\n            runner_argument=self.runner_argument,\n            logging_wrapper=self.logging_wrapper,\n            rm=self.rm,\n            callback_handler=self.callback_handler,\n        )\n        self.moderator: Moderator = Moderator(\n            topic=self.runner_argument.topic,\n            role_name=\"Moderator\",\n            role_description=\"\",\n            lm_config=self.lm_config,\n            runner_argument=self.runner_argument,\n            logging_wrapper=self.logging_wrapper,\n            encoder=self.encoder,\n            callback_handler=self.callback_handler,\n        )\n        self.general_knowledge_provider = CoStormExpert(\n            topic=self.runner_argument.topic,\n            role_name=\"General Knowledge Provider\",\n            role_description=\"Focus on broadly covering the basic facts about the question.\",\n            lm_config=self.lm_config,\n            runner_argument=self.runner_argument,\n            logging_wrapper=self.logging_wrapper,\n            rm=self.rm,\n            callback_handler=self.callback_handler,\n        )\n        self.generate_expert_module = GenerateExpertModule(\n            engine=self.lm_config.discourse_manage_lm\n        )\n        self.next_turn_moderator_override = False\n\n    def serialize_experts(self) -> List[Dict]:\n        return [\n            {\n                \"topic\": expert.topic,\n                \"role_name\": expert.role_name,\n                \"role_description\": expert.role_description,\n            }\n            for expert in self.experts\n        ]\n\n    def deserialize_experts(self, data: List[Dict]):\n        for expert_data in data:\n            self.experts.append(\n                CoStormExpert(\n                    topic=expert_data[\"topic\"],\n                    role_name=expert_data[\"role_name\"],\n                    role_description=expert_data[\"role_description\"],\n                    lm_config=self.lm_config,\n                    runner_argument=self.runner_argument,\n                    logging_wrapper=self.logging_wrapper,\n                    rm=self.rm,\n                    callback_handler=self.callback_handler,\n                )\n            )\n\n    def _should_generate_question(\n        self, conversation_history: List[ConversationTurn]\n    ) -> bool:\n        consecutive_non_questioning_turn = 0\n        for conv_turn in reversed(conversation_history):\n            if conv_turn.utterance_type not in [\n                \"Original Question\",\n                \"Information Request\",\n            ]:\n                consecutive_non_questioning_turn += 1\n            else:\n                break\n        return (\n            consecutive_non_questioning_turn\n            >= self.runner_argument.moderator_override_N_consecutive_answering_turn\n        )\n\n    def _parse_expert_names_to_agent(self, expert_descriptions: Union[str, List[str]]):\n        if type(expert_descriptions) == str:\n            expert_descriptions = [expert_descriptions]\n        agents: CoStormExpert = []\n        for expert_name in expert_descriptions:\n            role_name, role_description = expert_name.split(\":\")\n            role_name = role_name.strip()\n            role_description = role_description.strip()\n            new_costorm_expert = CoStormExpert(\n                topic=self.runner_argument.topic,\n                role_name=role_name,\n                role_description=role_description,\n                lm_config=self.lm_config,\n                runner_argument=self.runner_argument,\n                logging_wrapper=self.logging_wrapper,\n                rm=self.rm,\n                callback_handler=self.callback_handler,\n            )\n            agents.append(new_costorm_expert)\n        return agents\n\n    def _update_expert_list_from_utterance(self, focus: str, background_info: str):\n        expert_names = self.generate_expert_module(\n            topic=self.runner_argument.topic,\n            background_info=background_info,\n            focus=focus,\n            num_experts=self.runner_argument.max_num_round_table_experts,\n        ).experts\n        self.experts = self._parse_expert_names_to_agent(expert_names)\n\n    def _is_last_turn_questioning(self, conversation_history: List[ConversationTurn]):\n        return conversation_history and conversation_history[-1].utterance_type in [\n            \"Original Question\",\n            \"Information Request\",\n        ]\n\n    def get_next_turn_policy(\n        self,\n        conversation_history: List[ConversationTurn],\n        dry_run=False,\n        simulate_user=False,\n        simulate_user_intent: str = None,\n    ) -> TurnPolicySpec:\n        next_turn_policy = TurnPolicySpec()\n        if simulate_user:\n            self.simulated_user.intent = simulate_user_intent\n            next_turn_policy.agent = self.simulated_user\n        elif self.runner_argument.rag_only_baseline_mode:\n            assert self.conversation_history[-1].role == \"Guest\"\n            next_turn_policy.agent = self.pure_rag_agent\n        elif self.next_turn_moderator_override:\n            next_turn_policy.agent = self.moderator\n            if not dry_run:\n                self.next_turn_moderator_override = False\n        elif (\n            not self.runner_argument.disable_moderator\n            and self._should_generate_question(conversation_history)\n        ):\n            next_turn_policy.agent = self.moderator\n            next_turn_policy.should_reorganize_knowledge_base = True\n        # experts RAG gen\n        else:\n            next_turn_policy.agent = self.general_knowledge_provider\n            if (\n                not self._is_last_turn_questioning(conversation_history)\n                and not self.runner_argument.disable_multi_experts\n            ):\n                if dry_run:\n                    next_turn_policy.agent = self.experts[0]\n                else:\n                    next_turn_policy.agent = self.experts.pop(0)\n                    self.experts.append(next_turn_policy.agent)\n            next_turn_policy.should_update_experts_list = (\n                self._is_last_turn_questioning(conversation_history)\n                and not self.runner_argument.disable_multi_experts\n            )\n            next_turn_policy.should_polish_utterance = True\n        return next_turn_policy\n\n\nclass CoStormRunner:\n    def __init__(\n        self,\n        lm_config: CollaborativeStormLMConfigs,\n        runner_argument: RunnerArgument,\n        logging_wrapper: LoggingWrapper,\n        rm: Optional[dspy.Retrieve] = None,\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        self.runner_argument = runner_argument\n        self.lm_config = lm_config\n        self.logging_wrapper = logging_wrapper\n        self.callback_handler = callback_handler\n        if rm is None:\n            self.rm = BingSearch(k=runner_argument.retrieve_top_k)\n        else:\n            self.rm = rm\n        self.encoder = Encoder()\n        self.conversation_history = []\n        self.warmstart_conv_archive = []\n        self.knowledge_base = KnowledgeBase(\n            topic=self.runner_argument.topic,\n            knowledge_base_lm=self.lm_config.knowledge_base_lm,\n            node_expansion_trigger_count=self.runner_argument.node_expansion_trigger_count,\n            encoder=self.encoder,\n        )\n        self.discourse_manager = DiscourseManager(\n            lm_config=self.lm_config,\n            runner_argument=self.runner_argument,\n            logging_wrapper=self.logging_wrapper,\n            rm=self.rm,\n            encoder=self.encoder,\n            callback_handler=callback_handler,\n        )\n\n    def to_dict(self):\n        return {\n            \"runner_argument\": self.runner_argument.to_dict(),\n            \"lm_config\": self.lm_config.to_dict(),\n            \"conversation_history\": [\n                turn.to_dict() for turn in self.conversation_history\n            ],\n            \"warmstart_conv_archive\": [\n                turn.to_dict() for turn in self.warmstart_conv_archive\n            ],\n            \"experts\": self.discourse_manager.serialize_experts(),\n            \"knowledge_base\": self.knowledge_base.to_dict(),\n        }\n\n    @classmethod\n    def from_dict(cls, data, callback_handler: BaseCallbackHandler = None):\n        # FIXME: does not use the lm_config data but naively use default setting\n        lm_config = CollaborativeStormLMConfigs()\n        lm_config.init(lm_type=os.getenv(\"OPENAI_API_TYPE\"))\n        costorm_runner = cls(\n            lm_config=lm_config,\n            runner_argument=RunnerArgument.from_dict(data[\"runner_argument\"]),\n            logging_wrapper=LoggingWrapper(lm_config),\n            callback_handler=callback_handler,\n        )\n        costorm_runner.encoder = Encoder()\n        costorm_runner.conversation_history = [\n            ConversationTurn.from_dict(turn) for turn in data[\"conversation_history\"]\n        ]\n        costorm_runner.warmstart_conv_archive = [\n            ConversationTurn.from_dict(turn)\n            for turn in data.get(\"warmstart_conv_archive\", [])\n        ]\n        costorm_runner.discourse_manager.deserialize_experts(data[\"experts\"])\n        costorm_runner.knowledge_base = KnowledgeBase.from_dict(\n            data=data[\"knowledge_base\"],\n            knowledge_base_lm=costorm_runner.lm_config.knowledge_base_lm,\n            node_expansion_trigger_count=costorm_runner.runner_argument.node_expansion_trigger_count,\n            encoder=costorm_runner.encoder,\n        )\n        return costorm_runner\n\n    def warm_start(self):\n        \"\"\"\n        Warm start co-storm system to conduct background information search in order to build shared conceptual space with user.\n        This stage is a mini-STORM, spawning multiple LLM agent with different perspective and perform multi-round conversation.\n        The knowledge base (i.e. mind map) will be initialize using the collected information.\n\n        It will also generate a first draft of report and use it to produce an engaging and concise conversation presented to the\n        user to catch up with system's knowledge about the topic.\n        \"\"\"\n        with self.logging_wrapper.log_pipeline_stage(\n            pipeline_stage=f\"warm start stage\"\n        ):\n            if not self.runner_argument.rag_only_baseline_mode:\n                warm_start_module = WarmStartModule(\n                    lm_config=self.lm_config,\n                    runner_argument=self.runner_argument,\n                    logging_wrapper=self.logging_wrapper,\n                    rm=self.rm,\n                    callback_handler=self.callback_handler,\n                )\n\n                (\n                    warmstart_conv,\n                    warmstart_revised_conv,\n                    warmstart_experts,\n                ) = warm_start_module.initiate_warm_start(\n                    topic=self.runner_argument.topic,\n                    knowledge_base=self.knowledge_base,\n                )\n                self.discourse_manager.experts = (\n                    self.discourse_manager._parse_expert_names_to_agent(\n                        warmstart_experts\n                    )\n                )\n                self.discourse_manager.next_turn_moderator_override = True\n                self.conversation_history = (\n                    warmstart_revised_conv if warmstart_revised_conv else warmstart_conv\n                )\n                self.warmstart_conv_archive = warmstart_conv\n                self.knowledge_base.reorganize()\n            else:\n                if self.knowledge_base is None:\n                    self.knowledge_base = KnowledgeBase(\n                        topic=self.runner_argument.topic,\n                        knowledge_base_lm=self.lm_config.knowledge_base_lm,\n                        node_expansion_trigger_count=self.runner_argument.node_expansion_trigger_count,\n                        encoder=self.encoder,\n                    )\n                if self.conversation_history is None:\n                    self.conversation_history = []\n                conv_turn = (\n                    self.discourse_manager.pure_rag_agent.generate_topic_background()\n                )\n                self.conversation_history.append(conv_turn)\n                self.knowledge_base.update_from_conv_turn(\n                    conv_turn=conv_turn,\n                    allow_create_new_node=True,\n                    insert_under_root=self.runner_argument.rag_only_baseline_mode,\n                )\n\n    def generate_report(self) -> str:\n        \"\"\"\n        Generate report leveraging organized collected information in the knowledge base (i.e. mind map).\n        The article generation follows the paradigm in STORM paper, where it considers mind map nodes as section names, and generate the report section by section.\n\n        Returns:\n            str: A string representing the report, with \"#\" \"##\" indicating hierarchical sections and [1][2] indicating references.\n        \"\"\"\n        with self.logging_wrapper.log_pipeline_stage(\n            f\"report generation after conv turn: {len(self.conversation_history)}\"\n        ):\n            with self.logging_wrapper.log_event(\n                \"report generation stage: generate report\"\n            ):\n                return self.knowledge_base.to_report()\n\n    def dump_logging_and_reset(self):\n        return self.logging_wrapper.dump_logging_and_reset()\n\n    def step(\n        self,\n        user_utterance: str = \"\",\n        simulate_user: bool = False,\n        simulate_user_intent: str = \"\",\n    ) -> ConversationTurn:\n        \"\"\"\n        Yields a single turn in the conversation flow.\n\n        This method take a user input when user choose to inject an utterance or generates the next system utterance based on the current conversation history and defined discourse policies.\n        It handles updating the conversation history, managing expert lists, and interacting with the knowledge base.\n        Additionally, it logs each stage of the conversation for monitoring and debugging purposes.\n\n        Args:\n            user_utterance (str, optional): The input provided by the user. If provided, this utterance is added directly to the conversation history and returns with no further action.\n            simulate_user (bool, optional): This is designed for automatic experiments using a LLM agent to simulate user actions. Flag indicating whether to simulate user behavior. When set to `True`, the system will generate user intents based on predefined simulation logic. Defaults to `False`.\n            simulate_user_intent (str, optional): This is designed for automatic experiments using a LLM agent to simulate user actions. Specifies the intent to simulate for the user. This is used when `simulate_user` is `True` to guide the simulated user's responses,\n\n        Returns:\n            ConversationTurn: An object representing the latest turn in the conversation.\n\n        Workflow:\n            1. User Utterance Handling\n                - If `user_utterance` is provided, it is appended to the `conversation_history`\n\n            2. System Utterance Generation\n                - If no `user_utterance` is provided, the method proceeds to generate the next system utterance.\n                - Determines the next turn policy by consulting the `discourse_manager` with the current conversation history.\n                - Generates a new utterance using the agent defined in the turn policy, leveraging the `knowledge_base` and `conversation_history`.\n                - If the turn policy indicates that the experts list should be updated, it updates the expert list based on the latest utterances.\n\n            4. Knowledge Base Update\n                - Inserts the new turn into the `knowledge_base`, optionally allowing the creation of new nodes or inserting under the root based on the `rag_only_baseline_mode` flag.\n                - If the turn policy specifies, it reorganizes the `knowledge_base` to maintain optimal structure and relevance.\n        \"\"\"\n        last_conv_turn = self.conversation_history[-1]\n        cur_turn_name = f\"conv turn: {len(self.conversation_history) + 1}\"\n        with self.logging_wrapper.log_pipeline_stage(\n            pipeline_stage=f\"{cur_turn_name} stage\"\n        ):\n            conv_turn = None\n            if user_utterance:\n                self.discourse_manager.next_turn_moderator_override = False\n                conv_turn = ConversationTurn(\n                    role=\"Guest\",\n                    raw_utterance=user_utterance,\n                    utterance_type=\"Original Question\",\n                )\n                self.conversation_history.append(conv_turn)\n            else:\n                with self.logging_wrapper.log_event(\n                    f\"{cur_turn_name}: get turn policy\"\n                ):\n                    if self.callback_handler is not None:\n                        self.callback_handler.on_turn_policy_planning_start()\n                    turn_policy = self.discourse_manager.get_next_turn_policy(\n                        conversation_history=self.conversation_history,\n                        simulate_user=simulate_user,\n                        simulate_user_intent=simulate_user_intent,\n                        dry_run=False,\n                    )\n\n                with self.logging_wrapper.log_event(\n                    f\"{cur_turn_name}: generate utterance\"\n                ):\n                    conv_turn = turn_policy.agent.generate_utterance(\n                        knowledge_base=self.knowledge_base,\n                        conversation_history=self.conversation_history,\n                    )\n\n                if turn_policy.should_update_experts_list:\n                    with self.logging_wrapper.log_event(\n                        f\"{cur_turn_name}: update experts list\"\n                    ):\n                        self.discourse_manager._update_expert_list_from_utterance(\n                            focus=last_conv_turn.raw_utterance,\n                            background_info=conv_turn.raw_utterance,\n                        )\n\n                if conv_turn is not None:\n                    self.conversation_history.append(conv_turn)\n                    with self.logging_wrapper.log_event(\n                        f\"{cur_turn_name}: insert into knowledge base\"\n                    ):\n                        if self.callback_handler is not None:\n                            self.callback_handler.on_mindmap_insert_start()\n                        self.knowledge_base.update_from_conv_turn(\n                            conv_turn=conv_turn,\n                            allow_create_new_node=True,\n                            insert_under_root=self.runner_argument.rag_only_baseline_mode,\n                        )\n                        if self.callback_handler is not None:\n                            self.callback_handler.on_mindmap_insert_end()\n                if turn_policy.should_reorganize_knowledge_base:\n                    with self.logging_wrapper.log_event(\n                        f\"{cur_turn_name}: reorganize knowledge base\"\n                    ):\n                        if self.callback_handler is not None:\n                            self.callback_handler.on_mindmap_reorg_start()\n                        self.knowledge_base.reorganize()\n        return conv_turn\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/__init__.py",
    "content": "from .article_generation import *\nfrom .grounded_question_answering import *\nfrom .grounded_question_generation import *\nfrom .information_insertion_module import *\nfrom .simulate_user import *\nfrom .warmstart_hierarchical_chat import *\nfrom .knowledge_base_summary import *\nfrom .costorm_expert_utterance_generator import *\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/article_generation.py",
    "content": "import dspy\r\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\r\nfrom typing import Set, Union\r\n\r\nfrom .collaborative_storm_utils import clean_up_section\r\nfrom ...dataclass import KnowledgeBase, KnowledgeNode\r\n\r\n\r\nclass ArticleGenerationModule(dspy.Module):\r\n    \"\"\"Use the information collected from the information-seeking conversation to write a section.\"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],\r\n    ):\r\n        super().__init__()\r\n        self.write_section = dspy.Predict(WriteSection)\r\n        self.engine = engine\r\n\r\n    def _get_cited_information_string(\r\n        self,\r\n        all_citation_index: Set[int],\r\n        knowledge_base: KnowledgeBase,\r\n        max_words: int = 4000,\r\n    ):\r\n        information = []\r\n        cur_word_count = 0\r\n        for index in sorted(list(all_citation_index)):\r\n            info = knowledge_base.info_uuid_to_info_dict[index]\r\n            snippet = info.snippets[0]\r\n            info_text = f\"[{index}]: {snippet} (Question: {info.meta['question']}. Query: {info.meta['query']})\"\r\n            cur_snippet_length = len(info_text.split())\r\n            if cur_snippet_length + cur_word_count > max_words:\r\n                break\r\n            cur_word_count += cur_snippet_length\r\n            information.append(info_text)\r\n        return \"\\n\".join(information)\r\n\r\n    def gen_section(\r\n        self, topic: str, node: KnowledgeNode, knowledge_base: KnowledgeBase\r\n    ):\r\n        if node is None or len(node.content) == 0:\r\n            return \"\"\r\n        if (\r\n            node.synthesize_output is not None\r\n            and node.synthesize_output\r\n            and not node.need_regenerate_synthesize_output\r\n        ):\r\n            return node.synthesize_output\r\n        all_citation_index = node.collect_all_content()\r\n        information = self._get_cited_information_string(\r\n            all_citation_index=all_citation_index, knowledge_base=knowledge_base\r\n        )\r\n        with dspy.settings.context(lm=self.engine):\r\n            synthesize_output = clean_up_section(\r\n                self.write_section(\r\n                    topic=topic, info=information, section=node.name\r\n                ).output\r\n            )\r\n        node.synthesize_output = synthesize_output\r\n        node.need_regenerate_synthesize_output = False\r\n        return node.synthesize_output\r\n\r\n    def forward(self, knowledge_base: KnowledgeBase):\r\n        all_nodes = knowledge_base.collect_all_nodes()\r\n        node_to_paragraph = {}\r\n\r\n        # Define a function to generate paragraphs for nodes\r\n        def _node_generate_paragraph(node):\r\n            node_gen_paragraph = self.gen_section(\r\n                topic=knowledge_base.topic, node=node, knowledge_base=knowledge_base\r\n            )\r\n            lines = node_gen_paragraph.split(\"\\n\")\r\n            if lines[0].strip().replace(\"*\", \"\").replace(\"#\", \"\") == node.name:\r\n                lines = lines[1:]\r\n            node_gen_paragraph = \"\\n\".join(lines)\r\n            path = \" -> \".join(node.get_path_from_root())\r\n            return path, node_gen_paragraph\r\n\r\n        with ThreadPoolExecutor(max_workers=5) as executor:\r\n            # Submit all tasks\r\n            future_to_node = {\r\n                executor.submit(_node_generate_paragraph, node): node\r\n                for node in all_nodes\r\n            }\r\n\r\n            # Collect the results as they complete\r\n            for future in as_completed(future_to_node):\r\n                path, node_gen_paragraph = future.result()\r\n                node_to_paragraph[path] = node_gen_paragraph\r\n\r\n        def helper(cur_root, level):\r\n            to_return = []\r\n            if cur_root is not None:\r\n                hash_tag = \"#\" * level + \" \"\r\n                cur_path = \" -> \".join(cur_root.get_path_from_root())\r\n                node_gen_paragraph = node_to_paragraph[cur_path]\r\n                to_return.append(f\"{hash_tag}{cur_root.name}\\n{node_gen_paragraph}\")\r\n                for child in cur_root.children:\r\n                    to_return.extend(helper(child, level + 1))\r\n            return to_return\r\n\r\n        to_return = []\r\n        for child in knowledge_base.root.children:\r\n            to_return.extend(helper(child, level=1))\r\n\r\n        return \"\\n\".join(to_return)\r\n\r\n\r\nclass WriteSection(dspy.Signature):\r\n    \"\"\"Write a Wikipedia section based on the collected information. You will be given the topic, the section you are writing and relevant information.\r\n    Each information will be provided with the raw content along with question and query lead to that information.\r\n    Here is the format of your writing:\r\n    Use [1], [2], ..., [n] in line (for example, \"The capital of the United States is Washington, D.C.[1][3].\"). You DO NOT need to include a References or Sources section to list the sources at the end.\r\n    \"\"\"\r\n\r\n    info = dspy.InputField(prefix=\"The collected information:\\n\", format=str)\r\n    topic = dspy.InputField(prefix=\"The topic of the page: \", format=str)\r\n    section = dspy.InputField(prefix=\"The section you need to write: \", format=str)\r\n    output = dspy.OutputField(\r\n        prefix=\"Write the section with proper inline citations (Start your writing. Don't include the page title, section name, or try to write other sections. Do not start the section with topic name.):\\n\",\r\n        format=str,\r\n    )\r\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/callback.py",
    "content": "from typing import List\nfrom ...interface import Information\n\n\nclass BaseCallbackHandler:\n    \"\"\"Base callback handler to manage callbacks from the Co-STORM pipeline.\"\"\"\n\n    def on_turn_policy_planning_start(self, **kwargs):\n        \"\"\"Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn.\"\"\"\n        pass\n\n    def on_expert_action_planning_start(self, **kwargs):\n        \"\"\"Run when the expert action planning begins, preparing to determine the actions that each expert should take.\"\"\"\n        pass\n\n    def on_expert_action_planning_end(self, **kwargs):\n        \"\"\"Run when the expert action planning ends, after deciding the actions that each expert should take.\"\"\"\n        pass\n\n    def on_expert_information_collection_start(self, **kwargs):\n        \"\"\"Run when the expert information collection starts, start gathering all necessary data from selected sources.\"\"\"\n        pass\n\n    def on_expert_information_collection_end(self, info: List[Information], **kwargs):\n        \"\"\"Run when the expert information collection ends, after gathering all necessary data from selected sources.\"\"\"\n        pass\n\n    def on_expert_utterance_generation_end(self, **kwargs):\n        \"\"\"Run when the expert utterance generation ends, before creating responses or statements from each expert.\"\"\"\n        pass\n\n    def on_expert_utterance_polishing_start(self, **kwargs):\n        \"\"\"Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content.\"\"\"\n        pass\n\n    def on_mindmap_insert_start(self, **kwargs):\n        \"\"\"Run when the process of inserting new information into the mindmap starts.\"\"\"\n        pass\n\n    def on_mindmap_insert_end(self, **kwargs):\n        \"\"\"Run when the process of inserting new information into the mindmap ends.\"\"\"\n        pass\n\n    def on_mindmap_reorg_start(self, **kwargs):\n        \"\"\"Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information.\"\"\"\n        pass\n\n    def on_expert_list_update_start(self, **kwargs):\n        \"\"\"Run when the expert list update starts, to modify or refresh the list of active experts.\"\"\"\n        pass\n\n    def on_article_generation_start(self, **kwargs):\n        \"\"\"Run when the article generation process begins, to compile and format the final article content.\"\"\"\n        pass\n\n    def on_warmstart_update(self, message, **kwargs):\n        \"\"\"Run when the warm start process has update.\"\"\"\n        pass\n\n\nclass LocalConsolePrintCallBackHandler(BaseCallbackHandler):\n    def __init__(self):\n        pass\n\n    def on_turn_policy_planning_start(self, **kwargs):\n        \"\"\"Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn.\"\"\"\n        print(\"Start planning next expert; inspect mind map; inspect system state.\")\n\n    def on_expert_action_planning_start(self, **kwargs):\n        \"\"\"Run when the expert action planning begins, preparing to determine the actions that each expert should take.\"\"\"\n        print(\"Reviewing discourse history; Deciding utterance intent.\")\n\n    def on_expert_information_collection_start(self, **kwargs):\n        \"\"\"Run when the expert information collection ends, after gathering all necessary data from selected sources.\"\"\"\n        print(\"Start searching with the search engine; browsing collected information.\")\n\n    def on_expert_information_collection_end(self, info: List[Information], **kwargs):\n        \"\"\"Run when the expert information collection ends, after gathering all necessary data from selected sources.\"\"\"\n        if info:\n            urls = [i.url for i in info]\n            information_string = \"\\n\".join([f\"Finish browsing {url}\" for url in urls])\n            print(information_string)\n\n    def on_expert_utterance_generation_end(self, **kwargs):\n        \"\"\"Run when the expert utterance generation ends, before creating responses or statements from each expert.\"\"\"\n        print(\"Finish generating utterance from collected information.\")\n\n    def on_expert_utterance_polishing_start(self, **kwargs):\n        \"\"\"Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content.\"\"\"\n        print(\"Start polishing utterance.\")\n\n    def on_mindmap_insert_start(self, **kwargs):\n        \"\"\"Run when the process of inserting new information into the mindmap starts.\"\"\"\n        print(\"Start inserting information into mind map.\")\n\n    def on_mindmap_insert_end(self, **kwargs):\n        \"\"\"Run when the process of inserting new information into the mindmap ends.\"\"\"\n        print(\"Finish inserting information into mind map.\")\n\n    def on_mindmap_reorg_start(self, **kwargs):\n        \"\"\"Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information.\"\"\"\n        print(\"Start re-organizing mind map.\")\n\n    def on_expert_list_update_start(self, **kwargs):\n        \"\"\"Run when the expert list update starts, to modify or refresh the list of active experts.\"\"\"\n        print(\"Start updating expert candidates.\")\n\n    def on_warmstart_update(self, message, **kwargs):\n        \"\"\"Run when the warm start process has update.\"\"\"\n        print(f\"Warm start update: {message}\")\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/co_storm_agents.py",
    "content": "import dspy\nfrom itertools import zip_longest\nimport numpy as np\nfrom sklearn.metrics.pairwise import cosine_similarity\nfrom typing import List, Optional, TYPE_CHECKING\n\nfrom .callback import BaseCallbackHandler\nfrom .collaborative_storm_utils import (\n    extract_storm_info_snippet,\n    _get_answer_question_module_instance,\n)\nfrom .costorm_expert_utterance_generator import CoStormExpertUtteranceGenerationModule\nfrom .grounded_question_generation import GroundedQuestionGenerationModule\nfrom .simulate_user import GenSimulatedUserUtterance\nfrom ...dataclass import ConversationTurn, KnowledgeBase\nfrom ...encoder import Encoder\nfrom ...interface import Agent, Information, LMConfigs\nfrom ...logging_wrapper import LoggingWrapper\n\nif TYPE_CHECKING:\n    from ..engine import RunnerArgument\n\n\nclass CoStormExpert(Agent):\n    \"\"\"\n    Represents an expert agent in the Co-STORM framework.\n    The `CoStormExpert` is a specialized type of `Agent` that is tasked with participating in roundtable discussions within the Co-STORM system.\n    The expert uses language models to generate action plans, answer questions, and polish its utterances based on the current conversation history and knowledge base.\n      It interacts with modules for action planning and question answering grounding on provided retrieval models.\n\n    Args:\n        topic (str): The conversation topic that the expert specializes in.\n        role_name (str): The perspective of the expert's role (e.g. AI enthusiast, drug discovery expert, etc.)\n        role_description (str): A description of the perspective of the experts\n        lm_config (LMConfigs): Configuration for the language models\n        runner_argument (RunnerArgument): Co-STORM runner argument\n        logging_wrapper (LoggingWrapper): An instance of `LoggingWrapper` to log events.\n        rm (Optional[dspy.Retrieve], optional): A retrieval module used for fetching external knowledge or context.\n        callback_handler (BaseCallbackHandler, optional): Handles log message printing\n    \"\"\"\n\n    def __init__(\n        self,\n        topic: str,\n        role_name: str,\n        role_description: str,\n        lm_config: LMConfigs,\n        runner_argument: \"RunnerArgument\",\n        logging_wrapper: LoggingWrapper,\n        rm: Optional[dspy.Retrieve] = None,\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        super().__init__(topic, role_name, role_description)\n        self.lm_config = lm_config\n        self.runner_argument = runner_argument\n        self.logging_wrapper = logging_wrapper\n        self.callback_handler = callback_handler\n        self.costorm_agent_utterance_generator = (\n            self._get_costorm_expert_utterance_generator(rm=rm)\n        )\n\n    def _get_costorm_expert_utterance_generator(\n        self, rm: Optional[dspy.Retrieve] = None\n    ):\n        return CoStormExpertUtteranceGenerationModule(\n            action_planning_lm=self.lm_config.discourse_manage_lm,\n            utterance_polishing_lm=self.lm_config.utterance_polishing_lm,\n            answer_question_module=_get_answer_question_module_instance(\n                lm_config=self.lm_config,\n                runner_argument=self.runner_argument,\n                logging_wrapper=self.logging_wrapper,\n                rm=rm,\n            ),\n            logging_wrapper=self.logging_wrapper,\n            callback_handler=self.callback_handler,\n        )\n\n    def generate_utterance(\n        self,\n        knowledge_base: KnowledgeBase,\n        conversation_history: List[ConversationTurn],\n    ):\n        with self.logging_wrapper.log_event(\n            \"CoStormExpert generate utternace: get knowledge base summary\"\n        ):\n            if self.callback_handler is not None:\n                self.callback_handler.on_expert_action_planning_start()\n            conversation_summary = knowledge_base.get_knowledge_base_summary()\n        with self.logging_wrapper.log_event(\n            \"CoStormExpert.generate_utterance generate utterance\"\n        ):\n            last_conv_turn = conversation_history[-1]\n            conv_turn = self.costorm_agent_utterance_generator(\n                topic=self.topic,\n                current_expert=self.get_role_description(),\n                conversation_summary=conversation_summary,\n                last_conv_turn=last_conv_turn,\n            ).conversation_turn\n        with self.logging_wrapper.log_event(\n            \"CoStormExpert generate utterance: polish utterance\"\n        ):\n            if self.callback_handler is not None:\n                self.callback_handler.on_expert_utterance_polishing_start()\n            self.costorm_agent_utterance_generator.polish_utterance(\n                conversation_turn=conv_turn, last_conv_turn=last_conv_turn\n            )\n        return conv_turn\n\n\nclass SimulatedUser(Agent):\n    \"\"\"\n    Simulated Users is a special type of Agent in Co-STORM that simulates real user interaction behavior based on the given intent.\n\n    This class can be used for automatic experiments.\n    For more information, please refer to Section 3.4 of Co-STORM paper: https://www.arxiv.org/pdf/2408.15232\n    \"\"\"\n\n    def __init__(\n        self,\n        topic: str,\n        role_name: str,\n        role_description: str,\n        intent: str,\n        lm_config: LMConfigs,\n        runner_argument: \"RunnerArgument\",\n        logging_wrapper: LoggingWrapper,\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        super().__init__(topic, role_name, role_description)\n        self.intent = intent\n        self.lm_config = lm_config\n        self.runner_argument = runner_argument\n        self.logging_wrapper = logging_wrapper\n        self.gen_simulated_user_utterance = GenSimulatedUserUtterance(\n            engine=self.lm_config.question_answering_lm\n        )\n        self.callback_handler = callback_handler\n\n    def generate_utterance(\n        self,\n        knowledge_base: KnowledgeBase,\n        conversation_history: List[ConversationTurn],\n    ):\n        assert (\n            self.intent is not None and self.intent\n        ), \"Simulate user intent is not initialized.\"\n\n        with self.logging_wrapper.log_event(\n            \"SimulatedUser generate utternace: generate utterance\"\n        ):\n            utterance = self.gen_simulated_user_utterance(\n                topic=self.topic, intent=self.intent, conv_history=conversation_history\n            )\n        return ConversationTurn(\n            role=\"Guest\", raw_utterance=utterance, utterance_type=\"Original Question\"\n        )\n\n\nclass Moderator(Agent):\n    \"\"\"\n    The moderator's role in the Co-STORM framework is to inject new perspectives into the conversation to avoid stagnation, repetition, or overly niche discussions.\n    This is achieved by generating questions based on unused, uncited snippets of information retrieved since the last moderator's turn.\n    The selected information is reranked according to its relevance to the conversation topic and its dissimilarity to the original question.\n    The resulting top-ranked snippets are used to generate an informed question to be presented to the conversation participants.\n\n    For more information, please refer to Section 3.5 of Co-STORM paper: https://www.arxiv.org/pdf/2408.15232\n    \"\"\"\n\n    def __init__(\n        self,\n        topic: str,\n        role_name: str,\n        role_description: str,\n        lm_config: LMConfigs,\n        runner_argument: \"RunnerArgument\",\n        logging_wrapper: LoggingWrapper,\n        encoder: Encoder,\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        super().__init__(topic, role_name, role_description)\n        self.lm_config = lm_config\n        self.runner_argument = runner_argument\n        self.logging_wrapper = logging_wrapper\n        self.grounded_question_generation_module = GroundedQuestionGenerationModule(\n            engine=self.lm_config.question_asking_lm\n        )\n        self.callback_handler = callback_handler\n        self.encoder = encoder\n\n    def _get_conv_turn_unused_information(\n        self, conv_turn: ConversationTurn, knowledge_base: KnowledgeBase\n    ):\n        # extract all snippets from raw retrieved information\n        raw_retrieved_info: List[Information] = conv_turn.raw_retrieved_info\n        raw_retrieved_single_snippet_info: List[Information] = []\n        for info in raw_retrieved_info:\n            for snippet_idx in range(len(info.snippets)):\n                raw_retrieved_single_snippet_info.append(\n                    extract_storm_info_snippet(info, snippet_index=snippet_idx)\n                )\n        # get all cited information\n        cited_info = list(knowledge_base.info_uuid_to_info_dict.values())\n        cited_info_hash_set = set([hash(info) for info in cited_info])\n        cited_snippets = [info.snippets[0] for info in cited_info]\n        # get list of unused information\n        unused_information: List[Information] = [\n            info\n            for info in raw_retrieved_single_snippet_info\n            if hash(info) not in cited_info_hash_set\n        ]\n        if not unused_information:\n            return []\n        # extract snippets to get embeddings\n        unused_information_snippets = [info.snippets[0] for info in unused_information]\n        # get embeddings\n        unused_snippets_embeddings = self.encoder.encode(\n            unused_information_snippets, max_workers=20\n        )\n        claim_embedding = self.encoder.encode(conv_turn.claim_to_make)\n        query_embedding = self.encoder.encode(conv_turn.queries)\n        cited_snippets_embedding = self.encoder.encode(cited_snippets)\n        # calculate similarity\n        query_similarities = cosine_similarity(\n            unused_snippets_embeddings, query_embedding\n        )\n        max_query_similarity = np.max(query_similarities, axis=1)\n        cited_snippets_similarity = np.max(\n            cosine_similarity(unused_snippets_embeddings, cited_snippets_embedding),\n            axis=1,\n        )\n        cited_snippets_similarity = np.clip(cited_snippets_similarity, 0, 1)\n        # use claim similarity to filter out \"real\" not useful data\n        claim_similarity = cosine_similarity(\n            unused_snippets_embeddings, claim_embedding.reshape(1, -1)\n        ).flatten()\n        claim_similarity = np.where(claim_similarity >= 0.25, 1.0, 0.0)\n        # calculate score: snippet that is close to topic but far from query\n        query_sim_weight = 0.5\n        cited_snippets_sim_weight = 1 - query_sim_weight\n        combined_scores = (\n            ((1 - max_query_similarity) ** query_sim_weight)\n            * ((1 - cited_snippets_similarity) ** cited_snippets_sim_weight)\n            * claim_similarity\n        )\n        sorted_indices = np.argsort(combined_scores)[::-1]\n        return [unused_information[idx] for idx in sorted_indices]\n\n    def _get_sorted_unused_snippets(\n        self,\n        knowledge_base: KnowledgeBase,\n        conversation_history: List[ConversationTurn],\n        last_n_conv_turn: int = 2,\n    ):\n        # get last N conv turn and batch encode all related strings\n        considered_conv_turn = []\n        batch_snippets = [self.topic]\n        for conv_turn in reversed(conversation_history):\n            if len(considered_conv_turn) == last_n_conv_turn:\n                break\n            if conv_turn.utterance_type == \"Questioning\":\n                break\n            considered_conv_turn.append(conv_turn)\n            batch_snippets.extend(\n                sum([info.snippets for info in conv_turn.raw_retrieved_info], [])\n            )\n            batch_snippets.append(conv_turn.claim_to_make)\n            batch_snippets.extend(conv_turn.queries)\n        self.encoder.encode(batch_snippets, max_workers=20)\n\n        # get sorted unused snippets for each turn\n        sorted_snippets = []\n        for conv_turn in considered_conv_turn:\n            sorted_snippets.append(\n                self._get_conv_turn_unused_information(\n                    conv_turn=conv_turn, knowledge_base=knowledge_base\n                )\n            )\n\n        # use round robin rule to merge these snippets\n        merged_snippets = []\n        for elements in zip_longest(*sorted_snippets, fillvalue=None):\n            merged_snippets.extend(e for e in elements if e is not None)\n        return merged_snippets\n\n    def generate_utterance(\n        self,\n        knowledge_base: KnowledgeBase,\n        conversation_history: List[ConversationTurn],\n    ):\n        with self.logging_wrapper.log_event(\n            \"Moderator generate utternace: get unused snippets\"\n        ):\n            unused_snippets: List[Information] = self._get_sorted_unused_snippets(\n                knowledge_base=knowledge_base, conversation_history=conversation_history\n            )\n        with self.logging_wrapper.log_event(\n            \"Moderator generate utternace: QuestionGeneration module\"\n        ):\n            generated_question = self.grounded_question_generation_module(\n                topic=self.topic,\n                knowledge_base=knowledge_base,\n                last_conv_turn=conversation_history[-1],\n                unused_snippets=unused_snippets,\n            )\n        return ConversationTurn(\n            role=self.role_name,\n            raw_utterance=generated_question.raw_utterance,\n            utterance_type=\"Original Question\",\n            utterance=generated_question.utterance,\n            cited_info=generated_question.cited_info,\n        )\n\n\nclass PureRAGAgent(Agent):\n    \"\"\"\n    PureRAGAgent only handles grounded question generation by retrieving information from the retriever based on the query.\n    It does not utilize any other information besides the query itself.\n\n    It's designed for Co-STORM paper baseline comparison.\n    \"\"\"\n\n    def __init__(\n        self,\n        topic: str,\n        role_name: str,\n        role_description: str,\n        lm_config: LMConfigs,\n        runner_argument: \"RunnerArgument\",\n        logging_wrapper: LoggingWrapper,\n        rm: Optional[dspy.Retrieve] = None,\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        super().__init__(topic, role_name, role_description)\n        self.lm_config = lm_config\n        self.runner_argument = runner_argument\n        self.logging_wrapper = logging_wrapper\n        self.grounded_question_answering_module = _get_answer_question_module_instance(\n            lm_config=self.lm_config,\n            runner_argument=self.runner_argument,\n            logging_wrapper=self.logging_wrapper,\n            rm=rm,\n        )\n\n    def _gen_utterance_from_question(self, question: str):\n        grounded_answer = self.grounded_question_answering_module(\n            topic=self.topic,\n            question=question,\n            mode=\"brief\",\n            style=\"conversational and concise\",\n        )\n        conversation_turn = ConversationTurn(\n            role=self.role_name, raw_utterance=\"\", utterance_type=\"Potential Answer\"\n        )\n        conversation_turn.claim_to_make = question\n        conversation_turn.raw_utterance = grounded_answer.response\n        conversation_turn.utterance = grounded_answer.response\n        conversation_turn.queries = grounded_answer.queries\n        conversation_turn.raw_retrieved_info = grounded_answer.raw_retrieved_info\n        conversation_turn.cited_info = grounded_answer.cited_info\n        return conversation_turn\n\n    def generate_topic_background(self):\n        return self._gen_utterance_from_question(self.topic)\n\n    def generate_utterance(\n        self,\n        knowledge_base: KnowledgeBase,\n        conversation_history: List[ConversationTurn],\n    ):\n        with self.logging_wrapper.log_event(\n            \"PureRAGAgent generate utternace: generate utterance\"\n        ):\n            return self._gen_utterance_from_question(\n                question=conversation_history[-1].utterance\n            )\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/collaborative_storm_utils.py",
    "content": "import dspy\r\nimport os\r\nimport re\r\nimport sys\r\nimport toml\r\nfrom typing import List, Tuple, Dict, Optional, TYPE_CHECKING\r\n\r\nif TYPE_CHECKING:\r\n    from ..engine import RunnerArgument\r\nfrom ...interface import Information, Retriever, LMConfigs\r\nfrom ...logging_wrapper import LoggingWrapper\r\nfrom ...rm import BingSearch\r\n\r\n\r\ndef extract_storm_info_snippet(info: Information, snippet_index: int) -> Information:\r\n    \"\"\"\r\n    Constructs a new Information instance with only the specified snippet index.\r\n\r\n    Args:\r\n        storm_info (Information): The original Information instance.\r\n        snippet_index (int): The index of the snippet to retain.\r\n\r\n    Returns:\r\n        Information: A new Information instance with only the specified snippet.\r\n    \"\"\"\r\n    if snippet_index < 0 or snippet_index >= len(info.snippets):\r\n        raise ValueError(\"Snippet index out of range\")\r\n\r\n    new_snippets = [info.snippets[snippet_index]]\r\n    new_storm_info = Information(\r\n        info.url, info.description, new_snippets, info.title, info.meta\r\n    )\r\n    return new_storm_info\r\n\r\n\r\ndef format_search_results(\r\n    searched_results: List[Information],\r\n    info_max_num_words: int = 1000,\r\n    mode: str = \"brief\",\r\n) -> Tuple[str, Dict[int, Information]]:\r\n    \"\"\"\r\n    Constructs a string from a list of search results with a specified word limit and returns a mapping of indices to Information.\r\n\r\n    Args:\r\n        searched_results (List[Information]): List of Information objects to process.\r\n        info_max_num_words (int, optional): Maximum number of words allowed in the output string. Defaults to 1000.\r\n        mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information.\r\n                                'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'.\r\n\r\n    Returns:\r\n        Tuple[str, Dict[int, Information]]:\r\n            - Formatted string with search results, constrained by the word limit.\r\n            - Dictionary mapping indices to the corresponding Information objects.\r\n    \"\"\"\r\n    total_length = 0\r\n\r\n    extracted_snippet_queue = []\r\n    max_snippets = (\r\n        max(len(info.snippets) for info in searched_results) if searched_results else 0\r\n    )\r\n    max_snippets = 1 if mode == \"brief\" else max_snippets\r\n    abort = False\r\n    included_snippets = set()\r\n    for i in range(max_snippets):\r\n        for info in searched_results:\r\n            if i < len(info.snippets) and not abort:\r\n                cur_snippet = info.snippets[i]\r\n                cur_snippet_len = len(info.snippets[i].split())\r\n                if total_length + cur_snippet_len > info_max_num_words:\r\n                    abort = True\r\n                    break\r\n                if cur_snippet not in included_snippets:\r\n                    included_snippets.add(cur_snippet)\r\n                    info = extract_storm_info_snippet(info, snippet_index=i)\r\n                    extracted_snippet_queue.append(info)\r\n                    total_length += cur_snippet_len\r\n    output = []\r\n    index_mapping = {}\r\n    for idx, info in enumerate(extracted_snippet_queue):\r\n        output.append(f\"[{idx + 1}]: {info.snippets[0]}\")\r\n        index_mapping[idx + 1] = info\r\n    assert -1 not in index_mapping\r\n    return \"\\n\".join(output), index_mapping\r\n\r\n\r\ndef extract_cited_storm_info(\r\n    response: str, index_to_storm_info: Dict[int, Information]\r\n) -> Dict[int, Information]:\r\n    \"\"\"\r\n    Extracts a sub-dictionary of Information instances that are cited in the response.\r\n\r\n    Args:\r\n        response (str): The response string containing inline citations like [1], [2], etc.\r\n        index_to_storm_info (Dict[int, Information]): A dictionary mapping indices to Information instances.\r\n\r\n    Returns:\r\n        Dict[int, Information]: A sub-dictionary with only the indices that appear in the response.\r\n    \"\"\"\r\n    cited_indices = set(map(int, re.findall(r\"\\[(\\d+)\\]\", response)))\r\n    cited_storm_info = {\r\n        index: info\r\n        for index, info in index_to_storm_info.items()\r\n        if index in cited_indices\r\n    }\r\n    return cited_storm_info\r\n\r\n\r\ndef trim_output_after_hint(response: str, hint: str) -> str:\r\n    \"\"\"\r\n    Trims the output string to only keep the substring after the given hint (not including the hint).\r\n\r\n    Args:\r\n        response (str): The original output string.\r\n        hint (str): The hint string after which the substring should be kept.\r\n\r\n    Returns:\r\n        str: The trimmed output string, or the original string if the hint is not found.\r\n    \"\"\"\r\n    if hint in response:\r\n        start_index = response.find(hint) + len(hint)\r\n        return response[start_index:].strip()\r\n    return response.strip(\"\\n\")\r\n\r\n\r\ndef separate_citations(text: str) -> str:\r\n    \"\"\"\r\n    Separates multiple citations within square brackets into individual citations.\r\n\r\n    Args:\r\n        text (str): The input string containing citations.\r\n\r\n    Returns:\r\n        str: The string with separated citations.\r\n    \"\"\"\r\n\r\n    # Define a function to process each match\r\n    def replace_citations(match):\r\n        citations = match.group(1).split(\",\")\r\n        return \"\".join(f\"[{citation.strip()}]\" for citation in citations)\r\n\r\n    # Use regular expressions to find and replace citations\r\n    pattern = re.compile(r\"\\[(\\d+(?:,\\s*\\d+)*)\\]\")\r\n    return pattern.sub(replace_citations, text)\r\n\r\n\r\ndef extract_and_remove_citations(text: str) -> Tuple[str, List[int]]:\r\n    \"\"\"\r\n    Removes single inline citations from the input string and returns the modified string and a list of citation integers.\r\n\r\n    Args:\r\n        text (str): The input string containing citations.\r\n\r\n    Returns:\r\n        Tuple[str, List[int]]: The string after removal of citations and a list of citation integers.\r\n    \"\"\"\r\n    citations = []\r\n\r\n    # Define a function to process each match\r\n    def extract_citation(match):\r\n        citation = int(match.group(1))\r\n        citations.append(citation)\r\n        return \"\"\r\n\r\n    # Use regular expressions to find and replace citations\r\n    pattern = re.compile(r\"\\[(\\d+)\\]\")\r\n    modified_text = pattern.sub(extract_citation, text)\r\n\r\n    return modified_text, citations\r\n\r\n\r\ndef keep_first_and_last_paragraph(text: str) -> str:\r\n    \"\"\"\r\n    Processes the input text to keep the first and last paragraphs and replace\r\n    the middle paragraphs with '[content omitted due to space limit]'.\r\n\r\n    Args:\r\n        text (str): The input text containing paragraphs separated by '\\n\\n'.\r\n\r\n    Returns:\r\n        str: The processed text.\r\n    \"\"\"\r\n    paragraphs = text.split(\"\\n\\n\")\r\n\r\n    if len(paragraphs) <= 3:\r\n        return text\r\n\r\n    first_paragraph = paragraphs[0]\r\n    last_paragraph = \"\\n\\n\".join(paragraphs[-2:])\r\n    return (\r\n        f\"{first_paragraph}\\n\\n[content omitted due to space limit]\\n\\n{last_paragraph}\"\r\n    )\r\n\r\n\r\ndef clean_up_section(text):\r\n    \"\"\"Clean up a section:\r\n    1. Remove uncompleted sentences (usually due to output token limitation).\r\n    2. Deduplicate individual groups of citations.\r\n    3. Remove unnecessary summary.\"\"\"\r\n\r\n    paragraphs = text.split(\"\\n\")\r\n    output_paragraphs = []\r\n    summary_sec_flag = False\r\n    for p in paragraphs:\r\n        p = p.strip()\r\n        if len(p) == 0:\r\n            continue\r\n        if not p.startswith(\"#\"):\r\n            p = separate_citations(p)\r\n        if summary_sec_flag:\r\n            if p.startswith(\"#\"):\r\n                summary_sec_flag = False\r\n            else:\r\n                continue\r\n        if (\r\n            p.startswith(\"Overall\")\r\n            or p.startswith(\"In summary\")\r\n            or p.startswith(\"In conclusion\")\r\n        ):\r\n            continue\r\n        if \"# Summary\" in p or \"# Conclusion\" in p:\r\n            summary_sec_flag = True\r\n            continue\r\n        output_paragraphs.append(p)\r\n\r\n    return \"\\n\\n\".join(output_paragraphs)  # Join with '\\n\\n' for markdown format.\r\n\r\n\r\ndef load_api_key(toml_file_path):\r\n    try:\r\n        with open(toml_file_path, \"r\") as file:\r\n            data = toml.load(file)\r\n    except FileNotFoundError:\r\n        print(f\"File not found: {toml_file_path}\", file=sys.stderr)\r\n        return\r\n    except toml.TomlDecodeError:\r\n        print(f\"Error decoding TOML file: {toml_file_path}\", file=sys.stderr)\r\n        return\r\n    # Set environment variables\r\n    for key, value in data.items():\r\n        os.environ[key] = str(value)\r\n\r\n\r\ndef _get_answer_question_module_instance(\r\n    lm_config: LMConfigs,\r\n    runner_argument: \"RunnerArgument\",\r\n    logging_wrapper: LoggingWrapper,\r\n    rm: Optional[dspy.Retrieve] = None,\r\n):\r\n    from .grounded_question_answering import AnswerQuestionModule\r\n\r\n    # configure retriever\r\n    if rm is None:\r\n        rm = BingSearch(k=runner_argument.retrieve_top_k)\r\n    retriever = Retriever(rm=rm, max_thread=runner_argument.max_search_thread)\r\n    # return AnswerQuestionModule instance\r\n    return AnswerQuestionModule(\r\n        retriever=retriever,\r\n        max_search_queries=runner_argument.max_search_queries,\r\n        question_answering_lm=lm_config.question_answering_lm,\r\n        logging_wrapper=logging_wrapper,\r\n    )\r\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/costorm_expert_utterance_generator.py",
    "content": "import dspy\nfrom typing import Union\n\nfrom .callback import BaseCallbackHandler\nfrom .collaborative_storm_utils import (\n    trim_output_after_hint,\n    extract_and_remove_citations,\n    keep_first_and_last_paragraph,\n)\n\nfrom .grounded_question_answering import AnswerQuestionModule\nfrom .grounded_question_generation import ConvertUtteranceStyle\nfrom ...dataclass import ConversationTurn\nfrom ...logging_wrapper import LoggingWrapper\n\n\nclass GenExpertActionPlanning(dspy.Signature):\n    \"\"\"\n    You are an invited speaker in the round table conversation. Your task is to make a very short note to your assistant to help you prepare for your turn in the conversation.\n    You will be given the topic we are discussing, your expertise, and the conversation history.\n    Take a look at conversation history, especially last few turns, then let your assistant prepare the material for you with one of following ways.\n    1. Original Question: Initiates a new question to other speakers.\n        2. Further Details: Provides additional information.\n        3. Information Request: Requests information from other speakers.\n        4. Potential Answer: Offers a possible solution or answer.\n\n    Strictly follow this format: [type of contribution]: [one sentence description]. For example, Original Question: [description]\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"topic of discussion: \", format=str)\n    expert = dspy.InputField(prefix=\"You are inivited as: \", format=str)\n    summary = dspy.InputField(prefix=\"Discussion history: \\n\", format=str)\n    last_utterance = dspy.InputField(\n        prefix=\"Last utterance in the conversation: \\n\", format=str\n    )\n    resposne = dspy.OutputField(\n        prefix=\"Now give your note. Start with one of [Original Question, Further Details, Information Request, Potential Answer] with one sentence description\\n\",\n        format=str,\n    )\n\n\nclass CoStormExpertUtteranceGenerationModule(dspy.Module):\n    def __init__(\n        self,\n        action_planning_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        utterance_polishing_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        answer_question_module: AnswerQuestionModule,\n        logging_wrapper: LoggingWrapper,\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        self.action_planning_lm = action_planning_lm\n        self.utterance_polishing_lm = utterance_polishing_lm\n        self.expert_action = dspy.Predict(GenExpertActionPlanning)\n        self.change_style = dspy.Predict(ConvertUtteranceStyle)\n        self.answer_question_module = answer_question_module\n        self.logging_wrapper = logging_wrapper\n        self.callback_handler = callback_handler\n\n    def parse_action(self, action):\n        action_types = [\n            \"Original Question\",\n            \"Further Details\",\n            \"Information Request\",\n            \"Potential Answer\",\n        ]\n        for action_type in action_types:\n            if f\"{action_type}:\" in action:\n                return action_type, trim_output_after_hint(action, f\"{action_type}:\")\n            elif f\"[{action_type}]:\" in action:\n                return action_type, trim_output_after_hint(action, f\"[{action_type}]:\")\n        return \"Undefined\", \"\"\n\n    def polish_utterance(\n        self, conversation_turn: ConversationTurn, last_conv_turn: ConversationTurn\n    ):\n        # change utterance style\n        action_type = conversation_turn.utterance_type\n        with self.logging_wrapper.log_event(\n            \"RoundTableConversationModule.ConvertUtteranceStyle\"\n        ):\n            with dspy.settings.context(\n                lm=self.utterance_polishing_lm, show_guidelines=False\n            ):\n                action_string = (\n                    f\"{action_type} about: {conversation_turn.claim_to_make}\"\n                )\n                if action_type in [\"Original Question\", \"Information Request\"]:\n                    action_string = f\"{action_type}\"\n                last_expert_utterance_wo_citation, _ = extract_and_remove_citations(\n                    last_conv_turn.utterance\n                )\n                trimmed_last_expert_utterance = keep_first_and_last_paragraph(\n                    last_expert_utterance_wo_citation\n                )\n                utterance = self.change_style(\n                    expert=conversation_turn.role,\n                    action=action_string,\n                    prev=trimmed_last_expert_utterance,\n                    content=conversation_turn.raw_utterance,\n                ).utterance\n            conversation_turn.utterance = utterance\n\n    def forward(\n        self,\n        topic: str,\n        current_expert: str,\n        conversation_summary: str,\n        last_conv_turn: ConversationTurn,\n    ):\n        last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance)\n        if last_conv_turn.utterance_type in [\n            \"Original Question\",\n            \"Information Request\",\n        ]:\n            action_type = \"Potential Answer\"\n            action_content = last_utterance\n        else:\n            with self.logging_wrapper.log_event(\n                \"CoStormExpertUtteranceGenerationModule: GenExpertActionPlanning\"\n            ):\n                with dspy.settings.context(\n                    lm=self.action_planning_lm, show_guidelines=False\n                ):\n                    action = self.expert_action(\n                        topic=topic,\n                        expert=current_expert,\n                        summary=conversation_summary,\n                        last_utterance=last_utterance,\n                    ).resposne\n                action_type, action_content = self.parse_action(action)\n\n        if self.callback_handler is not None:\n            self.callback_handler.on_expert_action_planning_end()\n        # get response\n        conversation_turn = ConversationTurn(\n            role=current_expert, raw_utterance=\"\", utterance_type=action_type\n        )\n\n        if action_type == \"Undefined\":\n            raise Exception(f\"unexpected output: {action}\")\n        elif action_type in [\"Further Details\", \"Potential Answer\"]:\n            with self.logging_wrapper.log_event(\n                \"RoundTableConversationModule: QuestionAnswering\"\n            ):\n                grounded_answer = self.answer_question_module(\n                    topic=topic,\n                    question=action_content,\n                    mode=\"brief\",\n                    style=\"conversational and concise\",\n                    callback_handler=self.callback_handler,\n                )\n            conversation_turn.claim_to_make = action_content\n            conversation_turn.raw_utterance = grounded_answer.response\n            conversation_turn.queries = grounded_answer.queries\n            conversation_turn.raw_retrieved_info = grounded_answer.raw_retrieved_info\n            conversation_turn.cited_info = grounded_answer.cited_info\n        elif action_type in [\"Original Question\", \"Information Request\"]:\n            conversation_turn.raw_utterance = action_content\n\n        return dspy.Prediction(conversation_turn=conversation_turn)\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/expert_generation.py",
    "content": "import dspy\r\nimport re\r\nfrom typing import Union\r\n\r\n\r\nclass GenerateExpertGeneral(dspy.Signature):\r\n    \"\"\"You need to select a group of diverse experts who will be suitable to be invited to a roundtable discussion on the given topic.\r\n    Each expert should represent a different perspective, role, or affiliation related to this topic.\r\n    You can use the background information provided about the topic for inspiration. For each expert, add a description of their expertise and what they will focus on during the discussion.\r\n    No need to include speakers name in the output.\r\n    Strictly follow format below:\r\n    1. [speaker 1 role]: [speaker 1 short description]\r\n    2. [speaker 2 role]: [speaker 2 short description]\r\n    \"\"\"\r\n\r\n    topic = dspy.InputField(prefix=\"Topic of interest:\", format=str)\r\n    background_info = dspy.InputField(\r\n        prefix=\"Background information about the topic:\\n\", format=str\r\n    )\r\n    topN = dspy.InputField(prefix=\"Number of speakers needed: \", format=str)\r\n    experts = dspy.OutputField(format=str)\r\n\r\n\r\nclass GenerateExpertWithFocus(dspy.Signature):\r\n    \"\"\"\r\n    You need to select a group of speakers who will be suitable to have roundtable discussion on the [topic] of specific [focus].\r\n    You may consider inviting speakers having opposite stands on the topic; speakers representing different interest parties; Ensure that the selected speakers are directly connected to the specific context and scenario provided.\r\n    For example, if the discussion focus is about a recent event at a specific university, consider inviting students, faculty members, journalists covering the event, university officials, and local community members.\r\n    Use the background information provided about the topic for inspiration. For each speaker, add a description of their interests and what they will focus on during the discussion.\r\n    No need to include speakers name in the output.\r\n    Strictly follow format below:\r\n    1. [speaker 1 role]: [speaker 1 short description]\r\n    2. [speaker 2 role]: [speaker 2 short description]\r\n    \"\"\"\r\n\r\n    topic = dspy.InputField(prefix=\"Topic of interest:\", format=str)\r\n    background_info = dspy.InputField(prefix=\"Background information:\\n\", format=str)\r\n    focus = dspy.InputField(prefix=\"Discussion focus: \", format=str)\r\n    topN = dspy.InputField(prefix=\"Number of speakers needed: \", format=str)\r\n    experts = dspy.OutputField(format=str)\r\n\r\n\r\nclass GenerateExpertModule(dspy.Module):\r\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\r\n        self.engine = engine\r\n        self.generate_expert_general = dspy.Predict(GenerateExpertGeneral)\r\n        self.generate_expert_w_focus = dspy.ChainOfThought(GenerateExpertWithFocus)\r\n\r\n    def trim_background(self, background: str, max_words: int = 100):\r\n        words = background.split()\r\n        cur_len = len(words)\r\n        if cur_len <= max_words:\r\n            return background\r\n        trimmed_words = words[: min(cur_len, max_words)]\r\n        trimmed_background = \" \".join(trimmed_words)\r\n        return f\"{trimmed_background} [rest content omitted].\"\r\n\r\n    def forward(\r\n        self, topic: str, num_experts: int, background_info: str = \"\", focus: str = \"\"\r\n    ):\r\n        with dspy.settings.context(lm=self.engine, show_guidelines=False):\r\n            if not focus:\r\n                output = self.generate_expert_general(\r\n                    topic=topic, background_info=background_info, topN=num_experts\r\n                ).experts\r\n            else:\r\n                background_info = self.trim_background(\r\n                    background=background_info, max_words=100\r\n                )\r\n                output = self.generate_expert_w_focus(\r\n                    topic=topic,\r\n                    background_info=background_info,\r\n                    focus=focus,\r\n                    topN=num_experts,\r\n                ).experts\r\n        output = output.replace(\"*\", \"\").replace(\"[\", \"\").replace(\"]\", \"\")\r\n        expert_list = []\r\n        for s in output.split(\"\\n\"):\r\n            match = re.search(r\"\\d+\\.\\s*(.*)\", s)\r\n            if match:\r\n                expert_list.append(match.group(1))\r\n        expert_list = [expert.strip() for expert in expert_list if expert.strip()]\r\n        return dspy.Prediction(experts=expert_list, raw_output=output)\r\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/grounded_question_answering.py",
    "content": "import dspy\nfrom typing import Union, List\n\nfrom .callback import BaseCallbackHandler\nfrom .collaborative_storm_utils import (\n    trim_output_after_hint,\n    format_search_results,\n    extract_cited_storm_info,\n    separate_citations,\n)\nfrom ...logging_wrapper import LoggingWrapper\nfrom ...utils import ArticleTextProcessing\nfrom ...interface import Information\n\n\nclass QuestionToQuery(dspy.Signature):\n    \"\"\"You want to answer the question or support a claim using Google search. What do you type in the search box?\n    The question is raised in a round table discussion on a topic. The question may or may not focus on the topic itself.\n    Write the queries you will use in the following format:\n    - query 1\n    - query 2\n    ...\n    - query n\"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic context:\", format=str)\n    question = dspy.InputField(\n        prefix=\"I want to collect information about: \", format=str\n    )\n    queries = dspy.OutputField(prefix=\"Queries: \\n\", format=str)\n\n\nclass AnswerQuestion(dspy.Signature):\n    \"\"\"You are an expert who can use information effectively. You have gathered the related information and will now use the information to form a response.\n    Make your response as informative as possible and make sure every sentence is supported by the gathered information.\n    If [Gathered information] is not directly related to the [Topic] and [Question], provide the most relevant answer you can based on the available information, and explain any limitations or gaps.\n    Use [1], [2], ..., [n] in line (for example, \"The capital of the United States is Washington, D.C.[1][3].\").\n    You DO NOT need to include a References or Sources section to list the sources at the end. The style of writing should be formal.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic you are discussing about:\", format=str)\n    question = dspy.InputField(prefix=\"You want to provide insight on: \", format=str)\n    info = dspy.InputField(prefix=\"Gathered information:\\n\", format=str)\n    style = dspy.InputField(prefix=\"Style of your response should be:\", format=str)\n    answer = dspy.OutputField(\n        prefix=\"Now give your response. (Try to use as many different sources as possible and do not hallucinate.)\",\n        format=str,\n    )\n\n\nclass AnswerQuestionModule(dspy.Module):\n    def __init__(\n        self,\n        retriever: dspy.Retrieve,\n        max_search_queries: int,\n        question_answering_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        logging_wrapper: LoggingWrapper,\n    ):\n        super().__init__()\n        self.question_answering_lm = question_answering_lm\n        self.question_to_query = dspy.Predict(QuestionToQuery)\n        self.answer_question = dspy.Predict(AnswerQuestion)\n        self.retriever = retriever\n        self.max_search_queries = max_search_queries\n        self.logging_wrapper = logging_wrapper\n\n    def retrieve_information(self, topic, question):\n        # decompose question to queries\n        with self.logging_wrapper.log_event(\n            f\"AnswerQuestionModule.question_to_query ({hash(question)})\"\n        ):\n            with dspy.settings.context(lm=self.question_answering_lm):\n                queries = self.question_to_query(topic=topic, question=question).queries\n            queries = trim_output_after_hint(queries, hint=\"Queries:\")\n            queries = [\n                q.replace(\"-\", \"\").strip().strip('\"').strip('\"').strip()\n                for q in queries.split(\"\\n\")\n            ]\n            queries = queries[: self.max_search_queries]\n        self.logging_wrapper.add_query_count(count=len(queries))\n        with self.logging_wrapper.log_event(\n            f\"AnswerQuestionModule.retriever.retrieve ({hash(question)})\"\n        ):\n            # retrieve information using retriever\n            searched_results: List[Information] = self.retriever.retrieve(\n                list(set(queries)), exclude_urls=[]\n            )\n        # update storm information meta to include the question\n        for storm_info in searched_results:\n            storm_info.meta[\"question\"] = question\n        return queries, searched_results\n\n    def forward(\n        self,\n        topic: str,\n        question: str,\n        mode: str = \"brief\",\n        style: str = \"conversational\",\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        \"\"\"\n        Processes a topic and question to generate a response with relevant information and citations.\n\n        Args:\n            topic (str): The topic of interest.\n            question (str): The specific question related to the topic.\n            mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information.\n                                'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'.\n\n        Returns:\n            dspy.Prediction: An object containing the following:\n                - question (str): the question to answer\n                - queries (List[str]): List of query strings used for information retrieval.\n                - raw_retrieved_info (List[Information]): List of Information instances retrieved.\n                - cited_info (Dict[int, Information]): Dictionary of cited Information instances, indexed by their citation number.\n                - response (str): The generated response string with inline citations.\n        \"\"\"\n        # retrieve information\n        if callback_handler is not None:\n            callback_handler.on_expert_information_collection_start()\n        queries, searched_results = self.retrieve_information(\n            topic=topic, question=question\n        )\n        if callback_handler is not None:\n            callback_handler.on_expert_information_collection_end(searched_results)\n        # format information string for answer generation\n        info_text, index_to_information_mapping = format_search_results(\n            searched_results, mode=mode\n        )\n        answer = \"Sorry, there is insufficient information to answer the question.\"\n        # generate answer to the question\n        if info_text:\n            with self.logging_wrapper.log_event(\n                f\"AnswerQuestionModule.answer_question ({hash(question)})\"\n            ):\n                with dspy.settings.context(\n                    lm=self.question_answering_lm, show_guidelines=False\n                ):\n                    answer = self.answer_question(\n                        topic=topic, question=question, info=info_text, style=style\n                    ).answer\n                    answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(\n                        answer\n                    )\n                    answer = trim_output_after_hint(\n                        answer,\n                        hint=\"Now give your response. (Try to use as many different sources as possible and do not hallucinate.)\",\n                    )\n                    # enforce single citation index bracket. [1, 2] -> [1][2]\n                    answer = separate_citations(answer)\n                    if callback_handler is not None:\n                        callback_handler.on_expert_utterance_generation_end()\n        # construct cited search result\n        cited_searched_results = extract_cited_storm_info(\n            response=answer, index_to_storm_info=index_to_information_mapping\n        )\n\n        return dspy.Prediction(\n            question=question,\n            queries=queries,\n            raw_retrieved_info=searched_results,\n            cited_info=cited_searched_results,\n            response=answer,\n        )\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/grounded_question_generation.py",
    "content": "\"\"\"\r\nThis module handles question generation within the Co-STORM framework, specifically designed to support the Moderator role.\r\n\r\nThe Moderator generates insightful, thought-provoking questions that introduce new directions into the conversation. \r\nBy leveraging uncited or unused snippets of information retrieved during the discussion, the Moderator ensures the conversation remains dynamic and avoids repetitive or overly niche topics.\r\n\r\nFor more detailed information, refer to Section 3.5 of the Co-STORM paper: https://www.arxiv.org/pdf/2408.15232.\r\n\"\"\"\r\n\r\nimport dspy\r\nfrom typing import List, Union\r\n\r\nfrom .collaborative_storm_utils import (\r\n    format_search_results,\r\n    extract_and_remove_citations,\r\n    keep_first_and_last_paragraph,\r\n    extract_cited_storm_info,\r\n)\r\nfrom ...dataclass import ConversationTurn, KnowledgeBase\r\nfrom ...interface import Information\r\n\r\n\r\nclass KnowledgeBaseSummmary(dspy.Signature):\r\n    \"\"\"Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections.\r\n    You will be presented with these sections where \"#\" denotes level of section.\r\n    \"\"\"\r\n\r\n    topic = dspy.InputField(prefix=\"topic: \", format=str)\r\n    structure = dspy.InputField(prefix=\"Tree structure: \\n\", format=str)\r\n    output = dspy.OutputField(prefix=\"Now give brief summary:\\n\", format=str)\r\n\r\n\r\nclass ConvertUtteranceStyle(dspy.Signature):\r\n    \"\"\"\r\n    You are an invited speaker in the round table conversation.\r\n    Your task is to make the question or the response more conversational and engaging to facilicate the flow of conversation.\r\n    Note that this is ongoing conversation so no need to have welcoming and concluding words. Previous speaker utterance is provided only for making the conversation more natural.\r\n    Note that do not hallucinate and keep the citation index like [1] as it is. Also,\r\n    \"\"\"\r\n\r\n    expert = dspy.InputField(prefix=\"You are inivited as: \", format=str)\r\n    action = dspy.InputField(\r\n        prefix=\"You want to contribute to conversation by: \", format=str\r\n    )\r\n    prev = dspy.InputField(prefix=\"Previous speaker said: \", format=str)\r\n    content = dspy.InputField(\r\n        prefix=\"Question or response you want to say: \", format=str\r\n    )\r\n    utterance = dspy.OutputField(\r\n        prefix=\"Your utterance (keep the information as much as you can with citations, prefer shorter answers without loss of information): \",\r\n        format=str,\r\n    )\r\n\r\n\r\nclass GroundedQuestionGeneration(dspy.Signature):\r\n    \"\"\"Your job is to find next discussion focus in a roundtable conversation. You will be given previous conversation summary and some information that might assist you discover new discussion focus.\r\n    Note that the new discussion focus should bring new angle and perspective to the discussion and avoid repetition. The new discussion focus should be grounded on the available information and push the boundaries of the current discussion for broader exploration.\r\n    The new discussion focus should have natural flow from last utterance in the conversation.\r\n    Use [1][2] in line to ground your question.\r\n    \"\"\"\r\n\r\n    topic = dspy.InputField(prefix=\"topic: \", format=str)\r\n    summary = dspy.InputField(prefix=\"Discussion history: \\n\", format=str)\r\n    information = dspy.InputField(prefix=\"Available information: \\n\", format=str)\r\n    last_utterance = dspy.InputField(\r\n        prefix=\"Last utterance in the conversation: \\n\", format=str\r\n    )\r\n    output = dspy.OutputField(\r\n        prefix=\"Now give next discussion focus in the format of one sentence question:\\n\",\r\n        format=str,\r\n    )\r\n\r\n\r\nclass GroundedQuestionGenerationModule(dspy.Module):\r\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\r\n        self.engine = engine\r\n        self.gen_focus = dspy.Predict(GroundedQuestionGeneration)\r\n        self.polish_style = dspy.Predict(ConvertUtteranceStyle)\r\n        self.gen_summary = dspy.Predict(KnowledgeBaseSummmary)\r\n\r\n    def forward(\r\n        self,\r\n        topic: str,\r\n        knowledge_base: KnowledgeBase,\r\n        last_conv_turn: ConversationTurn,\r\n        unused_snippets: List[Information],\r\n    ):\r\n        information, index_to_information_mapping = format_search_results(\r\n            unused_snippets, info_max_num_words=1000\r\n        )\r\n        summary = knowledge_base.get_knowledge_base_summary()\r\n        last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance)\r\n        with dspy.settings.context(lm=self.engine, show_guidelines=False):\r\n            raw_utterance = self.gen_focus(\r\n                topic=topic,\r\n                summary=summary,\r\n                information=information,\r\n                last_utterance=last_utterance,\r\n            ).output\r\n            utterance = self.polish_style(\r\n                expert=\"Roundtable conversation moderator\",\r\n                action=\"Raising a new question by natural transit from previous utterance.\",\r\n                prev=keep_first_and_last_paragraph(last_utterance),\r\n                content=raw_utterance,\r\n            ).utterance\r\n            cited_searched_results = extract_cited_storm_info(\r\n                response=utterance, index_to_storm_info=index_to_information_mapping\r\n            )\r\n            return dspy.Prediction(\r\n                raw_utterance=raw_utterance,\r\n                utterance=utterance,\r\n                cited_info=cited_searched_results,\r\n            )\r\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/information_insertion_module.py",
    "content": "import dspy\r\nimport numpy as np\r\nimport re\r\nimport traceback\r\n\r\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\r\nfrom sklearn.metrics.pairwise import cosine_similarity\r\nfrom typing import List, Union, Dict, Optional\r\n\r\nfrom .collaborative_storm_utils import trim_output_after_hint\r\nfrom ...dataclass import KnowledgeNode, KnowledgeBase\r\nfrom ...encoder import Encoder\r\nfrom ...interface import Information\r\n\r\n\r\nclass InsertInformation(dspy.Signature):\r\n    \"\"\"Your job is to insert the given information to the knowledge base. The knowledge base is a tree based data structure to organize the collection information. Each knowledge node contains information derived from themantically similar question or intent.\r\n    To decide the best placement of the information, you will be navigated in this tree based data structure layer by layer.\r\n    You will be presented with the question and query leads to ththeis information, and tree structure.\r\n\r\n    Output should strictly follow one of options presetned below with no other information.\r\n    - 'insert': to place the information under the current node.\r\n    - 'step: [child node name]': to step into a specified child node.\r\n    - 'create: [new child node name]': to create new child node and insert the info under it.\r\n\r\n    Example outputs:\r\n    - insert\r\n    - step: node2\r\n    - create: node3\r\n    \"\"\"\r\n\r\n    intent = dspy.InputField(\r\n        prefix=\"Question and query leads to this info: \", format=str\r\n    )\r\n    structure = dspy.InputField(prefix=\"Tree structure: \\n\", format=str)\r\n    choice = dspy.OutputField(prefix=\"Choice:\\n\", format=str)\r\n\r\n\r\nclass InsertInformationCandidateChoice(dspy.Signature):\r\n    \"\"\"Your job is to insert the given information to the knowledge base. The knowledge base is a tree based data structure to organize the collection information. Each knowledge node contains information derived from themantically similar question or intent.\r\n    You will be presented with the question and query leads to this information, and candidate choices of placement. In these choices, -> denotes parent-child relationship. Note that reasonable may not be in these choices.\r\n\r\n    If there exists reasonable choice, output \"Best placement: [choice index]\"; otherwise, output \"No reasonable choice\".\r\n    \"\"\"\r\n\r\n    intent = dspy.InputField(\r\n        prefix=\"Question and query leads to this info: \", format=str\r\n    )\r\n    choices = dspy.InputField(prefix=\"Candidate placement:\\n\", format=str)\r\n    decision = dspy.OutputField(prefix=\"Decision:\\n\", format=str)\r\n\r\n\r\nclass InsertInformationModule(dspy.Module):\r\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], encoder: Encoder):\r\n        self.engine = engine\r\n        self.encoder = encoder\r\n        self.insert_info = dspy.ChainOfThought(InsertInformation)\r\n        self.candidate_choosing = dspy.Predict(InsertInformationCandidateChoice)\r\n\r\n    def _construct_intent(self, question: str, query: str):\r\n        intent = \"\"\r\n        if query == \"Not applicable\":\r\n            return question\r\n        if question:\r\n            intent += f\"Question: {question}\\n\"\r\n        if query:\r\n            intent += f\"Query: {query}\\n\"\r\n        if not intent:\r\n            intent = \"Not available.\"\r\n        return intent\r\n\r\n    def _get_navigation_choice(\r\n        self, knowledge_node: KnowledgeNode, question: str, query: str\r\n    ):\r\n        # construct information intent\r\n        intent = self._construct_intent(question, query)\r\n        # construct current kb structure\r\n        structure = f\"Current Node: {knowledge_node.name}\\n\"\r\n        child_names = \", \".join(knowledge_node.get_children_names())\r\n        if child_names:\r\n            structure += f\"Child Nodes: {child_names}\"\r\n        navigated_path = \" -> \".join(knowledge_node.get_path_from_root())\r\n        structure += f\"Path you have nagivated: {navigated_path}\"\r\n\r\n        # get predicted action\r\n        with dspy.settings.context(lm=self.engine):\r\n            predicted_action = self.insert_info(\r\n                intent=intent, structure=structure\r\n            ).choice\r\n\r\n        # parse action\r\n        cleaned_predicted_action = trim_output_after_hint(\r\n            predicted_action, \"Choice:\"\r\n        ).strip()\r\n        cleaned_predicted_action = cleaned_predicted_action.strip(\"-\").strip()\r\n        if cleaned_predicted_action.startswith(\"insert\"):\r\n            return \"insert\", \"\"\r\n        elif cleaned_predicted_action.startswith(\"step:\"):\r\n            node_name = trim_output_after_hint(cleaned_predicted_action, \"step:\")\r\n            return \"step\", node_name\r\n        elif cleaned_predicted_action.startswith(\"create:\"):\r\n            node_name = trim_output_after_hint(cleaned_predicted_action, \"create:\")\r\n            return \"create\", node_name\r\n        raise Exception(\r\n            f\"Undefined predicted action in knowledge navigation. {predicted_action}\"\r\n        )\r\n\r\n    def layer_by_layer_navigation_placement(\r\n        self,\r\n        knowledge_base: KnowledgeBase,\r\n        question: str,\r\n        query: str,\r\n        allow_create_new_node: bool = False,\r\n        root: Optional[KnowledgeNode] = None,\r\n    ):\r\n        current_node: KnowledgeNode = knowledge_base.root if root is None else root\r\n\r\n        while True:\r\n            action_type, node_name = self._get_navigation_choice(\r\n                knowledge_node=current_node, question=question, query=query\r\n            )\r\n            if action_type == \"insert\":\r\n                return dspy.Prediction(\r\n                    information_placement=\" -> \".join(\r\n                        current_node.get_path_from_root(root)\r\n                    ),\r\n                    note=\"None\",\r\n                )\r\n            elif action_type == \"step\":\r\n                for child in current_node.children:\r\n                    if child.name == node_name:\r\n                        current_node = child\r\n                        break\r\n                else:\r\n                    raise ValueError(f\"Child node with name {node_name} not found.\")\r\n            elif action_type == \"create\":\r\n                placement_path = current_node.get_path_from_root(root)\r\n                if allow_create_new_node:\r\n                    placement_path.append(node_name)\r\n                    note = f\"create new node: {{{node_name}}} under {{{current_node.name}}}\"\r\n                else:\r\n                    note = f\"attempt to create new node: {{{node_name}}} under {{{current_node.name}}}\"\r\n                return dspy.Prediction(\r\n                    information_placement=\" -> \".join(placement_path), note=note\r\n                )\r\n            else:\r\n                raise ValueError(f\"Unknown action type: {action_type}\")\r\n\r\n    def _get_sorted_embed_sim_section(\r\n        self,\r\n        encoded_outline: np.ndarray,\r\n        outlines: List[str],\r\n        question: str,\r\n        query: str,\r\n    ):\r\n        if encoded_outline is not None and encoded_outline.size > 0:\r\n            encoded_query = self.encoder.encode(f\"{question}, {query}\")\r\n            sim = cosine_similarity([encoded_query], encoded_outline)[0]\r\n            sorted_indices = np.argsort(sim)\r\n            sorted_outlines = np.array(outlines)[sorted_indices[::-1]]\r\n            return sorted_outlines\r\n        else:\r\n            return outlines\r\n\r\n    def _parse_selected_index(self, string: str):\r\n        match = re.search(r\"\\[(\\d+)\\]\", string)\r\n        if match:\r\n            return int(match.group(1))\r\n        try:\r\n            return int(string.strip())\r\n        except:\r\n            pass\r\n        return None\r\n\r\n    def choose_candidate_from_embedding_ranking(\r\n        self,\r\n        question: str,\r\n        query: str,\r\n        encoded_outlines: np.ndarray,\r\n        outlines: List[str],\r\n        top_N_candidates: int = 5,\r\n    ):\r\n        sorted_candidates = self._get_sorted_embed_sim_section(\r\n            encoded_outlines, outlines, question, query\r\n        )\r\n        considered_candidates = sorted_candidates[\r\n            : min(len(sorted_candidates), top_N_candidates)\r\n        ]\r\n        choices_string = \"\\n\".join(\r\n            [\r\n                f\"{idx + 1}: {candidate}\"\r\n                for idx, candidate in enumerate(considered_candidates)\r\n            ]\r\n        )\r\n        with dspy.settings.context(lm=self.engine, show_guidelines=False):\r\n            decision = self.candidate_choosing(\r\n                intent=self._construct_intent(question=question, query=query),\r\n                choices=choices_string,\r\n            ).decision\r\n            decision = trim_output_after_hint(decision, hint=\"Decision:\")\r\n            if \"Best placement:\" in decision:\r\n                decision = trim_output_after_hint(decision, hint=\"Best placement:\")\r\n                selected_index = self._parse_selected_index(decision)\r\n                if selected_index is not None:\r\n                    selected_index = selected_index - 1\r\n                    if selected_index < len(sorted_candidates) and selected_index >= 0:\r\n                        return dspy.Prediction(\r\n                            information_placement=sorted_candidates[selected_index],\r\n                            note=f\"Choosing from:\\n{considered_candidates}\",\r\n                        )\r\n            return None\r\n\r\n    def _info_list_to_intent_mapping(self, information_list: List[Information]):\r\n        intent_to_placement_dict = {}\r\n        for info in information_list:\r\n            intent = (info.meta.get(\"question\", \"\"), info.meta.get(\"query\", \"\"))\r\n            if intent not in intent_to_placement_dict:\r\n                intent_to_placement_dict[intent] = None\r\n        return intent_to_placement_dict\r\n\r\n    def forward(\r\n        self,\r\n        knowledge_base: KnowledgeBase,\r\n        information: Union[Information, List[Information]],\r\n        allow_create_new_node: bool = False,\r\n        max_thread: int = 5,\r\n        insert_root: Optional[KnowledgeNode] = None,\r\n        skip_candidate_from_embedding: bool = False,\r\n    ):\r\n        if not isinstance(information, List):\r\n            information = [information]\r\n        intent_to_placement_dict: Dict = self._info_list_to_intent_mapping(\r\n            information_list=information\r\n        )\r\n\r\n        # process one intent\r\n        def process_intent(question: str, query: str):\r\n            candidate_placement = None\r\n            try:\r\n                if not skip_candidate_from_embedding:\r\n                    candidate_placement = self.choose_candidate_from_embedding_ranking(\r\n                        question=question,\r\n                        query=query,\r\n                        encoded_outlines=encoded_outlines,\r\n                        outlines=outlines,\r\n                        top_N_candidates=8,\r\n                    )\r\n                if candidate_placement is None:\r\n                    candidate_placement = self.layer_by_layer_navigation_placement(\r\n                        knowledge_base=knowledge_base,\r\n                        question=question,\r\n                        query=query,\r\n                        allow_create_new_node=allow_create_new_node,\r\n                        root=insert_root,\r\n                    )\r\n                return (question, query), candidate_placement\r\n            except Exception as e:\r\n                print(traceback.format_exc())\r\n                return (question, query), None\r\n\r\n        def insert_info_to_kb(info, placement_prediction):\r\n            if placement_prediction is not None:\r\n                missing_node_handling = (\r\n                    \"raise error\" if not allow_create_new_node else \"create\"\r\n                )\r\n                knowledge_base.insert_information(\r\n                    path=placement_prediction.information_placement,\r\n                    information=info,\r\n                    missing_node_handling=missing_node_handling,\r\n                    root=insert_root,\r\n                )\r\n\r\n        (\r\n            encoded_outlines,\r\n            outlines,\r\n        ) = knowledge_base.get_knowledge_base_structure_embedding(root=insert_root)\r\n        to_return = []\r\n        if not allow_create_new_node:\r\n            # use multi thread as knowledge base structure does not change\r\n            with ThreadPoolExecutor(max_workers=max_thread) as executor:\r\n                futures = {\r\n                    executor.submit(process_intent, question, query): (question, query)\r\n                    for (question, query) in intent_to_placement_dict\r\n                }\r\n\r\n                for future in as_completed(futures):\r\n                    (question, query), candidate_placement = future.result()\r\n                    intent_to_placement_dict[(question, query)] = candidate_placement\r\n            # back mapping placement to each information\r\n            for info in information:\r\n                intent = (info.meta.get(\"question\", \"\"), info.meta.get(\"query\", \"\"))\r\n                placement_prediction = intent_to_placement_dict.get(intent, None)\r\n                insert_info_to_kb(info, placement_prediction)\r\n                to_return.append((info, placement_prediction))\r\n            return to_return\r\n        else:\r\n            # use sequential insert as knowledge base structure might change\r\n            for question, query in intent_to_placement_dict:\r\n                (\r\n                    encoded_outlines,\r\n                    outlines,\r\n                ) = knowledge_base.get_knowledge_base_structure_embedding(\r\n                    root=insert_root\r\n                )\r\n                _, placement_prediction = process_intent(question=question, query=query)\r\n                intent_to_placement_dict[(question, query)] = placement_prediction\r\n\r\n            for info in information:\r\n                intent = (info.meta.get(\"question\", \"\"), info.meta.get(\"query\", \"\"))\r\n                placement_prediction = intent_to_placement_dict.get(intent, None)\r\n                insert_info_to_kb(info, placement_prediction)\r\n                to_return.append((info, placement_prediction))\r\n            return to_return\r\n\r\n\r\nclass ExpandSection(dspy.Signature):\r\n    \"\"\"Your task is to expand a section in the mind map by creating new subsections under the given section.\r\n    You will be given a list of question and query that are used to collect information.\r\n    Output should be subsection names where each section should serve as a coherent and themantic organization of information and corresponding citation numbers. These subsection names are preferred to be concise and precise.\r\n    Output follows the format below:\r\n    subsection 1\r\n    subsection 2\r\n    subsection 3\r\n    \"\"\"\r\n\r\n    section = dspy.InputField(prefix=\"The section you need to expand: \", format=str)\r\n    info = dspy.InputField(prefix=\"The collected information:\\n\", format=str)\r\n    output = dspy.OutputField(\r\n        prefix=\"Now provide the expanded subsection names (If there's no need to expand current section as itself serves good organization, then output None):\\n\",\r\n        format=str,\r\n    )\r\n\r\n\r\nclass ExpandNodeModule(dspy.Module):\r\n    def __init__(\r\n        self,\r\n        engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],\r\n        information_insert_module: dspy.Module,\r\n        node_expansion_trigger_count: int,\r\n    ):\r\n        self.engine = engine\r\n        self.expand_section = dspy.Predict(ExpandSection)\r\n        self.information_insert_module = information_insert_module\r\n        self.node_expansion_trigger_count = node_expansion_trigger_count\r\n\r\n    def _get_cited_info_meta_string(self, node, knowledge_base):\r\n        meta_string = set()\r\n        for index in sorted(list(node.content)):\r\n            info = knowledge_base.info_uuid_to_info_dict[index]\r\n            intent = f\"Question: {info.meta['question']}\\nQuery: {info.meta['query']}\"\r\n            meta_string.add(intent)\r\n\r\n        return \"\\n\\n\".join(meta_string)\r\n\r\n    def _get_expand_subnode_names(self, node, knowledge_base):\r\n        information = self._get_cited_info_meta_string(node, knowledge_base)\r\n        node_path = node.get_path_from_root()\r\n        with dspy.settings.context(lm=self.engine, show_guidelines=False):\r\n            output = self.expand_section(section=node_path, info=information).output\r\n        subsections = []\r\n        if \"\\n\" in output and output != \"None\":\r\n            subsections = output.split(\"\\n\")\r\n            # remove any integer followed by a dot and a space, a leading dashline,\r\n            # or a specific hint at the start of the string\r\n            subsections = [\r\n                re.sub(r\"^\\d+\\.\\s|-|\" + re.escape(node.name), \"\", text)\r\n                .replace(\"*\", \"\")\r\n                .strip()\r\n                for text in subsections\r\n            ]\r\n        return subsections\r\n\r\n    def _find_first_node_to_expand(\r\n        self, root: KnowledgeNode, expanded_nodes: List[KnowledgeNode]\r\n    ):\r\n        if root is None:\r\n            return None\r\n        if (\r\n            root not in expanded_nodes\r\n            and len(root.content) >= self.node_expansion_trigger_count\r\n        ):\r\n            return root\r\n        for child in root.children:\r\n            to_return = self._find_first_node_to_expand(\r\n                root=child, expanded_nodes=expanded_nodes\r\n            )\r\n            if to_return is not None:\r\n                return to_return\r\n        return None\r\n\r\n    def _expand_node(self, node: KnowledgeNode, knowledge_base: KnowledgeBase):\r\n        subsection_names = self._get_expand_subnode_names(node, knowledge_base)\r\n        if len(subsection_names) <= 1:\r\n            return\r\n        # create new nodes\r\n        for subsection_name in subsection_names:\r\n            # remove citation bracket in the subsection name\r\n            subsection_name = re.sub(r\"\\[.*?\\]\", \"\", subsection_name)\r\n            knowledge_base.insert_node(new_node_name=subsection_name, parent_node=node)\r\n        # reset original information placement\r\n        original_cited_index = node.content\r\n        original_cited_information = [\r\n            knowledge_base.info_uuid_to_info_dict[index]\r\n            for index in original_cited_index\r\n        ]\r\n        node.content = set()\r\n        # re-insert under expanded section\r\n        self.information_insert_module(\r\n            knowledge_base=knowledge_base,\r\n            information=original_cited_information,\r\n            allow_create_new_node=False,\r\n            insert_root=node,\r\n        )\r\n\r\n    def forward(self, knowledge_base: KnowledgeBase):\r\n        expanded_nodes = []\r\n        while True:\r\n            node_to_expand = self._find_first_node_to_expand(\r\n                root=knowledge_base.root, expanded_nodes=expanded_nodes\r\n            )\r\n            if node_to_expand is None:\r\n                break\r\n            self._expand_node(node=node_to_expand, knowledge_base=knowledge_base)\r\n            expanded_nodes.append(node_to_expand)\r\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/knowledge_base_summary.py",
    "content": "import dspy\nfrom typing import Union\nfrom ...dataclass import KnowledgeBase\n\n\nclass KnowledgeBaseSummmary(dspy.Signature):\n    \"\"\"Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections.\n    You will be presented with these sections where \"#\" denotes level of section.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"topic: \", format=str)\n    structure = dspy.InputField(prefix=\"Tree structure: \\n\", format=str)\n    output = dspy.OutputField(prefix=\"Now give brief summary:\\n\", format=str)\n\n\nclass KnowledgeBaseSummaryModule(dspy.Module):\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.engine = engine\n        self.gen_summary = dspy.Predict(KnowledgeBaseSummmary)\n\n    def forward(self, knowledge_base: KnowledgeBase):\n        structure = knowledge_base.get_node_hierarchy_string(\n            include_indent=False,\n            include_full_path=False,\n            include_hash_tag=True,\n            include_node_content_count=False,\n        )\n        with dspy.settings.context(lm=self.engine, show_guidelines=False):\n            summary = self.gen_summary(\n                topic=knowledge_base.topic, structure=structure\n            ).output\n        return summary\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/simulate_user.py",
    "content": "import dspy\nfrom typing import List, Union\n\nfrom .collaborative_storm_utils import extract_and_remove_citations\nfrom ...dataclass import ConversationTurn\nfrom ...storm_wiki.modules.knowledge_curation import AskQuestionWithPersona\n\n\nclass GenSimulatedUserUtterance(dspy.Module):\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.engine = engine\n        self.ask_qeustion = dspy.Predict(AskQuestionWithPersona)\n\n    def gen_conv_history_string(self, conversation_turns: List[ConversationTurn]):\n        conv_history = []\n        total_turns = len(conversation_turns)\n\n        for i, turn in enumerate(conversation_turns):\n            utterance, _ = extract_and_remove_citations(turn.utterance)\n            if i >= total_turns - 4:\n                conv_history.append(f\"{turn.role}: {utterance}\")\n            else:\n                if turn.claim_to_make:\n                    conv_history.append(f\"{turn.role}: {turn.claim_to_make}\")\n                else:\n                    conv_history.append(f\"{turn.role}: {utterance}\")\n\n        return \"\\n\".join(conv_history)\n\n    def forward(self, topic: str, intent: str, conv_history: List[ConversationTurn]):\n        conv_history_string = self.gen_conv_history_string(conv_history)\n        with dspy.settings.context(lm=self.engine, show_guidelines=False):\n            return self.ask_qeustion(\n                topic=topic,\n                persona=f\"researcher with interest in {intent}\",\n                conv=conv_history_string,\n            ).question\n"
  },
  {
    "path": "knowledge_storm/collaborative_storm/modules/warmstart_hierarchical_chat.py",
    "content": "\"\"\"\r\nWarm starts the Co-STORM system by conducting a background information search to establish a shared conceptual space with the user.\r\n \r\nThis stage functions as a mini-STORM, where multiple LLM agents are spawned with different perspectives to engage in multi-round conversations. \r\nThe knowledge base (represented as a mind map) is initialized using the information gathered during these exchanges.\r\n\r\nAdditionally, the system generates a first draft of the report, which is then used to create a concise and engaging conversation. \r\nThe synthesized conversation is presented to the user to help them quickly catch up on the system's current knowledge about the topic.\r\n\"\"\"\r\n\r\nimport dspy\r\nimport concurrent.futures\r\nfrom threading import Lock\r\nfrom typing import List, Optional, Union, TYPE_CHECKING\r\n\r\nfrom .callback import BaseCallbackHandler\r\nfrom .collaborative_storm_utils import _get_answer_question_module_instance\r\nfrom .expert_generation import GenerateExpertModule\r\nfrom .grounded_question_answering import AnswerQuestionModule\r\nfrom ...dataclass import ConversationTurn, KnowledgeBase\r\nfrom ...interface import LMConfigs\r\nfrom ...logging_wrapper import LoggingWrapper\r\nfrom ...storm_wiki.modules.outline_generation import WritePageOutline\r\nfrom ...utils import ArticleTextProcessing as AP\r\n\r\n\r\nif TYPE_CHECKING:\r\n    from ..engine import RunnerArgument\r\n\r\n\r\nclass WarmStartModerator(dspy.Signature):\r\n    \"\"\"\r\n    You are a moderator in a roundtable discussion. The goal is to chat with multiple experts to discuss the facts and background of the topic to familiarize the audience with the topic.\r\n    You will be presented with the topic, the history of question you have already asked, and the current expert you are discussing with.\r\n    Based on these information, generate the next question for the current expert to further the discussion.\r\n\r\n    The output should only include the next question for the current expert. Do not include any other information or preamble.\r\n    \"\"\"\r\n\r\n    topic = dspy.InputField(prefix=\"Topic for roundtable discussion: \", format=str)\r\n    history = dspy.InputField(\r\n        prefix=\"Experts you have already interacted with: \", format=str\r\n    )\r\n    current_expert = dspy.InputField(prefix=\"Expert you are talking with:\", format=str)\r\n    question = dspy.OutputField(\r\n        prefix=\"Next question for the expert you are talking with: \", format=str\r\n    )\r\n\r\n\r\nclass SectionToConvTranscript(dspy.Signature):\r\n    \"\"\"\r\n    You are given a section of a brief report on a specific topic. Your task is to transform this section into an engaging opening discussion for a roundtable conversation.\r\n    The goal is to help participants and the audience quickly understand the key information.\r\n    Both question and answer should be in the tone of roundtable discussion talking to audiences.\r\n\r\n    Specifically, you need to:\r\n    1. Generate an engaging question that leverages section name and topic that opens discussion of the content.\r\n    2. Provide a brief and engaging answer (with all inline citations from original text) derived from the section serving as pointers and avoid too much details.\r\n    \"\"\"\r\n\r\n    topic = dspy.InputField(prefix=\"topic:\", format=str)\r\n    section_name = dspy.InputField(prefix=\"section name:\", format=str)\r\n    section_content = dspy.InputField(prefix=\"section content:\", format=str)\r\n    question = dspy.OutputField(prefix=\"Now give engaging question only.\\nQuestion:\")\r\n    answer = dspy.OutputField(\r\n        prefix=\"Now give engaging answer only with all inline citations from original text.\\nAnswer:\"\r\n    )\r\n\r\n\r\nclass ReportToConversation(dspy.Module):\r\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\r\n        self.engine = engine\r\n        self.section_to_conv_transcript = dspy.Predict(SectionToConvTranscript)\r\n\r\n    def forward(self, knowledge_base: KnowledgeBase):\r\n        def process_node(node, topic):\r\n            with dspy.settings.context(lm=self.engine, show_guidelines=False):\r\n                output = self.section_to_conv_transcript(\r\n                    topic=topic,\r\n                    section_name=node.get_path_from_root(),\r\n                    section_content=node.synthesize_output,\r\n                )\r\n                question = output.question.replace(\"Question:\", \"\").strip()\r\n                answer = output.answer.replace(\"Answer:\", \"\").strip()\r\n                return question, answer\r\n\r\n        conversations = []\r\n        nodes = knowledge_base.collect_all_nodes()\r\n        nodes = [node for node in nodes if node.name != \"root\" and node.content]\r\n        topic = knowledge_base.topic\r\n\r\n        with concurrent.futures.ThreadPoolExecutor() as executor:\r\n            future_to_node = {\r\n                executor.submit(process_node, node, topic): node for node in nodes\r\n            }\r\n            for future in concurrent.futures.as_completed(future_to_node):\r\n                node = future_to_node[future]\r\n                question, answer = future.result()\r\n                conversations.append(\r\n                    ConversationTurn(\r\n                        role=\"Background discussion moderator\",\r\n                        raw_utterance=question,\r\n                        utterance_type=\"Original Question\",\r\n                        utterance=question,\r\n                        cited_info=[\r\n                            knowledge_base.info_uuid_to_info_dict[idx]\r\n                            for idx in AP.parse_citation_indices(question)\r\n                        ],\r\n                    )\r\n                )\r\n                conversations.append(\r\n                    ConversationTurn(\r\n                        role=\"Background discussion expert\",\r\n                        raw_utterance=answer,\r\n                        utterance_type=\"Potential Answer\",\r\n                        utterance=answer,\r\n                        cited_info=[\r\n                            knowledge_base.info_uuid_to_info_dict[idx]\r\n                            for idx in AP.parse_citation_indices(answer)\r\n                        ],\r\n                    )\r\n                )\r\n        return conversations\r\n\r\n\r\nclass WarmStartConversation(dspy.Module):\r\n    def __init__(\r\n        self,\r\n        question_asking_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\r\n        generate_expert_module: GenerateExpertModule,\r\n        answer_question_module: AnswerQuestionModule,\r\n        logging_wrapper: LoggingWrapper,\r\n        max_num_experts: int = 3,\r\n        max_turn_per_experts: int = 2,\r\n        max_thread: int = 3,\r\n        callback_handler: BaseCallbackHandler = None,\r\n    ):\r\n        self.ask_question = dspy.Predict(WarmStartModerator)\r\n        self.max_num_experts = max_num_experts\r\n        self.max_turn_per_experts = max_turn_per_experts\r\n        self.question_asking_lm = question_asking_lm\r\n        self.answer_question_module = answer_question_module\r\n        self.max_thread = max_thread\r\n        self.generate_experts_module = generate_expert_module\r\n        self.logging_wrapper = logging_wrapper\r\n        self.callback_handler = callback_handler\r\n\r\n    def format_dialogue_question_history_string(\r\n        self, conversation_history: List[ConversationTurn]\r\n    ):\r\n        output = []\r\n        for idx, turn in enumerate(conversation_history):\r\n            info = turn.claim_to_make if turn.claim_to_make else turn.utterance\r\n            output.append(f\"{idx + 1}: {info}\")\r\n        return \"\\n\".join(output)\r\n\r\n    def generate_warmstart_experts(self, topic: str):\r\n        background_seeking_dialogue = self.get_background_info(topic=topic)\r\n        background_info = background_seeking_dialogue.utterance\r\n        gen_expert_output = self.generate_experts_module(\r\n            topic=topic,\r\n            background_info=background_info,\r\n            num_experts=self.max_num_experts,\r\n        )\r\n        return gen_expert_output.experts, background_seeking_dialogue\r\n\r\n    def get_background_info(self, topic: str):\r\n        question = f\"Background information about {topic}\"\r\n        answer = self.answer_question_module(\r\n            topic=topic, question=question, mode=\"extensive\", style=\"conversational\"\r\n        )\r\n\r\n        return ConversationTurn(\r\n            role=\"Default Background Researcher\",\r\n            raw_utterance=answer.response,\r\n            utterance_type=\"Questioning\",\r\n            claim_to_make=question,\r\n            queries=answer.queries,\r\n            raw_retrieved_info=answer.raw_retrieved_info,\r\n            cited_info=answer.cited_info,\r\n        )\r\n\r\n    def forward(self, topic: str):\r\n        with self.logging_wrapper.log_event(\r\n            \"warm start, perspective guided QA: identify experts\"\r\n        ):\r\n            # do background research, generate some experts\r\n            experts, background_seeking_dialogue = self.generate_warmstart_experts(\r\n                topic=topic\r\n            )\r\n        # init list to store the dialogue history\r\n        conversation_history: List[ConversationTurn] = []\r\n        lock = Lock()\r\n\r\n        # hierarchical chat: chat with one expert. Generate question, get answer\r\n        def process_expert(expert):\r\n            expert_name, expert_descriptoin = expert.split(\":\")\r\n            for idx in range(self.max_turn_per_experts):\r\n                with self.logging_wrapper.log_event(\r\n                    f\"warm start, perspective guided QA: expert {expert_name}; turn {idx + 1}\"\r\n                ):\r\n                    try:\r\n                        with lock:\r\n                            history = self.format_dialogue_question_history_string(\r\n                                conversation_history\r\n                            )\r\n                        with dspy.settings.context(lm=self.question_asking_lm):\r\n                            question = self.ask_question(\r\n                                topic=topic, history=history, current_expert=expert\r\n                            ).question\r\n                        answer = self.answer_question_module(\r\n                            topic=topic,\r\n                            question=question,\r\n                            mode=\"brief\",\r\n                            style=\"conversational\",\r\n                        )\r\n                        conversation_turn = ConversationTurn(\r\n                            role=expert,\r\n                            claim_to_make=question,\r\n                            raw_utterance=answer.response,\r\n                            utterance_type=\"Support\",\r\n                            queries=answer.queries,\r\n                            raw_retrieved_info=answer.raw_retrieved_info,\r\n                            cited_info=answer.cited_info,\r\n                        )\r\n                        if self.callback_handler is not None:\r\n                            self.callback_handler.on_warmstart_update(\r\n                                message=\"\\n\".join(\r\n                                    [\r\n                                        f\"Finish browsing {url}\"\r\n                                        for url in [\r\n                                            i.url for i in answer.raw_retrieved_info\r\n                                        ]\r\n                                    ]\r\n                                )\r\n                            )\r\n                        with lock:\r\n                            conversation_history.append(conversation_turn)\r\n                    except Exception as e:\r\n                        print(f\"Error processing expert {expert}: {e}\")\r\n\r\n        # multi-thread conversation\r\n        with concurrent.futures.ThreadPoolExecutor(\r\n            max_workers=self.max_thread\r\n        ) as executor:\r\n            futures = [\r\n                executor.submit(process_expert, expert)\r\n                for expert in experts[: min(len(experts), self.max_num_experts)]\r\n            ]\r\n            concurrent.futures.wait(futures)\r\n\r\n        conversation_history = [background_seeking_dialogue] + conversation_history\r\n\r\n        return dspy.Prediction(\r\n            conversation_history=conversation_history, experts=experts\r\n        )\r\n\r\n\r\nclass GenerateWarmStartOutline(dspy.Signature):\r\n    \"\"\"Generate a outline of the wikipedia-like report from a roundtable discussion. You will be presented discussion points in the conversation and corresponding queries.\r\n    You will be given a draft outline which you can borrow some inspiration. Do not include sections that are not mentioned in the given discussion history.\r\n    Use \"#\" to denote section headings, \"##\" to denote subsection headings, and so on.\r\n     Follow these guidelines:\r\n     1. Use \"#\" for section titles, \"##\" for subsection titles, \"###\" for subsubsection titles, and so on.\r\n     2. Do not include any additional information.\r\n     3. Exclude the topic name from the outline.\r\n     The organization of outline should adopt wikiepdia style.\r\n    \"\"\"\r\n\r\n    topic = dspy.InputField(prefix=\"The topic discussed: \", format=str)\r\n    draft = dspy.InputField(prefix=\"Draft outline you can reference to: \", format=str)\r\n    conv = dspy.InputField(prefix=\"Discussion history:\\n\", format=str)\r\n    outline = dspy.OutputField(\r\n        prefix='Write the conversation outline (Use \"#\" Title\" to indicate section title, \"##\" Title\" to indicate subsection title, ...):\\n',\r\n        format=str,\r\n    )\r\n\r\n\r\nclass GenerateWarmStartOutlineModule(dspy.Module):\r\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\r\n        self.engine = engine\r\n        self.gen_outline = dspy.Predict(GenerateWarmStartOutline)\r\n        self.draft_outline = dspy.Predict(WritePageOutline)\r\n\r\n    def extract_questions_and_queries(self, conv: List[ConversationTurn]):\r\n        context = []\r\n        for turn in conv:\r\n            focus = turn.claim_to_make\r\n            queries = turn.queries\r\n            queries_string = \"\\n\\t\".join(\r\n                f\"Query {idx + 1}: {query}\" for idx, query in enumerate(queries)\r\n            )\r\n            string = f\"Discussion focus {len(context) + 1}: {focus}\\n\\t{queries_string}\"\r\n            context.append(string)\r\n        return \"\\n\".join(context)\r\n\r\n    def get_draft_outline(self, topic: str):\r\n        with dspy.settings.context(lm=self.engine):\r\n            return self.draft_outline(topic=topic).outline\r\n\r\n    def forward(self, topic: str, conv: List[ConversationTurn]):\r\n        discussion_history = self.extract_questions_and_queries(conv)\r\n        draft_outline = self.get_draft_outline(topic=topic)\r\n        with dspy.settings.context(lm=self.engine):\r\n            outline = self.gen_outline(\r\n                topic=topic, draft=draft_outline, conv=discussion_history\r\n            ).outline\r\n            outline = AP.clean_up_outline(outline)\r\n        return dspy.Prediction(outline=outline, draft_outline=draft_outline)\r\n\r\n\r\nclass WarmStartModule:\r\n    def __init__(\r\n        self,\r\n        lm_config: LMConfigs,\r\n        runner_argument: \"RunnerArgument\",\r\n        logging_wrapper: LoggingWrapper,\r\n        rm: Optional[dspy.Retrieve] = None,\r\n        callback_handler: BaseCallbackHandler = None,\r\n    ):\r\n        generate_expert_module = GenerateExpertModule(\r\n            engine=lm_config.discourse_manage_lm\r\n        )\r\n        self.warmstart_conv = WarmStartConversation(\r\n            question_asking_lm=lm_config.question_asking_lm,\r\n            generate_expert_module=generate_expert_module,\r\n            answer_question_module=_get_answer_question_module_instance(\r\n                lm_config=lm_config,\r\n                runner_argument=runner_argument,\r\n                logging_wrapper=logging_wrapper,\r\n                rm=rm,\r\n            ),\r\n            max_num_experts=runner_argument.warmstart_max_num_experts,\r\n            max_turn_per_experts=runner_argument.warmstart_max_turn_per_experts,\r\n            max_thread=runner_argument.warmstart_max_thread,\r\n            logging_wrapper=logging_wrapper,\r\n            callback_handler=callback_handler,\r\n        )\r\n        self.warmstart_outline_gen_module = GenerateWarmStartOutlineModule(\r\n            engine=lm_config.warmstart_outline_gen_lm\r\n        )\r\n        self.report_to_conversation = ReportToConversation(lm_config.knowledge_base_lm)\r\n        self.logging_wrapper = logging_wrapper\r\n        self.callback_handler = callback_handler\r\n\r\n    def initiate_warm_start(self, topic: str, knowledge_base: KnowledgeBase):\r\n        \"\"\"\r\n        Initiates a warm start process for the given topic by generating a warm start conversation and inserting the\r\n        resulting information into a knowledge base.\r\n\r\n        Args:\r\n            topic (str): The topic for which to initiate the warm start process.\r\n\r\n        Returns:\r\n            Tuple[List[ConversationTurn], List[str], KnowledgeBase]:\r\n                - A list of ConversationTurn instances representing the conversation history.\r\n                - A list of strings representing the experts involved in the conversation.\r\n                - A KnowledgeBase instance containing the organized information.\r\n        \"\"\"\r\n        warm_start_conversation_history: List[ConversationTurn] = []\r\n        warm_start_experts = None\r\n        # get warm start conversations\r\n        with self.logging_wrapper.log_event(\"warm start: perspective guided QA\"):\r\n            if self.callback_handler is not None:\r\n                self.callback_handler.on_warmstart_update(\r\n                    message=\"Start getting familiar with the topic by chatting with multiple LLM experts (Step 1 / 4)\"\r\n                )\r\n            warm_start_result = self.warmstart_conv(topic=topic)\r\n            warm_start_conversation_history = warm_start_result.conversation_history\r\n            warm_start_experts = warm_start_result.experts\r\n\r\n        # get warm start conv outline\r\n        with self.logging_wrapper.log_event(\"warm start: outline generation\"):\r\n            if self.callback_handler is not None:\r\n                self.callback_handler.on_warmstart_update(\r\n                    \"Organizing collected information (Step 2 / 4)\"\r\n                )\r\n            warm_start_outline_output = self.warmstart_outline_gen_module(\r\n                topic=topic, conv=warm_start_conversation_history\r\n            )\r\n        # init knowledge base\r\n        with self.logging_wrapper.log_event(\"warm start: insert into knowledge base\"):\r\n            if self.callback_handler is not None:\r\n                self.callback_handler.on_warmstart_update(\r\n                    \"Inserting collected information into knowledge base (Step 3 / 4)\"\r\n                )\r\n            knowledge_base.insert_from_outline_string(\r\n                outline_string=warm_start_outline_output.outline\r\n            )\r\n            # insert information to knowledge base\r\n            for turn in warm_start_conversation_history:\r\n                knowledge_base.update_from_conv_turn(\r\n                    conv_turn=turn, allow_create_new_node=False\r\n                )\r\n        # knowledge base to report\r\n        if self.callback_handler is not None:\r\n            self.callback_handler.on_warmstart_update(\r\n                \"Synthesizing background information discussion utterances (Step 4 / 4)\"\r\n            )\r\n        knowledge_base.to_report()\r\n\r\n        # generate engaging conversations\r\n        engaging_conversations = self.report_to_conversation(knowledge_base)\r\n        return (\r\n            warm_start_conversation_history,\r\n            engaging_conversations,\r\n            warm_start_experts,\r\n        )\r\n"
  },
  {
    "path": "knowledge_storm/dataclass.py",
    "content": "import dspy\nimport numpy as np\nimport re\nimport threading\nfrom typing import Set, Dict, List, Optional, Union, Tuple\n\nfrom .encoder import Encoder\nfrom .interface import Information\n\n\nclass ConversationTurn:\n    \"\"\"\n    A class to represent a turn in a conversation.\n\n    Attributes:\n        role (str): A short phrase of the role of the speaker for the current conversation turn.\n        raw_utterance (str): The response generated by the LM model without polished style and tone.\n        utterance_type (str): The type of utterance (e.g., statement, question).\n        claim_to_make (Optional[str]): The point that this utterance tries to make. Should be empty if the utterance type is questioning.\n        utterance (Optional[str]): The response generated by the model with polished style and tone. Defaults to raw_utterance if not provided.\n        queries (List[str]): The queries used to gather information to have a grounded answer.\n        raw_retrieved_info (List['Information']): A list of Information type that is retrieved.\n        cited_info (Dict[int, 'Information']): A dictionary where the key is the citation index and the value is Information type.\n        role_description (Optional[str]): A few sentences description of the role. Defaults to an empty string if not provided.\n    \"\"\"\n\n    def __init__(\n        self,\n        role: str,\n        raw_utterance: str,\n        utterance_type: str,\n        claim_to_make: Optional[str] = None,\n        utterance: Optional[str] = None,\n        queries: Optional[List[str]] = None,\n        raw_retrieved_info: Optional[List[Information]] = None,\n        cited_info: Optional[List[Information]] = None,\n    ):\n        self.utterance = utterance if utterance is not None else raw_utterance\n        self.raw_utterance = raw_utterance\n        self.role = role if \":\" not in role else role.split(\":\")[0]\n        self.role_description = \"\" if \":\" not in role else role.split(\":\")[1]\n        self.queries = queries if queries is not None else []\n        self.raw_retrieved_info = (\n            raw_retrieved_info if raw_retrieved_info is not None else []\n        )\n        self.cited_info = cited_info if cited_info is not None else {}\n        self.utterance_type = utterance_type\n        self.claim_to_make = claim_to_make if claim_to_make is not None else \"\"\n\n    def get_all_citation_index(self):\n        citation_pattern = re.compile(r\"\\[(\\d+)\\]\")\n        return list(map(int, citation_pattern.findall(self.utterance)))\n\n    def to_dict(self):\n        raw_retrieved_info = [info.to_dict() for info in self.raw_retrieved_info]\n        return {\n            \"utterance\": self.utterance,\n            \"raw_utterance\": self.raw_utterance,\n            \"role\": self.role,\n            \"role_description\": self.role_description,\n            \"queries\": self.queries,\n            \"utterance_type\": self.utterance_type,\n            \"claim_to_make\": self.claim_to_make,\n            \"raw_retrieved_info\": raw_retrieved_info,\n            \"cited_info\": None,\n        }\n\n    @classmethod\n    def from_dict(cls, conv_turn_dict: Dict):\n        raw_retrieved_info = [\n            Information.from_dict(info) for info in conv_turn_dict[\"raw_retrieved_info\"]\n        ]\n\n        return cls(\n            utterance=conv_turn_dict[\"utterance\"],\n            raw_utterance=conv_turn_dict[\"raw_utterance\"],\n            role=f\"{conv_turn_dict['role']}: {conv_turn_dict['role_description']}\",\n            queries=conv_turn_dict[\"queries\"],\n            raw_retrieved_info=raw_retrieved_info,\n            cited_info=None,\n            utterance_type=conv_turn_dict[\"utterance_type\"],\n            claim_to_make=conv_turn_dict[\"claim_to_make\"],\n        )\n\n\nclass KnowledgeNode:\n    \"\"\"\n    Class representing a node in the knowledge base.\n\n    Attributes:\n        name (str): The name of the node.\n        content (list): A list of Information instances.\n        children (list): A list of child KnowledgeNode instances.\n        parent (KnowledgeNode): The parent node of the current node.\n    \"\"\"\n\n    def __init__(\n        self,\n        name: str,\n        content: Optional[str] = None,\n        parent: Optional[\"KnowledgeNode\"] = None,\n        children: Optional[List[\"KnowledgeNode\"]] = None,\n        synthesize_output: Optional[str] = None,\n        need_regenerate_synthesize_output: bool = True,\n    ):\n        \"\"\"\n        Initializes a KnowledgeNode instance.\n\n        Args:\n            name (str): The name of the node.\n            content (list, optional): A list of information uuid. Defaults to None.\n            parent (KnowledgeNode, optional): The parent node of the current node. Defaults to None.\n        \"\"\"\n        self.name = name\n        self.content: Set[int] = set(content) if content is not None else set()\n        self.children = [] if children is None else children\n        self.parent = parent\n        self.synthesize_output = synthesize_output\n        self.need_regenerate_synthesize_output = need_regenerate_synthesize_output\n\n    def collect_all_content(self):\n        \"\"\"\n        Collects all content from the current node and its descendants.\n\n        Returns:\n            Set[int]: A set containing all content from the current node and its descendants.\n        \"\"\"\n        all_content = set(self.content)\n        for child in self.children:\n            all_content.update(child.collect_all_content())\n        return all_content\n\n    def has_child(self, child_node_name: str):\n        \"\"\"\n        Check if the node has the child of given name.\n        \"\"\"\n        return child_node_name in [child.name for child in self.children]\n\n    def add_child(self, child_node_name: str, duplicate_handling: str = \"skip\"):\n        \"\"\"\n        Adds a child node to the current node.\n        duplicate_handling (str): How to handle duplicate nodes. Options are \"skip\", \"none\", and \"raise error\".\n        \"\"\"\n        if self.has_child(child_node_name):\n            if duplicate_handling == \"skip\":\n                for child in self.children:\n                    if child.name == child_node_name:\n                        return child\n            elif duplicate_handling == \"raise error\":\n                raise Exception(\n                    f\"Insert node error. Node {child_node_name} already exists under its parent node {self.name}.\"\n                )\n        child_node = KnowledgeNode(name=child_node_name, parent=self)\n        self.children.append(child_node)\n        return child_node\n\n    def get_parent(self):\n        \"\"\"\n        Returns the parent node of the current node.\n\n        Returns:\n            KnowledgeNode: The parent node of the current node.\n        \"\"\"\n        return self.parent\n\n    def get_children(self):\n        \"\"\"\n        Returns the children of the current node.\n\n        Returns:\n            list: A list of child KnowledgeNode instances.\n        \"\"\"\n        return self.children\n\n    def get_children_names(self):\n        \"\"\"\n        Returns a list of children names.\n        \"\"\"\n        return [child.name for child in self.children]\n\n    def __repr__(self):\n        \"\"\"\n        Returns a string representation of the KnowledgeNode instance.\n\n        Returns:\n            str: String representation of the KnowledgeNode instance.\n        \"\"\"\n        return f\"KnowledgeNode(name={self.name}, content={self.content}, children={len(self.children)})\"\n\n    def get_path_from_root(self, root: Optional[\"KnowledgeNode\"] = None):\n        \"\"\"\n        Get a list of names from the root to this node.\n\n        Returns:\n            List[str]: A list of node names from the root to this node.\n        \"\"\"\n        path = []\n        current_node = self\n        while current_node:\n            path.append(current_node.name)\n            if root is not None and current_node.name == root.name:\n                break\n            current_node = current_node.parent\n        return path[::-1]\n\n    def insert_information(self, information_index: int):\n        if information_index not in self.content:\n            self.need_regenerate_synthesize_output = True\n            self.content.add(information_index)\n\n    def get_all_descendents(self) -> List[\"KnowledgeNode\"]:\n        \"\"\"\n        Get a list of all descendant nodes.\n\n        Returns:\n            List[KnowledgeNode]: A list of all descendant nodes.\n        \"\"\"\n        descendents = []\n\n        def collect_descendents(node):\n            for child in node.children:\n                descendents.append(child)\n                collect_descendents(child)\n\n        collect_descendents(self)\n        return descendents\n\n    def get_all_predecessors(self) -> List[\"KnowledgeNode\"]:\n        \"\"\"\n        Get a list of all predecessor nodes (from current node to root).\n\n        Returns:\n            List[KnowledgeNode]: A list of all predecessor nodes.\n        \"\"\"\n        predecessors = []\n        current_node = self.parent\n        while current_node is not None:\n            predecessors.append(current_node)\n            current_node = current_node.parent\n        return predecessors\n\n    def to_dict(self):\n        \"\"\"\n        Converts the KnowledgeNode instance to a dictionary representation.\n\n        Returns:\n            dict: The dictionary representation of the KnowledgeNode.\n        \"\"\"\n        return {\n            \"name\": self.name,\n            \"content\": list(self.content),\n            \"children\": [child.to_dict() for child in self.children],\n            \"parent\": self.parent.name if self.parent else None,\n            \"synthesize_output\": self.synthesize_output,\n            \"need_regenerate_synthesize_output\": self.need_regenerate_synthesize_output,\n        }\n\n    @classmethod\n    def from_dict(cls, data):\n        \"\"\"\n        Constructs a KnowledgeNode instance from a dictionary representation.\n\n        Args:\n            data (dict): The dictionary representation of the KnowledgeNode.\n\n        Returns:\n            KnowledgeNode: The constructed KnowledgeNode instance.\n        \"\"\"\n\n        def helper(cls, data, parent_node=None):\n            if parent_node is not None:\n                assert data[\"parent\"] is not None and data[\"parent\"] == parent_node.name\n            node = cls(\n                name=data[\"name\"],\n                content=data[\"content\"],\n                parent=parent_node,\n                children=None,\n                synthesize_output=data.get(\"synthesize_output\", None),\n                need_regenerate_synthesize_output=data.get(\n                    \"need_regenerate_synthesize_output\", True\n                ),\n            )\n            for child_data in data[\"children\"]:\n                child_node = helper(cls, child_data, parent_node=node)\n                node.children.append(child_node)\n            return node\n\n        return helper(cls, data)\n\n\nclass KnowledgeBase:\n    \"\"\"\n    Represents the dynamic, hierarchical mind map used in Co-STORM to track and organize discourse.\n\n    The knowledge base serves as a shared conceptual space between the user and the system, allowing for effective collaboration by reducing the user's cognitive load and ensuring that the discourse is easy to follow.\n\n    The knowledge base is structured as a tree (or mind map) that dynamically organizes collected information and concepts as the conversation progresses.\n\n    The mind map consists of concepts (nodes) and edges that represent parent-child relationships among topics. Each concept is linked to retrieved information,\n    which is placed under the most appropriate concept based on its associated question and semantic similarity.\n\n    For more details, please refer to Section 3.2 of Co-STORM paper: https://www.arxiv.org/pdf/2408.15232\n    Attributes:\n        root (KnowledgeNode): The root node of the hierarchical knowledge base, representing the top-level concept.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        topic: str,\n        knowledge_base_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        node_expansion_trigger_count: int,\n        encoder: Encoder,\n    ):\n        \"\"\"\n        Initializes a KnowledgeBase instance.\n\n        Args:\n            topic (str): The topic of the knowledge base\n            expand_node_module (dspy.Module): The module that organize knowledge base in place.\n                The module should accept knowledge base as param. E.g. expand_node_module(self)\n            article_generation_module (dspy.Module): The module that generate report from knowledge base.\n                The module should return string. E.g. report = article_generation_module(self)\n        \"\"\"\n        from .collaborative_storm.modules.article_generation import (\n            ArticleGenerationModule,\n        )\n        from .collaborative_storm.modules.information_insertion_module import (\n            InsertInformationModule,\n            ExpandNodeModule,\n        )\n        from .collaborative_storm.modules.knowledge_base_summary import (\n            KnowledgeBaseSummaryModule,\n        )\n\n        self.topic: str = topic\n        self.encoder: Encoder = encoder\n\n        self.information_insert_module = InsertInformationModule(\n            engine=knowledge_base_lm, encoder=self.encoder\n        )\n        self.expand_node_module = ExpandNodeModule(\n            engine=knowledge_base_lm,\n            information_insert_module=self.information_insert_module,\n            node_expansion_trigger_count=node_expansion_trigger_count,\n        )\n        self.article_generation_module = ArticleGenerationModule(\n            engine=knowledge_base_lm\n        )\n        self.gen_summary_module = KnowledgeBaseSummaryModule(engine=knowledge_base_lm)\n\n        self.root: KnowledgeNode = KnowledgeNode(name=\"root\")\n        self.kb_embedding = {\n            \"hash\": hash(\"\"),\n            \"encoded_structure\": np.array([[]]),\n            \"structure_string\": \"\",\n        }\n        self.info_uuid_to_info_dict: Dict[int, Information] = {}\n        self.info_hash_to_uuid_dict: Dict[int, int] = {}\n        self._lock = threading.Lock()\n\n    def to_dict(self):\n        info_uuid_to_info_dict = {\n            key: value.to_dict() for key, value in self.info_uuid_to_info_dict.items()\n        }\n        return {\n            \"topic\": self.topic,\n            \"tree\": self.root.to_dict(),\n            \"info_uuid_to_info_dict\": info_uuid_to_info_dict,\n            \"info_hash_to_uuid_dict\": self.info_hash_to_uuid_dict,\n        }\n\n    @classmethod\n    def from_dict(\n        cls,\n        data: Dict,\n        knowledge_base_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        node_expansion_trigger_count: int,\n        encoder: Encoder,\n    ):\n        knowledge_base = cls(\n            topic=data[\"topic\"],\n            knowledge_base_lm=knowledge_base_lm,\n            node_expansion_trigger_count=node_expansion_trigger_count,\n            encoder=encoder,\n        )\n        knowledge_base.root = KnowledgeNode.from_dict(data[\"tree\"])\n        knowledge_base.info_hash_to_uuid_dict = {\n            int(key): int(value)\n            for key, value in data[\"info_hash_to_uuid_dict\"].items()\n        }\n        info_uuid_to_info_dict = {\n            int(key): Information.from_dict(value)\n            for key, value in data[\"info_uuid_to_info_dict\"].items()\n        }\n        knowledge_base.info_uuid_to_info_dict = info_uuid_to_info_dict\n        return knowledge_base\n\n    def get_knowledge_base_structure_embedding(\n        self, root: Optional[KnowledgeNode] = None\n    ) -> Tuple[np.ndarray, List[str]]:\n        outline_string = self.get_node_hierarchy_string(\n            include_indent=False,\n            include_full_path=True,\n            include_hash_tag=False,\n            root=root,\n        )\n        outline_string_hash = hash(outline_string)\n        if outline_string_hash != self.kb_embedding[\"hash\"]:\n            outline_strings: List[str] = outline_string.split(\"\\n\")\n            cleaned_outline_strings = [\n                outline.replace(\" -> \", \", \") for outline in outline_strings\n            ]\n            encoded_outline = self.encoder.encode(cleaned_outline_strings)\n            self.kb_embedding = {\n                \"hash\": outline_string_hash,\n                \"encoded_structure\": encoded_outline,\n                \"structure_string\": outline_strings,\n            }\n        return (\n            self.kb_embedding[\"encoded_structure\"],\n            self.kb_embedding[\"structure_string\"],\n        )\n\n    def traverse_down(self, node):\n        \"\"\"\n        Traverses the tree downward from the given node.\n\n        Args:\n            node (KnowledgeNode): The node to start the traversal from.\n\n        Returns:\n            list: A list of KnowledgeNode instances in the order they were visited.\n        \"\"\"\n        nodes = []\n\n        def _traverse(current_node):\n            nodes.append(current_node)\n            for child in current_node.get_children():\n                _traverse(child)\n\n        _traverse(node)\n        return nodes\n\n    def traverse_up(self, node):\n        \"\"\"\n        Traverses the tree upward from the given node.\n\n        Args:\n            node (KnowledgeNode): The node to start the traversal from.\n\n        Returns:\n            list: A list of KnowledgeNode instances in the order they were visited.\n        \"\"\"\n        nodes = []\n        while node is not None:\n            nodes.append(node)\n            node = node.get_parent()\n        return nodes\n\n    def collect_all_nodes(self):\n        nodes = []\n\n        def _collect(node):\n            nodes.append(node)\n            for child in node.children:\n                _collect(child)\n\n        _collect(self.root)\n        return nodes\n\n    def insert_node(\n        self,\n        new_node_name,\n        parent_node: Optional[KnowledgeNode] = None,\n        duplicate_handling=\"skip\",\n    ):\n        \"\"\"\n        Inserts a new node into the knowledge base under the specified parent node.\n\n        Args:\n            new_node_name (str): The name of the new node.\n            parent_node_name (str): The name of the parent node. If None, the new node is inserted under the root.\n            duplicate_handling (str): How to handle duplicate nodes. Options are \"skip\", \"none\", and \"raise error\".\n        \"\"\"\n        if parent_node is None:\n            return self.root.add_child(\n                new_node_name, duplicate_handling=duplicate_handling\n            )\n        else:\n            return parent_node.add_child(\n                new_node_name, duplicate_handling=duplicate_handling\n            )\n\n    def find_node(self, current_node, node_name):\n        \"\"\"\n        Finds a node by name in the knowledge base.\n\n        Args:\n            current_node (KnowledgeNode): The node to start the search from.\n            node_name (str): The name of the node to find.\n\n        Returns:\n            KnowledgeNode: The node with the specified name, or None if not found.\n        \"\"\"\n        if current_node.name == node_name:\n            return current_node\n        for child in current_node.get_children():\n            result = self.find_node(child, node_name)\n            if result is not None:\n                return result\n        return None\n\n    def insert_from_outline_string(self, outline_string, duplicate_handling=\"skip\"):\n        \"\"\"\n        Creates and inserts nodes into the knowledge base from a string outline.\n\n        Args:\n            outline_string (str): The outline string where each line starts with '#' denoting the level.\n            duplicate_handling (str): How to handle duplicate nodes. Options are \"skip\", \"none\", and \"raise error\".\n        \"\"\"\n        last_node_at_level = {}\n        for line in outline_string.split(\"\\n\"):\n            level = line.count(\"#\")\n            if level > 0:\n                title = line.strip(\"# \").strip()\n                if title.lower() in [\"overview\", \"summary\", \"introduction\"]:\n                    continue\n                parent_node = None if level == 1 else last_node_at_level.get(level - 1)\n                new_node = self.insert_node(\n                    new_node_name=title,\n                    parent_node=parent_node,\n                    duplicate_handling=duplicate_handling,\n                )\n                last_node_at_level[level] = new_node\n                for deeper_level in list(last_node_at_level.keys()):\n                    if deeper_level > level:\n                        del last_node_at_level[deeper_level]\n\n    def get_node_hierarchy_string(\n        self,\n        include_indent=False,\n        include_full_path=False,\n        include_hash_tag=True,\n        include_node_content_count=False,\n        cited_indices: Optional[List[int]] = None,\n        root: Optional[KnowledgeNode] = None,\n    ) -> str:\n        def find_node_contain_index(node, index):\n            \"\"\"\n            Traverses the tree downward from the given node.\n\n            Args:\n                node (KnowledgeNode): The node to start the traversal from.\n\n            Returns:\n                list: A list of KnowledgeNode instances in the order they were visited.\n            \"\"\"\n            nodes = []\n\n            def _traverse(current_node):\n                if current_node is not None and index in current_node.content:\n                    nodes.append(current_node)\n                for child in current_node.get_children():\n                    _traverse(child)\n\n            _traverse(node)\n            return nodes\n\n        paths_to_highlight = set()\n        nodes_to_include = set()\n        if cited_indices is not None:\n            for index in cited_indices:\n                for cur_node in find_node_contain_index(self.root, index):\n                    paths_to_highlight.add(\" -> \".join(cur_node.get_path_from_root()))\n                    nodes_to_include.add(cur_node)\n                    nodes_to_include.update(cur_node.get_all_descendents())\n                    predecessors = cur_node.get_all_predecessors()\n                    for predecessor in predecessors:\n                        nodes_to_include.update(predecessor.children)\n                    nodes_to_include.update(predecessors)\n\n        def should_include_node(node):\n            if cited_indices is None:\n                return True\n            return node in nodes_to_include\n\n        def should_omit_child_nodes(node):\n            if cited_indices is None:\n                return False\n            for child in node.children:\n                if should_include_node(child):\n                    return False\n            return True\n\n        def helper(cur_root, level):\n            to_return = []\n            if cur_root is not None:\n                should_include_current_node = should_include_node(cur_root)\n\n                indent = \"\" if not include_indent else \"\\t\" * (level - 1)\n                full_path = \" -> \".join(cur_root.get_path_from_root(root=root))\n                node_info = cur_root.name if not include_full_path else full_path\n                hash_tag = \"#\" * level + \" \" if include_hash_tag else \"\"\n                content_count = (\n                    f\" ({len(cur_root.content)})\" if include_node_content_count else \"\"\n                )\n                special_note = (\n                    \"\"\n                    if cited_indices is None or full_path not in paths_to_highlight\n                    else \" ⭐\"\n                )\n\n                if should_include_current_node:\n                    to_return.append(\n                        f\"{indent}{hash_tag}{node_info}{content_count}{special_note}\"\n                    )\n                    if should_omit_child_nodes(cur_root):\n                        if len(cur_root.children) > 0:\n                            child_indent = indent = (\n                                \"\" if not include_indent else \"\\t\" * (level)\n                            )\n                            to_return.append(f\"{child_indent}...\")\n                    else:\n                        for child in cur_root.children:\n                            to_return.extend(helper(child, level + 1))\n            return to_return\n\n        to_return = []\n        if root is None and self.root is not None:\n            for child in self.root.children:\n                to_return.extend(helper(child, level=1))\n        else:\n            to_return.extend(helper(root, level=1))\n\n        return \"\\n\".join(to_return)\n\n    def find_node_by_path(\n        self,\n        path: str,\n        missing_node_handling=\"abort\",\n        root: Optional[KnowledgeNode] = None,\n    ):\n        \"\"\"\n        Returns the target node given a path string.\n\n        Args:\n            path (str): The path to the node, with node names connected by \" -> \".\n            missing_node_handling (str): How to handle missing nodes. Options are \"abort\", \"create\", and \"raise error\".\n\n        Returns:\n            KnowledgeNode: The target node.\n        \"\"\"\n        node_names = path.split(\" -> \")\n        current_node = self.root if root is None else root\n\n        for name in node_names[1:]:\n            found_node = next(\n                (child for child in current_node.children if child.name == name), None\n            )\n            if found_node is None:\n                if missing_node_handling == \"abort\":\n                    return\n                elif missing_node_handling == \"create\":\n                    new_node = current_node.add_child(child_node_name=name)\n                    current_node = new_node\n                elif missing_node_handling == \"raise error\":\n                    structure = self.get_node_hierarchy_string(\n                        include_indent=True,\n                        include_full_path=False,\n                        include_hash_tag=True,\n                    )\n                    raise Exception(\n                        f\"Insert information error. Unable to find node {{{name}}} under {{{current_node.name}}}\\n{structure}\"\n                    )\n            else:\n                current_node = found_node\n        return current_node\n\n    def insert_information(\n        self,\n        path: str,\n        information: Information,\n        missing_node_handling=\"abort\",\n        root: Optional[KnowledgeNode] = None,\n    ):\n        \"\"\"\n        Inserts information into the knowledge base at the specified path.\n\n        Args:\n            path (str): The placement path string, connected by \" -> \" linking the name of nodes.\n            information (Information): The information to insert.\n            missing_node_handling (str): How to handle missing nodes. Options are \"abort\", \"create\", and \"raise error\".\n        Return:\n            uuid of insertion information\n        \"\"\"\n        with self._lock:\n            target_node: KnowledgeNode = self.find_node_by_path(\n                path=path, missing_node_handling=missing_node_handling, root=root\n            )\n            information_hash = hash(information)\n            if information.citation_uuid == -1:\n                info_citation_uuid = self.info_hash_to_uuid_dict.get(\n                    information_hash, len(self.info_hash_to_uuid_dict) + 1\n                )\n                information.citation_uuid = info_citation_uuid\n                self.info_hash_to_uuid_dict[information_hash] = info_citation_uuid\n                self.info_uuid_to_info_dict[info_citation_uuid] = information\n            if target_node is not None:\n                self.info_uuid_to_info_dict[information.citation_uuid].meta[\n                    \"placement\"\n                ] = \" -> \".join(target_node.get_path_from_root())\n                target_node.insert_information(information.citation_uuid)\n\n    def trim_empty_leaf_nodes(self):\n        \"\"\"\n        Trims all leaf nodes that do not have any content. Iteratively does it until all leaf nodes have at least one content.\n        \"\"\"\n\n        def trim_node(node):\n            if not node.children and not node.content:\n                return True\n            node.children = [child for child in node.children if not trim_node(child)]\n            return not node.children and not node.content\n\n        # Start the trimming process from the root\n        while True:\n            before_trim = len(self.get_all_leaf_nodes())\n            trim_node(self.root)\n            after_trim = len(self.get_all_leaf_nodes())\n            if before_trim == after_trim:\n                break\n\n    def get_all_leaf_nodes(self):\n        \"\"\"\n        Helper function to get all leaf nodes.\n\n        Returns:\n            List[KnowledgeNode]: A list of all leaf nodes in the knowledge base.\n        \"\"\"\n        leaf_nodes = []\n\n        def find_leaf_nodes(node):\n            if not node.children:\n                leaf_nodes.append(node)\n            for child in node.children:\n                find_leaf_nodes(child)\n\n        find_leaf_nodes(self.root)\n        return leaf_nodes\n\n    def merge_single_child_nodes(self):\n        \"\"\"\n        Merges content of a node with its single child and removes the child node.\n        Iteratively does this from leaf nodes back to the root.\n        \"\"\"\n\n        def merge_node(node):\n            # Recursively merge children first\n            for child in node.children:\n                merge_node(child)\n\n            # If the node has exactly one child, merge its content with the child and remove the child\n            if len(node.children) == 1:\n                single_child = node.children[0]\n                node.content.update(single_child.content)\n                node.children = single_child.children\n                for grandchild in node.children:\n                    grandchild.parent = node\n\n        merge_node(self.root)\n\n    def update_all_info_path(self):\n        def _helper(node):\n            for citation_idx in node.content:\n                self.info_uuid_to_info_dict[citation_idx].meta[\"placement\"] = (\n                    \" -> \".join(node.get_path_from_root())\n                )\n            for child in node.children:\n                _helper(child)\n\n        _helper(self.root)\n\n    def update_from_conv_turn(\n        self,\n        conv_turn: ConversationTurn,\n        allow_create_new_node: bool = False,\n        insert_under_root: bool = False,\n    ):\n        if conv_turn is None:\n            return\n        info_to_insert = list(conv_turn.cited_info.values())\n        if insert_under_root:\n            for info in info_to_insert:\n                self.insert_information(path=self.root.name, information=info)\n        else:\n            self.information_insert_module(\n                knowledge_base=self,\n                information=info_to_insert,\n                allow_create_new_node=allow_create_new_node,\n            )\n        old_to_new_citation_idx_mapping = {\n            old_idx: info.citation_uuid\n            for old_idx, info in conv_turn.cited_info.items()\n        }\n\n        for old_idx, new_idx in old_to_new_citation_idx_mapping.items():\n            conv_turn.utterance = conv_turn.utterance.replace(\n                f\"[{old_idx}]\", f\"[_{new_idx}_]\"\n            )\n            conv_turn.raw_utterance = conv_turn.raw_utterance.replace(\n                f\"[{old_idx}]\", f\"[_{new_idx}_]\"\n            )\n        for _, new_idx in old_to_new_citation_idx_mapping.items():\n            conv_turn.utterance = conv_turn.utterance.replace(\n                f\"[_{new_idx}_]\", f\"[{new_idx}]\"\n            )\n            conv_turn.utterance.replace(\"[-1]\", \"\")\n            conv_turn.raw_utterance = conv_turn.raw_utterance.replace(\n                f\"[_{new_idx}_]\", f\"[{new_idx}]\"\n            )\n            conv_turn.raw_utterance.replace(\"[-1]\", \"\")\n        conv_turn.cited_info = None\n\n    def get_knowledge_base_summary(self):\n        return self.gen_summary_module(self)\n\n    def reorganize(self):\n        \"\"\"\n        Reorganizes the knowledge base through two main processes: top-down expansion and bottom-up cleaning.\n\n        The reorganization process ensures that the knowledge base remains well-structured and relevant as new information is added. It consists of the following steps:\n        1.Top-Down Expansion: Expands nodes that have accumulated significant amounts of information by creating subtopics,\n          ensuring that each concept remains specific and manageable.\n        2.Bottom-Up Cleaning: Cleans the knowledge base by removing empty leaf nodes (nodes with no supporting information)\n          and merging nodes that have only a single child, simplifying the structure and maintaining clarity.\n        \"\"\"\n        # pre-processing\n        self.trim_empty_leaf_nodes()\n        self.merge_single_child_nodes()\n        # expand nodes\n        self.expand_node_module(knowledge_base=self)\n        # clean up\n        self.trim_empty_leaf_nodes()\n        self.merge_single_child_nodes()\n        self.update_all_info_path()\n\n    def to_report(self):\n        return self.article_generation_module(knowledge_base=self)\n"
  },
  {
    "path": "knowledge_storm/encoder.py",
    "content": "import os\r\nimport numpy as np\r\n\r\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\r\nfrom typing import List, Tuple, Union, Optional, Dict, Literal\r\nfrom pathlib import Path\r\n\r\ntry:\r\n    import warnings\r\n\r\n    with warnings.catch_warnings():\r\n        warnings.filterwarnings(\"ignore\", category=UserWarning)\r\n        if \"LITELLM_LOCAL_MODEL_COST_MAP\" not in os.environ:\r\n            os.environ[\"LITELLM_LOCAL_MODEL_COST_MAP\"] = \"True\"\r\n        import litellm\r\n\r\n        litellm.drop_params = True\r\n        litellm.telemetry = False\r\n\r\n    from litellm.caching.caching import Cache\r\n\r\n    disk_cache_dir = os.path.join(Path.home(), \".storm_local_cache\")\r\n    litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type=\"disk\")\r\n\r\nexcept ImportError:\r\n\r\n    class LitellmPlaceholder:\r\n        def __getattr__(self, _):\r\n            raise ImportError(\r\n                \"The LiteLLM package is not installed. Run `pip install litellm`.\"\r\n            )\r\n\r\n    litellm = LitellmPlaceholder()\r\n\r\n\r\nclass Encoder:\r\n    \"\"\"\r\n    A wrapper class for the LiteLLM embedding model, designed to handle embedding\r\n    generation tasks efficiently. It supports parallel processing and local caching of\r\n    embedding results for improved performance.\r\n\r\n    The Encoder utilizes the LiteLLM library to interact with various embedding models,\r\n    such as OpenAI and Azure embeddings. Users can specify the desired encoder type and\r\n    provide relevant API credentials during initialization.\r\n\r\n    Features:\r\n        - Support for multiple embedding models (e.g., OpenAI, Azure).\r\n        - Parallel processing for faster embedding generation.\r\n        - Local disk caching to store and reuse embedding results.\r\n        - Total token usage tracking for cost monitoring.\r\n\r\n    Note:\r\n        Refer to the LiteLLM documentation for details on supported embedding models:\r\n        https://docs.litellm.ai/docs/embedding/supported_embedding\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        encoder_type: Optional[str] = None,\r\n        api_key: Optional[str] = None,\r\n        api_base: Optional[str] = None,\r\n        api_version: Optional[str] = None,\r\n    ):\r\n        \"\"\"\r\n        Initializes the Encoder with the appropriate embedding model.\r\n\r\n        Args:\r\n            encoder_type (Optional[str]): Type of encoder ('openai', 'azure', etc.).\r\n            api_key (Optional[str]): API key for the encoder service.\r\n            api_base (Optional[str]): API base URL for the encoder service.\r\n            api_version (Optional[str]): API version for the encoder service.\r\n        \"\"\"\r\n        self.embedding_model_name = None\r\n        self.kargs = {}\r\n        self.total_token_usage = 0\r\n\r\n        # Initialize the appropriate embedding model\r\n        encoder_type = encoder_type or os.getenv(\"ENCODER_API_TYPE\")\r\n        if not encoder_type:\r\n            raise ValueError(\"ENCODER_API_TYPE environment variable is not set.\")\r\n\r\n        if encoder_type.lower() == \"openai\":\r\n            self.embedding_model_name = \"text-embedding-3-small\"\r\n            self.kargs = {\"api_key\": api_key or os.getenv(\"OPENAI_API_KEY\")}\r\n        elif encoder_type.lower() == \"azure\":\r\n            self.embedding_model_name = \"azure/text-embedding-3-small\"\r\n            self.kargs = {\r\n                \"api_key\": api_key or os.getenv(\"AZURE_API_KEY\"),\r\n                \"api_base\": api_base or os.getenv(\"AZURE_API_BASE\"),\r\n                \"api_version\": api_version or os.getenv(\"AZURE_API_VERSION\"),\r\n            }\r\n        else:\r\n            raise ValueError(\r\n                f\"Unsupported ENCODER_API_TYPE '{encoder_type}'. Supported types are 'openai', 'azure', 'together'.\"\r\n            )\r\n\r\n    def get_total_token_usage(self, reset: bool = False) -> int:\r\n        \"\"\"\r\n        Retrieves the total token usage.\r\n\r\n        Args:\r\n            reset (bool): If True, resets the total token usage counter after retrieval.\r\n\r\n        Returns:\r\n            int: The total number of tokens used.\r\n        \"\"\"\r\n        token_usage = self.total_token_usage\r\n        if reset:\r\n            self.total_token_usage = 0\r\n        return token_usage\r\n\r\n    def encode(self, texts: Union[str, List[str]], max_workers: int = 5) -> np.ndarray:\r\n        \"\"\"\r\n        Public method to get embeddings for the given texts.\r\n\r\n        Args:\r\n            texts (Union[str, List[str]]): A single text string or a list of text strings to embed.\r\n\r\n        Returns:\r\n            np.ndarray: The array of embeddings.\r\n        \"\"\"\r\n        return self._get_text_embeddings(texts, max_workers=max_workers)\r\n\r\n    def _get_single_text_embedding(self, text):\r\n        response = litellm.embedding(\r\n            model=self.embedding_model_name, input=text, caching=True, **self.kargs\r\n        )\r\n        embedding = response.data[0][\"embedding\"]\r\n        token_usage = response.get(\"usage\", {}).get(\"total_tokens\", 0)\r\n        return text, embedding, token_usage\r\n\r\n    def _get_text_embeddings(\r\n        self,\r\n        texts: Union[str, List[str]],\r\n        max_workers: int = 5,\r\n    ) -> Tuple[np.ndarray, int]:\r\n        \"\"\"\r\n        Get text embeddings using OpenAI's text-embedding-3-small model.\r\n\r\n        Args:\r\n            texts (Union[str, List[str]]): A single text string or a list of text strings to embed.\r\n            max_workers (int): The maximum number of workers for parallel processing.\r\n            api_key (str): The API key for accessing OpenAI's services.\r\n            embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings.\r\n\r\n        Returns:\r\n            Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage.\r\n        \"\"\"\r\n\r\n        if isinstance(texts, str):\r\n            _, embedding, tokens = self._get_single_text_embedding(texts)\r\n            self.total_token_usage += tokens\r\n            return np.array(embedding)\r\n\r\n        embeddings = []\r\n        total_tokens = 0\r\n\r\n        with ThreadPoolExecutor(max_workers=max_workers) as executor:\r\n            futures = {\r\n                executor.submit(self._get_single_text_embedding, text): text\r\n                for text in texts\r\n            }\r\n\r\n            for future in as_completed(futures):\r\n                try:\r\n                    text, embedding, tokens = future.result()\r\n                    embeddings.append((text, embedding, tokens))\r\n                    total_tokens += tokens\r\n                except Exception as e:\r\n                    print(f\"An error occurred for text: {futures[future]}\")\r\n                    print(e)\r\n\r\n        # Sort results to match the order of the input texts\r\n        embeddings.sort(key=lambda x: texts.index(x[0]))\r\n        embeddings = [result[1] for result in embeddings]\r\n        self.total_token_usage += total_tokens\r\n\r\n        return np.array(embeddings)\r\n"
  },
  {
    "path": "knowledge_storm/interface.py",
    "content": "import concurrent.futures\nimport dspy\nimport functools\nimport hashlib\nimport json\nimport logging\nimport time\nfrom abc import ABC, abstractmethod\nfrom collections import OrderedDict\nfrom typing import Dict, List, Optional, Union, TYPE_CHECKING\n\nfrom .utils import ArticleTextProcessing\n\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(name)s : %(levelname)-8s : %(message)s\"\n)\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    from .logging_wrapper import LoggingWrapper\n\n\nclass InformationTable(ABC):\n    \"\"\"\n    The InformationTable class serves as data class to store the information\n    collected during KnowledgeCuration stage.\n\n    Create subclass to incorporate more information as needed. For example,\n    in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information\n    would be perspective guided dialogue history.\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    @abstractmethod\n    def retrieve_information(**kwargs):\n        pass\n\n\nclass Information:\n    \"\"\"Class to represent detailed information.\n\n    Inherits from Information to include a unique identifier (URL), and extends\n    it with a description, snippets, and title of the storm information.\n\n    Attributes:\n        description (str): Brief description.\n        snippets (list): List of brief excerpts or snippets.\n        title (str): The title or headline of the information.\n        url (str): The unique URL (serving as UUID) of the information.\n    \"\"\"\n\n    def __init__(self, url, description, snippets, title, meta=None):\n        \"\"\"Initialize the Information object with detailed attributes.\n\n        Args:\n            url (str): The unique URL serving as the identifier for the information.\n            description (str): Detailed description.\n            snippets (list): List of brief excerpts or snippet.\n            title (str): The title or headline of the information.\n        \"\"\"\n        self.description = description\n        self.snippets = snippets\n        self.title = title\n        self.url = url\n        self.meta = meta if meta is not None else {}\n        self.citation_uuid = -1\n\n    def __hash__(self):\n        return hash(\n            (\n                self.url,\n                tuple(sorted(self.snippets)),\n            )\n        )\n\n    def __eq__(self, other):\n        if not isinstance(other, Information):\n            return False\n        return (\n            self.url == other.url\n            and set(self.snippets) == set(other.snippets)\n            and self._meta_str() == other._meta_str()\n        )\n\n    def __hash__(self):\n        return int(\n            self._md5_hash((self.url, tuple(sorted(self.snippets)), self._meta_str())),\n            16,\n        )\n\n    def _meta_str(self):\n        \"\"\"Generate a string representation of relevant meta information.\"\"\"\n        return f\"Question: {self.meta.get('question', '')}, Query: {self.meta.get('query', '')}\"\n\n    def _md5_hash(self, value):\n        \"\"\"Generate an MD5 hash for a given value.\"\"\"\n        if isinstance(value, (dict, list, tuple)):\n            value = json.dumps(value, sort_keys=True)\n        return hashlib.md5(str(value).encode(\"utf-8\")).hexdigest()\n\n    @classmethod\n    def from_dict(cls, info_dict):\n        \"\"\"Create a Information object from a dictionary.\n           Usage: info = Information.from_dict(storm_info_dict)\n\n        Args:\n            info_dict (dict): A dictionary containing keys 'url', 'description',\n                              'snippets', and 'title' corresponding to the object's attributes.\n\n        Returns:\n            Information: An instance of Information.\n        \"\"\"\n        info = cls(\n            url=info_dict[\"url\"],\n            description=info_dict[\"description\"],\n            snippets=info_dict[\"snippets\"],\n            title=info_dict[\"title\"],\n            meta=info_dict.get(\"meta\", None),\n        )\n        info.citation_uuid = int(info_dict.get(\"citation_uuid\", -1))\n        return info\n\n    def to_dict(self):\n        return {\n            \"url\": self.url,\n            \"description\": self.description,\n            \"snippets\": self.snippets,\n            \"title\": self.title,\n            \"meta\": self.meta,\n            \"citation_uuid\": self.citation_uuid,\n        }\n\n\nclass ArticleSectionNode:\n    \"\"\"\n    The ArticleSectionNode is the dataclass for handling the section of the article.\n    The content storage, section writing preferences are defined in this node.\n    \"\"\"\n\n    def __init__(self, section_name: str, content=None):\n        \"\"\"\n        section_name: section heading in string format. E.g. Introduction, History, etc.\n        content: content of the section. Up to you for design choice of the data structure.\n        \"\"\"\n        self.section_name = section_name\n        self.content = content\n        self.children = []\n        self.preference = None\n\n    def add_child(self, new_child_node, insert_to_front=False):\n        if insert_to_front:\n            self.children.insert(0, new_child_node)\n        else:\n            self.children.append(new_child_node)\n\n    def remove_child(self, child):\n        self.children.remove(child)\n\n\nclass Article(ABC):\n    def __init__(self, topic_name):\n        self.root = ArticleSectionNode(topic_name)\n\n    def find_section(\n        self, node: ArticleSectionNode, name: str\n    ) -> Optional[ArticleSectionNode]:\n        \"\"\"\n        Return the node of the section given the section name.\n\n        Args:\n            node: the node as the root to find.\n            name: the name of node as section name\n\n        Return:\n            reference of the node or None if section name has no match\n        \"\"\"\n        if node.section_name == name:\n            return node\n        for child in node.children:\n            result = self.find_section(child, name)\n            if result:\n                return result\n        return None\n\n    @abstractmethod\n    def to_string(self) -> str:\n        \"\"\"\n        Export Article object into string representation.\n        \"\"\"\n\n    def get_outline_tree(self):\n        \"\"\"\n        Generates a hierarchical tree structure representing the outline of the document.\n\n        Returns:\n            Dict[str, Dict]: A nested dictionary representing the hierarchical structure of the document's outline.\n                             Each key is a section name, and the value is another dictionary representing the child sections,\n                             recursively forming the tree structure of the document's outline. If a section has no subsections,\n                             its value is an empty dictionary.\n\n        Example:\n            Assuming a document with a structure like:\n            - Introduction\n                - Background\n                - Objective\n            - Methods\n                - Data Collection\n                - Analysis\n            The method would return:\n            {\n                'Introduction': {\n                    'Background': {},\n                    'Objective': {}\n                },\n                'Methods': {\n                    'Data Collection': {},\n                    'Analysis': {}\n                }\n            }\n        \"\"\"\n\n        def build_tree(node) -> Dict[str, Dict]:\n            tree = {}\n            for child in node.children:\n                tree[child.section_name] = build_tree(child)\n            return tree if tree else {}\n\n        return build_tree(self.root)\n\n    def get_first_level_section_names(self) -> List[str]:\n        \"\"\"\n        Get first level section names\n        \"\"\"\n        return [i.section_name for i in self.root.children]\n\n    @classmethod\n    @abstractmethod\n    def from_string(cls, topic_name: str, article_text: str):\n        \"\"\"\n        Create an instance of the Article object from a string\n        \"\"\"\n        pass\n\n    def prune_empty_nodes(self, node=None):\n        if node is None:\n            node = self.root\n\n        node.children[:] = [\n            child for child in node.children if self.prune_empty_nodes(child)\n        ]\n\n        if (node.content is None or node.content == \"\") and not node.children:\n            return None\n        else:\n            return node\n\n\nclass Retriever:\n    \"\"\"\n    An abstract base class for retriever modules. It provides a template for retrieving information based on a query.\n\n    This class should be extended to implement specific retrieval functionalities.\n    Users can design their retriever modules as needed by implementing the retrieve method.\n    The retrieval model/search engine used for each part should be declared with a suffix '_rm' in the attribute name.\n    \"\"\"\n\n    def __init__(self, rm: dspy.Retrieve, max_thread: int = 1):\n        self.max_thread = max_thread\n        self.rm = rm\n\n    def collect_and_reset_rm_usage(self):\n        combined_usage = []\n        if hasattr(getattr(self, \"rm\"), \"get_usage_and_reset\"):\n            combined_usage.append(getattr(self, \"rm\").get_usage_and_reset())\n\n        name_to_usage = {}\n        for usage in combined_usage:\n            for model_name, query_cnt in usage.items():\n                if model_name not in name_to_usage:\n                    name_to_usage[model_name] = query_cnt\n                else:\n                    name_to_usage[model_name] += query_cnt\n\n        return name_to_usage\n\n    def retrieve(\n        self, query: Union[str, List[str]], exclude_urls: List[str] = []\n    ) -> List[Information]:\n        queries = query if isinstance(query, list) else [query]\n        to_return = []\n\n        def process_query(q):\n            retrieved_data_list = self.rm(\n                query_or_queries=[q], exclude_urls=exclude_urls\n            )\n            local_to_return = []\n            for data in retrieved_data_list:\n                for i in range(len(data[\"snippets\"])):\n                    # STORM generate the article with citations. We do not consider multi-hop citations.\n                    # Remove citations in the source to avoid confusion.\n                    data[\"snippets\"][i] = ArticleTextProcessing.remove_citations(\n                        data[\"snippets\"][i]\n                    )\n                storm_info = Information.from_dict(data)\n                storm_info.meta[\"query\"] = q\n                local_to_return.append(storm_info)\n            return local_to_return\n\n        with concurrent.futures.ThreadPoolExecutor(\n            max_workers=self.max_thread\n        ) as executor:\n            results = list(executor.map(process_query, queries))\n\n        for result in results:\n            to_return.extend(result)\n\n        return to_return\n\n\nclass KnowledgeCurationModule(ABC):\n    \"\"\"\n    The interface for knowledge curation stage. Given topic, return collected information.\n    \"\"\"\n\n    def __init__(self, retriever: Retriever):\n        \"\"\"\n        Store args and finish initialization.\n        \"\"\"\n        self.retriever = retriever\n\n    @abstractmethod\n    def research(self, topic) -> InformationTable:\n        \"\"\"\n        Curate information and knowledge for the given topic\n\n        Args:\n            topic: topic of interest in natural language.\n\n        Returns:\n            collected_information: collected information in InformationTable type.\n        \"\"\"\n        pass\n\n\nclass OutlineGenerationModule(ABC):\n    \"\"\"\n    The interface for outline generation stage. Given topic, collected information from knowledge\n    curation stage, generate outline for the article.\n    \"\"\"\n\n    @abstractmethod\n    def generate_outline(\n        self, topic: str, information_table: InformationTable, **kwargs\n    ) -> Article:\n        \"\"\"\n        Generate outline for the article. Required arguments include:\n            topic: the topic of interest\n            information_table: knowledge curation data generated from KnowledgeCurationModule\n\n        More arguments could be\n            1. draft outline\n            2. user provided outline\n\n        Returns:\n            article_outline of type ArticleOutline\n        \"\"\"\n        pass\n\n\nclass ArticleGenerationModule(ABC):\n    \"\"\"\n    The interface for article generation stage. Given topic, collected information from\n    knowledge curation stage, generated outline from outline generation stage,\n    \"\"\"\n\n    @abstractmethod\n    def generate_article(\n        self,\n        topic: str,\n        information_table: InformationTable,\n        article_with_outline: Article,\n        **kwargs,\n    ) -> Article:\n        \"\"\"\n        Generate article. Required arguments include:\n            topic: the topic of interest\n            information_table: knowledge curation data generated from KnowledgeCurationModule\n            article_with_outline: article with specified outline from OutlineGenerationModule\n        \"\"\"\n        pass\n\n\nclass ArticlePolishingModule(ABC):\n    \"\"\"\n    The interface for article generation stage. Given topic, collected information from\n    knowledge curation stage, generated outline from outline generation stage,\n    \"\"\"\n\n    @abstractmethod\n    def polish_article(self, topic: str, draft_article: Article, **kwargs) -> Article:\n        \"\"\"\n        Polish article. Required arguments include:\n            topic: the topic of interest\n            draft_article: draft article from ArticleGenerationModule.\n        \"\"\"\n        pass\n\n\ndef log_execution_time(func):\n    \"\"\"Decorator to log the execution time of a function.\"\"\"\n\n    @functools.wraps(func)\n    def wrapper(self, *args, **kwargs):\n        start_time = time.time()\n        result = func(self, *args, **kwargs)\n        end_time = time.time()\n        execution_time = end_time - start_time\n        logger.info(f\"{func.__name__} executed in {execution_time:.4f} seconds\")\n        self.time[func.__name__] = execution_time\n        return result\n\n    return wrapper\n\n\nclass LMConfigs(ABC):\n    \"\"\"Abstract base class for language model configurations of the knowledge curation engine.\n\n    The language model used for each part should be declared with a suffix '_lm' in the attribute name.\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def init_check(self):\n        for attr_name in self.__dict__:\n            if \"_lm\" in attr_name and getattr(self, attr_name) is None:\n                logging.warning(\n                    f\"Language model for {attr_name} is not initialized. Please call set_{attr_name}()\"\n                )\n\n    def collect_and_reset_lm_history(self):\n        history = []\n        for attr_name in self.__dict__:\n            if \"_lm\" in attr_name and hasattr(getattr(self, attr_name), \"history\"):\n                history.extend(getattr(self, attr_name).history)\n                getattr(self, attr_name).history = []\n\n        return history\n\n    def collect_and_reset_lm_usage(self):\n        combined_usage = []\n        for attr_name in self.__dict__:\n            if \"_lm\" in attr_name and hasattr(\n                getattr(self, attr_name), \"get_usage_and_reset\"\n            ):\n                combined_usage.append(getattr(self, attr_name).get_usage_and_reset())\n\n        model_name_to_usage = {}\n        for usage in combined_usage:\n            for model_name, tokens in usage.items():\n                if model_name not in model_name_to_usage:\n                    model_name_to_usage[model_name] = tokens\n                else:\n                    model_name_to_usage[model_name][\"prompt_tokens\"] += tokens[\n                        \"prompt_tokens\"\n                    ]\n                    model_name_to_usage[model_name][\"completion_tokens\"] += tokens[\n                        \"completion_tokens\"\n                    ]\n\n        return model_name_to_usage\n\n    def log(self):\n        return OrderedDict(\n            {\n                attr_name: getattr(self, attr_name).kwargs\n                for attr_name in self.__dict__\n                if \"_lm\" in attr_name and hasattr(getattr(self, attr_name), \"kwargs\")\n            }\n        )\n\n\nclass Engine(ABC):\n    def __init__(self, lm_configs: LMConfigs):\n        self.lm_configs = lm_configs\n        self.time = {}\n        self.lm_cost = {}  # Cost of language models measured by in/out tokens.\n        self.rm_cost = {}  # Cost of retrievers measured by number of queries.\n\n    def log_execution_time_and_lm_rm_usage(self, func):\n        \"\"\"Decorator to log the execution time, language model usage, and retrieval model usage of a function.\"\"\"\n\n        @functools.wraps(func)\n        def wrapper(*args, **kwargs):\n            start_time = time.time()\n            result = func(*args, **kwargs)\n            end_time = time.time()\n            execution_time = end_time - start_time\n            self.time[func.__name__] = execution_time\n            logger.info(f\"{func.__name__} executed in {execution_time:.4f} seconds\")\n            self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage()\n            if hasattr(self, \"retriever\"):\n                self.rm_cost[func.__name__] = (\n                    self.retriever.collect_and_reset_rm_usage()\n                )\n            return result\n\n        return wrapper\n\n    def apply_decorators(self):\n        \"\"\"Apply decorators to methods that need them.\"\"\"\n        methods_to_decorate = [\n            method_name\n            for method_name in dir(self)\n            if callable(getattr(self, method_name)) and method_name.startswith(\"run_\")\n        ]\n        for method_name in methods_to_decorate:\n            original_method = getattr(self, method_name)\n            decorated_method = self.log_execution_time_and_lm_rm_usage(original_method)\n            setattr(self, method_name, decorated_method)\n\n    @abstractmethod\n    def run_knowledge_curation_module(self, **kwargs) -> Optional[InformationTable]:\n        pass\n\n    @abstractmethod\n    def run_outline_generation_module(self, **kwarg) -> Article:\n        pass\n\n    @abstractmethod\n    def run_article_generation_module(self, **kwarg) -> Article:\n        pass\n\n    @abstractmethod\n    def run_article_polishing_module(self, **kwarg) -> Article:\n        pass\n\n    @abstractmethod\n    def run(self, **kwargs):\n        pass\n\n    def summary(self):\n        print(\"***** Execution time *****\")\n        for k, v in self.time.items():\n            print(f\"{k}: {v:.4f} seconds\")\n\n        print(\"***** Token usage of language models: *****\")\n        for k, v in self.lm_cost.items():\n            print(f\"{k}\")\n            for model_name, tokens in v.items():\n                print(f\"    {model_name}: {tokens}\")\n\n        print(\"***** Number of queries of retrieval models: *****\")\n        for k, v in self.rm_cost.items():\n            print(f\"{k}: {v}\")\n\n    def reset(self):\n        self.time = {}\n        self.lm_cost = {}\n        self.rm_cost = {}\n\n\nclass Agent(ABC):\n    \"\"\"\n    Interface for STORM and Co-STORM LLM agent\n\n    This class must be implemented by any subclass of `Agent` to define how the agent generates an utterance.\n    The generated utterance can be influenced by the conversation history, knowledge base, and any additional parameters passed via `kwargs`.\n    The implementation should align with the specific role and perspective of the agent, as defined by the agent's topic, role name, and role description.\n\n    Args:\n        knowledge_base (KnowledgeBase): The current knowledge base (e.g., mind map in Co-STORM) that contains the accumulated information relevant to the conversation.\n        conversation_history (List[ConversationTurn]): A list of past conversation turns, providing context for generating the next utterance.\n                                                       The agent can refer to this history to maintain continuity and relevance in the conversation.\n        logging_wrapper (LoggingWrapper): A wrapper used for logging important events during the utterance generation process.\n        **kwargs: Additional arguments that can be passed to the method for more specialized utterance generation behavior depending on the agent's specific implementation.\n\n    Returns:\n        ConversationTurn: A new conversation turn generated by the agent, containing the agent's response, including the role, utterance type, and relevant information from the knowledge base.\n\n    Notes:\n        - Subclasses of `Agent` should define the exact strategy for generating the utterance, which could involve interacting with a language model, retrieving relevant knowledge, or following specific conversational policies.\n        - The agent's role, perspective, and the knowledge base content will influence how the utterance is formulated.\n    \"\"\"\n\n    from .dataclass import KnowledgeBase, ConversationTurn\n\n    def __init__(self, topic: str, role_name: str, role_description: str):\n        self.topic = topic\n        self.role_name = role_name\n        self.role_description = role_description\n\n    def get_role_description(self):\n        if self.role_description:\n            return f\"{self.role_name}: {self.role_description}\"\n        return self.role_name\n\n    @abstractmethod\n    def generate_utterance(\n        self,\n        knowledge_base: KnowledgeBase,\n        conversation_history: List[ConversationTurn],\n        logging_wrapper: \"LoggingWrapper\",\n        **kwargs,\n    ):\n        pass\n"
  },
  {
    "path": "knowledge_storm/lm.py",
    "content": "import backoff\nimport dspy\nimport functools\nimport logging\nimport os\nimport random\nimport requests\nimport threading\nfrom typing import Optional, Literal, Any\nimport ujson\nfrom pathlib import Path\n\n\nfrom dsp import ERRORS, backoff_hdlr, giveup_hdlr\nfrom dsp.modules.hf import openai_to_hf\nfrom dsp.modules.hf_client import send_hftgi_request_v01_wrapped\nfrom openai import OpenAI, AzureOpenAI\nfrom transformers import AutoTokenizer\n\ntry:\n    from anthropic import RateLimitError\nexcept ImportError:\n    RateLimitError = None\n\n############################\n# Code copied from https://github.com/stanfordnlp/dspy/blob/main/dspy/clients/lm.py on Sep 29, 2024\n\n# try:\nimport warnings\n\nwith warnings.catch_warnings():\n    warnings.filterwarnings(\"ignore\", category=UserWarning)\n    if \"LITELLM_LOCAL_MODEL_COST_MAP\" not in os.environ:\n        os.environ[\"LITELLM_LOCAL_MODEL_COST_MAP\"] = \"True\"\n    import litellm\n\n    litellm.drop_params = True\n    litellm.telemetry = False\n\nfrom litellm.caching.caching import Cache\n\ndisk_cache_dir = os.path.join(Path.home(), \".storm_local_cache\")\nlitellm.cache = Cache(disk_cache_dir=disk_cache_dir, type=\"disk\")\n\n# except ImportError:\n\n#     class LitellmPlaceholder:\n#         def __getattr__(self, _):\n#             raise ImportError(\n#                 \"The LiteLLM package is not installed. Run `pip install litellm`.\"\n#             )\n\n# litellm = LitellmPlaceholder()\nLM_LRU_CACHE_MAX_SIZE = 3000\n\n\nclass LM:\n    def __init__(\n        self,\n        model,\n        model_type=\"chat\",\n        temperature=0.0,\n        max_tokens=1000,\n        cache=True,\n        **kwargs,\n    ):\n        self.model = model\n        self.model_type = model_type\n        self.cache = cache\n        self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)\n        self.history = []\n\n        if \"o1-\" in model:\n            assert (\n                max_tokens >= 5000 and temperature == 1.0\n            ), \"OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`\"\n\n    def __call__(self, prompt=None, messages=None, **kwargs):\n        # Build the request.\n        cache = kwargs.pop(\"cache\", self.cache)\n        messages = messages or [{\"role\": \"user\", \"content\": prompt}]\n        kwargs = {**self.kwargs, **kwargs}\n\n        # Make the request and handle LRU & disk caching.\n        if self.model_type == \"chat\":\n            completion = cached_litellm_completion if cache else litellm_completion\n        else:\n            completion = (\n                cached_litellm_text_completion if cache else litellm_text_completion\n            )\n\n        response = completion(\n            ujson.dumps(dict(model=self.model, messages=messages, **kwargs))\n        )\n        outputs = [\n            c.message.content if hasattr(c, \"message\") else c[\"text\"]\n            for c in response[\"choices\"]\n        ]\n\n        # Logging, with removed api key & where `cost` is None on cache hit.\n        kwargs = {k: v for k, v in kwargs.items() if not k.startswith(\"api_\")}\n        entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)\n        entry = dict(**entry, outputs=outputs, usage=dict(response[\"usage\"]))\n        entry = dict(\n            **entry, cost=response.get(\"_hidden_params\", {}).get(\"response_cost\")\n        )\n        self.history.append(entry)\n\n        return outputs\n\n    def inspect_history(self, n: int = 1):\n        _inspect_history(self, n)\n\n\n@functools.lru_cache(maxsize=LM_LRU_CACHE_MAX_SIZE)\ndef cached_litellm_completion(request):\n    return litellm_completion(request, cache={\"no-cache\": False, \"no-store\": False})\n\n\ndef litellm_completion(request, cache={\"no-cache\": True, \"no-store\": True}):\n    kwargs = ujson.loads(request)\n    return litellm.completion(cache=cache, **kwargs)\n\n\n@functools.lru_cache(maxsize=LM_LRU_CACHE_MAX_SIZE)\ndef cached_litellm_text_completion(request):\n    return litellm_text_completion(\n        request, cache={\"no-cache\": False, \"no-store\": False}\n    )\n\n\ndef litellm_text_completion(request, cache={\"no-cache\": True, \"no-store\": True}):\n    kwargs = ujson.loads(request)\n\n    # Extract the provider and model from the model string.\n    model = kwargs.pop(\"model\").split(\"/\", 1)\n    provider, model = model[0] if len(model) > 1 else \"openai\", model[-1]\n\n    # Use the API key and base from the kwargs, or from the environment.\n    api_key = kwargs.pop(\"api_key\", None) or os.getenv(f\"{provider}_API_KEY\")\n    api_base = kwargs.pop(\"api_base\", None) or os.getenv(f\"{provider}_API_BASE\")\n\n    # Build the prompt from the messages.\n    prompt = \"\\n\\n\".join(\n        [x[\"content\"] for x in kwargs.pop(\"messages\")] + [\"BEGIN RESPONSE:\"]\n    )\n\n    return litellm.text_completion(\n        cache=cache,\n        model=f\"text-completion-openai/{model}\",\n        api_key=api_key,\n        api_base=api_base,\n        prompt=prompt,\n        **kwargs,\n    )\n\n\ndef _green(text: str, end: str = \"\\n\"):\n    return \"\\x1b[32m\" + str(text).lstrip() + \"\\x1b[0m\" + end\n\n\ndef _red(text: str, end: str = \"\\n\"):\n    return \"\\x1b[31m\" + str(text) + \"\\x1b[0m\" + end\n\n\ndef _inspect_history(lm, n: int = 1):\n    \"\"\"Prints the last n prompts and their completions.\"\"\"\n\n    for item in lm.history[-n:]:\n        messages = item[\"messages\"] or [{\"role\": \"user\", \"content\": item[\"prompt\"]}]\n        outputs = item[\"outputs\"]\n\n        print(\"\\n\\n\\n\")\n        for msg in messages:\n            print(_red(f\"{msg['role'].capitalize()} message:\"))\n            print(msg[\"content\"].strip())\n            print(\"\\n\")\n\n        print(_red(\"Response:\"))\n        print(_green(outputs[0].strip()))\n\n        if len(outputs) > 1:\n            choices_text = f\" \\t (and {len(outputs)-1} other completions)\"\n            print(_red(choices_text, end=\"\"))\n\n    print(\"\\n\\n\\n\")\n\n\n############################\n\n\nclass LitellmModel(LM):\n    \"\"\"A wrapper class for LiteLLM.\n\n    Check out https://docs.litellm.ai/docs/providers for usage details.\n    \"\"\"\n\n    def __init__(\n        self,\n        model: str = \"openai/gpt-4o-mini\",\n        api_key: Optional[str] = None,\n        model_type: Literal[\"chat\", \"text\"] = \"chat\",\n        **kwargs,\n    ):\n        super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs)\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the OpenAI API response.\"\"\"\n        usage_data = response.get(\"usage\")\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.get(\"prompt_tokens\", 0)\n                self.completion_tokens += usage_data.get(\"completion_tokens\", 0)\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.model\n            or self.kwargs.get(\"model\")\n            or self.kwargs.get(\"engine\"): {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n        return usage\n\n    def __call__(self, prompt=None, messages=None, **kwargs):\n        # Build the request.\n        cache = kwargs.pop(\"cache\", self.cache)\n        messages = messages or [{\"role\": \"user\", \"content\": prompt}]\n        kwargs = {**self.kwargs, **kwargs}\n\n        # Make the request and handle LRU & disk caching.\n        if self.model_type == \"chat\":\n            completion = cached_litellm_completion if cache else litellm_completion\n        else:\n            completion = (\n                cached_litellm_text_completion if cache else litellm_text_completion\n            )\n\n        response = completion(\n            ujson.dumps(dict(model=self.model, messages=messages, **kwargs))\n        )\n        response_dict = response.json()\n        self.log_usage(response_dict)\n        outputs = [\n            c.message.content if hasattr(c, \"message\") else c[\"text\"]\n            for c in response[\"choices\"]\n        ]\n\n        # Logging, with removed api key & where `cost` is None on cache hit.\n        kwargs = {k: v for k, v in kwargs.items() if not k.startswith(\"api_\")}\n        entry = dict(\n            prompt=prompt, messages=messages, kwargs=kwargs, response=response_dict\n        )\n        entry = dict(**entry, outputs=outputs, usage=dict(response_dict[\"usage\"]))\n        entry = dict(\n            **entry, cost=response.get(\"_hidden_params\", {}).get(\"response_cost\")\n        )\n        self.history.append(entry)\n\n        return outputs\n\n\n# ========================================================================\n# The following language model classes were deprecated after v1.1.0.\n# They remain in this file for backward compatibility but will no longer be maintained.\n\n\nclass OpenAIModel(dspy.OpenAI):\n    \"\"\"A wrapper class for dspy.OpenAI.\"\"\"\n\n    def __init__(\n        self,\n        model: str = \"gpt-4o-mini\",\n        api_key: Optional[str] = None,\n        model_type: Literal[\"chat\", \"text\"] = None,\n        **kwargs,\n    ):\n        super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs)\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the OpenAI API response.\"\"\"\n        usage_data = response.get(\"usage\")\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.get(\"prompt_tokens\", 0)\n                self.completion_tokens += usage_data.get(\"completion_tokens\", 0)\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.kwargs.get(\"model\")\n            or self.kwargs.get(\"engine\"): {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n        return usage\n\n    def __call__(\n        self,\n        prompt: str,\n        only_completed: bool = True,\n        return_sorted: bool = False,\n        **kwargs,\n    ) -> list[dict[str, Any]]:\n        \"\"\"Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage.\"\"\"\n\n        assert only_completed, \"for now\"\n        assert return_sorted is False, \"for now\"\n\n        # if kwargs.get(\"n\", 1) > 1:\n        #     if self.model_type == \"chat\":\n        #         kwargs = {**kwargs}\n        #     else:\n        #         kwargs = {**kwargs, \"logprobs\": 5}\n\n        response = self.request(prompt, **kwargs)\n\n        # Log the token usage from the OpenAI API response.\n        self.log_usage(response)\n\n        choices = response[\"choices\"]\n\n        completed_choices = [c for c in choices if c[\"finish_reason\"] != \"length\"]\n\n        if only_completed and len(completed_choices):\n            choices = completed_choices\n\n        completions = [self._get_choice_text(c) for c in choices]\n        if return_sorted and kwargs.get(\"n\", 1) > 1:\n            scored_completions = []\n\n            for c in choices:\n                tokens, logprobs = (\n                    c[\"logprobs\"][\"tokens\"],\n                    c[\"logprobs\"][\"token_logprobs\"],\n                )\n\n                if \"<|endoftext|>\" in tokens:\n                    index = tokens.index(\"<|endoftext|>\") + 1\n                    tokens, logprobs = tokens[:index], logprobs[:index]\n\n                avglog = sum(logprobs) / len(logprobs)\n                scored_completions.append((avglog, self._get_choice_text(c)))\n\n            scored_completions = sorted(scored_completions, reverse=True)\n            completions = [c for _, c in scored_completions]\n\n        return completions\n\n\nclass DeepSeekModel(dspy.OpenAI):\n    \"\"\"A wrapper class for DeepSeek API, compatible with dspy.OpenAI.\"\"\"\n\n    def __init__(\n        self,\n        model: str = \"deepseek-chat\",\n        api_key: Optional[str] = None,\n        api_base: str = \"https://api.deepseek.com\",\n        **kwargs,\n    ):\n        super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs)\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n        self.model = model\n        self.api_key = api_key or os.getenv(\"DEEPSEEK_API_KEY\")\n        self.api_base = api_base\n        if not self.api_key:\n            raise ValueError(\n                \"DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY\"\n            )\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the DeepSeek API response.\"\"\"\n        usage_data = response.get(\"usage\")\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.get(\"prompt_tokens\", 0)\n                self.completion_tokens += usage_data.get(\"completion_tokens\", 0)\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.model: {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n        return usage\n\n    @backoff.on_exception(\n        backoff.expo,\n        ERRORS,\n        max_time=1000,\n        on_backoff=backoff_hdlr,\n        giveup=giveup_hdlr,\n    )\n    def _create_completion(self, prompt: str, **kwargs):\n        \"\"\"Create a completion using the DeepSeek API.\"\"\"\n        headers = {\n            \"Content-Type\": \"application/json\",\n            \"Authorization\": f\"Bearer {self.api_key}\",\n        }\n        data = {\n            \"model\": self.model,\n            \"messages\": [{\"role\": \"user\", \"content\": prompt}],\n            **kwargs,\n        }\n        response = requests.post(\n            f\"{self.api_base}/v1/chat/completions\", headers=headers, json=data\n        )\n        response.raise_for_status()\n        return response.json()\n\n    def __call__(\n        self,\n        prompt: str,\n        only_completed: bool = True,\n        return_sorted: bool = False,\n        **kwargs,\n    ) -> list[dict[str, Any]]:\n        \"\"\"Call the DeepSeek API to generate completions.\"\"\"\n        assert only_completed, \"for now\"\n        assert return_sorted is False, \"for now\"\n\n        response = self._create_completion(prompt, **kwargs)\n\n        # Log the token usage from the DeepSeek API response.\n        self.log_usage(response)\n\n        choices = response[\"choices\"]\n        completions = [choice[\"message\"][\"content\"] for choice in choices]\n\n        history = {\n            \"prompt\": prompt,\n            \"response\": response,\n            \"kwargs\": kwargs,\n        }\n        self.history.append(history)\n\n        return completions\n\n\nclass AzureOpenAIModel(dspy.LM):\n    \"\"\"A wrapper class of Azure OpenAI endpoint.\n\n    Note: param::model should match the deployment_id on your Azure platform.\n    \"\"\"\n\n    def __init__(\n        self,\n        azure_endpoint: str,\n        api_version: str,\n        model: str,\n        api_key: str,\n        model_type: Literal[\"chat\", \"text\"] = \"chat\",\n        **kwargs,\n    ):\n        super().__init__(model=model)\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n        self.model = model\n        self.provider = \"azure\"\n        self.model_type = model_type\n\n        self.client = AzureOpenAI(\n            azure_endpoint=azure_endpoint,\n            api_key=api_key,\n            api_version=api_version,\n        )\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n        self.kwargs = {\n            \"model\": model,\n            \"temperature\": 0.0,\n            \"max_tokens\": 150,\n            \"top_p\": 1,\n            \"frequency_penalty\": 0,\n            \"presence_penalty\": 0,\n            \"n\": 1,\n            **kwargs,\n        }\n\n    @backoff.on_exception(\n        backoff.expo,\n        ERRORS,\n        max_time=1000,\n        on_backoff=backoff_hdlr,\n        giveup=giveup_hdlr,\n    )\n    def basic_request(self, prompt: str, **kwargs) -> Any:\n        kwargs = {**self.kwargs, **kwargs}\n\n        try:\n            if self.model_type == \"chat\":\n                messages = [{\"role\": \"user\", \"content\": prompt}]\n\n                response = self.client.chat.completions.create(\n                    messages=messages, **kwargs\n                )\n            else:\n                response = self.client.completions.create(prompt=prompt, **kwargs)\n\n            self.log_usage(response)\n\n            history_entry = {\n                \"prompt\": prompt,\n                \"response\": dict(response),\n                \"kwargs\": kwargs,\n            }\n            self.history.append(history_entry)\n\n            return response\n\n        except Exception as e:\n            logging.error(f\"Error making request to Azure OpenAI: {str(e)}\")\n            raise\n\n    def _get_choice_text(self, choice: Any) -> str:\n        \"\"\"Extract text from a choice object based on model type.\"\"\"\n        if self.model_type == \"chat\":\n            return choice.message.content\n        return choice.text\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from response.\"\"\"\n        usage_data = response.usage\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.prompt_tokens\n                self.completion_tokens += usage_data.completion_tokens\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.model: {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n        return usage\n\n    def __call__(\n        self,\n        prompt: str,\n        only_completed: bool = True,\n        return_sorted: bool = False,\n        **kwargs,\n    ) -> list[str]:\n        \"\"\"Get completions from Azure OpenAI.\n\n        Args:\n            prompt: The prompt to send to the model\n            only_completed: Only return completed responses\n            return_sorted: Sort completions by probability (not implemented)\n            **kwargs: Additional arguments to pass to the API\n\n        Returns:\n            List of completion strings\n        \"\"\"\n        response = self.basic_request(prompt, **kwargs)\n\n        choices = response.choices\n        completed_choices = [c for c in choices if c.finish_reason != \"length\"]\n\n        if only_completed and completed_choices:\n            choices = completed_choices\n\n        completions = [self._get_choice_text(c) for c in choices]\n\n        return completions\n\n\nclass GroqModel(dspy.OpenAI):\n    \"\"\"A wrapper class for Groq API (https://console.groq.com/), compatible with dspy.OpenAI.\"\"\"\n\n    def __init__(\n        self,\n        model: str = \"llama3-70b-8192\",\n        api_key: Optional[str] = None,\n        api_base: str = \"https://api.groq.com/openai/v1\",\n        **kwargs,\n    ):\n        super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs)\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n        self.model = model\n        self.api_key = api_key or os.getenv(\"GROQ_API_KEY\")\n        self.api_base = api_base\n        if not self.api_key:\n            raise ValueError(\n                \"Groq API key must be provided either as an argument or as an environment variable GROQ_API_KEY\"\n            )\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the Groq API response.\"\"\"\n        usage_data = response.get(\"usage\")\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.get(\"prompt_tokens\", 0)\n                self.completion_tokens += usage_data.get(\"completion_tokens\", 0)\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.model: {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n        return usage\n\n    @backoff.on_exception(\n        backoff.expo,\n        ERRORS,\n        max_time=1000,\n        on_backoff=backoff_hdlr,\n        giveup=giveup_hdlr,\n    )\n    def _create_completion(self, prompt: str, **kwargs):\n        \"\"\"Create a completion using the Groq API.\"\"\"\n        headers = {\n            \"Content-Type\": \"application/json\",\n            \"Authorization\": f\"Bearer {self.api_key}\",\n        }\n\n        # Remove unsupported fields\n        kwargs.pop(\"logprobs\", None)\n        kwargs.pop(\"logit_bias\", None)\n        kwargs.pop(\"top_logprobs\", None)\n\n        # Ensure N is 1 if supplied\n        if \"n\" in kwargs and kwargs[\"n\"] != 1:\n            raise ValueError(\"Groq API only supports N=1\")\n\n        # Adjust temperature if it's 0\n        if kwargs.get(\"temperature\", 1) == 0:\n            kwargs[\"temperature\"] = 1e-8\n\n        data = {\n            \"model\": self.model,\n            \"messages\": [{\"role\": \"user\", \"content\": prompt}],\n            **kwargs,\n        }\n\n        # Remove 'name' field from messages if present\n        for message in data[\"messages\"]:\n            message.pop(\"name\", None)\n\n        response = requests.post(\n            f\"{self.api_base}/chat/completions\", headers=headers, json=data\n        )\n        response.raise_for_status()\n        return response.json()\n\n    def __call__(\n        self,\n        prompt: str,\n        only_completed: bool = True,\n        return_sorted: bool = False,\n        **kwargs,\n    ) -> list[dict[str, Any]]:\n        \"\"\"Call the Groq API to generate completions.\"\"\"\n        assert only_completed, \"for now\"\n        assert return_sorted is False, \"for now\"\n\n        response = self._create_completion(prompt, **kwargs)\n\n        # Log the token usage from the Groq API response.\n        self.log_usage(response)\n\n        choices = response[\"choices\"]\n        completions = [choice[\"message\"][\"content\"] for choice in choices]\n\n        history = {\n            \"prompt\": prompt,\n            \"response\": response,\n            \"kwargs\": kwargs,\n        }\n        self.history.append(history)\n\n        return completions\n\n\nclass ClaudeModel(dspy.dsp.modules.lm.LM):\n    \"\"\"Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.\"\"\"\n\n    def __init__(\n        self,\n        model: str,\n        api_key: Optional[str] = None,\n        api_base: Optional[str] = None,\n        **kwargs,\n    ):\n        super().__init__(model)\n        try:\n            from anthropic import Anthropic\n        except ImportError as err:\n            raise ImportError(\"Claude requires `pip install anthropic`.\") from err\n\n        self.provider = \"anthropic\"\n        self.api_key = api_key = (\n            os.environ.get(\"ANTHROPIC_API_KEY\") if api_key is None else api_key\n        )\n        self.api_base = (\n            \"https://api.anthropic.com/v1/messages\" if api_base is None else api_base\n        )\n        self.kwargs = {\n            \"temperature\": kwargs.get(\"temperature\", 0.0),\n            \"max_tokens\": min(kwargs.get(\"max_tokens\", 4096), 4096),\n            \"top_p\": kwargs.get(\"top_p\", 1.0),\n            \"top_k\": kwargs.get(\"top_k\", 1),\n            \"n\": kwargs.pop(\"n\", kwargs.pop(\"num_generations\", 1)),\n            **kwargs,\n            \"model\": model,\n        }\n        self.history: list[dict[str, Any]] = []\n        self.client = Anthropic(api_key=api_key)\n        self.model = model\n\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the Anthropic API response.\"\"\"\n        usage_data = response.usage\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.input_tokens\n                self.completion_tokens += usage_data.output_tokens\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.model: {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n        return usage\n\n    def basic_request(self, prompt: str, **kwargs):\n        raw_kwargs = kwargs\n        kwargs = {**self.kwargs, **kwargs}\n        # caching mechanism requires hashable kwargs\n        kwargs[\"messages\"] = [{\"role\": \"user\", \"content\": prompt}]\n        kwargs.pop(\"n\")\n        response = self.client.messages.create(**kwargs)\n        # history = {\n        #     \"prompt\": prompt,\n        #     \"response\": response,\n        #     \"kwargs\": kwargs,\n        #     \"raw_kwargs\": raw_kwargs,\n        # }\n        json_serializable_history = {\n            \"prompt\": prompt,\n            \"response\": {\n                \"content\": response.content[0].text,\n                \"model\": response.model,\n                \"role\": response.role,\n                \"stop_reason\": response.stop_reason,\n                \"stop_sequence\": response.stop_sequence,\n                \"type\": response.type,\n                \"usage\": {\n                    \"input_tokens\": response.usage.input_tokens,\n                    \"output_tokens\": response.usage.output_tokens,\n                },\n            },\n            \"kwargs\": kwargs,\n            \"raw_kwargs\": raw_kwargs,\n        }\n        self.history.append(json_serializable_history)\n        return response\n\n    @backoff.on_exception(\n        backoff.expo,\n        (RateLimitError,),\n        max_time=1000,\n        max_tries=8,\n        on_backoff=backoff_hdlr,\n        giveup=giveup_hdlr,\n    )\n    def request(self, prompt: str, **kwargs):\n        \"\"\"Handles retrieval of completions from Anthropic whilst handling API errors.\"\"\"\n        return self.basic_request(prompt, **kwargs)\n\n    def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):\n        \"\"\"Retrieves completions from Anthropic.\n\n        Args:\n            prompt (str): prompt to send to Anthropic\n            only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.\n            return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.\n\n        Returns:\n            list[str]: list of completion choices\n        \"\"\"\n        assert only_completed, \"for now\"\n        assert return_sorted is False, \"for now\"\n        # per eg here: https://docs.anthropic.com/claude/reference/messages-examples\n        # max tokens can be used as a proxy to return smaller responses\n        # so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.\n        n = kwargs.pop(\"n\", 1)\n        completions = []\n        for _ in range(n):\n            response = self.request(prompt, **kwargs)\n            self.log_usage(response)\n            # This is the original behavior in dspy/dsp/modules/anthropic.py.\n            # Comment it out because it can cause \"IndexError: list index out of range\" silently\n            # which is not transparent to developers.\n            # if only_completed and response.stop_reason == \"max_tokens\":\n            #     continue\n            completions = [c.text for c in response.content]\n        return completions\n\n\nclass VLLMClient(dspy.dsp.LM):\n    \"\"\"A client compatible with vLLM HTTP server.\n\n    vLLM HTTP server is designed to be compatible with the OpenAI API. Use OpenAI client to interact with the server.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        port,\n        model_type: Literal[\"chat\", \"text\"] = \"text\",\n        url=\"http://localhost\",\n        api_key=\"null\",\n        **kwargs,\n    ):\n        \"\"\"Check out https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html for more information.\"\"\"\n        super().__init__(model=model)\n        # Store additional kwargs for the generate method.\n        self.kwargs = {**self.kwargs, **kwargs}\n        self.model = model\n        self.base_url = f\"{url}:{port}/v1/\"\n        if model_type == \"chat\":\n            self.base_url += \"chat/\"\n        self.client = OpenAI(base_url=self.base_url, api_key=api_key)\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n        self._token_usage_lock = threading.Lock()\n\n    def basic_request(self, prompt, **kwargs):\n        completion = self.client.chat.completions.create(\n            **kwargs,\n            messages=[{\"role\": \"user\", \"content\": prompt}],\n        )\n        return completion\n\n    @backoff.on_exception(\n        backoff.expo,\n        ERRORS,\n        max_time=1000,\n        on_backoff=backoff_hdlr,\n    )\n    def request(self, prompt: str, **kwargs):\n        return self.basic_request(prompt, **kwargs)\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the response.\"\"\"\n        usage_data = response.usage\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.prompt_tokens\n                self.completion_tokens += usage_data.completion_tokens\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.kwargs.get(\"model\")\n            or self.kwargs.get(\"engine\"): {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n        return usage\n\n    def __call__(self, prompt: str, **kwargs):\n        kwargs = {**self.kwargs, **kwargs}\n\n        try:\n            response = self.request(prompt, **kwargs)\n        except Exception as e:\n            print(f\"Failed to generate completion: {e}\")\n            raise Exception(e)\n\n        self.log_usage(response)\n\n        choices = response.choices\n        completions = [choice.message.content for choice in choices]\n\n        history = {\n            \"prompt\": prompt,\n            \"response\": response,\n            \"kwargs\": kwargs,\n        }\n        self.history.append(history)\n\n        return completions\n\n\nclass OllamaClient(dspy.OllamaLocal):\n    \"\"\"A wrapper class for dspy.OllamaClient.\"\"\"\n\n    def __init__(self, model, port, url=\"http://localhost\", **kwargs):\n        \"\"\"Copied from dspy/dsp/modules/hf_client.py with the addition of storing additional kwargs.\"\"\"\n        # Check if the URL has 'http://' or 'https://'\n        if not url.startswith(\"http://\") and not url.startswith(\"https://\"):\n            url = \"http://\" + url\n        super().__init__(model=model, base_url=f\"{url}:{port}\", **kwargs)\n        # Store additional kwargs for the generate method.\n        self.kwargs = {**self.kwargs, **kwargs}\n\n\nclass TGIClient(dspy.HFClientTGI):\n    def __init__(self, model, port, url, http_request_kwargs=None, **kwargs):\n        super().__init__(\n            model=model,\n            port=port,\n            url=url,\n            http_request_kwargs=http_request_kwargs,\n            **kwargs,\n        )\n\n    def _generate(self, prompt, **kwargs):\n        \"\"\"Copied from dspy/dsp/modules/hf_client.py with the addition of removing hard-coded parameters.\"\"\"\n        kwargs = {**self.kwargs, **kwargs}\n\n        payload = {\n            \"inputs\": prompt,\n            \"parameters\": {\n                \"do_sample\": kwargs[\"n\"] > 1,\n                \"best_of\": kwargs[\"n\"],\n                \"details\": kwargs[\"n\"] > 1,\n                **kwargs,\n            },\n        }\n\n        payload[\"parameters\"] = openai_to_hf(**payload[\"parameters\"])\n\n        # Comment out the following lines to remove the hard-coded parameters.\n        # payload[\"parameters\"][\"temperature\"] = max(\n        #     0.1, payload[\"parameters\"][\"temperature\"],\n        # )\n\n        response = send_hftgi_request_v01_wrapped(\n            f\"{self.url}:{random.Random().choice(self.ports)}\" + \"/generate\",\n            url=self.url,\n            ports=tuple(self.ports),\n            json=payload,\n            headers=self.headers,\n            **self.http_request_kwargs,\n        )\n\n        try:\n            json_response = response.json()\n            # completions = json_response[\"generated_text\"]\n\n            completions = [json_response[\"generated_text\"]]\n\n            if (\n                \"details\" in json_response\n                and \"best_of_sequences\" in json_response[\"details\"]\n            ):\n                completions += [\n                    x[\"generated_text\"]\n                    for x in json_response[\"details\"][\"best_of_sequences\"]\n                ]\n\n            response = {\"prompt\": prompt, \"choices\": [{\"text\": c} for c in completions]}\n            return response\n        except Exception:\n            print(\"Failed to parse JSON response:\", response.text)\n            raise Exception(\"Received invalid JSON response from server\")\n\n\nclass TogetherClient(dspy.HFModel):\n    \"\"\"A wrapper class for dspy.Together.\"\"\"\n\n    def __init__(\n        self,\n        model,\n        api_key: Optional[str] = None,\n        apply_tokenizer_chat_template=False,\n        hf_tokenizer_name=None,\n        model_type: Literal[\"chat\", \"text\"] = \"chat\",\n        **kwargs,\n    ):\n        \"\"\"Copied from dspy/dsp/modules/hf_client.py with the support of applying tokenizer chat template.\"\"\"\n\n        super().__init__(model=model, is_client=True)\n        self.session = requests.Session()\n        self.api_key = api_key = (\n            os.environ.get(\"TOGETHER_API_KEY\") if api_key is None else api_key\n        )\n        self.model = model\n        self.model_type = model_type\n        if os.getenv(\"TOGETHER_API_BASE\") is None:\n            if self.model_type == \"chat\":\n                self.api_base = \"https://api.together.xyz/v1/chat/completions\"\n            else:\n                self.api_base = \"https://api.together.xyz/v1/completions\"\n        else:\n            self.api_base = os.getenv(\"TOGETHER_API_BASE\")\n\n        # self.use_inst_template = False\n        # if any(keyword in self.model.lower() for keyword in [\"inst\", \"instruct\"]):\n        #     self.use_inst_template = True\n        self.apply_tokenizer_chat_template = apply_tokenizer_chat_template\n        if self.apply_tokenizer_chat_template:\n            logging.info(\"Loading huggingface tokenizer.\")\n            if hf_tokenizer_name is None:\n                hf_tokenizer_name = self.model\n            self.tokenizer = AutoTokenizer.from_pretrained(\n                hf_tokenizer_name, cache_dir=kwargs.get(\"cache_dir\", None)\n            )\n\n        stop_default = \"\\n\\n---\"\n\n        self.kwargs = {\n            \"temperature\": kwargs.get(\"temperature\", 0.0),\n            \"max_tokens\": min(kwargs.get(\"max_tokens\", 4096), 4096),\n            \"top_p\": kwargs.get(\"top_p\", 1.0),\n            \"top_k\": kwargs.get(\"top_k\", 1),\n            \"repetition_penalty\": 1,\n            \"n\": kwargs.pop(\"n\", kwargs.pop(\"num_generations\", 1)),\n            \"stop\": stop_default if \"stop\" not in kwargs else kwargs[\"stop\"],\n            **kwargs,\n        }\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the OpenAI API response.\"\"\"\n        usage_data = response.get(\"usage\")\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.get(\"prompt_tokens\", 0)\n                self.completion_tokens += usage_data.get(\"completion_tokens\", 0)\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.model: {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n        return usage\n\n    @backoff.on_exception(\n        backoff.expo,\n        ERRORS,\n        max_time=1000,\n        on_backoff=backoff_hdlr,\n    )\n    def _generate(self, prompt, **kwargs):\n        kwargs = {**self.kwargs, **kwargs}\n\n        stop = kwargs.get(\"stop\")\n        temperature = kwargs.get(\"temperature\")\n        max_tokens = kwargs.get(\"max_tokens\", 150)\n        top_p = kwargs.get(\"top_p\", 0.7)\n        top_k = kwargs.get(\"top_k\", 50)\n        repetition_penalty = kwargs.get(\"repetition_penalty\", 1)\n        if self.apply_tokenizer_chat_template:\n            prompt = self.tokenizer.apply_chat_template(\n                [{\"role\": \"user\", \"content\": prompt}], tokenize=False\n            )\n        # prompt = f\"[INST]{prompt}[/INST]\" if self.use_inst_template else prompt\n\n        if self.model_type == \"chat\":\n            messages = [\n                {\n                    \"role\": \"system\",\n                    \"content\": \"You are a helpful assistant. You must continue the user text directly without *any* additional interjections.\",\n                },\n                {\"role\": \"user\", \"content\": prompt},\n            ]\n            body = {\n                \"model\": self.model,\n                \"messages\": messages,\n                \"temperature\": temperature,\n                \"max_tokens\": max_tokens,\n                \"top_p\": top_p,\n                \"top_k\": top_k,\n                \"repetition_penalty\": repetition_penalty,\n                \"stop\": stop,\n            }\n        else:\n            body = {\n                \"model\": self.model,\n                \"prompt\": prompt,\n                \"temperature\": temperature,\n                \"max_tokens\": max_tokens,\n                \"top_p\": top_p,\n                \"top_k\": top_k,\n                \"repetition_penalty\": repetition_penalty,\n                \"stop\": stop,\n            }\n\n        headers = {\"Authorization\": f\"Bearer {self.api_key}\"}\n\n        with self.session.post(self.api_base, headers=headers, json=body) as resp:\n            resp_json = resp.json()\n            # Log the token usage from the Together API response.\n            self.log_usage(resp_json)\n            if self.model_type == \"chat\":\n                # completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', \"\")]\n                completions = [\n                    resp_json.get(\"choices\", [])[0]\n                    .get(\"message\", {})\n                    .get(\"content\", \"\")\n                ]\n            else:\n                # completions = [resp_json['output'].get('choices', [])[0].get('text', \"\")]\n                completions = [resp_json.get(\"choices\", [])[0].get(\"text\", \"\")]\n            response = {\"prompt\": prompt, \"choices\": [{\"text\": c} for c in completions]}\n            return response\n\n\nclass GoogleModel(dspy.dsp.modules.lm.LM):\n    \"\"\"A wrapper class for Google Gemini API.\"\"\"\n\n    def __init__(\n        self,\n        model: str,\n        api_key: Optional[str] = None,\n        **kwargs,\n    ):\n        \"\"\"You can use `genai.list_models()` to get a list of available models.\"\"\"\n        super().__init__(model)\n        try:\n            import google.generativeai as genai\n        except ImportError as err:\n            raise ImportError(\n                \"GoogleModel requires `pip install google-generativeai`.\"\n            ) from err\n\n        api_key = os.environ.get(\"GOOGLE_API_KEY\") if api_key is None else api_key\n        genai.configure(api_key=api_key)\n\n        kwargs = {\n            \"candidate_count\": 1,  # Caveat: Gemini API supports only one candidate for now.\n            \"temperature\": (\n                0.0 if \"temperature\" not in kwargs else kwargs[\"temperature\"]\n            ),\n            \"max_output_tokens\": kwargs[\"max_tokens\"],\n            \"top_p\": 1,\n            \"top_k\": 1,\n            **kwargs,\n        }\n\n        kwargs.pop(\"max_tokens\", None)  # GenerationConfig cannot accept max_tokens\n\n        self.model = model\n        self.config = genai.GenerationConfig(**kwargs)\n        self.llm = genai.GenerativeModel(\n            model_name=model, generation_config=self.config\n        )\n\n        self.kwargs = {\n            \"n\": 1,\n            **kwargs,\n        }\n\n        self.history: list[dict[str, Any]] = []\n\n        self._token_usage_lock = threading.Lock()\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n    def log_usage(self, response):\n        \"\"\"Log the total tokens from the Google API response.\"\"\"\n        usage_data = response.usage_metadata\n        if usage_data:\n            with self._token_usage_lock:\n                self.prompt_tokens += usage_data.prompt_token_count\n                self.completion_tokens += usage_data.candidates_token_count\n\n    def get_usage_and_reset(self):\n        \"\"\"Get the total tokens used and reset the token usage.\"\"\"\n        usage = {\n            self.model: {\n                \"prompt_tokens\": self.prompt_tokens,\n                \"completion_tokens\": self.completion_tokens,\n            }\n        }\n        self.prompt_tokens = 0\n        self.completion_tokens = 0\n\n        return usage\n\n    def basic_request(self, prompt: str, **kwargs):\n        raw_kwargs = kwargs\n        kwargs = {\n            **self.kwargs,\n            **kwargs,\n        }\n\n        # Google disallows \"n\" arguments.\n        n = kwargs.pop(\"n\", None)\n\n        response = self.llm.generate_content(prompt, generation_config=kwargs)\n\n        history = {\n            \"prompt\": prompt,\n            \"response\": [response.to_dict()],\n            \"kwargs\": kwargs,\n            \"raw_kwargs\": raw_kwargs,\n        }\n        self.history.append(history)\n\n        return response\n\n    @backoff.on_exception(\n        backoff.expo,\n        (Exception,),\n        max_time=1000,\n        max_tries=8,\n        on_backoff=backoff_hdlr,\n        giveup=giveup_hdlr,\n    )\n    def request(self, prompt: str, **kwargs):\n        \"\"\"Handles retrieval of completions from Google whilst handling API errors\"\"\"\n        return self.basic_request(prompt, **kwargs)\n\n    def __call__(\n        self,\n        prompt: str,\n        only_completed: bool = True,\n        return_sorted: bool = False,\n        **kwargs,\n    ):\n        assert only_completed, \"for now\"\n        assert return_sorted is False, \"for now\"\n\n        n = kwargs.pop(\"n\", 1)\n\n        completions = []\n        for _ in range(n):\n            response = self.request(prompt, **kwargs)\n            self.log_usage(response)\n            completions.append(response.parts[0].text)\n\n        return completions\n\n\n# ========================================================================\n"
  },
  {
    "path": "knowledge_storm/logging_wrapper.py",
    "content": "from contextlib import contextmanager\nimport time\nimport pytz\nfrom datetime import datetime\n\n# Define California timezone\nCALIFORNIA_TZ = pytz.timezone(\"America/Los_Angeles\")\n\n\nclass EventLog:\n    def __init__(self, event_name):\n        self.event_name = event_name\n        self.start_time = None\n        self.end_time = None\n        self.child_events = {}\n\n    def record_start_time(self):\n        self.start_time = datetime.now(\n            pytz.utc\n        )  # Store in UTC for consistent timezone conversion\n\n    def record_end_time(self):\n        self.end_time = datetime.now(\n            pytz.utc\n        )  # Store in UTC for consistent timezone conversion\n\n    def get_total_time(self):\n        if self.start_time and self.end_time:\n            return (self.end_time - self.start_time).total_seconds()\n        return 0\n\n    def get_start_time(self):\n        if self.start_time:\n            # Format to milliseconds\n            return self.start_time.astimezone(CALIFORNIA_TZ).strftime(\n                \"%Y-%m-%d %H:%M:%S.%f\"\n            )[:-3]\n        return None\n\n    def get_end_time(self):\n        if self.end_time:\n            # Format to milliseconds\n            return self.end_time.astimezone(CALIFORNIA_TZ).strftime(\n                \"%Y-%m-%d %H:%M:%S.%f\"\n            )[:-3]\n        return None\n\n    def add_child_event(self, child_event):\n        self.child_events[child_event.event_name] = child_event\n\n    def get_child_events(self):\n        return self.child_events\n\n\nclass LoggingWrapper:\n    def __init__(self, lm_config):\n        self.logging_dict = {}\n        self.lm_config = lm_config\n        self.current_pipeline_stage = None\n        self.event_stack = []\n        self.pipeline_stage_active = False\n\n    def _pipeline_stage_start(self, pipeline_stage: str):\n        if self.pipeline_stage_active:\n            raise RuntimeError(\n                \"A pipeline stage is already active. End the current stage before starting a new one.\"\n            )\n\n        self.current_pipeline_stage = pipeline_stage\n        self.logging_dict[pipeline_stage] = {\n            \"time_usage\": {},\n            \"lm_usage\": {},\n            \"lm_history\": [],\n            \"query_count\": 0,\n        }\n        self.pipeline_stage_active = True\n\n    def _event_start(self, event_name: str):\n        if not self.pipeline_stage_active:\n            raise RuntimeError(\"No pipeline stage is currently active.\")\n\n        if not self.event_stack and self.current_pipeline_stage:\n            # Top-level event (directly under the pipeline stage)\n            if (\n                event_name\n                not in self.logging_dict[self.current_pipeline_stage][\"time_usage\"]\n            ):\n                event = EventLog(event_name=event_name)\n                event.record_start_time()\n                self.logging_dict[self.current_pipeline_stage][\"time_usage\"][\n                    event_name\n                ] = event\n                self.event_stack.append(event)\n            else:\n                self.logging_dict[self.current_pipeline_stage][\"time_usage\"][\n                    event_name\n                ].record_start_time()\n        elif self.event_stack:\n            # Nested event (under another event)\n            parent_event = self.event_stack[-1]\n            if event_name not in parent_event.get_child_events():\n                event = EventLog(event_name=event_name)\n                event.record_start_time()\n                parent_event.add_child_event(event)\n                self.logging_dict[self.current_pipeline_stage][\"time_usage\"][\n                    event_name\n                ] = event\n                self.event_stack.append(event)\n            else:\n                parent_event.get_child_events()[event_name].record_start_time()\n        else:\n            raise RuntimeError(\n                \"Cannot start an event without an active pipeline stage or parent event.\"\n            )\n\n    def _event_end(self, event_name: str):\n        if not self.pipeline_stage_active:\n            raise RuntimeError(\"No pipeline stage is currently active.\")\n\n        if not self.event_stack:\n            raise RuntimeError(\"No parent event is currently active.\")\n\n        if self.event_stack:\n            current_event_log = self.event_stack[-1]\n            if event_name in current_event_log.get_child_events():\n                current_event_log.get_child_events()[event_name].record_end_time()\n            elif (\n                event_name\n                in self.logging_dict[self.current_pipeline_stage][\"time_usage\"]\n            ):\n                self.logging_dict[self.current_pipeline_stage][\"time_usage\"][\n                    event_name\n                ].record_end_time()\n            else:\n                raise AssertionError(\n                    f\"Failure to record end time for event {event_name}. Start time is not recorded.\"\n                )\n            if current_event_log.event_name == event_name:\n                self.event_stack.pop()\n        else:\n            raise RuntimeError(\"Cannot end an event without an active parent event.\")\n\n    def _pipeline_stage_end(self):\n        if not self.pipeline_stage_active:\n            raise RuntimeError(\"No pipeline stage is currently active to end.\")\n\n        self.logging_dict[self.current_pipeline_stage][\n            \"lm_usage\"\n        ] = self.lm_config.collect_and_reset_lm_usage()\n        self.logging_dict[self.current_pipeline_stage][\n            \"lm_history\"\n        ] = self.lm_config.collect_and_reset_lm_history()\n        self.pipeline_stage_active = False\n\n    def add_query_count(self, count):\n        if not self.pipeline_stage_active:\n            raise RuntimeError(\n                \"No pipeline stage is currently active to add query count.\"\n            )\n\n        self.logging_dict[self.current_pipeline_stage][\"query_count\"] += count\n\n    @contextmanager\n    def log_event(self, event_name):\n        if not self.pipeline_stage_active:\n            raise RuntimeError(\"No pipeline stage is currently active.\")\n\n        self._event_start(event_name)\n        yield\n        self._event_end(event_name)\n\n    @contextmanager\n    def log_pipeline_stage(self, pipeline_stage):\n        if self.pipeline_stage_active:\n            print(\n                \"A pipeline stage is already active, ending the current stage safely.\"\n            )\n            self._pipeline_stage_end()\n\n        start_time = time.time()\n        try:\n            self._pipeline_stage_start(pipeline_stage)\n            yield\n        except Exception as e:\n            print(f\"Error occurred during pipeline stage '{pipeline_stage}': {e}\")\n        finally:\n            self.logging_dict[self.current_pipeline_stage][\"total_wall_time\"] = (\n                time.time() - start_time\n            )\n            self._pipeline_stage_end()\n\n    def dump_logging_and_reset(self, reset_logging=True):\n        log_dump = {}\n        for pipeline_stage, pipeline_log in self.logging_dict.items():\n            time_stamp_log = {\n                event_name: {\n                    \"total_time_seconds\": event.get_total_time(),\n                    \"start_time\": event.get_start_time(),\n                    \"end_time\": event.get_end_time(),\n                }\n                for event_name, event in pipeline_log[\"time_usage\"].items()\n            }\n            log_dump[pipeline_stage] = {\n                \"time_usage\": time_stamp_log,\n                \"lm_usage\": pipeline_log[\"lm_usage\"],\n                \"lm_history\": pipeline_log[\"lm_history\"],\n                \"query_count\": pipeline_log[\"query_count\"],\n                \"total_wall_time\": pipeline_log[\"total_wall_time\"],\n            }\n        if reset_logging:\n            self.logging_dict.clear()\n        return log_dump\n"
  },
  {
    "path": "knowledge_storm/rm.py",
    "content": "import logging\nimport os\nfrom typing import Callable, Union, List\n\nimport backoff\nimport dspy\nimport requests\nfrom dsp import backoff_hdlr, giveup_hdlr\n\nfrom .utils import WebPageHelper\n\n\nclass YouRM(dspy.Retrieve):\n    def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None):\n        super().__init__(k=k)\n        if not ydc_api_key and not os.environ.get(\"YDC_API_KEY\"):\n            raise RuntimeError(\n                \"You must supply ydc_api_key or set environment variable YDC_API_KEY\"\n            )\n        elif ydc_api_key:\n            self.ydc_api_key = ydc_api_key\n        else:\n            self.ydc_api_key = os.environ[\"YDC_API_KEY\"]\n        self.usage = 0\n\n        # If not None, is_valid_source shall be a function that takes a URL and returns a boolean.\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n\n        return {\"YouRM\": usage}\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search with You.com for self.k top passages for query or queries\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of urls to exclude from the search results.\n\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n        collected_results = []\n        for query in queries:\n            try:\n                headers = {\"X-API-Key\": self.ydc_api_key}\n                results = requests.get(\n                    f\"https://api.ydc-index.io/search?query={query}\",\n                    headers=headers,\n                ).json()\n\n                authoritative_results = []\n                for r in results[\"hits\"]:\n                    if self.is_valid_source(r[\"url\"]) and r[\"url\"] not in exclude_urls:\n                        authoritative_results.append(r)\n                if \"hits\" in results:\n                    collected_results.extend(authoritative_results[: self.k])\n            except Exception as e:\n                logging.error(f\"Error occurs when searching query {query}: {e}\")\n\n        return collected_results\n\n\nclass BingSearch(dspy.Retrieve):\n    def __init__(\n        self,\n        bing_search_api_key=None,\n        k=3,\n        is_valid_source: Callable = None,\n        min_char_count: int = 150,\n        snippet_chunk_size: int = 1000,\n        webpage_helper_max_threads=10,\n        mkt=\"en-US\",\n        language=\"en\",\n        **kwargs,\n    ):\n        \"\"\"\n        Params:\n            min_char_count: Minimum character count for the article to be considered valid.\n            snippet_chunk_size: Maximum character count for each snippet.\n            webpage_helper_max_threads: Maximum number of threads to use for webpage helper.\n            mkt, language, **kwargs: Bing search API parameters.\n            - Reference: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/query-parameters\n        \"\"\"\n        super().__init__(k=k)\n        if not bing_search_api_key and not os.environ.get(\"BING_SEARCH_API_KEY\"):\n            raise RuntimeError(\n                \"You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY\"\n            )\n        elif bing_search_api_key:\n            self.bing_api_key = bing_search_api_key\n        else:\n            self.bing_api_key = os.environ[\"BING_SEARCH_API_KEY\"]\n        self.endpoint = \"https://api.bing.microsoft.com/v7.0/search\"\n        self.params = {\"mkt\": mkt, \"setLang\": language, \"count\": k, **kwargs}\n        self.webpage_helper = WebPageHelper(\n            min_char_count=min_char_count,\n            snippet_chunk_size=snippet_chunk_size,\n            max_thread_num=webpage_helper_max_threads,\n        )\n        self.usage = 0\n\n        # If not None, is_valid_source shall be a function that takes a URL and returns a boolean.\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n\n        return {\"BingSearch\": usage}\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search with Bing for self.k top passages for query or queries\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of urls to exclude from the search results.\n\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n\n        url_to_results = {}\n\n        headers = {\"Ocp-Apim-Subscription-Key\": self.bing_api_key}\n\n        for query in queries:\n            try:\n                results = requests.get(\n                    self.endpoint, headers=headers, params={**self.params, \"q\": query}\n                ).json()\n\n                for d in results[\"webPages\"][\"value\"]:\n                    if self.is_valid_source(d[\"url\"]) and d[\"url\"] not in exclude_urls:\n                        url_to_results[d[\"url\"]] = {\n                            \"url\": d[\"url\"],\n                            \"title\": d[\"name\"],\n                            \"description\": d[\"snippet\"],\n                        }\n            except Exception as e:\n                logging.error(f\"Error occurs when searching query {query}: {e}\")\n\n        valid_url_to_snippets = self.webpage_helper.urls_to_snippets(\n            list(url_to_results.keys())\n        )\n        collected_results = []\n        for url in valid_url_to_snippets:\n            r = url_to_results[url]\n            r[\"snippets\"] = valid_url_to_snippets[url][\"snippets\"]\n            collected_results.append(r)\n\n        return collected_results\n\n\nclass VectorRM(dspy.Retrieve):\n    \"\"\"Retrieve information from custom documents using Qdrant.\n\n    To be compatible with STORM, the custom documents should have the following fields:\n        - content: The main text content of the document.\n        - title: The title of the document.\n        - url: The URL of the document. STORM use url as the unique identifier of the document, so ensure different\n            documents have different urls.\n        - description (optional): The description of the document.\n    The documents should be stored in a CSV file.\n    \"\"\"\n\n    def __init__(\n        self,\n        collection_name: str,\n        embedding_model: str,\n        device: str = \"mps\",\n        k: int = 3,\n    ):\n        from langchain_huggingface import HuggingFaceEmbeddings\n\n        \"\"\"\n        Params:\n            collection_name: Name of the Qdrant collection.\n            embedding_model: Name of the Hugging Face embedding model.\n            device: Device to run the embeddings model on, can be \"mps\", \"cuda\", \"cpu\".\n            k: Number of top chunks to retrieve.\n        \"\"\"\n        super().__init__(k=k)\n        self.usage = 0\n        # check if the collection is provided\n        if not collection_name:\n            raise ValueError(\"Please provide a collection name.\")\n        # check if the embedding model is provided\n        if not embedding_model:\n            raise ValueError(\"Please provide an embedding model.\")\n\n        model_kwargs = {\"device\": device}\n        encode_kwargs = {\"normalize_embeddings\": True}\n        self.model = HuggingFaceEmbeddings(\n            model_name=embedding_model,\n            model_kwargs=model_kwargs,\n            encode_kwargs=encode_kwargs,\n        )\n\n        self.collection_name = collection_name\n        self.client = None\n        self.qdrant = None\n\n    def _check_collection(self):\n        from langchain_qdrant import Qdrant\n\n        \"\"\"\n        Check if the Qdrant collection exists and create it if it does not.\n        \"\"\"\n        if self.client is None:\n            raise ValueError(\"Qdrant client is not initialized.\")\n        if self.client.collection_exists(collection_name=f\"{self.collection_name}\"):\n            print(\n                f\"Collection {self.collection_name} exists. Loading the collection...\"\n            )\n            self.qdrant = Qdrant(\n                client=self.client,\n                collection_name=self.collection_name,\n                embeddings=self.model,\n            )\n        else:\n            raise ValueError(\n                f\"Collection {self.collection_name} does not exist. Please create the collection first.\"\n            )\n\n    def init_online_vector_db(self, url: str, api_key: str):\n        from qdrant_client import QdrantClient\n\n        \"\"\"\n        Initialize the Qdrant client that is connected to an online vector store with the given URL and API key.\n\n        Args:\n            url (str): URL of the Qdrant server.\n            api_key (str): API key for the Qdrant server.\n        \"\"\"\n        if api_key is None:\n            if not os.getenv(\"QDRANT_API_KEY\"):\n                raise ValueError(\"Please provide an api key.\")\n            api_key = os.getenv(\"QDRANT_API_KEY\")\n        if url is None:\n            raise ValueError(\"Please provide a url for the Qdrant server.\")\n\n        try:\n            self.client = QdrantClient(url=url, api_key=api_key)\n            self._check_collection()\n        except Exception as e:\n            raise ValueError(f\"Error occurs when connecting to the server: {e}\")\n\n    def init_offline_vector_db(self, vector_store_path: str):\n        from qdrant_client import QdrantClient\n\n        \"\"\"\n        Initialize the Qdrant client that is connected to an offline vector store with the given vector store folder path.\n\n        Args:\n            vector_store_path (str): Path to the vector store.\n        \"\"\"\n        if vector_store_path is None:\n            raise ValueError(\"Please provide a folder path.\")\n\n        try:\n            self.client = QdrantClient(path=vector_store_path)\n            self._check_collection()\n        except Exception as e:\n            raise ValueError(f\"Error occurs when loading the vector store: {e}\")\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n\n        return {\"VectorRM\": usage}\n\n    def get_vector_count(self):\n        \"\"\"\n        Get the count of vectors in the collection.\n\n        Returns:\n            int: Number of vectors in the collection.\n        \"\"\"\n        return self.qdrant.client.count(collection_name=self.collection_name)\n\n    def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):\n        \"\"\"\n        Search in your data for self.k top passages for query or queries.\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.\n\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n        collected_results = []\n        for query in queries:\n            related_docs = self.qdrant.similarity_search_with_score(query, k=self.k)\n            for i in range(len(related_docs)):\n                doc = related_docs[i][0]\n                collected_results.append(\n                    {\n                        \"description\": doc.metadata[\"description\"],\n                        \"snippets\": [doc.page_content],\n                        \"title\": doc.metadata[\"title\"],\n                        \"url\": doc.metadata[\"url\"],\n                    }\n                )\n\n        return collected_results\n\n\nclass StanfordOvalArxivRM(dspy.Retrieve):\n    \"\"\"[Alpha] This retrieval class is for internal use only, not intended for the public.\"\"\"\n\n    def __init__(self, endpoint, k=3, rerank=True):\n        super().__init__(k=k)\n        self.endpoint = endpoint\n        self.usage = 0\n        self.rerank = rerank\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n\n        return {\"StanfordOvalArxivRM\": usage}\n\n    def _retrieve(self, query: str):\n        payload = {\"query\": query, \"num_blocks\": self.k, \"rerank\": self.rerank}\n\n        response = requests.post(\n            self.endpoint, json=payload, headers={\"Content-Type\": \"application/json\"}\n        )\n\n        # Check if the request was successful\n        if response.status_code == 200:\n            response_data_list = response.json()[0][\"results\"]\n            results = []\n            for response_data in response_data_list:\n                result = {\n                    \"title\": response_data[\"document_title\"],\n                    \"url\": response_data[\"url\"],\n                    \"snippets\": [response_data[\"content\"]],\n                    \"description\": response_data.get(\"description\", \"N/A\"),\n                    \"meta\": {\n                        key: value\n                        for key, value in response_data.items()\n                        if key not in [\"document_title\", \"url\", \"content\"]\n                    },\n                }\n\n                results.append(result)\n\n            return results\n        else:\n            raise Exception(\n                f\"Error: Unable to retrieve results. Status code: {response.status_code}\"\n            )\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        collected_results = []\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n\n        for query in queries:\n            try:\n                results = self._retrieve(query)\n                collected_results.extend(results)\n            except Exception as e:\n                logging.error(f\"Error occurs when searching query {query}: {e}\")\n        return collected_results\n\n\nclass SerperRM(dspy.Retrieve):\n    \"\"\"Retrieve information from custom queries using Serper.dev.\"\"\"\n\n    def __init__(\n        self,\n        serper_search_api_key=None,\n        k=3,\n        query_params=None,\n        ENABLE_EXTRA_SNIPPET_EXTRACTION=False,\n        min_char_count: int = 150,\n        snippet_chunk_size: int = 1000,\n        webpage_helper_max_threads=10,\n    ):\n        \"\"\"Args:\n        serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/\n        query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query.\n            Commonly used fields are as follows (see more information in https://serper.dev/playground):\n                q str: query that will be used with google search\n                type str: type that will be used for browsing google. Types are search, images, video, maps, places, etc.\n                gl str: Country that will be focused on for the search\n                location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations.\n                autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated.\n                results int: Max number of results per page.\n                page int: Max number of pages per call.\n                tbs str: date time range, automatically set to any time by default.\n                qdr:h str: Date time range for the past hour.\n                qdr:d str: Date time range for the past 24 hours.\n                qdr:w str: Date time range for past week.\n                qdr:m str: Date time range for past month.\n                qdr:y str: Date time range for past year.\n        \"\"\"\n        super().__init__(k=k)\n        self.usage = 0\n        self.query_params = None\n        self.ENABLE_EXTRA_SNIPPET_EXTRACTION = ENABLE_EXTRA_SNIPPET_EXTRACTION\n        self.webpage_helper = WebPageHelper(\n            min_char_count=min_char_count,\n            snippet_chunk_size=snippet_chunk_size,\n            max_thread_num=webpage_helper_max_threads,\n        )\n\n        if query_params is None:\n            self.query_params = {\"num\": k, \"autocorrect\": True, \"page\": 1}\n        else:\n            self.query_params = query_params\n            self.query_params.update({\"num\": k})\n        self.serper_search_api_key = serper_search_api_key\n        if not self.serper_search_api_key and not os.environ.get(\"SERPER_API_KEY\"):\n            raise RuntimeError(\n                \"You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY\"\n            )\n\n        elif self.serper_search_api_key:\n            self.serper_search_api_key = serper_search_api_key\n\n        else:\n            self.serper_search_api_key = os.environ[\"SERPER_API_KEY\"]\n\n        self.base_url = \"https://google.serper.dev\"\n\n    def serper_runner(self, query_params):\n        self.search_url = f\"{self.base_url}/search\"\n\n        headers = {\n            \"X-API-KEY\": self.serper_search_api_key,\n            \"Content-Type\": \"application/json\",\n        }\n\n        response = requests.request(\n            \"POST\", self.search_url, headers=headers, json=query_params\n        )\n\n        if response == None:\n            raise RuntimeError(\n                f\"Error had occurred while running the search process.\\n Error is {response.reason}, had failed with status code {response.status_code}\"\n            )\n\n        return response.json()\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n        return {\"SerperRM\": usage}\n\n    def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):\n        \"\"\"\n        Calls the API and searches for the query passed in.\n\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.\n\n        Returns:\n            a list of dictionaries, each dictionary has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n\n        self.usage += len(queries)\n        self.results = []\n        collected_results = []\n        for query in queries:\n            if query == \"Queries:\":\n                continue\n            query_params = self.query_params\n\n            # All available parameters can be found in the playground: https://serper.dev/playground\n            # Sets the json value for query to be the query that is being parsed.\n            query_params[\"q\"] = query\n\n            # Sets the type to be search, can be images, video, places, maps etc that Google provides.\n            query_params[\"type\"] = \"search\"\n\n            self.result = self.serper_runner(query_params)\n            self.results.append(self.result)\n\n        # Array of dictionaries that will be used by Storm to create the jsons\n        collected_results = []\n\n        if self.ENABLE_EXTRA_SNIPPET_EXTRACTION:\n            urls = []\n            for result in self.results:\n                organic_results = result.get(\"organic\", [])\n                for organic in organic_results:\n                    url = organic.get(\"link\")\n                    if url:\n                        urls.append(url)\n            valid_url_to_snippets = self.webpage_helper.urls_to_snippets(urls)\n        else:\n            valid_url_to_snippets = {}\n\n        for result in self.results:\n            try:\n                # An array of dictionaries that contains the snippets, title of the document and url that will be used.\n                organic_results = result.get(\"organic\")\n                knowledge_graph = result.get(\"knowledgeGraph\")\n                for organic in organic_results:\n                    snippets = [organic.get(\"snippet\")]\n                    if self.ENABLE_EXTRA_SNIPPET_EXTRACTION:\n                        snippets.extend(\n                            valid_url_to_snippets.get(url, {}).get(\"snippets\", [])\n                        )\n                    collected_results.append(\n                        {\n                            \"snippets\": snippets,\n                            \"title\": organic.get(\"title\"),\n                            \"url\": organic.get(\"link\"),\n                            \"description\": (\n                                knowledge_graph.get(\"description\")\n                                if knowledge_graph is not None\n                                else \"\"\n                            ),\n                        }\n                    )\n            except:\n                continue\n\n        return collected_results\n\n\nclass BraveRM(dspy.Retrieve):\n    def __init__(\n        self, brave_search_api_key=None, k=3, is_valid_source: Callable = None\n    ):\n        super().__init__(k=k)\n        if not brave_search_api_key and not os.environ.get(\"BRAVE_API_KEY\"):\n            raise RuntimeError(\n                \"You must supply brave_search_api_key or set environment variable BRAVE_API_KEY\"\n            )\n        elif brave_search_api_key:\n            self.brave_search_api_key = brave_search_api_key\n        else:\n            self.brave_search_api_key = os.environ[\"BRAVE_API_KEY\"]\n        self.usage = 0\n\n        # If not None, is_valid_source shall be a function that takes a URL and returns a boolean.\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n\n        return {\"BraveRM\": usage}\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search with api.search.brave.com for self.k top passages for query or queries\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of urls to exclude from the search results.\n\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n        collected_results = []\n        for query in queries:\n            try:\n                headers = {\n                    \"Accept\": \"application/json\",\n                    \"Accept-Encoding\": \"gzip\",\n                    \"X-Subscription-Token\": self.brave_search_api_key,\n                }\n                response = requests.get(\n                    f\"https://api.search.brave.com/res/v1/web/search?result_filter=web&q={query}\",\n                    headers=headers,\n                ).json()\n                results = response.get(\"web\", {}).get(\"results\", [])\n\n                for result in results:\n                    collected_results.append(\n                        {\n                            \"snippets\": result.get(\"extra_snippets\", []),\n                            \"title\": result.get(\"title\"),\n                            \"url\": result.get(\"url\"),\n                            \"description\": result.get(\"description\"),\n                        }\n                    )\n            except Exception as e:\n                logging.error(f\"Error occurs when searching query {query}: {e}\")\n\n        return collected_results\n\n\nclass SearXNG(dspy.Retrieve):\n    def __init__(\n        self,\n        searxng_api_url,\n        searxng_api_key=None,\n        k=3,\n        is_valid_source: Callable = None,\n    ):\n        \"\"\"Initialize the SearXNG search retriever.\n        Please set up SearXNG according to https://docs.searxng.org/index.html.\n\n        Args:\n            searxng_api_url (str): The URL of the SearXNG API. Consult SearXNG documentation for details.\n            searxng_api_key (str, optional): The API key for the SearXNG API. Defaults to None. Consult SearXNG documentation for details.\n            k (int, optional): The number of top passages to retrieve. Defaults to 3.\n            is_valid_source (Callable, optional): A function that takes a URL and returns a boolean indicating if the\n            source is valid. Defaults to None.\n        \"\"\"\n        super().__init__(k=k)\n        if not searxng_api_url:\n            raise RuntimeError(\"You must supply searxng_api_url\")\n        self.searxng_api_url = searxng_api_url\n        self.searxng_api_key = searxng_api_key\n        self.usage = 0\n\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n        return {\"SearXNG\": usage}\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search with SearxNG for self.k top passages for query or queries\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of urls to exclude from the search results.\n\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n        collected_results = []\n        headers = (\n            {\"Authorization\": f\"Bearer {self.searxng_api_key}\"}\n            if self.searxng_api_key\n            else {}\n        )\n\n        for query in queries:\n            try:\n                params = {\"q\": query, \"format\": \"json\"}\n                response = requests.get(\n                    self.searxng_api_url, headers=headers, params=params\n                )\n                results = response.json()\n\n                for r in results[\"results\"]:\n                    if self.is_valid_source(r[\"url\"]) and r[\"url\"] not in exclude_urls:\n                        collected_results.append(\n                            {\n                                \"description\": r.get(\"content\", \"\"),\n                                \"snippets\": [r.get(\"content\", \"\")],\n                                \"title\": r.get(\"title\", \"\"),\n                                \"url\": r[\"url\"],\n                            }\n                        )\n            except Exception as e:\n                logging.error(f\"Error occurs when searching query {query}: {e}\")\n\n        return collected_results\n\n\nclass DuckDuckGoSearchRM(dspy.Retrieve):\n    \"\"\"Retrieve information from custom queries using DuckDuckGo.\"\"\"\n\n    def __init__(\n        self,\n        k: int = 3,\n        is_valid_source: Callable = None,\n        min_char_count: int = 150,\n        snippet_chunk_size: int = 1000,\n        webpage_helper_max_threads=10,\n        safe_search: str = \"On\",\n        region: str = \"us-en\",\n    ):\n        \"\"\"\n        Params:\n            min_char_count: Minimum character count for the article to be considered valid.\n            snippet_chunk_size: Maximum character count for each snippet.\n            webpage_helper_max_threads: Maximum number of threads to use for webpage helper.\n            **kwargs: Additional parameters for the OpenAI API.\n        \"\"\"\n        super().__init__(k=k)\n        try:\n            from duckduckgo_search import DDGS\n        except ImportError as err:\n            raise ImportError(\n                \"Duckduckgo requires `pip install duckduckgo_search`.\"\n            ) from err\n        self.k = k\n        self.webpage_helper = WebPageHelper(\n            min_char_count=min_char_count,\n            snippet_chunk_size=snippet_chunk_size,\n            max_thread_num=webpage_helper_max_threads,\n        )\n        self.usage = 0\n        # All params for search can be found here:\n        #   https://duckduckgo.com/duckduckgo-help-pages/settings/params/\n\n        # Sets the backend to be api\n        self.duck_duck_go_backend = \"api\"\n\n        # Only gets safe search results\n        self.duck_duck_go_safe_search = safe_search\n\n        # Specifies the region that the search will use\n        self.duck_duck_go_region = region\n\n        # If not None, is_valid_source shall be a function that takes a URL and returns a boolean.\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n        # Import the duckduckgo search library found here: https://github.com/deedy5/duckduckgo_search\n        self.ddgs = DDGS()\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n        return {\"DuckDuckGoRM\": usage}\n\n    @backoff.on_exception(\n        backoff.expo,\n        (Exception,),\n        max_time=1000,\n        max_tries=8,\n        on_backoff=backoff_hdlr,\n        giveup=giveup_hdlr,\n    )\n    def request(self, query: str):\n        results = self.ddgs.text(\n            query, max_results=self.k, backend=self.duck_duck_go_backend\n        )\n        return results\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search with DuckDuckGoSearch for self.k top passages for query or queries\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of urls to exclude from the search results.\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n\n        collected_results = []\n\n        for query in queries:\n            #  list of dicts that will be parsed to return\n            results = self.request(query)\n\n            for d in results:\n                # assert d is dict\n                if not isinstance(d, dict):\n                    print(f\"Invalid result: {d}\\n\")\n                    continue\n\n                try:\n                    # ensure keys are present\n                    url = d.get(\"href\", None)\n                    title = d.get(\"title\", None)\n                    description = d.get(\"description\", title)\n                    snippets = [d.get(\"body\", None)]\n\n                    # raise exception of missing key(s)\n                    if not all([url, title, description, snippets]):\n                        raise ValueError(f\"Missing key(s) in result: {d}\")\n                    if self.is_valid_source(url) and url not in exclude_urls:\n                        result = {\n                            \"url\": url,\n                            \"title\": title,\n                            \"description\": description,\n                            \"snippets\": snippets,\n                        }\n                        collected_results.append(result)\n                    else:\n                        print(f\"invalid source {url} or url in exclude_urls\")\n                except Exception as e:\n                    print(f\"Error occurs when processing {result=}: {e}\\n\")\n                    print(f\"Error occurs when searching query {query}: {e}\")\n\n        return collected_results\n\n\nclass TavilySearchRM(dspy.Retrieve):\n    \"\"\"Retrieve information from custom queries using Tavily. Documentation and examples can be found at https://docs.tavily.com/docs/python-sdk/tavily-search/examples\"\"\"\n\n    def __init__(\n        self,\n        tavily_search_api_key=None,\n        k: int = 3,\n        is_valid_source: Callable = None,\n        min_char_count: int = 150,\n        snippet_chunk_size: int = 1000,\n        webpage_helper_max_threads=10,\n        include_raw_content=False,\n    ):\n        \"\"\"\n        Params:\n            tavily_search_api_key str: API key for tavily that can be retrieved from https://tavily.com/\n            min_char_count: Minimum character count for the article to be considered valid.\n            snippet_chunk_size: Maximum character count for each snippet.\n            webpage_helper_max_threads: Maximum number of threads to use for webpage helper.\n            include_raw_content bool: Boolean that is used to determine if the full text should be returned.\n        \"\"\"\n        super().__init__(k=k)\n        try:\n            from tavily import TavilyClient\n        except ImportError as err:\n            raise ImportError(\"Tavily requires `pip install tavily-python`.\") from err\n\n        if not tavily_search_api_key and not os.environ.get(\"TAVILY_API_KEY\"):\n            raise RuntimeError(\n                \"You must supply tavily_search_api_key or set environment variable TAVILY_API_KEY\"\n            )\n        elif tavily_search_api_key:\n            self.tavily_search_api_key = tavily_search_api_key\n        else:\n            self.tavily_search_api_key = os.environ[\"TAVILY_API_KEY\"]\n\n        self.k = k\n        self.webpage_helper = WebPageHelper(\n            min_char_count=min_char_count,\n            snippet_chunk_size=snippet_chunk_size,\n            max_thread_num=webpage_helper_max_threads,\n        )\n\n        self.usage = 0\n\n        # Creates client instance that will use search. Full search params are here:\n        # https://docs.tavily.com/docs/python-sdk/tavily-search/examples\n        self.tavily_client = TavilyClient(api_key=self.tavily_search_api_key)\n\n        self.include_raw_content = include_raw_content\n\n        # If not None, is_valid_source shall be a function that takes a URL and returns a boolean.\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n        return {\"TavilySearchRM\": usage}\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search with TavilySearch for self.k top passages for query or queries\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of urls to exclude from the search results.\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n\n        collected_results = []\n\n        for query in queries:\n            args = {\n                \"max_results\": self.k,\n                \"include_raw_contents\": self.include_raw_content,\n            }\n            #  list of dicts that will be parsed to return\n            responseData = self.tavily_client.search(query)\n            results = responseData.get(\"results\")\n            for d in results:\n                # assert d is dict\n                if not isinstance(d, dict):\n                    print(f\"Invalid result: {d}\\n\")\n                    continue\n\n                try:\n                    # ensure keys are present\n                    url = d.get(\"url\", None)\n                    title = d.get(\"title\", None)\n                    description = d.get(\"content\", None)\n                    snippets = []\n                    if d.get(\"raw_body_content\"):\n                        snippets.append(d.get(\"raw_body_content\"))\n                    else:\n                        snippets.append(d.get(\"content\"))\n\n                    # raise exception of missing key(s)\n                    if not all([url, title, description, snippets]):\n                        raise ValueError(f\"Missing key(s) in result: {d}\")\n                    if self.is_valid_source(url) and url not in exclude_urls:\n                        result = {\n                            \"url\": url,\n                            \"title\": title,\n                            \"description\": description,\n                            \"snippets\": snippets,\n                        }\n                        collected_results.append(result)\n                    else:\n                        print(f\"invalid source {url} or url in exclude_urls\")\n                except Exception as e:\n                    print(f\"Error occurs when processing {result=}: {e}\\n\")\n                    print(f\"Error occurs when searching query {query}: {e}\")\n\n        return collected_results\n\n\nclass GoogleSearch(dspy.Retrieve):\n    def __init__(\n        self,\n        google_search_api_key=None,\n        google_cse_id=None,\n        k=3,\n        is_valid_source: Callable = None,\n        min_char_count: int = 150,\n        snippet_chunk_size: int = 1000,\n        webpage_helper_max_threads=10,\n    ):\n        \"\"\"\n        Params:\n            google_search_api_key: Google API key. Check out https://developers.google.com/custom-search/v1/overview\n                \"API key\" section\n            google_cse_id: Custom search engine ID. Check out https://developers.google.com/custom-search/v1/overview\n                \"Search engine ID\" section\n            k: Number of top results to retrieve.\n            is_valid_source: Optional function to filter valid sources.\n            min_char_count: Minimum character count for the article to be considered valid.\n            snippet_chunk_size: Maximum character count for each snippet.\n            webpage_helper_max_threads: Maximum number of threads to use for webpage helper.\n        \"\"\"\n        super().__init__(k=k)\n        try:\n            from googleapiclient.discovery import build\n        except ImportError as err:\n            raise ImportError(\n                \"GoogleSearch requires `pip install google-api-python-client`.\"\n            ) from err\n        if not google_search_api_key and not os.environ.get(\"GOOGLE_SEARCH_API_KEY\"):\n            raise RuntimeError(\n                \"You must supply google_search_api_key or set the GOOGLE_SEARCH_API_KEY environment variable\"\n            )\n        if not google_cse_id and not os.environ.get(\"GOOGLE_CSE_ID\"):\n            raise RuntimeError(\n                \"You must supply google_cse_id or set the GOOGLE_CSE_ID environment variable\"\n            )\n\n        self.google_search_api_key = (\n            google_search_api_key or os.environ[\"GOOGLE_SEARCH_API_KEY\"]\n        )\n        self.google_cse_id = google_cse_id or os.environ[\"GOOGLE_CSE_ID\"]\n\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n        self.service = build(\n            \"customsearch\", \"v1\", developerKey=self.google_search_api_key\n        )\n        self.webpage_helper = WebPageHelper(\n            min_char_count=min_char_count,\n            snippet_chunk_size=snippet_chunk_size,\n            max_thread_num=webpage_helper_max_threads,\n        )\n        self.usage = 0\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n        return {\"GoogleSearch\": usage}\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search using Google Custom Search API for self.k top results for query or queries.\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of URLs to exclude from the search results.\n\n        Returns:\n            A list of dicts, each dict has keys: 'title', 'url', 'snippet', 'description'.\n        \"\"\"\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n\n        url_to_results = {}\n\n        for query in queries:\n            try:\n                response = (\n                    self.service.cse()\n                    .list(\n                        q=query,\n                        cx=self.google_cse_id,\n                        num=self.k,\n                    )\n                    .execute()\n                )\n\n                for item in response.get(\"items\", []):\n                    if (\n                        self.is_valid_source(item[\"link\"])\n                        and item[\"link\"] not in exclude_urls\n                    ):\n                        url_to_results[item[\"link\"]] = {\n                            \"title\": item[\"title\"],\n                            \"url\": item[\"link\"],\n                            # \"snippet\": item.get(\"snippet\", \"\"),  # Google search snippet is very short.\n                            \"description\": item.get(\"snippet\", \"\"),\n                        }\n\n            except Exception as e:\n                logging.error(f\"Error occurred while searching query {query}: {e}\")\n\n        valid_url_to_snippets = self.webpage_helper.urls_to_snippets(\n            list(url_to_results.keys())\n        )\n        collected_results = []\n        for url in valid_url_to_snippets:\n            r = url_to_results[url]\n            r[\"snippets\"] = valid_url_to_snippets[url][\"snippets\"]\n            collected_results.append(r)\n\n        return collected_results\n\n\nclass AzureAISearch(dspy.Retrieve):\n    \"\"\"Retrieve information from custom queries using Azure AI Search.\n\n    General Documentation: https://learn.microsoft.com/en-us/azure/search/search-create-service-portal.\n    Python Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python.\n    \"\"\"\n\n    def __init__(\n        self,\n        azure_ai_search_api_key=None,\n        azure_ai_search_url=None,\n        azure_ai_search_index_name=None,\n        k=3,\n        is_valid_source: Callable = None,\n    ):\n        \"\"\"\n        Params:\n            azure_ai_search_api_key: Azure AI Search API key. Check out https://learn.microsoft.com/en-us/azure/search/search-security-api-keys?tabs=rest-use%2Cportal-find%2Cportal-query\n                \"API key\" section\n            azure_ai_search_url: Custom Azure AI Search Endpoint URL. Check out https://learn.microsoft.com/en-us/azure/search/search-create-service-portal#name-the-service\n            azure_ai_search_index_name: Custom Azure AI Search Index Name. Check out https://learn.microsoft.com/en-us/azure/search/search-how-to-create-search-index?tabs=portal\n            k: Number of top results to retrieve.\n            is_valid_source: Optional function to filter valid sources.\n            min_char_count: Minimum character count for the article to be considered valid.\n            snippet_chunk_size: Maximum character count for each snippet.\n            webpage_helper_max_threads: Maximum number of threads to use for webpage helper.\n        \"\"\"\n        super().__init__(k=k)\n\n        try:\n            from azure.core.credentials import AzureKeyCredential\n            from azure.search.documents import SearchClient\n        except ImportError as err:\n            raise ImportError(\n                \"AzureAISearch requires `pip install azure-search-documents`.\"\n            ) from err\n\n        if not azure_ai_search_api_key and not os.environ.get(\n            \"AZURE_AI_SEARCH_API_KEY\"\n        ):\n            raise RuntimeError(\n                \"You must supply azure_ai_search_api_key or set environment variable AZURE_AI_SEARCH_API_KEY\"\n            )\n        elif azure_ai_search_api_key:\n            self.azure_ai_search_api_key = azure_ai_search_api_key\n        else:\n            self.azure_ai_search_api_key = os.environ[\"AZURE_AI_SEARCH_API_KEY\"]\n\n        if not azure_ai_search_url and not os.environ.get(\"AZURE_AI_SEARCH_URL\"):\n            raise RuntimeError(\n                \"You must supply azure_ai_search_url or set environment variable AZURE_AI_SEARCH_URL\"\n            )\n        elif azure_ai_search_url:\n            self.azure_ai_search_url = azure_ai_search_url\n        else:\n            self.azure_ai_search_url = os.environ[\"AZURE_AI_SEARCH_URL\"]\n\n        if not azure_ai_search_index_name and not os.environ.get(\n            \"AZURE_AI_SEARCH_INDEX_NAME\"\n        ):\n            raise RuntimeError(\n                \"You must supply azure_ai_search_index_name or set environment variable AZURE_AI_SEARCH_INDEX_NAME\"\n            )\n        elif azure_ai_search_index_name:\n            self.azure_ai_search_index_name = azure_ai_search_index_name\n        else:\n            self.azure_ai_search_index_name = os.environ[\"AZURE_AI_SEARCH_INDEX_NAME\"]\n\n        self.usage = 0\n\n        # If not None, is_valid_source shall be a function that takes a URL and returns a boolean.\n        if is_valid_source:\n            self.is_valid_source = is_valid_source\n        else:\n            self.is_valid_source = lambda x: True\n\n    def get_usage_and_reset(self):\n        usage = self.usage\n        self.usage = 0\n\n        return {\"AzureAISearch\": usage}\n\n    def forward(\n        self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []\n    ):\n        \"\"\"Search with Azure Open AI for self.k top passages for query or queries\n\n        Args:\n            query_or_queries (Union[str, List[str]]): The query or queries to search for.\n            exclude_urls (List[str]): A list of urls to exclude from the search results.\n\n        Returns:\n            a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'\n        \"\"\"\n        try:\n            from azure.core.credentials import AzureKeyCredential\n            from azure.search.documents import SearchClient\n        except ImportError as err:\n            raise ImportError(\n                \"AzureAISearch requires `pip install azure-search-documents`.\"\n            ) from err\n        queries = (\n            [query_or_queries]\n            if isinstance(query_or_queries, str)\n            else query_or_queries\n        )\n        self.usage += len(queries)\n        collected_results = []\n\n        client = SearchClient(\n            self.azure_ai_search_url,\n            self.azure_ai_search_index_name,\n            AzureKeyCredential(self.azure_ai_search_api_key),\n        )\n        for query in queries:\n            try:\n                # https://learn.microsoft.com/en-us/python/api/azure-search-documents/azure.search.documents.searchclient?view=azure-python#azure-search-documents-searchclient-search\n                results = client.search(search_text=query, top=1)\n\n                for result in results:\n                    document = {\n                        \"url\": result[\"metadata_storage_path\"],\n                        \"title\": result[\"title\"],\n                        \"description\": \"N/A\",\n                        \"snippets\": [result[\"chunk\"]],\n                    }\n                    collected_results.append(document)\n            except Exception as e:\n                logging.error(f\"Error occurs when searching query {query}: {e}\")\n\n        return collected_results\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/__init__.py",
    "content": "from .engine import *\nfrom .modules import *\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/engine.py",
    "content": "import json\nimport logging\nimport os\nfrom dataclasses import dataclass, field\nfrom typing import Union, Literal, Optional\n\nimport dspy\n\nfrom .modules.article_generation import StormArticleGenerationModule\nfrom .modules.article_polish import StormArticlePolishingModule\nfrom .modules.callback import BaseCallbackHandler\nfrom .modules.knowledge_curation import StormKnowledgeCurationModule\nfrom .modules.outline_generation import StormOutlineGenerationModule\nfrom .modules.persona_generator import StormPersonaGenerator\nfrom .modules.storm_dataclass import StormInformationTable, StormArticle\nfrom ..interface import Engine, LMConfigs, Retriever\nfrom ..lm import LitellmModel\nfrom ..utils import FileIOHelper, makeStringRed, truncate_filename\n\n\nclass STORMWikiLMConfigs(LMConfigs):\n    \"\"\"Configurations for LLM used in different parts of STORM.\n\n    Given that different parts in STORM framework have different complexity, we use different LLM configurations\n    to achieve a balance between quality and efficiency. If no specific configuration is provided, we use the default\n    setup in the paper.\n    \"\"\"\n\n    def __init__(self):\n        self.conv_simulator_lm = (\n            None  # LLM used in conversation simulator except for question asking.\n        )\n        self.question_asker_lm = None  # LLM used in question asking.\n        self.outline_gen_lm = None  # LLM used in outline generation.\n        self.article_gen_lm = None  # LLM used in article generation.\n        self.article_polish_lm = None  # LLM used in article polishing.\n\n    def init_openai_model(\n        self,\n        openai_api_key: str,\n        azure_api_key: str,\n        openai_type: Literal[\"openai\", \"azure\"],\n        api_base: Optional[str] = None,\n        api_version: Optional[str] = None,\n        temperature: Optional[float] = 1.0,\n        top_p: Optional[float] = 0.9,\n    ):\n        \"\"\"Legacy: Corresponding to the original setup in the NAACL'24 paper.\"\"\"\n        azure_kwargs = {\n            \"api_key\": azure_api_key,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"api_base\": api_base,\n            \"api_version\": api_version,\n        }\n\n        openai_kwargs = {\n            \"api_key\": openai_api_key,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"api_base\": None,\n        }\n        if openai_type and openai_type == \"openai\":\n            self.conv_simulator_lm = LitellmModel(\n                model=\"gpt-4o-mini-2024-07-18\", max_tokens=500, **openai_kwargs\n            )\n            self.question_asker_lm = LitellmModel(\n                model=\"gpt-4o-mini-2024-07-18\", max_tokens=500, **openai_kwargs\n            )\n            # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.)\n            self.outline_gen_lm = LitellmModel(\n                model=\"gpt-4-0125-preview\", max_tokens=400, **openai_kwargs\n            )\n            self.article_gen_lm = LitellmModel(\n                model=\"gpt-4o-2024-05-13\", max_tokens=700, **openai_kwargs\n            )\n            self.article_polish_lm = LitellmModel(\n                model=\"gpt-4o-2024-05-13\", max_tokens=4000, **openai_kwargs\n            )\n        elif openai_type and openai_type == \"azure\":\n            self.conv_simulator_lm = LitellmModel(\n                model=\"azure/gpt-4o-mini-2024-07-18\", max_tokens=500, **openai_kwargs\n            )\n            self.question_asker_lm = LitellmModel(\n                model=\"azure/gpt-4o-mini-2024-07-18\",\n                max_tokens=500,\n                **azure_kwargs,\n                model_type=\"chat\",\n            )\n            # use combination of openai and azure-openai as azure-openai does not support gpt-4 in standard deployment\n            self.outline_gen_lm = LitellmModel(\n                model=\"azure/gpt-4o\", max_tokens=400, **azure_kwargs, model_type=\"chat\"\n            )\n            self.article_gen_lm = LitellmModel(\n                model=\"azure/gpt-4o-mini-2024-07-18\",\n                max_tokens=700,\n                **azure_kwargs,\n                model_type=\"chat\",\n            )\n            self.article_polish_lm = LitellmModel(\n                model=\"azure/gpt-4o-mini-2024-07-18\",\n                max_tokens=4000,\n                **azure_kwargs,\n                model_type=\"chat\",\n            )\n        else:\n            logging.warning(\n                \"No valid OpenAI API provider is provided. Cannot use default LLM configurations.\"\n            )\n\n    def set_conv_simulator_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.conv_simulator_lm = model\n\n    def set_question_asker_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.question_asker_lm = model\n\n    def set_outline_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.outline_gen_lm = model\n\n    def set_article_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.article_gen_lm = model\n\n    def set_article_polish_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.article_polish_lm = model\n\n\n@dataclass\nclass STORMWikiRunnerArguments:\n    \"\"\"Arguments for controlling the STORM Wiki pipeline.\"\"\"\n\n    output_dir: str = field(\n        metadata={\"help\": \"Output directory for the results.\"},\n    )\n    max_conv_turn: int = field(\n        default=3,\n        metadata={\n            \"help\": \"Maximum number of questions in conversational question asking.\"\n        },\n    )\n    max_perspective: int = field(\n        default=3,\n        metadata={\n            \"help\": \"Maximum number of perspectives to consider in perspective-guided question asking.\"\n        },\n    )\n    max_search_queries_per_turn: int = field(\n        default=3,\n        metadata={\"help\": \"Maximum number of search queries to consider in each turn.\"},\n    )\n    disable_perspective: bool = field(\n        default=False,\n        metadata={\"help\": \"If True, disable perspective-guided question asking.\"},\n    )\n    search_top_k: int = field(\n        default=3,\n        metadata={\"help\": \"Top k search results to consider for each search query.\"},\n    )\n    retrieve_top_k: int = field(\n        default=3,\n        metadata={\"help\": \"Top k collected references for each section title.\"},\n    )\n    max_thread_num: int = field(\n        default=10,\n        metadata={\n            \"help\": \"Maximum number of threads to use. \"\n            \"Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API.\"\n        },\n    )\n\n\nclass STORMWikiRunner(Engine):\n    \"\"\"STORM Wiki pipeline runner.\"\"\"\n\n    def __init__(\n        self, args: STORMWikiRunnerArguments, lm_configs: STORMWikiLMConfigs, rm\n    ):\n        super().__init__(lm_configs=lm_configs)\n        self.args = args\n        self.lm_configs = lm_configs\n\n        self.retriever = Retriever(rm=rm, max_thread=self.args.max_thread_num)\n        storm_persona_generator = StormPersonaGenerator(\n            self.lm_configs.question_asker_lm\n        )\n        self.storm_knowledge_curation_module = StormKnowledgeCurationModule(\n            retriever=self.retriever,\n            persona_generator=storm_persona_generator,\n            conv_simulator_lm=self.lm_configs.conv_simulator_lm,\n            question_asker_lm=self.lm_configs.question_asker_lm,\n            max_search_queries_per_turn=self.args.max_search_queries_per_turn,\n            search_top_k=self.args.search_top_k,\n            max_conv_turn=self.args.max_conv_turn,\n            max_thread_num=self.args.max_thread_num,\n        )\n        self.storm_outline_generation_module = StormOutlineGenerationModule(\n            outline_gen_lm=self.lm_configs.outline_gen_lm\n        )\n        self.storm_article_generation = StormArticleGenerationModule(\n            article_gen_lm=self.lm_configs.article_gen_lm,\n            retrieve_top_k=self.args.retrieve_top_k,\n            max_thread_num=self.args.max_thread_num,\n        )\n        self.storm_article_polishing_module = StormArticlePolishingModule(\n            article_gen_lm=self.lm_configs.article_gen_lm,\n            article_polish_lm=self.lm_configs.article_polish_lm,\n        )\n\n        self.lm_configs.init_check()\n        self.apply_decorators()\n\n    def run_knowledge_curation_module(\n        self,\n        ground_truth_url: str = \"None\",\n        callback_handler: BaseCallbackHandler = None,\n    ) -> StormInformationTable:\n        (\n            information_table,\n            conversation_log,\n        ) = self.storm_knowledge_curation_module.research(\n            topic=self.topic,\n            ground_truth_url=ground_truth_url,\n            callback_handler=callback_handler,\n            max_perspective=self.args.max_perspective,\n            disable_perspective=False,\n            return_conversation_log=True,\n        )\n\n        FileIOHelper.dump_json(\n            conversation_log,\n            os.path.join(self.article_output_dir, \"conversation_log.json\"),\n        )\n        information_table.dump_url_to_info(\n            os.path.join(self.article_output_dir, \"raw_search_results.json\")\n        )\n        return information_table\n\n    def run_outline_generation_module(\n        self,\n        information_table: StormInformationTable,\n        callback_handler: BaseCallbackHandler = None,\n    ) -> StormArticle:\n        outline, draft_outline = self.storm_outline_generation_module.generate_outline(\n            topic=self.topic,\n            information_table=information_table,\n            return_draft_outline=True,\n            callback_handler=callback_handler,\n        )\n        outline.dump_outline_to_file(\n            os.path.join(self.article_output_dir, \"storm_gen_outline.txt\")\n        )\n        draft_outline.dump_outline_to_file(\n            os.path.join(self.article_output_dir, \"direct_gen_outline.txt\")\n        )\n        return outline\n\n    def run_article_generation_module(\n        self,\n        outline: StormArticle,\n        information_table=StormInformationTable,\n        callback_handler: BaseCallbackHandler = None,\n    ) -> StormArticle:\n        draft_article = self.storm_article_generation.generate_article(\n            topic=self.topic,\n            information_table=information_table,\n            article_with_outline=outline,\n            callback_handler=callback_handler,\n        )\n        draft_article.dump_article_as_plain_text(\n            os.path.join(self.article_output_dir, \"storm_gen_article.txt\")\n        )\n        draft_article.dump_reference_to_file(\n            os.path.join(self.article_output_dir, \"url_to_info.json\")\n        )\n        return draft_article\n\n    def run_article_polishing_module(\n        self, draft_article: StormArticle, remove_duplicate: bool = False\n    ) -> StormArticle:\n        polished_article = self.storm_article_polishing_module.polish_article(\n            topic=self.topic,\n            draft_article=draft_article,\n            remove_duplicate=remove_duplicate,\n        )\n        FileIOHelper.write_str(\n            polished_article.to_string(),\n            os.path.join(self.article_output_dir, \"storm_gen_article_polished.txt\"),\n        )\n        return polished_article\n\n    def post_run(self):\n        \"\"\"\n        Post-run operations, including:\n        1. Dumping the run configuration.\n        2. Dumping the LLM call history.\n        \"\"\"\n        config_log = self.lm_configs.log()\n        FileIOHelper.dump_json(\n            config_log, os.path.join(self.article_output_dir, \"run_config.json\")\n        )\n\n        llm_call_history = self.lm_configs.collect_and_reset_lm_history()\n        with open(\n            os.path.join(self.article_output_dir, \"llm_call_history.jsonl\"), \"w\"\n        ) as f:\n            for call in llm_call_history:\n                if \"kwargs\" in call:\n                    call.pop(\n                        \"kwargs\"\n                    )  # All kwargs are dumped together to run_config.json.\n                f.write(json.dumps(call) + \"\\n\")\n\n    def _load_information_table_from_local_fs(self, information_table_local_path):\n        assert os.path.exists(information_table_local_path), makeStringRed(\n            f\"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.\"\n        )\n        return StormInformationTable.from_conversation_log_file(\n            information_table_local_path\n        )\n\n    def _load_outline_from_local_fs(self, topic, outline_local_path):\n        assert os.path.exists(outline_local_path), makeStringRed(\n            f\"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.\"\n        )\n        return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path)\n\n    def _load_draft_article_from_local_fs(\n        self, topic, draft_article_path, url_to_info_path\n    ):\n        assert os.path.exists(draft_article_path), makeStringRed(\n            f\"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.\"\n        )\n        assert os.path.exists(url_to_info_path), makeStringRed(\n            f\"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.\"\n        )\n        article_text = FileIOHelper.load_str(draft_article_path)\n        references = FileIOHelper.load_json(url_to_info_path)\n        return StormArticle.from_string(\n            topic_name=topic, article_text=article_text, references=references\n        )\n\n    def run(\n        self,\n        topic: str,\n        ground_truth_url: str = \"\",\n        do_research: bool = True,\n        do_generate_outline: bool = True,\n        do_generate_article: bool = True,\n        do_polish_article: bool = True,\n        remove_duplicate: bool = False,\n        callback_handler: BaseCallbackHandler = BaseCallbackHandler(),\n    ):\n        \"\"\"\n        Run the STORM pipeline.\n\n        Args:\n            topic: The topic to research.\n            ground_truth_url: A ground truth URL including a curated article about the topic. The URL will be excluded.\n            do_research: If True, research the topic through information-seeking conversation;\n             if False, expect conversation_log.json and raw_search_results.json to exist in the output directory.\n            do_generate_outline: If True, generate an outline for the topic;\n             if False, expect storm_gen_outline.txt to exist in the output directory.\n            do_generate_article: If True, generate a curated article for the topic;\n             if False, expect storm_gen_article.txt to exist in the output directory.\n            do_polish_article: If True, polish the article by adding a summarization section and (optionally) removing\n             duplicated content.\n            remove_duplicate: If True, remove duplicated content.\n            callback_handler: A callback handler to handle the intermediate results.\n        \"\"\"\n        assert (\n            do_research\n            or do_generate_outline\n            or do_generate_article\n            or do_polish_article\n        ), makeStringRed(\n            \"No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article\"\n        )\n\n        self.topic = topic\n        self.article_dir_name = truncate_filename(\n            topic.replace(\" \", \"_\").replace(\"/\", \"_\")\n        )\n        self.article_output_dir = os.path.join(\n            self.args.output_dir, self.article_dir_name\n        )\n        os.makedirs(self.article_output_dir, exist_ok=True)\n\n        # research module\n        information_table: StormInformationTable = None\n        if do_research:\n            information_table = self.run_knowledge_curation_module(\n                ground_truth_url=ground_truth_url, callback_handler=callback_handler\n            )\n        # outline generation module\n        outline: StormArticle = None\n        if do_generate_outline:\n            # load information table if it's not initialized\n            if information_table is None:\n                information_table = self._load_information_table_from_local_fs(\n                    os.path.join(self.article_output_dir, \"conversation_log.json\")\n                )\n            outline = self.run_outline_generation_module(\n                information_table=information_table, callback_handler=callback_handler\n            )\n\n        # article generation module\n        draft_article: StormArticle = None\n        if do_generate_article:\n            if information_table is None:\n                information_table = self._load_information_table_from_local_fs(\n                    os.path.join(self.article_output_dir, \"conversation_log.json\")\n                )\n            if outline is None:\n                outline = self._load_outline_from_local_fs(\n                    topic=topic,\n                    outline_local_path=os.path.join(\n                        self.article_output_dir, \"storm_gen_outline.txt\"\n                    ),\n                )\n            draft_article = self.run_article_generation_module(\n                outline=outline,\n                information_table=information_table,\n                callback_handler=callback_handler,\n            )\n\n        # article polishing module\n        if do_polish_article:\n            if draft_article is None:\n                draft_article_path = os.path.join(\n                    self.article_output_dir, \"storm_gen_article.txt\"\n                )\n                url_to_info_path = os.path.join(\n                    self.article_output_dir, \"url_to_info.json\"\n                )\n                draft_article = self._load_draft_article_from_local_fs(\n                    topic=topic,\n                    draft_article_path=draft_article_path,\n                    url_to_info_path=url_to_info_path,\n                )\n            self.run_article_polishing_module(\n                draft_article=draft_article, remove_duplicate=remove_duplicate\n            )\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/__init__.py",
    "content": "from .knowledge_curation import *\nfrom .persona_generator import *\nfrom .retriever import *\nfrom .storm_dataclass import *\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/article_generation.py",
    "content": "import concurrent.futures\nimport copy\nimport logging\nfrom concurrent.futures import as_completed\nfrom typing import List, Union\n\nimport dspy\n\nfrom .callback import BaseCallbackHandler\nfrom .storm_dataclass import StormInformationTable, StormArticle\nfrom ...interface import ArticleGenerationModule, Information\nfrom ...utils import ArticleTextProcessing\n\n\nclass StormArticleGenerationModule(ArticleGenerationModule):\n    \"\"\"\n    The interface for article generation stage. Given topic, collected information from\n    knowledge curation stage, generated outline from outline generation stage,\n    \"\"\"\n\n    def __init__(\n        self,\n        article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        retrieve_top_k: int = 5,\n        max_thread_num: int = 10,\n    ):\n        super().__init__()\n        self.retrieve_top_k = retrieve_top_k\n        self.article_gen_lm = article_gen_lm\n        self.max_thread_num = max_thread_num\n        self.section_gen = ConvToSection(engine=self.article_gen_lm)\n\n    def generate_section(\n        self, topic, section_name, information_table, section_outline, section_query\n    ):\n        collected_info: List[Information] = []\n        if information_table is not None:\n            collected_info = information_table.retrieve_information(\n                queries=section_query, search_top_k=self.retrieve_top_k\n            )\n        output = self.section_gen(\n            topic=topic,\n            outline=section_outline,\n            section=section_name,\n            collected_info=collected_info,\n        )\n        return {\n            \"section_name\": section_name,\n            \"section_content\": output.section,\n            \"collected_info\": collected_info,\n        }\n\n    def generate_article(\n        self,\n        topic: str,\n        information_table: StormInformationTable,\n        article_with_outline: StormArticle,\n        callback_handler: BaseCallbackHandler = None,\n    ) -> StormArticle:\n        \"\"\"\n        Generate article for the topic based on the information table and article outline.\n\n        Args:\n            topic (str): The topic of the article.\n            information_table (StormInformationTable): The information table containing the collected information.\n            article_with_outline (StormArticle): The article with specified outline.\n            callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger\n                custom callbacks at various stages of the article generation process. Defaults to None.\n        \"\"\"\n        information_table.prepare_table_for_retrieval()\n\n        if article_with_outline is None:\n            article_with_outline = StormArticle(topic_name=topic)\n\n        sections_to_write = article_with_outline.get_first_level_section_names()\n\n        section_output_dict_collection = []\n        if len(sections_to_write) == 0:\n            logging.error(\n                f\"No outline for {topic}. Will directly search with the topic.\"\n            )\n            section_output_dict = self.generate_section(\n                topic=topic,\n                section_name=topic,\n                information_table=information_table,\n                section_outline=\"\",\n                section_query=[topic],\n            )\n            section_output_dict_collection = [section_output_dict]\n        else:\n            with concurrent.futures.ThreadPoolExecutor(\n                max_workers=self.max_thread_num\n            ) as executor:\n                future_to_sec_title = {}\n                for section_title in sections_to_write:\n                    # We don't want to write a separate introduction section.\n                    if section_title.lower().strip() == \"introduction\":\n                        continue\n                        # We don't want to write a separate conclusion section.\n                    if section_title.lower().strip().startswith(\n                        \"conclusion\"\n                    ) or section_title.lower().strip().startswith(\"summary\"):\n                        continue\n                    section_query = article_with_outline.get_outline_as_list(\n                        root_section_name=section_title, add_hashtags=False\n                    )\n                    queries_with_hashtags = article_with_outline.get_outline_as_list(\n                        root_section_name=section_title, add_hashtags=True\n                    )\n                    section_outline = \"\\n\".join(queries_with_hashtags)\n                    future_to_sec_title[\n                        executor.submit(\n                            self.generate_section,\n                            topic,\n                            section_title,\n                            information_table,\n                            section_outline,\n                            section_query,\n                        )\n                    ] = section_title\n\n                for future in as_completed(future_to_sec_title):\n                    section_output_dict_collection.append(future.result())\n\n        article = copy.deepcopy(article_with_outline)\n        for section_output_dict in section_output_dict_collection:\n            article.update_section(\n                parent_section_name=topic,\n                current_section_content=section_output_dict[\"section_content\"],\n                current_section_info_list=section_output_dict[\"collected_info\"],\n            )\n        article.post_processing()\n        return article\n\n\nclass ConvToSection(dspy.Module):\n    \"\"\"Use the information collected from the information-seeking conversation to write a section.\"\"\"\n\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        super().__init__()\n        self.write_section = dspy.Predict(WriteSection)\n        self.engine = engine\n\n    def forward(\n        self, topic: str, outline: str, section: str, collected_info: List[Information]\n    ):\n        info = \"\"\n        for idx, storm_info in enumerate(collected_info):\n            info += f\"[{idx + 1}]\\n\" + \"\\n\".join(storm_info.snippets)\n            info += \"\\n\\n\"\n\n        info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1500)\n\n        with dspy.settings.context(lm=self.engine):\n            section = ArticleTextProcessing.clean_up_section(\n                self.write_section(topic=topic, info=info, section=section).output\n            )\n\n        return dspy.Prediction(section=section)\n\n\nclass WriteSection(dspy.Signature):\n    \"\"\"Write a Wikipedia section based on the collected information.\n\n    Here is the format of your writing:\n        1. Use \"#\" Title\" to indicate section title, \"##\" Title\" to indicate subsection title, \"###\" Title\" to indicate subsubsection title, and so on.\n        2. Use [1], [2], ..., [n] in line (for example, \"The capital of the United States is Washington, D.C.[1][3].\"). You DO NOT need to include a References or Sources section to list the sources at the end.\n    \"\"\"\n\n    info = dspy.InputField(prefix=\"The collected information:\\n\", format=str)\n    topic = dspy.InputField(prefix=\"The topic of the page: \", format=str)\n    section = dspy.InputField(prefix=\"The section you need to write: \", format=str)\n    output = dspy.OutputField(\n        prefix=\"Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\\n\",\n        format=str,\n    )\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/article_polish.py",
    "content": "import copy\nfrom typing import Union\n\nimport dspy\n\nfrom .storm_dataclass import StormArticle\nfrom ...interface import ArticlePolishingModule\nfrom ...utils import ArticleTextProcessing\n\n\nclass StormArticlePolishingModule(ArticlePolishingModule):\n    \"\"\"\n    The interface for article generation stage. Given topic, collected information from\n    knowledge curation stage, generated outline from outline generation stage.\n    \"\"\"\n\n    def __init__(\n        self,\n        article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n    ):\n        self.article_gen_lm = article_gen_lm\n        self.article_polish_lm = article_polish_lm\n\n        self.polish_page = PolishPageModule(\n            write_lead_engine=self.article_gen_lm, polish_engine=self.article_polish_lm\n        )\n\n    def polish_article(\n        self, topic: str, draft_article: StormArticle, remove_duplicate: bool = False\n    ) -> StormArticle:\n        \"\"\"\n        Polish article.\n\n        Args:\n            topic (str): The topic of the article.\n            draft_article (StormArticle): The draft article.\n            remove_duplicate (bool): Whether to use one additional LM call to remove duplicates from the article.\n        \"\"\"\n\n        article_text = draft_article.to_string()\n        polish_result = self.polish_page(\n            topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate\n        )\n        lead_section = f\"# summary\\n{polish_result.lead_section}\"\n        polished_article = \"\\n\\n\".join([lead_section, polish_result.page])\n        polished_article_dict = ArticleTextProcessing.parse_article_into_dict(\n            polished_article\n        )\n        polished_article = copy.deepcopy(draft_article)\n        polished_article.insert_or_create_section(article_dict=polished_article_dict)\n        polished_article.post_processing()\n        return polished_article\n\n\nclass WriteLeadSection(dspy.Signature):\n    \"\"\"Write a lead section for the given Wikipedia page with the following guidelines:\n    1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies.\n    2. The lead section should be concise and contain no more than four well-composed paragraphs.\n    3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., \"Washington, D.C., is the capital of the United States.[1][3].\") where necessary.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"The topic of the page: \", format=str)\n    draft_page = dspy.InputField(prefix=\"The draft page:\\n\", format=str)\n    lead_section = dspy.OutputField(prefix=\"Write the lead section:\\n\", format=str)\n\n\nclass PolishPage(dspy.Signature):\n    \"\"\"You are a faithful text editor that is good at finding repeated information in the article and deleting them to make sure there is no repetition in the article. You won't delete any non-repeated part in the article. You will keep the inline citations and article structure (indicated by \"#\", \"##\", etc.) appropriately. Do your job for the following article.\"\"\"\n\n    draft_page = dspy.InputField(prefix=\"The draft article:\\n\", format=str)\n    page = dspy.OutputField(prefix=\"Your revised article:\\n\", format=str)\n\n\nclass PolishPageModule(dspy.Module):\n    def __init__(\n        self,\n        write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n    ):\n        super().__init__()\n        self.write_lead_engine = write_lead_engine\n        self.polish_engine = polish_engine\n        self.write_lead = dspy.Predict(WriteLeadSection)\n        self.polish_page = dspy.Predict(PolishPage)\n\n    def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True):\n        # NOTE: Change show_guidelines to false to make the generation more robust to different LM families.\n        with dspy.settings.context(lm=self.write_lead_engine, show_guidelines=False):\n            lead_section = self.write_lead(\n                topic=topic, draft_page=draft_page\n            ).lead_section\n            if \"The lead section:\" in lead_section:\n                lead_section = lead_section.split(\"The lead section:\")[1].strip()\n        if polish_whole_page:\n            # NOTE: Change show_guidelines to false to make the generation more robust to different LM families.\n            with dspy.settings.context(lm=self.polish_engine, show_guidelines=False):\n                page = self.polish_page(draft_page=draft_page).page\n        else:\n            page = draft_page\n\n        return dspy.Prediction(lead_section=lead_section, page=page)\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/callback.py",
    "content": "class BaseCallbackHandler:\n    \"\"\"Base callback handler that can be used to handle callbacks from the STORM pipeline.\"\"\"\n\n    def on_identify_perspective_start(self, **kwargs):\n        \"\"\"Run when the perspective identification starts.\"\"\"\n        pass\n\n    def on_identify_perspective_end(self, perspectives: list[str], **kwargs):\n        \"\"\"Run when the perspective identification finishes.\"\"\"\n        pass\n\n    def on_information_gathering_start(self, **kwargs):\n        \"\"\"Run when the information gathering starts.\"\"\"\n        pass\n\n    def on_dialogue_turn_end(self, dlg_turn, **kwargs):\n        \"\"\"Run when a question asking and answering turn finishes.\"\"\"\n        pass\n\n    def on_information_gathering_end(self, **kwargs):\n        \"\"\"Run when the information gathering finishes.\"\"\"\n        pass\n\n    def on_information_organization_start(self, **kwargs):\n        \"\"\"Run when the information organization starts.\"\"\"\n        pass\n\n    def on_direct_outline_generation_end(self, outline: str, **kwargs):\n        \"\"\"Run when the direct outline generation finishes.\"\"\"\n        pass\n\n    def on_outline_refinement_end(self, outline: str, **kwargs):\n        \"\"\"Run when the outline refinement finishes.\"\"\"\n        pass\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/knowledge_curation.py",
    "content": "import concurrent.futures\nimport logging\nimport os\nfrom concurrent.futures import as_completed\nfrom typing import Union, List, Tuple, Optional, Dict\n\nimport dspy\n\nfrom .callback import BaseCallbackHandler\nfrom .persona_generator import StormPersonaGenerator\nfrom .storm_dataclass import DialogueTurn, StormInformationTable\nfrom ...interface import KnowledgeCurationModule, Retriever, Information\nfrom ...utils import ArticleTextProcessing\n\ntry:\n    from streamlit.runtime.scriptrunner import add_script_run_ctx\n\n    streamlit_connection = True\nexcept ImportError as err:\n    streamlit_connection = False\n\nscript_dir = os.path.dirname(os.path.abspath(__file__))\n\n\nclass ConvSimulator(dspy.Module):\n    \"\"\"Simulate a conversation between a Wikipedia writer with specific persona and an expert.\"\"\"\n\n    def __init__(\n        self,\n        topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        retriever: Retriever,\n        max_search_queries_per_turn: int,\n        search_top_k: int,\n        max_turn: int,\n    ):\n        super().__init__()\n        self.wiki_writer = WikiWriter(engine=question_asker_engine)\n        self.topic_expert = TopicExpert(\n            engine=topic_expert_engine,\n            max_search_queries=max_search_queries_per_turn,\n            search_top_k=search_top_k,\n            retriever=retriever,\n        )\n        self.max_turn = max_turn\n\n    def forward(\n        self,\n        topic: str,\n        persona: str,\n        ground_truth_url: str,\n        callback_handler: BaseCallbackHandler,\n    ):\n        \"\"\"\n        topic: The topic to research.\n        persona: The persona of the Wikipedia writer.\n        ground_truth_url: The ground_truth_url will be excluded from search to avoid ground truth leakage in evaluation.\n        \"\"\"\n        dlg_history: List[DialogueTurn] = []\n        for _ in range(self.max_turn):\n            user_utterance = self.wiki_writer(\n                topic=topic, persona=persona, dialogue_turns=dlg_history\n            ).question\n            if user_utterance == \"\":\n                logging.error(\"Simulated Wikipedia writer utterance is empty.\")\n                break\n            if user_utterance.startswith(\"Thank you so much for your help!\"):\n                break\n            expert_output = self.topic_expert(\n                topic=topic, question=user_utterance, ground_truth_url=ground_truth_url\n            )\n            dlg_turn = DialogueTurn(\n                agent_utterance=expert_output.answer,\n                user_utterance=user_utterance,\n                search_queries=expert_output.queries,\n                search_results=expert_output.searched_results,\n            )\n            dlg_history.append(dlg_turn)\n            callback_handler.on_dialogue_turn_end(dlg_turn=dlg_turn)\n\n        return dspy.Prediction(dlg_history=dlg_history)\n\n\nclass WikiWriter(dspy.Module):\n    \"\"\"Perspective-guided question asking in conversational setup.\n\n    The asked question will be used to start a next round of information seeking.\"\"\"\n\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        super().__init__()\n        self.ask_question_with_persona = dspy.ChainOfThought(AskQuestionWithPersona)\n        self.ask_question = dspy.ChainOfThought(AskQuestion)\n        self.engine = engine\n\n    def forward(\n        self,\n        topic: str,\n        persona: str,\n        dialogue_turns: List[DialogueTurn],\n        draft_page=None,\n    ):\n        conv = []\n        for turn in dialogue_turns[:-4]:\n            conv.append(\n                f\"You: {turn.user_utterance}\\nExpert: Omit the answer here due to space limit.\"\n            )\n        for turn in dialogue_turns[-4:]:\n            conv.append(\n                f\"You: {turn.user_utterance}\\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}\"\n            )\n        conv = \"\\n\".join(conv)\n        conv = conv.strip() or \"N/A\"\n        conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 2500)\n\n        with dspy.settings.context(lm=self.engine):\n            if persona is not None and len(persona.strip()) > 0:\n                question = self.ask_question_with_persona(\n                    topic=topic, persona=persona, conv=conv\n                ).question\n            else:\n                question = self.ask_question(\n                    topic=topic, persona=persona, conv=conv\n                ).question\n\n        return dspy.Prediction(question=question)\n\n\nclass AskQuestion(dspy.Signature):\n    \"\"\"You are an experienced Wikipedia writer. You are chatting with an expert to get information for the topic you want to contribute. Ask good questions to get more useful information relevant to the topic.\n    When you have no more question to ask, say \"Thank you so much for your help!\" to end the conversation.\n    Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic you want to write: \", format=str)\n    conv = dspy.InputField(prefix=\"Conversation history:\\n\", format=str)\n    question = dspy.OutputField(format=str)\n\n\nclass AskQuestionWithPersona(dspy.Signature):\n    \"\"\"You are an experienced Wikipedia writer and want to edit a specific page. Besides your identity as a Wikipedia writer, you have specific focus when researching the topic.\n    Now, you are chatting with an expert to get information. Ask good questions to get more useful information.\n    When you have no more question to ask, say \"Thank you so much for your help!\" to end the conversation.\n    Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic you want to write: \", format=str)\n    persona = dspy.InputField(\n        prefix=\"Your persona besides being a Wikipedia writer: \", format=str\n    )\n    conv = dspy.InputField(prefix=\"Conversation history:\\n\", format=str)\n    question = dspy.OutputField(format=str)\n\n\nclass QuestionToQuery(dspy.Signature):\n    \"\"\"You want to answer the question using Google search. What do you type in the search box?\n    Write the queries you will use in the following format:\n    - query 1\n    - query 2\n    ...\n    - query n\"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic you are discussing about: \", format=str)\n    question = dspy.InputField(prefix=\"Question you want to answer: \", format=str)\n    queries = dspy.OutputField(format=str)\n\n\nclass AnswerQuestion(dspy.Signature):\n    \"\"\"You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants to write a Wikipedia page on topic you know. You have gathered the related information and will now use the information to form a response.\n    Make your response as informative as possible, ensuring that every sentence is supported by the gathered information. If the [gathered information] is not directly related to the [topic] or [question], provide the most relevant answer based on the available information. If no appropriate answer can be formulated, respond with, “I cannot answer this question based on the available information,” and explain any limitations or gaps.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic you are discussing about:\", format=str)\n    conv = dspy.InputField(prefix=\"Question:\\n\", format=str)\n    info = dspy.InputField(prefix=\"Gathered information:\\n\", format=str)\n    answer = dspy.OutputField(\n        prefix=\"Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\\n\",\n        format=str,\n    )\n\n\nclass TopicExpert(dspy.Module):\n    \"\"\"Answer questions using search-based retrieval and answer generation. This module conducts the following steps:\n    1. Generate queries from the question.\n    2. Search for information using the queries.\n    3. Filter out unreliable sources.\n    4. Generate an answer using the retrieved information.\n    \"\"\"\n\n    def __init__(\n        self,\n        engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        max_search_queries: int,\n        search_top_k: int,\n        retriever: Retriever,\n    ):\n        super().__init__()\n        self.generate_queries = dspy.Predict(QuestionToQuery)\n        self.retriever = retriever\n        self.answer_question = dspy.Predict(AnswerQuestion)\n        self.engine = engine\n        self.max_search_queries = max_search_queries\n        self.search_top_k = search_top_k\n\n    def forward(self, topic: str, question: str, ground_truth_url: str):\n        with dspy.settings.context(lm=self.engine, show_guidelines=False):\n            # Identify: Break down question into queries.\n            queries = self.generate_queries(topic=topic, question=question).queries\n            queries = [\n                q.replace(\"-\", \"\").strip().strip('\"').strip('\"').strip()\n                for q in queries.split(\"\\n\")\n            ]\n            queries = queries[: self.max_search_queries]\n            # Search\n            searched_results: List[Information] = self.retriever.retrieve(\n                list(set(queries)), exclude_urls=[ground_truth_url]\n            )\n            if len(searched_results) > 0:\n                # Evaluate: Simplify this part by directly using the top 1 snippet.\n                info = \"\"\n                for n, r in enumerate(searched_results):\n                    info += \"\\n\".join(f\"[{n + 1}]: {s}\" for s in r.snippets[:1])\n                    info += \"\\n\\n\"\n\n                info = ArticleTextProcessing.limit_word_count_preserve_newline(\n                    info, 1000\n                )\n\n                try:\n                    answer = self.answer_question(\n                        topic=topic, conv=question, info=info\n                    ).answer\n                    answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(\n                        answer\n                    )\n                except Exception as e:\n                    logging.error(f\"Error occurs when generating answer: {e}\")\n                    answer = \"Sorry, I cannot answer this question. Please ask another question.\"\n            else:\n                # When no information is found, the expert shouldn't hallucinate.\n                answer = \"Sorry, I cannot find information for this question. Please ask another question.\"\n\n        return dspy.Prediction(\n            queries=queries, searched_results=searched_results, answer=answer\n        )\n\n\nclass StormKnowledgeCurationModule(KnowledgeCurationModule):\n    \"\"\"\n    The interface for knowledge curation stage. Given topic, return collected information.\n    \"\"\"\n\n    def __init__(\n        self,\n        retriever: Retriever,\n        persona_generator: Optional[StormPersonaGenerator],\n        conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],\n        max_search_queries_per_turn: int,\n        search_top_k: int,\n        max_conv_turn: int,\n        max_thread_num: int,\n    ):\n        \"\"\"\n        Store args and finish initialization.\n        \"\"\"\n        self.retriever = retriever\n        self.persona_generator = persona_generator\n        self.conv_simulator_lm = conv_simulator_lm\n        self.search_top_k = search_top_k\n        self.max_thread_num = max_thread_num\n        self.retriever = retriever\n        self.conv_simulator = ConvSimulator(\n            topic_expert_engine=conv_simulator_lm,\n            question_asker_engine=question_asker_lm,\n            retriever=retriever,\n            max_search_queries_per_turn=max_search_queries_per_turn,\n            search_top_k=search_top_k,\n            max_turn=max_conv_turn,\n        )\n\n    def _get_considered_personas(self, topic: str, max_num_persona) -> List[str]:\n        return self.persona_generator.generate_persona(\n            topic=topic, max_num_persona=max_num_persona\n        )\n\n    def _run_conversation(\n        self,\n        conv_simulator,\n        topic,\n        ground_truth_url,\n        considered_personas,\n        callback_handler: BaseCallbackHandler,\n    ) -> List[Tuple[str, List[DialogueTurn]]]:\n        \"\"\"\n        Executes multiple conversation simulations concurrently, each with a different persona,\n        and collects their dialog histories. The dialog history of each conversation is cleaned\n        up before being stored.\n\n        Parameters:\n            conv_simulator (callable): The function to simulate conversations. It must accept four\n                parameters: `topic`, `ground_truth_url`, `persona`, and `callback_handler`, and return\n                an object that has a `dlg_history` attribute.\n            topic (str): The topic of conversation for the simulations.\n            ground_truth_url (str): The URL to the ground truth data related to the conversation topic.\n            considered_personas (list): A list of personas under which the conversation simulations\n                will be conducted. Each persona is passed to `conv_simulator` individually.\n            callback_handler (callable): A callback function that is passed to `conv_simulator`. It\n                should handle any callbacks or events during the simulation.\n\n        Returns:\n            list of tuples: A list where each tuple contains a persona and its corresponding cleaned\n            dialog history (`dlg_history`) from the conversation simulation.\n        \"\"\"\n\n        conversations = []\n\n        def run_conv(persona):\n            return conv_simulator(\n                topic=topic,\n                ground_truth_url=ground_truth_url,\n                persona=persona,\n                callback_handler=callback_handler,\n            )\n\n        max_workers = min(self.max_thread_num, len(considered_personas))\n\n        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n            future_to_persona = {\n                executor.submit(run_conv, persona): persona\n                for persona in considered_personas\n            }\n\n            if streamlit_connection:\n                # Ensure the logging context is correct when connecting with Streamlit frontend.\n                for t in executor._threads:\n                    add_script_run_ctx(t)\n\n            for future in as_completed(future_to_persona):\n                persona = future_to_persona[future]\n                conv = future.result()\n                conversations.append(\n                    (persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history)\n                )\n\n        return conversations\n\n    def research(\n        self,\n        topic: str,\n        ground_truth_url: str,\n        callback_handler: BaseCallbackHandler,\n        max_perspective: int = 0,\n        disable_perspective: bool = True,\n        return_conversation_log=False,\n    ) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]:\n        \"\"\"\n        Curate information and knowledge for the given topic\n\n        Args:\n            topic: topic of interest in natural language.\n\n        Returns:\n            collected_information: collected information in InformationTable type.\n        \"\"\"\n\n        # identify personas\n        callback_handler.on_identify_perspective_start()\n        considered_personas = []\n        if disable_perspective:\n            considered_personas = [\"\"]\n        else:\n            considered_personas = self._get_considered_personas(\n                topic=topic, max_num_persona=max_perspective\n            )\n        callback_handler.on_identify_perspective_end(perspectives=considered_personas)\n\n        # run conversation\n        callback_handler.on_information_gathering_start()\n        conversations = self._run_conversation(\n            conv_simulator=self.conv_simulator,\n            topic=topic,\n            ground_truth_url=ground_truth_url,\n            considered_personas=considered_personas,\n            callback_handler=callback_handler,\n        )\n\n        information_table = StormInformationTable(conversations)\n        callback_handler.on_information_gathering_end()\n        if return_conversation_log:\n            return information_table, StormInformationTable.construct_log_dict(\n                conversations\n            )\n        return information_table\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/outline_generation.py",
    "content": "from typing import Union, Optional, Tuple\n\nimport dspy\n\nfrom .callback import BaseCallbackHandler\nfrom .storm_dataclass import StormInformationTable, StormArticle\nfrom ...interface import OutlineGenerationModule\nfrom ...utils import ArticleTextProcessing\n\n\nclass StormOutlineGenerationModule(OutlineGenerationModule):\n    \"\"\"\n    The interface for outline generation stage. Given topic, collected information from knowledge\n    curation stage, generate outline for the article.\n    \"\"\"\n\n    def __init__(self, outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        super().__init__()\n        self.outline_gen_lm = outline_gen_lm\n        self.write_outline = WriteOutline(engine=self.outline_gen_lm)\n\n    def generate_outline(\n        self,\n        topic: str,\n        information_table: StormInformationTable,\n        old_outline: Optional[StormArticle] = None,\n        callback_handler: BaseCallbackHandler = None,\n        return_draft_outline=False,\n    ) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]:\n        \"\"\"\n        Generates an outline for an article based on the specified topic and the information\n        gathered during the knowledge curation stage. This method can optionally return both the\n        final article outline and a draft outline if required.\n\n        Args:\n            topic (str): The topic of the article.\n            information_table (StormInformationTable): The information table containing the collected information.\n            old_outline (Optional[StormArticle]): An optional previous version of the article outline that can\n                be used for reference or comparison. Defaults to None.\n            callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger\n                custom callbacks at various stages of the outline generation process, such as when the information\n                organization starts. Defaults to None.\n            return_draft_outline (bool): A flag indicating whether the method should return both the final article\n                outline and a draft version of the outline. If False, only the final article outline is returned.\n                Defaults to False.\n\n        Returns:\n            Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`,\n                this method returns either a single `StormArticle` object containing the final outline or a tuple of\n                two  `StormArticle` objects, the first containing the final outline and the second containing the\n                draft outline.\n        \"\"\"\n        if callback_handler is not None:\n            callback_handler.on_information_organization_start()\n\n        concatenated_dialogue_turns = sum(\n            [conv for (_, conv) in information_table.conversations], []\n        )\n        result = self.write_outline(\n            topic=topic,\n            dlg_history=concatenated_dialogue_turns,\n            callback_handler=callback_handler,\n        )\n        article_with_outline_only = StormArticle.from_outline_str(\n            topic=topic, outline_str=result.outline\n        )\n        article_with_draft_outline_only = StormArticle.from_outline_str(\n            topic=topic, outline_str=result.old_outline\n        )\n        if not return_draft_outline:\n            return article_with_outline_only\n        return article_with_outline_only, article_with_draft_outline_only\n\n\nclass WriteOutline(dspy.Module):\n    \"\"\"Generate the outline for the Wikipedia page.\"\"\"\n\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        super().__init__()\n        self.draft_page_outline = dspy.Predict(WritePageOutline)\n        self.write_page_outline = dspy.Predict(WritePageOutlineFromConv)\n        self.engine = engine\n\n    def forward(\n        self,\n        topic: str,\n        dlg_history,\n        old_outline: Optional[str] = None,\n        callback_handler: BaseCallbackHandler = None,\n    ):\n        trimmed_dlg_history = []\n        for turn in dlg_history:\n            if (\n                \"topic you\" in turn.agent_utterance.lower()\n                or \"topic you\" in turn.user_utterance.lower()\n            ):\n                continue\n            trimmed_dlg_history.append(turn)\n        conv = \"\\n\".join(\n            [\n                f\"Wikipedia Writer: {turn.user_utterance}\\nExpert: {turn.agent_utterance}\"\n                for turn in trimmed_dlg_history\n            ]\n        )\n        conv = ArticleTextProcessing.remove_citations(conv)\n        conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 5000)\n\n        with dspy.settings.context(lm=self.engine):\n            if old_outline is None:\n                old_outline = ArticleTextProcessing.clean_up_outline(\n                    self.draft_page_outline(topic=topic).outline\n                )\n                if callback_handler:\n                    callback_handler.on_direct_outline_generation_end(\n                        outline=old_outline\n                    )\n            outline = ArticleTextProcessing.clean_up_outline(\n                self.write_page_outline(\n                    topic=topic, old_outline=old_outline, conv=conv\n                ).outline\n            )\n            if callback_handler:\n                callback_handler.on_outline_refinement_end(outline=outline)\n\n        return dspy.Prediction(outline=outline, old_outline=old_outline)\n\n\nclass WritePageOutline(dspy.Signature):\n    \"\"\"Write an outline for a Wikipedia page.\n    Here is the format of your writing:\n    1. Use \"#\" Title\" to indicate section title, \"##\" Title\" to indicate subsection title, \"###\" Title\" to indicate subsubsection title, and so on.\n    2. Do not include other information.\n    3. Do not include topic name itself in the outline.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"The topic you want to write: \", format=str)\n    outline = dspy.OutputField(prefix=\"Write the Wikipedia page outline:\\n\", format=str)\n\n\nclass NaiveOutlineGen(dspy.Module):\n    \"\"\"Generate the outline with LLM's parametric knowledge directly.\"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.write_outline = dspy.Predict(WritePageOutline)\n\n    def forward(self, topic: str):\n        outline = self.write_outline(topic=topic).outline\n\n        return dspy.Prediction(outline=outline)\n\n\nclass WritePageOutlineFromConv(dspy.Signature):\n    \"\"\"Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the information learned from an information-seeking conversation to make it more informative.\n    Here is the format of your writing:\n    1. Use \"#\" Title\" to indicate section title, \"##\" Title\" to indicate subsection title, \"###\" Title\" to indicate subsubsection title, and so on.\n    2. Do not include other information.\n    3. Do not include topic name itself in the outline.\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"The topic you want to write: \", format=str)\n    conv = dspy.InputField(prefix=\"Conversation history:\\n\", format=str)\n    old_outline = dspy.OutputField(prefix=\"Current outline:\\n\", format=str)\n    outline = dspy.OutputField(\n        prefix='Write the Wikipedia page outline (Use \"#\" Title\" to indicate section title, \"##\" Title\" to indicate subsection title, ...):\\n',\n        format=str,\n    )\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/persona_generator.py",
    "content": "import logging\nimport re\nfrom typing import Union, List\n\nimport dspy\nimport requests\nfrom bs4 import BeautifulSoup\n\n\ndef get_wiki_page_title_and_toc(url):\n    \"\"\"Get the main title and table of contents from an url of a Wikipedia page.\"\"\"\n\n    response = requests.get(url)\n    soup = BeautifulSoup(response.content, \"html.parser\")\n\n    # Get the main title from the first h1 tag\n    main_title = soup.find(\"h1\").text.replace(\"[edit]\", \"\").strip().replace(\"\\xa0\", \" \")\n\n    toc = \"\"\n    levels = []\n    excluded_sections = {\n        \"Contents\",\n        \"See also\",\n        \"Notes\",\n        \"References\",\n        \"External links\",\n    }\n\n    # Start processing from h2 to exclude the main title from TOC\n    for header in soup.find_all([\"h2\", \"h3\", \"h4\", \"h5\", \"h6\"]):\n        level = int(\n            header.name[1]\n        )  # Extract the numeric part of the header tag (e.g., '2' from 'h2')\n        section_title = header.text.replace(\"[edit]\", \"\").strip().replace(\"\\xa0\", \" \")\n        if section_title in excluded_sections:\n            continue\n\n        while levels and level <= levels[-1]:\n            levels.pop()\n        levels.append(level)\n\n        indentation = \"  \" * (len(levels) - 1)\n        toc += f\"{indentation}{section_title}\\n\"\n\n    return main_title, toc.strip()\n\n\nclass FindRelatedTopic(dspy.Signature):\n    \"\"\"I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics.\n    Please list the urls in separate lines.\"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic of interest:\", format=str)\n    related_topics = dspy.OutputField(format=str)\n\n\nclass GenPersona(dspy.Signature):\n    \"\"\"You need to select a group of Wikipedia editors who will work together to create a comprehensive article on the topic. Each of them represents a different perspective, role, or affiliation related to this topic. You can use other Wikipedia pages of related topics for inspiration. For each editor, add a description of what they will focus on.\n    Give your answer in the following format: 1. short summary of editor 1: description\\n2. short summary of editor 2: description\\n...\n    \"\"\"\n\n    topic = dspy.InputField(prefix=\"Topic of interest:\", format=str)\n    examples = dspy.InputField(\n        prefix=\"Wiki page outlines of related topics for inspiration:\\n\", format=str\n    )\n    personas = dspy.OutputField(format=str)\n\n\nclass CreateWriterWithPersona(dspy.Module):\n    \"\"\"Discover different perspectives of researching the topic by reading Wikipedia pages of related topics.\"\"\"\n\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        super().__init__()\n        self.find_related_topic = dspy.ChainOfThought(FindRelatedTopic)\n        self.gen_persona = dspy.ChainOfThought(GenPersona)\n        self.engine = engine\n\n    def forward(self, topic: str, draft=None):\n        with dspy.settings.context(lm=self.engine):\n            # Get section names from wiki pages of relevant topics for inspiration.\n            related_topics = self.find_related_topic(topic=topic).related_topics\n            urls = []\n            for s in related_topics.split(\"\\n\"):\n                if \"http\" in s:\n                    urls.append(s[s.find(\"http\") :])\n            examples = []\n            for url in urls:\n                try:\n                    title, toc = get_wiki_page_title_and_toc(url)\n                    examples.append(f\"Title: {title}\\nTable of Contents: {toc}\")\n                except Exception as e:\n                    logging.error(f\"Error occurs when processing {url}: {e}\")\n                    continue\n            if len(examples) == 0:\n                examples.append(\"N/A\")\n            gen_persona_output = self.gen_persona(\n                topic=topic, examples=\"\\n----------\\n\".join(examples)\n            ).personas\n\n        personas = []\n        for s in gen_persona_output.split(\"\\n\"):\n            match = re.search(r\"\\d+\\.\\s*(.*)\", s)\n            if match:\n                personas.append(match.group(1))\n\n        sorted_personas = personas\n\n        return dspy.Prediction(\n            personas=personas,\n            raw_personas_output=sorted_personas,\n            related_topics=related_topics,\n        )\n\n\nclass StormPersonaGenerator:\n    \"\"\"\n    A generator class for creating personas based on a given topic.\n\n    This class uses an underlying engine to generate personas tailored to the specified topic.\n    The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas,\n    including a default 'Basic fact writer' persona.\n\n    Attributes:\n        create_writer_with_persona (CreateWriterWithPersona): An instance responsible for\n            generating personas based on the provided engine and topic.\n\n    Args:\n        engine (Union[dspy.dsp.LM, dspy.dsp.HFModel]): The underlying engine used for generating\n            personas. It must be an instance of either `dspy.dsp.LM` or `dspy.dsp.HFModel`.\n    \"\"\"\n\n    def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):\n        self.create_writer_with_persona = CreateWriterWithPersona(engine=engine)\n\n    def generate_persona(self, topic: str, max_num_persona: int = 3) -> List[str]:\n        \"\"\"\n        Generates a list of personas based on the provided topic, up to a maximum number specified.\n\n        This method first creates personas using the underlying `create_writer_with_persona` instance\n        and then prepends a default 'Basic fact writer' persona to the list before returning it.\n        The number of personas returned is limited to `max_num_persona`, excluding the default persona.\n\n        Args:\n            topic (str): The topic for which personas are to be generated.\n            max_num_persona (int): The maximum number of personas to generate, excluding the\n                default 'Basic fact writer' persona.\n\n        Returns:\n            List[str]: A list of persona descriptions, including the default 'Basic fact writer' persona\n                and up to `max_num_persona` additional personas generated based on the topic.\n        \"\"\"\n        personas = self.create_writer_with_persona(topic=topic)\n        default_persona = \"Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic.\"\n        considered_personas = [default_persona] + personas.personas[:max_num_persona]\n        return considered_personas\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/retriever.py",
    "content": "from typing import Union, List\nfrom urllib.parse import urlparse\n\nimport dspy\n\nfrom ...interface import Retriever, Information\nfrom ...utils import ArticleTextProcessing\n\n# Internet source restrictions according to Wikipedia standard:\n# https://en.wikipedia.org/wiki/Wikipedia:Reliable_sources/Perennial_sources\nGENERALLY_UNRELIABLE = {\n    \"112_Ukraine\",\n    \"Ad_Fontes_Media\",\n    \"AlterNet\",\n    \"Amazon\",\n    \"Anadolu_Agency_(controversial_topics)\",\n    \"Ancestry.com\",\n    \"Answers.com\",\n    \"Antiwar.com\",\n    \"Anti-Defamation_League\",\n    \"arXiv\",\n    \"Atlas_Obscura_places\",\n    \"Bild\",\n    \"Blaze_Media\",\n    \"Blogger\",\n    \"BroadwayWorld\",\n    \"California_Globe\",\n    \"The_Canary\",\n    \"CelebrityNetWorth\",\n    \"CESNUR\",\n    \"ChatGPT\",\n    \"CNET_(November_2022\\u2013present)\",\n    \"CoinDesk\",\n    \"Consortium_News\",\n    \"CounterPunch\",\n    \"Correo_del_Orinoco\",\n    \"Cracked.com\",\n    \"Daily_Express\",\n    \"Daily_Kos\",\n    \"Daily_Sabah\",\n    \"The_Daily_Wire\",\n    \"Discogs\",\n    \"Distractify\",\n    \"The_Electronic_Intifada\",\n    \"Encyclopaedia_Metallum\",\n    \"Ethnicity_of_Celebs\",\n    \"Facebook\",\n    \"FamilySearch\",\n    \"Fandom\",\n    \"The_Federalist\",\n    \"Find_a_Grave\",\n    \"Findmypast\",\n    \"Flags_of_the_World\",\n    \"Flickr\",\n    \"Forbes.com_contributors\",\n    \"Fox_News_(politics_and_science)\",\n    \"Fox_News_(talk_shows)\",\n    \"Gawker\",\n    \"GB_News\",\n    \"Geni.com\",\n    \"gnis-class\",\n    \"gns-class\",\n    \"GlobalSecurity.org\",\n    \"Goodreads\",\n    \"Guido_Fawkes\",\n    \"Heat_Street\",\n    \"History\",\n    \"HuffPost_contributors\",\n    \"IMDb\",\n    \"Independent_Media_Center\",\n    \"Inquisitr\",\n    \"International_Business_Times\",\n    \"Investopedia\",\n    \"Jewish_Virtual_Library\",\n    \"Joshua_Project\",\n    \"Know_Your_Meme\",\n    \"Land_Transport_Guru\",\n    \"LinkedIn\",\n    \"LiveJournal\",\n    \"Marquis_Who's_Who\",\n    \"Mashable_sponsored_content\",\n    \"MEAWW\",\n    \"Media_Bias/Fact_Check\",\n    \"Media_Research_Center\",\n    \"Medium\",\n    \"metal-experience\",\n    \"Metro\",\n    \"The_New_American\",\n    \"New_York_Post\",\n    \"NGO_Monitor\",\n    \"The_Onion\",\n    \"Our_Campaigns\",\n    \"PanAm_Post\",\n    \"Patheos\",\n    \"An_Phoblacht\",\n    \"The_Post_Millennial\",\n    \"arXiv\",\n    \"bioRxiv\",\n    \"medRxiv\",\n    \"PeerJ Preprints\",\n    \"Preprints.org\",\n    \"SSRN\",\n    \"PR_Newswire\",\n    \"Quadrant\",\n    \"Quillette\",\n    \"Quora\",\n    \"Raw_Story\",\n    \"Reddit\",\n    \"RedState\",\n    \"ResearchGate\",\n    \"Rolling_Stone_(politics_and_society,_2011\\u2013present)\",\n    \"Rolling_Stone_(Culture_Council)\",\n    \"Scribd\",\n    \"Scriptural_texts\",\n    \"Simple_Flying\",\n    \"Sixth_Tone_(politics)\",\n    \"The_Skwawkbox\",\n    \"SourceWatch\",\n    \"Spirit_of_Metal\",\n    \"Sportskeeda\",\n    \"Stack_Exchange\",\n    \"Stack_Overflow\",\n    \"MathOverflow\",\n    \"Ask_Ubuntu\",\n    \"starsunfolded.com\",\n    \"Statista\",\n    \"TASS\",\n    \"The_Truth_About_Guns\",\n    \"TV.com\",\n    \"TV_Tropes\",\n    \"Twitter\",\n    \"X.com\",\n    \"Urban_Dictionary\",\n    \"Venezuelanalysis\",\n    \"VGChartz\",\n    \"VoC\",\n    \"Washington_Free_Beacon\",\n    \"Weather2Travel\",\n    \"The_Western_Journal\",\n    \"We_Got_This_Covered\",\n    \"WhatCulture\",\n    \"Who's_Who_(UK)\",\n    \"WhoSampled\",\n    \"Wikidata\",\n    \"WikiLeaks\",\n    \"Wikinews\",\n    \"Wikipedia\",\n    \"WordPress.com\",\n    \"Worldometer\",\n    \"YouTube\",\n    \"ZDNet\",\n}\nDEPRECATED = {\n    \"Al_Mayadeen\",\n    \"ANNA_News\",\n    \"Baidu_Baike\",\n    \"China_Global_Television_Network\",\n    \"The_Cradle\",\n    \"Crunchbase\",\n    \"The_Daily_Caller\",\n    \"Daily_Mail\",\n    \"Daily_Star\",\n    \"The_Epoch_Times\",\n    \"FrontPage_Magazine\",\n    \"The_Gateway_Pundit\",\n    \"Global_Times\",\n    \"The_Grayzone\",\n    \"HispanTV\",\n    \"Jihad_Watch\",\n    \"Last.fm\",\n    \"LifeSiteNews\",\n    \"The_Mail_on_Sunday\",\n    \"MintPress_News\",\n    \"National_Enquirer\",\n    \"New_Eastern_Outlook\",\n    \"News_Break\",\n    \"NewsBlaze\",\n    \"News_of_the_World\",\n    \"Newsmax\",\n    \"NNDB\",\n    \"Occupy_Democrats\",\n    \"Office_of_Cuba_Broadcasting\",\n    \"One_America_News_Network\",\n    \"Peerage_websites\",\n    \"Press_TV\",\n    \"Project_Veritas\",\n    \"Rate_Your_Music\",\n    \"Republic_TV\",\n    \"Royal_Central\",\n    \"RT\",\n    \"Sputnik\",\n    \"The_Sun\",\n    \"Taki's_Magazine\",\n    \"Tasnim_News_Agency\",\n    \"Telesur\",\n    \"The_Unz_Review\",\n    \"VDARE\",\n    \"Voltaire_Network\",\n    \"WorldNetDaily\",\n    \"Zero_Hedge\",\n}\nBLACKLISTED = {\n    \"Advameg\",\n    \"bestgore.com\",\n    \"Breitbart_News\",\n    \"Centre_for_Research_on_Globalization\",\n    \"Examiner.com\",\n    \"Famous_Birthdays\",\n    \"Healthline\",\n    \"InfoWars\",\n    \"Lenta.ru\",\n    \"LiveLeak\",\n    \"Lulu.com\",\n    \"MyLife\",\n    \"Natural_News\",\n    \"OpIndia\",\n    \"The_Points_Guy\",\n    \"The_Points_Guy_(sponsored_content)\",\n    \"Swarajya\",\n    \"Veterans_Today\",\n    \"ZoomInfo\",\n}\n\n\ndef is_valid_wikipedia_source(url):\n    parsed_url = urlparse(url)\n    # Check if the URL is from a reliable domain\n    combined_set = GENERALLY_UNRELIABLE | DEPRECATED | BLACKLISTED\n    for domain in combined_set:\n        if domain in parsed_url.netloc:\n            return False\n\n    return True\n"
  },
  {
    "path": "knowledge_storm/storm_wiki/modules/storm_dataclass.py",
    "content": "import copy\nimport re\nfrom collections import OrderedDict\nfrom typing import Union, Optional, Any, List, Tuple, Dict\n\nimport numpy as np\nfrom sentence_transformers import SentenceTransformer\nfrom sklearn.metrics.pairwise import cosine_similarity\n\nfrom ...interface import Information, InformationTable, Article, ArticleSectionNode\nfrom ...utils import ArticleTextProcessing, FileIOHelper\n\n\nclass DialogueTurn:\n    def __init__(\n        self,\n        agent_utterance: str = None,\n        user_utterance: str = None,\n        search_queries: Optional[List[str]] = None,\n        search_results: Optional[List[Union[Information, Dict]]] = None,\n    ):\n        self.agent_utterance = agent_utterance\n        self.user_utterance = user_utterance\n        self.search_queries = search_queries\n        self.search_results = search_results\n\n        if self.search_results:\n            for idx in range(len(self.search_results)):\n                if type(self.search_results[idx]) == dict:\n                    self.search_results[idx] = Information.from_dict(\n                        self.search_results[idx]\n                    )\n\n    def log(self):\n        \"\"\"\n        Returns a json object that contains all information inside `self`\n        \"\"\"\n        return OrderedDict(\n            {\n                \"agent_utterance\": self.agent_utterance,\n                \"user_utterance\": self.user_utterance,\n                \"search_queries\": self.search_queries,\n                \"search_results\": [data.to_dict() for data in self.search_results],\n            }\n        )\n\n\nclass StormInformationTable(InformationTable):\n    \"\"\"\n    The InformationTable class serves as data class to store the information\n    collected during KnowledgeCuration stage.\n\n    Create subclass to incorporate more information as needed. For example,\n    in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information\n    would be perspective guided dialogue history.\n    \"\"\"\n\n    def __init__(self, conversations=List[Tuple[str, List[DialogueTurn]]]):\n        super().__init__()\n        self.conversations = conversations\n        self.url_to_info: Dict[str, Information] = (\n            StormInformationTable.construct_url_to_info(self.conversations)\n        )\n\n    @staticmethod\n    def construct_url_to_info(\n        conversations: List[Tuple[str, List[DialogueTurn]]]\n    ) -> Dict[str, Information]:\n        url_to_info = {}\n\n        for persona, conv in conversations:\n            for turn in conv:\n                for storm_info in turn.search_results:\n                    if storm_info.url in url_to_info:\n                        url_to_info[storm_info.url].snippets.extend(storm_info.snippets)\n                    else:\n                        url_to_info[storm_info.url] = storm_info\n        for url in url_to_info:\n            url_to_info[url].snippets = list(set(url_to_info[url].snippets))\n        return url_to_info\n\n    @staticmethod\n    def construct_log_dict(\n        conversations: List[Tuple[str, List[DialogueTurn]]]\n    ) -> List[Dict[str, Union[str, Any]]]:\n        conversation_log = []\n        for persona, conv in conversations:\n            conversation_log.append(\n                {\"perspective\": persona, \"dlg_turns\": [turn.log() for turn in conv]}\n            )\n        return conversation_log\n\n    def dump_url_to_info(self, path):\n        url_to_info = copy.deepcopy(self.url_to_info)\n        for url in url_to_info:\n            url_to_info[url] = url_to_info[url].to_dict()\n        FileIOHelper.dump_json(url_to_info, path)\n\n    @classmethod\n    def from_conversation_log_file(cls, path):\n        conversation_log_data = FileIOHelper.load_json(path)\n        conversations = []\n        for item in conversation_log_data:\n            dialogue_turns = [DialogueTurn(**turn) for turn in item[\"dlg_turns\"]]\n            persona = item[\"perspective\"]\n            conversations.append((persona, dialogue_turns))\n        return cls(conversations)\n\n    def prepare_table_for_retrieval(self):\n        self.encoder = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\")\n        self.collected_urls = []\n        self.collected_snippets = []\n        for url, information in self.url_to_info.items():\n            for snippet in information.snippets:\n                self.collected_urls.append(url)\n                self.collected_snippets.append(snippet)\n        self.encoded_snippets = self.encoder.encode(self.collected_snippets)\n\n    def retrieve_information(\n        self, queries: Union[List[str], str], search_top_k\n    ) -> List[Information]:\n        selected_urls = []\n        selected_snippets = []\n        if type(queries) is str:\n            queries = [queries]\n        for query in queries:\n            encoded_query = self.encoder.encode(query)\n            sim = cosine_similarity([encoded_query], self.encoded_snippets)[0]\n            sorted_indices = np.argsort(sim)\n            for i in sorted_indices[-search_top_k:][::-1]:\n                selected_urls.append(self.collected_urls[i])\n                selected_snippets.append(self.collected_snippets[i])\n\n        url_to_snippets = {}\n        for url, snippet in zip(selected_urls, selected_snippets):\n            if url not in url_to_snippets:\n                url_to_snippets[url] = set()\n            url_to_snippets[url].add(snippet)\n\n        selected_url_to_info = {}\n        for url in url_to_snippets:\n            selected_url_to_info[url] = copy.deepcopy(self.url_to_info[url])\n            selected_url_to_info[url].snippets = list(url_to_snippets[url])\n\n        return list(selected_url_to_info.values())\n\n\nclass StormArticle(Article):\n    def __init__(self, topic_name):\n        super().__init__(topic_name=topic_name)\n        self.reference = {\"url_to_unified_index\": {}, \"url_to_info\": {}}\n\n    def find_section(\n        self, node: ArticleSectionNode, name: str\n    ) -> Optional[ArticleSectionNode]:\n        \"\"\"\n        Return the node of the section given the section name.\n\n        Args:\n            node: the node as the root to find.\n            name: the name of node as section name\n\n        Return:\n            reference of the node or None if section name has no match\n        \"\"\"\n        if node.section_name == name:\n            return node\n        for child in node.children:\n            result = self.find_section(child, name)\n            if result:\n                return result\n        return None\n\n    def _merge_new_info_to_references(\n        self, new_info_list: List[Information], index_to_keep=None\n    ) -> Dict[int, int]:\n        \"\"\"\n        Merges new storm information into existing references and updates the citation index mapping.\n\n        Args:\n        new_info_list (List[Information]): A list of dictionaries representing new storm information.\n        index_to_keep (List[int]): A list of index of the new_info_list to keep. If none, keep all.\n\n        Returns:\n        Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list\n                        to its unified citation index in the references.\n        \"\"\"\n        citation_idx_mapping = {}\n        for idx, storm_info in enumerate(new_info_list):\n            if index_to_keep is not None and idx not in index_to_keep:\n                continue\n            url = storm_info.url\n            if url not in self.reference[\"url_to_unified_index\"]:\n                self.reference[\"url_to_unified_index\"][url] = (\n                    len(self.reference[\"url_to_unified_index\"]) + 1\n                )  # The citation index starts from 1.\n                self.reference[\"url_to_info\"][url] = storm_info\n            else:\n                existing_snippets = self.reference[\"url_to_info\"][url].snippets\n                existing_snippets.extend(storm_info.snippets)\n                self.reference[\"url_to_info\"][url].snippets = list(\n                    set(existing_snippets)\n                )\n            citation_idx_mapping[idx + 1] = self.reference[\"url_to_unified_index\"][\n                url\n            ]  # The citation index starts from 1.\n        return citation_idx_mapping\n\n    def insert_or_create_section(\n        self,\n        article_dict: Dict[str, Dict],\n        parent_section_name: str = None,\n        trim_children=False,\n    ):\n        parent_node = (\n            self.root\n            if parent_section_name is None\n            else self.find_section(self.root, parent_section_name)\n        )\n\n        if trim_children:\n            section_names = set(article_dict.keys())\n            for child in parent_node.children[:]:\n                if child.section_name not in section_names:\n                    parent_node.remove_child(child)\n\n        for section_name, content_dict in article_dict.items():\n            current_section_node = self.find_section(parent_node, section_name)\n            if current_section_node is None:\n                current_section_node = ArticleSectionNode(\n                    section_name=section_name, content=content_dict[\"content\"].strip()\n                )\n                insert_to_front = (\n                    parent_node.section_name == self.root.section_name\n                    and current_section_node.section_name == \"summary\"\n                )\n                parent_node.add_child(\n                    current_section_node, insert_to_front=insert_to_front\n                )\n            else:\n                current_section_node.content = content_dict[\"content\"].strip()\n\n            self.insert_or_create_section(\n                article_dict=content_dict[\"subsections\"],\n                parent_section_name=section_name,\n                trim_children=True,\n            )\n\n    def update_section(\n        self,\n        current_section_content: str,\n        current_section_info_list: List[Information],\n        parent_section_name: Optional[str] = None,\n    ) -> Optional[ArticleSectionNode]:\n        \"\"\"\n        Add new section to the article.\n\n        Args:\n            current_section_name: new section heading name in string format.\n            parent_section_name: under which parent section to add the new one. Default to root.\n            current_section_content: optional section content.\n\n        Returns:\n            the ArticleSectionNode for current section if successfully created / updated. Otherwise none.\n        \"\"\"\n\n        if current_section_info_list is not None:\n            references = set(\n                [int(x) for x in re.findall(r\"\\[(\\d+)\\]\", current_section_content)]\n            )\n            # for any reference number greater than max number of references, delete the reference\n            if len(references) > 0:\n                max_ref_num = max(references)\n                if max_ref_num > len(current_section_info_list):\n                    for i in range(len(current_section_info_list), max_ref_num + 1):\n                        current_section_content = current_section_content.replace(\n                            f\"[{i}]\", \"\"\n                        )\n                        if i in references:\n                            references.remove(i)\n            # for any reference that is not used, trim it from current_section_info_list\n            index_to_keep = [i - 1 for i in references]\n            citation_mapping = self._merge_new_info_to_references(\n                current_section_info_list, index_to_keep\n            )\n            current_section_content = ArticleTextProcessing.update_citation_index(\n                current_section_content, citation_mapping\n            )\n\n        if parent_section_name is None:\n            parent_section_name = self.root.section_name\n        article_dict = ArticleTextProcessing.parse_article_into_dict(\n            current_section_content\n        )\n        self.insert_or_create_section(\n            article_dict=article_dict,\n            parent_section_name=parent_section_name,\n            trim_children=False,\n        )\n\n    def get_outline_as_list(\n        self,\n        root_section_name: Optional[str] = None,\n        add_hashtags: bool = False,\n        include_root: bool = True,\n    ) -> List[str]:\n        \"\"\"\n        Get outline of the article as a list.\n\n        Args:\n            section_name: get all section names in pre-order travel ordering in the subtree of section_name.\n                          For example:\n                            #root\n                            ##section1\n                            ###section1.1\n                            ###section1.2\n                            ##section2\n                          article.get_outline_as_list(\"section1\") returns [section1, section1.1, section1.2, section2]\n\n        Returns:\n            list of section and subsection names.\n        \"\"\"\n        if root_section_name is None:\n            section_node = self.root\n        else:\n            section_node = self.find_section(self.root, root_section_name)\n            include_root = include_root or section_node != self.root.section_name\n        if section_node is None:\n            return []\n        result = []\n\n        def preorder_traverse(node, level):\n            prefix = (\n                \"#\" * level if add_hashtags else \"\"\n            )  # Adjust level if excluding root\n            result.append(\n                f\"{prefix} {node.section_name}\".strip()\n                if add_hashtags\n                else node.section_name\n            )\n            for child in node.children:\n                preorder_traverse(child, level + 1)\n\n        # Adjust the initial level based on whether root is included and hashtags are added\n        if include_root:\n            preorder_traverse(section_node, level=1)\n        else:\n            for child in section_node.children:\n                preorder_traverse(child, level=1)\n        return result\n\n    def to_string(self) -> str:\n        \"\"\"\n        Get outline of the article as a list.\n\n        Returns:\n            list of section and subsection names.\n        \"\"\"\n        result = []\n\n        def preorder_traverse(node, level):\n            prefix = \"#\" * level\n            result.append(f\"{prefix} {node.section_name}\".strip())\n            result.append(node.content)\n            for child in node.children:\n                preorder_traverse(child, level + 1)\n\n        # Adjust the initial level based on whether root is included and hashtags are added\n        for child in self.root.children:\n            preorder_traverse(child, level=1)\n        result = [i.strip() for i in result if i is not None and i.strip()]\n        return \"\\n\\n\".join(result)\n\n    def reorder_reference_index(self):\n        # pre-order traversal to get order of references appear in the article\n        ref_indices = []\n\n        def pre_order_find_index(node):\n            if node is not None:\n                if node.content is not None and node.content:\n                    ref_indices.extend(\n                        ArticleTextProcessing.parse_citation_indices(node.content)\n                    )\n                for child in node.children:\n                    pre_order_find_index(child)\n\n        pre_order_find_index(self.root)\n        # constrcut index mapping\n        ref_index_mapping = {}\n        for ref_index in ref_indices:\n            if ref_index not in ref_index_mapping:\n                ref_index_mapping[ref_index] = len(ref_index_mapping) + 1\n\n        # update content\n        def pre_order_update_index(node):\n            if node is not None:\n                if node.content is not None and node.content:\n                    node.content = ArticleTextProcessing.update_citation_index(\n                        node.content, ref_index_mapping\n                    )\n                for child in node.children:\n                    pre_order_update_index(child)\n\n        pre_order_update_index(self.root)\n        # update reference\n        for url in list(self.reference[\"url_to_unified_index\"]):\n            pre_index = self.reference[\"url_to_unified_index\"][url]\n            if pre_index not in ref_index_mapping:\n                del self.reference[\"url_to_unified_index\"][url]\n            else:\n                new_index = ref_index_mapping[pre_index]\n                self.reference[\"url_to_unified_index\"][url] = new_index\n\n    def get_outline_tree(self):\n        def build_tree(node) -> Dict[str, Dict]:\n            tree = {}\n            for child in node.children:\n                tree[child.section_name] = build_tree(child)\n            return tree if tree else {}\n\n        return build_tree(self.root)\n\n    def get_first_level_section_names(self) -> List[str]:\n        \"\"\"\n        Get first level section names\n        \"\"\"\n        return [i.section_name for i in self.root.children]\n\n    @classmethod\n    def from_outline_file(cls, topic: str, file_path: str):\n        \"\"\"\n        Create StormArticle class instance from outline file.\n        \"\"\"\n        outline_str = FileIOHelper.load_str(file_path)\n        return StormArticle.from_outline_str(topic=topic, outline_str=outline_str)\n\n    @classmethod\n    def from_outline_str(cls, topic: str, outline_str: str):\n        \"\"\"\n        Create StormArticle class instance from outline only string.\n        \"\"\"\n        lines = []\n        try:\n            lines = outline_str.split(\"\\n\")\n            lines = [line.strip() for line in lines if line.strip()]\n        except:\n            pass\n\n        instance = cls(topic)\n        if lines:\n            a = lines[0].startswith(\"#\") and lines[0].replace(\"#\", \"\").strip().lower()\n            b = topic.lower().replace(\"_\", \" \")\n            adjust_level = lines[0].startswith(\"#\") and lines[0].replace(\n                \"#\", \"\"\n            ).strip().lower() == topic.lower().replace(\"_\", \" \")\n            if adjust_level:\n                lines = lines[1:]\n            node_stack = [(0, instance.root)]  # Stack to keep track of (level, node)\n\n            for line in lines:\n                level = line.count(\"#\") - adjust_level\n                section_name = line.replace(\"#\", \"\").strip()\n\n                if section_name == topic:\n                    continue\n\n                new_node = ArticleSectionNode(section_name)\n\n                while node_stack and level <= node_stack[-1][0]:\n                    node_stack.pop()\n\n                node_stack[-1][1].add_child(new_node)\n                node_stack.append((level, new_node))\n        return instance\n\n    def dump_outline_to_file(self, file_path):\n        outline = self.get_outline_as_list(add_hashtags=True, include_root=False)\n        FileIOHelper.write_str(\"\\n\".join(outline), file_path)\n\n    def dump_reference_to_file(self, file_path):\n        reference = copy.deepcopy(self.reference)\n        for url in reference[\"url_to_info\"]:\n            reference[\"url_to_info\"][url] = reference[\"url_to_info\"][url].to_dict()\n        FileIOHelper.dump_json(reference, file_path)\n\n    def dump_article_as_plain_text(self, file_path):\n        text = self.to_string()\n        FileIOHelper.write_str(text, file_path)\n\n    @classmethod\n    def from_string(cls, topic_name: str, article_text: str, references: dict):\n        article_dict = ArticleTextProcessing.parse_article_into_dict(article_text)\n        article = cls(topic_name=topic_name)\n        article.insert_or_create_section(article_dict=article_dict)\n        for url in list(references[\"url_to_info\"]):\n            references[\"url_to_info\"][url] = Information.from_dict(\n                references[\"url_to_info\"][url]\n            )\n        article.reference = references\n        return article\n\n    def post_processing(self):\n        self.prune_empty_nodes()\n        self.reorder_reference_index()\n"
  },
  {
    "path": "knowledge_storm/utils.py",
    "content": "import concurrent.futures\nimport dspy\nimport httpx\nimport json\nimport logging\nimport os\nimport pickle\nimport re\nimport regex\nimport sys\nimport toml\nfrom typing import List, Dict\nfrom tqdm import tqdm\n\nfrom langchain_text_splitters import RecursiveCharacterTextSplitter\nfrom trafilatura import extract\n\nfrom .lm import LitellmModel\n\nlogging.getLogger(\"httpx\").setLevel(logging.WARNING)  # Disable INFO logging for httpx.\n\n\ndef truncate_filename(filename, max_length=125):\n    \"\"\"Truncate filename to max_length to ensure the filename won't exceed the file system limit.\n\n    Args:\n        filename: str\n        max_length: int, default to 125 (usual path length limit is 255 chars)\n    \"\"\"\n\n    if len(filename) > max_length:\n        truncated_filename = filename[:max_length]\n        logging.warning(\n            f\"Filename is too long. Filename is truncated to {truncated_filename}.\"\n        )\n        return truncated_filename\n\n    return filename\n\n\ndef load_api_key(toml_file_path):\n    try:\n        with open(toml_file_path, \"r\") as file:\n            data = toml.load(file)\n    except FileNotFoundError:\n        print(f\"File not found: {toml_file_path}\", file=sys.stderr)\n        return\n    except toml.TomlDecodeError:\n        print(f\"Error decoding TOML file: {toml_file_path}\", file=sys.stderr)\n        return\n    # Set environment variables\n    for key, value in data.items():\n        os.environ[key] = str(value)\n\n\ndef makeStringRed(message):\n    return f\"\\033[91m {message}\\033[00m\"\n\n\nclass QdrantVectorStoreManager:\n    \"\"\"\n    Helper class for managing the Qdrant vector store, can be used with `VectorRM` in rm.py.\n\n    Before you initialize `VectorRM`, call `create_or_update_vector_store` to create or update the vector store.\n    Once you have the vector store, you can initialize `VectorRM` with the vector store path or the Qdrant server URL.\n    \"\"\"\n\n    @staticmethod\n    def _check_create_collection(\n        client: \"QdrantClient\", collection_name: str, model: \"HuggingFaceEmbeddings\"\n    ):\n        from langchain_qdrant import Qdrant\n        from qdrant_client import models\n\n        \"\"\"Check if the Qdrant collection exists and create it if it does not.\"\"\"\n        if client is None:\n            raise ValueError(\"Qdrant client is not initialized.\")\n        if client.collection_exists(collection_name=f\"{collection_name}\"):\n            print(f\"Collection {collection_name} exists. Loading the collection...\")\n            return Qdrant(\n                client=client,\n                collection_name=collection_name,\n                embeddings=model,\n            )\n        else:\n            print(\n                f\"Collection {collection_name} does not exist. Creating the collection...\"\n            )\n            # create the collection\n            client.create_collection(\n                collection_name=f\"{collection_name}\",\n                vectors_config=models.VectorParams(\n                    size=1024, distance=models.Distance.COSINE\n                ),\n            )\n            return Qdrant(\n                client=client,\n                collection_name=collection_name,\n                embeddings=model,\n            )\n\n    @staticmethod\n    def _init_online_vector_db(\n        url: str, api_key: str, collection_name: str, model: \"HuggingFaceEmbeddings\"\n    ):\n        from qdrant_client import QdrantClient\n\n        \"\"\"Initialize the Qdrant client that is connected to an online vector store with the given URL and API key.\n\n        Args:\n            url (str): URL of the Qdrant server.\n            api_key (str): API key for the Qdrant server.\n        \"\"\"\n        if api_key is None:\n            if not os.getenv(\"QDRANT_API_KEY\"):\n                raise ValueError(\"Please provide an api key.\")\n            api_key = os.getenv(\"QDRANT_API_KEY\")\n        if url is None:\n            raise ValueError(\"Please provide a url for the Qdrant server.\")\n\n        try:\n            client = QdrantClient(url=url, api_key=api_key)\n            return QdrantVectorStoreManager._check_create_collection(\n                client=client, collection_name=collection_name, model=model\n            )\n        except Exception as e:\n            raise ValueError(f\"Error occurs when connecting to the server: {e}\")\n\n    @staticmethod\n    def _init_offline_vector_db(\n        vector_store_path: str, collection_name: str, model: \"HuggingFaceEmbeddings\"\n    ):\n        from qdrant_client import QdrantClient\n\n        \"\"\"Initialize the Qdrant client that is connected to an offline vector store with the given vector store folder path.\n\n        Args:\n            vector_store_path (str): Path to the vector store.\n        \"\"\"\n        if vector_store_path is None:\n            raise ValueError(\"Please provide a folder path.\")\n\n        try:\n            client = QdrantClient(path=vector_store_path)\n            return QdrantVectorStoreManager._check_create_collection(\n                client=client, collection_name=collection_name, model=model\n            )\n        except Exception as e:\n            raise ValueError(f\"Error occurs when loading the vector store: {e}\")\n\n    @staticmethod\n    def create_or_update_vector_store(\n        collection_name: str,\n        vector_db_mode: str,\n        file_path: str,\n        content_column: str,\n        title_column: str = \"title\",\n        url_column: str = \"url\",\n        desc_column: str = \"description\",\n        batch_size: int = 64,\n        chunk_size: int = 500,\n        chunk_overlap: int = 100,\n        vector_store_path: str = None,\n        url: str = None,\n        qdrant_api_key: str = None,\n        embedding_model: str = \"BAAI/bge-m3\",\n        device: str = \"mps\",\n    ):\n        from qdrant_client import Document\n\n        \"\"\"\n        Takes a CSV file and adds each row in the CSV file to the Qdrant collection.\n\n        This function expects each row of the CSV file as a document.\n        The CSV file should have columns for \"content\", \"title\", \"URL\", and \"description\".\n\n        Args:\n            collection_name: Name of the Qdrant collection.\n            vector_store_path (str): Path to the directory where the vector store is stored or will be stored.\n            vector_db_mode (str): Mode of the Qdrant vector store (offline or online).\n            file_path (str): Path to the CSV file.\n            content_column (str): Name of the column containing the content.\n            title_column (str): Name of the column containing the title. Default is \"title\".\n            url_column (str): Name of the column containing the URL. Default is \"url\".\n            desc_column (str): Name of the column containing the description. Default is \"description\".\n            batch_size (int): Batch size for adding documents to the collection.\n            chunk_size: Size of each chunk if you need to build the vector store from documents.\n            chunk_overlap: Overlap between chunks if you need to build the vector store from documents.\n            embedding_model: Name of the Hugging Face embedding model.\n            device: Device to run the embeddings model on, can be \"mps\", \"cuda\", \"cpu\".\n            qdrant_api_key: API key for the Qdrant server (Only required if the Qdrant server is online).\n        \"\"\"\n        # check if the collection name is provided\n        if collection_name is None:\n            raise ValueError(\"Please provide a collection name.\")\n\n        model_kwargs = {\"device\": device}\n        encode_kwargs = {\"normalize_embeddings\": True}\n        from langchain_huggingface import HuggingFaceEmbeddings\n\n        model = HuggingFaceEmbeddings(\n            model_name=embedding_model,\n            model_kwargs=model_kwargs,\n            encode_kwargs=encode_kwargs,\n        )\n\n        if file_path is None:\n            raise ValueError(\"Please provide a file path.\")\n        # check if the file is a csv file\n        if not file_path.endswith(\".csv\"):\n            raise ValueError(f\"Not valid file format. Please provide a csv file.\")\n        if content_column is None:\n            raise ValueError(\"Please provide the name of the content column.\")\n        if url_column is None:\n            raise ValueError(\"Please provide the name of the url column.\")\n\n        # try to initialize the Qdrant client\n        qdrant = None\n        if vector_db_mode == \"online\":\n            qdrant = QdrantVectorStoreManager._init_online_vector_db(\n                url=url,\n                api_key=qdrant_api_key,\n                collection_name=collection_name,\n                model=model,\n            )\n        elif vector_db_mode == \"offline\":\n            qdrant = QdrantVectorStoreManager._init_offline_vector_db(\n                vector_store_path=vector_store_path,\n                collection_name=collection_name,\n                model=model,\n            )\n        else:\n            raise ValueError(\n                \"Invalid vector_db_mode. Please provide either 'online' or 'offline'.\"\n            )\n        if qdrant is None:\n            raise ValueError(\"Qdrant client is not initialized.\")\n\n        # read the csv file\n        import pandas as pd\n\n        df = pd.read_csv(file_path)\n        # check that content column exists and url column exists\n        if content_column not in df.columns:\n            raise ValueError(\n                f\"Content column {content_column} not found in the csv file.\"\n            )\n        if url_column not in df.columns:\n            raise ValueError(f\"URL column {url_column} not found in the csv file.\")\n\n        documents = [\n            Document(\n                page_content=row[content_column],\n                metadata={\n                    \"title\": row.get(title_column, \"\"),\n                    \"url\": row[url_column],\n                    \"description\": row.get(desc_column, \"\"),\n                },\n            )\n            for row in df.to_dict(orient=\"records\")\n        ]\n\n        # split the documents\n        from langchain_text_splitters import RecursiveCharacterTextSplitter\n\n        text_splitter = RecursiveCharacterTextSplitter(\n            chunk_size=chunk_size,\n            chunk_overlap=chunk_overlap,\n            length_function=len,\n            add_start_index=True,\n            separators=[\n                \"\\n\\n\",\n                \"\\n\",\n                \".\",\n                \"\\uff0e\",  # Fullwidth full stop\n                \"\\u3002\",  # Ideographic full stop\n                \",\",\n                \"\\uff0c\",  # Fullwidth comma\n                \"\\u3001\",  # Ideographic comma\n                \" \",\n                \"\\u200B\",  # Zero-width space\n                \"\",\n            ],\n        )\n        split_documents = text_splitter.split_documents(documents)\n\n        # update and save the vector store\n        num_batches = (len(split_documents) + batch_size - 1) // batch_size\n        for i in tqdm(range(num_batches)):\n            start_idx = i * batch_size\n            end_idx = min((i + 1) * batch_size, len(split_documents))\n            qdrant.add_documents(\n                documents=split_documents[start_idx:end_idx],\n                batch_size=batch_size,\n            )\n\n        # close the qdrant client\n        qdrant.client.close()\n\n\nclass ArticleTextProcessing:\n    @staticmethod\n    def limit_word_count_preserve_newline(input_string, max_word_count):\n        \"\"\"\n        Limit the word count of an input string to a specified maximum, while preserving the integrity of complete lines.\n\n        The function truncates the input string at the nearest word that does not exceed the maximum word count,\n        ensuring that no partial lines are included in the output. Words are defined as text separated by spaces,\n        and lines are defined as text separated by newline characters.\n\n        Args:\n            input_string (str): The string to be truncated. This string may contain multiple lines.\n            max_word_count (int): The maximum number of words allowed in the truncated string.\n\n        Returns:\n            str: The truncated string with word count limited to `max_word_count`, preserving complete lines.\n        \"\"\"\n\n        word_count = 0\n        limited_string = \"\"\n\n        for word in input_string.split(\"\\n\"):\n            line_words = word.split()\n            for lw in line_words:\n                if word_count < max_word_count:\n                    limited_string += lw + \" \"\n                    word_count += 1\n                else:\n                    break\n            if word_count >= max_word_count:\n                break\n            limited_string = limited_string.strip() + \"\\n\"\n\n        return limited_string.strip()\n\n    @staticmethod\n    def remove_citations(s):\n        \"\"\"\n        Removes all citations from a given string. Citations are assumed to be in the format\n        of numbers enclosed in square brackets, such as [1], [2], or [1, 2], etc. This function searches\n        for all occurrences of such patterns and removes them, returning the cleaned string.\n\n        Args:\n            s (str): The string from which citations are to be removed.\n\n        Returns:\n            str: The string with all citation patterns removed.\n        \"\"\"\n\n        return re.sub(r\"\\[\\d+(?:,\\s*\\d+)*\\]\", \"\", s)\n\n    @staticmethod\n    def parse_citation_indices(s):\n        \"\"\"\n        Extracts citation indexes from the provided content string and returns them as a list of integers.\n\n        Args:\n            content (str): The content string containing citations in the format [number].\n\n        Returns:\n            List[int]: A list of unique citation indexes extracted from the content, in the order they appear.\n        \"\"\"\n        matches = re.findall(r\"\\[\\d+\\]\", s)\n        return [int(index[1:-1]) for index in matches]\n\n    @staticmethod\n    def remove_uncompleted_sentences_with_citations(text):\n        \"\"\"\n        Removes uncompleted sentences and standalone citations from the input text. Sentences are identified\n        by their ending punctuation (.!?), optionally followed by a citation in square brackets (e.g., \"[1]\").\n        Grouped citations (e.g., \"[1, 2]\") are split into individual ones (e.g., \"[1] [2]\"). Only text up to\n        and including the last complete sentence and its citation is retained.\n\n        Args:\n            text (str): The input text from which uncompleted sentences and their citations are to be removed.\n\n        Returns:\n            str: The processed string with uncompleted sentences and standalone citations removed, leaving only\n            complete sentences and their associated citations if present.\n        \"\"\"\n\n        # Convert citations like [1, 2, 3] to [1][2][3].\n        def replace_with_individual_brackets(match):\n            numbers = match.group(1).split(\", \")\n            return \" \".join(f\"[{n}]\" for n in numbers)\n\n        # Deduplicate and sort individual groups of citations.\n        def deduplicate_group(match):\n            citations = match.group(0)\n            unique_citations = list(set(re.findall(r\"\\[\\d+\\]\", citations)))\n            sorted_citations = sorted(\n                unique_citations, key=lambda x: int(x.strip(\"[]\"))\n            )\n            # Return the sorted unique citations as a string\n            return \"\".join(sorted_citations)\n\n        text = re.sub(r\"\\[([0-9, ]+)\\]\", replace_with_individual_brackets, text)\n        text = re.sub(r\"(\\[\\d+\\])+\", deduplicate_group, text)\n\n        # Deprecated: Remove sentence without proper ending punctuation and citations.\n        # Split the text into sentences (including citations).\n        # sentences_with_trailing = re.findall(r'([^.!?]*[.!?].*?)(?=[^.!?]*[.!?]|$)', text)\n\n        # Filter sentences to ensure they end with a punctuation mark and properly formatted citations\n        # complete_sentences = []\n        # for sentence in sentences_with_trailing:\n        #     # Check if the sentence ends with properly formatted citations\n        #     if re.search(r'[.!?]( \\[\\d+\\])*$|^[^.!?]*[.!?]$', sentence.strip()):\n        #         complete_sentences.append(sentence.strip())\n\n        # combined_sentences = ' '.join(complete_sentences)\n\n        # Check for and append any complete citations that follow the last sentence\n        # trailing_citations = re.findall(r'(\\[\\d+\\]) ', text[text.rfind(combined_sentences) + len(combined_sentences):])\n        # if trailing_citations:\n        #     combined_sentences += ' '.join(trailing_citations)\n\n        # Regex pattern to match sentence endings, including optional citation markers.\n        eos_pattern = r\"([.!?])\\s*(\\[\\d+\\])?\\s*\"\n        matches = list(re.finditer(eos_pattern, text))\n        if matches:\n            last_match = matches[-1]\n            text = text[: last_match.end()].strip()\n\n        return text\n\n    @staticmethod\n    def clean_up_citation(conv):\n        for turn in conv.dlg_history:\n            if \"References:\" in turn.agent_utterance:\n                turn.agent_utterance = turn.agent_utterance[\n                    : turn.agent_utterance.find(\"References:\")\n                ]\n            if \"Sources:\" in turn.agent_utterance:\n                turn.agent_utterance = turn.agent_utterance[\n                    : turn.agent_utterance.find(\"Sources:\")\n                ]\n            turn.agent_utterance = turn.agent_utterance.replace(\"Answer:\", \"\").strip()\n            try:\n                max_ref_num = max(\n                    [int(x) for x in re.findall(r\"\\[(\\d+)\\]\", turn.agent_utterance)]\n                )\n            except Exception as e:\n                max_ref_num = 0\n            if max_ref_num > len(turn.search_results):\n                for i in range(len(turn.search_results), max_ref_num + 1):\n                    turn.agent_utterance = turn.agent_utterance.replace(f\"[{i}]\", \"\")\n            turn.agent_utterance = (\n                ArticleTextProcessing.remove_uncompleted_sentences_with_citations(\n                    turn.agent_utterance\n                )\n            )\n\n        return conv\n\n    @staticmethod\n    def clean_up_outline(outline, topic=\"\"):\n        output_lines = []\n        current_level = 0  # To track the current section level\n\n        for line in outline.split(\"\\n\"):\n            stripped_line = line.strip()\n\n            if topic != \"\" and f\"# {topic.lower()}\" in stripped_line.lower():\n                output_lines = []\n\n            # Check if the line is a section header\n            if stripped_line.startswith(\"#\"):\n                current_level = stripped_line.count(\"#\")\n                output_lines.append(stripped_line)\n            # Check if the line is a bullet point\n            elif stripped_line.startswith(\"-\"):\n                subsection_header = (\n                    \"#\" * (current_level + 1) + \" \" + stripped_line[1:].strip()\n                )\n                output_lines.append(subsection_header)\n\n        outline = \"\\n\".join(output_lines)\n\n        # Remove references.\n        outline = re.sub(r\"#[#]? See also.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        outline = re.sub(r\"#[#]? See Also.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        outline = re.sub(r\"#[#]? Notes.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        outline = re.sub(r\"#[#]? References.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        outline = re.sub(\n            r\"#[#]? External links.*?(?=##|$)\", \"\", outline, flags=re.DOTALL\n        )\n        outline = re.sub(\n            r\"#[#]? External Links.*?(?=##|$)\", \"\", outline, flags=re.DOTALL\n        )\n        outline = re.sub(r\"#[#]? Bibliography.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        outline = re.sub(\n            r\"#[#]? Further reading*?(?=##|$)\", \"\", outline, flags=re.DOTALL\n        )\n        outline = re.sub(\n            r\"#[#]? Further Reading*?(?=##|$)\", \"\", outline, flags=re.DOTALL\n        )\n        outline = re.sub(r\"#[#]? Summary.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        outline = re.sub(r\"#[#]? Appendices.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        outline = re.sub(r\"#[#]? Appendix.*?(?=##|$)\", \"\", outline, flags=re.DOTALL)\n        # clean up citation in outline\n        outline = re.sub(r\"\\[.*?\\]\", \"\", outline)\n        return outline\n\n    @staticmethod\n    def clean_up_section(text):\n        \"\"\"Clean up a section:\n        1. Remove uncompleted sentences (usually due to output token limitation).\n        2. Deduplicate individual groups of citations.\n        3. Remove unnecessary summary.\"\"\"\n\n        paragraphs = text.split(\"\\n\")\n        output_paragraphs = []\n        summary_sec_flag = False\n        for p in paragraphs:\n            p = p.strip()\n            if len(p) == 0:\n                continue\n            if not p.startswith(\"#\"):\n                p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p)\n            if summary_sec_flag:\n                if p.startswith(\"#\"):\n                    summary_sec_flag = False\n                else:\n                    continue\n            if (\n                p.startswith(\"Overall\")\n                or p.startswith(\"In summary\")\n                or p.startswith(\"In conclusion\")\n            ):\n                continue\n            if \"# Summary\" in p or \"# Conclusion\" in p:\n                summary_sec_flag = True\n                continue\n            output_paragraphs.append(p)\n\n        # Join with '\\n\\n' for markdown format.\n        return \"\\n\\n\".join(output_paragraphs)\n\n    @staticmethod\n    def update_citation_index(s, citation_map):\n        \"\"\"Update citation index in the string based on the citation map.\"\"\"\n        for original_citation in citation_map:\n            s = s.replace(\n                f\"[{original_citation}]\", f\"__PLACEHOLDER_{original_citation}__\"\n            )\n        for original_citation, unify_citation in citation_map.items():\n            s = s.replace(f\"__PLACEHOLDER_{original_citation}__\", f\"[{unify_citation}]\")\n\n        return s\n\n    @staticmethod\n    def parse_article_into_dict(input_string):\n        \"\"\"\n        Parses a structured text into a nested dictionary. The structure of the text\n        is defined by markdown-like headers (using '#' symbols) to denote sections\n        and subsections. Each section can contain content and further nested subsections.\n\n        The resulting dictionary captures the hierarchical structure of sections, where\n        each section is represented as a key (the section's title) mapping to a value\n        that is another dictionary. This dictionary contains two keys:\n        - 'content': content of the section\n        - 'subsections': a list of dictionaries, each representing a nested subsection\n        following the same structure.\n\n        Args:\n            input_string (str): A string containing the structured text to parse.\n\n        Returns:\n            A dictionary representing contains the section title as the key, and another dictionary\n        as the value, which includes the 'content' and 'subsections' keys as described above.\n        \"\"\"\n        lines = input_string.split(\"\\n\")\n        lines = [line for line in lines if line.strip()]\n        root = {\"content\": \"\", \"subsections\": {}}\n        current_path = [(root, -1)]  # (current_dict, level)\n\n        for line in lines:\n            if line.startswith(\"#\"):\n                level = line.count(\"#\")\n                title = line.strip(\"# \").strip()\n                new_section = {\"content\": \"\", \"subsections\": {}}\n\n                # Pop from stack until find the parent level\n                while current_path and current_path[-1][1] >= level:\n                    current_path.pop()\n\n                # Append new section to the nearest upper level's subsections\n                current_path[-1][0][\"subsections\"][title] = new_section\n                current_path.append((new_section, level))\n            else:\n                current_path[-1][0][\"content\"] += line + \"\\n\"\n\n        return root[\"subsections\"]\n\n\nclass FileIOHelper:\n    @staticmethod\n    def dump_json(obj, file_name, encoding=\"utf-8\"):\n        with open(file_name, \"w\", encoding=encoding) as fw:\n            json.dump(obj, fw, default=FileIOHelper.handle_non_serializable)\n\n    @staticmethod\n    def handle_non_serializable(obj):\n        return \"non-serializable contents\"  # mark the non-serializable part\n\n    @staticmethod\n    def load_json(file_name, encoding=\"utf-8\"):\n        with open(file_name, \"r\", encoding=encoding) as fr:\n            return json.load(fr)\n\n    @staticmethod\n    def write_str(s, path):\n        with open(path, \"w\") as f:\n            f.write(s)\n\n    @staticmethod\n    def load_str(path):\n        with open(path, \"r\") as f:\n            return \"\\n\".join(f.readlines())\n\n    @staticmethod\n    def dump_pickle(obj, path):\n        with open(path, \"wb\") as f:\n            pickle.dump(obj, f)\n\n    @staticmethod\n    def load_pickle(path):\n        with open(path, \"rb\") as f:\n            return pickle.load(f)\n\n\nclass WebPageHelper:\n    \"\"\"Helper class to process web pages.\n\n    Acknowledgement: Part of the code is adapted from https://github.com/stanford-oval/WikiChat project.\n    \"\"\"\n\n    def __init__(\n        self,\n        min_char_count: int = 150,\n        snippet_chunk_size: int = 1000,\n        max_thread_num: int = 10,\n    ):\n        \"\"\"\n        Args:\n            min_char_count: Minimum character count for the article to be considered valid.\n            snippet_chunk_size: Maximum character count for each snippet.\n            max_thread_num: Maximum number of threads to use for concurrent requests (e.g., downloading webpages).\n        \"\"\"\n        self.httpx_client = httpx.Client(verify=False)\n        self.min_char_count = min_char_count\n        self.max_thread_num = max_thread_num\n        self.text_splitter = RecursiveCharacterTextSplitter(\n            chunk_size=snippet_chunk_size,\n            chunk_overlap=0,\n            length_function=len,\n            is_separator_regex=False,\n            separators=[\n                \"\\n\\n\",\n                \"\\n\",\n                \".\",\n                \"\\uff0e\",  # Fullwidth full stop\n                \"\\u3002\",  # Ideographic full stop\n                \",\",\n                \"\\uff0c\",  # Fullwidth comma\n                \"\\u3001\",  # Ideographic comma\n                \" \",\n                \"\\u200B\",  # Zero-width space\n                \"\",\n            ],\n        )\n\n    def download_webpage(self, url: str):\n        try:\n            res = self.httpx_client.get(url, timeout=4)\n            if res.status_code >= 400:\n                res.raise_for_status()\n            return res.content\n        except httpx.HTTPError as exc:\n            print(f\"Error while requesting {exc.request.url!r} - {exc!r}\")\n            return None\n\n    def urls_to_articles(self, urls: List[str]) -> Dict:\n        with concurrent.futures.ThreadPoolExecutor(\n            max_workers=self.max_thread_num\n        ) as executor:\n            htmls = list(executor.map(self.download_webpage, urls))\n\n        articles = {}\n\n        for h, u in zip(htmls, urls):\n            if h is None:\n                continue\n            article_text = extract(\n                h,\n                include_tables=False,\n                include_comments=False,\n                output_format=\"txt\",\n            )\n            if article_text is not None and len(article_text) > self.min_char_count:\n                articles[u] = {\"text\": article_text}\n\n        return articles\n\n    def urls_to_snippets(self, urls: List[str]) -> Dict:\n        articles = self.urls_to_articles(urls)\n        for u in articles:\n            articles[u][\"snippets\"] = self.text_splitter.split_text(articles[u][\"text\"])\n\n        return articles\n\n\ndef user_input_appropriateness_check(user_input):\n    my_openai_model = LitellmModel(\n        model=\"azure/gpt-4o-mini\",\n        max_tokens=10,\n        temperature=0.0,\n        top_p=0.9,\n    )\n\n    if len(user_input.split()) > 20:\n        return \"The input is too long. Please make your input topic more concise!\"\n\n    if not re.match(r'^[a-zA-Z0-9\\s\\-\"\\,\\.?\\']*$', user_input):\n        return \"The input contains invalid characters. The input should only contain a-z, A-Z, 0-9, space, -/\\\"/,./?/'.\"\n\n    prompt = f\"\"\"Here is a topic input into a knowledge curation engine that can write a Wikipedia-like article for the topic. Please judge whether it is appropriate or not for the engine to curate information for this topic based on English search engine. The following types of inputs are inappropriate:\n1. Inputs that may be related to illegal, harmful, violent, racist, or sexual purposes.\n2. Inputs that are given using languages other than English. Currently, the engine can only support English.\n3. Inputs that are related to personal experience or personal information. Currently, the engine can only use information from the search engine.\n4. Inputs that are not aimed at topic research or inquiry. For example, asks requiring detailed execution, such as calculations, programming, or specific service searches fall outside the engine's scope of capabilities.\nIf the topic is appropriate for the engine to process, output \"Yes.\"; otherwise, output \"No. The input violates reason [1/2/3/4]\".\nUser input: {user_input}\"\"\"\n    reject_reason_info = {\n        1: \"Sorry, this input may be related to sensitive topics. Please try another topic. \"\n        \"(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. \"\n        \"We apologize for any inconvenience.)\",\n        2: \"Sorry, the current engine can only support English. Please try another topic. \"\n        \"(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. \"\n        \"We apologize for any inconvenience.)\",\n        3: \"Sorry, the current engine cannot process topics related to personal experience. Please try another topic. \"\n        \"(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. \"\n        \"We apologize for any inconvenience.)\",\n        4: \"Sorry, STORM cannot follow arbitrary instruction. Please input a topic you want to learn about. \"\n        \"(Our input filtering uses OpenAI GPT-4o-mini, which may result in false positives. \"\n        \"We apologize for any inconvenience.)\",\n    }\n\n    try:\n        response = my_openai_model(prompt)[0].replace(\"[\", \"\").replace(\"]\", \"\")\n        if response.startswith(\"No\"):\n            match = regex.search(r\"reason\\s(\\d+)\", response)\n            if match:\n                reject_reason = int(match.group(1))\n                if reject_reason in reject_reason_info:\n                    return reject_reason_info[reject_reason]\n                else:\n                    return (\n                        \"Sorry, the input is inappropriate. Please try another topic!\"\n                    )\n            return \"Sorry, the input is inappropriate. Please try another topic!\"\n\n    except Exception as e:\n        return \"Sorry, the input is inappropriate. Please try another topic!\"\n    return \"Approved\"\n\n\ndef purpose_appropriateness_check(user_input):\n    my_openai_model = LitellmModel(\n        model=\"azure/gpt-4o-mini\",\n        max_tokens=10,\n        temperature=0.0,\n        top_p=0.9,\n    )\n\n    prompt = f\"\"\"\n    Here is a purpose input into a report generation engine that can create a long-form report on any topic of interest. \n    Please judge whether the provided purpose is valid for using this service. \n    Try to judge if given purpose is non-sense like random words or just try to get around the sanity check.\n    You should not make the rule too strict.\n    \n    If the purpose is valid, output \"Yes.\"; otherwise, output \"No\" followed by reason.\n    User input: {user_input}\n    \"\"\"\n    try:\n        response = my_openai_model(prompt)[0].replace(\"[\", \"\").replace(\"]\", \"\")\n        if response.startswith(\"No\"):\n            return \"Please provide a more detailed explanation on your purpose of requesting this article.\"\n\n    except Exception as e:\n        return \"Please provide a more detailed explanation on your purpose of requesting this article.\"\n    return \"Approved\"\n"
  },
  {
    "path": "requirements.txt",
    "content": "dspy_ai==2.4.9\nwikipedia==1.4.0\nsentence-transformers\ntoml\nlangchain-text-splitters\ntrafilatura\nlangchain-huggingface\nqdrant-client\nlangchain-qdrant\nnumpy\nlitellm\ndiskcache"
  },
  {
    "path": "setup.py",
    "content": "import re\n\nfrom setuptools import setup, find_packages\n\n# Read the content of the README file\nwith open(\"README.md\", encoding=\"utf-8\") as f:\n    long_description = f.read()\n    # Remove p tags.\n    pattern = re.compile(r\"<p.*?>.*?</p>\", re.DOTALL)\n    long_description = re.sub(pattern, \"\", long_description)\n\n# Read the content of the requirements.txt file\nwith open(\"requirements.txt\", encoding=\"utf-8\") as f:\n    requirements = f.read().splitlines()\n\n\nsetup(\n    name=\"knowledge-storm\",\n    version=\"1.1.1\",\n    author=\"Yijia Shao, Yucheng Jiang\",\n    author_email=\"shaoyj@stanford.edu, yuchengj@stanford.edu\",\n    description=\"STORM: A language model-powered knowledge curation engine.\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    url=\"https://github.com/stanford-oval/storm\",\n    license=\"MIT License\",\n    packages=find_packages(),\n    classifiers=[\n        \"Development Status :: 3 - Alpha\",\n        \"License :: OSI Approved :: MIT License\",\n        \"Operating System :: OS Independent\",\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: Python :: 3.10\",\n        \"Programming Language :: Python :: 3.11\",\n    ],\n    python_requires=\">=3.10\",\n    install_requires=requirements,\n)\n"
  }
]