[
  {
    "path": "Document_QA.py",
    "content": "\nimport openai\nimport faiss\nimport numpy as np\nimport pickle\nfrom tqdm import tqdm\nimport argparse\nimport os\n\ndef create_embeddings(input):\n    \"\"\"Create embeddings for the provided input.\"\"\"\n    result = []\n    # limit about 1000 tokens per request\n    lens = [len(text) for text in input]\n    query_len = 0\n    start_index = 0\n    tokens = 0\n\n    def get_embedding(input_slice):\n        embedding = openai.Embedding.create(model=\"text-embedding-ada-002\", input=input_slice)\n        return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens\n\n    for index, l in tqdm(enumerate(lens)):\n        query_len += l\n        if query_len > 4096:\n            ebd, tk = get_embedding(input[start_index:index + 1])\n            query_len = 0\n            start_index = index + 1\n            tokens += tk\n            result.extend(ebd)\n\n    if query_len > 0:\n        ebd, tk = get_embedding(input[start_index:])\n        tokens += tk\n        result.extend(ebd)\n    return result, tokens\n\ndef create_embedding(text):\n    \"\"\"Create an embedding for the provided text.\"\"\"\n    embedding = openai.Embedding.create(model=\"text-embedding-ada-002\", input=text)\n    return text, embedding.data[0].embedding\n\nclass QA():\n    def __init__(self,data_embe) -> None:\n        d = 1536\n        index = faiss.IndexFlatL2(d)\n        embe = np.array([emm[1] for emm in data_embe])\n        data = [emm[0] for emm in data_embe]\n        index.add(embe)\n        self.index = index\n        self.data = data\n    def __call__(self, query):\n        embedding = create_embedding(query)\n        context = self.get_texts(embedding[1], limit)\n        answer = self.completion(query,context)\n        return answer,context\n    def get_texts(self,embeding,limit):\n        _,text_index = self.index.search(np.array([embeding]),limit)\n        context = []\n        for i in list(text_index[0]):\n            context.extend(self.data[i:i+5])\n        # context = [self.data[i] for i in list(text_index[0])]\n        return context\n    \n    def completion(self,query, context):\n        \"\"\"Create a completion.\"\"\"\n        lens = [len(text) for text in context]\n\n        maximum = 3000\n        for index, l in enumerate(lens):\n            maximum -= l\n            if maximum < 0:\n                context = context[:index + 1]\n                print(\"超过最大长度，截断到前\", index + 1, \"个片段\")\n                break\n\n        text = \"\\n\".join(f\"{index}. {text}\" for index, text in enumerate(context))\n        response = openai.ChatCompletion.create(\n            model=\"gpt-3.5-turbo\",\n            messages=[\n                {'role': 'system',\n                'content': f'你是一个有帮助的AI文章助手，从下文中提取有用的内容进行回答，不能回答不在下文提到的内容，相关性从高到底排序：\\n\\n{text}'},\n                {'role': 'user', 'content': query},\n            ],\n        )\n        print(\"使用的tokens：\", response.usage.total_tokens)\n        return response.choices[0].message.content\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description=\"Document QA\")\n    parser.add_argument(\"--input_file\", default=\"input.txt\", dest=\"input_file\", type=str,help=\"输入文件路径\")\n    parser.add_argument(\"--file_embeding\", default=\"input_embed.pkl\", dest=\"file_embeding\", type=str,help=\"文件embeding文件路径\")\n    parser.add_argument(\"--print_context\", action='store_true',help=\"是否打印上下文\")\n    \n\n    args = parser.parse_args()\n\n    if os.path.isfile(args.file_embeding):\n        data_embe = pickle.load(open(args.file_embeding,'rb'))\n    else:\n        with open(args.input_file,'r',encoding='utf-8') as f:\n            texts = f.readlines()\n            texts = [text.strip() for text in texts if text.strip()]\n            data_embe,tokens = create_embeddings(texts)\n            pickle.dump(data_embe,open(args.file_embeding,'wb'))\n            print(\"文本消耗 {} tokens\".format(tokens))\n\n    qa =QA(data_embe)\n\n    limit = 10\n    while True:\n        query = input(\"请输入查询(help可查看指令)：\")\n        if query == \"quit\":\n            break\n        elif query.startswith(\"limit\"):\n            try:\n                limit = int(query.split(\" \")[1])\n                print(\"已设置limit为\", limit)\n            except Exception as e:\n                print(\"设置limit失败\", e)\n            continue\n        elif query == \"help\":\n            print(\"输入limit [数字]设置limit\")\n            print(\"输入quit退出\")\n            continue\n        answer,context = qa(query)\n        if args.print_context:\n            print(\"已找到相关片段：\")\n            for text in context:\n                print('\\t', text)\n            print(\"=====================================\")\n        print(\"回答如下\\n\\n\")\n        print(answer.strip())\n        print(\"=====================================\")\n\n"
  },
  {
    "path": "README.md",
    "content": "# Document_QA\n\n根据传入的文本文件，回答你的问题。\n\n核心逻辑来自于chatPDF，自动化客服AI，以及：[ChatWeb](https://github.com/SkywalkerDarren/chatWeb)\n\n由于原来的ChatWeb项目使用的是pqsql作为向量存储和计算工具，较为复杂，本项目修改成faiss，更简单快速。\n\n\n# 基本原理\n\n1. 读取文件，并进行分割\n2. 对于每段文本，使用text-embedding-ada-002生成特征向量\n3. 将向量和文本对应关系存入本地pkl文件\n4. 对于用户输入，生成向量\n5. 使用向量数据库进行最近邻搜索，返回最相似的文本列表\n6. 使用gpt3.5的chatAPI，设计prompt，使其基于最相似的文本列表进行回答\n\n就是先把大量文本中提取相关内容，再进行回答，最终可以达到类似突破token限制的效果  \n后续可以考虑将openai的文本向量改成自定义的向量生成工具\n\n# 准备开始\n\n- 项目依赖\n\n主要依赖\n```\nfaiss\nnumpy\nopenai\n```\n\n- 环境变量\n\n设置`OPENAI_API_KEY`为你的openai的api key\n\n```shell\nexport OPENAI_API_KEY=\"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\"\n```\n\n- 运行\n\n```\npython Document_QA.py --input_file test.md --file_embeding test.pkl\n```"
  }
]