[
  {
    "path": ".gitignore",
    "content": "**/__pycache__\n**/*secret*.*\ndata/**\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Maciej Obarski\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Ask my PDF\n\n\n\nThank you for your interest in my application. Please be aware that this is only a **Proof of Concept system** and may contain bugs or unfinished features. If you like this app you can ❤️ [follow me](https://twitter.com/KerbalFPV) on Twitter for news and updates.\n\n\n\n### Ask my PDF - Question answering system built on top of GPT3\n\n\n\n🎲 The primary use case for this app is to assist users in answering  questions about board game rules based on the instruction manual. While  the app can be used for other tasks, helping users with board game rules is particularly meaningful to me since I'm an avid fan of board games  myself. Additionally, this use case is relatively harmless, even in  cases where the model may experience hallucinations.\n\n\n\n🌐 The app can be accessed on the Streamlit Community Cloud at https://ask-my-pdf.streamlit.app/. 🔑 However, to use the app, you will need your own [OpenAI's API key](https://platform.openai.com/account/api-keys).\n\n\n\n📄 The app implements the following academic papers:\n\n- [In-Context Retrieval-Augmented Language Models](https://arxiv.org/abs/2302.00083) aka **RALM**\n\n- [Precise Zero-Shot Dense Retrieval without Relevance Labels](https://arxiv.org/abs/2212.10496) aka **HyDE** (Hypothetical Document Embeddings)\n\n\n\n### Installation\n\n\n\n1. Clone the repo:\n\n   `git clone https://github.com/mobarski/ask-my-pdf`\n\n2. Install dependencies:\n\n   `pip install -r ask-my-pdf/requirements.txt`\n\n3. Run the app:\n\n   `cd ask-my-pdf/src`\n   \n   `run.sh` or `run.bat`\n\n\n\n### High-level documentation\n\n\n\n#### RALM + HyDE\n\n![RALM + HyDE](docs/ralm_hyde.jpg)\n\n\n\n#### RALM + HyDE + context\n\n![RALM + HyDE + context](docs/ralm_hyde_wc.jpg)\n\n\n\n### Environment variables used for configuration\n\n\n\n##### General configuration:\n\n- **STORAGE_SALT** - cryptograpic salt used when deriving user/folder name and encryption key from API key, hexadecimal notation, 2-16 characters\n\n- **STORAGE_MODE** - index storage mode:  S3, LOCAL, DICT (default)\n\n- **STATS_MODE** - usage stats storage mode: REDIS, DICT (default)\n\n- **FEEDBACK_MODE** - user feedback storage mode: REDIS, NONE (default)\n\n- **CACHE_MODE** - embeddings cache mode: S3, DISK, NONE (default)\n\n  \n\n##### Local filesystem configuration (storage / cache):\n\n- **STORAGE_PATH** - directory path for index storage\n\n- **CACHE_PATH** - directory path for embeddings cache\n\n  \n\n##### S3 configuration (storage / cache):\n\n- **S3_REGION** - region code\n\n- **S3_BUCKET** - bucket name (storage)\n\n- **S3_SECRET** - secret key\n\n- **S3_KEY** - access key\n\n- **S3_URL** - URL\n\n- **S3_PREFIX** - object name prefix\n\n- **S3_CACHE_BUCKET** - bucket name (cache)\n\n- **S3_CACHE_PREFIX** - object name prefix (cache)\n\n  \n\n##### Redis configuration (for persistent usage statistics / user feedback):\n\n- **REDIS_URL** - Redis DB URL (redis[s]://:password@host:port/[db])\n\n  \n\n##### Community version related options:\n\n- **OPENAI_KEY** - API key used for the default user\n- **COMMUNITY_DAILY_USD** - default user's daily budget\n- **COMMUNITY_USER** - default user's code\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "git+https://github.com/mobarski/ai-bricks.git\nstreamlit\npypdf\nscikit-learn\nnumpy\npycryptodome\nboto3\nredis\nretry\n"
  },
  {
    "path": "src/ai.py",
    "content": "from ai_bricks.api import openai\nimport stats\nimport os\n\nDEFAULT_USER = os.getenv('COMMUNITY_USER','')\n\ndef use_key(key):\n\topenai.use_key(key)\n\nusage_stats = stats.get_stats(user=DEFAULT_USER)\ndef set_user(user):\n\tglobal usage_stats\n\tusage_stats = stats.get_stats(user=user)\n\topenai.set_global('user', user)\n\topenai.add_callback('after', stats_callback)\n\ndef complete(text, **kw):\n\tmodel = kw.get('model','gpt-3.5-turbo')\n\tllm = openai.model(model)\n\tllm.config['pre_prompt'] = 'output only in raw text' # for chat models\n\tresp = llm.complete(text, **kw)\n\tresp['model'] = model\n\treturn resp\n\ndef embedding(text, **kw):\n\tmodel = kw.get('model','text-embedding-ada-002')\n\tllm = openai.model(model)\n\tresp = llm.embed(text, **kw)\n\tresp['model'] = model\n\treturn resp\n\ndef embeddings(texts, **kw):\n\tmodel = kw.get('model','text-embedding-ada-002')\n\tllm = openai.model(model)\n\tresp = llm.embed_many(texts, **kw)\n\tresp['model'] = model\n\treturn resp\n\ntokenizer_model = openai.model('text-davinci-003')\ndef get_token_count(text):\n\treturn tokenizer_model.token_count(text)\n\ndef stats_callback(out, resp, self):\n\tmodel = self.config['model']\n\tusage = resp['usage']\n\tusage['call_cnt'] = 1\n\tif 'text' in out:\n\t\tusage['completion_chars'] = len(out['text'])\n\telif 'texts' in out:\n\t\tusage['completion_chars'] = sum([len(text) for text in out['texts']])\n\t# TODO: prompt_chars\n\t# TODO: total_chars\n\tif 'rtt' in out:\n\t\tusage['rtt'] = out['rtt']\n\t\tusage['rtt_cnt'] = 1\n\tusage_stats.incr(f'usage:v4:[date]:[user]', {f'{k}:{model}':v for k,v in usage.items()})\n\tusage_stats.incr(f'hourly:v4:[date]',       {f'{k}:{model}:[hour]':v for k,v in usage.items()})\n\t#print('STATS_CALLBACK', usage, flush=True) # XXX\n\ndef get_community_usage_cost():\n\tdata = usage_stats.get(f'usage:v4:[date]:{DEFAULT_USER}')\n\tused = 0.0\n\tused += 0.04   * data.get('total_tokens:gpt-4',0) / 1000 # prompt_price=0.03 but output_price=0.06\n\tused += 0.02   * data.get('total_tokens:text-davinci-003',0) / 1000\n\tused += 0.002  * data.get('total_tokens:text-curie-001',0) / 1000\n\tused += 0.002  * data.get('total_tokens:gpt-3.5-turbo',0) / 1000\n\tused += 0.0004 * data.get('total_tokens:text-embedding-ada-002',0) / 1000\n\treturn used\n"
  },
  {
    "path": "src/cache.py",
    "content": "from retry import retry\n\nfrom binascii import hexlify,unhexlify\nimport pickle\nimport zlib\nimport io\nimport os\n\n# pip install boto3\nimport boto3\nimport botocore\n\nclass Cache:\n\t\"Dummy / Base Cache\"\n\tdef __init__(self):\n\t\tpass\n\t\n\tdef put(self, key, obj):\n\t\tpass\n\t\n\tdef get(self, key):\n\t\treturn None\n\t\n\tdef has(self, key):\n\t\treturn False\n\t\n\tdef delete(self, key):\n\t\tpass\n\n\tdef serialize(self, obj):\n\t\tpickled = pickle.dumps(obj)\n\t\tcompressed = self.compress(pickled)\n\t\treturn compressed\n\t\n\tdef deserialize(self, data):\n\t\tpickled = self.decompress(data)\n\t\tobj = pickle.loads(pickled)\n\t\treturn obj\n\n\tdef compress(self, data):\n\t\treturn zlib.compress(data)\n\t\n\tdef decompress(self, data):\n\t\treturn zlib.decompress(data)\n\n\tdef encode(self, name):\n\t\treturn hexlify(name.encode('utf8')).decode('utf8')\n\t\n\tdef decode(self, name):\n\t\treturn unhexlify(name).decode('utf8')\n\t\n\tdef call(self, key, fun, *a, **kw):\n\t\tif self.has(key):\n\t\t\treturn self.get(key)\n\t\telse:\n\t\t\tresp = fun(*a, **kw)\n\t\t\tself.put(key, resp)\n\t\t\treturn resp\n\n\nclass DiskCache(Cache):\n\t\"Local disk based cache\"\n\n\tdef __init__(self, root):\n\t\tself.root = root\n\t\n\tdef path(self, key):\n\t\treturn os.path.join(self.root, self.encode(key))\n\t\n\tdef put(self, key, obj):\n\t\tpath = self.path(key)\n\t\tdata = self.serialize(obj)\n\t\twith open(path, 'wb') as f:\n\t\t\tf.write(data)\n\t\n\tdef get(self, key):\n\t\tpath = self.path(key)\n\t\twith open(path, 'rb') as f:\n\t\t\tdata = f.read()\n\t\tobj = self.deserialize(data)\n\t\treturn obj\n\n\tdef has(self, key):\n\t\tpath = self.path(key)\n\t\treturn os.path.exists(path)\n\n\tdef delete(self, key):\n\t\tpath = self.path(key)\n\t\tos.remove(path)\n\n\nclass S3Cache(Cache):\n\t\"S3 based cache\"\n\n\tdef __init__(self, **kw):\n\t\tbucket = kw.get('bucket') or os.getenv('S3_CACHE_BUCKET','ask-my-pdf')\n\t\tprefix = kw.get('prefix') or os.getenv('S3_CACHE_PREFIX','cache/x1')\n\t\tregion = kw.get('region') or os.getenv('S3_REGION','sfo3')\n\t\turl    = kw.get('url')    or os.getenv('S3_URL',f'https://{region}.digitaloceanspaces.com')\n\t\tkey    = os.getenv('S3_KEY','')\n\t\tsecret = os.getenv('S3_SECRET','')\n\t\t#\n\t\tif not key or not secret:\n\t\t\traise Exception(\"No S3 credentials in environment variables!\")\n\t\t#\n\t\tself.session = boto3.session.Session()\n\t\tself.s3 = self.session.client('s3',\n\t\t\t\tconfig=botocore.config.Config(s3={'addressing_style': 'virtual'}),\n\t\t\t\tregion_name=region,\n\t\t\t\tendpoint_url=url,\n\t\t\t\taws_access_key_id=key,\n\t\t\t\taws_secret_access_key=secret,\n\t\t\t)\n\t\tself.bucket = bucket\n\t\tself.prefix = prefix\n\n\tdef get_s3_key(self, key):\n\t\treturn f'{self.prefix}/{key}'\n\t\n\tdef put(self, key, obj):\n\t\ts3_key = self.get_s3_key(key)\n\t\tdata = self.serialize(obj)\n\t\tf = io.BytesIO(data)\n\t\tself.s3.upload_fileobj(f, self.bucket, s3_key)\n\n\tdef get(self, key, default=None):\n\t\ts3_key = self.get_s3_key(key)\n\t\tf = io.BytesIO()\n\t\ttry:\n\t\t\tself.s3.download_fileobj(self.bucket, s3_key, f)\n\t\texcept:\n\t\t\tf.close()\n\t\t\treturn default\n\t\tf.seek(0)\n\t\tdata = f.read()\n\t\tobj = self.deserialize(data)\n\t\treturn obj\n\t\n\tdef has(self, key):\n\t\ts3_key = self.get_s3_key(key)\n\t\ttry:\n\t\t\tself.s3.head_object(Bucket=self.bucket, Key=s3_key)\n\t\t\treturn True\n\t\texcept:\n\t\t\treturn False\n\t\n\tdef delete(self, key):\n\t\tself.s3.delete_object(\n\t\t\tBucket = self.bucket,\n\t\t\tKey = self.get_s3_key(key))\n\n\ndef get_cache(**kw):\n\tmode = os.getenv('CACHE_MODE','').upper()\n\tpath = os.getenv('CACHE_PATH','')\n\tif mode == 'DISK':\n\t\treturn DiskCache(path)\n\telif mode == 'S3':\n\t\treturn S3Cache(**kw)\n\telse:\n\t\treturn Cache()\n\n\nif __name__==\"__main__\":\n\t#cache = DiskCache('__pycache__')\n\tcache = S3Cache()\n\tcache.put('xxx',{'a':1,'b':22})\n\tprint('get xxx', cache.get('xxx'))\n\tprint('has xxx', cache.has('xxx'))\n\tprint('has yyy', cache.has('yyy'))\n\tprint('delete xxx', cache.delete('xxx'))\n\tprint('has xxx', cache.has('xxx'))\n\tprint('get xxx', cache.get('xxx'))\n\t#\n"
  },
  {
    "path": "src/css.py",
    "content": "v1 = \"\"\"\n/* feedback checkbox */\n.css-18fuwiq {\n position: relative;\n padding-top: 6px;\n}\n.css-949r0i {\n position: relative;\n padding-top: 6px;\n}\n\n\"\"\"\n"
  },
  {
    "path": "src/feedback.py",
    "content": "import datetime\nimport hashlib\nimport redis\nimport os\nfrom retry import retry\n\ndef hexdigest(text):\n\treturn hashlib.md5(text.encode('utf8')).hexdigest()\n\ndef as_int(x):\n\treturn int(x) if x is not None else None\n\nclass Feedback:\n\t\"Dummy feedback adapter\"\n\tdef __init__(self, user):\n\t\t...\n\tdef send(self, score, ctx, details=False):\n\t\t...\n\tdef get_score(self):\n\t\treturn 0\n\nclass RedisFeedback(Feedback):\n\t\"Redis feedback adapter\"\n\tdef __init__(self, user):\n\t\tREDIS_URL = os.getenv('REDIS_URL')\n\t\tif not REDIS_URL:\n\t\t\traise Exception('No Redis configuration in environment variables!')\n\t\tsuper().__init__(user)\n\t\tself.db = redis.Redis.from_url(REDIS_URL)\n\t\tself.user = user\n\n\t@retry(tries=5, delay=0.1)\n\tdef send(self, score, ctx, details=False):\n\t\tp = self.db.pipeline()\n\t\tdist_list = ctx.get('debug',{}).get('model.query.resp',{}).get('dist_list',[])\n\t\t# feedback\n\t\tindex = ctx.get('index',{})\n\t\tdata = {}\n\t\tdata['user'] = self.user\n\t\tdata['task-prompt-version'] = ctx.get('task_name')\n\t\tdata['model'] = ctx.get('model')\n\t\tdata['model-embeddings'] = ctx.get('model_embed')\n\t\tdata['task-prompt'] = ctx.get('task')\n\t\tdata['temperature'] = ctx.get('temperature')\n\t\tdata['frag-size'] = ctx.get('frag_size')\n\t\tdata['frag-cnt'] = ctx.get('max_frags')\n\t\tdata['frag-n-before'] = ctx.get('n_frag_before')\n\t\tdata['frag-n-after'] = ctx.get('n_frag_after')\n\t\tdata['filename'] = ctx.get('filename')\n\t\tdata['filehash'] = index.get('hash') or index.get('filehash')\n\t\tdata['filesize'] = index.get('filesize')\n\t\tdata['n-pages'] = index.get('n_pages')\n\t\tdata['n-texts'] = index.get('n_texts')\n\t\tdata['use-hyde'] = as_int(ctx.get('use_hyde'))\n\t\tdata['use-hyde-summary'] = as_int(ctx.get('use_hyde_summary'))\n\t\tdata['question'] = ctx.get('question')\n\t\tdata['answer'] = ctx.get('answer')\n\t\tdata['hyde-summary'] = index.get('summary')\n\t\tdata['resp-dist-list'] = '|'.join([f\"{x:0.3f}\" for x in dist_list])\n\t\tfb_hash = hexdigest(str(list(sorted(data.items()))))\n\t\t#\n\t\tdata['score'] = score\n\t\tdata['datetime'] = str(datetime.datetime.now())\n\t\tkey1 = f'feedback:v2:{fb_hash}'\n\t\tif not details:\n\t\t\tfor k in ['question','answer','hyde-summary']:\n\t\t\t\tdata[k] = ''\n\t\tp.hset(key1, mapping=data)\n\t\t# feedback-daily\n\t\tdate = datetime.date.today()\n\t\tkey2 = f'feedback-daily:v1:{date}:{\"positive\" if score > 0 else \"negative\"}'\n\t\tp.sadd(key2, fb_hash)\n\t\t# feedback-score\n\t\tkey3 = f'feedback-score:v2:{self.user}'\n\t\tp.sadd(key3, fb_hash)\n\t\tp.execute()\n\t\n\t@retry(tries=5, delay=0.1)\n\tdef get_score(self):\n\t\tkey = f'feedback-score:v2:{self.user}'\n\t\treturn self.db.scard(key)\n\n\ndef get_feedback_adapter(user):\n\tMODE = os.getenv('FEEDBACK_MODE','').upper()\n\tif MODE=='REDIS':\n\t\treturn RedisFeedback(user)\n\telse:\n\t\treturn Feedback(user)\n"
  },
  {
    "path": "src/gui.py",
    "content": "__version__ = \"0.4.8.3\"\napp_name = \"Ask my PDF\"\n\n\n# BOILERPLATE\n\nimport streamlit as st\nst.set_page_config(layout='centered', page_title=f'{app_name} {__version__}')\nss = st.session_state\nif 'debug' not in ss: ss['debug'] = {}\nimport css\nst.write(f'<style>{css.v1}</style>', unsafe_allow_html=True)\nheader1 = st.empty() # for errors / messages\nheader2 = st.empty() # for errors / messages\nheader3 = st.empty() # for errors / messages\n\n# IMPORTS\n\nimport prompts\nimport model\nimport storage\nimport feedback\nimport cache\nimport os\n\nfrom time import time as now\n\n# HANDLERS\n\ndef on_api_key_change():\n\tapi_key = ss.get('api_key') or os.getenv('OPENAI_KEY')\n\tmodel.use_key(api_key) # TODO: empty api_key\n\t#\n\tif 'data_dict' not in ss: ss['data_dict'] = {} # used only with DictStorage\n\tss['storage'] = storage.get_storage(api_key, data_dict=ss['data_dict'])\n\tss['cache'] = cache.get_cache()\n\tss['user'] = ss['storage'].folder # TODO: refactor user 'calculation' from get_storage\n\tmodel.set_user(ss['user'])\n\tss['feedback'] = feedback.get_feedback_adapter(ss['user'])\n\tss['feedback_score'] = ss['feedback'].get_score()\n\t#\n\tss['debug']['storage.folder'] = ss['storage'].folder\n\tss['debug']['storage.class'] = ss['storage'].__class__.__name__\n\n\nss['community_user'] = os.getenv('COMMUNITY_USER')\nif 'user' not in ss and ss['community_user']:\n\ton_api_key_change() # use community key\n\n# COMPONENTS\n\n\ndef ui_spacer(n=2, line=False, next_n=0):\n\tfor _ in range(n):\n\t\tst.write('')\n\tif line:\n\t\tst.tabs([' '])\n\tfor _ in range(next_n):\n\t\tst.write('')\n\ndef ui_info():\n\tst.markdown(f\"\"\"\n\t# Ask my PDF\n\tversion {__version__}\n\t\n\tQuestion answering system built on top of GPT3.\n\t\"\"\")\n\tui_spacer(1)\n\tst.write(\"Made by [Maciej Obarski](https://www.linkedin.com/in/mobarski/).\", unsafe_allow_html=True)\n\tui_spacer(1)\n\tst.markdown(\"\"\"\n\t\tThank you for your interest in my application.\n\t\tPlease be aware that this is only a Proof of Concept system\n\t\tand may contain bugs or unfinished features.\n\t\tIf you like this app you can ❤️ [follow me](https://twitter.com/KerbalFPV)\n\t\ton Twitter for news and updates.\n\t\t\"\"\")\n\tui_spacer(1)\n\tst.markdown('Source code can be found [here](https://github.com/mobarski/ask-my-pdf).')\n\ndef ui_api_key():\n\tif ss['community_user']:\n\t\tst.write('## 1. Optional - enter your OpenAI API key')\n\t\tt1,t2 = st.tabs(['community version','enter your own API key'])\n\t\twith t1:\n\t\t\tpct = model.community_tokens_available_pct()\n\t\t\tst.write(f'Community tokens available: :{\"green\" if pct else \"red\"}[{int(pct)}%]')\n\t\t\tst.progress(pct/100)\n\t\t\tst.write('Refresh in: ' + model.community_tokens_refresh_in())\n\t\t\tst.write('You can sign up to OpenAI and/or create your API key [here](https://platform.openai.com/account/api-keys)')\n\t\t\tss['community_pct'] = pct\n\t\t\tss['debug']['community_pct'] = pct\n\t\twith t2:\n\t\t\tst.text_input('OpenAI API key', type='password', key='api_key', on_change=on_api_key_change, label_visibility=\"collapsed\")\n\telse:\n\t\tst.write('## 1. Enter your OpenAI API key')\n\t\tst.text_input('OpenAI API key', type='password', key='api_key', on_change=on_api_key_change, label_visibility=\"collapsed\")\n\ndef index_pdf_file():\n\tif ss['pdf_file']:\n\t\tss['filename'] = ss['pdf_file'].name\n\t\tif ss['filename'] != ss.get('fielname_done'): # UGLY\n\t\t\twith st.spinner(f'indexing {ss[\"filename\"]}'):\n\t\t\t\tindex = model.index_file(ss['pdf_file'], ss['filename'], fix_text=ss['fix_text'], frag_size=ss['frag_size'], cache=ss['cache'])\n\t\t\t\tss['index'] = index\n\t\t\t\tdebug_index()\n\t\t\t\tss['filename_done'] = ss['filename'] # UGLY\n\ndef debug_index():\n\tindex = ss['index']\n\td = {}\n\td['hash'] = index['hash']\n\td['frag_size'] = index['frag_size']\n\td['n_pages'] = len(index['pages'])\n\td['n_texts'] = len(index['texts'])\n\td['summary'] = index['summary']\n\td['pages'] = index['pages']\n\td['texts'] = index['texts']\n\td['time'] = index.get('time',{})\n\tss['debug']['index'] = d\n\ndef ui_pdf_file():\n\tst.write('## 2. Upload or select your PDF file')\n\tdisabled = not ss.get('user') or (not ss.get('api_key') and not ss.get('community_pct',0))\n\tt1,t2 = st.tabs(['UPLOAD','SELECT'])\n\twith t1:\n\t\tst.file_uploader('pdf file', type='pdf', key='pdf_file', disabled=disabled, on_change=index_pdf_file, label_visibility=\"collapsed\")\n\t\tb_save()\n\twith t2:\n\t\tfilenames = ['']\n\t\tif ss.get('storage'):\n\t\t\tfilenames += ss['storage'].list()\n\t\tdef on_change():\n\t\t\tname = ss['selected_file']\n\t\t\tif name and ss.get('storage'):\n\t\t\t\twith ss['spin_select_file']:\n\t\t\t\t\twith st.spinner('loading index'):\n\t\t\t\t\t\tt0 = now()\n\t\t\t\t\t\tindex = ss['storage'].get(name)\n\t\t\t\t\t\tss['debug']['storage_get_time'] = now()-t0\n\t\t\t\tss['filename'] = name # XXX\n\t\t\t\tss['index'] = index\n\t\t\t\tdebug_index()\n\t\t\telse:\n\t\t\t\t#ss['index'] = {}\n\t\t\t\tpass\n\t\tst.selectbox('select file', filenames, on_change=on_change, key='selected_file', label_visibility=\"collapsed\", disabled=disabled)\n\t\tb_delete()\n\t\tss['spin_select_file'] = st.empty()\n\ndef ui_show_debug():\n\tst.checkbox('show debug section', key='show_debug')\n\ndef ui_fix_text():\n\tst.checkbox('fix common PDF problems', value=True, key='fix_text')\n\ndef ui_temperature():\n\t#st.slider('temperature', 0.0, 1.0, 0.0, 0.1, key='temperature', format='%0.1f')\n\tss['temperature'] = 0.0\n\ndef ui_fragments():\n\t#st.number_input('fragment size', 0,2000,200, step=100, key='frag_size')\n\tst.selectbox('fragment size (characters)', [0,200,300,400,500,600,700,800,900,1000], index=3, key='frag_size')\n\tb_reindex()\n\tst.number_input('max fragments', 1, 10, 4, key='max_frags')\n\tst.number_input('fragments before', 0, 3, 1, key='n_frag_before') # TODO: pass to model\n\tst.number_input('fragments after',  0, 3, 1, key='n_frag_after')  # TODO: pass to model\n\ndef ui_model():\n\tmodels = ['gpt-3.5-turbo','gpt-4','text-davinci-003','text-curie-001']\n\tst.selectbox('main model', models, key='model', disabled=not ss.get('api_key'))\n\tst.selectbox('embedding model', ['text-embedding-ada-002'], key='model_embed') # FOR FUTURE USE\n\ndef ui_hyde():\n\tst.checkbox('use HyDE', value=True, key='use_hyde')\n\ndef ui_hyde_summary():\n\tst.checkbox('use summary in HyDE', value=True, key='use_hyde_summary')\n\ndef ui_task_template():\n\tst.selectbox('task prompt template', prompts.TASK.keys(), key='task_name')\n\ndef ui_task():\n\tx = ss['task_name']\n\tst.text_area('task prompt', prompts.TASK[x], key='task')\n\ndef ui_hyde_prompt():\n\tst.text_area('HyDE prompt', prompts.HYDE, key='hyde_prompt')\n\ndef ui_question():\n\tst.write('## 3. Ask questions'+(f' to {ss[\"filename\"]}' if ss.get('filename') else ''))\n\tdisabled = False\n\tst.text_area('question', key='question', height=100, placeholder='Enter question here', help='', label_visibility=\"collapsed\", disabled=disabled)\n\n# REF: Hypotetical Document Embeddings\ndef ui_hyde_answer():\n\t# TODO: enter or generate\n\tpass\n\ndef ui_output():\n\toutput = ss.get('output','')\n\tst.markdown(output)\n\ndef ui_debug():\n\tif ss.get('show_debug'):\n\t\tst.write('### debug')\n\t\tst.write(ss.get('debug',{}))\n\n\ndef b_ask():\n\tc1,c2,c3,c4,c5 = st.columns([2,1,1,2,2])\n\tif c2.button('👍', use_container_width=True, disabled=not ss.get('output')):\n\t\tss['feedback'].send(+1, ss, details=ss['send_details'])\n\t\tss['feedback_score'] = ss['feedback'].get_score()\n\tif c3.button('👎', use_container_width=True, disabled=not ss.get('output')):\n\t\tss['feedback'].send(-1, ss, details=ss['send_details'])\n\t\tss['feedback_score'] = ss['feedback'].get_score()\n\tscore = ss.get('feedback_score',0)\n\tc5.write(f'feedback score: {score}')\n\tc4.checkbox('send details', True, key='send_details',\n\t\t\thelp='allow question and the answer to be stored in the ask-my-pdf feedback database')\n\t#c1,c2,c3 = st.columns([1,3,1])\n\t#c2.radio('zzz',['👍',r'...',r'👎'],horizontal=True,label_visibility=\"collapsed\")\n\t#\n\tdisabled = (not ss.get('api_key') and not ss.get('community_pct',0)) or not ss.get('index')\n\tif c1.button('get answer', disabled=disabled, type='primary', use_container_width=True):\n\t\tquestion = ss.get('question','')\n\t\ttemperature = ss.get('temperature', 0.0)\n\t\thyde = ss.get('use_hyde')\n\t\thyde_prompt = ss.get('hyde_prompt')\n\t\tif ss.get('use_hyde_summary'):\n\t\t\tsummary = ss['index']['summary']\n\t\t\thyde_prompt += f\" Context: {summary}\\n\\n\"\n\t\ttask = ss.get('task')\n\t\tmax_frags = ss.get('max_frags',1)\n\t\tn_before = ss.get('n_frag_before',0)\n\t\tn_after  = ss.get('n_frag_after',0)\n\t\tindex = ss.get('index',{})\n\t\twith st.spinner('preparing answer'):\n\t\t\tresp = model.query(question, index,\n\t\t\t\t\ttask=task,\n\t\t\t\t\ttemperature=temperature,\n\t\t\t\t\thyde=hyde,\n\t\t\t\t\thyde_prompt=hyde_prompt,\n\t\t\t\t\tmax_frags=max_frags,\n\t\t\t\t\tlimit=max_frags+2,\n\t\t\t\t\tn_before=n_before,\n\t\t\t\t\tn_after=n_after,\n\t\t\t\t\tmodel=ss['model'],\n\t\t\t\t)\n\t\tusage = resp.get('usage',{})\n\t\tusage['cnt'] = 1\n\t\tss['debug']['model.query.resp'] = resp\n\t\tss['debug']['resp.usage'] = usage\n\t\tss['debug']['model.vector_query_time'] = resp['vector_query_time']\n\t\t\n\t\tq = question.strip()\n\t\ta = resp['text'].strip()\n\t\tss['answer'] = a\n\t\toutput_add(q,a)\n\t\tst.experimental_rerun() # to enable the feedback buttons\n\ndef b_clear():\n\tif st.button('clear output'):\n\t\tss['output'] = ''\n\ndef b_reindex():\n\t# TODO: disabled\n\tif st.button('reindex'):\n\t\tindex_pdf_file()\n\ndef b_reload():\n\tif st.button('reload prompts'):\n\t\timport importlib\n\t\timportlib.reload(prompts)\n\ndef b_save():\n\tdb = ss.get('storage')\n\tindex = ss.get('index')\n\tname = ss.get('filename')\n\tapi_key = ss.get('api_key')\n\tdisabled = not api_key or not db or not index or not name\n\thelp = \"The file will be stored for about 90 days. Available only when using your own API key.\"\n\tif st.button('save encrypted index in ask-my-pdf', disabled=disabled, help=help):\n\t\twith st.spinner('saving to ask-my-pdf'):\n\t\t\tdb.put(name, index)\n\ndef b_delete():\n\tdb = ss.get('storage')\n\tname = ss.get('selected_file')\n\t# TODO: confirm delete\n\tif st.button('delete from ask-my-pdf', disabled=not db or not name):\n\t\twith st.spinner('deleting from ask-my-pdf'):\n\t\t\tdb.delete(name)\n\t\t#st.experimental_rerun()\n\ndef output_add(q,a):\n\tif 'output' not in ss: ss['output'] = ''\n\tq = q.replace('$',r'\\$')\n\ta = a.replace('$',r'\\$')\n\tnew = f'#### {q}\\n{a}\\n\\n'\n\tss['output'] = new + ss['output']\n\n# LAYOUT\n\nwith st.sidebar:\n\tui_info()\n\tui_spacer(2)\n\twith st.expander('advanced'):\n\t\tui_show_debug()\n\t\tb_clear()\n\t\tui_model()\n\t\tui_fragments()\n\t\tui_fix_text()\n\t\tui_hyde()\n\t\tui_hyde_summary()\n\t\tui_temperature()\n\t\tb_reload()\n\t\tui_task_template()\n\t\tui_task()\n\t\tui_hyde_prompt()\n\nui_api_key()\nui_pdf_file()\nui_question()\nui_hyde_answer()\nb_ask()\nui_output()\nui_debug()\n"
  },
  {
    "path": "src/model.py",
    "content": "from sklearn.metrics.pairwise import cosine_distances\n\nimport datetime\nfrom collections import Counter\nfrom time import time as now\nimport hashlib\nimport re\nimport io\nimport os\n\nimport pdf\nimport ai\n\ndef use_key(api_key):\n\tai.use_key(api_key)\n\ndef set_user(user):\n\tai.set_user(user)\n\ndef query_by_vector(vector, index, limit=None):\n\t\"return (ids, distances and texts) sorted by cosine distance\"\n\tvectors = index['vectors']\n\ttexts = index['texts']\n\t#\n\tsim = cosine_distances([vector], vectors)[0]\n\t#\n\tid_dist_list = list(enumerate(sim))\n\tid_dist_list.sort(key=lambda x:x[1])\n\tid_list   = [x[0] for x in id_dist_list][:limit]\n\tdist_list = [x[1] for x in id_dist_list][:limit]\n\ttext_list = [texts[x] for x in id_list] if texts else ['ERROR']*len(id_list)\n\treturn id_list, dist_list, text_list\n\ndef get_vectors(text_list):\n\t\"transform texts into embedding vectors\"\n\tbatch_size = 128\n\tvectors = []\n\tusage = Counter()\n\tfor i,texts in enumerate(batch(text_list, batch_size)):\n\t\tresp = ai.embeddings(texts)\n\t\tv = resp['vectors']\n\t\tu = resp['usage']\n\t\tu['call_cnt'] = 1\n\t\tusage.update(u)\n\t\tvectors.extend(v)\n\treturn {'vectors':vectors, 'usage':dict(usage), 'model':resp['model']}\n\ndef index_file(f, filename, fix_text=False, frag_size=0, cache=None):\n\t\"return vector index (dictionary) for a given PDF file\"\n\t# calc md5\n\th = hashlib.md5()\n\th.update(f.read())\n\tmd5 = h.hexdigest()\n\tfilesize = f.tell()\n\tf.seek(0)\n\t#\n\tt0 = now()\n\tpages = pdf.pdf_to_pages(f)\n\tt1 = now()\n\t\n\tif fix_text:\n\t\tfor i in range(len(pages)):\n\t\t\tpages[i] = fix_text_problems(pages[i])\n\ttexts = split_pages_into_fragments(pages, frag_size)\n\tt2 = now()\n\tif cache:\n\t\tcache_key = f'get_vectors:{md5}:{frag_size}:{fix_text}'\n\t\tresp = cache.call(cache_key, get_vectors, texts)\n\telse:\n\t\tresp = get_vectors(texts)\n\t\n\tt3 = now()\n\tvectors = resp['vectors']\n\tsummary_prompt = f\"{texts[0]}\\n\\nDescribe the document from which the fragment is extracted. Omit any details.\\n\\n\" # TODO: move to prompts.py\n\tsummary = ai.complete(summary_prompt)\n\tt4 = now()\n\tusage = resp['usage']\n\tout = {}\n\tout['frag_size'] = frag_size\n\tout['n_pages']   = len(pages)\n\tout['n_texts']   = len(texts)\n\tout['texts']     = texts\n\tout['pages']     = pages\n\tout['vectors']   = vectors\n\tout['summary']   = summary['text']\n\tout['filename']  = filename\n\tout['filehash']  = f'md5:{md5}'\n\tout['filesize']  = filesize\n\tout['usage']     = usage\n\tout['model']     = resp['model']\n\tout['time']      = {'pdf_to_pages':t1-t0, 'split_pages':t2-t1, 'get_vectors':t3-t2, 'summary':t4-t3}\n\tout['size']      = len(texts)   # DEPRECATED -> filesize\n\tout['hash']      = f'md5:{md5}' # DEPRECATED -> filehash\n\treturn out\n\ndef split_pages_into_fragments(pages, frag_size):\n\t\"split pages (list of texts) into smaller fragments (list of texts)\"\n\tpage_offset = [0]\n\tfor p,page in enumerate(pages):\n\t\tpage_offset += [page_offset[-1]+len(page)+1]\n\t# TODO: del page_offset[-1] ???\n\tif frag_size:\n\t\ttext = ' '.join(pages)\n\t\treturn text_to_fragments(text, frag_size, page_offset)\n\telse:\n\t\treturn pages\n\ndef text_to_fragments(text, size, page_offset):\n\t\"split single text into smaller fragments (list of texts)\"\n\tif size and len(text)>size:\n\t\tout = []\n\t\tpos = 0\n\t\tpage = 1\n\t\tp_off = page_offset.copy()[1:]\n\t\teos = find_eos(text)\n\t\tif len(text) not in eos:\n\t\t\teos += [len(text)]\n\t\tfor i in range(len(eos)):\n\t\t\tif eos[i]-pos>size:\n\t\t\t\ttext_fragment = f'PAGE({page}):\\n'+text[pos:eos[i]]\n\t\t\t\tout += [text_fragment]\n\t\t\t\tpos = eos[i]\n\t\t\t\tif eos[i]>p_off[0]:\n\t\t\t\t\tpage += 1\n\t\t\t\t\tdel p_off[0]\n\t\t# ugly: last iter\n\t\ttext_fragment = f'PAGE({page}):\\n'+text[pos:eos[i]]\n\t\tout += [text_fragment]\n\t\t#\n\t\tout = [x for x in out if x]\n\t\treturn out\n\telse:\n\t\treturn [text]\n\ndef find_eos(text):\n\t\"return list of all end-of-sentence offsets\"\n\treturn [x.span()[1] for x in re.finditer('[.!?。]\\s+',text)]\n\n###############################################################################\n\ndef fix_text_problems(text):\n\t\"fix common text problems\"\n\ttext = re.sub('\\s+[-]\\s+','',text) # word continuation in the next line\n\treturn text\n\ndef query(text, index, task=None, temperature=0.0, max_frags=1, hyde=False, hyde_prompt=None, limit=None, n_before=1, n_after=1, model=None):\n\t\"get dictionary with the answer for the given question (text).\"\n\tout = {}\n\t\n\tif hyde:\n\t\t# TODO: model param\n\t\tout['hyde'] = hypotetical_answer(text, index, hyde_prompt=hyde_prompt, temperature=temperature)\n\t\t# TODO: usage\n\t\n\t# RANK FRAGMENTS\n\tif hyde:\n\t\tresp = ai.embedding(out['hyde']['text'])\n\t\t# TODO: usage\n\telse:\n\t\tresp = ai.embedding(text)\n\t\t# TODO: usage\n\tv = resp['vector']\n\tt0 = now()\n\tid_list, dist_list, text_list = query_by_vector(v, index, limit=limit)\n\tdt0 = now()-t0\n\t\n\t# BUILD PROMPT\n\t\n\t# select fragments\n\tN_BEFORE = 1 # TODO: param\n\tN_AFTER =  1 # TODO: param\n\tselected = {} # text id -> rank\n\tfor rank,id in enumerate(id_list):\n\t\tfor x in range(id-n_before, id+1+n_after):\n\t\t\tif x not in selected and x>=0 and x<index['size']:\n\t\t\t\tselected[x] = rank\n\tselected2 = [(id,rank) for id,rank in selected.items()]\n\tselected2.sort(key=lambda x:(x[1],x[0]))\n\t\n\t# build context\n\tSEPARATOR = '\\n---\\n'\n\tcontext = ''\n\tcontext_len = 0\n\tfrag_list = []\n\tfor id,rank in selected2:\n\t\tfrag = index['texts'][id]\n\t\tfrag_len = ai.get_token_count(frag)\n\t\tif context_len+frag_len <= 3000: # TODO: remove hardcode\n\t\t\tcontext += SEPARATOR + frag # add separator and text fragment\n\t\t\tfrag_list += [frag]\n\t\t\tcontext_len = ai.get_token_count(context)\n\tout['context_len'] = context_len\n\tprompt = f\"\"\"\n\t\t{task or 'Task: Answer question based on context.'}\n\t\t\n\t\tContext:\n\t\t{context}\n\t\t\n\t\tQuestion: {text}\n\t\t\n\t\tAnswer:\"\"\" # TODO: move to prompts.py\n\t\n\t# GET ANSWER\n\tresp2 = ai.complete(prompt, temperature=temperature, model=model)\n\tanswer = resp2['text']\n\tusage = resp2['usage']\n\t\n\t# OUTPUT\n\tout['vector_query_time'] = dt0\n\tout['id_list'] = id_list\n\tout['dist_list'] = dist_list\n\tout['selected'] = selected\n\tout['selected2'] = selected2\n\tout['frag_list'] = frag_list\n\t#out['query.vector'] = resp['vector']\n\tout['usage'] = usage\n\tout['prompt'] = prompt\n\tout['model'] = resp2['model']\n\t# CORE\n\tout['text'] = answer\n\treturn out\n\ndef hypotetical_answer(text, index, hyde_prompt=None, temperature=0.0):\n\t\"get hypotethical answer for the question (text)\"\n\thyde_prompt = hyde_prompt or 'Write document that answers the question.'\n\tprompt = f\"\"\"\n\t{hyde_prompt}\n\tQuestion: \"{text}\"\n\tDocument:\"\"\" # TODO: move to prompts.py\n\tresp = ai.complete(prompt, temperature=temperature)\n\treturn resp\n\n\ndef community_tokens_available_pct():\n\tused = ai.get_community_usage_cost()\n\tlimit = float(os.getenv('COMMUNITY_DAILY_USD',0))\n\tpct = (100.0 * (limit-used) / limit) if limit else 0\n\tpct = max(0, pct)\n\tpct = min(100, pct)\n\treturn pct\n\n\ndef community_tokens_refresh_in():\n\tx = datetime.datetime.now()\n\tdt = (x.replace(hour=23, minute=59, second=59) - x).seconds\n\th = dt // 3600\n\tm = dt  % 3600 // 60\n\treturn f\"{h} h {m} min\"\n\n# util\ndef batch(data, n):\n\tfor i in range(0, len(data), n):\n\t\tyield data[i:i+n]\n\nif __name__==\"__main__\":\n\tprint(text_to_fragments(\"to jest. test tego. programu\", size=3, page_offset=[0,5,10,15,20]))\n"
  },
  {
    "path": "src/pdf.py",
    "content": "\"PDF adapter\"\n\nimport pypdf\n\ndef pdf_to_pages(file):\n\t\"extract text (pages) from pdf file\"\n\tpages = []\n\tpdf = pypdf.PdfReader(file)\n\tfor p in range(len(pdf.pages)):\n\t\tpage = pdf.pages[p]\n\t\ttext = page.extract_text()\n\t\tpages += [text]\n\treturn pages\n"
  },
  {
    "path": "src/prompts.py",
    "content": "# INFO: some prompts are still in model.py\n\n# TODO: Ignore OCR problems in the text below.\n\nTASK = {\n\t'v6': (\n\t\t\t\"Answer the question truthfully based on the text below. \"\n\t\t\t\"Include verbatim quote and a comment where to find it in the text (page number). \"\n\t\t\t#\"After the quote write a step by step explanation in a new paragraph. \"\n\t\t\t\"After the quote write a step by step explanation. \"\n\t\t\t\"Use bullet points. \"\n\t\t\t#\"After that try to rephrase the original question so it might give better results. \" \n\t\t),\n\t'v5': (\n\t\t\t\"Answer the question truthfully based on the text below. \"\n\t\t\t\"Include at least one verbatim quote (marked with quotation marks) and a comment where to find it in the text (ie name of the section and page number). \"\n\t\t\t\"Use ellipsis in the quote to omit irrelevant parts of the quote. \"\n\t\t\t\"After the quote write (in the new paragraph) a step by step explanation to be sure we have the right answer \"\n\t\t\t\"(use bullet-points in separate lines)\" #, adjust the language for a young reader). \"\n\t\t\t\"After the explanation check if the Answer is consistent with the Context and doesn't require external knowledge. \"\n\t\t\t\"In a new line write 'SELF-CHECK OK' if the check was successful and 'SELF-CHECK FAILED' if it failed. \" \n\t\t),\n\t'v4':\n\t\t\"Answer the question truthfully based on the text below. \" \\\n\t\t\"Include verbatim quote and a comment where to find it in the text (ie name of the section and page number). \" \\\n\t\t\"After the quote write an explanation (in the new paragraph) for a young reader.\",\n\t'v3': 'Answer the question truthfully based on the text below. Include verbatim quote and a comment where to find it in the text (ie name of the section and page number).',\n\t'v2': 'Answer question based on context. The answers sould be elaborate and based only on the context.',\n\t'v1': 'Answer question based on context.',\n\t# 'v5':\n\t\t# \"Generate a comprehensive and informative answer for a given question solely based on the provided document fragments. \" \\\n\t\t# \"You must only use information from the provided fragments. Use an unbiased and journalistic tone. Combine fragments together into coherent answer. \" \\\n\t\t# \"Do not repeat text. Cite fragments using [${number}] notation. Only cite the most relevant fragments that answer the question accurately. \" \\\n\t\t# \"If different fragments refer to different entities with the same name, write separate answer for each entity.\",\n}\n\nHYDE = \"Write an example answer to the following question. Don't write generic answer, just assume everything that is not known.\"\n\n# TODO\nSUMMARY = {\n\t'v2':'Describe the document from which the fragment is extracted. Omit any details.',\n\t'v1':'Describe the document from which the fragment is extracted. Do not describe the fragment, focus on figuring out what kind document it is.',\n}\n"
  },
  {
    "path": "src/run.bat",
    "content": "set STORAGE_MODE=LOCAL\nset CACHE_MODE=DISK\n\nset STORAGE_PATH=../data/storage\nset CACHE_PATH=../data/cache\n\nstreamlit run gui.py\n"
  },
  {
    "path": "src/run.sh",
    "content": "STORAGE_MODE=LOCAL\nCACHE_MODE=DISK\n\nSTORAGE_PATH=../data/storage\nCACHE_PATH=../data/cache\n\nstreamlit run gui.py\n"
  },
  {
    "path": "src/stats.py",
    "content": "import redis\nfrom time import strftime\nimport os\nfrom retry import retry\n\nclass Stats:\n\tdef __init__(self):\n\t\tself.config = {}\n\t\n\tdef render(self, key):\n\t\tvariables = dict(\n\t\t\tdate = strftime('%Y-%m-%d'),\n\t\t\thour = strftime('%H'),\n\t\t)\n\t\tvariables.update(self.config)\n\t\tfor k,v in variables.items():\n\t\t\tkey = key.replace('['+k+']',v)\n\t\treturn key\n\t\n\nclass DictStats(Stats):\n\tdef __init__(self, data_dict):\n\t\tself.data = data_dict\n\t\tself.config = {}\n\t\n\tdef incr(self, key, kv_dict):\n\t\tdata = self.data\n\t\tkey = self.render(key)\n\t\tif key not in data:\n\t\t\tdata[key] = {}\n\t\tfor member,val in kv_dict.items():\n\t\t\tmember = self.render(member)\n\t\t\tdata[key][member] = data[key].get(member,0) + val\n\t\n\tdef get(self, key):\n\t\tkey = self.render(key)\n\t\treturn self.data.get(key, {})\n\n\nclass RedisStats(Stats):\n\tdef __init__(self):\n\t\tREDIS_URL = os.getenv('REDIS_URL')\n\t\tif not REDIS_URL:\n\t\t\traise Exception('No Redis configuration in environment variables!')\n\t\tself.db = redis.Redis.from_url(REDIS_URL)\n\t\tself.config = {}\n\t\n\t@retry(tries=5, delay=0.1)\n\tdef incr(self, key, kv_dict):\n\t\t# TODO: non critical code -> safe exceptions\n\t\tkey = self.render(key)\n\t\tp = self.db.pipeline()\n\t\tfor member,val in kv_dict.items():\n\t\t\tmember = self.render(member)\n\t\t\tself.db.zincrby(key, val, member)\n\t\tp.execute()\n\t\n\t@retry(tries=5, delay=0.1)\n\tdef get(self, key):\n\t\t# TODO: non critical code -> safe exceptions\n\t\tkey = self.render(key)\n\t\titems = self.db.zscan_iter(key)\n\t\treturn {k.decode('utf8'):v for k,v in items}\n\n\nstats_data_dict = {}\ndef get_stats(**kw):\n\tMODE = os.getenv('STATS_MODE','').upper()\n\tif MODE=='REDIS':\n\t\tstats = RedisStats()\n\telse:\n\t\tstats = DictStats(stats_data_dict)\n\tstats.config.update(kw)\n\treturn stats\n\n\n\nif __name__==\"__main__\":\n\ts1 = get_stats(user='maciek')\n\ts1.incr('aaa:[date]:[user]', dict(a=1,b=2))\n\ts1.incr('aaa:[date]:[user]', dict(a=1,b=2))\n\tprint(s1.data)\n\tprint(s1.get('aaa:[date]:[user]'))\n\t#\n\ts2 = get_stats(user='kerbal')\n\ts2.incr('aaa:[date]:[user]', dict(a=1,b=2))\n\ts2.incr('aaa:[date]:[user]', dict(a=1,b=2))\n\tprint(s2.data)\n\tprint(s2.get('aaa:[date]:[user]'))\n\n"
  },
  {
    "path": "src/storage.py",
    "content": "\"Storage adapter - one folder for each user / api_key\"\n\n# pip install pycryptodome\n# REF: https://www.pycryptodome.org/src/cipher/aes\nfrom Crypto.Cipher import AES\nfrom Crypto.Util.Padding import pad,unpad\n\nfrom retry import retry\n\nfrom binascii import hexlify,unhexlify\nimport hashlib\nimport pickle\nimport zlib\nimport os\nimport io\n\n# pip install boto3\nimport boto3\nimport botocore\n\nSALT = unhexlify(os.getenv('STORAGE_SALT','00'))\n\nclass Storage:\n\t\"Encrypted object storage (base class)\"\n\t\n\tdef __init__(self, secret_key):\n\t\tk = secret_key.encode()\n\t\tself.folder = hashlib.blake2s(k, salt=SALT, person=b'folder', digest_size=8).hexdigest()\n\t\tself.passwd = hashlib.blake2s(k, salt=SALT, person=b'passwd', digest_size=32).hexdigest()\n\t\tself.AES_MODE = AES.MODE_ECB # TODO: better AES mode ???\n\t\tself.AES_BLOCK_SIZE = 16\n\t\n\tdef get(self, name, default=None):\n\t\t\"get one object from the folder\"\n\t\tsafe_name = self.encode(name)\n\t\tdata = self._get(safe_name)\n\t\tobj = self.deserialize(data)\n\t\treturn obj\n\t\n\tdef put(self, name, obj):\n\t\t\"put the object into the folder\"\n\t\tsafe_name = self.encode(name)\n\t\tdata = self.serialize(obj)\n\t\tself._put(safe_name, data)\n\t\treturn data\n\n\tdef list(self):\n\t\t\"list object names from the folder\"\n\t\treturn [self.decode(name) for name in self._list()]\n\n\tdef delete(self, name):\n\t\t\"delete the object from the folder\"\n\t\tsafe_name = self.encode(name)\n\t\tself._delete(safe_name)\n\t\n\t# IMPLEMENTED IN SUBCLASSES\n\tdef _put(self, name, data):\n\t\t...\n\tdef _get(self, name):\n\t\t...\t\n\tdef _delete(self, name):\n\t\tpass\n\tdef _list(self):\n\t\t...\n\t\n\t# # #\n\t\n\tdef serialize(self, obj):\n\t\traw = pickle.dumps(obj)\n\t\tcompressed = self.compress(raw)\n\t\tencrypted = self.encrypt(compressed)\n\t\treturn encrypted\n\t\n\tdef deserialize(self, encrypted):\n\t\tcompressed = self.decrypt(encrypted)\n\t\traw = self.decompress(compressed)\n\t\tobj = pickle.loads(raw)\n\t\treturn obj\n\n\tdef encrypt(self, raw):\n\t\tcipher = AES.new(unhexlify(self.passwd), self.AES_MODE)\n\t\treturn cipher.encrypt(pad(raw, self.AES_BLOCK_SIZE))\n\t\n\tdef decrypt(self, encrypted):\n\t\tcipher = AES.new(unhexlify(self.passwd), self.AES_MODE)\n\t\treturn unpad(cipher.decrypt(encrypted), self.AES_BLOCK_SIZE)\n\n\tdef compress(self, data):\n\t\treturn zlib.compress(data)\n\t\n\tdef decompress(self, data):\n\t\treturn zlib.decompress(data)\n\t\n\tdef encode(self, name):\n\t\treturn hexlify(name.encode('utf8')).decode('utf8')\n\t\n\tdef decode(self, name):\n\t\treturn unhexlify(name).decode('utf8')\n\n\nclass DictStorage(Storage):\n\t\"Dictionary based storage\"\n\t\n\tdef __init__(self, secret_key, data_dict):\n\t\tsuper().__init__(secret_key)\n\t\tself.data = data_dict\n\t\t\n\tdef _put(self, name, data):\n\t\tif self.folder not in self.data:\n\t\t\tself.data[self.folder] = {}\n\t\tself.data[self.folder][name] = data\n\t\t\n\tdef _get(self, name):\n\t\treturn self.data[self.folder][name]\n\t\n\tdef _list(self):\n\t\t# TODO: sort by modification time (reverse=True)\n\t\treturn list(self.data.get(self.folder,{}).keys())\n\t\n\tdef _delete(self, name):\n\t\tdel self.data[self.folder][name]\n\n\nclass LocalStorage(Storage):\n\t\"Local filesystem based storage\"\n\t\n\tdef __init__(self, secret_key, path):\n\t\tif not path:\n\t\t\traise Exception('No storage path in environment variables!')\n\t\tsuper().__init__(secret_key)\n\t\tself.path = os.path.join(path, self.folder)\n\t\tif not os.path.exists(self.path):\n\t\t\tos.makedirs(self.path)\n\t\n\tdef _put(self, name, data):\n\t\twith open(os.path.join(self.path, name), 'wb') as f:\n\t\t\tf.write(data)\n\n\tdef _get(self, name):\n\t\twith open(os.path.join(self.path, name), 'rb') as f:\n\t\t\tdata = f.read()\n\t\treturn data\n\t\n\tdef _list(self):\n\t\t# TODO: sort by modification time (reverse=True)\n\t\treturn os.listdir(self.path)\n\t\n\tdef _delete(self, name):\n\t\tos.remove(os.path.join(self.path, name))\n\n\nclass S3Storage(Storage):\n\t\"S3 based encrypted storage\"\n\t\n\tdef __init__(self, secret_key, **kw):\n\t\tprefix = kw.get('prefix') or os.getenv('S3_PREFIX','index/x1')\n\t\tregion = kw.get('region') or os.getenv('S3_REGION','sfo3')\n\t\tbucket = kw.get('bucket') or os.getenv('S3_BUCKET','ask-my-pdf')\n\t\turl    = kw.get('url')    or os.getenv('S3_URL',f'https://{region}.digitaloceanspaces.com')\n\t\tkey    = os.getenv('S3_KEY','')\n\t\tsecret = os.getenv('S3_SECRET','')\n\t\t#\n\t\tif not key or not secret:\n\t\t\traise Exception(\"No S3 credentials in environment variables!\")\n\t\t#\n\t\tsuper().__init__(secret_key)\n\t\tself.session = boto3.session.Session()\n\t\tself.s3 = self.session.client('s3',\n\t\t\t\tconfig=botocore.config.Config(s3={'addressing_style': 'virtual'}),\n\t\t\t\tregion_name=region,\n\t\t\t\tendpoint_url=url,\n\t\t\t\taws_access_key_id=key,\n\t\t\t\taws_secret_access_key=secret,\n\t\t\t)\n\t\tself.bucket = bucket\n\t\tself.prefix = prefix\n\t\n\tdef get_key(self, name):\n\t\treturn f'{self.prefix}/{self.folder}/{name}'\n\t\n\tdef _put(self, name, data):\n\t\tkey = self.get_key(name)\n\t\tf = io.BytesIO(data)\n\t\tself.s3.upload_fileobj(f, self.bucket, key)\n\t\n\tdef _get(self, name):\n\t\tkey = self.get_key(name)\n\t\tf = io.BytesIO()\n\t\tself.s3.download_fileobj(self.bucket, key, f)\n\t\tf.seek(0)\n\t\treturn f.read()\n\t\n\tdef _list(self):\n\t\tresp = self.s3.list_objects(\n\t\t\t\tBucket=self.bucket,\n\t\t\t\tPrefix=self.get_key('')\n\t\t\t)\n\t\tcontents = resp.get('Contents',[])\n\t\tcontents.sort(key=lambda x:x['LastModified'], reverse=True)\n\t\tkeys = [x['Key'] for x in contents]\n\t\tnames = [x.split('/')[-1] for x in keys]\n\t\treturn names\n\t\n\tdef _delete(self, name):\n\t\tself.s3.delete_object(\n\t\t\t\tBucket=self.bucket,\n\t\t\t\tKey=self.get_key(name)\n\t\t\t)\n\ndef get_storage(api_key, data_dict):\n\t\"get storage adapter configured in environment variables\"\n\tmode = os.getenv('STORAGE_MODE','').upper()\n\tpath = os.getenv('STORAGE_PATH','')\n\tif mode=='S3':\n\t\tstorage = S3Storage(api_key)\n\telif mode=='LOCAL':\n\t\tstorage = LocalStorage(api_key, path)\n\telse:\n\t\tstorage = DictStorage(api_key, data_dict)\n\treturn storage\n"
  }
]