Repository: mobarski/ask-my-pdf
Branch: main
Commit: 470d31969331
Files: 16
Total size: 40.5 KB
Directory structure:
gitextract_zi78g5yh/
├── .gitignore
├── LICENSE
├── README.md
├── requirements.txt
└── src/
├── ai.py
├── cache.py
├── css.py
├── feedback.py
├── gui.py
├── model.py
├── pdf.py
├── prompts.py
├── run.bat
├── run.sh
├── stats.py
└── storage.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
**/__pycache__
**/*secret*.*
data/**
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 Maciej Obarski
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Ask my PDF
Thank you for your interest in my application. Please be aware that this is only a **Proof of Concept system** and may contain bugs or unfinished features. If you like this app you can ❤️ [follow me](https://twitter.com/KerbalFPV) on Twitter for news and updates.
### Ask my PDF - Question answering system built on top of GPT3
🎲 The primary use case for this app is to assist users in answering questions about board game rules based on the instruction manual. While the app can be used for other tasks, helping users with board game rules is particularly meaningful to me since I'm an avid fan of board games myself. Additionally, this use case is relatively harmless, even in cases where the model may experience hallucinations.
🌐 The app can be accessed on the Streamlit Community Cloud at https://ask-my-pdf.streamlit.app/. 🔑 However, to use the app, you will need your own [OpenAI's API key](https://platform.openai.com/account/api-keys).
📄 The app implements the following academic papers:
- [In-Context Retrieval-Augmented Language Models](https://arxiv.org/abs/2302.00083) aka **RALM**
- [Precise Zero-Shot Dense Retrieval without Relevance Labels](https://arxiv.org/abs/2212.10496) aka **HyDE** (Hypothetical Document Embeddings)
### Installation
1. Clone the repo:
`git clone https://github.com/mobarski/ask-my-pdf`
2. Install dependencies:
`pip install -r ask-my-pdf/requirements.txt`
3. Run the app:
`cd ask-my-pdf/src`
`run.sh` or `run.bat`
### High-level documentation
#### RALM + HyDE

#### RALM + HyDE + context

### Environment variables used for configuration
##### General configuration:
- **STORAGE_SALT** - cryptograpic salt used when deriving user/folder name and encryption key from API key, hexadecimal notation, 2-16 characters
- **STORAGE_MODE** - index storage mode: S3, LOCAL, DICT (default)
- **STATS_MODE** - usage stats storage mode: REDIS, DICT (default)
- **FEEDBACK_MODE** - user feedback storage mode: REDIS, NONE (default)
- **CACHE_MODE** - embeddings cache mode: S3, DISK, NONE (default)
##### Local filesystem configuration (storage / cache):
- **STORAGE_PATH** - directory path for index storage
- **CACHE_PATH** - directory path for embeddings cache
##### S3 configuration (storage / cache):
- **S3_REGION** - region code
- **S3_BUCKET** - bucket name (storage)
- **S3_SECRET** - secret key
- **S3_KEY** - access key
- **S3_URL** - URL
- **S3_PREFIX** - object name prefix
- **S3_CACHE_BUCKET** - bucket name (cache)
- **S3_CACHE_PREFIX** - object name prefix (cache)
##### Redis configuration (for persistent usage statistics / user feedback):
- **REDIS_URL** - Redis DB URL (redis[s]://:password@host:port/[db])
##### Community version related options:
- **OPENAI_KEY** - API key used for the default user
- **COMMUNITY_DAILY_USD** - default user's daily budget
- **COMMUNITY_USER** - default user's code
================================================
FILE: requirements.txt
================================================
git+https://github.com/mobarski/ai-bricks.git
streamlit
pypdf
scikit-learn
numpy
pycryptodome
boto3
redis
retry
================================================
FILE: src/ai.py
================================================
from ai_bricks.api import openai
import stats
import os
DEFAULT_USER = os.getenv('COMMUNITY_USER','')
def use_key(key):
openai.use_key(key)
usage_stats = stats.get_stats(user=DEFAULT_USER)
def set_user(user):
global usage_stats
usage_stats = stats.get_stats(user=user)
openai.set_global('user', user)
openai.add_callback('after', stats_callback)
def complete(text, **kw):
model = kw.get('model','gpt-3.5-turbo')
llm = openai.model(model)
llm.config['pre_prompt'] = 'output only in raw text' # for chat models
resp = llm.complete(text, **kw)
resp['model'] = model
return resp
def embedding(text, **kw):
model = kw.get('model','text-embedding-ada-002')
llm = openai.model(model)
resp = llm.embed(text, **kw)
resp['model'] = model
return resp
def embeddings(texts, **kw):
model = kw.get('model','text-embedding-ada-002')
llm = openai.model(model)
resp = llm.embed_many(texts, **kw)
resp['model'] = model
return resp
tokenizer_model = openai.model('text-davinci-003')
def get_token_count(text):
return tokenizer_model.token_count(text)
def stats_callback(out, resp, self):
model = self.config['model']
usage = resp['usage']
usage['call_cnt'] = 1
if 'text' in out:
usage['completion_chars'] = len(out['text'])
elif 'texts' in out:
usage['completion_chars'] = sum([len(text) for text in out['texts']])
# TODO: prompt_chars
# TODO: total_chars
if 'rtt' in out:
usage['rtt'] = out['rtt']
usage['rtt_cnt'] = 1
usage_stats.incr(f'usage:v4:[date]:[user]', {f'{k}:{model}':v for k,v in usage.items()})
usage_stats.incr(f'hourly:v4:[date]', {f'{k}:{model}:[hour]':v for k,v in usage.items()})
#print('STATS_CALLBACK', usage, flush=True) # XXX
def get_community_usage_cost():
data = usage_stats.get(f'usage:v4:[date]:{DEFAULT_USER}')
used = 0.0
used += 0.04 * data.get('total_tokens:gpt-4',0) / 1000 # prompt_price=0.03 but output_price=0.06
used += 0.02 * data.get('total_tokens:text-davinci-003',0) / 1000
used += 0.002 * data.get('total_tokens:text-curie-001',0) / 1000
used += 0.002 * data.get('total_tokens:gpt-3.5-turbo',0) / 1000
used += 0.0004 * data.get('total_tokens:text-embedding-ada-002',0) / 1000
return used
================================================
FILE: src/cache.py
================================================
from retry import retry
from binascii import hexlify,unhexlify
import pickle
import zlib
import io
import os
# pip install boto3
import boto3
import botocore
class Cache:
"Dummy / Base Cache"
def __init__(self):
pass
def put(self, key, obj):
pass
def get(self, key):
return None
def has(self, key):
return False
def delete(self, key):
pass
def serialize(self, obj):
pickled = pickle.dumps(obj)
compressed = self.compress(pickled)
return compressed
def deserialize(self, data):
pickled = self.decompress(data)
obj = pickle.loads(pickled)
return obj
def compress(self, data):
return zlib.compress(data)
def decompress(self, data):
return zlib.decompress(data)
def encode(self, name):
return hexlify(name.encode('utf8')).decode('utf8')
def decode(self, name):
return unhexlify(name).decode('utf8')
def call(self, key, fun, *a, **kw):
if self.has(key):
return self.get(key)
else:
resp = fun(*a, **kw)
self.put(key, resp)
return resp
class DiskCache(Cache):
"Local disk based cache"
def __init__(self, root):
self.root = root
def path(self, key):
return os.path.join(self.root, self.encode(key))
def put(self, key, obj):
path = self.path(key)
data = self.serialize(obj)
with open(path, 'wb') as f:
f.write(data)
def get(self, key):
path = self.path(key)
with open(path, 'rb') as f:
data = f.read()
obj = self.deserialize(data)
return obj
def has(self, key):
path = self.path(key)
return os.path.exists(path)
def delete(self, key):
path = self.path(key)
os.remove(path)
class S3Cache(Cache):
"S3 based cache"
def __init__(self, **kw):
bucket = kw.get('bucket') or os.getenv('S3_CACHE_BUCKET','ask-my-pdf')
prefix = kw.get('prefix') or os.getenv('S3_CACHE_PREFIX','cache/x1')
region = kw.get('region') or os.getenv('S3_REGION','sfo3')
url = kw.get('url') or os.getenv('S3_URL',f'https://{region}.digitaloceanspaces.com')
key = os.getenv('S3_KEY','')
secret = os.getenv('S3_SECRET','')
#
if not key or not secret:
raise Exception("No S3 credentials in environment variables!")
#
self.session = boto3.session.Session()
self.s3 = self.session.client('s3',
config=botocore.config.Config(s3={'addressing_style': 'virtual'}),
region_name=region,
endpoint_url=url,
aws_access_key_id=key,
aws_secret_access_key=secret,
)
self.bucket = bucket
self.prefix = prefix
def get_s3_key(self, key):
return f'{self.prefix}/{key}'
def put(self, key, obj):
s3_key = self.get_s3_key(key)
data = self.serialize(obj)
f = io.BytesIO(data)
self.s3.upload_fileobj(f, self.bucket, s3_key)
def get(self, key, default=None):
s3_key = self.get_s3_key(key)
f = io.BytesIO()
try:
self.s3.download_fileobj(self.bucket, s3_key, f)
except:
f.close()
return default
f.seek(0)
data = f.read()
obj = self.deserialize(data)
return obj
def has(self, key):
s3_key = self.get_s3_key(key)
try:
self.s3.head_object(Bucket=self.bucket, Key=s3_key)
return True
except:
return False
def delete(self, key):
self.s3.delete_object(
Bucket = self.bucket,
Key = self.get_s3_key(key))
def get_cache(**kw):
mode = os.getenv('CACHE_MODE','').upper()
path = os.getenv('CACHE_PATH','')
if mode == 'DISK':
return DiskCache(path)
elif mode == 'S3':
return S3Cache(**kw)
else:
return Cache()
if __name__=="__main__":
#cache = DiskCache('__pycache__')
cache = S3Cache()
cache.put('xxx',{'a':1,'b':22})
print('get xxx', cache.get('xxx'))
print('has xxx', cache.has('xxx'))
print('has yyy', cache.has('yyy'))
print('delete xxx', cache.delete('xxx'))
print('has xxx', cache.has('xxx'))
print('get xxx', cache.get('xxx'))
#
================================================
FILE: src/css.py
================================================
v1 = """
/* feedback checkbox */
.css-18fuwiq {
position: relative;
padding-top: 6px;
}
.css-949r0i {
position: relative;
padding-top: 6px;
}
"""
================================================
FILE: src/feedback.py
================================================
import datetime
import hashlib
import redis
import os
from retry import retry
def hexdigest(text):
return hashlib.md5(text.encode('utf8')).hexdigest()
def as_int(x):
return int(x) if x is not None else None
class Feedback:
"Dummy feedback adapter"
def __init__(self, user):
...
def send(self, score, ctx, details=False):
...
def get_score(self):
return 0
class RedisFeedback(Feedback):
"Redis feedback adapter"
def __init__(self, user):
REDIS_URL = os.getenv('REDIS_URL')
if not REDIS_URL:
raise Exception('No Redis configuration in environment variables!')
super().__init__(user)
self.db = redis.Redis.from_url(REDIS_URL)
self.user = user
@retry(tries=5, delay=0.1)
def send(self, score, ctx, details=False):
p = self.db.pipeline()
dist_list = ctx.get('debug',{}).get('model.query.resp',{}).get('dist_list',[])
# feedback
index = ctx.get('index',{})
data = {}
data['user'] = self.user
data['task-prompt-version'] = ctx.get('task_name')
data['model'] = ctx.get('model')
data['model-embeddings'] = ctx.get('model_embed')
data['task-prompt'] = ctx.get('task')
data['temperature'] = ctx.get('temperature')
data['frag-size'] = ctx.get('frag_size')
data['frag-cnt'] = ctx.get('max_frags')
data['frag-n-before'] = ctx.get('n_frag_before')
data['frag-n-after'] = ctx.get('n_frag_after')
data['filename'] = ctx.get('filename')
data['filehash'] = index.get('hash') or index.get('filehash')
data['filesize'] = index.get('filesize')
data['n-pages'] = index.get('n_pages')
data['n-texts'] = index.get('n_texts')
data['use-hyde'] = as_int(ctx.get('use_hyde'))
data['use-hyde-summary'] = as_int(ctx.get('use_hyde_summary'))
data['question'] = ctx.get('question')
data['answer'] = ctx.get('answer')
data['hyde-summary'] = index.get('summary')
data['resp-dist-list'] = '|'.join([f"{x:0.3f}" for x in dist_list])
fb_hash = hexdigest(str(list(sorted(data.items()))))
#
data['score'] = score
data['datetime'] = str(datetime.datetime.now())
key1 = f'feedback:v2:{fb_hash}'
if not details:
for k in ['question','answer','hyde-summary']:
data[k] = ''
p.hset(key1, mapping=data)
# feedback-daily
date = datetime.date.today()
key2 = f'feedback-daily:v1:{date}:{"positive" if score > 0 else "negative"}'
p.sadd(key2, fb_hash)
# feedback-score
key3 = f'feedback-score:v2:{self.user}'
p.sadd(key3, fb_hash)
p.execute()
@retry(tries=5, delay=0.1)
def get_score(self):
key = f'feedback-score:v2:{self.user}'
return self.db.scard(key)
def get_feedback_adapter(user):
MODE = os.getenv('FEEDBACK_MODE','').upper()
if MODE=='REDIS':
return RedisFeedback(user)
else:
return Feedback(user)
================================================
FILE: src/gui.py
================================================
__version__ = "0.4.8.3"
app_name = "Ask my PDF"
# BOILERPLATE
import streamlit as st
st.set_page_config(layout='centered', page_title=f'{app_name} {__version__}')
ss = st.session_state
if 'debug' not in ss: ss['debug'] = {}
import css
st.write(f'', unsafe_allow_html=True)
header1 = st.empty() # for errors / messages
header2 = st.empty() # for errors / messages
header3 = st.empty() # for errors / messages
# IMPORTS
import prompts
import model
import storage
import feedback
import cache
import os
from time import time as now
# HANDLERS
def on_api_key_change():
api_key = ss.get('api_key') or os.getenv('OPENAI_KEY')
model.use_key(api_key) # TODO: empty api_key
#
if 'data_dict' not in ss: ss['data_dict'] = {} # used only with DictStorage
ss['storage'] = storage.get_storage(api_key, data_dict=ss['data_dict'])
ss['cache'] = cache.get_cache()
ss['user'] = ss['storage'].folder # TODO: refactor user 'calculation' from get_storage
model.set_user(ss['user'])
ss['feedback'] = feedback.get_feedback_adapter(ss['user'])
ss['feedback_score'] = ss['feedback'].get_score()
#
ss['debug']['storage.folder'] = ss['storage'].folder
ss['debug']['storage.class'] = ss['storage'].__class__.__name__
ss['community_user'] = os.getenv('COMMUNITY_USER')
if 'user' not in ss and ss['community_user']:
on_api_key_change() # use community key
# COMPONENTS
def ui_spacer(n=2, line=False, next_n=0):
for _ in range(n):
st.write('')
if line:
st.tabs([' '])
for _ in range(next_n):
st.write('')
def ui_info():
st.markdown(f"""
# Ask my PDF
version {__version__}
Question answering system built on top of GPT3.
""")
ui_spacer(1)
st.write("Made by [Maciej Obarski](https://www.linkedin.com/in/mobarski/).", unsafe_allow_html=True)
ui_spacer(1)
st.markdown("""
Thank you for your interest in my application.
Please be aware that this is only a Proof of Concept system
and may contain bugs or unfinished features.
If you like this app you can ❤️ [follow me](https://twitter.com/KerbalFPV)
on Twitter for news and updates.
""")
ui_spacer(1)
st.markdown('Source code can be found [here](https://github.com/mobarski/ask-my-pdf).')
def ui_api_key():
if ss['community_user']:
st.write('## 1. Optional - enter your OpenAI API key')
t1,t2 = st.tabs(['community version','enter your own API key'])
with t1:
pct = model.community_tokens_available_pct()
st.write(f'Community tokens available: :{"green" if pct else "red"}[{int(pct)}%]')
st.progress(pct/100)
st.write('Refresh in: ' + model.community_tokens_refresh_in())
st.write('You can sign up to OpenAI and/or create your API key [here](https://platform.openai.com/account/api-keys)')
ss['community_pct'] = pct
ss['debug']['community_pct'] = pct
with t2:
st.text_input('OpenAI API key', type='password', key='api_key', on_change=on_api_key_change, label_visibility="collapsed")
else:
st.write('## 1. Enter your OpenAI API key')
st.text_input('OpenAI API key', type='password', key='api_key', on_change=on_api_key_change, label_visibility="collapsed")
def index_pdf_file():
if ss['pdf_file']:
ss['filename'] = ss['pdf_file'].name
if ss['filename'] != ss.get('fielname_done'): # UGLY
with st.spinner(f'indexing {ss["filename"]}'):
index = model.index_file(ss['pdf_file'], ss['filename'], fix_text=ss['fix_text'], frag_size=ss['frag_size'], cache=ss['cache'])
ss['index'] = index
debug_index()
ss['filename_done'] = ss['filename'] # UGLY
def debug_index():
index = ss['index']
d = {}
d['hash'] = index['hash']
d['frag_size'] = index['frag_size']
d['n_pages'] = len(index['pages'])
d['n_texts'] = len(index['texts'])
d['summary'] = index['summary']
d['pages'] = index['pages']
d['texts'] = index['texts']
d['time'] = index.get('time',{})
ss['debug']['index'] = d
def ui_pdf_file():
st.write('## 2. Upload or select your PDF file')
disabled = not ss.get('user') or (not ss.get('api_key') and not ss.get('community_pct',0))
t1,t2 = st.tabs(['UPLOAD','SELECT'])
with t1:
st.file_uploader('pdf file', type='pdf', key='pdf_file', disabled=disabled, on_change=index_pdf_file, label_visibility="collapsed")
b_save()
with t2:
filenames = ['']
if ss.get('storage'):
filenames += ss['storage'].list()
def on_change():
name = ss['selected_file']
if name and ss.get('storage'):
with ss['spin_select_file']:
with st.spinner('loading index'):
t0 = now()
index = ss['storage'].get(name)
ss['debug']['storage_get_time'] = now()-t0
ss['filename'] = name # XXX
ss['index'] = index
debug_index()
else:
#ss['index'] = {}
pass
st.selectbox('select file', filenames, on_change=on_change, key='selected_file', label_visibility="collapsed", disabled=disabled)
b_delete()
ss['spin_select_file'] = st.empty()
def ui_show_debug():
st.checkbox('show debug section', key='show_debug')
def ui_fix_text():
st.checkbox('fix common PDF problems', value=True, key='fix_text')
def ui_temperature():
#st.slider('temperature', 0.0, 1.0, 0.0, 0.1, key='temperature', format='%0.1f')
ss['temperature'] = 0.0
def ui_fragments():
#st.number_input('fragment size', 0,2000,200, step=100, key='frag_size')
st.selectbox('fragment size (characters)', [0,200,300,400,500,600,700,800,900,1000], index=3, key='frag_size')
b_reindex()
st.number_input('max fragments', 1, 10, 4, key='max_frags')
st.number_input('fragments before', 0, 3, 1, key='n_frag_before') # TODO: pass to model
st.number_input('fragments after', 0, 3, 1, key='n_frag_after') # TODO: pass to model
def ui_model():
models = ['gpt-3.5-turbo','gpt-4','text-davinci-003','text-curie-001']
st.selectbox('main model', models, key='model', disabled=not ss.get('api_key'))
st.selectbox('embedding model', ['text-embedding-ada-002'], key='model_embed') # FOR FUTURE USE
def ui_hyde():
st.checkbox('use HyDE', value=True, key='use_hyde')
def ui_hyde_summary():
st.checkbox('use summary in HyDE', value=True, key='use_hyde_summary')
def ui_task_template():
st.selectbox('task prompt template', prompts.TASK.keys(), key='task_name')
def ui_task():
x = ss['task_name']
st.text_area('task prompt', prompts.TASK[x], key='task')
def ui_hyde_prompt():
st.text_area('HyDE prompt', prompts.HYDE, key='hyde_prompt')
def ui_question():
st.write('## 3. Ask questions'+(f' to {ss["filename"]}' if ss.get('filename') else ''))
disabled = False
st.text_area('question', key='question', height=100, placeholder='Enter question here', help='', label_visibility="collapsed", disabled=disabled)
# REF: Hypotetical Document Embeddings
def ui_hyde_answer():
# TODO: enter or generate
pass
def ui_output():
output = ss.get('output','')
st.markdown(output)
def ui_debug():
if ss.get('show_debug'):
st.write('### debug')
st.write(ss.get('debug',{}))
def b_ask():
c1,c2,c3,c4,c5 = st.columns([2,1,1,2,2])
if c2.button('👍', use_container_width=True, disabled=not ss.get('output')):
ss['feedback'].send(+1, ss, details=ss['send_details'])
ss['feedback_score'] = ss['feedback'].get_score()
if c3.button('👎', use_container_width=True, disabled=not ss.get('output')):
ss['feedback'].send(-1, ss, details=ss['send_details'])
ss['feedback_score'] = ss['feedback'].get_score()
score = ss.get('feedback_score',0)
c5.write(f'feedback score: {score}')
c4.checkbox('send details', True, key='send_details',
help='allow question and the answer to be stored in the ask-my-pdf feedback database')
#c1,c2,c3 = st.columns([1,3,1])
#c2.radio('zzz',['👍',r'...',r'👎'],horizontal=True,label_visibility="collapsed")
#
disabled = (not ss.get('api_key') and not ss.get('community_pct',0)) or not ss.get('index')
if c1.button('get answer', disabled=disabled, type='primary', use_container_width=True):
question = ss.get('question','')
temperature = ss.get('temperature', 0.0)
hyde = ss.get('use_hyde')
hyde_prompt = ss.get('hyde_prompt')
if ss.get('use_hyde_summary'):
summary = ss['index']['summary']
hyde_prompt += f" Context: {summary}\n\n"
task = ss.get('task')
max_frags = ss.get('max_frags',1)
n_before = ss.get('n_frag_before',0)
n_after = ss.get('n_frag_after',0)
index = ss.get('index',{})
with st.spinner('preparing answer'):
resp = model.query(question, index,
task=task,
temperature=temperature,
hyde=hyde,
hyde_prompt=hyde_prompt,
max_frags=max_frags,
limit=max_frags+2,
n_before=n_before,
n_after=n_after,
model=ss['model'],
)
usage = resp.get('usage',{})
usage['cnt'] = 1
ss['debug']['model.query.resp'] = resp
ss['debug']['resp.usage'] = usage
ss['debug']['model.vector_query_time'] = resp['vector_query_time']
q = question.strip()
a = resp['text'].strip()
ss['answer'] = a
output_add(q,a)
st.experimental_rerun() # to enable the feedback buttons
def b_clear():
if st.button('clear output'):
ss['output'] = ''
def b_reindex():
# TODO: disabled
if st.button('reindex'):
index_pdf_file()
def b_reload():
if st.button('reload prompts'):
import importlib
importlib.reload(prompts)
def b_save():
db = ss.get('storage')
index = ss.get('index')
name = ss.get('filename')
api_key = ss.get('api_key')
disabled = not api_key or not db or not index or not name
help = "The file will be stored for about 90 days. Available only when using your own API key."
if st.button('save encrypted index in ask-my-pdf', disabled=disabled, help=help):
with st.spinner('saving to ask-my-pdf'):
db.put(name, index)
def b_delete():
db = ss.get('storage')
name = ss.get('selected_file')
# TODO: confirm delete
if st.button('delete from ask-my-pdf', disabled=not db or not name):
with st.spinner('deleting from ask-my-pdf'):
db.delete(name)
#st.experimental_rerun()
def output_add(q,a):
if 'output' not in ss: ss['output'] = ''
q = q.replace('$',r'\$')
a = a.replace('$',r'\$')
new = f'#### {q}\n{a}\n\n'
ss['output'] = new + ss['output']
# LAYOUT
with st.sidebar:
ui_info()
ui_spacer(2)
with st.expander('advanced'):
ui_show_debug()
b_clear()
ui_model()
ui_fragments()
ui_fix_text()
ui_hyde()
ui_hyde_summary()
ui_temperature()
b_reload()
ui_task_template()
ui_task()
ui_hyde_prompt()
ui_api_key()
ui_pdf_file()
ui_question()
ui_hyde_answer()
b_ask()
ui_output()
ui_debug()
================================================
FILE: src/model.py
================================================
from sklearn.metrics.pairwise import cosine_distances
import datetime
from collections import Counter
from time import time as now
import hashlib
import re
import io
import os
import pdf
import ai
def use_key(api_key):
ai.use_key(api_key)
def set_user(user):
ai.set_user(user)
def query_by_vector(vector, index, limit=None):
"return (ids, distances and texts) sorted by cosine distance"
vectors = index['vectors']
texts = index['texts']
#
sim = cosine_distances([vector], vectors)[0]
#
id_dist_list = list(enumerate(sim))
id_dist_list.sort(key=lambda x:x[1])
id_list = [x[0] for x in id_dist_list][:limit]
dist_list = [x[1] for x in id_dist_list][:limit]
text_list = [texts[x] for x in id_list] if texts else ['ERROR']*len(id_list)
return id_list, dist_list, text_list
def get_vectors(text_list):
"transform texts into embedding vectors"
batch_size = 128
vectors = []
usage = Counter()
for i,texts in enumerate(batch(text_list, batch_size)):
resp = ai.embeddings(texts)
v = resp['vectors']
u = resp['usage']
u['call_cnt'] = 1
usage.update(u)
vectors.extend(v)
return {'vectors':vectors, 'usage':dict(usage), 'model':resp['model']}
def index_file(f, filename, fix_text=False, frag_size=0, cache=None):
"return vector index (dictionary) for a given PDF file"
# calc md5
h = hashlib.md5()
h.update(f.read())
md5 = h.hexdigest()
filesize = f.tell()
f.seek(0)
#
t0 = now()
pages = pdf.pdf_to_pages(f)
t1 = now()
if fix_text:
for i in range(len(pages)):
pages[i] = fix_text_problems(pages[i])
texts = split_pages_into_fragments(pages, frag_size)
t2 = now()
if cache:
cache_key = f'get_vectors:{md5}:{frag_size}:{fix_text}'
resp = cache.call(cache_key, get_vectors, texts)
else:
resp = get_vectors(texts)
t3 = now()
vectors = resp['vectors']
summary_prompt = f"{texts[0]}\n\nDescribe the document from which the fragment is extracted. Omit any details.\n\n" # TODO: move to prompts.py
summary = ai.complete(summary_prompt)
t4 = now()
usage = resp['usage']
out = {}
out['frag_size'] = frag_size
out['n_pages'] = len(pages)
out['n_texts'] = len(texts)
out['texts'] = texts
out['pages'] = pages
out['vectors'] = vectors
out['summary'] = summary['text']
out['filename'] = filename
out['filehash'] = f'md5:{md5}'
out['filesize'] = filesize
out['usage'] = usage
out['model'] = resp['model']
out['time'] = {'pdf_to_pages':t1-t0, 'split_pages':t2-t1, 'get_vectors':t3-t2, 'summary':t4-t3}
out['size'] = len(texts) # DEPRECATED -> filesize
out['hash'] = f'md5:{md5}' # DEPRECATED -> filehash
return out
def split_pages_into_fragments(pages, frag_size):
"split pages (list of texts) into smaller fragments (list of texts)"
page_offset = [0]
for p,page in enumerate(pages):
page_offset += [page_offset[-1]+len(page)+1]
# TODO: del page_offset[-1] ???
if frag_size:
text = ' '.join(pages)
return text_to_fragments(text, frag_size, page_offset)
else:
return pages
def text_to_fragments(text, size, page_offset):
"split single text into smaller fragments (list of texts)"
if size and len(text)>size:
out = []
pos = 0
page = 1
p_off = page_offset.copy()[1:]
eos = find_eos(text)
if len(text) not in eos:
eos += [len(text)]
for i in range(len(eos)):
if eos[i]-pos>size:
text_fragment = f'PAGE({page}):\n'+text[pos:eos[i]]
out += [text_fragment]
pos = eos[i]
if eos[i]>p_off[0]:
page += 1
del p_off[0]
# ugly: last iter
text_fragment = f'PAGE({page}):\n'+text[pos:eos[i]]
out += [text_fragment]
#
out = [x for x in out if x]
return out
else:
return [text]
def find_eos(text):
"return list of all end-of-sentence offsets"
return [x.span()[1] for x in re.finditer('[.!?。]\s+',text)]
###############################################################################
def fix_text_problems(text):
"fix common text problems"
text = re.sub('\s+[-]\s+','',text) # word continuation in the next line
return text
def query(text, index, task=None, temperature=0.0, max_frags=1, hyde=False, hyde_prompt=None, limit=None, n_before=1, n_after=1, model=None):
"get dictionary with the answer for the given question (text)."
out = {}
if hyde:
# TODO: model param
out['hyde'] = hypotetical_answer(text, index, hyde_prompt=hyde_prompt, temperature=temperature)
# TODO: usage
# RANK FRAGMENTS
if hyde:
resp = ai.embedding(out['hyde']['text'])
# TODO: usage
else:
resp = ai.embedding(text)
# TODO: usage
v = resp['vector']
t0 = now()
id_list, dist_list, text_list = query_by_vector(v, index, limit=limit)
dt0 = now()-t0
# BUILD PROMPT
# select fragments
N_BEFORE = 1 # TODO: param
N_AFTER = 1 # TODO: param
selected = {} # text id -> rank
for rank,id in enumerate(id_list):
for x in range(id-n_before, id+1+n_after):
if x not in selected and x>=0 and x safe exceptions
key = self.render(key)
p = self.db.pipeline()
for member,val in kv_dict.items():
member = self.render(member)
self.db.zincrby(key, val, member)
p.execute()
@retry(tries=5, delay=0.1)
def get(self, key):
# TODO: non critical code -> safe exceptions
key = self.render(key)
items = self.db.zscan_iter(key)
return {k.decode('utf8'):v for k,v in items}
stats_data_dict = {}
def get_stats(**kw):
MODE = os.getenv('STATS_MODE','').upper()
if MODE=='REDIS':
stats = RedisStats()
else:
stats = DictStats(stats_data_dict)
stats.config.update(kw)
return stats
if __name__=="__main__":
s1 = get_stats(user='maciek')
s1.incr('aaa:[date]:[user]', dict(a=1,b=2))
s1.incr('aaa:[date]:[user]', dict(a=1,b=2))
print(s1.data)
print(s1.get('aaa:[date]:[user]'))
#
s2 = get_stats(user='kerbal')
s2.incr('aaa:[date]:[user]', dict(a=1,b=2))
s2.incr('aaa:[date]:[user]', dict(a=1,b=2))
print(s2.data)
print(s2.get('aaa:[date]:[user]'))
================================================
FILE: src/storage.py
================================================
"Storage adapter - one folder for each user / api_key"
# pip install pycryptodome
# REF: https://www.pycryptodome.org/src/cipher/aes
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad,unpad
from retry import retry
from binascii import hexlify,unhexlify
import hashlib
import pickle
import zlib
import os
import io
# pip install boto3
import boto3
import botocore
SALT = unhexlify(os.getenv('STORAGE_SALT','00'))
class Storage:
"Encrypted object storage (base class)"
def __init__(self, secret_key):
k = secret_key.encode()
self.folder = hashlib.blake2s(k, salt=SALT, person=b'folder', digest_size=8).hexdigest()
self.passwd = hashlib.blake2s(k, salt=SALT, person=b'passwd', digest_size=32).hexdigest()
self.AES_MODE = AES.MODE_ECB # TODO: better AES mode ???
self.AES_BLOCK_SIZE = 16
def get(self, name, default=None):
"get one object from the folder"
safe_name = self.encode(name)
data = self._get(safe_name)
obj = self.deserialize(data)
return obj
def put(self, name, obj):
"put the object into the folder"
safe_name = self.encode(name)
data = self.serialize(obj)
self._put(safe_name, data)
return data
def list(self):
"list object names from the folder"
return [self.decode(name) for name in self._list()]
def delete(self, name):
"delete the object from the folder"
safe_name = self.encode(name)
self._delete(safe_name)
# IMPLEMENTED IN SUBCLASSES
def _put(self, name, data):
...
def _get(self, name):
...
def _delete(self, name):
pass
def _list(self):
...
# # #
def serialize(self, obj):
raw = pickle.dumps(obj)
compressed = self.compress(raw)
encrypted = self.encrypt(compressed)
return encrypted
def deserialize(self, encrypted):
compressed = self.decrypt(encrypted)
raw = self.decompress(compressed)
obj = pickle.loads(raw)
return obj
def encrypt(self, raw):
cipher = AES.new(unhexlify(self.passwd), self.AES_MODE)
return cipher.encrypt(pad(raw, self.AES_BLOCK_SIZE))
def decrypt(self, encrypted):
cipher = AES.new(unhexlify(self.passwd), self.AES_MODE)
return unpad(cipher.decrypt(encrypted), self.AES_BLOCK_SIZE)
def compress(self, data):
return zlib.compress(data)
def decompress(self, data):
return zlib.decompress(data)
def encode(self, name):
return hexlify(name.encode('utf8')).decode('utf8')
def decode(self, name):
return unhexlify(name).decode('utf8')
class DictStorage(Storage):
"Dictionary based storage"
def __init__(self, secret_key, data_dict):
super().__init__(secret_key)
self.data = data_dict
def _put(self, name, data):
if self.folder not in self.data:
self.data[self.folder] = {}
self.data[self.folder][name] = data
def _get(self, name):
return self.data[self.folder][name]
def _list(self):
# TODO: sort by modification time (reverse=True)
return list(self.data.get(self.folder,{}).keys())
def _delete(self, name):
del self.data[self.folder][name]
class LocalStorage(Storage):
"Local filesystem based storage"
def __init__(self, secret_key, path):
if not path:
raise Exception('No storage path in environment variables!')
super().__init__(secret_key)
self.path = os.path.join(path, self.folder)
if not os.path.exists(self.path):
os.makedirs(self.path)
def _put(self, name, data):
with open(os.path.join(self.path, name), 'wb') as f:
f.write(data)
def _get(self, name):
with open(os.path.join(self.path, name), 'rb') as f:
data = f.read()
return data
def _list(self):
# TODO: sort by modification time (reverse=True)
return os.listdir(self.path)
def _delete(self, name):
os.remove(os.path.join(self.path, name))
class S3Storage(Storage):
"S3 based encrypted storage"
def __init__(self, secret_key, **kw):
prefix = kw.get('prefix') or os.getenv('S3_PREFIX','index/x1')
region = kw.get('region') or os.getenv('S3_REGION','sfo3')
bucket = kw.get('bucket') or os.getenv('S3_BUCKET','ask-my-pdf')
url = kw.get('url') or os.getenv('S3_URL',f'https://{region}.digitaloceanspaces.com')
key = os.getenv('S3_KEY','')
secret = os.getenv('S3_SECRET','')
#
if not key or not secret:
raise Exception("No S3 credentials in environment variables!")
#
super().__init__(secret_key)
self.session = boto3.session.Session()
self.s3 = self.session.client('s3',
config=botocore.config.Config(s3={'addressing_style': 'virtual'}),
region_name=region,
endpoint_url=url,
aws_access_key_id=key,
aws_secret_access_key=secret,
)
self.bucket = bucket
self.prefix = prefix
def get_key(self, name):
return f'{self.prefix}/{self.folder}/{name}'
def _put(self, name, data):
key = self.get_key(name)
f = io.BytesIO(data)
self.s3.upload_fileobj(f, self.bucket, key)
def _get(self, name):
key = self.get_key(name)
f = io.BytesIO()
self.s3.download_fileobj(self.bucket, key, f)
f.seek(0)
return f.read()
def _list(self):
resp = self.s3.list_objects(
Bucket=self.bucket,
Prefix=self.get_key('')
)
contents = resp.get('Contents',[])
contents.sort(key=lambda x:x['LastModified'], reverse=True)
keys = [x['Key'] for x in contents]
names = [x.split('/')[-1] for x in keys]
return names
def _delete(self, name):
self.s3.delete_object(
Bucket=self.bucket,
Key=self.get_key(name)
)
def get_storage(api_key, data_dict):
"get storage adapter configured in environment variables"
mode = os.getenv('STORAGE_MODE','').upper()
path = os.getenv('STORAGE_PATH','')
if mode=='S3':
storage = S3Storage(api_key)
elif mode=='LOCAL':
storage = LocalStorage(api_key, path)
else:
storage = DictStorage(api_key, data_dict)
return storage