Repository: mobarski/ask-my-pdf Branch: main Commit: 470d31969331 Files: 16 Total size: 40.5 KB Directory structure: gitextract_zi78g5yh/ ├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt └── src/ ├── ai.py ├── cache.py ├── css.py ├── feedback.py ├── gui.py ├── model.py ├── pdf.py ├── prompts.py ├── run.bat ├── run.sh ├── stats.py └── storage.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ **/__pycache__ **/*secret*.* data/** ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Maciej Obarski Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Ask my PDF Thank you for your interest in my application. Please be aware that this is only a **Proof of Concept system** and may contain bugs or unfinished features. If you like this app you can ❤️ [follow me](https://twitter.com/KerbalFPV) on Twitter for news and updates. ### Ask my PDF - Question answering system built on top of GPT3 🎲 The primary use case for this app is to assist users in answering questions about board game rules based on the instruction manual. While the app can be used for other tasks, helping users with board game rules is particularly meaningful to me since I'm an avid fan of board games myself. Additionally, this use case is relatively harmless, even in cases where the model may experience hallucinations. 🌐 The app can be accessed on the Streamlit Community Cloud at https://ask-my-pdf.streamlit.app/. 🔑 However, to use the app, you will need your own [OpenAI's API key](https://platform.openai.com/account/api-keys). 📄 The app implements the following academic papers: - [In-Context Retrieval-Augmented Language Models](https://arxiv.org/abs/2302.00083) aka **RALM** - [Precise Zero-Shot Dense Retrieval without Relevance Labels](https://arxiv.org/abs/2212.10496) aka **HyDE** (Hypothetical Document Embeddings) ### Installation 1. Clone the repo: `git clone https://github.com/mobarski/ask-my-pdf` 2. Install dependencies: `pip install -r ask-my-pdf/requirements.txt` 3. Run the app: `cd ask-my-pdf/src` `run.sh` or `run.bat` ### High-level documentation #### RALM + HyDE ![RALM + HyDE](docs/ralm_hyde.jpg) #### RALM + HyDE + context ![RALM + HyDE + context](docs/ralm_hyde_wc.jpg) ### Environment variables used for configuration ##### General configuration: - **STORAGE_SALT** - cryptograpic salt used when deriving user/folder name and encryption key from API key, hexadecimal notation, 2-16 characters - **STORAGE_MODE** - index storage mode: S3, LOCAL, DICT (default) - **STATS_MODE** - usage stats storage mode: REDIS, DICT (default) - **FEEDBACK_MODE** - user feedback storage mode: REDIS, NONE (default) - **CACHE_MODE** - embeddings cache mode: S3, DISK, NONE (default) ##### Local filesystem configuration (storage / cache): - **STORAGE_PATH** - directory path for index storage - **CACHE_PATH** - directory path for embeddings cache ##### S3 configuration (storage / cache): - **S3_REGION** - region code - **S3_BUCKET** - bucket name (storage) - **S3_SECRET** - secret key - **S3_KEY** - access key - **S3_URL** - URL - **S3_PREFIX** - object name prefix - **S3_CACHE_BUCKET** - bucket name (cache) - **S3_CACHE_PREFIX** - object name prefix (cache) ##### Redis configuration (for persistent usage statistics / user feedback): - **REDIS_URL** - Redis DB URL (redis[s]://:password@host:port/[db]) ##### Community version related options: - **OPENAI_KEY** - API key used for the default user - **COMMUNITY_DAILY_USD** - default user's daily budget - **COMMUNITY_USER** - default user's code ================================================ FILE: requirements.txt ================================================ git+https://github.com/mobarski/ai-bricks.git streamlit pypdf scikit-learn numpy pycryptodome boto3 redis retry ================================================ FILE: src/ai.py ================================================ from ai_bricks.api import openai import stats import os DEFAULT_USER = os.getenv('COMMUNITY_USER','') def use_key(key): openai.use_key(key) usage_stats = stats.get_stats(user=DEFAULT_USER) def set_user(user): global usage_stats usage_stats = stats.get_stats(user=user) openai.set_global('user', user) openai.add_callback('after', stats_callback) def complete(text, **kw): model = kw.get('model','gpt-3.5-turbo') llm = openai.model(model) llm.config['pre_prompt'] = 'output only in raw text' # for chat models resp = llm.complete(text, **kw) resp['model'] = model return resp def embedding(text, **kw): model = kw.get('model','text-embedding-ada-002') llm = openai.model(model) resp = llm.embed(text, **kw) resp['model'] = model return resp def embeddings(texts, **kw): model = kw.get('model','text-embedding-ada-002') llm = openai.model(model) resp = llm.embed_many(texts, **kw) resp['model'] = model return resp tokenizer_model = openai.model('text-davinci-003') def get_token_count(text): return tokenizer_model.token_count(text) def stats_callback(out, resp, self): model = self.config['model'] usage = resp['usage'] usage['call_cnt'] = 1 if 'text' in out: usage['completion_chars'] = len(out['text']) elif 'texts' in out: usage['completion_chars'] = sum([len(text) for text in out['texts']]) # TODO: prompt_chars # TODO: total_chars if 'rtt' in out: usage['rtt'] = out['rtt'] usage['rtt_cnt'] = 1 usage_stats.incr(f'usage:v4:[date]:[user]', {f'{k}:{model}':v for k,v in usage.items()}) usage_stats.incr(f'hourly:v4:[date]', {f'{k}:{model}:[hour]':v for k,v in usage.items()}) #print('STATS_CALLBACK', usage, flush=True) # XXX def get_community_usage_cost(): data = usage_stats.get(f'usage:v4:[date]:{DEFAULT_USER}') used = 0.0 used += 0.04 * data.get('total_tokens:gpt-4',0) / 1000 # prompt_price=0.03 but output_price=0.06 used += 0.02 * data.get('total_tokens:text-davinci-003',0) / 1000 used += 0.002 * data.get('total_tokens:text-curie-001',0) / 1000 used += 0.002 * data.get('total_tokens:gpt-3.5-turbo',0) / 1000 used += 0.0004 * data.get('total_tokens:text-embedding-ada-002',0) / 1000 return used ================================================ FILE: src/cache.py ================================================ from retry import retry from binascii import hexlify,unhexlify import pickle import zlib import io import os # pip install boto3 import boto3 import botocore class Cache: "Dummy / Base Cache" def __init__(self): pass def put(self, key, obj): pass def get(self, key): return None def has(self, key): return False def delete(self, key): pass def serialize(self, obj): pickled = pickle.dumps(obj) compressed = self.compress(pickled) return compressed def deserialize(self, data): pickled = self.decompress(data) obj = pickle.loads(pickled) return obj def compress(self, data): return zlib.compress(data) def decompress(self, data): return zlib.decompress(data) def encode(self, name): return hexlify(name.encode('utf8')).decode('utf8') def decode(self, name): return unhexlify(name).decode('utf8') def call(self, key, fun, *a, **kw): if self.has(key): return self.get(key) else: resp = fun(*a, **kw) self.put(key, resp) return resp class DiskCache(Cache): "Local disk based cache" def __init__(self, root): self.root = root def path(self, key): return os.path.join(self.root, self.encode(key)) def put(self, key, obj): path = self.path(key) data = self.serialize(obj) with open(path, 'wb') as f: f.write(data) def get(self, key): path = self.path(key) with open(path, 'rb') as f: data = f.read() obj = self.deserialize(data) return obj def has(self, key): path = self.path(key) return os.path.exists(path) def delete(self, key): path = self.path(key) os.remove(path) class S3Cache(Cache): "S3 based cache" def __init__(self, **kw): bucket = kw.get('bucket') or os.getenv('S3_CACHE_BUCKET','ask-my-pdf') prefix = kw.get('prefix') or os.getenv('S3_CACHE_PREFIX','cache/x1') region = kw.get('region') or os.getenv('S3_REGION','sfo3') url = kw.get('url') or os.getenv('S3_URL',f'https://{region}.digitaloceanspaces.com') key = os.getenv('S3_KEY','') secret = os.getenv('S3_SECRET','') # if not key or not secret: raise Exception("No S3 credentials in environment variables!") # self.session = boto3.session.Session() self.s3 = self.session.client('s3', config=botocore.config.Config(s3={'addressing_style': 'virtual'}), region_name=region, endpoint_url=url, aws_access_key_id=key, aws_secret_access_key=secret, ) self.bucket = bucket self.prefix = prefix def get_s3_key(self, key): return f'{self.prefix}/{key}' def put(self, key, obj): s3_key = self.get_s3_key(key) data = self.serialize(obj) f = io.BytesIO(data) self.s3.upload_fileobj(f, self.bucket, s3_key) def get(self, key, default=None): s3_key = self.get_s3_key(key) f = io.BytesIO() try: self.s3.download_fileobj(self.bucket, s3_key, f) except: f.close() return default f.seek(0) data = f.read() obj = self.deserialize(data) return obj def has(self, key): s3_key = self.get_s3_key(key) try: self.s3.head_object(Bucket=self.bucket, Key=s3_key) return True except: return False def delete(self, key): self.s3.delete_object( Bucket = self.bucket, Key = self.get_s3_key(key)) def get_cache(**kw): mode = os.getenv('CACHE_MODE','').upper() path = os.getenv('CACHE_PATH','') if mode == 'DISK': return DiskCache(path) elif mode == 'S3': return S3Cache(**kw) else: return Cache() if __name__=="__main__": #cache = DiskCache('__pycache__') cache = S3Cache() cache.put('xxx',{'a':1,'b':22}) print('get xxx', cache.get('xxx')) print('has xxx', cache.has('xxx')) print('has yyy', cache.has('yyy')) print('delete xxx', cache.delete('xxx')) print('has xxx', cache.has('xxx')) print('get xxx', cache.get('xxx')) # ================================================ FILE: src/css.py ================================================ v1 = """ /* feedback checkbox */ .css-18fuwiq { position: relative; padding-top: 6px; } .css-949r0i { position: relative; padding-top: 6px; } """ ================================================ FILE: src/feedback.py ================================================ import datetime import hashlib import redis import os from retry import retry def hexdigest(text): return hashlib.md5(text.encode('utf8')).hexdigest() def as_int(x): return int(x) if x is not None else None class Feedback: "Dummy feedback adapter" def __init__(self, user): ... def send(self, score, ctx, details=False): ... def get_score(self): return 0 class RedisFeedback(Feedback): "Redis feedback adapter" def __init__(self, user): REDIS_URL = os.getenv('REDIS_URL') if not REDIS_URL: raise Exception('No Redis configuration in environment variables!') super().__init__(user) self.db = redis.Redis.from_url(REDIS_URL) self.user = user @retry(tries=5, delay=0.1) def send(self, score, ctx, details=False): p = self.db.pipeline() dist_list = ctx.get('debug',{}).get('model.query.resp',{}).get('dist_list',[]) # feedback index = ctx.get('index',{}) data = {} data['user'] = self.user data['task-prompt-version'] = ctx.get('task_name') data['model'] = ctx.get('model') data['model-embeddings'] = ctx.get('model_embed') data['task-prompt'] = ctx.get('task') data['temperature'] = ctx.get('temperature') data['frag-size'] = ctx.get('frag_size') data['frag-cnt'] = ctx.get('max_frags') data['frag-n-before'] = ctx.get('n_frag_before') data['frag-n-after'] = ctx.get('n_frag_after') data['filename'] = ctx.get('filename') data['filehash'] = index.get('hash') or index.get('filehash') data['filesize'] = index.get('filesize') data['n-pages'] = index.get('n_pages') data['n-texts'] = index.get('n_texts') data['use-hyde'] = as_int(ctx.get('use_hyde')) data['use-hyde-summary'] = as_int(ctx.get('use_hyde_summary')) data['question'] = ctx.get('question') data['answer'] = ctx.get('answer') data['hyde-summary'] = index.get('summary') data['resp-dist-list'] = '|'.join([f"{x:0.3f}" for x in dist_list]) fb_hash = hexdigest(str(list(sorted(data.items())))) # data['score'] = score data['datetime'] = str(datetime.datetime.now()) key1 = f'feedback:v2:{fb_hash}' if not details: for k in ['question','answer','hyde-summary']: data[k] = '' p.hset(key1, mapping=data) # feedback-daily date = datetime.date.today() key2 = f'feedback-daily:v1:{date}:{"positive" if score > 0 else "negative"}' p.sadd(key2, fb_hash) # feedback-score key3 = f'feedback-score:v2:{self.user}' p.sadd(key3, fb_hash) p.execute() @retry(tries=5, delay=0.1) def get_score(self): key = f'feedback-score:v2:{self.user}' return self.db.scard(key) def get_feedback_adapter(user): MODE = os.getenv('FEEDBACK_MODE','').upper() if MODE=='REDIS': return RedisFeedback(user) else: return Feedback(user) ================================================ FILE: src/gui.py ================================================ __version__ = "0.4.8.3" app_name = "Ask my PDF" # BOILERPLATE import streamlit as st st.set_page_config(layout='centered', page_title=f'{app_name} {__version__}') ss = st.session_state if 'debug' not in ss: ss['debug'] = {} import css st.write(f'', unsafe_allow_html=True) header1 = st.empty() # for errors / messages header2 = st.empty() # for errors / messages header3 = st.empty() # for errors / messages # IMPORTS import prompts import model import storage import feedback import cache import os from time import time as now # HANDLERS def on_api_key_change(): api_key = ss.get('api_key') or os.getenv('OPENAI_KEY') model.use_key(api_key) # TODO: empty api_key # if 'data_dict' not in ss: ss['data_dict'] = {} # used only with DictStorage ss['storage'] = storage.get_storage(api_key, data_dict=ss['data_dict']) ss['cache'] = cache.get_cache() ss['user'] = ss['storage'].folder # TODO: refactor user 'calculation' from get_storage model.set_user(ss['user']) ss['feedback'] = feedback.get_feedback_adapter(ss['user']) ss['feedback_score'] = ss['feedback'].get_score() # ss['debug']['storage.folder'] = ss['storage'].folder ss['debug']['storage.class'] = ss['storage'].__class__.__name__ ss['community_user'] = os.getenv('COMMUNITY_USER') if 'user' not in ss and ss['community_user']: on_api_key_change() # use community key # COMPONENTS def ui_spacer(n=2, line=False, next_n=0): for _ in range(n): st.write('') if line: st.tabs([' ']) for _ in range(next_n): st.write('') def ui_info(): st.markdown(f""" # Ask my PDF version {__version__} Question answering system built on top of GPT3. """) ui_spacer(1) st.write("Made by [Maciej Obarski](https://www.linkedin.com/in/mobarski/).", unsafe_allow_html=True) ui_spacer(1) st.markdown(""" Thank you for your interest in my application. Please be aware that this is only a Proof of Concept system and may contain bugs or unfinished features. If you like this app you can ❤️ [follow me](https://twitter.com/KerbalFPV) on Twitter for news and updates. """) ui_spacer(1) st.markdown('Source code can be found [here](https://github.com/mobarski/ask-my-pdf).') def ui_api_key(): if ss['community_user']: st.write('## 1. Optional - enter your OpenAI API key') t1,t2 = st.tabs(['community version','enter your own API key']) with t1: pct = model.community_tokens_available_pct() st.write(f'Community tokens available: :{"green" if pct else "red"}[{int(pct)}%]') st.progress(pct/100) st.write('Refresh in: ' + model.community_tokens_refresh_in()) st.write('You can sign up to OpenAI and/or create your API key [here](https://platform.openai.com/account/api-keys)') ss['community_pct'] = pct ss['debug']['community_pct'] = pct with t2: st.text_input('OpenAI API key', type='password', key='api_key', on_change=on_api_key_change, label_visibility="collapsed") else: st.write('## 1. Enter your OpenAI API key') st.text_input('OpenAI API key', type='password', key='api_key', on_change=on_api_key_change, label_visibility="collapsed") def index_pdf_file(): if ss['pdf_file']: ss['filename'] = ss['pdf_file'].name if ss['filename'] != ss.get('fielname_done'): # UGLY with st.spinner(f'indexing {ss["filename"]}'): index = model.index_file(ss['pdf_file'], ss['filename'], fix_text=ss['fix_text'], frag_size=ss['frag_size'], cache=ss['cache']) ss['index'] = index debug_index() ss['filename_done'] = ss['filename'] # UGLY def debug_index(): index = ss['index'] d = {} d['hash'] = index['hash'] d['frag_size'] = index['frag_size'] d['n_pages'] = len(index['pages']) d['n_texts'] = len(index['texts']) d['summary'] = index['summary'] d['pages'] = index['pages'] d['texts'] = index['texts'] d['time'] = index.get('time',{}) ss['debug']['index'] = d def ui_pdf_file(): st.write('## 2. Upload or select your PDF file') disabled = not ss.get('user') or (not ss.get('api_key') and not ss.get('community_pct',0)) t1,t2 = st.tabs(['UPLOAD','SELECT']) with t1: st.file_uploader('pdf file', type='pdf', key='pdf_file', disabled=disabled, on_change=index_pdf_file, label_visibility="collapsed") b_save() with t2: filenames = [''] if ss.get('storage'): filenames += ss['storage'].list() def on_change(): name = ss['selected_file'] if name and ss.get('storage'): with ss['spin_select_file']: with st.spinner('loading index'): t0 = now() index = ss['storage'].get(name) ss['debug']['storage_get_time'] = now()-t0 ss['filename'] = name # XXX ss['index'] = index debug_index() else: #ss['index'] = {} pass st.selectbox('select file', filenames, on_change=on_change, key='selected_file', label_visibility="collapsed", disabled=disabled) b_delete() ss['spin_select_file'] = st.empty() def ui_show_debug(): st.checkbox('show debug section', key='show_debug') def ui_fix_text(): st.checkbox('fix common PDF problems', value=True, key='fix_text') def ui_temperature(): #st.slider('temperature', 0.0, 1.0, 0.0, 0.1, key='temperature', format='%0.1f') ss['temperature'] = 0.0 def ui_fragments(): #st.number_input('fragment size', 0,2000,200, step=100, key='frag_size') st.selectbox('fragment size (characters)', [0,200,300,400,500,600,700,800,900,1000], index=3, key='frag_size') b_reindex() st.number_input('max fragments', 1, 10, 4, key='max_frags') st.number_input('fragments before', 0, 3, 1, key='n_frag_before') # TODO: pass to model st.number_input('fragments after', 0, 3, 1, key='n_frag_after') # TODO: pass to model def ui_model(): models = ['gpt-3.5-turbo','gpt-4','text-davinci-003','text-curie-001'] st.selectbox('main model', models, key='model', disabled=not ss.get('api_key')) st.selectbox('embedding model', ['text-embedding-ada-002'], key='model_embed') # FOR FUTURE USE def ui_hyde(): st.checkbox('use HyDE', value=True, key='use_hyde') def ui_hyde_summary(): st.checkbox('use summary in HyDE', value=True, key='use_hyde_summary') def ui_task_template(): st.selectbox('task prompt template', prompts.TASK.keys(), key='task_name') def ui_task(): x = ss['task_name'] st.text_area('task prompt', prompts.TASK[x], key='task') def ui_hyde_prompt(): st.text_area('HyDE prompt', prompts.HYDE, key='hyde_prompt') def ui_question(): st.write('## 3. Ask questions'+(f' to {ss["filename"]}' if ss.get('filename') else '')) disabled = False st.text_area('question', key='question', height=100, placeholder='Enter question here', help='', label_visibility="collapsed", disabled=disabled) # REF: Hypotetical Document Embeddings def ui_hyde_answer(): # TODO: enter or generate pass def ui_output(): output = ss.get('output','') st.markdown(output) def ui_debug(): if ss.get('show_debug'): st.write('### debug') st.write(ss.get('debug',{})) def b_ask(): c1,c2,c3,c4,c5 = st.columns([2,1,1,2,2]) if c2.button('👍', use_container_width=True, disabled=not ss.get('output')): ss['feedback'].send(+1, ss, details=ss['send_details']) ss['feedback_score'] = ss['feedback'].get_score() if c3.button('👎', use_container_width=True, disabled=not ss.get('output')): ss['feedback'].send(-1, ss, details=ss['send_details']) ss['feedback_score'] = ss['feedback'].get_score() score = ss.get('feedback_score',0) c5.write(f'feedback score: {score}') c4.checkbox('send details', True, key='send_details', help='allow question and the answer to be stored in the ask-my-pdf feedback database') #c1,c2,c3 = st.columns([1,3,1]) #c2.radio('zzz',['👍',r'...',r'👎'],horizontal=True,label_visibility="collapsed") # disabled = (not ss.get('api_key') and not ss.get('community_pct',0)) or not ss.get('index') if c1.button('get answer', disabled=disabled, type='primary', use_container_width=True): question = ss.get('question','') temperature = ss.get('temperature', 0.0) hyde = ss.get('use_hyde') hyde_prompt = ss.get('hyde_prompt') if ss.get('use_hyde_summary'): summary = ss['index']['summary'] hyde_prompt += f" Context: {summary}\n\n" task = ss.get('task') max_frags = ss.get('max_frags',1) n_before = ss.get('n_frag_before',0) n_after = ss.get('n_frag_after',0) index = ss.get('index',{}) with st.spinner('preparing answer'): resp = model.query(question, index, task=task, temperature=temperature, hyde=hyde, hyde_prompt=hyde_prompt, max_frags=max_frags, limit=max_frags+2, n_before=n_before, n_after=n_after, model=ss['model'], ) usage = resp.get('usage',{}) usage['cnt'] = 1 ss['debug']['model.query.resp'] = resp ss['debug']['resp.usage'] = usage ss['debug']['model.vector_query_time'] = resp['vector_query_time'] q = question.strip() a = resp['text'].strip() ss['answer'] = a output_add(q,a) st.experimental_rerun() # to enable the feedback buttons def b_clear(): if st.button('clear output'): ss['output'] = '' def b_reindex(): # TODO: disabled if st.button('reindex'): index_pdf_file() def b_reload(): if st.button('reload prompts'): import importlib importlib.reload(prompts) def b_save(): db = ss.get('storage') index = ss.get('index') name = ss.get('filename') api_key = ss.get('api_key') disabled = not api_key or not db or not index or not name help = "The file will be stored for about 90 days. Available only when using your own API key." if st.button('save encrypted index in ask-my-pdf', disabled=disabled, help=help): with st.spinner('saving to ask-my-pdf'): db.put(name, index) def b_delete(): db = ss.get('storage') name = ss.get('selected_file') # TODO: confirm delete if st.button('delete from ask-my-pdf', disabled=not db or not name): with st.spinner('deleting from ask-my-pdf'): db.delete(name) #st.experimental_rerun() def output_add(q,a): if 'output' not in ss: ss['output'] = '' q = q.replace('$',r'\$') a = a.replace('$',r'\$') new = f'#### {q}\n{a}\n\n' ss['output'] = new + ss['output'] # LAYOUT with st.sidebar: ui_info() ui_spacer(2) with st.expander('advanced'): ui_show_debug() b_clear() ui_model() ui_fragments() ui_fix_text() ui_hyde() ui_hyde_summary() ui_temperature() b_reload() ui_task_template() ui_task() ui_hyde_prompt() ui_api_key() ui_pdf_file() ui_question() ui_hyde_answer() b_ask() ui_output() ui_debug() ================================================ FILE: src/model.py ================================================ from sklearn.metrics.pairwise import cosine_distances import datetime from collections import Counter from time import time as now import hashlib import re import io import os import pdf import ai def use_key(api_key): ai.use_key(api_key) def set_user(user): ai.set_user(user) def query_by_vector(vector, index, limit=None): "return (ids, distances and texts) sorted by cosine distance" vectors = index['vectors'] texts = index['texts'] # sim = cosine_distances([vector], vectors)[0] # id_dist_list = list(enumerate(sim)) id_dist_list.sort(key=lambda x:x[1]) id_list = [x[0] for x in id_dist_list][:limit] dist_list = [x[1] for x in id_dist_list][:limit] text_list = [texts[x] for x in id_list] if texts else ['ERROR']*len(id_list) return id_list, dist_list, text_list def get_vectors(text_list): "transform texts into embedding vectors" batch_size = 128 vectors = [] usage = Counter() for i,texts in enumerate(batch(text_list, batch_size)): resp = ai.embeddings(texts) v = resp['vectors'] u = resp['usage'] u['call_cnt'] = 1 usage.update(u) vectors.extend(v) return {'vectors':vectors, 'usage':dict(usage), 'model':resp['model']} def index_file(f, filename, fix_text=False, frag_size=0, cache=None): "return vector index (dictionary) for a given PDF file" # calc md5 h = hashlib.md5() h.update(f.read()) md5 = h.hexdigest() filesize = f.tell() f.seek(0) # t0 = now() pages = pdf.pdf_to_pages(f) t1 = now() if fix_text: for i in range(len(pages)): pages[i] = fix_text_problems(pages[i]) texts = split_pages_into_fragments(pages, frag_size) t2 = now() if cache: cache_key = f'get_vectors:{md5}:{frag_size}:{fix_text}' resp = cache.call(cache_key, get_vectors, texts) else: resp = get_vectors(texts) t3 = now() vectors = resp['vectors'] summary_prompt = f"{texts[0]}\n\nDescribe the document from which the fragment is extracted. Omit any details.\n\n" # TODO: move to prompts.py summary = ai.complete(summary_prompt) t4 = now() usage = resp['usage'] out = {} out['frag_size'] = frag_size out['n_pages'] = len(pages) out['n_texts'] = len(texts) out['texts'] = texts out['pages'] = pages out['vectors'] = vectors out['summary'] = summary['text'] out['filename'] = filename out['filehash'] = f'md5:{md5}' out['filesize'] = filesize out['usage'] = usage out['model'] = resp['model'] out['time'] = {'pdf_to_pages':t1-t0, 'split_pages':t2-t1, 'get_vectors':t3-t2, 'summary':t4-t3} out['size'] = len(texts) # DEPRECATED -> filesize out['hash'] = f'md5:{md5}' # DEPRECATED -> filehash return out def split_pages_into_fragments(pages, frag_size): "split pages (list of texts) into smaller fragments (list of texts)" page_offset = [0] for p,page in enumerate(pages): page_offset += [page_offset[-1]+len(page)+1] # TODO: del page_offset[-1] ??? if frag_size: text = ' '.join(pages) return text_to_fragments(text, frag_size, page_offset) else: return pages def text_to_fragments(text, size, page_offset): "split single text into smaller fragments (list of texts)" if size and len(text)>size: out = [] pos = 0 page = 1 p_off = page_offset.copy()[1:] eos = find_eos(text) if len(text) not in eos: eos += [len(text)] for i in range(len(eos)): if eos[i]-pos>size: text_fragment = f'PAGE({page}):\n'+text[pos:eos[i]] out += [text_fragment] pos = eos[i] if eos[i]>p_off[0]: page += 1 del p_off[0] # ugly: last iter text_fragment = f'PAGE({page}):\n'+text[pos:eos[i]] out += [text_fragment] # out = [x for x in out if x] return out else: return [text] def find_eos(text): "return list of all end-of-sentence offsets" return [x.span()[1] for x in re.finditer('[.!?。]\s+',text)] ############################################################################### def fix_text_problems(text): "fix common text problems" text = re.sub('\s+[-]\s+','',text) # word continuation in the next line return text def query(text, index, task=None, temperature=0.0, max_frags=1, hyde=False, hyde_prompt=None, limit=None, n_before=1, n_after=1, model=None): "get dictionary with the answer for the given question (text)." out = {} if hyde: # TODO: model param out['hyde'] = hypotetical_answer(text, index, hyde_prompt=hyde_prompt, temperature=temperature) # TODO: usage # RANK FRAGMENTS if hyde: resp = ai.embedding(out['hyde']['text']) # TODO: usage else: resp = ai.embedding(text) # TODO: usage v = resp['vector'] t0 = now() id_list, dist_list, text_list = query_by_vector(v, index, limit=limit) dt0 = now()-t0 # BUILD PROMPT # select fragments N_BEFORE = 1 # TODO: param N_AFTER = 1 # TODO: param selected = {} # text id -> rank for rank,id in enumerate(id_list): for x in range(id-n_before, id+1+n_after): if x not in selected and x>=0 and x safe exceptions key = self.render(key) p = self.db.pipeline() for member,val in kv_dict.items(): member = self.render(member) self.db.zincrby(key, val, member) p.execute() @retry(tries=5, delay=0.1) def get(self, key): # TODO: non critical code -> safe exceptions key = self.render(key) items = self.db.zscan_iter(key) return {k.decode('utf8'):v for k,v in items} stats_data_dict = {} def get_stats(**kw): MODE = os.getenv('STATS_MODE','').upper() if MODE=='REDIS': stats = RedisStats() else: stats = DictStats(stats_data_dict) stats.config.update(kw) return stats if __name__=="__main__": s1 = get_stats(user='maciek') s1.incr('aaa:[date]:[user]', dict(a=1,b=2)) s1.incr('aaa:[date]:[user]', dict(a=1,b=2)) print(s1.data) print(s1.get('aaa:[date]:[user]')) # s2 = get_stats(user='kerbal') s2.incr('aaa:[date]:[user]', dict(a=1,b=2)) s2.incr('aaa:[date]:[user]', dict(a=1,b=2)) print(s2.data) print(s2.get('aaa:[date]:[user]')) ================================================ FILE: src/storage.py ================================================ "Storage adapter - one folder for each user / api_key" # pip install pycryptodome # REF: https://www.pycryptodome.org/src/cipher/aes from Crypto.Cipher import AES from Crypto.Util.Padding import pad,unpad from retry import retry from binascii import hexlify,unhexlify import hashlib import pickle import zlib import os import io # pip install boto3 import boto3 import botocore SALT = unhexlify(os.getenv('STORAGE_SALT','00')) class Storage: "Encrypted object storage (base class)" def __init__(self, secret_key): k = secret_key.encode() self.folder = hashlib.blake2s(k, salt=SALT, person=b'folder', digest_size=8).hexdigest() self.passwd = hashlib.blake2s(k, salt=SALT, person=b'passwd', digest_size=32).hexdigest() self.AES_MODE = AES.MODE_ECB # TODO: better AES mode ??? self.AES_BLOCK_SIZE = 16 def get(self, name, default=None): "get one object from the folder" safe_name = self.encode(name) data = self._get(safe_name) obj = self.deserialize(data) return obj def put(self, name, obj): "put the object into the folder" safe_name = self.encode(name) data = self.serialize(obj) self._put(safe_name, data) return data def list(self): "list object names from the folder" return [self.decode(name) for name in self._list()] def delete(self, name): "delete the object from the folder" safe_name = self.encode(name) self._delete(safe_name) # IMPLEMENTED IN SUBCLASSES def _put(self, name, data): ... def _get(self, name): ... def _delete(self, name): pass def _list(self): ... # # # def serialize(self, obj): raw = pickle.dumps(obj) compressed = self.compress(raw) encrypted = self.encrypt(compressed) return encrypted def deserialize(self, encrypted): compressed = self.decrypt(encrypted) raw = self.decompress(compressed) obj = pickle.loads(raw) return obj def encrypt(self, raw): cipher = AES.new(unhexlify(self.passwd), self.AES_MODE) return cipher.encrypt(pad(raw, self.AES_BLOCK_SIZE)) def decrypt(self, encrypted): cipher = AES.new(unhexlify(self.passwd), self.AES_MODE) return unpad(cipher.decrypt(encrypted), self.AES_BLOCK_SIZE) def compress(self, data): return zlib.compress(data) def decompress(self, data): return zlib.decompress(data) def encode(self, name): return hexlify(name.encode('utf8')).decode('utf8') def decode(self, name): return unhexlify(name).decode('utf8') class DictStorage(Storage): "Dictionary based storage" def __init__(self, secret_key, data_dict): super().__init__(secret_key) self.data = data_dict def _put(self, name, data): if self.folder not in self.data: self.data[self.folder] = {} self.data[self.folder][name] = data def _get(self, name): return self.data[self.folder][name] def _list(self): # TODO: sort by modification time (reverse=True) return list(self.data.get(self.folder,{}).keys()) def _delete(self, name): del self.data[self.folder][name] class LocalStorage(Storage): "Local filesystem based storage" def __init__(self, secret_key, path): if not path: raise Exception('No storage path in environment variables!') super().__init__(secret_key) self.path = os.path.join(path, self.folder) if not os.path.exists(self.path): os.makedirs(self.path) def _put(self, name, data): with open(os.path.join(self.path, name), 'wb') as f: f.write(data) def _get(self, name): with open(os.path.join(self.path, name), 'rb') as f: data = f.read() return data def _list(self): # TODO: sort by modification time (reverse=True) return os.listdir(self.path) def _delete(self, name): os.remove(os.path.join(self.path, name)) class S3Storage(Storage): "S3 based encrypted storage" def __init__(self, secret_key, **kw): prefix = kw.get('prefix') or os.getenv('S3_PREFIX','index/x1') region = kw.get('region') or os.getenv('S3_REGION','sfo3') bucket = kw.get('bucket') or os.getenv('S3_BUCKET','ask-my-pdf') url = kw.get('url') or os.getenv('S3_URL',f'https://{region}.digitaloceanspaces.com') key = os.getenv('S3_KEY','') secret = os.getenv('S3_SECRET','') # if not key or not secret: raise Exception("No S3 credentials in environment variables!") # super().__init__(secret_key) self.session = boto3.session.Session() self.s3 = self.session.client('s3', config=botocore.config.Config(s3={'addressing_style': 'virtual'}), region_name=region, endpoint_url=url, aws_access_key_id=key, aws_secret_access_key=secret, ) self.bucket = bucket self.prefix = prefix def get_key(self, name): return f'{self.prefix}/{self.folder}/{name}' def _put(self, name, data): key = self.get_key(name) f = io.BytesIO(data) self.s3.upload_fileobj(f, self.bucket, key) def _get(self, name): key = self.get_key(name) f = io.BytesIO() self.s3.download_fileobj(self.bucket, key, f) f.seek(0) return f.read() def _list(self): resp = self.s3.list_objects( Bucket=self.bucket, Prefix=self.get_key('') ) contents = resp.get('Contents',[]) contents.sort(key=lambda x:x['LastModified'], reverse=True) keys = [x['Key'] for x in contents] names = [x.split('/')[-1] for x in keys] return names def _delete(self, name): self.s3.delete_object( Bucket=self.bucket, Key=self.get_key(name) ) def get_storage(api_key, data_dict): "get storage adapter configured in environment variables" mode = os.getenv('STORAGE_MODE','').upper() path = os.getenv('STORAGE_PATH','') if mode=='S3': storage = S3Storage(api_key) elif mode=='LOCAL': storage = LocalStorage(api_key, path) else: storage = DictStorage(api_key, data_dict) return storage