Repository: fierceX/Document_QA Branch: main Commit: a7626faee7f3 Files: 2 Total size: 5.2 KB Directory structure: gitextract_t7cvnxfi/ ├── Document_QA.py └── README.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: Document_QA.py ================================================ import openai import faiss import numpy as np import pickle from tqdm import tqdm import argparse import os def create_embeddings(input): """Create embeddings for the provided input.""" result = [] # limit about 1000 tokens per request lens = [len(text) for text in input] query_len = 0 start_index = 0 tokens = 0 def get_embedding(input_slice): embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice) return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens for index, l in tqdm(enumerate(lens)): query_len += l if query_len > 4096: ebd, tk = get_embedding(input[start_index:index + 1]) query_len = 0 start_index = index + 1 tokens += tk result.extend(ebd) if query_len > 0: ebd, tk = get_embedding(input[start_index:]) tokens += tk result.extend(ebd) return result, tokens def create_embedding(text): """Create an embedding for the provided text.""" embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text) return text, embedding.data[0].embedding class QA(): def __init__(self,data_embe) -> None: d = 1536 index = faiss.IndexFlatL2(d) embe = np.array([emm[1] for emm in data_embe]) data = [emm[0] for emm in data_embe] index.add(embe) self.index = index self.data = data def __call__(self, query): embedding = create_embedding(query) context = self.get_texts(embedding[1], limit) answer = self.completion(query,context) return answer,context def get_texts(self,embeding,limit): _,text_index = self.index.search(np.array([embeding]),limit) context = [] for i in list(text_index[0]): context.extend(self.data[i:i+5]) # context = [self.data[i] for i in list(text_index[0])] return context def completion(self,query, context): """Create a completion.""" lens = [len(text) for text in context] maximum = 3000 for index, l in enumerate(lens): maximum -= l if maximum < 0: context = context[:index + 1] print("超过最大长度,截断到前", index + 1, "个片段") break text = "\n".join(f"{index}. {text}" for index, text in enumerate(context)) response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ {'role': 'system', 'content': f'你是一个有帮助的AI文章助手,从下文中提取有用的内容进行回答,不能回答不在下文提到的内容,相关性从高到底排序:\n\n{text}'}, {'role': 'user', 'content': query}, ], ) print("使用的tokens:", response.usage.total_tokens) return response.choices[0].message.content if __name__ == '__main__': parser = argparse.ArgumentParser(description="Document QA") parser.add_argument("--input_file", default="input.txt", dest="input_file", type=str,help="输入文件路径") parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径") parser.add_argument("--print_context", action='store_true',help="是否打印上下文") args = parser.parse_args() if os.path.isfile(args.file_embeding): data_embe = pickle.load(open(args.file_embeding,'rb')) else: with open(args.input_file,'r',encoding='utf-8') as f: texts = f.readlines() texts = [text.strip() for text in texts if text.strip()] data_embe,tokens = create_embeddings(texts) pickle.dump(data_embe,open(args.file_embeding,'wb')) print("文本消耗 {} tokens".format(tokens)) qa =QA(data_embe) limit = 10 while True: query = input("请输入查询(help可查看指令):") if query == "quit": break elif query.startswith("limit"): try: limit = int(query.split(" ")[1]) print("已设置limit为", limit) except Exception as e: print("设置limit失败", e) continue elif query == "help": print("输入limit [数字]设置limit") print("输入quit退出") continue answer,context = qa(query) if args.print_context: print("已找到相关片段:") for text in context: print('\t', text) print("=====================================") print("回答如下\n\n") print(answer.strip()) print("=====================================") ================================================ FILE: README.md ================================================ # Document_QA 根据传入的文本文件,回答你的问题。 核心逻辑来自于chatPDF,自动化客服AI,以及:[ChatWeb](https://github.com/SkywalkerDarren/chatWeb) 由于原来的ChatWeb项目使用的是pqsql作为向量存储和计算工具,较为复杂,本项目修改成faiss,更简单快速。 # 基本原理 1. 读取文件,并进行分割 2. 对于每段文本,使用text-embedding-ada-002生成特征向量 3. 将向量和文本对应关系存入本地pkl文件 4. 对于用户输入,生成向量 5. 使用向量数据库进行最近邻搜索,返回最相似的文本列表 6. 使用gpt3.5的chatAPI,设计prompt,使其基于最相似的文本列表进行回答 就是先把大量文本中提取相关内容,再进行回答,最终可以达到类似突破token限制的效果 后续可以考虑将openai的文本向量改成自定义的向量生成工具 # 准备开始 - 项目依赖 主要依赖 ``` faiss numpy openai ``` - 环境变量 设置`OPENAI_API_KEY`为你的openai的api key ```shell export OPENAI_API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" ``` - 运行 ``` python Document_QA.py --input_file test.md --file_embeding test.pkl ```