Showing preview only (406K chars total). Download the full file or copy to clipboard to get everything.
Repository: 2g-XzenG/Claim-PT
Branch: main
Commit: 38348ccbe0c1
Files: 18
Total size: 38.7 MB
Directory structure:
gitextract_9ionexop/
├── .gitignore
├── README.md
├── attention/
│ ├── attention.ipynb
│ ├── attention_weights_finetune
│ ├── attention_weights_pretrain
│ ├── fake_data
│ ├── get_att.py
│ └── saveModel/
│ ├── saved_model.pb
│ └── variables/
│ ├── variables.data-00000-of-00002
│ ├── variables.data-00001-of-00002
│ └── variables.index
├── finetune/
│ ├── asthma/
│ │ └── fine-tune.ipynb
│ ├── autoDiag/
│ │ └── fine-tune.ipynb
│ └── suicide/
│ └── fine-tune.ipynb
└── pretraining/
├── DataGenerator.py
├── cpt.py
├── plot.ipynb
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.DS_Store
.AppleDouble
.LSOverride
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: README.md
================================================
# Claim-PT: Pretrained transformer framework on pediatric claims data for population specific tasks
This repository contains the tensorflow implementation of the following paper:
Paper Name: Pretrained transformer framework on pediatric claims data for population specific tasks
Authors: Xianlong Zeng, Simon L. Linwood, Chang Liu
Abstract: The adoption of electronic health records (EHR) has become universal during the past decade, which has afforded in-depth data-based research. By learning from the large amount of healthcare data, various data-driven models have been built to predict future events for different medical tasks, such as auto diagnosis and heart-attack prediction. Although EHR is abundant, the population that satisfies specific criteria for learning population-specific tasks is scarce, making it challenging to train data-hungry deep learning models. This study presents the Claim Pre-Training (Claim-PT) framework, a generic pre-training model that first trains on the entire pediatric claims dataset, followed by a discriminative fine-tuning on each population-specific task. The semantic meaning of medical events can be captured in the pre-training stage, and the effective knowledge transfer is completed through the task-aware fine-tuning stage. The fine-tuning process requires minimal parameter modification without changing the model architecture, which mitigates the data scarcity issue and helps train the deep learning model adequately on small patient cohorts. We conducted experiments on a real-world claims dataset with more than one million patient records. Experimental results on two downstream tasks demonstrated the effectiveness of our method: our general task-agnostic pre-training framework outperformed tailored task-specific models, achieving more than 10\% higher in model performance as compared to baselines. In addition, our framework showed a great generalizability potential to transfer learned knowledge from one institution to another, paving the way for future healthcare model pre-training across institutions.
[Paper Link](https://www.nature.com/articles/s41598-022-07545-1): Zeng, Xianlong, Simon L. Linwood, and Chang Liu. "Pretrained transformer framework on pediatric claims data for population specific tasks." Scientific Reports 12, no. 1 (2022): 3651.
# Environment
Ubuntu16.04, Python3.7, TensorFlow2.1
# Model Pretraining
in folder pretraining
### List of hyper-parameter we used in the pre-training stage
- MAX_VISIT=30
- MAX_CODE=10
- MAX_DEMO=2
- PATIENT_DIM=100
- BATCH_SIZE = 100
- TRAIN_RATIO = 0.8
- DATA_SIZE = len(age_seq)
- EPOCHS = 1000
# Downstream finetune
in folder fineune
# Attention playground
in folder attention
================================================
FILE: attention/attention.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "regulation-chemistry",
"metadata": {},
"source": [
"### Case Study: Attention Visualization\n",
"To illustrate the benefit of applying attention mechanisms and better understand how attention weights are changed from the pre-training stage to the fine-tuning stage, we analyze the learned attention weights at two different stages through a case study. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "european-allowance",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"wandb: WARNING W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.\n"
]
}
],
"source": [
"import numpy as np\n",
"import _pickle as pickle\n",
"import pandas as pd\n",
"import torch\n",
"from bertviz import head_view\n",
"from transformers import BertTokenizer, BertModel\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "moderate-translator",
"metadata": {},
"outputs": [],
"source": [
"fake_data = pickle.load(open(\"./fake_data\", \"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "juvenile-tennis",
"metadata": {},
"outputs": [],
"source": [
"attention_weights_pretrain = pickle.load(open(\"./attention_weights_pretrain\", \"rb\"))\n",
"attention_weights_finetune = pickle.load(open(\"./attention_weights_finetune\", \"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "social-affairs",
"metadata": {},
"outputs": [],
"source": [
"def attention_scores(weights, k):\n",
" weight = weights[1][k:k+2]\n",
" sentence = [list(x)[1][:7] for x in fake_data[k]]\n",
" \n",
" attn = torch.tensor(weight[:, :len(sentence), :len(sentence)])\n",
" attn = attn.repeat(1, 1, 1, 1)\n",
" attn_new = []\n",
" for _ in range(1):\n",
" attn_new.append(attn)\n",
" x = head_view(attn_new, sentence)\n",
" return "
]
},
{
"cell_type": "markdown",
"id": "floating-surge",
"metadata": {},
"source": [
"# Select Patient Index"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "specialized-maximum",
"metadata": {},
"outputs": [],
"source": [
"k=0"
]
},
{
"cell_type": "markdown",
"id": "according-assets",
"metadata": {},
"source": [
"# Pretrain"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "living-gates",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-973440206fb84c43a608ed88945dabed'>\n",
" <span style=\"user-select:none\">\n",
" Layer: <select id=\"layer\"></select>\n",
" \n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": "/**\n * @fileoverview Transformer Visualization D3 javascript code.\n *\n *\n * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n *\n * Change log:\n *\n * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n * 12/29/20 Jesse Vig Significant refactor.\n * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n * 07/25/21 Jesse Vig Support layer filtering\n **/\n\nrequire.config({\n paths: {\n d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n }\n});\n\nrequirejs(['jquery', 'd3'], function ($, d3) {\n\n const params = {\"attention\": [{\"name\": null, \"attn\": [[[[0.16614454984664917, 0.1271992027759552, 0.11374139040708542, 0.15638020634651184, 0.20374387502670288, 0.11444535851478577, 0.11834541708230972], [0.16203011572360992, 0.1401018649339676, 0.13978800177574158, 0.1404246836900711, 0.13866281509399414, 0.13900551199913025, 0.1399870365858078], [0.14190000295639038, 0.14325059950351715, 0.14268504083156586, 0.14326979219913483, 0.14316092431545258, 0.14276131987571716, 0.14297233521938324], [0.1783653199672699, 0.13668516278266907, 0.13705065846443176, 0.13718663156032562, 0.1375586986541748, 0.1364724338054657, 0.13668116927146912], [0.17943505942821503, 0.13677170872688293, 0.13591942191123962, 0.13615824282169342, 0.13820093870162964, 0.13686884939670563, 0.13664573431015015], [0.1567888855934143, 0.1408628672361374, 0.14035286009311676, 0.1405779868364334, 0.14033928513526917, 0.14011918008327484, 0.14095890522003174], [0.1511572003364563, 0.1412300169467926, 0.14104074239730835, 0.14131779968738556, 0.14250311255455017, 0.1413128525018692, 0.14143821597099304]], [[0.047577470541000366, 0.04034224525094032, 0.03173127770423889, 0.03325849026441574, 0.05665835365653038, 0.040445178747177124, 0.03266838937997818], [0.03983418643474579, 0.03432993218302727, 0.0343342199921608, 0.034387726336717606, 0.034017566591501236, 0.03415762633085251, 0.03437581658363342], [0.0348617359995842, 0.03448709473013878, 0.034443072974681854, 0.034463874995708466, 0.0343620628118515, 0.03469187021255493, 0.03444177284836769], [0.03609755262732506, 0.03446199744939804, 0.03440416231751442, 0.034431908279657364, 0.03436928614974022, 0.034668270498514175, 0.0344182550907135], [0.04275275021791458, 0.034237828105688095, 0.03420986980199814, 0.03424009680747986, 0.03407548740506172, 0.03425997495651245, 0.03429330885410309], [0.04144877940416336, 0.034321531653404236, 0.034290462732315063, 0.03432182967662811, 0.03395794332027435, 0.03400801867246628, 0.0343194305896759], [0.03561968356370926, 0.034495096653699875, 0.034405313432216644, 0.03443482518196106, 0.034406013786792755, 0.03469153866171837, 0.03441842645406723]]]], \"left_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"], \"right_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"]}], \"default_filter\": \"0\", \"root_div_id\": \"bertviz-973440206fb84c43a608ed88945dabed\", \"layer\": null, \"heads\": null, \"include_layers\": [0]}; // HACK: {\"attention\": [{\"name\": null, \"attn\": [[[[0.16614454984664917, 0.1271992027759552, 0.11374139040708542, 0.15638020634651184, 0.20374387502670288, 0.11444535851478577, 0.11834541708230972], [0.16203011572360992, 0.1401018649339676, 0.13978800177574158, 0.1404246836900711, 0.13866281509399414, 0.13900551199913025, 0.1399870365858078], [0.14190000295639038, 0.14325059950351715, 0.14268504083156586, 0.14326979219913483, 0.14316092431545258, 0.14276131987571716, 0.14297233521938324], [0.1783653199672699, 0.13668516278266907, 0.13705065846443176, 0.13718663156032562, 0.1375586986541748, 0.1364724338054657, 0.13668116927146912], [0.17943505942821503, 0.13677170872688293, 0.13591942191123962, 0.13615824282169342, 0.13820093870162964, 0.13686884939670563, 0.13664573431015015], [0.1567888855934143, 0.1408628672361374, 0.14035286009311676, 0.1405779868364334, 0.14033928513526917, 0.14011918008327484, 0.14095890522003174], [0.1511572003364563, 0.1412300169467926, 0.14104074239730835, 0.14131779968738556, 0.14250311255455017, 0.1413128525018692, 0.14143821597099304]], [[0.047577470541000366, 0.04034224525094032, 0.03173127770423889, 0.03325849026441574, 0.05665835365653038, 0.040445178747177124, 0.03266838937997818], [0.03983418643474579, 0.03432993218302727, 0.0343342199921608, 0.034387726336717606, 0.034017566591501236, 0.03415762633085251, 0.03437581658363342], [0.0348617359995842, 0.03448709473013878, 0.034443072974681854, 0.034463874995708466, 0.0343620628118515, 0.03469187021255493, 0.03444177284836769], [0.03609755262732506, 0.03446199744939804, 0.03440416231751442, 0.034431908279657364, 0.03436928614974022, 0.034668270498514175, 0.0344182550907135], [0.04275275021791458, 0.034237828105688095, 0.03420986980199814, 0.03424009680747986, 0.03407548740506172, 0.03425997495651245, 0.03429330885410309], [0.04144877940416336, 0.034321531653404236, 0.034290462732315063, 0.03432182967662811, 0.03395794332027435, 0.03400801867246628, 0.0343194305896759], [0.03561968356370926, 0.034495096653699875, 0.034405313432216644, 0.03443482518196106, 0.034406013786792755, 0.03469153866171837, 0.03441842645406723]]]], \"left_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"], \"right_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"]}], \"default_filter\": \"0\", \"root_div_id\": \"bertviz-973440206fb84c43a608ed88945dabed\", \"layer\": null, \"heads\": null, \"include_layers\": [0]} is a template marker that is replaced by actual params.\n const TEXT_SIZE = 15;\n const BOXWIDTH = 110;\n const BOXHEIGHT = 22.5;\n const MATRIX_WIDTH = 115;\n const CHECKBOX_SIZE = 20;\n const TEXT_TOP = 30;\n\n console.log(\"d3 version\", d3.version)\n let headColors;\n try {\n headColors = d3.scaleOrdinal(d3.schemeCategory10);\n } catch (err) {\n console.log('Older d3 version')\n headColors = d3.scale.category10();\n }\n let config = {};\n initialize();\n renderVis();\n\n function initialize() {\n config.attention = params['attention'];\n config.filter = params['default_filter'];\n config.rootDivId = params['root_div_id'];\n config.nLayers = config.attention[config.filter]['attn'].length;\n config.nHeads = config.attention[config.filter]['attn'][0].length;\n config.layers = params['include_layers']\n\n if (params['heads']) {\n config.headVis = new Array(config.nHeads).fill(false);\n params['heads'].forEach(x => config.headVis[x] = true);\n } else {\n config.headVis = new Array(config.nHeads).fill(true);\n }\n config.initialTextLength = config.attention[config.filter].right_text.length;\n config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer));\n config.layer = config.layers[config.layer_seq]\n\n let layerEl = $(`#${config.rootDivId} #layer`);\n for (const layer of config.layers) {\n layerEl.append($(\"<option />\").val(layer).text(layer));\n }\n layerEl.val(config.layer).change();\n layerEl.on('change', function (e) {\n config.layer = +e.currentTarget.value;\n config.layer_seq = config.layers.findIndex(layer => config.layer === layer);\n renderVis();\n });\n\n $(`#${config.rootDivId} #filter`).on('change', function (e) {\n config.filter = e.currentTarget.value;\n renderVis();\n });\n }\n\n function renderVis() {\n\n // Load parameters\n const attnData = config.attention[config.filter];\n const leftText = attnData.left_text;\n const rightText = attnData.right_text;\n\n // Select attention for given layer\n const layerAttention = attnData.attn[config.layer_seq];\n\n // Clear vis\n $(`#${config.rootDivId} #vis`).empty();\n\n // Determine size of visualization\n const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;\n const svg = d3.select(`#${config.rootDivId} #vis`)\n .append('svg')\n .attr(\"width\", \"100%\")\n .attr(\"height\", height + \"px\");\n\n // Display tokens on left and right side of visualization\n renderText(svg, leftText, true, layerAttention, 0);\n renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);\n\n // Render attention arcs\n renderAttention(svg, layerAttention);\n\n // Draw squares at top of visualization, one for each head\n drawCheckboxes(0, svg, layerAttention);\n }\n\n function renderText(svg, text, isLeft, attention, leftPos) {\n\n const textContainer = svg.append(\"svg:g\")\n .attr(\"id\", isLeft ? \"left\" : \"right\");\n\n // Add attention highlights superimposed over words\n textContainer.append(\"g\")\n .classed(\"attentionBoxes\", true)\n .selectAll(\"g\")\n .data(attention)\n .enter()\n .append(\"g\")\n .attr(\"head-index\", (d, i) => i)\n .selectAll(\"rect\")\n .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights\n .enter()\n .append(\"rect\")\n .attr(\"x\", function () {\n var headIndex = +this.parentNode.getAttribute(\"head-index\");\n return leftPos + boxOffsets(headIndex);\n })\n .attr(\"y\", (+1) * BOXHEIGHT)\n .attr(\"width\", BOXWIDTH / activeHeads())\n .attr(\"height\", BOXHEIGHT)\n .attr(\"fill\", function () {\n return headColors(+this.parentNode.getAttribute(\"head-index\"))\n })\n .style(\"opacity\", 0.0);\n\n const tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n .data(text)\n .enter()\n .append(\"g\");\n\n // Add gray background that appears when hovering over text\n tokenContainer.append(\"rect\")\n .classed(\"background\", true)\n .style(\"opacity\", 0.0)\n .attr(\"fill\", \"lightgray\")\n .attr(\"x\", leftPos)\n .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n .attr(\"width\", BOXWIDTH)\n .attr(\"height\", BOXHEIGHT);\n\n // Add token text\n const textEl = tokenContainer.append(\"text\")\n .text(d => d)\n .attr(\"font-size\", TEXT_SIZE + \"px\")\n .style(\"cursor\", \"default\")\n .style(\"-webkit-user-select\", \"none\")\n .attr(\"x\", leftPos)\n .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT);\n\n if (isLeft) {\n textEl.style(\"text-anchor\", \"end\")\n .attr(\"dx\", BOXWIDTH - 0.5 * TEXT_SIZE)\n .attr(\"dy\", TEXT_SIZE);\n } else {\n textEl.style(\"text-anchor\", \"start\")\n .attr(\"dx\", +0.5 * TEXT_SIZE)\n .attr(\"dy\", TEXT_SIZE);\n }\n\n tokenContainer.on(\"mouseover\", function (d, index) {\n\n // Show gray background for moused-over token\n textContainer.selectAll(\".background\")\n .style(\"opacity\", (d, i) => i === index ? 1.0 : 0.0)\n\n // Reset visibility attribute for any previously highlighted attention arcs\n svg.select(\"#attention\")\n .selectAll(\"line[visibility='visible']\")\n .attr(\"visibility\", null)\n\n // Hide group containing attention arcs\n svg.select(\"#attention\").attr(\"visibility\", \"hidden\");\n\n // Set to visible appropriate attention arcs to be highlighted\n if (isLeft) {\n svg.select(\"#attention\").selectAll(\"line[left-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n } else {\n svg.select(\"#attention\").selectAll(\"line[right-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n }\n\n // Update color boxes superimposed over tokens\n const id = isLeft ? \"right\" : \"left\";\n const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;\n svg.select(\"#\" + id)\n .selectAll(\".attentionBoxes\")\n .selectAll(\"g\")\n .attr(\"head-index\", (d, i) => i)\n .selectAll(\"rect\")\n .attr(\"x\", function () {\n const headIndex = +this.parentNode.getAttribute(\"head-index\");\n return leftPos + boxOffsets(headIndex);\n })\n .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n .attr(\"width\", BOXWIDTH / activeHeads())\n .attr(\"height\", BOXHEIGHT)\n .style(\"opacity\", function (d) {\n const headIndex = +this.parentNode.getAttribute(\"head-index\");\n if (config.headVis[headIndex])\n if (d) {\n return d[index];\n } else {\n return 0.0;\n }\n else\n return 0.0;\n });\n });\n\n textContainer.on(\"mouseleave\", function () {\n\n // Unhighlight selected token\n d3.select(this).selectAll(\".background\")\n .style(\"opacity\", 0.0);\n\n // Reset visibility attributes for previously selected lines\n svg.select(\"#attention\")\n .selectAll(\"line[visibility='visible']\")\n .attr(\"visibility\", null) ;\n svg.select(\"#attention\").attr(\"visibility\", \"visible\");\n\n // Reset highlights superimposed over tokens\n svg.selectAll(\".attentionBoxes\")\n .selectAll(\"g\")\n .selectAll(\"rect\")\n .style(\"opacity\", 0.0);\n });\n }\n\n function renderAttention(svg, attention) {\n\n // Remove previous dom elements\n svg.select(\"#attention\").remove();\n\n // Add new elements\n svg.append(\"g\")\n .attr(\"id\", \"attention\") // Container for all attention arcs\n .selectAll(\".headAttention\")\n .data(attention)\n .enter()\n .append(\"g\")\n .classed(\"headAttention\", true) // Group attention arcs by head\n .attr(\"head-index\", (d, i) => i)\n .selectAll(\".tokenAttention\")\n .data(d => d)\n .enter()\n .append(\"g\")\n .classed(\"tokenAttention\", true) // Group attention arcs by left token\n .attr(\"left-token-index\", (d, i) => i)\n .selectAll(\"line\")\n .data(d => d)\n .enter()\n .append(\"line\")\n .attr(\"x1\", BOXWIDTH)\n .attr(\"y1\", function () {\n const leftTokenIndex = +this.parentNode.getAttribute(\"left-token-index\")\n return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)\n })\n .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n .attr(\"y2\", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))\n .attr(\"stroke-width\", 2)\n .attr(\"stroke\", function () {\n const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n return headColors(headIndex)\n })\n .attr(\"left-token-index\", function () {\n return +this.parentNode.getAttribute(\"left-token-index\")\n })\n .attr(\"right-token-index\", (d, i) => i)\n ;\n updateAttention(svg)\n }\n\n function updateAttention(svg) {\n svg.select(\"#attention\")\n .selectAll(\"line\")\n .attr(\"stroke-opacity\", function (d) {\n const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n // If head is selected\n if (config.headVis[headIndex]) {\n // Set opacity to attention weight divided by number of active heads\n return d / activeHeads()\n } else {\n return 0.0;\n }\n })\n }\n\n function boxOffsets(i) {\n const numHeadsAbove = config.headVis.reduce(\n function (acc, val, cur) {\n return val && cur < i ? acc + 1 : acc;\n }, 0);\n return numHeadsAbove * (BOXWIDTH / activeHeads());\n }\n\n function activeHeads() {\n return config.headVis.reduce(function (acc, val) {\n return val ? acc + 1 : acc;\n }, 0);\n }\n\n function drawCheckboxes(top, svg) {\n const checkboxContainer = svg.append(\"g\");\n const checkbox = checkboxContainer.selectAll(\"rect\")\n .data(config.headVis)\n .enter()\n .append(\"rect\")\n .attr(\"fill\", (d, i) => headColors(i))\n .attr(\"x\", (d, i) => i * CHECKBOX_SIZE)\n .attr(\"y\", top)\n .attr(\"width\", CHECKBOX_SIZE)\n .attr(\"height\", CHECKBOX_SIZE);\n\n function updateCheckboxes() {\n checkboxContainer.selectAll(\"rect\")\n .data(config.headVis)\n .attr(\"fill\", (d, i) => d ? headColors(i): lighten(headColors(i)));\n }\n\n updateCheckboxes();\n\n checkbox.on(\"click\", function (d, i) {\n if (config.headVis[i] && activeHeads() === 1) return;\n config.headVis[i] = !config.headVis[i];\n updateCheckboxes();\n updateAttention(svg);\n });\n\n checkbox.on(\"dblclick\", function (d, i) {\n // If we double click on the only active head then reset\n if (config.headVis[i] && activeHeads() === 1) {\n config.headVis = new Array(config.nHeads).fill(true);\n } else {\n config.headVis = new Array(config.nHeads).fill(false);\n config.headVis[i] = true;\n }\n updateCheckboxes();\n updateAttention(svg);\n });\n }\n\n function lighten(color) {\n const c = d3.hsl(color);\n const increment = (1 - c.l) * 0.6;\n c.l += increment;\n c.s -= increment;\n return c;\n }\n\n function transpose(mat) {\n return mat[0].map(function (col, i) {\n return mat.map(function (row) {\n return row[i];\n });\n });\n }\n\n});",
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"attention_scores(attention_weights_pretrain, k)"
]
},
{
"cell_type": "markdown",
"id": "statutory-meditation",
"metadata": {},
"source": [
"# Finetune"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "amino-burner",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-651dd202c575406799ca073a08ef5e54'>\n",
" <span style=\"user-select:none\">\n",
" Layer: <select id=\"layer\"></select>\n",
" \n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": "/**\n * @fileoverview Transformer Visualization D3 javascript code.\n *\n *\n * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n *\n * Change log:\n *\n * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n * 12/29/20 Jesse Vig Significant refactor.\n * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n * 07/25/21 Jesse Vig Support layer filtering\n **/\n\nrequire.config({\n paths: {\n d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n }\n});\n\nrequirejs(['jquery', 'd3'], function ($, d3) {\n\n const params = {\"attention\": [{\"name\": null, \"attn\": [[[[0.16445937752723694, 0.12579870223999023, 0.1133345440030098, 0.15633031725883484, 0.20919007062911987, 0.11322496086359024, 0.11766202002763748], [0.16152320802211761, 0.14000782370567322, 0.1398523598909378, 0.14041535556316376, 0.13919208943843842, 0.13906262814998627, 0.13994649052619934], [0.1411072164773941, 0.14333412051200867, 0.14292435348033905, 0.14334851503372192, 0.14325562119483948, 0.14297249913215637, 0.14305762946605682], [0.17711158096790314, 0.13652077317237854, 0.13727137446403503, 0.137272909283638, 0.1386304348707199, 0.1365974098443985, 0.13659557700157166], [0.17797741293907166, 0.13661958277225494, 0.13598830997943878, 0.13632971048355103, 0.13927258551120758, 0.1371464729309082, 0.13666586577892303], [0.15614137053489685, 0.14076903462409973, 0.14039446413516998, 0.14061200618743896, 0.1409381926059723, 0.14022590219974518, 0.14091895520687103], [0.15094415843486786, 0.1412247121334076, 0.14123418927192688, 0.14132803678512573, 0.14243265986442566, 0.14137880504131317, 0.1414574384689331]], [[0.04643920436501503, 0.03959910571575165, 0.031600456684827805, 0.03310973942279816, 0.05844629928469658, 0.03954439237713814, 0.03240245208144188], [0.039766352623701096, 0.03439924493432045, 0.034392524510622025, 0.034459188580513, 0.034056250005960464, 0.034218546003103256, 0.034431859850883484], [0.0349244587123394, 0.03453252092003822, 0.03443722426891327, 0.034467242658138275, 0.034302908927202225, 0.0347064845263958, 0.03444622829556465], [0.036057598888874054, 0.034504521638154984, 0.03441145271062851, 0.03444594889879227, 0.03432873636484146, 0.034707941114902496, 0.03443336486816406], [0.04230942204594612, 0.034347303211688995, 0.034300368279218674, 0.0343492291867733, 0.034209031611680984, 0.03433912247419357, 0.03436385095119476], [0.04137994721531868, 0.034371599555015564, 0.034335121512413025, 0.03437613323330879, 0.033850301057100296, 0.034061115235090256, 0.03437604010105133], [0.03564491868019104, 0.034539464861154556, 0.034410227090120316, 0.03444889560341835, 0.03437260538339615, 0.03471061959862709, 0.034431684762239456]]]], \"left_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"], \"right_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"]}], \"default_filter\": \"0\", \"root_div_id\": \"bertviz-651dd202c575406799ca073a08ef5e54\", \"layer\": null, \"heads\": null, \"include_layers\": [0]}; // HACK: {\"attention\": [{\"name\": null, \"attn\": [[[[0.16445937752723694, 0.12579870223999023, 0.1133345440030098, 0.15633031725883484, 0.20919007062911987, 0.11322496086359024, 0.11766202002763748], [0.16152320802211761, 0.14000782370567322, 0.1398523598909378, 0.14041535556316376, 0.13919208943843842, 0.13906262814998627, 0.13994649052619934], [0.1411072164773941, 0.14333412051200867, 0.14292435348033905, 0.14334851503372192, 0.14325562119483948, 0.14297249913215637, 0.14305762946605682], [0.17711158096790314, 0.13652077317237854, 0.13727137446403503, 0.137272909283638, 0.1386304348707199, 0.1365974098443985, 0.13659557700157166], [0.17797741293907166, 0.13661958277225494, 0.13598830997943878, 0.13632971048355103, 0.13927258551120758, 0.1371464729309082, 0.13666586577892303], [0.15614137053489685, 0.14076903462409973, 0.14039446413516998, 0.14061200618743896, 0.1409381926059723, 0.14022590219974518, 0.14091895520687103], [0.15094415843486786, 0.1412247121334076, 0.14123418927192688, 0.14132803678512573, 0.14243265986442566, 0.14137880504131317, 0.1414574384689331]], [[0.04643920436501503, 0.03959910571575165, 0.031600456684827805, 0.03310973942279816, 0.05844629928469658, 0.03954439237713814, 0.03240245208144188], [0.039766352623701096, 0.03439924493432045, 0.034392524510622025, 0.034459188580513, 0.034056250005960464, 0.034218546003103256, 0.034431859850883484], [0.0349244587123394, 0.03453252092003822, 0.03443722426891327, 0.034467242658138275, 0.034302908927202225, 0.0347064845263958, 0.03444622829556465], [0.036057598888874054, 0.034504521638154984, 0.03441145271062851, 0.03444594889879227, 0.03432873636484146, 0.034707941114902496, 0.03443336486816406], [0.04230942204594612, 0.034347303211688995, 0.034300368279218674, 0.0343492291867733, 0.034209031611680984, 0.03433912247419357, 0.03436385095119476], [0.04137994721531868, 0.034371599555015564, 0.034335121512413025, 0.03437613323330879, 0.033850301057100296, 0.034061115235090256, 0.03437604010105133], [0.03564491868019104, 0.034539464861154556, 0.034410227090120316, 0.03444889560341835, 0.03437260538339615, 0.03471061959862709, 0.034431684762239456]]]], \"left_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"], \"right_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"]}], \"default_filter\": \"0\", \"root_div_id\": \"bertviz-651dd202c575406799ca073a08ef5e54\", \"layer\": null, \"heads\": null, \"include_layers\": [0]} is a template marker that is replaced by actual params.\n const TEXT_SIZE = 15;\n const BOXWIDTH = 110;\n const BOXHEIGHT = 22.5;\n const MATRIX_WIDTH = 115;\n const CHECKBOX_SIZE = 20;\n const TEXT_TOP = 30;\n\n console.log(\"d3 version\", d3.version)\n let headColors;\n try {\n headColors = d3.scaleOrdinal(d3.schemeCategory10);\n } catch (err) {\n console.log('Older d3 version')\n headColors = d3.scale.category10();\n }\n let config = {};\n initialize();\n renderVis();\n\n function initialize() {\n config.attention = params['attention'];\n config.filter = params['default_filter'];\n config.rootDivId = params['root_div_id'];\n config.nLayers = config.attention[config.filter]['attn'].length;\n config.nHeads = config.attention[config.filter]['attn'][0].length;\n config.layers = params['include_layers']\n\n if (params['heads']) {\n config.headVis = new Array(config.nHeads).fill(false);\n params['heads'].forEach(x => config.headVis[x] = true);\n } else {\n config.headVis = new Array(config.nHeads).fill(true);\n }\n config.initialTextLength = config.attention[config.filter].right_text.length;\n config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer));\n config.layer = config.layers[config.layer_seq]\n\n let layerEl = $(`#${config.rootDivId} #layer`);\n for (const layer of config.layers) {\n layerEl.append($(\"<option />\").val(layer).text(layer));\n }\n layerEl.val(config.layer).change();\n layerEl.on('change', function (e) {\n config.layer = +e.currentTarget.value;\n config.layer_seq = config.layers.findIndex(layer => config.layer === layer);\n renderVis();\n });\n\n $(`#${config.rootDivId} #filter`).on('change', function (e) {\n config.filter = e.currentTarget.value;\n renderVis();\n });\n }\n\n function renderVis() {\n\n // Load parameters\n const attnData = config.attention[config.filter];\n const leftText = attnData.left_text;\n const rightText = attnData.right_text;\n\n // Select attention for given layer\n const layerAttention = attnData.attn[config.layer_seq];\n\n // Clear vis\n $(`#${config.rootDivId} #vis`).empty();\n\n // Determine size of visualization\n const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;\n const svg = d3.select(`#${config.rootDivId} #vis`)\n .append('svg')\n .attr(\"width\", \"100%\")\n .attr(\"height\", height + \"px\");\n\n // Display tokens on left and right side of visualization\n renderText(svg, leftText, true, layerAttention, 0);\n renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);\n\n // Render attention arcs\n renderAttention(svg, layerAttention);\n\n // Draw squares at top of visualization, one for each head\n drawCheckboxes(0, svg, layerAttention);\n }\n\n function renderText(svg, text, isLeft, attention, leftPos) {\n\n const textContainer = svg.append(\"svg:g\")\n .attr(\"id\", isLeft ? \"left\" : \"right\");\n\n // Add attention highlights superimposed over words\n textContainer.append(\"g\")\n .classed(\"attentionBoxes\", true)\n .selectAll(\"g\")\n .data(attention)\n .enter()\n .append(\"g\")\n .attr(\"head-index\", (d, i) => i)\n .selectAll(\"rect\")\n .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights\n .enter()\n .append(\"rect\")\n .attr(\"x\", function () {\n var headIndex = +this.parentNode.getAttribute(\"head-index\");\n return leftPos + boxOffsets(headIndex);\n })\n .attr(\"y\", (+1) * BOXHEIGHT)\n .attr(\"width\", BOXWIDTH / activeHeads())\n .attr(\"height\", BOXHEIGHT)\n .attr(\"fill\", function () {\n return headColors(+this.parentNode.getAttribute(\"head-index\"))\n })\n .style(\"opacity\", 0.0);\n\n const tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n .data(text)\n .enter()\n .append(\"g\");\n\n // Add gray background that appears when hovering over text\n tokenContainer.append(\"rect\")\n .classed(\"background\", true)\n .style(\"opacity\", 0.0)\n .attr(\"fill\", \"lightgray\")\n .attr(\"x\", leftPos)\n .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n .attr(\"width\", BOXWIDTH)\n .attr(\"height\", BOXHEIGHT);\n\n // Add token text\n const textEl = tokenContainer.append(\"text\")\n .text(d => d)\n .attr(\"font-size\", TEXT_SIZE + \"px\")\n .style(\"cursor\", \"default\")\n .style(\"-webkit-user-select\", \"none\")\n .attr(\"x\", leftPos)\n .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT);\n\n if (isLeft) {\n textEl.style(\"text-anchor\", \"end\")\n .attr(\"dx\", BOXWIDTH - 0.5 * TEXT_SIZE)\n .attr(\"dy\", TEXT_SIZE);\n } else {\n textEl.style(\"text-anchor\", \"start\")\n .attr(\"dx\", +0.5 * TEXT_SIZE)\n .attr(\"dy\", TEXT_SIZE);\n }\n\n tokenContainer.on(\"mouseover\", function (d, index) {\n\n // Show gray background for moused-over token\n textContainer.selectAll(\".background\")\n .style(\"opacity\", (d, i) => i === index ? 1.0 : 0.0)\n\n // Reset visibility attribute for any previously highlighted attention arcs\n svg.select(\"#attention\")\n .selectAll(\"line[visibility='visible']\")\n .attr(\"visibility\", null)\n\n // Hide group containing attention arcs\n svg.select(\"#attention\").attr(\"visibility\", \"hidden\");\n\n // Set to visible appropriate attention arcs to be highlighted\n if (isLeft) {\n svg.select(\"#attention\").selectAll(\"line[left-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n } else {\n svg.select(\"#attention\").selectAll(\"line[right-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n }\n\n // Update color boxes superimposed over tokens\n const id = isLeft ? \"right\" : \"left\";\n const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;\n svg.select(\"#\" + id)\n .selectAll(\".attentionBoxes\")\n .selectAll(\"g\")\n .attr(\"head-index\", (d, i) => i)\n .selectAll(\"rect\")\n .attr(\"x\", function () {\n const headIndex = +this.parentNode.getAttribute(\"head-index\");\n return leftPos + boxOffsets(headIndex);\n })\n .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n .attr(\"width\", BOXWIDTH / activeHeads())\n .attr(\"height\", BOXHEIGHT)\n .style(\"opacity\", function (d) {\n const headIndex = +this.parentNode.getAttribute(\"head-index\");\n if (config.headVis[headIndex])\n if (d) {\n return d[index];\n } else {\n return 0.0;\n }\n else\n return 0.0;\n });\n });\n\n textContainer.on(\"mouseleave\", function () {\n\n // Unhighlight selected token\n d3.select(this).selectAll(\".background\")\n .style(\"opacity\", 0.0);\n\n // Reset visibility attributes for previously selected lines\n svg.select(\"#attention\")\n .selectAll(\"line[visibility='visible']\")\n .attr(\"visibility\", null) ;\n svg.select(\"#attention\").attr(\"visibility\", \"visible\");\n\n // Reset highlights superimposed over tokens\n svg.selectAll(\".attentionBoxes\")\n .selectAll(\"g\")\n .selectAll(\"rect\")\n .style(\"opacity\", 0.0);\n });\n }\n\n function renderAttention(svg, attention) {\n\n // Remove previous dom elements\n svg.select(\"#attention\").remove();\n\n // Add new elements\n svg.append(\"g\")\n .attr(\"id\", \"attention\") // Container for all attention arcs\n .selectAll(\".headAttention\")\n .data(attention)\n .enter()\n .append(\"g\")\n .classed(\"headAttention\", true) // Group attention arcs by head\n .attr(\"head-index\", (d, i) => i)\n .selectAll(\".tokenAttention\")\n .data(d => d)\n .enter()\n .append(\"g\")\n .classed(\"tokenAttention\", true) // Group attention arcs by left token\n .attr(\"left-token-index\", (d, i) => i)\n .selectAll(\"line\")\n .data(d => d)\n .enter()\n .append(\"line\")\n .attr(\"x1\", BOXWIDTH)\n .attr(\"y1\", function () {\n const leftTokenIndex = +this.parentNode.getAttribute(\"left-token-index\")\n return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)\n })\n .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n .attr(\"y2\", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))\n .attr(\"stroke-width\", 2)\n .attr(\"stroke\", function () {\n const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n return headColors(headIndex)\n })\n .attr(\"left-token-index\", function () {\n return +this.parentNode.getAttribute(\"left-token-index\")\n })\n .attr(\"right-token-index\", (d, i) => i)\n ;\n updateAttention(svg)\n }\n\n function updateAttention(svg) {\n svg.select(\"#attention\")\n .selectAll(\"line\")\n .attr(\"stroke-opacity\", function (d) {\n const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n // If head is selected\n if (config.headVis[headIndex]) {\n // Set opacity to attention weight divided by number of active heads\n return d / activeHeads()\n } else {\n return 0.0;\n }\n })\n }\n\n function boxOffsets(i) {\n const numHeadsAbove = config.headVis.reduce(\n function (acc, val, cur) {\n return val && cur < i ? acc + 1 : acc;\n }, 0);\n return numHeadsAbove * (BOXWIDTH / activeHeads());\n }\n\n function activeHeads() {\n return config.headVis.reduce(function (acc, val) {\n return val ? acc + 1 : acc;\n }, 0);\n }\n\n function drawCheckboxes(top, svg) {\n const checkboxContainer = svg.append(\"g\");\n const checkbox = checkboxContainer.selectAll(\"rect\")\n .data(config.headVis)\n .enter()\n .append(\"rect\")\n .attr(\"fill\", (d, i) => headColors(i))\n .attr(\"x\", (d, i) => i * CHECKBOX_SIZE)\n .attr(\"y\", top)\n .attr(\"width\", CHECKBOX_SIZE)\n .attr(\"height\", CHECKBOX_SIZE);\n\n function updateCheckboxes() {\n checkboxContainer.selectAll(\"rect\")\n .data(config.headVis)\n .attr(\"fill\", (d, i) => d ? headColors(i): lighten(headColors(i)));\n }\n\n updateCheckboxes();\n\n checkbox.on(\"click\", function (d, i) {\n if (config.headVis[i] && activeHeads() === 1) return;\n config.headVis[i] = !config.headVis[i];\n updateCheckboxes();\n updateAttention(svg);\n });\n\n checkbox.on(\"dblclick\", function (d, i) {\n // If we double click on the only active head then reset\n if (config.headVis[i] && activeHeads() === 1) {\n config.headVis = new Array(config.nHeads).fill(true);\n } else {\n config.headVis = new Array(config.nHeads).fill(false);\n config.headVis[i] = true;\n }\n updateCheckboxes();\n updateAttention(svg);\n });\n }\n\n function lighten(color) {\n const c = d3.hsl(color);\n const increment = (1 - c.l) * 0.6;\n c.l += increment;\n c.s -= increment;\n return c;\n }\n\n function transpose(mat) {\n return mat[0].map(function (col, i) {\n return mat.map(function (row) {\n return row[i];\n });\n });\n }\n\n});",
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"attention_scores(attention_weights_finetune, k)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "personalized-intermediate",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: attention/get_att.py
================================================
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, seqs, vocab_sizes, list_IDs, max_visit, max_code, batch_size=100, shuffle=True):
self.seqs = seqs
self.code_vocab = vocab_sizes[0]
self.cat_vocab = vocab_sizes[1]
self.list_IDs = list_IDs
self.max_visit = max_visit
self.max_code = max_code
self.batch_size = batch_size
self.shuffle = shuffle
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.ceil(len(self.list_IDs) / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
list_IDs_temp = [self.list_IDs[k] for k in indexes]
X, y = self.__data_generation(list_IDs_temp)
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, list_IDs_temp):
'Generates data containing batch_size samples'
demo_feature, code_feature, util_feature, date_feature, cls_feature = self.seqs
batch_demo, batch_code, batch_util, batch_date, batch_cls = [], [], [], [], []
for i, ID in enumerate(list_IDs_temp):
batch_demo.append(demo_feature[ID])
batch_code.append(code_feature[ID])
batch_util.append(util_feature[ID])
batch_date.append(date_feature[ID])
batch_cls.append(cls_feature[ID])
batch_demo_feature = np.array(batch_demo)
batch_code_feature = self.code_padding(batch_code)
batch_util_feature = self.date_padding(batch_util)
batch_date_feature = self.date_padding(batch_date)
batch_cls = np.array(batch_cls)
dic = (
{
'demo_feature': batch_demo_feature,
'code_feature': batch_code_feature,
'util_feature': batch_util_feature,
'date_feature': batch_date_feature,
},
{
'cls_label': batch_cls
})
return dic
def date_padding(self, seq):
seq = [x[:-1] for x in seq]
pad_seq = np.zeros((len(seq), self.max_visit))
for i, p in enumerate(seq):
pad_seq[i][:len(p)] = p[:self.max_visit]
return pad_seq
def code_padding(self, seq):
seq = [x[:-1] for x in seq]
X = np.zeros((len(seq), self.max_visit, self.max_code))
for i, p in enumerate(seq):
if len(p) > self.max_visit:
p = p[:self.max_visit]
for j, claim in enumerate(p):
claim = claim[:self.max_code]
X[i][j][:len(claim)] = claim
return X
def process_code(seq, vocab2int):
unseen = []
new_seq = []
for p in seq:
new_p = []
for v in p:
new_v = []
for c in v:
if c not in vocab2int:
unseen.append(c)
continue
# vocab2int[c] = len(vocab2int)
new_v.append(vocab2int[c])
new_p.append(new_v)
new_seq.append(new_p)
print("UNSEEN VOCAB:",len(set(unseen)), len(unseen))
return new_seq
def process_util(seq, util2int):
new_seq = []
vocab2int = {"PAD":0,"IP":1,"RX":2,"OP":3}
for p in seq:
new_p = []
for v in p:
if "IP" in v:
new_v=1
elif "RX" in v:
new_v=2
else:
new_v=3
new_p.append(new_v)
new_seq.append(new_p)
return new_seq
def process_demo(age_seq, sex_seq, vocab2int):
new_seq = []
for age, sex in zip(age_seq, sex_seq):
p = []
assert age in vocab2int
assert sex in vocab2int
p.append(vocab2int[age])
p.append(vocab2int[sex])
new_seq.append(p)
return np.array(new_seq)
def get_cat(seq,code2cat):
new_seq = []
for p in seq:
new_p = []
for v in p:
new_v = []
for c in v:
new_c = code2cat[c]
new_v.append(new_c)
new_p.append(new_v)
new_seq.append(new_p)
return new_seq
print("==========LOADING DATA==========")
age_seq = pickle.load(open("../../suicideRisk/data/new_age_seq","rb"))
sex_seq = pickle.load(open("../../suicideRisk/data/new_sex_seq","rb"))
util_seq = pickle.load(open("../../suicideRisk/data/new_util_seq","rb"))
code_seq = pickle.load(open("../../suicideRisk/data/new_code_seq","rb"))
date_seq = pickle.load(open("../../suicideRisk/data/new_date_seq","rb"))
label_seq = pickle.load(open("../../suicideRisk/data/new_label_seq","rb"))
print("------LOADING DIC------")
path = "/Users/xxz005/Desktop/RAW_DATA/code2cat/"
diag2cat = pickle.load(open(path+"diag2cat","rb"))
proc2cat = pickle.load(open(path+"proc2cat","rb"))
drug2cat = pickle.load(open(path+"drug2cat","rb"))
code2cat = {**diag2cat, **proc2cat, **drug2cat}
code2int, util2int, demo2int, cat2int = pickle.load(open("../../pretraining/model/vocabs/vocabs","rb"))
code_feature = process_code(code_seq, code2int)
util_feature = process_util(util_seq, util2int)
demo_feature = process_demo(age_seq, sex_seq, demo2int)
date_feature = date_seq
cls_feature = np.array(label_seq).reshape((-1,1))
MAX_VISIT=30
MAX_CODE=10
MAX_DEMO=2
PATIENT_DIM=100
BATCH_SIZE = 500
TRAIN_RATIO = 0.7
DATA_SIZE = len(age_seq)
EPOCHS = 20
params = {
'seqs':[demo_feature, code_feature, util_feature, date_feature, cls_feature],
'vocab_sizes': [len(code2int), len(cat2int)],
'batch_size':100,
'max_visit':MAX_VISIT,
'max_code':MAX_CODE,
}
generator = DataGenerator(list_IDs=range(DATA_SIZE), shuffle=False, **params)
model_path = "./saveModel"
model = tf.keras.models.load_model(model_path)
model_losses = {
"cls_label":tf.keras.losses.BinaryCrossentropy(),
}
model_metrics = {
"cls_label": tf.keras.metrics.AUC(),
}
opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
model.compile(optimizer=opt, loss=model_losses, metrics=model_metrics)
print(model.summary())
layer_name = 'multihead_attention-0'
intermediate_layer_model = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
encoded_features = intermediate_layer_model.predict(generator)
pickle.dump(encoded_features, open("attention_weights", "wb"))
================================================
FILE: attention/saveModel/variables/variables.data-00000-of-00002
================================================
[File too large to display: 38.3 MB]
================================================
FILE: finetune/asthma/fine-tune.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "wicked-finder",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 205 µs\n"
]
}
],
"source": [
"%load_ext autotime"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "random-fluid",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 780 ms\n"
]
}
],
"source": [
"import numpy as np\n",
"import _pickle as pickle\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "diverse-vegetation",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========LOADING DATA==========\n",
"time: 185 ms\n"
]
}
],
"source": [
"print(\"==========LOADING DATA==========\")\n",
"age_seq = pickle.load(open(\"../data/new_age_seq\",\"rb\"))\n",
"sex_seq = pickle.load(open(\"../data/new_sex_seq\",\"rb\"))\n",
" \n",
"util_seq = pickle.load(open(\"../data/new_util_seq\",\"rb\"))\n",
"code_seq = pickle.load(open(\"../data/new_code_seq\",\"rb\"))\n",
"date_seq = pickle.load(open(\"../data/new_date_seq\",\"rb\"))\n",
"label_seq = pickle.load(open(\"../data/new_label_seq\",\"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "disabled-favor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 3.28 s\n"
]
}
],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "hollow-environment",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 5.82 ms\n"
]
}
],
"source": [
"class DataGenerator(tf.keras.utils.Sequence):\n",
" def __init__(self, seqs, vocab_sizes, list_IDs, max_visit, max_code, batch_size=100, shuffle=True):\n",
" self.seqs = seqs\n",
" self.code_vocab = vocab_sizes[0]\n",
" self.cat_vocab = vocab_sizes[1]\n",
" self.list_IDs = list_IDs\n",
" self.max_visit = max_visit\n",
" self.max_code = max_code\n",
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.on_epoch_end()\n",
"\n",
" def __len__(self):\n",
" 'Denotes the number of batches per epoch'\n",
" return int(np.ceil(len(self.list_IDs) / self.batch_size))\n",
"\n",
" def __getitem__(self, index):\n",
" 'Generate one batch of data'\n",
" indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]\n",
" list_IDs_temp = [self.list_IDs[k] for k in indexes]\n",
" X, y = self.__data_generation(list_IDs_temp)\n",
" return X, y\n",
"\n",
" def on_epoch_end(self):\n",
" 'Updates indexes after each epoch' \n",
" self.indexes = np.arange(len(self.list_IDs))\n",
" if self.shuffle == True:\n",
" np.random.shuffle(self.indexes)\n",
"\n",
" def __data_generation(self, list_IDs_temp):\n",
" 'Generates data containing batch_size samples' \n",
" demo_feature, code_feature, util_feature, date_feature, cls_feature = self.seqs\n",
" batch_demo, batch_code, batch_util, batch_date, batch_cls = [], [], [], [], []\n",
" for i, ID in enumerate(list_IDs_temp):\n",
" batch_demo.append(demo_feature[ID])\n",
" batch_code.append(code_feature[ID])\n",
" batch_util.append(util_feature[ID])\n",
" batch_date.append(date_feature[ID])\n",
" batch_cls.append(cls_feature[ID])\n",
" \n",
" batch_demo_feature = np.array(batch_demo)\n",
" batch_code_feature = self.code_padding(batch_code)\n",
" batch_util_feature = self.date_padding(batch_util)\n",
" batch_date_feature = self.date_padding(batch_date)\n",
" batch_cls = np.array(batch_cls)\n",
" \n",
" dic = (\n",
" {\n",
" 'demo_feature': batch_demo_feature,\n",
" 'code_feature': batch_code_feature,\n",
" 'util_feature': batch_util_feature,\n",
" 'date_feature': batch_date_feature,\n",
" },\n",
" {\n",
" 'cls_label': batch_cls\n",
" })\n",
" return dic\n",
" \n",
" def date_padding(self, seq):\n",
" seq = [x[:-1] for x in seq]\n",
" \n",
" pad_seq = np.zeros((len(seq), self.max_visit))\n",
" for i, p in enumerate(seq):\n",
" pad_seq[i][:len(p)] = p[:self.max_visit]\n",
" return pad_seq\n",
" \n",
" def code_padding(self, seq):\n",
" seq = [x[:-1] for x in seq]\n",
" \n",
" X = np.zeros((len(seq), self.max_visit, self.max_code))\n",
" for i, p in enumerate(seq):\n",
" if len(p) > self.max_visit: \n",
" p = p[:self.max_visit]\n",
" for j, claim in enumerate(p):\n",
" claim = claim[:self.max_code]\n",
" X[i][j][:len(claim)] = claim\n",
" return X\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "quick-webster",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 18 ms\n"
]
}
],
"source": [
"def create_code_mask(code_seq):\n",
" code_mask = tf.cast(tf.math.not_equal(code_seq, 0), tf.float32)\n",
" return code_mask[:,:,:,tf.newaxis]\n",
"\n",
"def create_visit_mask(seq):\n",
" visit_mask = tf.cast(tf.math.not_equal(seq, 0), tf.float32)\n",
" return visit_mask[:,:]\n",
"\n",
"def scaled_dot_product_attention(Q, K, V, Q_masks, K_masks):\n",
" d_k = K.get_shape().as_list()[-1] # d_model/h\n",
"\n",
" outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) # (h*N, T_q, T_k)\n",
" outputs /= d_k ** 0.5\n",
"\n",
" padding_num = -1e+7\n",
" K_masks = tf.expand_dims(K_masks, 1) # (h*N, 1, T_k)\n",
" K_masks = tf.tile(K_masks, [1, tf.shape(Q)[1], 1]) # (h*N, T_q, T_k)\n",
" paddings = tf.ones_like(outputs) * padding_num\n",
" outputs = tf.where(tf.equal(K_masks, 0), paddings, outputs) # (h*N, T_q, T_k)\n",
"\n",
" outputs = tf.nn.softmax(outputs)\n",
" Q_masks = tf.expand_dims(Q_masks, -1) # (h*N, T_q, 1)\n",
" Q_masks = tf.tile(Q_masks, [1, 1, tf.shape(K)[1]]) # (h*N, T_q, T_k)\n",
" outputs = outputs * tf.cast(Q_masks, dtype=tf.float32)\n",
"\n",
" return tf.matmul(outputs, V) # [h*N, T_q, d_model/h]\n",
"\n",
"class multihead_attention(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, num_heads, name=\"multihead_attention\"):\n",
" super(multihead_attention, self).__init__(name=name)\n",
" self.num_heads = num_heads\n",
" self.d_model = d_model\n",
"\n",
" assert d_model % self.num_heads == 0\n",
"\n",
" self.query_dense = layers.Dense(units=d_model, use_bias=False)\n",
" self.key_dense = layers.Dense(units=d_model, use_bias=False)\n",
" self.value_dense = layers.Dense(units=d_model, use_bias=False)\n",
" self.add =layers.Add()\n",
" self.norm = layers.LayerNormalization()\n",
" \n",
" def call(self, queries, keys, values, query_masks, key_masks):\n",
" Q = self.query_dense(queries)\n",
" K = self.key_dense(keys)\n",
" V = self.value_dense(values)\n",
"\n",
" # Split and concat\n",
" Q_ = tf.concat(tf.split(Q, self.num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h)\n",
" K_ = tf.concat(tf.split(K, self.num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)\n",
" V_ = tf.concat(tf.split(V, self.num_heads, axis=2), axis=0) # (h*N, T_v, d_model/h)\n",
" query_masks = tf.tile(query_masks, [self.num_heads, 1]) # (h*N, T_q)\n",
" key_masks = tf.tile(key_masks, [self.num_heads, 1]) # (h*N, T_k)\n",
"\n",
" # Attention\n",
" outputs = scaled_dot_product_attention(Q_, K_, V_, query_masks, key_masks) # (h*N, T_q, d_model/h)\n",
"\n",
" # Restore shape\n",
" outputs = tf.concat(tf.split(outputs, self.num_heads, axis=0), axis=2) # (N, T_q, d_model)\n",
"\n",
" # Residual connection\n",
" outputs = self.add([queries, outputs])\n",
" outputs = self.norm(outputs)\n",
" \n",
" return outputs\n",
"\n",
"class ffn(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, ffn_dim, name=\"ffn\"):\n",
" super(ffn, self).__init__(name=name)\n",
" self.ffn_dim = ffn_dim\n",
" self.dense1 = layers.Dense(units=ffn_dim, activation=tf.nn.relu, use_bias=False)\n",
" self.dense2 = layers.Dense(units=d_model, use_bias=False)\n",
" self.add =layers.Add()\n",
" self.norm = layers.LayerNormalization()\n",
" \n",
" def call(self, inputs):\n",
" outputs = self.dense1(inputs)\n",
" outputs = self.dense2(outputs)\n",
" outputs = self.add([inputs, outputs])\n",
" outputs = self.norm(outputs)\n",
" return outputs\n",
"\n",
"def cat_recall(y_true, y_pred):\n",
" mask_value = tf.cast(tf.not_equal(tf.reduce_sum(y_true,axis=-1), 0), tf.float32)\n",
" true_positives = tf.cast(tf.reduce_sum(tf.multiply(tf.round(y_pred), y_true), axis=-1), tf.float32)\n",
" possible_positives = tf.cast(tf.reduce_sum(y_true, axis=-1), tf.float32)\n",
" values = true_positives / (possible_positives + 1e-7)\n",
" return tf.reduce_sum(values)/tf.reduce_sum(mask_value)\n",
"\n",
"def cat_loss_fun(y_true, y_pred):\n",
" loss = tf.cast(tf.keras.losses.BinaryCrossentropy(reduction='none')(y_true, y_pred), tf.float32)\n",
" mask = tf.cast(tf.not_equal(tf.reduce_sum(y_true,axis=-1), 0), tf.float32)\n",
" loss = tf.multiply(loss, mask)\n",
" # return tf.reduce_sum(loss)/tf.reduce_sum(mask)\n",
" return loss\n",
"\n",
"def model(\n",
" max_visit,\n",
" max_code,\n",
" max_demo,\n",
" \n",
" demo_vocab,\n",
" code_vocab,\n",
" date_vocab,\n",
" util_vocab,\n",
" cat_vocab,\n",
"\n",
" patient_dim,\n",
" vocab_dim=100,\n",
" model_dim=100,\n",
" ffn_dim=100,\n",
" num_heads=2,\n",
" num_translayer=1,\n",
" \n",
" model_name=\"TransF\"):\n",
" \n",
" demo = layers.Input(shape=(max_demo, ), name=\"demo_feature\") # max_demo = 2, age&sex\n",
" code_seq = layers.Input(shape=(max_visit, max_code), name=\"code_feature\") \n",
" util_seq = layers.Input(shape=(max_visit), name=\"util_feature\")\n",
" date_seq = layers.Input(shape=(max_visit), name=\"date_feature\")\n",
"\n",
" inputs = [demo, code_seq, util_seq, date_seq]\n",
" \n",
" # demo embedding\n",
" demo_emb = layers.Embedding(input_dim=demo_vocab, output_dim=vocab_dim, mask_zero=True, name='demo_embedding')(demo)\n",
" demo_emb = layers.Lambda(lambda x: tf.keras.backend.sum(x, axis=1))(demo_emb) \n",
"\n",
" # code sequence\n",
" code_mask = layers.Lambda(create_code_mask)(code_seq)\n",
" code_emb = layers.Embedding(input_dim=code_vocab, \n",
" output_dim=vocab_dim, \n",
" name='code_embed')(code_seq)\n",
" code_emb = layers.Multiply()([code_emb, code_mask])\n",
" code_emb = tf.reduce_sum(code_emb, axis=2) \n",
"\n",
" \n",
" # visit mask\n",
" visit_mask = layers.Lambda(create_visit_mask)(date_seq)\n",
" \n",
" # util sequence \n",
" util_emb = layers.Embedding(input_dim=util_vocab, output_dim=vocab_dim, mask_zero=True, name='util_embedding')(util_seq)\n",
" util_emb = layers.Multiply()([util_emb, visit_mask[:,:,tf.newaxis]])\n",
" \n",
" # date sequence \n",
" date_emb = layers.Embedding(input_dim=date_vocab, output_dim=vocab_dim, mask_zero=True, name='date_embedding')(date_seq)\n",
" date_emb = layers.Multiply()([date_emb, visit_mask[:,:,tf.newaxis]])\n",
"\n",
" # visit sequence\n",
" visit_emb = layers.Add()([code_emb, date_emb, util_emb]) \n",
"\n",
" demo_emb = tf.expand_dims(demo_emb, 1) # (N, 1, emb_size)\n",
" demo_mask = tf.ones_like(tf.reduce_sum(demo_emb, axis=2), tf.float32) # (N, 1)\n",
" \n",
" for trans_layer in range(num_translayer):\n",
" multihead = multihead_attention(model_dim, num_heads, name=\"multihead_attention-\"+str(trans_layer))(\n",
" queries=tf.concat([demo_emb, visit_emb], 1), \n",
" keys=tf.concat([demo_emb, visit_emb], 1),\n",
" values=tf.concat([demo_emb, visit_emb], 1),\n",
" query_masks=tf.concat([demo_mask, visit_mask], 1),\n",
" key_masks=tf.concat([demo_mask, visit_mask], 1)\n",
" )\n",
" \n",
" visit_emb = ffn(model_dim, ffn_dim, name=\"ffn-\"+str(trans_layer))(multihead) # (N, max_visit, emb_size)\n",
" \n",
" demo_emb = visit_emb[:, :1, :] # (N, 1, emb_size)\n",
" visit_emb = visit_emb[:, 1:max_visit+1, :] # (N, max_visit, emb_size)\n",
" \n",
" patient_embedding = layers.Dense(patient_dim, activation=None, name=\"patient_embedding\")(tf.squeeze(demo_emb, [1]))\n",
" \n",
" \n",
" cls_label = layers.Dense(units=model_dim, activation=tf.nn.relu)(patient_embedding)\n",
" cls_label = layers.Dense(units=1, activation=tf.nn.sigmoid, name=\"cls_label\")(cls_label)\n",
" \n",
" \n",
" outputs = [cls_label]\n",
"\n",
" return tf.keras.Model(inputs=inputs, outputs=outputs, name=model_name)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dependent-movie",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 4.81 ms\n"
]
}
],
"source": [
"def process_code(seq, vocab2int):\n",
" unseen = []\n",
" new_seq = []\n",
" for p in seq:\n",
" new_p = []\n",
" for v in p:\n",
" new_v = []\n",
" for c in v:\n",
" if c not in vocab2int: \n",
" unseen.append(c)\n",
" continue\n",
" # vocab2int[c] = len(vocab2int)\n",
" new_v.append(vocab2int[c])\n",
" new_p.append(new_v)\n",
" new_seq.append(new_p)\n",
" \n",
" print(\"UNSEEN VOCAB:\",len(set(unseen)), len(unseen))\n",
" return new_seq\n",
"\n",
"def process_util(seq, util2int):\n",
" new_seq = []\n",
" vocab2int = {\"PAD\":0,\"IP\":1,\"RX\":2,\"OP\":3}\n",
" for p in seq:\n",
" new_p = []\n",
" for v in p:\n",
" if \"IP\" in v:\n",
" new_v=1\n",
" elif \"RX\" in v:\n",
" new_v=2\n",
" else:\n",
" new_v=3\n",
" new_p.append(new_v)\n",
" new_seq.append(new_p)\n",
" return new_seq\n",
" \n",
"def process_demo(age_seq, sex_seq, vocab2int):\n",
" new_seq = []\n",
" for age, sex in zip(age_seq, sex_seq):\n",
" p = []\n",
" assert age in vocab2int\n",
" assert sex in vocab2int\n",
" \n",
" p.append(vocab2int[age])\n",
" p.append(vocab2int[sex])\n",
" new_seq.append(p)\n",
" return np.array(new_seq)\n",
"\n",
"def get_cat(seq,code2cat):\n",
" new_seq = []\n",
" for p in seq:\n",
" new_p = []\n",
" for v in p:\n",
" new_v = []\n",
" for c in v:\n",
" new_c = code2cat[c]\n",
" new_v.append(new_c)\n",
" new_p.append(new_v)\n",
" new_seq.append(new_p)\n",
" return new_seq"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ongoing-manufacturer",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"------LOADING DIC------\n",
"time: 69.2 ms\n"
]
}
],
"source": [
"print(\"------LOADING DIC------\")\n",
"path = \"/Users/xxz005/Desktop/RAW_DATA/code2cat/\"\n",
"\n",
"diag2cat = pickle.load(open(path+\"diag2cat\",\"rb\"))\n",
"proc2cat = pickle.load(open(path+\"proc2cat\",\"rb\"))\n",
"drug2cat = pickle.load(open(path+\"drug2cat\",\"rb\"))\n",
"\n",
"code2cat = {**diag2cat, **proc2cat, **drug2cat}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "wrong-warrior",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 36.3 ms\n"
]
}
],
"source": [
"code2int, util2int, demo2int, cat2int = pickle.load(open(\"../../pretraining/model/vocabs/vocabs\",\"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "invalid-silicon",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"UNSEEN VOCAB: 3 4\n",
"time: 120 ms\n"
]
}
],
"source": [
"code_feature = process_code(code_seq, code2int)\n",
"util_feature = process_util(util_seq, util2int)\n",
"demo_feature = process_demo(age_seq, sex_seq, demo2int)\n",
"date_feature = date_seq\n",
"\n",
"cls_feature = np.array(label_seq).reshape((-1,1))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "adverse-arkansas",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 1.92 ms\n"
]
}
],
"source": [
"MAX_VISIT=30\n",
"MAX_CODE=10\n",
"MAX_DEMO=2\n",
"PATIENT_DIM=100\n",
"\n",
"BATCH_SIZE = 500\n",
"TRAIN_RATIO = 0.5\n",
"DATA_SIZE = len(age_seq)\n",
"EPOCHS = 20\n",
"\n",
"params = {\n",
" 'seqs':[demo_feature, code_feature, util_feature, date_feature, cls_feature],\n",
" 'vocab_sizes': [len(code2int), len(cat2int)],\n",
" 'batch_size':100,\n",
" 'max_visit':MAX_VISIT, \n",
" 'max_code':MAX_CODE,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "sorted-wholesale",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 933 ms\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"train_IDs, valid_IDs = train_test_split(range(DATA_SIZE), train_size=TRAIN_RATIO, random_state=42)\n",
"train_generator = DataGenerator(list_IDs=train_IDs, shuffle=True, **params)\n",
"valid_generator = DataGenerator(list_IDs=valid_IDs, shuffle=False, **params)"
]
},
{
"cell_type": "markdown",
"id": "separated-layout",
"metadata": {},
"source": [
"# Pretrain"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "suspended-settlement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 393 µs\n"
]
}
],
"source": [
"# model_path = \"../../pretraining/model/saveModel\"\n",
"\n",
"# model = tf.keras.models.load_model(model_path)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "handy-impression",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py:1059: UserWarning: cpt is not loaded, but a Lambda layer uses it. It may cause errors.\n",
" , UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 1.84 s\n"
]
}
],
"source": [
"model_path = \"/Users/xxz005/Desktop/saveModel\"\n",
"\n",
"model = tf.keras.models.load_model(model_path)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "involved-genetics",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 8.37 ms\n"
]
}
],
"source": [
"model_losses = {\n",
" \"cls_label\":tf.keras.losses.BinaryCrossentropy(),\n",
"}\n",
"\n",
"model_metrics = {\n",
" \"cls_label\": tf.keras.metrics.AUC(),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "veterinary-chinese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 16.6 ms\n"
]
}
],
"source": [
"opt = tf.keras.optimizers.Adam(learning_rate=0.0001)\n",
"model.compile(optimizer=opt, loss=model_losses, metrics=model_metrics)\n",
"# print(model.summary())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "vanilla-karen",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9/9 [==============================] - 1s 26ms/step - loss: 5.8961 - cls_label_loss: 5.8961 - cls_label_auc: 0.4489\n"
]
},
{
"data": {
"text/plain": [
"[5.853375434875488, 5.853375434875488, 0.4493659734725952]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 946 ms\n"
]
}
],
"source": [
"model.evaluate(valid_generator)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "verbal-kingdom",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"WARNING:tensorflow:Gradients do not exist for variables ['code_label/kernel:0', 'code_label/bias:0', 'cat_label/kernel:0', 'cat_label/bias:0'] when minimizing the loss.\n",
"WARNING:tensorflow:Gradients do not exist for variables ['code_label/kernel:0', 'code_label/bias:0', 'cat_label/kernel:0', 'cat_label/bias:0'] when minimizing the loss.\n",
"9/9 - 3s - loss: 2.6438 - cls_label_loss: 2.6438 - cls_label_auc: 0.5739 - val_loss: 1.8408 - val_cls_label_loss: 1.8408 - val_cls_label_auc: 0.5634\n",
"Epoch 2/20\n",
"9/9 - 1s - loss: 1.1236 - cls_label_loss: 1.1236 - cls_label_auc: 0.6073 - val_loss: 0.7250 - val_cls_label_loss: 0.7250 - val_cls_label_auc: 0.5605\n",
"Epoch 3/20\n",
"9/9 - 1s - loss: 0.7922 - cls_label_loss: 0.7922 - cls_label_auc: 0.5777 - val_loss: 0.7698 - val_cls_label_loss: 0.7698 - val_cls_label_auc: 0.5885\n",
"Epoch 4/20\n",
"9/9 - 1s - loss: 0.6654 - cls_label_loss: 0.6654 - cls_label_auc: 0.6627 - val_loss: 0.6365 - val_cls_label_loss: 0.6365 - val_cls_label_auc: 0.6055\n",
"Epoch 5/20\n",
"9/9 - 1s - loss: 0.5701 - cls_label_loss: 0.5701 - cls_label_auc: 0.7252 - val_loss: 0.6634 - val_cls_label_loss: 0.6634 - val_cls_label_auc: 0.6159\n",
"Epoch 6/20\n",
"9/9 - 1s - loss: 0.5389 - cls_label_loss: 0.5389 - cls_label_auc: 0.7541 - val_loss: 0.6247 - val_cls_label_loss: 0.6247 - val_cls_label_auc: 0.6402\n",
"Epoch 7/20\n",
"9/9 - 1s - loss: 0.5063 - cls_label_loss: 0.5063 - cls_label_auc: 0.7928 - val_loss: 0.6179 - val_cls_label_loss: 0.6179 - val_cls_label_auc: 0.6594\n",
"Epoch 8/20\n",
"9/9 - 1s - loss: 0.4940 - cls_label_loss: 0.4940 - cls_label_auc: 0.8012 - val_loss: 0.6137 - val_cls_label_loss: 0.6137 - val_cls_label_auc: 0.6710\n",
"Epoch 9/20\n",
"9/9 - 1s - loss: 0.4807 - cls_label_loss: 0.4807 - cls_label_auc: 0.8176 - val_loss: 0.6009 - val_cls_label_loss: 0.6009 - val_cls_label_auc: 0.6736\n",
"Epoch 10/20\n",
"9/9 - 1s - loss: 0.4478 - cls_label_loss: 0.4478 - cls_label_auc: 0.8687 - val_loss: 0.5882 - val_cls_label_loss: 0.5882 - val_cls_label_auc: 0.6880\n",
"Epoch 11/20\n",
"9/9 - 1s - loss: 0.4346 - cls_label_loss: 0.4346 - cls_label_auc: 0.8817 - val_loss: 0.5889 - val_cls_label_loss: 0.5889 - val_cls_label_auc: 0.6959\n",
"Epoch 12/20\n",
"9/9 - 1s - loss: 0.4192 - cls_label_loss: 0.4192 - cls_label_auc: 0.8963 - val_loss: 0.5838 - val_cls_label_loss: 0.5838 - val_cls_label_auc: 0.7004\n",
"Epoch 13/20\n",
"9/9 - 1s - loss: 0.4034 - cls_label_loss: 0.4034 - cls_label_auc: 0.9113 - val_loss: 0.5802 - val_cls_label_loss: 0.5802 - val_cls_label_auc: 0.7077\n",
"Epoch 14/20\n",
"9/9 - 1s - loss: 0.3892 - cls_label_loss: 0.3892 - cls_label_auc: 0.9186 - val_loss: 0.5747 - val_cls_label_loss: 0.5747 - val_cls_label_auc: 0.7162\n",
"Epoch 15/20\n",
"9/9 - 1s - loss: 0.3736 - cls_label_loss: 0.3736 - cls_label_auc: 0.9294 - val_loss: 0.5760 - val_cls_label_loss: 0.5760 - val_cls_label_auc: 0.7186\n",
"Epoch 16/20\n",
"9/9 - 1s - loss: 0.3601 - cls_label_loss: 0.3601 - cls_label_auc: 0.9389 - val_loss: 0.5731 - val_cls_label_loss: 0.5731 - val_cls_label_auc: 0.7215\n",
"Epoch 17/20\n",
"9/9 - 1s - loss: 0.3471 - cls_label_loss: 0.3471 - cls_label_auc: 0.9431 - val_loss: 0.5776 - val_cls_label_loss: 0.5776 - val_cls_label_auc: 0.7272\n",
"Epoch 18/20\n",
"9/9 - 1s - loss: 0.3372 - cls_label_loss: 0.3372 - cls_label_auc: 0.9480 - val_loss: 0.5731 - val_cls_label_loss: 0.5731 - val_cls_label_auc: 0.7282\n",
"Epoch 19/20\n",
"9/9 - 1s - loss: 0.3290 - cls_label_loss: 0.3290 - cls_label_auc: 0.9498 - val_loss: 0.5812 - val_cls_label_loss: 0.5812 - val_cls_label_auc: 0.7331\n",
"Epoch 20/20\n",
"9/9 - 1s - loss: 0.3190 - cls_label_loss: 0.3190 - cls_label_auc: 0.9516 - val_loss: 0.5695 - val_cls_label_loss: 0.5695 - val_cls_label_auc: 0.7372\n",
"time: 23.8 s\n"
]
}
],
"source": [
"finetune_history = model.fit(\n",
" train_generator,\n",
" epochs=EPOCHS,\n",
" validation_data=valid_generator,\n",
" verbose=2,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "conventional-makeup",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "logical-opinion",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "median-beads",
"metadata": {},
"source": [
"# Cold Start"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "particular-corps",
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "The first argument to `Layer.call` must always be passed.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-18-e604b1f99c15>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mutil_vocab\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mdate_vocab\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m365\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mcat_vocab\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcat2int\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m )\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[0;31m# not to any other argument.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 941\u001b[0m \u001b[0;31m# - setting the SavedModel saving spec.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 942\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_out_first_arg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 943\u001b[0m \u001b[0minput_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 944\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m_split_out_first_arg\u001b[0;34m(self, args, kwargs)\u001b[0m\n\u001b[1;32m 3046\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3047\u001b[0m raise ValueError(\n\u001b[0;32m-> 3048\u001b[0;31m 'The first argument to `Layer.call` must always be passed.')\n\u001b[0m\u001b[1;32m 3049\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3050\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: The first argument to `Layer.call` must always be passed."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 58.9 ms\n"
]
}
],
"source": [
"m = model(\n",
" patient_dim=PATIENT_DIM,\n",
" max_visit=MAX_VISIT,\n",
" max_code=MAX_CODE,\n",
" max_demo=MAX_DEMO,\n",
" code_vocab=len(code2int),\n",
" demo_vocab=len(demo2int),\n",
" util_vocab=4,\n",
" date_vocab=365,\n",
" cat_vocab=len(cat2int),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "demonstrated-congress",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 19,
"id": "frozen-borough",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'm' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-19-440ef62b5236>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m m.compile(optimizer=\"RMSprop\", loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.Recall(top_k=5), \n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRecall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m tf.keras.metrics.Recall(top_k=30)])\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'm' is not defined"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 16.8 ms\n"
]
}
],
"source": [
"m.compile(optimizer=\"RMSprop\", loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.Recall(top_k=5), \n",
" tf.keras.metrics.Recall(top_k=10), \n",
" tf.keras.metrics.Recall(top_k=30)])\n",
"print(m.summary())\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "structural-space",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'm' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-26-ea420610e986>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m m.compile(optimizer=opt, loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.Recall(top_k=5), \n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRecall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m tf.keras.metrics.Recall(top_k=30)])\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'm' is not defined"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 18.9 ms\n"
]
}
],
"source": [
"opt = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
"m.compile(optimizer=opt, loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.Recall(top_k=5), \n",
" tf.keras.metrics.Recall(top_k=10), \n",
" tf.keras.metrics.Recall(top_k=30)])\n",
"print(m.summary())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cleared-swimming",
"metadata": {},
"outputs": [],
"source": [
"cold_start_his = m.fit(\n",
" train_generator,\n",
" epochs=EPOCHS,\n",
" validation_data=valid_generator,\n",
" verbose=2,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "apparent-percentage",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "shaped-variable",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "japanese-canvas",
"metadata": {},
"source": [
"# plot"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "comprehensive-playback",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 218 ms\n"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "established-probe",
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'val_code_label_recall_19'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-d990d1626ccd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mft_recall_5\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0.2107\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mfinetune_history\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'val_code_label_recall_19'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mft_recall_10\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0.3366\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mfinetune_history\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'val_code_label_recall_20'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mft_recall_30\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0.5424\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mfinetune_history\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'val_code_label_recall_21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mcs_recall_5\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcold_start_his\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'val_recall_41'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: 'val_code_label_recall_19'"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 19.9 ms\n"
]
}
],
"source": [
"ft_recall_5 = [0.2107] + finetune_history.history['val_code_label_recall_19']\n",
"ft_recall_10 = [0.3366] + finetune_history.history['val_code_label_recall_20']\n",
"ft_recall_30 = [0.5424] + finetune_history.history['val_code_label_recall_21']\n",
"\n",
"cs_recall_5 = cold_start_his.history['val_recall_41']\n",
"cs_recall_10 = cold_start_his.history['val_recall_42']\n",
"cs_recall_30 = cold_start_his.history['val_recall_43']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ordered-fraud",
"metadata": {},
"outputs": [],
"source": [
"value1 = ft_recall_5\n",
"value2 = cs_recall_5\n",
"\n",
"length = len(value1)\n",
"plt.plot(range(length), value1, \"-s\", label=\"Finetune\")\n",
"plt.plot(range(1, length), value2, \"-s\", label=\"w/o Pretrain\")\n",
"\n",
"\n",
"plt.xticks(range(length))\n",
"plt.ylabel('Recall@5')\n",
"plt.xlabel(\"EPOCHS\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 151,
"id": "gothic-evaluation",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA8Z0lEQVR4nO3deXxU1fn48c+TnX0N+yoCsoggERRFkEVxC1qxoFhxaa0Lrq1Wf1pLrbZWtN9qtSpuuFBRcUMFARXrAihBQVkkLGUJSwh7Qsj+/P64NzCGmcnMJJOZSZ736zWvzNx7z73PJJP7zLnn3HNEVTHGGGMqiot0AMYYY6KTJQhjjDFeWYIwxhjjlSUIY4wxXlmCMMYY41VCpAOoLi1bttQuXbpEOgxjjIkpy5Yt262qqd7W1ZoE0aVLFzIyMiIdhjHGxBQR2exrnV1iMsYY45UlCGOMMV5ZgjDGGOOVJQhjjDFehTVBiMgYEVkrIutF5G4v668XkR9FZLmIfCUivd3lo0VkmbtumYiMCGecxhhjjhW2XkwiEg88BYwGsoClIjJbVVd7bPYfVX3G3T4d+AcwBtgNXKiq20WkLzAPaB+uWI0xJhLSHlzA7ryiY5a3bJhExn2jIxDRz4WzBjEIWK+qG1W1CJgJjPXcQFUPerxsAKi7/HtV3e4uXwXUE5HkMMZqjDE1zlty8Le8poXzPoj2wFaP11nA4IobichNwB1AEuDtUtIlwHeqWhiOII0xsS+S38SDPbaqcrCghOyDBX73uz+/iCb1EhGRajluKCJ+o5yqPgU8JSKXA/cBk8rXiUgf4O/A2d7Kish1wHUAnTp1Cn+wxpiwqcoJryrfxKt6ovV37Be/+h/ZBwvYebCAnQcKyD5YQPbBQg4Xl1a63/4PLCAxXkhtmExqI89HCqmNkmuk9hHOBLEN6OjxuoO7zJeZwNPlL0SkA/AucKWqbvBWQFWnAdMA0tLSbOYjYyKspk7y+UUl7MkrYs+hIvbk+b+48MriTdRLjKd+UgL1k+JJSYynfpLzqJcU7/e4323Zx8HDxRwsKHF/FnPwcAkHjjwv9nvsBz5cTVJ8HK2bJNOmcQp92zdhZK8U2jROoXWTFG55/XufZf94QW9ycgudR14h2/YXsHzrfvYcKqKm5nkLZ4JYCnQXka44iWECcLnnBiLSXVXXuS/PB9a5y5sCHwF3q+rXYYzRGFNBuE7yqkpRaRmFJWUUFpdRWFL6s+f+XPXSt+w9VOQmhUIKissCfj/3v78q4G0r+sW/Fx2zLCk+jsb1EmlcL4HGKYl+y3/3x9E0q+/7MpG/BHHtGV29Li8pLWPvoSIG/fVTv8euDmFLEKpaIiKTcXogxQMvquoqEXkAyFDV2cBkERkFFAP7OHp5aTJwPHC/iNzvLjtbVXeFK15jjMPfSX7F1v1HvkWXf4P2fO1P13vmVCGmQlo0SOb41Ia0aJhE8wbJtGiQ5D5P4mIvJ/JyGfeN4nBRKflFpeQXlXC4uPTI68NFpdz19g8+y06/+hQnGaQcTQgpifE/26bL3R/5LN+8QZLf99WyYZLPZOxLQnwcrRqn+N1vdQlrG4SqzgHmVFh2v8fzW32UexB4MJyxGVNbhdJouuNAARty8li/K8/vvsc+dWyFPj5OaJySQON6/r9N3zLieJIT40lOiDv6MyGO5IR4khPjuPqlpT7LfnjzUL/79qdlQ/8dIP0liOE9W4V83EBEQ1dWfyLeSG2MqV7+agBrd+ayISePDbvyWJ+Tx4acPDbmHCK/qPJGU4Dnr0yjSf2ff6OunxR/5BKKv2/Td5zdM/g3E6BQvonH+rFr4riWIIyJQqG0A5SUlpFTSYPtOf/84sjz9k3r0a1VQ07p0pxuqQ2dR6sGDHrI97XtUb1bB/gOgleVE15VvolX9UQbqVpATRzXEoQxYRKuxt43M7aSfaCA7NwCdh4oPNKNcndeYaW9Wx6f0J9uqQ05LrUB9ZOq/98/Uif5qoj2yzyRZAnCmDCprNtmcWkZe/KK3G6MBUe7NOb6rwXcNcu5Zt60fqLTXbJxCr3bNqZ142RaN0nh3ndX+iw7tn/lI9bE4knehIclCGMiYOBfFrA333t/9iaVNPZ+cedZtGqcfExvmnL+EkQg7CRvylmCMKaalJSWsWr7Qb79316+3bTX77Zj+rY5emesx52yLRs6J35/jb2dWtT3u+9INtia2sUShDF++GtH+OoPI1i+dT9L3YTw3eZ9HHJ7A3Wp5CT+0MUnhiVesBqAqT6WIIzxw187Qr8p8ykqLUMEerZuxCUDOzCoa3MGdWlOq8YpfmsBlbFagIkGliBMrRZsT6JDhSVk7TtM1r58svYd9rvvq0/vwildmpPWpRlN6x974rbGXhPrLEGYWs1fDeCVxZt+lgy27s1nX77/4SI83XNeL7/r7SRvYp0lCBP1Aq0FlJYpu3IL2L6/gO37D7PjgP8awP3vryIpIY4OzerRoVl9+rZvQodm9ejYrP6RZac89Em1vx9jYoUlCBP1/NUCbnn9ezcZODeLlZYFPg7yt/eOpGWDZOLivI+0aUxdZwnCRC1VZV0lg8ct37qftk1SGNy1OW2bptCuaT3aNalHu6b1aNs0hX5T5vss26pR5SNiWmOxqcssQZiosjuvkK/X7+aLzN18tT6H7IP+7yr+4q6zwhqPtSOYuswShKkR/toR/jl+AF+uz+HLzN2s3nEQcIaROOP4lgzt3pI/vP1jyMe1GoAxobMEYWqEv3aEK174hsR4YWDnZtx5Tk+Gdm9Jn3ZNiHfbBqqSIKwGYEzoLEGYsKus4filq05hUNfmNEj2/nG0WoAxkWEJwoRFcWkZSzbuYe7KncxftdPvtmed4H/WLqsFGBMZliBMtSksKeWrdbuZu3Inn6zJZn9+MfWT4jmrZys++nFHpMMzxgTJEoQJmK+G5sYpCZx1Qis+W7OL3MISGqUkMKpXa8b0bcOwHqmkJMbzURXGJTLGREZYE4SIjAEeB+KB51X14QrrrwduAkqBPOA6VV3trrsHuNZdd4uqzgtnrKZyvhqaDxaU8EVmDued2JYxJ7bh9G4tSUqI+9k21o5gTOwJW4IQkXjgKWA0kAUsFZHZ5QnA9R9VfcbdPh34BzBGRHoDE4A+QDvgExHpoaqBzaxuqp1WMpfl0ntHkRAf53O9tSMYE3vCWYMYBKxX1Y0AIjITGAscSRCqetBj+wZA+VloLDBTVQuB/4nIend/i8MYr/Fif34R736/jde/3eJ3O3/JwRgTm8KZINoDWz1eZwGDK24kIjcBdwBJwAiPsksqlD1mMl0RuQ64DqBTp07VErRxagtLN+3j9W+38NGPOygqKeOkjk0jHZYxpoZFvJFaVZ8CnhKRy4H7gElBlJ0GTANIS0sLfJQ249XeQ0W8810Wr3+7hQ05h2iUnMD4tI5MGNSRPu2aVGkCHGNM7AlngtgGdPR43cFd5stM4OkQy5oA+eqJlJwQhyoUlZYxoFNTHhnXjwv6taV+0tGPiDU0G1O3hDNBLAW6i0hXnJP7BOByzw1EpLuqrnNfng+UP58N/EdE/oHTSN0d+DaMsdYZvnoiFZaUcdWQLkwY1JET2jT2uo01NBtTt4QtQahqiYhMBubhdHN9UVVXicgDQIaqzgYmi8gooBjYh3t5yd3uTZwG7RLgJuvBVHWVDXkxJb1PDUVijIkFYW2DUNU5wJwKy+73eH6rn7IPAQ+FL7q6Y+vefN7K2MqsZVmRDsUYE0Mi3khtwqOguJT5q7N5c+lWvt6wG4Ch3VPZfqAgwpEZY2KFJYhaZvX2g7yZsZV3v9/GgcPFtG9aj9tG9mBcWgfaN61nPZGMMQGzBBFjfPVCapgcz3GpDfkh6wBJ8XGc3ac1E07pxJBuLX4257L1RDLGBMoSRIzx1Qspr7CUopIy/nRhby7q355mDbyf8K0nkjEmUJYgapG5tw5FRCrf0BhjAmAJIkZs3ZvPK4s3+d3GkoMxpjpZgohiqsqiDXuYvmgTn6zJJt4SgDGmBlmCiEL5RSW8+/02Xl60iczsPFo0SGLyWcczcXBnTv3bp5EOzxhTR1iCiABfPZGa109kXFpHZn67hYMFJfRp15hHLz2JC/q1JSUxHrBeSMaYmmMJIgJ89UTam1/MC1/9jzF923D1kC4M7NzsmHYF64VkjKkpliCizFd/OIu2TepFOgxjjMGmAYsylhyMMdHCEoQxxhivLEHUsKKSskiHYIwxAbEEUcOmfLDK5zrriWSMiSbWSF2DXl2ymf98s4Xrh3Xj7nNPiHQ4xhjjl9UgasiSjXv48+xVjDihFXee0zPS4RhjTKUsQdSArXvzuXHGd3RuUZ9/TuhPfJwNmWGMiX6WIMLsUGEJv3klg+LSMp67Mo3GKYmRDskYYwIS1gQhImNEZK2IrBeRu72sv0NEVovIDyLyqYh09lj3iIisEpE1IvKExOBQpWVlyu/fWkFmdi5PXn4yx6U2jHRIxhgTsLAlCBGJB54CzgV6A5eJSO8Km30PpKlqP2AW8IhbdghwOtAP6AucAgwLV6zh8q/P1jN35U7uObcXw3qkRjocY4wJSjhrEIOA9aq6UVWLgJnAWM8NVHWhqua7L5cAHcpXASlAEpAMJALZYYy12n28cif/90kmvxjQnl8P7RrpcIwxJmjhTBDtga0er7PcZb5cC8wFUNXFwEJgh/uYp6prKhYQketEJENEMnJycqot8Kr6aedB7nhzOSd1bMpff3GiTeRjjIlJUdFILSJXAGnAVPf18UAvnBpFe2CEiAytWE5Vp6lqmqqmpaZGxyWcvYeK+M0rGTRMTmDarwYeGabbGGNiTTgTxDago8frDu6ynxGRUcC9QLqqFrqLLwaWqGqequbh1CxOC2Os1aK4tIybZnxH9sFCnv3VQFo3Tol0SMYYE7JwJoilQHcR6SoiScAEYLbnBiIyAHgWJzns8li1BRgmIgkikojTQH3MJaZo85cPV7N44x7+dvGJDOjULNLhGGNMlYQtQahqCTAZmIdzcn9TVVeJyAMiku5uNhVoCLwlIstFpDyBzAI2AD8CK4AVqvpBuGKtDq9/u4VXFm/m12d05ZKBHSovYIwxUU5UNdIxVIu0tDTNyMioueP5mDa0ZcMkm/XNGBMzRGSZqqZ5WxcVjdSxyNe0ob6WG2NMrLEEYYwxxitLEMYYY7yyBGGMMcYrSxDGGGO8sgQRIl/Tg9q0ocaY2sKmHA1Rxn2juf7VZWRm5/LZ74dHOhxjjKl2VoOogsxduXRvbXM8GGNqJ0sQISooLmXT7kP0bN0o0qEYY0xYWIII0cacQ5QpdLcEYYyppSxBhCgzOxeAnm0sQRhjaidLECHKzM4lIU7o0qJBpEMxxpiwsAQRoszsPLq2bEBSgv0KjTG1k53dQpSZnUsPu7xkjKnFLEGEIL+ohK378unRyhKEMab2sgQRgvW78lCFnm3sHghjTO1lCSIEmdl5gHVxNcbUbkEnCHeO6V+IyAnhCCgWZGbnkpQQR+fm9SMdijHGhE2lCUJE3vN4Phb4DLgQeF9ErgpbZFEsMzuXbqkNSYi3CpgxpvYK5AzX2eP5H4ARqno1cDpwu7+CIjJGRNaKyHoRudvL+jtEZLWI/CAin4pIZ491nURkvoiscbfpEthbCr/Mnbn0tDGYjDG1XCAJQj2eJ6jq/wBUdTdQ5quQiMQDTwHnAr2By0Skd4XNvgfSVLUfMAt4xGPdK8BUVe0FDAJ2BRBr2OUWFLP9QIG1Pxhjar1Ahvs+SUQOAgIki0hbVd0hIklAvJ9yg4D1qroRQERmAmOB1eUbqOpCj+2XAFe42/bGSUYL3O3ygnhPYVXeQG2D9BkTvOLiYrKysigoKIh0KHVOSkoKHTp0IDExMeAylSYIVfWVBOoDv/VTtD2w1eN1FjDYz/bXAnPd5z2A/SLyDtAV+AS4W1VLPQuIyHXAdQCdOnXys+vqs84dg6mHJQhjgpaVlUWjRo3o0qULIhLpcOoMVWXPnj1kZWXRtWvXgMuF0ouphYjEqep+VV0cbHkf+7wCSAOmuosSgKHA74FTgOOAqyqWU9VpqpqmqmmpqanVEUql1mbnUi8xng7N6tXI8YypTQoKCmjRooUlhxomIrRo0SLomltAM8qJSDPgL8CJwA6gmYhsA25W1UM+im0DOnq87uAuq7jvUcC9wDBVLXQXZwHLPS5PvQecCrwQSLzhtC47j+6tGxIXZx/woEztDoe8NCM1aAV3rqv5eEzEWHKIjFB+74F0c20KzAHeVtVhqjpBVc8BXgUeFpGhItLcS9GlQHf3vokkYAIwu8K+BwDPAumquqtC2aYiUl4tGIFH20Ukrc3OtctLofCWHPwtNyZM4uPj6d+//5HHpk2bGDJkSMj7mz59Otu3b6/GCKNHIDWIPwKPqupCEXkV55v8bqAl8CNO4/V9wB2ehVS1REQmA/NwGrNfVNVVIvIAkKGqs3EuKTUE3nKz2xZVTVfVUhH5PfCpOCuWAc9Vw/utkn2HisjJLaSHdXE1JuzSHlzA7ryiY5a3bJhExn2jQ95vvXr1WL58+c+WLVq0KOT9TZ8+nb59+9KuXbuQ9xGtAmmDOFNV33afFwKXqeppwHhgD/AVcJa3gqo6R1V7qGo3VX3IXXa/mxxQ1VGq2lpV+7uPdI+yC1S1n6qeqKpXqeqxn5QalmkN1OGhWvk2ps7xlhz8La+Khg2dL32ff/45w4cPZ9y4cZxwwglMnDgRdT+fy5YtY9iwYQwcOJBzzjmHHTt2MGvWLDIyMpg4cSL9+/fn8OHDdOnShd27dwOQkZHB8OHDAZgyZQrXXHMNw4cP57jjjuOJJ544cvzXXnuNQYMG0b9/f377299SWlpKNAikBpEiIqLOb+lkYIW7fCVwsqqW1ZVripm7nC6uliCCVFkC+Gc/6HUh9E6HDoMgzu5Qrwv+/MEqVm8/GFLZ8c967x/Tu11j/nRhH79lDx8+TP/+/QHo2rUr77777s/Wf//996xatYp27dpx+umn8/XXXzN48GBuvvlm3n//fVJTU3njjTe49957efHFF3nyySd59NFHSUtLqzTun376iYULF5Kbm0vPnj254YYbWL9+PW+88QZff/01iYmJ3HjjjcyYMYMrr7wysF9GGAWSIL4FRuJ0Nf03MF9EFgOnAc+KyCnAqvCFGD3WZefSKDmBtk1SIh1K7CgphA9u9b9N696w9DlY8hQ0bAO9LoBe6dD5dIhPsAZuU628XWLyNGjQIDp06ABwpI2iadOmrFy5ktGjnUtbpaWltG3bNuhjn3/++SQnJ5OcnEyrVq3Izs7m008/ZdmyZZxyyimAk8BatWoV/BsLg0ASxEPAmyJyvqo+7/YoOg74B84lqtnApPCFGD3W7syle+uG1gsjUId2w8yJsHUJJDaAYi8d3hq0gsvfgIKDsG4+rH4fvp8BS5+H+i2g53nWwF1LVfZNv8vdH/lc98ZvT6vucI5ITk4+8jw+Pp6SkhJUlT59+rB4ceU9+xMSEigrcwaZqNit1Ne+J02axN/+9rdqegfVJ5Ab5TaKyE3AbBGZj3PHcylwnvv4naquDW+YkaeqZGbnMqZvm0iHEhuyV8Pr4yFvF4x7Efpe4n/7lMZw4jjnUZQP6z+BNbNh1Xs1Eq4x/vTs2ZOcnBwWL17MaaedRnFxMZmZmfTp04dGjRqRm5t7ZNsuXbqwbNkyzj33XN5++20/e3WMHDmSsWPHcvvtt9OqVSv27t1Lbm4unTt3rrRsuAV0sVdVv8G5pPQF0Avoi5Mohqjql+ELL3rszitiX34x3W0WucplzocXzoaSIrh6TuXJoaKk+k57xCXPw10bwhOjiXotGyYFtTyckpKSmDVrFn/4wx846aST6N+//5GeT1dddRXXX3/9kUbqP/3pT9x6662kpaURH+9vNCJH7969efDBBzn77LPp168fo0ePZseOHeF+SwERrSU9SNLS0jQjIyNs+1+0fjeXP/8NM349mNOPbxm248Q0VVjyb5h/H7TuC5fNhCbtq77fKU38rDtQ9f2bGrNmzRp69eoV6TDqLG+/fxFZpqpeW9grvcQkIrk4I7oKPx/ZVQBV1cahhxs71rpdXLvbPRDelRTBnN/Ddy87PZIufhaSGoT/uGVl1uvJmDCp9D9LVRupamOPn409X9dEkNEgMzuPpvUTSW2YXPnGdU3+XnjtF05yGPo7uPSV6k0ODfz06Hj/Rigtqb5jGWOOCKQG4W0YjSNUdW/1hRO9Mt0hNqwHUwU5mfCfX8LBbXDxNDhpfPUfw1tXVlX44lFY+CAUHIBxL0GidT82pjoF0s11GUcvMVWkOF1ea7XyHkxj+9e+W+mD4ut+BASunQ8dB9VcLCIw7E5IaQJz74QZ4+Cy1yHZOhEYU10C6eYa+ODhtdTOgwXkFpTYJEE+7zvQmk0OngZfB/WawrvXw8vpcMXbUN9vpdcYE6CAhvsu5w773R04UpdX1S+qO6hoUz6LnE0zGqX6/dKpObx1Fbx0LvzqXWhcx2t7xlSDgLt/iMivce6DmAf82f05JTxhRZfMnXV8kD5V2LCw8u0iqee5Tu3hwDZ44RzYY/dPmNA9/PDDzJgxI6Btp0+fTmpqKv3796d3794891xwA0+/9957rF4d/GwGs2fP5uGHHw66XDCCqUHcijO72xJVPUtETgD+Gp6woktmdi4tGybTvEHN36ATUQUHYPnrzrAXe2JgzKMuZ8Ck2fDaJfDiGKcm0aZvpKMyoYrgGFzz5s3jzTffDHj78ePH8+STT7Jr1y769OlDeno6rVu3PrK+pKSEhATvp9v33nuPCy64gN69ex+zzl+59PR00tPTva6rLsF0IC9Q1QIAEUlW1Z+AnuEJK7pkZufSs00duv8hexV8cBs81gs+/oPTEHzxs5GOKjDtT4ZrPob4RJh+Hmz5JtIRmVCFYQyuqVOnHhlm+/bbb2fEiBEAfPbZZ0ycOBGAgwcPUlRURGpqKps2bWLEiBH069ePkSNHsmXLFr/7b9WqFd26dWPz5s1H7rAePHgwd911Fxs2bGDMmDEMHDiQoUOH8tNPP7Fo0SJmz57NnXfeSf/+/dmwYQPDhw/ntttuIy0tjccff5wPPviAwYMHM2DAAEaNGkV2djbg1FwmT54MOHdz33LLLQwZMoTjjjuOWbNmhfw78hRMDSLLnV3uPWCBiOwDNldLFFGsrExZtyuPX6Z1rHzjWFZaDGs+cGoLm7+GhBToOw4G/RraDXC2mf9H39/ooklqTydJvHIRvHoRjH8Vjh8V6ahMRXPvhp0/hlb2pfO9L29zIpzr+7LL0KFDeeyxx7jlllvIyMigsLCQ4uJivvzyS84880wAPvnkE0aOHAnAzTffzKRJk5g0aRIvvvgit9xyC++9957P/W/cuJGNGzdy/PHHA5CVlcWiRYuIj49n5MiRPPPMM3Tv3p1vvvmGG2+8kc8++4z09HQuuOACxo0bd2Q/RUVFlI8MsW/fPpYsWYKI8Pzzz/PII4/w2GOPHXPsHTt28NVXX/HTTz+Rnp7+s/2FKuAEoaoXu0+niMhCoAnwcZUjiHLb9h8mv6iUnm1qSfuDz66qcUAZNO0Mox+AAb86tjdQLA2t3bSTkyT+0cu55FSRDRVeJw0cOJBly5Zx8OBBkpOTOfnkk8nIyODLL788UrP4+OOPufrqqwFYvHgx77zzDgC/+tWvuOuuu7zu94033uCrr74iOTmZZ599lubNnf+dSy+9lPj4ePLy8li0aBGXXnrpkTKFhYU+4xw//uj9RFlZWYwfP54dO3ZQVFRE167eO5ZedNFFxMXF0bt37yO1jKoKOEGIyKnAKlXNVdX/ikhjYABQq+vwR2eRqyWXmHxWz8vg8jedb9pxlQ8wFhMatoIyH3dZ21Dhkefnmz7gfwyuq30PBe5PYmIiXbt2Zfr06QwZMoR+/fqxcOFC1q9ff2SMom+//Zann346qP2Wt0FU1KCBM6JAWVkZTZs29TsPhbdy4NRi7rjjDtLT0/n888+ZMmWK1zKeQ4lX1xh7wbRBPA3kebzOc5fVanWqi2uPc2pPcjDGh6FDh/Loo49y5plnMnToUJ555hkGDBiAiLBq1SpOOOGEI6OwDhkyhJkzZwIwY8YMhg4dGtIxGzduTNeuXXnrrbcA5wS+YoUzOWfF4cIrOnDgAO3bO4NevvzyyyEdP1TBJIjyaUcBUNUyKqmBiMgYEVkrIutF5G4v6+8QkdUi8oOIfCoinSusbywiWSJybGquIZnZubRtkkLjlMRIhWDC5cC2SEdg/PHVtlXFNq+hQ4eyY8cOTjvtNFq3bk1KSsqRE//cuXMZM2bMkW3/9a9/8dJLL9GvXz9effVVHn/88ZCPO2PGDF544QVOOukk+vTpw/vvvw/AhAkTmDp1KgMGDGDDhmO7Z0+ZMoVLL72UgQMH0rJlzY4kHfBw3yLyDvA5R2sNNwJnqepFPraPBzKB0UAWsBS4TFVXe2xzFvCNquaLyA3AcFUd77H+cSAV2Kuqk/3FF67hvs9/4ktaNkzm5WsidKdwdatrQ2f7e78JKTD4ejjjdudubBN20T7c9+jRo3nllVdCmk40FgQ73HcwNYjrgSHANpwT/mDgOj/bDwLWq+pGVS0CZgJjPTdQ1YWqmu++XAJ08Ah6INAamB9EjNWqtExZvyuv9rQ/rJ0b6QiiS++x8PXj8ER/WPyUM3+2qdMWLFhQa5NDKAJOEKq6S1UnqGorVW2tqperqr+WvvbAVo/XWe4yX64F5gKISBzwGPB7fzGJyHUikiEiGTk5OYG9kSBs2ZtPYUlZ7biD+vB+594G8dHGEG1dVauLv8sUv5gGv/3C6cY77//Bk2nww1vOHBPGmKB6MfXAubzUWlX7ikg/IF1VH6xqECJyBZAGDHMX3QjMUdUsf8Nrq+o0YBo4l5iqGkdFa2vTEBvz7oVDOfCbT4/e11AXVNaVtW0/547rDZ/BgvvhnV/D4n/BqD9Dt7NqJkZjolQwN8o9B9wJPAugqj+IyH8AXwliG+B5d1kHd9nPiMgo4F5gmKqW1/FPA4aKyI1AQyBJRPJU9ZiG7nBaV1tmkVv3CSx/Dc64o24lh2B0GwFdh8PKWfDpX5wb7LqNhO3fw2EvU57YfRQhU1WbVyUCQun6GkyCqK+q31b4w/qbymsp0F1EuuIkhgnA5Z4biMgAnIQzxvNylapO9NjmKiCtppMDONOMdmxej/pJQQ16G10KDsIHt0DLnjDsD5GOJrrFxTkjw/ZKd+4o/2IqFOz3vq3dRxGSlJQU9uzZQ4sWLSxJ1CBVZc+ePaSkBDepVjBnvt0i0g13XmoRGQfs8BNQiYhMxhn1NR54UVVXicgDQIaqzgam4tQQ3nI/LFtUNbyjTwVhXXZe7M8BseB+yN0B18y3GdcClZgCQybDgInw9y6RjqZW6dChA1lZWYSjzdD4l5KSQocOHSrf0EMwCeImnOv9J4jINuB/wER/BVR1DjCnwrL7PZ5XOkCOqk4HpgcRZ7UoLi1j4+48RvSK4cbbjf+FZS/BaZOh4ymRjib21GsW6QhqnfI7mU1sCKYX00b3hJ4KnIDToHxGuAKLtE27D1FcqrFbgyjMg9mToXk3GHFfpKOpnT57yOkdZkwtVWmCcO9mvkdEnhSR0UA+MAlYD/wy3AFGytpYb6D+9AHYvxXGPgWJ9SIdTe30xSPweD/471Qo9D1UgjGxKpAaxKs48z78CPwGWAhcClysqmP9FYxlmdl5xAl0S43BBLF5EXz7LAy6DjqfFuloYpu/+yh++yV0GgILH4R/9nNuuivK9769MTEokDaI41T1RAAReR6nYbpT+eRBtVXmzly6tGhASmKMDV5XlA/v3+QM2z3qT5GOJvZV1pX18pmQtQwWPuR0CFj0JAy9AwZebZ0CTMwLJEEUlz9R1VIRyartyQEgc1dubF5eWvgQ7N0IV86GpAaVb2+qrsNA+NU7sHmx8/v/+G74+gnnslORl0tPdg+FiRGBXGI6SUQOuo9coF/5cxE5GO4AI6GguJRNuw/FXgP11qWw5N/Ot9fjhlW+valenU+Dqz50knPTjt6TA9g9FCZmVFqDUNUYu8ZSdRtzDlGmMTYHRHEBvH8jNGrnzAhnIue4YdD1TPhz00hHYkyVxPAtwuGzbpfzzS+mphn9799hdyZc8TakNI50NMbuEja1QDDDfdcZa3fmkhAndGkRI9fwt3/v9KAZcIUzZaiJfkufh7LSSEdhjF+WILzIzM7juNQGJCXEwK+npAjeu8mZf/nshyIdjQnUR7+D50fCtmWRjsQYn2LgDFjzMrNzY6f94cvHYNcquOCfNitatPF3D8UlL8DB7fDcSPjwdsj3MmKsMRFmbRAV5BeVsHVfPuMGBjeoVY2Z2t17L5jZN1vXyWhT2d+j+2hY+DfnpsbV78Pov8BJlzmjyhoTBeyTWMH6XXmoEr3TjPrqImldJ2NPShM492FnVrsWxzu90F46F3aujHRkxgBWgzhGZnYeUEtmkTOxoc2JcPXHsHyGczf2s2dCQhIUHz52W7vJztQgq0FUkJmdS1JCHJ1jpQeTqR3i4uDkX8HNy5yf3pIDWE3R1ChLEBVkZudyfGpD4uOsH7uJgPrN4cLHIx2FMYAliGNk7syN3vYHY4ypQZYgPOQWFLP9QAE9ovkOan9dJ40xphpZI7WHIw3UraI4QYx7AV6+0OlHf+K4SEdjImHHCmh7UqSjMHVAWGsQIjJGRNaKyHoRudvL+jtEZLWI/CAin4pIZ3d5fxFZLCKr3HXjwxlnuXXZMTAG04qZkNQITjg/0pGYcPJVI5Q4eOl82Ph5jYZj6qaw1SBEJB54ChgNZAFLRWS2qq722Ox7IE1V80XkBuARYDzOtKZXquo6EWkHLBOReaq6P1zxgjPNaL3EeNo3jdIpOovynRuq+lxk04jWdr66sh7cDq9dAq+Ng4ufsVqkCatw1iAGAetVdaOqFgEzgZ9NUaqqC1W1fI7GJUAHd3mmqq5zn28HdgGpYYwVgHXZefRo3ZC4aO3B9NNHUJTn3G1r6qbG7eDqudBxELx9LSz+d6QjMrVYOBNEe2Crx+ssd5kv1wJzKy4UkUFAErDBy7rrRCRDRDJycnKqGK5Tg4jqMZhWvA5NOjnzIJu6q15TuOId6HUhzLvHubmurCzSUZlaKCoaqUXkCiANGFZheVvgVWCSqh7zH6Cq04BpAGlpaVqVGPbnF5GTWxi9s8jl7oSNC+GMO2ysHuPMd33pyzDnTmeo99xsGPskxCdWfd++xvuyu7jrnHAmiG1AR4/XHdxlPyMio4B7gWGqWuixvDHwEXCvqi4JY5zA0R5MUTsP9Y9vgZbBSRMiHYmJFnHxcP5j0KgtLHwQDuXAL1+B5Cp+hm28L+MK51fRpUB3EekqIknABGC25wYiMgB4FkhX1V0ey5OAd4FXVHVWGGM8Yq3bgylqx2BaMRPap0HL7pGOxEQTERh2J1z4hFPDfPkCyAvxcqsq7NtUreGZ2Ba2GoSqlojIZGAeEA+8qKqrROQBIENVZwNTgYbAW+JM0bhFVdOBXwJnAi1E5Cp3l1ep6vJwxbsuO5dGyQm0bZISrkOEbuePkL0Szns00pGYaDVwkjNp1FtXwYtnO20Uzbv6L1NWBrvXwuavYfNi2LwIcrfXSLgmNoS1DUJV5wBzKiy73+O51/kxVfU14LVwxlbR2p25dG/dEInGuYRXzIS4ROjzi0hHYqJZz3Phytnw4jnwRP9j1zdoBZe/4SSCzYtgy2I47E5U1KgtdB4CnU6DOb+v0bBN9IqKRupIU1Uys3MZ07dNpEM5VmmJ0/7Q4xxo0CLS0Zho12kw4KO/xqFd8NxZzvPmx0HP85yk0Pk0aNbVuVwFliDMEZYggN15RezLL6Z7NA6xsfFzyMu2xmlTPca96HSTbtzW9zYNWvnuxWTqFEsQRPkQGz/MhJSm0P3sSEdiaoO+l1S+TcWurB/eDsumwxVvhyUkE72sQz1HezBFXRfXgoOw5kPnnzohOdLRmLpq5P1Qrzl89Du7Ia+OqdM1iLQHF7A7r+jI60EPfQpAy4ZJZNw3OlJhHbVmNpQctqE1TGTVawZn/wXeuwGWvwYnXxnpiEwNqdM1CM/kEMjyGrdiJjTvBh3SIh2JiSXhmDPkpMuctosFf4L8vaHvx8SUOl2DiGr7t8CmL+Gse4/2LjEmEOEYDkPEuWv7mTPgkymQ/kT1H8OG+Ig6dboGEdV+eNP52e+XkY3DmHKte8OpN8B3L8PWpdW/fxviI+pYgohGqs7lpU5DoFmXSEdjzFHD74ZG7eCj2517dEytZgkiGm3/Dvass3sfTPRJbgRj/uYM/7L0+UhHY8KsTieIlg2TglpeY1bMhPhkZ+Y4Y6JN77HQbSQsfMgZhr465NllpGhUpxupo6Ira0UlRfDjLGfO6ZQmkY7GmGOJwHlT4d+nwvz74JIq1iT2bYJXL/a/zU9z4ITzqnYcX6xx3Kc6XYOISus/cQZQs8tLJpq16AZn3O6ME7bxv6HvJ3s1vHCO03W2XjPv28QlwMzL4ct/OO1z1c0ax32q0zWIqLTidWiQCt1GRDoSY/w743b44Q3nDusbFkFCkJdmt3wD/7kUEuvDNR9Dq17etys+DO/fBJ/+GXatgfR/OTPqmbCzGkQ0ObwPMj+GEy+tnqkjjQmnxHrOHCV71sHifwVXdt0CeGUs1G8B18zznRzKj3PJCzDiPvjxTZh+XvW1fRi/LEFEk1XvQmkR9Bsf6UiMCUz30dDrQvjvVOfmzkD8OAten+DMjnjNPGjWufIyInDmnTD+Ndj1E0w7C7Z/X7XY92+Fj++p2j5qOUsQ0WTFTEjtBW1PinQkxgRuzMPOCXzu3ZVv+800ePvX0PFUuOpDZxa8YPS6EK6d58zH/eK5sDKEEWZ3roR3roPHT4JvpwVfvg6xBBEt9myArd84jdM2tIaJJU06wLA/wNqPYO1c79uowsK/wdw7nYmKrng79F56bU6E3yx0vkjNugY+e6jyUWZV4X9fwGuXwDOnO6MkD74eblnue4yq5Cgc/r+GWSN1tPjhDUCc9gdjYs2pNzodLObeBV2HQVL9o+vKypzlS5+D/hPhwicgvoqnnoapMGk2fHgHfPEIfP04lBYeu12DVnDu3531O5Y7HUBG/BFOufZor6mKXVnLyuDViyArw/ni1qJb1WKNYWGtQYjIGBFZKyLrReSY+qeI3CEiq0XkBxH5VEQ6e6ybJCLr3MekcMYZceVDaxw3DJq0j3Q0xgQvIckZzG//FvjqH0eXlxTBO79xksOQm2HsU1VPDkeOmQxjn4Rz/uo9OYDTVXXW1VCYCxf8E25bCWf+3neXWoC4OLjoaSfOd39bp4cUEQ1Hv2JAROKBTGA0kAUsBS5T1dUe25wFfKOq+SJyAzBcVceLSHMgA0jDmWB3GTBQVff5Ol5aWppmZGSE5b2E3ebF8NIYuPhZu//BxLYHW0NJwbHLkxrA/9sevuNO8XO56pevOjeexsUHt88fZ8Hb18JZ98GwO6sWXxQTkWWq6nVOgXDWIAYB61V1o6oWATOBsZ4bqOpCVc13Xy4BOrjPzwEWqOpeNyksAMaEMdbIWvE6JDaAEy6IdCTGVI235ABQdKhm4/DUOz345ABw4jjoOw7++zBs+67644oB4UwQ7YGtHq+z3GW+XAuUt3AFVFZErhORDBHJyMnJqWK4EVJcAKvec3pnJEfZlKfG1HXnPwoNWzu9noryK9++lomKXkwicgXO5aSpwZRT1WmqmqaqaampqeEJLtwy50LhAbu0ZEw0qtcMLvq3czPggvsjHU2NC2cvpm1AR4/XHdxlPyMio4B7gWGqWuhRdniFsp+HJcpI8DY42KsX2eBgxoSqQSvfA+5V1XHD4dSbYMlT0GMMdB9V9X3GiHAmiKVAdxHpinPCnwBc7rmBiAwAngXGqKrnX3ce8FcRKe9qcDZQe255tMHBjKle4f5iNfJ+2PCZMybUjYuhfvPwHi9KhO0Sk6qWAJNxTvZrgDdVdZWIPCAi6e5mU4GGwFsislxEZrtl9wJ/wUkyS4EH3GXGmGjm6xt7dXyTj6TEFPjFNMjfAx/cGp5RZaNQ2Lq51rSY6OZaVgabv4KXL/S9zZQDNRePMSY4X/0ffDIFLnoG+l8W6Wiqhb9urnYndU3IyYQfZsIPb8KBrZVvb4yJTkNugcz5MOdO6DwksIEGY1hU9GKqlQ7tcQYmm3YWPHWK880j9QRn2GJjTGyKi4eLn3Gev3cDlJVGNp4wsxpEqHxNU5jcBLqcAevmQVmJM7DY2Q85Yyw1au1s8/E94etxYYwJr2ad4bxHnASx+Ek4/dZIRxQ2liBC5avHUeEB2JYBp94A/SZAm77HbmNdWY2JbSdd5oxc++lfnNkf25wY6YjCwhJEONy+uvoGJDPGRB8RZ/C/rd84d1n/ZmGtnAbVzmLhYMnBmNqvQQsY+2+YcQk81NrL+ti/8dUaqY0xJlT+7qquBTe+WoIwxhjjlSWIUNXWO0aNMdVn3QIoiN2bX+1ieahi/NqiMaYGzBgHEget+0Ln06HzadBpiDNlKvjuLh8l7ReWIIwxJlyunA1bFsPmr2HZdPjmaWd5yx7Q6bSoH7jTEoQxxlSFv6HGjxvmPMCZn3vHCidZbF7kTBQW5SxBGGNMVQR6KSghCTqe4jzOuM0ZpuOB6B423BqpjTEmEkKZJ7uGWYIwxphotObDSEdgCcIYYyLGV7f4uAR4YyJ88ueIjhhrbRDGGBMpvtovigtg7l3w1T9g+/fONAENWtRsbFgNwhhjok9iCqQ/Aen/cno8TRvmJIoaZgnCGGOi1clXwjUfO89fOAe+e7VGDx/WBCEiY0RkrYisF5G7vaw/U0S+E5ESERlXYd0jIrJKRNaIyBMiIuGM1RhjolL7k+G6/zp3Yc+eDB/cCiWFNXLosCUIEYkHngLOBXoDl4lI7wqbbQGuAv5ToewQ4HSgH9AXOAUYFq5YjTEmqjVoAVe8A2fc7tyR/dK5cCAr7IcNZyP1IGC9qm4EEJGZwFhgdfkGqrrJXVdWoawCKUASIEAikB3GWI0xJrrFxcOoKdB+ILx7A/xfX5xTZQXVOI5TOC8xtQe2erzOcpdVSlUXAwuBHe5jnqquqfYIjTEm1vS6EH7zGV6TA1TrOE5R2UgtIscDvYAOOEllhIgM9bLddSKSISIZOTk5NR2mMcZERmqPGjlMOBPENqCjx+sO7rJAXAwsUdU8Vc0D5gKnVdxIVaepapqqpqWmplY5YGOMMUeFM0EsBbqLSFcRSQImALMDLLsFGCYiCSKSiNNAbZeYjDGmBoUtQahqCTAZmIdzcn9TVVeJyAMikg4gIqeISBZwKfCsiKxyi88CNgA/AiuAFar6QbhiNcYYc6ywDrWhqnOAORWW3e/xfCnOpaeK5UqB34YzNmOMiWn+5qGoJjYWkzHGxKIamJI0KnsxGWOMiTxLEMYYY7yyBGGMMcYrSxDGGGO8sgRhjDHGK1H1MZ5HjBGRHGBzFXbREthdh8pG8tixWDaSx7b3HBtlI3nsqpTtrKreh6JQVXs4STKjLpWN1bjt92XvOVrLxnLcvh52ickYY4xXliCMMcZ4ZQniqGl1rGwkjx2LZSN5bHvPsVE2kseuatxe1ZpGamOMMdXLahDGGGO8sgRhjDHGqzqfIERkjIisFZH1InJ3kGVfFJFdIrIyhON2FJGFIrJaRFaJyK1BlE0RkW9FZIVb9s8hHD9eRL4XkQ+DLLdJRH4UkeUikhHCcZuKyCwR+UlE1ojIMTMF+ijX0z1m+eOgiNwWxHFvd39XK0XkdRFJCaLsrW65VYEc09vnQkSai8gCEVnn/mwWRNlL3WOXiUhakMed6v6ufxCRd0WkaRBl/+KWWy4i80WkXTDH9lj3OxFREWkZxLGniMg2j7/3ecEcV0Rudt/3KhF5JIjjvuFxzE0isjyIsv1FZEn5/4aIDPJW1k/5k0Rksfv/9YGINPZSzut5I9DPV9DC0Xc2Vh5APM7ERMcBSTiTE/UOovyZwMnAyhCO3RY42X3eCMgM9NiAAA3d54nAN8CpQR7/DuA/wIdBltsEtKzC7/xl4Nfu8ySgaYh/t504N/gEsn174H9APff1m8BVAZbtC6wE6uMMj/8JcHywnwvgEeBu9/ndwN+DKNsL6Al8DqQFedyzgQT3+d+DPG5jj+e3AM8Ec2x3eUecScM2+/rc+Dj2FOD3Afx9vJU9y/07JbuvWwUTs8f6x4D7gzjufOBc9/l5wOdBxr0UGOY+vwb4i5dyXs8bgX6+gn3U9RrEIGC9qm5U1SJgJjA20MKq+gWwN5QDq+oOVf3OfZ6LM+te+wDLqjpzdYOTIBKBgHsbiEgH4Hzg+aCCriIRaYLzj/ECgKoWqer+EHY1EtigqsHcOZ8A1BORBJyT/fYAy/UCvlHVfHVmSfwv8At/BXx8LsbiJEfcnxcFWlZV16jq2soC9VF2vhs3wBK8TNDlp+xBj5cN8PMZ8/O/8H/AXSGWrZSPsjcAD6tqobuNl1l1/B9XRAT4JfB6EGUVKP/W3wQ/nzEf5XsAX7jPFwCXeCnn67wR0OcrWHU9QbQHtnq8ziLAk3R1EpEuwACcmkCgZeLd6u8uYIGqBlwW+CfOP21ZEGXKKTBfRJaJyHVBlu0K5AAvuZe3nheRBiHEMAEf/7jeqOo24FGcuc53AAdUdX6AxVcCQ0WkhYjUx/lm2DHIeAFaq+oO9/lOoHUI+6iqa4C5wRQQkYdEZCswEbi/su0rlB0LbFPVFcGU8zDZvcT1YpCXTHrg/M2+EZH/isgpIRx7KJCtqsHMynMbMNX9fT0K3BPkMVdx9AvqpVTyOatw3gjL56uuJ4iIE5GGwNvAbRW+sfmlqqWq2h/nG+EgEekb4PEuAHap6rJQ4gXOUNWTgXOBm0TkzCDKJuBUq59W1QHAIZzqcMBEJAlIB94KokwznH+8rkA7oIGIXBFIWVVdg3NpZj7wMbAcKA0mZi/7VIKo8VUHEbkXKAFmBFNOVe9V1Y5uuclBHK8+8P8IMql4eBroBvTHSeqPBVE2AWgOnArcCbzp1giCcRlBfAlx3QDc7v6+bsetKQfhGuBGEVmGc/moyNeG/s4b1fn5qusJYhs/z9Id3GU1QkQScf7IM1T1nVD24V6iWQiMCbDI6UC6iGzCuaQ2QkReC+J429yfu4B3cS7TBSoLyPKo7czCSRjBOBf4TlWzgygzCvifquaoajHwDjAk0MKq+oKqDlTVM4F9ONd9g5UtIm0B3J9eL3uEg4hcBVwATHRPHqGYgZdLHn50w0nIK9zPWgfgOxFpE0hhVc12vwSVAc8R/OfsHfdS7Lc4NWWvDeTeuJchfwG8EcQxASbhfLbA+QITTMyo6k+qeraqDsRJTht8xOftvBGWz1ddTxBLge4i0tX9ZjoBmF0TB3a/0bwArFHVfwRZNrW8N4qI1ANGAz8FUlZV71HVDqraBef9fqaqAX2bFpEGItKo/DlOA2jAPbhUdSewVUR6uotGAqsDLe8K5ZvdFuBUEanv/t5H4ly7DYiItHJ/dsI5cfwnyOOD87ma5D6fBLwfwj6CJiJjcC4npqtqfpBlu3u8HEuAnzEAVf1RVVupahf3s5aF07i6M8Bjt/V4eTFBfM6A93AaqhGRHjidIYIZ6XQU8JOqZgVRBpw2h2Hu8xFAUJNGe3zO4oD7gGe8bOPrvBGez1d1tHTH8gPnmnImTra+N8iyr+NUf4tx/gGuDaLsGTjVwB9wLlssB84LsGw/4Hu37Ep89LQIYD/DCaIXE05vrxXuY1Wwvy93H/2BDDf294BmQZRtAOwBmoRw3D/jnOBWAq/i9nAJsOyXOIlsBTAylM8F0AL4FOek8QnQPIiyF7vPC4FsYF4QZdfjtLOVf8a89kTyUfZt9/f1A/AB0D7U/wX89H7zcexXgR/dY88G2gZRNgl4zY39O2BEMDED04HrQ/gbnwEscz8n3wADgyx/K865KBN4GHekiwrlvJ43Av18BfuwoTaMMcZ4VdcvMRljjPHBEoQxxhivLEEYY4zxyhKEMcYYryxBGGOM8coShDE+iEip/HwE2bvd5Z+LMwLwChH5uvy+DhFJEpF/ijMy8DoRed8d96p8f21EZKaIbHCHKpkjIj1EpIscOxrpFBH5vfv8VHfYiOXijIA7pQZ/DaYOS4h0AMZEscPqDGfizURVzXDHo5qKM/zHX3GGSOipqqUicjXwjogMdsu8C7ysqhPAGd4ZZ8ycrcfu/mdeBn6pqitEJB5nZFdjws4ShDFV8wVwmzv20NVAV1UtBVDVl0TkGpy7ahUoVtUjd8eqO4idO+iaP61wbqrC3Xewd58bExJLEMb4Vk9+PmHM31S14vg8F+Lc8Xs8sEWPHXAxA+jjPvc3QGK3CsdqgzMiKDhDZq8Vkc9xBgx8WVULAn0TxoTKEoQxvvm7xDRDRA7jDCFxM1DVGbw2eB7Ls51BVR8QkRk4Y19djjMe1fAqHs+YSlmCMCY0E1X1yJSrIrIX6CQijdSZyKXcQKB8WtdxoR5MVTcAT4vIc0COiLRQ1T2h7s+YQFgvJmOqgaoewmlM/ofbkIyIXIkze91n7iPZc5IlEeknIkMr27eInO8xn0F3nPko9lfvOzDmWJYgjPGtXoVurg9Xsv09QAGQKSLrcGYFu1hdOKOyjnK7ua4C/oYz+1dlfoXTBrEcZ5TTieUN4caEk43maowxxiurQRhjjPHKEoQxxhivLEEYY4zxyhKEMcYYryxBGGOM8coShDHGGK8sQRhjjPHq/wNFIswiUzm/jwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 295 ms\n"
]
}
],
"source": [
"value1 = ft_recall_5\n",
"value2 = cs_recall_5\n",
"\n",
"length = len(value1)\n",
"plt.plot(range(length), value1, \"-s\", label=\"Finetune\")\n",
"plt.plot(range(1, length), value2, \"-s\", label=\"w/o Pretrain\")\n",
"\n",
"\n",
"plt.xticks(range(length))\n",
"plt.ylabel('Recall@5')\n",
"plt.xlabel(\"EPOCHS\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 152,
"id": "confirmed-exhibit",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEGCAYAAABy53LJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA7NElEQVR4nO3dd3wUZf7A8c83nRJAgSBVEEIXEaIURVGK2CKnoCAqlhMLqGc78fQ87uQ8+53tJ4oiFhQ5VOROFFBAUaQkCkpvggaB0Dup398fM+HWsLvZSbLZlO/79dpXZmfmmXkm2cx3nzLPI6qKMcYYE6qoSGfAGGNMxWKBwxhjjCcWOIwxxnhigcMYY4wnFjiMMcZ4EhPpDJSFevXqafPmzSOdDWOMqVDS09N3qmr9wuurROBo3rw5aWlpkc6GMcZUKCKy2d96q6oyxhjjiQUOY4wxnljgMMYY44kFDmOMMZ5Y4DDGGONJWHtVicgA4DkgGnhNVR8PsN8VwFTgDFVNE5F+wONAHJAN3K+qc9x95wENgSNu8v6qmhnO6zDGmLKUMnY2Ow9mH7e+Xs040h7uF7a0oQpb4BCRaOAloB+QASwRkemqurLQfonAXcAin9U7gUtV9VcR6QjMBBr7bB+mqta/1pgqoKQ3wrK4kZY2f/kNtr600oYqnCWOM4H1qroRQEQmA5cBKwvt9yjwBHB/wQpV/d5n+wqgmojEq2pWGPNrTLkXyZtopNKW9EZYkvRlfc0Hs3LJ3H806HFfmrueKBGiBKJEEPdnlEB0lARNW1rCGTgaA7/4vM8AuvnuICJdgKaq+omI3I9/VwDfFQoab4hIHvABMFb9TCoiIiOAEQDNmjUr/lUYU44U5yaoquTkKdl5+UHT5+crUUFuPGX5LTgvXzmcncuR7Lygx527JhNVJS8f8lVRVfLVWc5X59qDOZiVS424aET8X3e4rvnVrzawfX8WmQeyyNx/9NjPQ0VcL8BTM9cUuU+4RezJcRGJAp4Frg+yTwec0kh/n9XDVHWLW8X1AXAt8FbhtKr6KvAqQEpKis1WZSo0VWVbEd9Ez3t6Htm5+eTk5ZOdl09Obv6xgBGKU/40g+goITZaiIuOIi4milifn8EMn7CY6CjnW6/4fBuOEgkajAAuem4+R3PyOJydx5GcPI5k54Wc5xveWBLSfoF0/MtMoqOEWgkx1K4WS+1qsdRyf9auFhs07ZjpKziSncdhN89HcnKda3CvI5jHZqymWmw0DWrFk5SYQIdGtTivTRJJteJpUCueu99fFjDtmrEDUJ/gmK+K5kOeKvmqpIz9vFi/Cy/CGTi2AE193jdx1xVIBDoC89xofxIwXURS3QbyJsBHwHWquqEgkapucX8eEJF3carEjgscxoRTuKowFv+pL7/uO8K6zIOs336QtdsPOMuZBzmYlRv0uKc2ru3e6J0bf2x0FLExUT5BQHhsxuqA6e/u25rsvDwn2BQEIPdnTp6yPvNgwLR7D2cf900/X5W8fKWoSUYb1UkgITaa6nHRVIuNplpcjM9yNA9PWx4w7Ue39/QJUD7ByieAnf/MlwHT/+mituw7kuO+co8tZ+w5wr4jOUHz/eF3GVSPi6GaT15rxsdQr2Y81eOi2bzrcMC0P47pT834mIAlnWCBIz4mOmi+ykI4A8cSIFlEWuAEjCHA1QUbVXUfUK/gvdtb6j43aNQBPgFGq+o3PvvEAHVUdaeIxAKXAOEPr8YUEqwaYvFPu4udtuOYmRz2qa6oVzOe5KSaXN6lMckNEvlzkJvo80NPLzLfwQLHXX2Tg6b9ZPQnAbd9POrsoGmbB0n72vAzgqYNFjhOb3ZC0LRFGXFOy6Dbg+X7hzEXBE378dJfA25LTAhemqlXMy7gl4uilCRtqMIWOFQ1V0RG4fSIigYmqOoKEfkbkKaq04MkHwW0Ah4RkUfcdf2BQ8BMN2hE4wSN8eG6BmP82Xc4+DfRK1/5ttjHvjKlKckNapKclEhyUk1OqPHbf/ZggaOyKumNsCxupKWtJL29yqKnWFjbOFR1BjCj0LpHAuzb22d5LDA2wGG7llb+jAnV0Zw8vliVycdLtzBvzY6g+076fbeg24e9tijgtjGpHYKmjeRNNFJpS3ojLEn68v7NP1KkqF4HlUFKSorasOrGq7x85dsNu5i2dAszl2/jQFYuSYnxpJ7WiNe+/ilguk2PXxz0uMGqP4pKa0xZEpF0VU0pvL5KzMdhjD+BGqnrVIvl8i5N+M8Pv7LjQBaJ8TEM6HgSA09vTPdT6hIdJUEDhzGVnQUOU2UFaqTeeySHdxZupneb+gw8vTHnt00iIfa3PVmsCsNUZRY4TIUWarfY3Lx8th/IYuveI2zZe4Rf9wZ/JmLJQ32pXT1wz5fy3nhpTDhZ4DAVWrCuraPe/Y6t+47y694jbN9/lHwPzXnBgoYxVZ0FDhNRXh6kU1V2HsxmXeYB1mceZN32wA+kASzfso+GtavRs2U9GtVJoFGdas6rdgIN61Sj419mluq1GFNVWOAwERWsxPDN+p2s236Ate5T1OsyD7DH5xmKxITgH995959Xqnk1xjgscJhyq+B5h9rVYmndoCYDOjakdcHDcQ1qkpQYT4sHZxRxlMCskdqY4rHAYcrU0Zw8fsjYR9rm3aRv2hN033dv7kZyUiL1asYFHNOnJKyR2pjiscBhSixYO8XMP5xD2uY9pG/eQ9qm3fy4ZR85eU4rdcv6NYIet2fLekG3F5zDSg3GlC0LHKbEgrVTdHWHeI6LjqJTk9rceHYLUk4+ka4nn8CJNeKCPkUdCis1GFP2LHCYEsk8EPx5iAcvbEtK8xPo2Li23+GgrcRgTMVjgcN4cjQnj8U/7Wb+uh3MX7eT1dsOBN3/lnODD1ttJQZjKh4LHAYI3k7x1o3djgWKxZt2k52bT1x0FCnNT+CBAW154rPAczwYYyofCxwGCN5OcdHz8wFo3aAm13Y/mV7J9ejWoi7V4pyqJwscxlQtFjgM+UWMxfH04NM4u1U9Tqqd4He7tVMYU7VY4Kiijubk8e2GXcxetZ0vVm0Puu+grk2Cbrd2CmOqFgsclUhR4z7tPpTNnNWZzF65jfnrdnI4O4/qcdGc27o+ny7fFoEcG2MqIgsclUiwdorB4xaQvnkP+Qon1Urg8i6N6duuAd1PqUtCbHSJn6cwxlQdYQ0cIjIAeA6IBl5T1ccD7HcFMBU4Q1XT3HUPAjcBecCdqjrTyzHNbx3KymPU+cn0a9eAjo1rHTeEh7VTGGNCFbbAISLRwEtAPyADWCIi01V1ZaH9EoG7gEU+69oDQ4AOQCPgcxFp7W4u8pjmeDPu6hV0u7VTGGNCFRXGY58JrFfVjaqaDUwGLvOz36PAE4DvI8iXAZNVNUtVfwLWu8cL9ZhVyqKNuxj66sJIZ8MYU0WEM3A0Bn7xeZ/hrjtGRLoATVW1cAV7oLRFHtPn2CNEJE1E0nbs2FG8KyjnlmzazdXjF3LVqwtZlxl8UiNjjCkt4QwcQYlIFPAscG84jq+qr6pqiqqm1K9fPxyniJj0zbu55rVFDB73LWu3H+Dhi9sx/4/nBWyPsHYKY0xpCmfj+Bagqc/7Ju66AolAR2Ce21B7EjBdRFKLSBvsmJVa+uY9/Ovztcxft5O6NeJ46KJ2XNP95GNPcFs7hTGmLIQzcCwBkkWkBc7NfQhwdcFGVd0HHJtwQUTmAfepapqIHAHeFZFncRrHk4HFgAQ7ZkUX6DmMOtVi6dS0Dl+t3cGJNeJ48MK2XNvjZKrHWW9qY0zZC9udR1VzRWQUMBOn6+wEVV0hIn8D0lR1epC0K0RkCrASyAVGqmoegL9jhusaylqg5zD2Hsnhx4y9PDCgLdf1OJka8RYwjDGRI6rBxymqDFJSUjQtLS3S2ShSsIfwVvz1AgsYxpgyJSLpqppSeH3EGseNNxY0jDHlhQUOY4wxnljgKCcy9hyOdBaMMSYkFjjKgZ93HeaqVxYiAbbbcxjGmPLEKs4jbNPOQwwdv5DD2Xn8546z6di4dqSzZIwxQVngiKCNOw4ydPxCsnPzee/m7rRvVCvSWTLGmCJZ4IiQ9ZkHGDp+Efn5ynsjutP2JAsaxpiKwQJHBKzZdoBhry0EhMkjupPcIDHSWTLGmJBZ43gZW7V1P0PHLyRKLGgYYyomK3GUoeVb9nHN64uoFhvNuzd3p0W9GpHOkjHGeGYljjLyQ8Zerh6/kBpxMbw/oocFDWNMhWUljjLw/c97uG7CYmpXi+W9m7vT9MTqkc6SMcYUmwWOUhZoaPQogc/+cA6N61SLQK6MMab0WFVVKQs0NHq+YkHDGFMpWOAwxhjjiQUOY4wxnljgMMYY44kFDmOMMZ6ENXCIyAARWSMi60VktJ/tt4rIjyKyVES+FpH27vph7rqCV76IdHa3zXOPWbAtKZzX4FWgIdBtaHRjTGURtjnHRSQaWAv0AzKAJcBQVV3ps08tVd3vLqcCt6vqgELHORWYpqot3ffzgPtUNeRJxMtyznFVpcc/5tC5aR3GXdu1TM5pjDHhEIk5x88E1qvqRlXNBiYDl/nuUBA0XDUAf1FsqJu2Qvhxyz627T9Kv/YNIp0VY4wJi3A+ANgY+MXnfQbQrfBOIjISuAeIA873c5yrKBRwgDdEJA/4ABirfopNIjICGAHQrFmz4uS/WGat2E50lHB+23JVg2aMMaUm4o3jqvqSWw31APCw7zYR6QYcVtXlPquHqeqpQC/3dW2A476qqimqmlK/fv0w5f54s1Zu44zmJ3BCDWvTMMZUTuEMHFuApj7vm7jrApkMDCy0bgjwnu8KVd3i/jwAvItTJVYubNp5iLXbD9K//UmRzooxxoRNOAPHEiBZRFqISBxOEJjuu4OIJPu8vRhY57MtCrgSn/YNEYkRkXrucixwCeBbGomo2Su3A1j7hjGmUgtbG4eq5orIKGAmEA1MUNUVIvI3IE1VpwOjRKQvkAPsAYb7HOIc4BdV3eizLh6Y6QaNaOBzYHy4rsGr2Su3065hLRv91hhTqYV1dFxVnQHMKLTuEZ/lu4KknQd0L7TuEFAu+7juPJhF2ubd3HF+ctE7G2NMBRbxxvHKYs6qTPLVqqmMMZWfBY5SMmvldhrXqUaHRrUinRVjjAkrCxyl4HB2LvPX7aBf+waISKSzY4wxYWWBoxTMX7eTrNx8+ls1lTGmCrDAUQpmrdhO7WqxnNHixEhnxRhjws4CRwnl5uUzZ/V2zm+bRGy0/TqNMZWf3elKKG3zHvYczrFqKmNMlWGBo4RmrdhOXEwU57Quu/GwjDEmkixwlICqMnvVNs5uVY8a8WF9ltIYY8oNCxwlsHrbAX7ZfcSqqYwxVYoFjhKYvXI7ItCnnQUOY0zVYYGjBGat3EaXZidQPzE+0lkxxpgyY4GjmLbsPcLyLfutmsoYU+VY4Cimz23uDWNMFWWBo5hmrdxGq6SanFK/ZqSzYowxZarIwOHOuneLiHwmIj+4r09F5FZ3QqUqZ9/hHBZt3G2lDWNMlRTKwwdvA3uBMUCGu64Jzmx97wBXhSNj5dncNZnk5qu1bxhjqqRQAkdXVW1daF0GsFBE1oYhT+XerJXbSEqM57QmdSKdFWOMKXOhtHHsFpHBInJsXxGJEpGrcOYJr1KO5uTx5Zod9G3fgKgom3vDGFP1hBI4hgCDgO0istYtZWwDLne3BSQiA0RkjYisF5HRfrbfKiI/ishSEflaRNq765uLyBF3/VIRGeeTpqubZr2IPC9lPHPStxt2cSg7z6qpjDFVVpFVVaq6CbcdQ0Tquut2FZVORKKBl4B+OFVbS0Rkuqqu9NntXVUd5+6fCjwLDHC3bVDVzn4O/TJwM7AImOHu/2lR+Skts1Zup2Z8DD1a1i2rUxpjTLniqTuuqu7yDRoi0i/I7mcC61V1o6pmA5OBywodb7/P2xqABju/iDQEaqnqQlVV4C1goJdrKIn8fOXzVds5t0194mOiy+q0xhhTrpR0SNfXgWYBtjUGfvF5nwF0K7yTiIwE7gHigPN9NrUQke+B/cDDqjrfPWaGzz4Z7rrjiMgIYARAs2aBsujN0oy97DiQZdVUxoRBTk4OGRkZHD16NNJZqXISEhJo0qQJsbGhPWFRZOAQkemBNgElrq9R1ZeAl0TkauBhnG6+W4FmqrpLRLoC00Skg8fjvgq8CpCSkhK0JBOqWSu2ExMl9G6TVBqHM8b4yMjIIDExkebNm1PGTZdVmqqya9cuMjIyaNGiRUhpQilx9AKuAQ4WWi841VGBbAGa+rxv4q4LZDJO+wWqmgVkucvpIrIBaO2mb+LhmKVq1spt9GhZl9rVquRzj8aE1dGjRy1oRICIULduXXbs2BFymlACx0LgsKp+6eeEa4KkWwIki0gLnJv7EODqQumTVXWd+/ZiYJ27vj6wW1XzROQUIBnYqKq7RWS/iHTHaRy/DnghhGsosfWZB9m44xDX92xeFqczpkqyoBEZXn/vRTaOq+qFqjo3wLZzgqTLBUYBM4FVwBRVXSEif3N7UAGMEpEVIrIUp51juLv+HOAHd/1U4FZV3e1uux14DVgPbKCMelTNdgc17GtzbxhTaUVHR9O5c+djr02bNtGzZ89iH2/ixIn8+uuvpZjD8sFz47jbJXePquYXta+qzsDpMuu77hGf5bsCpPsA+CDAtjSgo5c8l4bZK7dxauPaNKpTraxPbYwpJGXsbHYezD5ufb2acaQ9HKyzZ3DVqlVj6dKlv1m3YMGCYh9v4sSJdOzYkUaNGhX7GOVRSN1xReQEEXlRRL7EeTbjUxGZICI1wpu98iFz/1G+/2Wv9aYKh6eSYUzt419PJUc6Z6Yc8xc0gq0viZo1nRGw582bR+/evRk0aBBt27Zl2LBhOE8FQHp6Oueeey5du3blggsuYOvWrUydOpW0tDSGDRtG586dOXLkCM2bN2fnzp0ApKWl0bt3bwDGjBnDjTfeSO/evTnllFN4/vnnj53/nXfe4cwzz6Rz587ccsst5OXllfo1ehVKr6o6OKWGP6nqKJ/15wGPi8gUYIVPVVKl8/mqTFShf4eTIp2VyudQprf1pkr4639WsPLX/UXv6MdVr3zrd337RrX4y6XBO2ceOXKEzp07A9CiRQs++uij32z//vvvWbFiBY0aNeKss87im2++oVu3btxxxx18/PHH1K9fn/fff5+HHnqICRMm8OKLL/L000+TkpJSZL5Xr17N3LlzOXDgAG3atOG2225j/fr1vP/++3zzzTfExsZy++23M2nSJK677rrQfhlhEkpV1Z+Bp1V1roi8DXQHdgL1gB9xelc9jNNGUSnNXrmNZidWp3UDm3vDmMrMX1WVrzPPPJMmTZyOnQVtIHXq1GH58uX06+dUkeXl5dGwYUPP57744ouJj48nPj6epKQktm/fzhdffEF6ejpnnHEG4AS2pKTIPw4QSuA4R1XvdZezgKGqmiYiXYDbgK+B58KVwUg7mJXLN+t3cV2Pk63HR2nIzYJtP0JGGmxJi3RuTDlVVMmg+ehPAm57/5YepZ2dY+Lj448tR0dHk5ubi6rSoUMHvv3Wf0nHV0xMDPn5TvNw4QcdAx17+PDh/OMf/yilKygdoQSOBBERd4iPLsAyd/1yoIuq5lfmG+pXa3eQnZdvkzYF81Sy/6qlGklw42dukEh3AsW2HyHPrYdOLKLBcNEr0GU4xCaUfp6NKSVt2rRhx44dfPvtt/To0YOcnBzWrl1Lhw4dSExM5MCBA8f2bd68Oenp6Vx44YV88IHf/j+/0adPHy677DLuvvtukpKS2L17NwcOHODkk08O5yUVKZTG8cVAH3f5/4BZIvIYTjfbV0TkDGBFmPIXcbNWbOPEGnF0PfmESGel/ArWTvFCF/hoBHz/DsRUg+63wZVvwz2r4N5VwY/76R+d9GkTILf0Gz1NxVWvZpyn9eEUFxfH1KlTeeCBBzjttNPo3LnzsZ5Y119/PbfeeuuxxvG//OUv3HXXXaSkpBAdXfR4d+3bt2fs2LH079+fTp060a9fP7Zu3RruSyqSFPQKCLiD8wDeFOBiVd0uIvWAU4CNOIFnOjBcVYM9DBhRKSkpmpYWerVIuLr6VVpjagfedulz0DgF6reFaD8F3GCllSvGw5y/Q8ZiqNMMzn0AOg3xfxxT4a1atYp27dpFOhtVlr/fv4ikq+pxLfuhDKu+0R2IcLqIzMJ5kjwPuMh93Vueg0ZxlGVXvwpv67Lg27teH3z7/euCb29xLqz/HOaMhY9Hwvxnofdo6HgFRNkIxcZEQkhf3VR1kYj0wKmyOs1dvRAY6z4hbqqaQ7tgzqOQPjG85xGB5H7Qqi+smQFzH4MPb4b5zzgBZMYfA5dYigpKxphiCbnM7z4pPtt9maoqL9dpc5g7FrIOQrdbYdHL4T+vCLS9GFpfCKs+hrn/gH9fH3h/ew7EmLAJ5QHAAzgTLAm/nWhJAFXVWmHKmylvfvoKPn0AMlc6VUgXPgFJ7WD5B4G/9Ze2qCjo8Dtol+qc98ObS/8cxpigQmnjSCyLjJhybO/PMOvPsHKa00h95dvQ7lKnFACRqRKKioZOV1rgMCYCQilxnBhse2UcaqRezbiAvaoqrUC9m2JrQMF4luc9BD3vgNgKMtCj6v+CmzGm1ITSxpHO/6qqClOcrrmVSpXschuoTSDnkFM11O9RqNPU/z7l1ZuXwsXPQv3Wkc6JqUQef/xxmjZtyrBhw4rcd+LEidx///00btyY7Oxs7r77bm6+OfRS8rRp02jdujXt27f3lMfp06ezcuVKRo8e7SldqEKpqgptLkFTeQ2eGOkcBFYjyX/Qi0uEbT/Ayz3h7D9Ar3srTknJFC3Y8z9hrjqdOXMmU6ZMCXn/q666ihdffJHMzEw6dOhAamoqDRr8bySK3NxcYmL834qnTZvGJZdc4jdwBEuXmppKamqq322lwdOTVCJyAs5sfMfGgFDVr0o7U8aELNhN4uAOmPUQfPUU/DgVLn4GWvUJvL+pOMIwqvJTTz1FfHw8d955J3fffTfLli1jzpw5zJkzh9dff51Jkyaxf/9+srOzqV+/Pps2beLGG29k586d1K9fnzfeeINmzZoFPH5SUhItW7Zk8+bNPPDAAyQkJPD9999z1llnMXLkSEaOHMmOHTuoXr0648ePZ/fu3UyfPp0vv/ySsWPH8sEHH3DTTTfRuXNnvv76a4YOHUrr1q0ZO3Ys2dnZ1K1bl0mTJtGgQQMmTpxIWloaL774Itdffz21atUiLS2Nbdu28eSTTzJo0KBi/57AQ+AQkd8Dd+HM870UZ5Tcb4HzS5QDY8KlZn24/FXoPAw+uQfeudx5cPCCxyDRhsgv1z4d7YxrVhxvXOx//UmnwoWPB0zWq1cvnnnmGe68807S0tLIysoiJyeH+fPnc845zmSnn3/+OX36OF8+7rjjDoYPH87w4cOZMGECd955J9OmTQt4/I0bN7Jx40ZatWoFQEZGBgsWLCA6Opo+ffowbtw4kpOTWbRoEbfffjtz5swhNTWVSy655Dc3+uzsbApGwtizZw8LFy5ERHjttdd48skneeaZZ44799atW/n6669ZvXo1qampZRc4cILGGcBCVT1PRNoCj5Xo7MaUhVPOhdsWwNf/ch4cXDcb+jwCKTfa0+fmmK5du5Kens7+/fuJj4+nS5cupKWlMX/+/GMTK3322WfccMMNAHz77bd8+OGHAFx77bX88Y9/9Hvc999/n6+//pr4+HheeeUVTjzR6W80ePBgoqOjOXjwIAsWLGDw4MHH0mRlZQXM51VXXXVsOSMjg6uuuoqtW7eSnZ1Nixb+WxYGDhxIVFQU7du3Z/v27R5+K/55CRxHVfWoiCAi8aq6WkTalDgHJvIy0gNvC8ezGJEQEw+9H4BTBzmljxn3wWejId/PwAf21HnkBSkZAMHHR7sh8JDrwcTGxtKiRQsmTpxIz5496dSpE3PnzmX9+vXHxnBavHgxL7/s7YHXgjaOwmrUcCZQzc/Pp06dOkHnAfGXDpxSzz333ENqairz5s1jzJgxftP4Dtle1PiEoQhp6lhXhjsb4DRgtoh8DGwOlkBEBojIGhFZLyLHNe+LyK0i8qOILBWRr0Wkvbu+n4iku9vSReR8nzTz3GMudV+V5M4WIVkH4IOboHZTeGAzjNn321dlu4HWbQnXToMrXvcfNMCeOq/CevXqxdNPP80555xDr169GDduHKeffjoiwooVK2jbtu2xUW179uzJ5MmTAZg0aRK9evUq1jlr1apFixYt+Pe//w04N/Zly5wx4AoPy17Yvn37aNy4MQBvvvlmsc5fHCEHDlX9naruVdUxOLMCvg4MDLS/iETjzE9+IdAeGFoQGHy8q6qnqmpn4EngWXf9TuBSVT0VGA68XSjdMFXt7L7sv7wkPrkP9m6GK16DanUinZuyIeKUPEzFFagkXMIScq9evdi6dSs9evSgQYMGJCQkHAsIn376KQMGDDi27wsvvMAbb7xBp06dePvtt3nuueLPZzdp0iRef/11TjvtNDp06MDHH38MwJAhQ3jqqac4/fTT2bBhw3HpxowZw+DBg+natSv16tUr9vm9KnJY9WM7inTHmVv8gPu+FtBOVRcF2L8HMEZVL3DfPwigqn6nshKRocB1qnphofUC7AIaqmqWiMwD7lPVkMdJ9zqsepXxwxTnyeveDzoDBlY1wao7eoyC04bCSR3LLj9VXHkfVr1fv3689dZbxZoWtiLwMqy6l6qql4GDPu8PuusCaQz84vM+w11XOGMjRWQDTonjTj/HuQL4TlV9W4vecKup/iwBph8UkREikiYiaTt27AiSzSpq90/w33ugWQ/odV+kc1P+LBoH486CcWfDt//ndO01Vdrs2bMrbdDwykvjeMH0sYAzWq6IlHhGHVV9CXhJRK4GHsapmnJOKNIBeALo75NkmKpuEZFE4APgWuAtP8d9FXgVnBJHSfNZqeTlwAe/B4lyuqvaxEjHu3ctLJ8KS9+FmQ/C7D9Dcn+nFNL6AvhnRxvO3VRZXu4YG0XkTv5XyrgdZxbAQLYAvmNUNHHXBTLZ59iISBPgI5zqq2OVe6q6xf15QETeBc7ET+AwQcx73Jn/e/BEZ9DCqirQU+c1kqBGXeh2i/PKXOUEkB+mOHOCVDsBjuzxf0xrWDdVgJfAcSvwPE6pQIEvgBFB9l8CJItIC5yAMQS42ncHEUlW1YKvZxcD69z1dYBPgNGq+o3P/jFAHVXdKSKxwCXA5x6uwfw033mW4fRrnTGoqrJQSwZJ7aD/o9DnL7BxHiydBCs+DGvWqipVJUDtswkjr110vUzklIlz8w91/1wRGQXMBKKBCaq6QkT+BqSp6nRglIj0BXKAPfyvmmoU0Ap4REQecdf1Bw4BM92gEY0TNMaHmqcq7/Bu+HCE0yX1wicinZuKJzoGkvs6r2CBY+qNzjzrTVLgpE4Qm/Db7REcZ6k8S0hIYNeuXdStW9eCRxlSVXbt2kVCQkLRO7u8DDnSGqcqqYGqdhSRTkCqqo4NkqEZwIxC6x7xWb4rQLqxQKDjdg01z8aHKky/Aw7tgKGfQ1yNotOY4vl5oTPJFEBUrNMzqyCQNE4JyzhLlUGTJk3IyMjAOrOUvYSEBJo0aRLy/l6qqsYD9wOvAKjqD24bQ8DAYcqR9Ddg9X+h/9+hUedI56Zyu2cl7N/qtCNlpMGWdFj2HiyxwnEwBU9um/LPS+CorqqLCxUhAzx6a8qVzNXw2Z+g5fnQ/fZI56ZyCNawDlCrIdS61JkpESA/D3asdgLJf/z1Ojem4vASOHaKSEvcecdFZBCwNSy5MqUn56gzpEh8TRg4zpmz25Sc17aIqGho0MF5WeAwFZyXwDES57mItiKyBfgJKHoKLBNZn/8Fti+Hq/8NiQ2K3t9EVuYqpxeXMeWYl7GqNqpqX6A+0BY4Fzg7XBkzpWDtTOcJ6G63Qev+Re9vykbA8ZSiYHwfWPlxmWbHGK+KLHG4Y1KNxBku5GOcLrAjgXuBH4BJ4cygKaYD22DabdDgVOj310jnxvgKVM21/1d4/1qYcp0z1e15D9l8IaZcKnKQQ3f49D04s/31AZIAAe5S1aXhzmBpqDKDHNrzARVfbpYzV8h3b0GrflVr1GJT7gQa5DCUNo5T3OHNEZHXcBrEm6nq0VLOoykpez6g4ouJh0ufh4ad4dMHYPx5MORda/cw5UoobRw5BQuqmgdkWNAwJoxE4Iyb4Pr/QtZBt91jeqRzZcwxoQSO00Rkv/s6AHQqWBaR/eHOoDFVVrPucMuXTmljyrXwxaPO8yDGRFiRgUNVo1W1lvtKVNUYn+VaZZFJY6qsWo3ghhnOoJTzn4b3hsCRvZHOlanibCIGY8q7mHhIfcEZKuaTe+GJk4/fxzpAmDJkjxFXFrnZzsRM/pRwHmZTDojAGb8PvN06QJgyZCWOymL5VNB8uOYDaNU30rkxxlRiVuKoDFRhwQuQ1AFa9ol0bkykbPzS+SwYE2YWOCqD9V9A5kroeYdTpWGqprdS4fX+sG62BRATVhY4KoMFz0FiI+h4RaRzYiLpoqedYUsmDYJXe8Oq/0J+fqRzZSohCxwV3a9L4aevoPutEBMX6dyYcAvU0aFGEpx5M9z5vdMD6+heeH8YjDsbln9oz3+YUmWN4xXdghcgLhG6Xh/pnJiyUFSX25g46HIdnHa1M33t/Kdh6g1Qr7UzcOKsP9t4ZqbEwho4RGQA8BwQDbymqo8X2n4rzki7ecBBYISqrnS3PQjc5G67U1VnhnLMKmXvz7DiI+hxOyTUjnRuTHkSHQOnXQWnDnKGaf/qafjolsD7h9Kd1wbRNK6wBQ4RiQZeAvoBGcASEZleEBhc76rqOHf/VOBZYICItAeGAB2ARsDnItLaTVPUMauOhS87jeHdbo10Tkx5FRUNHS+H9gNh7acw+erA+859DOJqQlwNiE/0Wa7plGptEE3jCmeJ40xgvapuBBCRycBlwLGbvKr6jnVVA3daWne/yaqaBfwkIuvd41HUMauMI3sg/U3oOAhqN4l0bkx5FxUFbS8Ovs+XT5RNXkyFF87A0Rj4xed9BtCt8E4iMhK4B4gDzvdJu7BQ2sbucpHHdI87AhgB0KxZM++5L+/SJkDOIacLrjGl4ZE9zmcq6yBkH4LsAz7LB525642hHDSOq+pLwEsicjXwMDC8lI77Ks4c6aSkpFSuTu25WbDoFWh5PpzUMdK5MZVFVJRTRRWf6H+7BQ7jCmd33C1AU5/3Tdx1gUwGBhaR1usxK6cfpsDB7dDzzkjnxFQ0wbrzGhOicJY4lgDJItIC5+Y+BPhNy5yIJKtqQXeMi4GC5enAuyLyLE7jeDKwGGfK2qDHrPTy850uuCedCqf0jnRuTEVTkt5PNZL8N4THVCv+MU2FFLbAoaq5IjIKmInTdXaCqq4Qkb8Baao6HRglIn1xZhncg1tN5e43BafROxcY6c4+iL9jhusayqX1s2HnGrh8vA0vYsqWv6Dz2YOw8P9g1X+g3aVlnycTEaJVYEyblJQUTUtLi3Q2SscbF8OeTXDXUoiOjXRuTFWXmw0T+sPujXDr11CnEnZEqcJEJF1VUwqvtyFHKpIt6bD5a+h+mwUNUz7ExMGgCU4V6tSbIC8n0jkyZcACR0Wy4AWIrw1dS6XjmTGl48RTIPU5yFgMc/8e6dyYMmCBo6LY/ZMzdETK9YG7SxoTKR2vgC7D4et/OsP8m0ot4s9xmBAtfBkkGrrdFumcGOPfgMfhl8XOmFi3fgOJDUrnuDZGVrljJY6K4PBu+P5t6HQl1GoY6dwY419cdRg80Xna/MPfl95Q7jZGVrljgaMiWPI65ByGHqMinRNjgktqCxc96cwRM//ZSOfGhIkFjvIu5ygsfgVa9YMG7SOdG2OKdvq1cOpgmPcYbF4Q6dyYMLDAUd79MBkO7YCzbHgRU0GIwCX/hBOawwe/d6pai2P7Cpg0uFSzZkqHBY7yLD8fFrwIDU+D5r0inRtjQhefCIPecL70TLsNvDxovG8LTBsJL58FPy8Kvm9uVsnyaYrFAkd5tvYz2LXOGczQhhcxFU2jztDvUedzvPDlovc/shc+HwMvdIEfp0CPkc4ICcEGYJwy3Hl63ZQp645bni14Hmo3c2ZvM6Yi6naL01A++xFo1h0adzl+n9wspwPIV086E5R1ugrOewhOONnZHqjL7eLxMOM+Z071wRNtNIUyZGNVlSfWX91URod3w1MtQfOP3xZfC6rVgb0/wynnQb+/OlWzoVr4Mnw2Gjpc7gz8GW3fhUtToLGq7Ldcnlh/dVMZVT/Rf9AAyNoPdU6Gaz6EVn28H7v7bc74WLP/7JQ4Br7szLNuwsoChzEmsm75ypl9sLjOuhPysmHOoxAVC6kvlOx4pkgWOIwxkVUaN/lz7nNKHl8+7lRXXfKv0ulQYtXHflngMMZUDr1HOyWPr591Sh4XPVXy4GHVx35Z4CgP8vPhizGRzoUxFZsI9HnECR7fvui0eVzwmHVlDwMLHJGWc9R5QGrFh87czblHjt8nWD92YyqCQPOVl/ZnWwT6j4X8XGdK2+hY6PtXb8Ej66DThXjdrNLNWyUS1sAhIgOA53DmB39NVR8vtP0e4Pc484rvAG5U1c0ich7wT59d2wJDVHWaiEwEzgX2uduuV9Wl4byOsDm8GyYPg58XOB/us+6yb0emcirL9gARZ4j3vBz45jmIjofzHwq8vyrs2uAEinWzYPM3TqklrmbZ5bmCCVvgEJFo4CWgH5ABLBGR6aq60me374EUVT0sIrcBTwJXqepcoLN7nBOB9YBv+L9fVaeGK+9lYs8meGcQ7N0MV7wOpw6KdI6MqTxE4KKnnQDw1ZPOq7CEOs7DhutmwZ6fnHX12sCZIyC5PzTrAWPrl2m2K4pwljjOBNar6kYAEZkMXAYcCxxugCiwELjGz3EGAZ+q6uEw5rVsbfkO3r3S+UZ07TRoflakc2RM5RMVBZc+78xl48/RvfDdW9DiHGd4k+R+zsCMvgJVscXVKO3cVijhDByNgV983mcA3YLsfxPwqZ/1Q4DCA/v/XUQeAb4ARqvqcSOdicgIYARAs2bNPGQ7zNZ8ClNvhBr14PoZUL91pHNkTOVVVFffB36C2GqBtxeuYlN1RvxdPhVW/QfaXVryPFZA5eIpGRG5BkgBniq0viFwKjDTZ/WDOG0eZwAnAg/4O6aqvqqqKaqaUr9+OSluLh4Pk6+G+m3g919Y0DAm0oIFDX9E4LIXoXEKfDgCfl0almyVd+EMHFuApj7vm7jrfkNE+gIPAal+Sg5XAh+pak7BClXdqo4s4A2cKrHyLT/fGeRtxn1O3en1n0BN6yllTIUUWw2GvAvV68J7Q2D/1kjnqMyFs6pqCZAsIi1wAsYQ4GrfHUTkdOAVYICq+nuiZihOCcM3TUNV3SoiAgwEloch78UX6ElTgJSb4MInbSA2Yyq6xAYwdDJMuMAJHjd86sy5XkWErcShqrnAKJxqplXAFFVdISJ/E5FUd7engJrAv0VkqYhML0gvIs1xSixfFjr0JBH5EfgRqAeMDdc1FEuwJ0ovfsaChjFlLdCzIiV9huSkjnDFa7B1GUy71alZqCJsWPXSNqZ2kG37Am8zxlRMC16EWQ9Br/ugz58jnZtSZcOqG2NMOPQYCTvXwPynoV5rOO2qkh+znA+uWC56VRljTIUlAhc9A817wfRR8PPCkh+znA+uaIHDGGNKKiYOrnwLajd1hhHasynSOQorCxylLVwNccaY8q36iXD1FGeAxXeHwNH93o+Rlwsb5ha9X4RZG0dpKwf1j8aYCKnXyil5vHO5M0LE0MlF96TMy3FG4105DVb9F47sLpOsloQFDmOMKU2nnOt0vf/PXfBo3eO310iCu1e4weIjWP0JHNnjjMbbegB0GAjv+xu2r/ywwGGMMaWt6/VO4PDnUCY83QqO7oP4WtDmQmh/GbTsA7EJzj6BBldMqBOuHHtigcMYY8pam4ug/UBoeR7ExB+/vXCVd/YhePks0Dxnoqn4yM4VYo3jxhhT1n43DtoM8B80/ImrAQP/D/b+4ox7F2EWOIwxpiI4uSd0vx3SXo94zysLHMYYU1H0+TPUbQUfjyped99SYoHDGGPCIRzPdMVWg4Hj4MCvzvhYEWKN48YYEw7heqar6RnQ8w745jlol+pMeVvGrMRhjDEVTe8/Qf22MP0O5xmQMmaBwxhjKprYBBj4MhzMhM/+VOant8BhjDEVUeMu0OseWPYurPm0TE9tgcMYYyqqc/4IDTo6T6kfLrsxrixwGGNMRRUT51RZHd4Fn/6xzE5rgcMYYyqyhp2ckseP/4aV08vklGENHCIyQETWiMh6ERntZ/s9IrJSRH4QkS9E5GSfbXkistR9TfdZ30JEFrnHfF9E4sJ5DcYYU+71ugdO6gT/vRsO7Qz76cIWOEQkGngJuBBoDwwVkfaFdvseSFHVTsBU4EmfbUdUtbP7SvVZ/wTwT1VtBewBbgrXNRhjTIUQHeuMf3V0H3xyb9hPF84Sx5nAelXdqKrZwGTgMt8dVHWuqh523y4EmgQ7oIgIcD5OkAF4ExhYmpk2xpgKqUEHOO9BZ0Ko5R+G9VThDByNgV983me46wK5CfDtU5YgImkislBEBrrr6gJ7VTW3qGOKyAg3fdqOHTuKdQHGGFOh9LwLGnVxSh0H/cznUUrKxZAjInINkAKc67P6ZFXdIiKnAHNE5EdgX6jHVNVXgVcBUlJStDTza4wx5VJ0DOzZ5Ew/+3Tyb7fVSCq1YVDCWeLYAjT1ed/EXfcbItIXeAhIVdWsgvWqusX9uRGYB5wO7ALqiEhBwPN7TGOMqbICzVnub0bBYgpn4FgCJLu9oOKAIcBv+oqJyOnAKzhBI9Nn/QkiEu8u1wPOAlaqqgJzgUHursOBj8N4DcYYYwoJW+Bw2yFGATOBVcAUVV0hIn8TkYJeUk8BNYF/F+p22w5IE5FlOIHicVVd6W57ALhHRNbjtHm8Hq5rMMYYc7ywtnGo6gxgRqF1j/gs9w2QbgFwaoBtG3F6bBljjIkAe3LcGGOMJxY4jDGmMgnHzIOFlIvuuMYYY0pJuGYe9GElDmOMMZ5Y4DDGGOOJBQ5jjDGeWOAwxhjjiQUOY4wxnogzikflJiI7gM3FTF4PKO7MKJFKG8lz2zVXjLSRPLddc8VIC85gs/WPW6uq9gryAtIqWtqKmm+7Zvt92TWXn7TBXlZVZYwxxhMLHMYYYzyxwFG0Vytg2kie2665YqSN5LntmitG2oCqROO4McaY0mMlDmOMMZ5Y4DDGGOOJBY4gRGSAiKwRkfUiMtpDugkikikiy4txzqYiMldEVorIChG5y0PaBBFZLCLL3LR/Lcb5o0XkexH5bzHSbhKRH93ZHNM8pq0jIlNFZLWIrBKRHiGma+Oer+C1X0T+4OG8d7u/q+Ui8p6IJHhIe5ebbkUo5/T3uRCRE0Vktoisc3+e4CHtYPfc+SKS4vG8T7m/6x9E5CMRqeMx/aNu2qUiMktEGoWa1mfbvSKi7vTQoZ53jIhs8fl7X+TlvCJyh3vdK0TkSY/X/L7PeTeJyFIPaTuLyMKC/w0R8TsZXYC0p4nIt+7/1n9EpFaAtH7vHaF+xjwJRx/fyvACooENwClAHLAMaB9i2nOALsDyYpy3IdDFXU4E1no4rwA13eVYYBHQ3eP57wHeBf5bjLxvAuoV8/f9JvB7dzkOqFPMv9k2nIeWQtm/MfATUM19PwW4PsS0HYHlQHWc6Qk+B1p5/VwATwKj3eXRwBMe0rYD2gDzgBSP5+0PxLjLTwQ6b5D0tXyW7wTGhZrWXd8UZ1rpzYE+MwHOOwa4L4S/j7+057l/p3j3fZKX9IW2PwM84uHcs4AL3eWLgHke0i4BznWXbwQeDZDW770j1M+Yl5eVOAI7E1ivqhtVNRuYDFwWSkJV/QrYXZyTqupWVf3OXT6AM1974xDTqqoedN/Guq+Qez+ISBPgYuA1T5kuIRGpjfMP8zqAqmar6t5iHKoPsEFVvYwSEANUE5EYnCDwa4jp2gGLVPWwquYCXwKXB0sQ4HNxGU7QxP05MNS0qrpKVdcUldEAaWe5+QZYCDTxmH6/z9saBPicBflf+Cfwx0DpikhbpABpbwMeV9Usd5/M4pxbRAS4EnjPQ1oFCkoKtQnwOQuQtjXwlbs8G7giQNpA946QPmNeWOAIrDHwi8/7DEK8gZcWEWkOnI5Tcgg1TbRbhM4EZqtqyGmBf+H8M+d7SONLgVkiki4iIzykawHsAN5wq8leE5EaxTj/EAL8M/ujqluAp4Gfga3APlWdFWLy5UAvEakrItVxvkU29ZhfgAaqutVd3gY0KMYxSupG4FOviUTk7yLyCzAMeMRDusuALaq6zOs5XaPcarIJHqtdWuP8zRaJyJcickYxz98L2K6qXmZM+gPwlPv7ehp40EPaFfzvS+tgQvicFbp3lPpnzAJHOSUiNYEPgD8U+nYXlKrmqWpnnG+QZ4pIxxDPdwmQqarpxcmv62xV7QJcCIwUkXNCTBeDUzx/WVVPBw7hFKlDJiJxQCrwbw9pTsD5h2wBNAJqiMg1oaRV1VU4VTyzgM+ApUCelzz7OabioYRYGkTkISAXmOQ1rao+pKpN3bSjQjxfdeBPeAg0hbwMtAQ64wT7ZzykjQFOBLoD9wNT3NKDV0Px8AXFdRtwt/v7uhu3dB2iG4HbRSQdpwoqO9jOwe4dpfUZs8AR2BZ+G9mbuOvCTkRicf7wk1T1w+Icw63qmQsMCDHJWUCqiGzCqZY7X0Te8XjOLe7PTOAjnOq+UGQAGT6lo6k4gcSLC4HvVHW7hzR9gZ9UdYeq5gAfAj1DTayqr6tqV1U9B9iDU6fs1XYRaQjg/gxYfVLaROR64BJgmHtDKa5JBKg+8aMlTqBe5n7WmgDfichJoSRW1e3ul6N8YDyhf8bA+Zx96FbpLsYpWfttmA/ErdK8HHjfSzpgOM7nC5wvNyHnW1VXq2p/Ve2KE7A2BMmfv3tHqX/GLHAEtgRIFpEW7rfZIcD0cJ/U/Qb0OrBKVZ/1mLZ+Qe8YEakG9ANWh5JWVR9U1Saq2hznWueoakjfvt3z1RCRxIJlnMbXkHqVqeo24BcRaeOu6gOsDPXcruJ8C/wZ6C4i1d3fex+ceuGQiEiS+7MZzs3kXY/nB+czNdxdHg58XIxjeCYiA3CqJVNV9XAx0if7vL2M0D9nP6pqkqo2dz9rGTgNuttCPG9Dn7e/I8TPmGsaTgM5ItIapxOG15Fj+wKrVTXDY7pfgXPd5fOBkKu5fD5nUcDDwLgA+wW6d5T+Z6ykreuV+YVTb70WJ8I/5CHdezjF6Bycf4ybPKQ9G6co+QNO9cdS4KIQ03YCvnfTLidAr48QjtMbj72qcHqfLXNfK7z8vtz0nYE0N+/TgBM8pK0B7AJqF+Na/4pz01sOvI3b4ybEtPNxAtwyoE9xPhdAXeALnBvJ58CJHtL+zl3OArYDMz2kXY/ThlfwGfPbKypI+g/c39kPwH+AxsX5XyBIT7wA530b+NE973SgoYe0ccA7br6/A873cs3u+onArcX4O58NpLuflUVAVw9p78K5D60FHscd8cNPWr/3jlA/Y15eNuSIMcYYT6yqyhhjjCcWOIwxxnhigcMYY4wnFjiMMcZ4YoHDGGOMJxY4jPFIRPLktyPyjnbXzxNnNOVlIvJNwXMpIhInIv8SZ5TldSLysTsuWMHxThKRySKywR2uZYaItBaR5nL86K5jROQ+d7m7O3zGUnFGFB5Thr8GU4XFRDoDxlRAR9QZ1sWfYaqa5o7V9RTOMCiP4QwV0UZV80TkBuBDEenmpvkIeFNVh4AzjDbOeEK/HH/433gTuFJVl4lINM5IucaEnQUOY8LjK+AP7thMNwAtVDUPQFXfEJEbcZ4gViBHVY89Dazu4H/uQHXBJOE8LIZ7bK9P2xtTLBY4jPGumvx2Ep9/qGrhsYsuxXnCuRXwsx4/UGUa0MFdDjawZMtC5zoJZ3RVcIYmXyMi83AGWnxTVY+GehHGFJcFDmO8C1ZVNUlEjuAMpXEHUNLZ1jb4nsu3HUNV/yYik3DGBbsaZ7yu3iU8nzFFssBhTOkapqrHps0Vkd1AMxFJVGdynQJdgYLpeQcV92SqugF4WUTGAztEpK6q7iru8YwJhfWqMiaMVPUQTiP2s24DNiJyHc5sg3PcV7zvxFci0klEehV1bBG52Gc+iWSc+UD2lu4VGHM8CxzGeFetUHfcx4vY/0HgKLBWRNbhzOL2O3XhjHLb1+2OuwL4B85MbUW5FqeNYynOqLHDChrgjQknGx3XGGOMJ1biMMYY44kFDmOMMZ5Y4DDGGOOJBQ5jjDGeWOAwxhjjiQUOY4wxnljgMMYY48n/A69wHWo+JstmAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 229 ms\n"
]
}
],
"source": [
"value1 = ft_recall_10\n",
"value2 = cs_recall_10\n",
"name = \"Recall@10\"\n",
"\n",
"length = len(value1)\n",
"plt.plot(range(length), value1, \"-s\", label=\"Finetune\")\n",
"plt.plot(range(1, length), value2, \"-s\", label=\"w/o Pretrain\")\n",
"\n",
"\n",
"plt.xticks(range(length))\n",
"plt.ylabel(name)\n",
"plt.xlabel(\"EPOCHS\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 154,
"id": "attached-printer",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA1E0lEQVR4nO3deXgUVfbw8e/JHiBhC2HXALIICgEiIAoyIMiog7soqCDOuAsuPxVHR3F5x93RGR13XFFQVMAVccENEQKiEkAFRA3DEhANCNnP+0dVoA3dne5OL1nO53n6SXdV3aqTTqdO17237hVVxRhjjKkqLtYBGGOMqZ0sQRhjjPHKEoQxxhivLEEYY4zxyhKEMcYYrxJiHUC4ZGRkaFZWVqzDMMaYOmXZsmXbVLWVt3X1JkFkZWWRm5sb6zCMMaZOEZEffa2zKiZjjDFeWYIwxhjjlSUIY4wxXlmCMMYY45UlCGOMMV7Vm15MxtQ2ObctYNuukv2WZzRJIveGkfXuuKb+sQRhjB81Odl6K+dveW04rjGeLEGYWi+W34iDOdmWllews6iMwj2lFBaV+t3v5+u2k5IYR0piPMkJzk/P5/6O++vuEvaUlrO7pJw9JeUUVT4vdV7Hkl011S+WIEytF+1v4hUVyi+7SyjYWex33yf/9zMKi8rYWVRK4Z4y9pQGfnI+8/HFAW9bVfYtC0Iu+9dncjmkfTqHtm/KIe2b0jo9Zb9t6uLVS129aqrtic0ShKnVqpvQ6v9e/oomyQmkpSTQJDmBxh7PmyQn+D1xPPv5BrYWFlOws5iCXcVs3VlEwc5itu0qobyi+om0GiUl0Do9hfSURNJTE0hPSSQtJYH01ETSUxL567O+7+x/4W8DKS6toLisnKLSCopKnSuB4rIKikor+Nd73/kse+PxPWmUFE9qUjypic7PRknOFUhqYjzD7/3IZ9kftu3i/TVbqHxbW6Ulc0g7J2H0at+UQ9s3rfZku7ukjO27SijYVcy2ncVs/71k709/PvqugA7NU2nfLJWUxPj91gd7slRVft1dypadRX6PW5vFqhoyUJYgTK3z6+4SPlu7nU++L+CT77f53XbR2m3sLC5jV3EZwU6OeOPcPOIEMpok0yrNefRsm+48b5JMZnoKF89Y7rP8838dGNwBPQzukuF3vb8EMenITiEf9/2rhvF7cRmrNhWycuNvrNzo/PzouwICyIn0vPEddvuoxkpL8X86mTB9yd7nGU2SaN+8ER2apTpJo3mq35PlIx+tY0thEVsLi9lSWMSWnUVsKSympKyi2phPePBTurdJo3ubdA5uk0b3Nmm0bJK8d31NT7Shlq/uy8/vxWU0SopHRLyuj8ZVkyUIExX+/okWTR3B8p928Mn3BXz6/Ta+3vgbqs4JZ3CXlmz8dY/P/S66bgTg/LPtLinn9+IyJ2EUOUlj/BNf+Cy79PqjadE4ifg47/+ANZXRJMnn7xxJ1R23cXICh2W14LCsFnvX7SkpZ/XmQvI2/sY/5ub53PeZAw4go0kyLZsk0cr9Wfk6OSGerKlv+iz70gWHs/HX3WzcsYf8HXvY+OseVm8qZMHqLdWe6O94ew1NkhPITE+mdVoK/Q9oTuv0FDLTU2iTnsIlL/hO5I2TE3h/9VZeys33eC+S6dEmjR5t0mp8ovVXftG6bU5CcxNb1QTnT6+b5hMnTvxpyQk0qXKVHA2WIExU+Psnyr7lXXaXlBMfJ/Tt2IwpI7oypGsr+nRoSkJ8nN+TTiURobH7z5MZYEyt0pKr3aYmJ/maXOZH+7ipSfH0O6A5/Q5o7jdB/OP4nkHvu9KATi2AFvstr6hQtv1ezID/977PsitvPoYmyb5PV5e84Pu4L/xtEAAFO4v5dvNO1mwuZM3mnXy7eSfPLfY5Th0Ao/71EXEiiAgCxMXxx9fVfLcY9/i+LyiNk+Jp3TTlDwnu0Y/X+yz792N7sKvoj194dhWXsbOojE2/RadazRKEiajS8gp+3P67321O6deBI7tmcHiXlqSnJO63PlbfxKFmJ/m6eNyaCuVvFRcnZKbt31juyV9yCPS4ldWIR3bdV71XXqF0+ftbPvfbOaMJilKhzlWq58+KAOo0X/jbQFqnp9A6PcXr7+AvQZw/tIvffQfyxammLEGYgPmrJvr02uGsL/idtQW7WLtlJ99v3cX3W3exYdvvlFVTuX3riYf4XR+rb+INVayummoi1ONWV734yNn9q92HvxN1dW1Ntf3zaQnCBMxfNVHPG9/Z28gZJ3Bgy8YclNmEkT1b0zWzCVe+9FUUI92nrn4Tj6VYvWe1/WQZCbX9y48lCBOQ6hoRLx3ela6ZTTgoswmdMhrv140xVgnC1B11NTHFKrFF4/2yBGH82lJYxIwvfuKFL37yu92VI7v5Xd8Qvx2auqGmJ9r6fJVqCcLsR1VZumEHz3y+gfkrN1OuyrBurfjw24KQ91mf/4mMqa8sQZi9dpeUMXfF/3hm0QbWbN5JekoC5x6RxVmDDuTAlo2j0mvCGFN7WIJoYHz1REpNjCMxPo7CojIObpvOHScfygnZ7UlN2teWYNVExjQsEU0QIjIaeACIB55Q1Tu8bHM6MA1Q4CtVHecunwDc4G52m6o+E8lYGwpfPZH2lFYw4uDWTBicRc6Bzb3e3m/VRMY0LBFLECISDzwEjATygaUiMk9VV3ls0xW4DjhCVXeISKa7vAVwE5CDkziWuWV3RCrehmDbLv+39j84rl+UIjHG1AWRvIIYAKxV1fUAIjITOAFY5bHN34CHKk/8qrrVXX4MsEBVf3HLLgBGAy9GMN56Z/uuYpb88Aufr9/O4vXb+W7LrliHZIypQyKZINoDP3u8zgeqDn/ZDUBEPsOphpqmqu/4KNu+6gFE5HzgfIADDjggbIHXZv7uZn73iqNY8sN2Fq//hc/XbefbLTsBaJQUT05WC07s25673vk22iEbY+qoWDdSJwBdgWFAB+BjETk00MKq+hjwGEBOTk6Qgz3XTf7uZu53qzORTGpiPDlZzRmT3Y5BnVvSu0NTEuPjACxBGGMCFskEsRHo6PG6g7vMUz7whaqWAj+IyHc4CWMjTtLwLLswYpHWE/83qhuHd2nJoe2bkZQQ53Ub64lkjAlUJBPEUqCriHTCOeGfAYyrss0c4EzgKRHJwKlyWg+sA/4pIs3d7UbhNGYbPy4d3rXabawnkjEmUBFLEKpaJiKXAvNx2hemq2qeiNwC5KrqPHfdKBFZBZQDV6vqdgARuRUnyQDcUtlg3ZD9urt2z69rjKlfItoGoapvAW9VWXajx3MFrnQfVctOB6ZHMr665Kftu5n49JLqNzTGmDDxXlFtapUVP//KyQ9/xvZdJTRN3X9CHbA2BGNM+MW6F5Opxvy8zUyZ+SWZaSk8de5hdGnVJNYhGWMaCEsQtdj0T3/g1jdX0adDM56YkENGk+rnUDbGmHCxBFELlVcot725iqc+28AxvVpz/9i+fxg0zxhjosESRC2zp6ScKTO/5N1VW5h0RCeuP+7gaufNNcaYSLAEUYts21XMec/k8nX+r9z0l56ce0SnWIdkjGnALEHUEusKdjHxqSUU7Czm0bP6M6pXm1iHZIxp4CxBxICvAfcEeO2SI8ju2CzqMRljTFV2H0QM+BpwT8GSgzGm1rAEYYwxxitLEMYYY7yyBGGMMcYrSxBR9uVPNq22MaZusAQRResKdjHp6aX4uu/NBtwzxtQm1s01SrYUFnHOk0uIE+GDq4aRldE41iEZY4xfliCioLColAnTl7Bjdwkzzx9kycEYUydYFVOEFZWW87dnclm7dRePnNWf3h2axTokY4wJiF1BRFB5hXLFrBV88cMv3D82m6HdWsU6JGOMCZhdQUSIqjJtXh5vr9zMDccdzIl928c6JGOMCYoliAj5zwdreW7xj1wwtDN/HdI51uEYY0zQLEFEwItLfuK+Bd9xct/2XDu6R6zDMcaYkFiCCLN38zZz/WvfcFS3Vtx5am/ibLIfY0wdZQkijHI3/MJlL37Joe2b8t/x/UiMt7fXGFN32RksTL7bspNJTy+lXbNUpk88jMbJ1kHMGFO32VksRL4m/UmIE1o2SY5BRMYYE152BREiX5P+/LK7NMqRGGNMZFiCMMYY45UlCGOMMV5ZgjDGGOOVJYgQbNj2e6xDMMaYiLMEEaSi0nIunrEcX7e/2aQ/xpj6wrq5Bunm1/NYtamQ6RNzGN6jdazDMcaYiLEriCC8siyfF5f8zMXDulhyMMbUe5YgArRmcyHXz/mGgZ1acOXIbrEOxxhjIs4SRAB2FZdx8YzlpKUk8p9xfUmwMZaMMQ2AnemqoapMfeVrNmz7nX+f0ZfMtJRYh2SMMVFhCaIazy3+kTe+3sRVo7pzeJeWsQ7HGGOiJqIJQkRGi8i3IrJWRKZ6WT9RRApEZIX7+KvHunKP5fMiGacvK37+lVvfWMXwHplcdFSXWIRgjDExE7FuriISDzwEjATygaUiMk9VV1XZdJaqXuplF3tUNTtS8VVnx+8lXDJjOZlpKdx3eh+b+McY0+BE8gpiALBWVderagkwEzghgscLm4oK5cqXVrB1ZxH/Hd+PZo3s5jdjTMMTyQTRHvjZ43W+u6yqU0TkaxGZLSIdPZaniEiuiCwWkRO9HUBEzne3yS0oKAhb4A9/tI4Pvy3gH8f3pE/HZmHbrzHG1CWxbqR+HchS1d7AAuAZj3UHqmoOMA64X0T2awRQ1cdUNUdVc1q1ahWWgD5ft5173/2Wv/Rpx9mDDgzLPo0xpi6KZILYCHheEXRwl+2lqttVtdh9+QTQ32PdRvfnemAh0DeCsQKwtbCIy178kqyMxtx+8qGIWLuDMabhimSCWAp0FZFOIpIEnAH8oTeSiLT1eDkGWO0uby4iye7zDOAIoGrjdliVlVdw2Ytf8ntxGY+c1Z8mNqe0MaaBi9hZUFXLRORSYD4QD0xX1TwRuQXIVdV5wGQRGQOUAb8AE93iBwOPikgFThK7w0vvpxrzNa/0uMcXk3vDyHAfzhhj6pSIfk1W1beAt6osu9Hj+XXAdV7KLQIOjWRs4HteaV/LjTGmIYl1I7UxxphayhKEMcYYr6wl1kTH3V3h9637L2+cCVd/H/14jDHVsisIEx3ekoO/5caYmGvQCcLX/NE2r3QYlZXA/76MdRTGmBA06Com68oapOqqicrLoGCNkxD+t9z5uSUPyq1XmDF1UYNOECZI/qqJnhgJm7+Bsj3OsuR0aNsHBl0E7frCyxN977esGBKSwx6uMaZmLEGY8IhLgJxJTjJo1xdadIY4jxpMfwni8RFwyhOQ2SPiYRpjAhdQghCRY4AT2Tca60Zgrqq+E6G4TG1SXgorX/G/zaS3/a9vnOn9CiSlKez8Hzx2FBzz/yDnPLAxsIypFapNECJyP9ANeBZnyG5wBt6bLCJ/VtUpkQvPxFTxLlj+LHz+EBTmV7+9P/66su7cAnMugjevgu/fgxMehMYZNTueMabGAunFdKyqHquqM1X1U/cxEzgOODbC8ZlY2FUAH9wG/+oF86+D5gfCuJcid7y01jB+NhxzO6x7Hx4eDGvfj9zxjDEBCaSKqUhEDlPVpVWWHwYURSAmEyvb18HnD8KKF5yG4x7HwRGXQ8fDnPW+qokaZ9b82HFxcPjF0GkovHIePH8yDLoEjr7JGrCNiZFAEsRE4GERSWNfFVNH4Df2jb5q6gpfXVXjk6Gi1Gls7nMmDL4MMrr+cZto3PHc5hA4fyEsuBEWPwQ/fASnPGkN2MbEQLUJQlWXAwNFpA0ejdSqujmikZnI8NVVtbwYjrwCBl4IaW2iG1NVialw7N1w0NEw52L470Dv29kwHcZEVEB3UrvJAVVdBvwEDBaRnpEMzMTA0dNinxw8dTsGLlrke70N02FMRAXSi+kCYKrzVO7EqVZaCdwuInep6pORDdE0aGmtYx2BCaPS0lLy8/MpKrLmy2hLSUmhQ4cOJCYmBlwmkDaIS4FeQCrwI3CQqm4WkebAh4AlCBM7W/Kgda9YR2EClJ+fT1paGllZWTbnexSpKtu3byc/P59OnToFXC6QKqZSVd2tqtuBdZVtD6q6A9DQwjUx8W09vK/x4cHw7Imw9j1Q+zjWdkVFRbRs2dKSQ5SJCC1btgz6yi2QBKEiUnlNcpzHAVMCLG9qg19/gtcucHopeROOrqqxMOJG2Loanj/FSRZfPu900TW1liWH2AjlfQ+kiukk3CsFVfW8nbYlcFXQRzTRV1YCL58LFeVwyRJo2SXWEQXH3/0XQ66Cwy+DlbNh0YMw9xJ4/xYY8Ddn2I5GLaIfr6nV4uPjOfTQfVPez5kzh3HjxrFokZ8OEX48/fTTjBo1inbt2oUrxFojkG6uP3m+FpGWwA5V3YgzJpOp7d6bBhtz4bRn6l5ygOq7siYkQfY45/6N9Qudm/0+uA0+vtdZXznCrCfrIlvr5dy2gG279h8qPqNJUo2G6k9NTWXFihV/WBZqcgAnQRxyyCH1MkEE2s21uYg8KCIfAQ8Bb4vIdBFpHNnwTI2tfsO54WzA+dDrxFhHE1ki0OVPcNYrcPFiOPQU78kBrItsHeAtOfhbXhNNmjQBYOHChQwbNoxTTz2VHj16MH78eNRt21q2bBlHHXUU/fv355hjjmHTpk3Mnj2b3Nxcxo8fT3Z2Nnv27CErK4tt27YBkJuby7BhwwCYNm0akyZNYtiwYXTu3Jl///vfe4///PPPM2DAALKzs7ngggsoLy8P++8YikC6uTYD3gL+rqqXeiz/E3CHiLwE5KnqLxGL0oTmlx+cG83a9YVRt8U6mujKPBhOeMhpkzC10s2v57Hqf4UhlR376Odel/dsl85Nf/Hfq23Pnj1kZ2cD0KlTJ1577bU/rP/yyy/Jy8ujXbt2HHHEEXz22WcMHDiQyy67jLlz59KqVStmzZrF9ddfz/Tp03nwwQe55557yMnJqTbuNWvW8OGHH7Jz5066d+/ORRddxNq1a5k1axafffYZiYmJXHzxxcyYMYNzzjknsDcjggJpg/gHcI+qfigizwGDgG1ABvANIMANwJURi9IEr6x43xwMpz1t4xkZ4/JWxeRpwIABdOjQAYDs7Gw2bNhAs2bNWLlyJSNHOlVb5eXltG3bNuhjH3fccSQnJ5OcnExmZiZbtmzh/fffZ9myZRx2mDPm2Z49e8jMrB2dRgJJEENVtbIxuhg4U1VzRaQfcBHwKfBApAI0IXr3Bti0AsbOgOZZsY7GmP1U900/a+qbPtfNuuDwcIezV3Lyvi9T8fHxlJWVoar06tWLzz/3fuXiKSEhgYqKCoD9upX62veECRO4/fbbw/QbhE8gbRApsq9/VD/gK/f5SqCfqlZEJDITurzXYMljzmioBx8f62hqr/LSWEdg6oju3btTUFCwN0GUlpaSl5cHQFpaGjt37ty7bVZWFsuWLQPglVeqmWgLGDFiBLNnz2brVqdd7JdffuHHH38M968QkkASxBJghPv8v8C7IvJPYD7wqIgcBuRFKD4TrO3rYO5l0D7HGVupofN3f8f8v0cvDhO0jCZJQS2PpKSkJGbPns21115Lnz59yM7O3tvzaeLEiVx44YV7G6lvuukmpkyZQk5ODvHx8dXuu2fPntx2222MGjWK3r17M3LkSDZt2hTpXykgotXcfSoinYGXgONUdYuIZACdgfU4CWYeMEFVv410sP7k5ORobm5uLEOIvdIiePJo+PVnuPATaHZArCOqveZf73SHHfMg9Ds71tE0GKtXr+bggw+OdRgNlrf3X0SWqarXFvZA7oNYLyKXAPNE5F1gMVCOM5vcscBVsU4OxjX/Otj8DZw5y5JDdY6+2RnH6c0roVWPfZMiGWP2Cug+CFX9Ajgc+Bg4GDgEJ1EMVtVPIheeCdg3syF3OgyeDN1Hxzqa2i8+AU6dDuntYNZZUFg7LumNqU0C6cUEgNsYvcB9mNpk2/fw+hToONAZm8gEplELOONFeOJoJ0lMfBMSU8K3f1+z99ld3KaOqPYKQkR2ikihx89Cz9fRCNL4UbrHud8hPglOfQriAx/r3QCte8JJjzhDkbx5VXhHhPV1t7bdxW3qiEDaINKiEYgJkK9vpSnNoGn7/Zeb6vUcA0OvgY/vgrZ9YOD5sY7ImFohkKE2/A6HaUNsRJmvb59Fv0Y1jHpnmNvA/85UZ5iOTkNqtr/ysvDEZUwMBdJIvQzIdX9WfTTwfqWm3oiLg5Mfc0a7fXmCM39GKFRh1Tz476Dwxmei6o477mDGjBkBbfv000/TqlUrsrOz6dmzJ48//nhQx5ozZw6rVq0KOsZ58+Zxxx13BF0uGIFUMQU+P50xdVlKutNo/fhwmDkOJr0LSY0CL7/hU1hwk9OekdE9cnE2FDFs5J8/fz4vvfRSwNuPHTuWBx98kK1bt9KrVy/GjBlD69b75lMvKysjIcH76XbOnDkcf/zx9OzZc791/sqNGTOGMWPGBBxjKIKaEc4d9nuAiAytfFSz/WgR+VZE1orIVC/rJ4pIgYiscB9/9Vg3QUS+dx8TgonTmJBlHASnPAGbVzqTDwXSaL15Jcw4DZ4+Dgr/B2P+Axct8n0Xd3ySM3mT8S8Cjfx333333mG2r7jiCoYPHw7ABx98wPjx4wEoLCykpKSEVq1asWHDBoYPH07v3r0ZMWIEP/3k/8oyMzOTLl268OOPP+69w3rgwIFcc801rFu3jtGjR9O/f3+GDBnCmjVrWLRoEfPmzePqq68mOzubdevWMWzYMC6//HJycnJ44IEHeP311xk4cCB9+/bl6KOPZsuWLYBz5XLppc4A2xMnTmTy5MkMHjyYzp07M3v27JDfI08Bd3N1T95TgA7ACpxRXT8HhvvYPh5n7oiRQD6wVETmqWrVa6lZnsOIu2VbADcBOTiz2S1zy+4INF5jQtZtlNNd+P2boW1vOPIK79vt+BE+/Cd8Pcu5+hh5izPvRmKqs97bt9wvHoO3r3bu5P5zZKsHar23pzrtPqF46jjvy9sc6vd9HTJkCPfeey+TJ08mNzeX4uJiSktL+eSTTxg61Pm++9577zFihDO60GWXXcaECROYMGEC06dPZ/LkycyZM8fn/tevX8/69es56KCDAMjPz2fRokXEx8czYsQIHnnkEbp27coXX3zBxRdfzAcffMCYMWM4/vjjOfXUU/fup6SkhMqRIXbs2MHixYsREZ544gnuuusu7r333v2OvWnTJj799FPWrFnDmDFj/rC/UAWcIHCSw2HAYlX9k4j0AP7pZ/sBwFpVXQ8gIjOBE4BAKtuOARZUNoCLyAJgNPBiEPHWPxUVEJcIFV4Gmaurc0rXVkde4Zy83pvmPKpKTHWuAiQOjpgCR14Oqc2r3+/A82HHD7D4v84ou4MuDG/cxq/+/fuzbNkyCgsLSU5Opl+/fuTm5vLJJ5/svbJ45513OPfccwH4/PPPefXVVwE4++yzueaaa7zud9asWXz66ackJyfz6KOP0qKF07fntNNOIz4+nl27drFo0SJOO+20vWWKi33PnT527Ni9z/Pz8xk7diybNm2ipKSETp281/qfeOKJxMXF0bNnz71XGTUVTIIoUtUiEUFEklV1jYj4q2htD/zs8TofGOhlu1PcqqrvgCtU9WcfZa0P59InnORw4sPOFJsmckTghAch71Xv60v3QL9z4KipwXcvHnWb0wj+zlRnSJQex9Y83rqouiuoaU19rzvX91Dg/iQmJtKpUyeefvppBg8eTO/evfnwww9Zu3bt3jGKlixZwsMPPxzUfivbIKpq3NiZdLOiooJmzZr5nYfCWzlwrmKuvPJKxowZw8KFC5k2bZrXMp5DiVc3xl6ggmmDyHdnl5sDLBCRuUBNx6R9HchS1d44d2g/E0xhETlfRHJFJLegoKCGodRyOzY432QPOtqZe9lEXlI1M+qO+U9o957ExcPJjzsz/b1yHmxcHlp8JiRDhgzhnnvuYejQoQwZMoRHHnmEvn37IiLk5eXRo0ePvaOwDh48mJkzZwIwY8YMhgwJrftzeno6nTp14uWXXwacE/hXXzkzJ1QdLryq3377jfbtnc/ZM88EdYqssYAThKqepKq/quo0nFnmngRO9FNkI9DR43UHd5nnPrerauV11hNA/0DLuuUfU9UcVc1p1apVoL9K3aMK8yY71RnH3+98uzV1W1IjGDcLGmfAC2ND71Zbn/mqNq1hdeqQIUPYtGkThx9+OK1btyYlJWXvif/tt99m9Oh9Y5n95z//4amnnqJ3794899xzPPBA6HOjzZgxgyeffJI+ffrQq1cv5s6dC8AZZ5zB3XffTd++fVm3bt1+5aZNm8Zpp51G//79ycjICPn4oah2uO+9G4oMwpl7eqf7Oh042B3Iz9v2CTjVRiNwTu5LgXGqmuexTVtV3eQ+Pwm4VlUHuY3Uy3AmKAJYDvT3d1NevR7ue9nTzlhLx/8LcibFOpqGxV81x7Tfar7/rWvgyVGQ3hYmzYfUZjXfZy1W24f7HjlyJM8++2xI04nWBcEO9x1MFdPDwC6P17vcZV6pahlwKc7EQquBl1Q1T0RuEZHKzruTRSRPRL4CJgMT3bK/ALfiJJWlwC0N9o7t3/Jh/g2QNQT6TYx1NCbcMnvA2OeciZ5eOhvKSmIdUYO2YMGCepscQhFMI7Wox+WGqla4Vwk+qepbwFtVlt3o8fw64DofZacD04OIr/5RhTeuAC136rvjgrptxYRD40zfN2uFS+ejYMy/Yc5F8MblcMJDVo1oaoVgEsR6EZnMvquGi3FmlTOR8vUs+P5dGH0ntLAb2mMiWsNyZ49z7qv46A5o3gmOujo6xzXGj2C+kl4IDMZpT6jssmrDXkbKzi3w9rXQcZBz85Wp/4ZNhd5nwIe3wdeBD/NQ14SrC6YJTijvezATBm0Fzgj6CCZ4qs5UmKV7nL74VrXUMIg4VYmFG51hPtLbQdaRsY4qrFJSUti+fTstW7ZErBotalSV7du3k5IS3IRYwQy10Q2neqm1qh4iIr2BMap6W3ChmmrlvQZr3nDmTc7oGutoTDQlJDmN1nd1ccZ2qqqOz0bXoUMH8vPzqff3LdVCKSkpdOjQIagywbRBPA5cDTwKoKpfi8gLgCWIcPp9G7x1NbTrB4dfWv32pv5Jbe50TPCmjs9GV3kns6kbgqm7aKSqS6oss1lRwu3ta6HoN6cnS3ww+dsYY8IrmASxTUS64IyuioicCmyKSFQN1Zo3YeVsOOoaZ65kY4yJoWC+ol4CPAb0EJGNwA/A+IhE1RDt2QFvXAmtD/U9vLQxxkRRML2Y1gNHi0hjnCuP3Ti9mmo6YJ8BZ36A3wuc8XniE2MdjanNSnYHN9OdMSGqtopJRNJF5DoReVBERuIkhgnAWuD0SAfYIHz/HqyY4cwp0C471tGY2sDfndozz3S6QBsTYdUO1ucO670DZ/a4EUAmIMAUVV0R6QADVacG6/M5124ruHpt9OMxdceKF2DOxXDQCBg7AxKD69duTFX+BusLpIqps6oe6u7oCZyG6QNUtSiMMTYsPufatb7hphrZ46CiDOZdBi+d49wzkZBcfTljQhBIL6a981uqajmQb8nBmBjqdw4cdx98Px9ePhfKvUxBa0wYBHIF0UdECt3nAqS6rwVQVU2PWHTGGO8OO8+ZE/vtq2H2JDh1unVuMGFXbYJQ1fhoBGKMCdLA8505yuf/HV4935nG1G6uNGFkn6ZoKyqsfhtjAnX4JU4V03s3QVwCnPSIM+e1MWFgCSKayophlp97C8M5CY1pOI683LmS+OA2p5ppjI0AbMLDEkS0VFTAaxfADx/DSY9Bn7GxjsjUJ0OvhvIyZ8KhuHg4/gFLEqbGLEFEgyq8c60zjPfIWy05mMgYNtXpAvvJPbD82f3X1/Ghwk302VeMaPjkXljymDN89xGTYx2Nqa9EYPgNvtfX8aHCTfRZgoi05c/BB7fCoac7Vw/GRJLN0mbCyKqYIunbd+D1KdBluDO/g9UJm1h7dChkdHMfXZ2fLbr8ccgOn0PBWBVVQ2MJIlJ+XgIvT4Q2h8LpzzpTSRoTa40y4Kcv4JuXPRYKND9wX+LwORSMVVE1NJYgIqHgW3jhdEhvC+NnQ3JarCMyxnH2q87Pkt2wfS1s+w62fb/v5w8fxzY+U6tYggi33zbCcydDXCKc9So0aRXriExD0zjTdxVRpaRG0La38/BUUQG3NI9sfKbOsAQRTnt2wPOnOHNKn/smtLDJ2U0M1KSdwNrJjAf7NIRL6R548Uznsv2MGdC2T6wjMib8dv8S6whMFNkVRKh89fRITofOR0U/HmPCxVcVFcD0Y5x2teYHRjcmExOWIELl6x+o2AbjM3WcryqqDZ85050+ORLGv2xXyQ2AVTEZYwKTdQRMehfik+CpY2Hte7GOyESYJQhjTOAye8B5C6B5J5hxOnz5fKwjMhFkCcIYE5z0tnDuW9BpKMy9BBbe6QxIaeodSxDGmOClpDvtEH3OhIX/hNcnO8ONm3rFGqlDFcjNSMbUZ/GJcOLD0LQDfHw37NwMpz4FyU1iHZkJE0sQobJBy4zZN8R4ent480q480BnToqqbKC/OsmqmIwxNZdzLpzxovfkADbQXx1lCcIYEx7dR8c6AhNmEU0QIjJaRL4VkbUiMtXPdqeIiIpIjvs6S0T2iMgK9/FIJOM0xhizv4i1QYhIPPAQMBLIB5aKyDxVXVVluzRgCvBFlV2sU9XsSMVnjImyR4dCvwlw6GlOLyhT60XyCmIAsFZV16tqCTATOMHLdrcCdwJFEYzFGBNrFeVOQ/a93WHOJc6kWnb/RK0WyV5M7YGfPV7nAwM9NxCRfkBHVX1TRK6uUr6TiHwJFAI3qOonVQ8gIucD5wMccMAB4YzdGBMKf92/L/wUNi6H5U/DN6/Aiuchsyf0Owd6j4WHBtpUp7VMzLq5ikgccB8w0cvqTcABqrpdRPoDc0Skl6r+YSQ8VX0MeAwgJyfHvooYE2vVncg79Hcex/wTVr4Cy56Bd6bCgpugvNh7GesBFTORrGLaCHT0eN3BXVYpDTgEWCgiG4BBwDwRyVHVYlXdDqCqy4B1QLcIxmqMiabkNOg/Ec7/0Lmy6HdOrCMyXkQyQSwFuopIJxFJAs4A5lWuVNXfVDVDVbNUNQtYDIxR1VwRaeU2ciMinYGuwPoIxmqMiZU2h8Jx98Q6CuNFxBKEqpYBlwLzgdXAS6qaJyK3iMiYaooPBb4WkRXAbOBCVbWprIxpqGwmu5gQrSe9CHJycjQ3NzfWYRhjQjWtqe91qS3g6GnQ92ybNzvMRGSZquZ4W2fvtDGmdvA10GVqC2jV3Rkxdvoo2PRVdONqwGywPmNM7eCvB5QqfDUTFvwDHhsGh/0V/nQ9pDaLVnQNkl1BGGNqPxHIPhMuzYWc82DpE/BgDqx40W62iyBrgzDG1D3/W+Hclb1xGRwwGArWwB4vDdl2k121rA3CGFO/tMuG896DvzwABau9Jwewm+xqyBKEMaZuiotzbra7dFmsI6m3LEEYY+q2xi39r//8v05VVHlpdOKpR6wXkzGmfpt/nfMzsRG07w8HDIKOg6DjYZDSFO7uaoME+mAJwhhTv125Gn5aDD9/4fz85D7QckCc0WR9tVNY+4UlCGNMPeBvmPH0dnDIyc4DoHgXbMyFn76AnxfD1rzoxlqHWIIwxtR9wVQFJTeBzsOcB/gf4qOBs0ZqY4wxXlmCMMYYXzYuj3UEMWUJwhjTsPkaJFDi4NkTIb/hjtBgbRDGmIbNV/vFrz/DM8fDcyfBWa9AxwHRjasWsCsIY4zxpllHmPgWNM5wksSPn8c6oqizBGGMMb40be8kibQ28PwpsOGzWEcUVVbFZIwx/qS3hYlvwjNjYMapMG4WdBoann3X8ru47QrCGGOqk9YGJr4BzQ6EGafD+oXh2W8tv4vbEoQxxgSiSSZMeB1adIYXxsLa92MdUcRZgjDGmEA1aeUkiZZd4cUz4fsFwZVXhS158OHt8NCgyMQYRtYGYYwxwWjcEibMg+dOhJnj4PTnoPto39urwqavYNVcWD0Ptq8FBA48IloRh8wShDHGBKtRCzhnrtP99cWx3rdJbQ59z3ISw68/gcRDpyEw6GLocTykta7140BZgjDGmFCkNoez58CdB3pfv2cHLH7EGRRw6DXQ/dj9JzfyNQptfLJz5SES7qiDYgnCGGNCldrM//qr1/rfxltX1k/vh/dugs/uhyOvCD22MLBGamOMiZTqEog3R0yBXifDezfD2vfCHlIwLEEYY0xtIgInPOjMdjf7PPhlfcxCsQRhjDG1TVJjOGOG83zmWVDye0zCsARhjDE14Wu4cF/LA9WiE5w6HQpWw9xLnEbrKLNGamOMqYlIjpl00AgYcZPTaN02G468PHLH8sKuIIwxpjY7Ygr0Ognevznqw3tYgjDGmNpMBE54CFodDLMnRbXR2hKEMcbUdjFqtLYEYYwxdUEMGq0tQRhjTF1R2Wid9xp89kDED2cJwhhj6pIoNlpbN1djjKlLKhutV82D50/ef30YpyuN6BWEiIwWkW9FZK2ITPWz3SkioiKS47HsOrfctyJyTCTjNMaYOiWpMWi593VhnK40YlcQIhIPPASMBPKBpSIyT1VXVdkuDZgCfOGxrCdwBtALaAe8JyLdVH29I8YYY8ItklcQA4C1qrpeVUuAmcAJXra7FbgTKPJYdgIwU1WLVfUHYK27P2OMMVESyQTRHvjZ43W+u2wvEekHdFTVN4Mt65Y/X0RyRSS3oKAgPFEbY4wBYtiLSUTigPuAq0Ldh6o+pqo5qprTqlWr8AVnjDEmor2YNgIdPV53cJdVSgMOARaKM61eG2CeiIwJoKwxxjRsvqYrrekosh4imSCWAl1FpBPOyf0MYFzlSlX9DciofC0iC4H/U9VcEdkDvCAi9+E0UncFlkQwVmOMqVsiOYqsK2IJQlXLRORSYD4QD0xX1TwRuQXIVdV5fsrmichLwCqgDLjEejAZY0x0icZgEopIyMnJ0dzc3FiHYYwxdYqILFPVHG/rbKgNY4wxXlmCMMYY45UlCGOMMV7VmzYIESkAfqzBLjKAbQ2obCyPXRfLxvLY9jvXjbKxPHZNyh6oqt5vJFNVezhJMrchla2rcdv7Zb9zbS1bl+P29bAqJmOMMV5ZgjDGGOOVJYh9HmtgZWN57LpYNpbHtt+5bpSN5bFrGrdX9aaR2hhjTHjZFYQxxhivLEEYY4zxqsEniEDnzfZRdrqIbBWRlSEct6OIfCgiq0QkT0SmBFE2RUSWiMhXbtmbQzh+vIh8KSJvBFlug4h8IyIrRCTowa9EpJmIzBaRNSKyWkQOD7Bcd/eYlY9CEbk8iONe4b5XK0XkRRFJCaLsFLdcXiDH9Pa5EJEWIrJARL53fzYPouxp7rErPOdtD7Ds3e57/bWIvCYizYIoe6tbboWIvCsi7YI5tse6q9w55zMCLSsi00Rko8ff+9hgjisil7m/d56I3BXEcWd5HHODiKwIomy2iCyu/N8QEZ+zYPoo30dEPnf/v14XkXQv5byeNwL9fAUtEn1n68oDZ5TZdUBnIAn4CugZRPmhQD9gZQjHbgv0c5+nAd8FemxAgCbu80Sc+bwHBXn8K4EXgDeCLLcByKjBe/4M8Ff3eRLQLMS/22acG3wC2b498AOQ6r5+CZgYYNlDgJVAI5zRj98DDgr2cwHcBUx1n08F7gyi7MFAd2AhkBPkcUcBCe7zO4M8brrH88nAI8Ec213eEWdE5x99fW58HHsazvD/Qf8PAn9y/07J7uvMYGL2WH8vcGMQx30X+LP7/FhgYZBxLwWOcp9PAm71Us7reSPQz1ewj4Z+BRHovNleqerHwC+hHFhVN6nqcvf5TmA1XqZV9VFWVXWX+zLRfQTc20BEOgDHAU8EFXQNiUhTnH+MJwFUtURVfw1hVyOAdaoazJ3zCUCqiCTgnOz/F2C5g4EvVHW3qpYBHwEn+yvg43NxAk5yxP15YqBlVXW1qn5bXaA+yr7rxg2wGGfyrUDLFnq8bIyfz5if/4V/AdeEWLZaPspeBNyhqsXuNl5m1fF/XBER4HTgxSDKKlD5rb8pfj5jPsp3Az52ny8ATvFSztd5I6DPV7AaeoIIaO7rSBORLKAvzpVAoGXi3cvfrcACVQ24LHA/zj9tRRBlKinwrogsE5HzgyzbCSgAnnKrt54QkcYhxHAGPv5xvVHVjcA9wE/AJuA3VX03wOIrgSEi0lJEGuF8M+xYTRlvWqvqJvf5ZqB1CPuoqUnA28EUEJH/JyI/A+OBG4MsewKwUVW/Cqach0vdKq7pQVaZdMP5m30hIh+JyGEhHHsIsEVVg5mV53Lgbvf9uge4Lshj5rHvC+ppVPM5q3LeiMjnq6EniJgTkSbAK8DlVb6x+aWq5aqajfONcICIHBLg8Y4HtqrqslDiBY5U1X7An4FLRGRoEGUTcC6rH1bVvsDvOJfDARORJGAM8HIQZZrj/ON1wpmhsLGInBVIWVVdjVM18y7wDrACqNHkVerUA0S1f7mIXI8z+daMYMqp6vWq2tEtd2kQx2sE/J0gk4qHh4EuQDZOUr83iLIJQAtgEHA18JJ7RRCMMwniS4jrIuAK9/26AvdKOQiTgItFZBlO9VGJrw39nTfC+flq6AkipnNfi0gizh95hqq+Gso+3CqaD4HRARY5AhgjIhtwqtSGi8jzQRxvo/tzK/AaTjVdoPKBfI+rndk4CSMYfwaWq+qWIMocDfygqgWqWgq8CgwOtLCqPqmq/VV1KLADp943WFtEpC2A+9NrtUckiMhE4HhgvHvyCMUMvFR5+NEFJyF/5X7WOgDLRaRNIIVVdYv7JagCeJzgP2evulWxS3CulL02kHvjVkOeDMwK4pgAE3A+W+B8gQkmZlR1jaqOUtX+OMlpnY/4vJ03IvL5augJYu+82e430zMAn1OhhpP7jeZJYLWq3hdk2VaVvVFEJBUYCawJpKyqXqeqHVQ1C+f3/UBVA/o2LSKNRSSt8jlOA2jAPbhUdTPws4h0dxeNwJlWNhihfLP7CRgkIo3c930ETt1tQEQk0/15AM6J44Ugjw/O52qC+3wCMDeEfQRNREbjVCeOUdXdQZbt6vHyBAL8jAGo6jeqmqmqWe5nLR+ncXVzgMdu6/HyJIL4nAFzcBqqEZFuOJ0hghnp9GhgjarmB1EGnDaHo9znw4GgJo32+JzFATcAj3jZxtd5IzKfr3C0dNflB06d8nc42fr6IMu+iHP5W4rzD3BeEGWPxLkM/Bqn2mIFcGyAZXsDX7plV+Kjp0UA+xlGEL2YcHp7feU+8oJ9v9x9ZAO5buxzgOZBlG0MbAeahnDcm3FOcCuB53B7uARY9hOcRPYVMCKUzwXQEngf56TxHtAiiLInuc+LgS3A/CDKrsVpZ6v8jHntieSj7Cvu+/U18DrQPtT/Bfz0fvNx7OeAb9xjzwPaBlE2CXjejX05MDyYmIGngQtD+BsfCSxzPydfAP2DLD8F51z0HXAH7kgXVcp5PW8E+vkK9mFDbRhjjPGqoVcxGWOM8cEShDHGGK8sQRhjjPHKEoQxxhivLEEYY4zxyhKEMT6ISLn8cQTZqe7yheKMAPyViHxWeV+HiCSJyP3ijAz8vYjMdce9qtxfGxGZKSLr3KFK3hKRbiKSJfuPRjpNRP7PfT7IHTZihTgj4E6L4ttgGrCEWAdgTC22R53hTLwZr6q57nhUd+MM//FPnCESuqtquYicC7wqIgPdMq8Bz6jqGeAM74wzZs7P++/+D54BTlfVr0QkHmdkV2MizhKEMTXzMXC5O/bQuUAnVS0HUNWnRGQSzl21CpSq6t67Y9UdxM4ddM2fTJybqnD3Hezd58aExBKEMb6lyh8njLldVauOz/MXnDt+DwJ+0v0HXMwFernP/Q2Q2KXKsdrgjAgKzpDZ34rIQpwBA59R1aJAfwljQmUJwhjf/FUxzRCRPThDSFwG1HQGr3Wex/JsZ1DVW0RkBs7YV+NwxqMaVsPjGVMtSxDGhGa8qu6dclVEfgEOEJE0dSZyqdQfqJzW9dRQD6aq64CHReRxoEBEWqrq9lD3Z0wgrBeTMWGgqr/jNCbf5zYkIyLn4Mxe94H7SPacZElEeovIkOr2LSLHecxn0BVnPopfw/sbGLM/SxDG+JZapZvrHdVsfx1QBHwnIt/jzAp2krpwRmU92u3mmgfcjjP7V3XOxmmDWIEzyun4yoZwYyLJRnM1xhjjlV1BGGOM8coShDHGGK8sQRhjjPHKEoQxxhivLEEYY4zxyhKEMcYYryxBGGOM8er/AxJQFOw0o8vUAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 208 ms\n"
]
}
],
"source": [
"value1 = ft_recall_30\n",
"value2 = cs_recall_30\n",
"name = \"Recall@30\"\n",
"\n",
"length = len(value1)\n",
"plt.plot(range(length), value1, \"-s\", label=\"Finetune\")\n",
"plt.plot(range(1, length), value2, \"-s\", label=\"w/o Pretrain\")\n",
"\n",
"\n",
"plt.xticks(range(length))\n",
"plt.ylabel(name)\n",
"plt.xlabel(\"EPOCHS\")\n",
"plt.legend(loc='center right')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "united-recall",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: finetune/autoDiag/fine-tune.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "wicked-finder",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 205 µs\n"
]
}
],
"source": [
"%load_ext autotime"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "random-fluid",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 1.51 s\n"
]
}
],
"source": [
"import numpy as np\n",
"import _pickle as pickle\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "diverse-vegetation",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========LOADING DATA==========\n",
"time: 14.1 s\n"
]
}
],
"source": [
"print(\"==========LOADING DATA==========\")\n",
"age_seq = pickle.load(open(\"../data/new_age_seq\",\"rb\"))\n",
"sex_seq = pickle.load(open(\"../data/new_sex_seq\",\"rb\"))\n",
"\n",
"util_seq = pickle.load(open(\"../data/new_util_seq\",\"rb\"))\n",
"code_seq = pickle.load(open(\"../data/new_code_seq\",\"rb\"))\n",
"date_seq = pickle.load(open(\"../data/new_date_seq\",\"rb\"))\n",
"label_seq = pickle.load(open(\"../data/new_label_seq\",\"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "judicial-aircraft",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 851 ms\n"
]
}
],
"source": [
"k = 10000\n",
"age_seq = age_seq[:k]\n",
"sex_seq = sex_seq[:k]\n",
"\n",
"util_seq = util_seq[:k]\n",
"code_seq = code_seq[:k]\n",
"date_seq = date_seq[:k]\n",
"label_seq = label_seq[:k]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "disabled-favor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 3.5 s\n"
]
}
],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "hollow-environment",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 10.5 ms\n"
]
}
],
"source": [
"class DataGenerator(tf.keras.utils.Sequence):\n",
" def __init__(self, seqs, vocab_sizes, list_IDs, max_visit, max_code, batch_size=100, shuffle=True):\n",
" self.seqs = seqs\n",
" self.code_vocab = vocab_sizes[0]\n",
" self.cat_vocab = vocab_sizes[1]\n",
" self.list_IDs = list_IDs\n",
" self.max_visit = max_visit\n",
" self.max_code = max_code\n",
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.on_epoch_end()\n",
"\n",
" def __len__(self):\n",
" 'Denotes the number of batches per epoch'\n",
" return int(np.ceil(len(self.list_IDs) / self.batch_size))\n",
"\n",
" def __getitem__(self, index):\n",
" 'Generate one batch of data'\n",
" indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]\n",
" list_IDs_temp = [self.list_IDs[k] for k in indexes]\n",
" X, y = self.__data_generation(list_IDs_temp)\n",
" return X, y\n",
"\n",
" def on_epoch_end(self):\n",
" 'Updates indexes after each epoch' \n",
" self.indexes = np.arange(len(self.list_IDs))\n",
" if self.shuffle == True:\n",
" np.random.shuffle(self.indexes)\n",
"\n",
" def __data_generation(self, list_IDs_temp):\n",
" 'Generates data containing batch_size samples' \n",
" demo_feature, code_feature, util_feature, date_feature, cat_feature = self.seqs\n",
" batch_demo, batch_code, batch_util, batch_date, batch_cat = [], [], [], [], [] \n",
" for i, ID in enumerate(list_IDs_temp):\n",
" batch_demo.append(demo_feature[ID])\n",
" batch_code.append(code_feature[ID])\n",
" batch_util.append(util_feature[ID])\n",
" batch_date.append(date_feature[ID])\n",
" batch_cat.append(cat_feature[ID])\n",
" \n",
" batch_demo_feature = np.array(batch_demo)\n",
" batch_code_feature = self.code_padding(batch_code)\n",
" batch_util_feature = self.date_padding(batch_util)\n",
" batch_date_feature = self.date_padding(batch_date)\n",
" \n",
" batch_code_label = self.code_labelling(batch_cat) # predict cat instead\n",
" \n",
" dic = (\n",
" {\n",
" 'demo_feature': batch_demo_feature,\n",
" 'code_feature': batch_code_feature,\n",
" 'util_feature': batch_util_feature,\n",
" 'date_feature': batch_date_feature,\n",
" },\n",
" {\n",
" 'code_label': batch_code_label,\n",
" })\n",
" return dic\n",
" \n",
" def date_padding(self, seq):\n",
" seq = [x[:-1] for x in seq]\n",
" \n",
" pad_seq = np.zeros((len(seq), self.max_visit))\n",
" for i, p in enumerate(seq):\n",
" pad_seq[i][:len(p)] = p[:self.max_visit]\n",
" return pad_seq\n",
" \n",
" def code_padding(self, seq):\n",
" seq = [x[:-1] for x in seq]\n",
" \n",
" X = np.zeros((len(seq), self.max_visit, self.max_code))\n",
" for i, p in enumerate(seq):\n",
" if len(p) > self.max_visit: \n",
" p = p[:self.max_visit]\n",
" for j, claim in enumerate(p):\n",
" claim = claim[:self.max_code]\n",
" X[i][j][:len(claim)] = claim\n",
" return X\n",
" \n",
" def code_labelling(self, seq):\n",
" seq = [x[-1] for x in seq]\n",
" \n",
" X = np.zeros((len(seq), self.cat_vocab))\n",
" for i, claim in enumerate(seq):\n",
" for c in claim:\n",
" X[i][c] = 1\n",
" return X\n",
"\n",
" def cat_labelling(self, seq):\n",
" seq = [x[:-1] for x in seq]\n",
" \n",
" X = np.zeros((len(seq), self.max_visit, self.cat_vocab))\n",
" for i, p in enumerate(seq):\n",
" if len(p) > self.max_visit: \n",
" p = p[:self.max_visit]\n",
" for j, claim in enumerate(p):\n",
" for c in claim:\n",
" X[i][j][c] = 1\n",
" return X"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "quick-webster",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 15.8 ms\n"
]
}
],
"source": [
"def create_code_mask(code_seq):\n",
" code_mask = tf.cast(tf.math.not_equal(code_seq, 0), tf.float32)\n",
" return code_mask[:,:,:,tf.newaxis]\n",
"\n",
"def create_visit_mask(seq):\n",
" visit_mask = tf.cast(tf.math.not_equal(seq, 0), tf.float32)\n",
" return visit_mask[:,:]\n",
"\n",
"def scaled_dot_product_attention(Q, K, V, Q_masks, K_masks):\n",
" d_k = K.get_shape().as_list()[-1] # d_model/h\n",
"\n",
" outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) # (h*N, T_q, T_k)\n",
" outputs /= d_k ** 0.5\n",
"\n",
" padding_num = -1e+7\n",
" K_masks = tf.expand_dims(K_masks, 1) # (h*N, 1, T_k)\n",
" K_masks = tf.tile(K_masks, [1, tf.shape(Q)[1], 1]) # (h*N, T_q, T_k)\n",
" paddings = tf.ones_like(outputs) * padding_num\n",
" outputs = tf.where(tf.equal(K_masks, 0), paddings, outputs) # (h*N, T_q, T_k)\n",
"\n",
" outputs = tf.nn.softmax(outputs)\n",
" Q_masks = tf.expand_dims(Q_masks, -1) # (h*N, T_q, 1)\n",
" Q_masks = tf.tile(Q_masks, [1, 1, tf.shape(K)[1]]) # (h*N, T_q, T_k)\n",
" outputs = outputs * tf.cast(Q_masks, dtype=tf.float32)\n",
"\n",
" return tf.matmul(outputs, V) # [h*N, T_q, d_model/h]\n",
"\n",
"class multihead_attention(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, num_heads, name=\"multihead_attention\"):\n",
" super(multihead_attention, self).__init__(name=name)\n",
" self.num_heads = num_heads\n",
" self.d_model = d_model\n",
"\n",
" assert d_model % self.num_heads == 0\n",
"\n",
" self.query_dense = layers.Dense(units=d_model, use_bias=False)\n",
" self.key_dense = layers.Dense(units=d_model, use_bias=False)\n",
" self.value_dense = layers.Dense(units=d_model, use_bias=False)\n",
" self.add =layers.Add()\n",
" self.norm = layers.LayerNormalization()\n",
" \n",
" def call(self, queries, keys, values, query_masks, key_masks):\n",
" Q = self.query_dense(queries)\n",
" K = self.key_dense(keys)\n",
" V = self.value_dense(values)\n",
"\n",
" # Split and concat\n",
" Q_ = tf.concat(tf.split(Q, self.num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h)\n",
" K_ = tf.concat(tf.split(K, self.num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)\n",
" V_ = tf.concat(tf.split(V, self.num_heads, axis=2), axis=0) # (h*N, T_v, d_model/h)\n",
" query_masks = tf.tile(query_masks, [self.num_heads, 1]) # (h*N, T_q)\n",
" key_masks = tf.tile(key_masks, [self.num_heads, 1]) # (h*N, T_k)\n",
"\n",
" # Attention\n",
" outputs = scaled_dot_product_attention(Q_, K_, V_, query_masks, key_masks) # (h*N, T_q, d_model/h)\n",
"\n",
" # Restore shape\n",
" outputs = tf.concat(tf.split(outputs, self.num_heads, axis=0), axis=2) # (N, T_q, d_model)\n",
"\n",
" # Residual connection\n",
" outputs = self.add([queries, outputs])\n",
" outputs = self.norm(outputs)\n",
" \n",
" return outputs\n",
"\n",
"class ffn(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, ffn_dim, name=\"ffn\"):\n",
" super(ffn, self).__init__(name=name)\n",
" self.ffn_dim = ffn_dim\n",
" self.dense1 = layers.Dense(units=ffn_dim, activation=tf.nn.relu, use_bias=False)\n",
" self.dense2 = layers.Dense(units=d_model, use_bias=False)\n",
" self.add =layers.Add()\n",
" self.norm = layers.LayerNormalization()\n",
" \n",
" def call(self, inputs):\n",
" outputs = self.dense1(inputs)\n",
" outputs = self.dense2(outputs)\n",
" outputs = self.add([inputs, outputs])\n",
" outputs = self.norm(outputs)\n",
" return outputs\n",
"\n",
"def cat_recall(y_true, y_pred):\n",
" mask_value = tf.cast(tf.not_equal(tf.reduce_sum(y_true,axis=-1), 0), tf.float32)\n",
" true_positives = tf.cast(tf.reduce_sum(tf.multiply(tf.round(y_pred), y_true), axis=-1), tf.float32)\n",
" possible_positives = tf.cast(tf.reduce_sum(y_true, axis=-1), tf.float32)\n",
" values = true_positives / (possible_positives + 1e-7)\n",
" return tf.reduce_sum(values)/tf.reduce_sum(mask_value)\n",
"\n",
"def cat_loss_fun(y_true, y_pred):\n",
" loss = tf.cast(tf.keras.losses.BinaryCrossentropy(reduction='none')(y_true, y_pred), tf.float32)\n",
" mask = tf.cast(tf.not_equal(tf.reduce_sum(y_true,axis=-1), 0), tf.float32)\n",
" loss = tf.multiply(loss, mask)\n",
" # return tf.reduce_sum(loss)/tf.reduce_sum(mask)\n",
" return loss\n",
"\n",
"def model(\n",
" max_visit,\n",
" max_code,\n",
" max_demo,\n",
" \n",
" demo_vocab,\n",
" code_vocab,\n",
" date_vocab,\n",
" util_vocab,\n",
" cat_vocab,\n",
"\n",
" patient_dim,\n",
" vocab_dim=100,\n",
" model_dim=100,\n",
" ffn_dim=100,\n",
" num_heads=2,\n",
" num_translayer=1,\n",
" \n",
" model_name=\"TransF\"):\n",
" \n",
" demo = layers.Input(shape=(max_demo, ), name=\"demo_feature\") # max_demo = 2, age&sex\n",
" code_seq = layers.Input(shape=(max_visit, max_code), name=\"code_feature\") \n",
" util_seq = layers.Input(shape=(max_visit), name=\"util_feature\")\n",
" date_seq = layers.Input(shape=(max_visit), name=\"date_feature\")\n",
"\n",
" inputs = [demo, code_seq, util_seq, date_seq]\n",
" \n",
" # demo embedding\n",
" demo_emb = layers.Embedding(input_dim=demo_vocab, output_dim=vocab_dim, mask_zero=True, name='demo_embedding')(demo)\n",
" demo_emb = layers.Lambda(lambda x: tf.keras.backend.sum(x, axis=1))(demo_emb) \n",
"\n",
" # code sequence\n",
" code_mask = layers.Lambda(create_code_mask)(code_seq)\n",
" code_emb = layers.Embedding(input_dim=code_vocab, \n",
" output_dim=vocab_dim, \n",
" name='code_embed')(code_seq)\n",
" code_emb = layers.Multiply()([code_emb, code_mask])\n",
" code_emb = tf.reduce_sum(code_emb, axis=2) \n",
"\n",
" \n",
" # visit mask\n",
" visit_mask = layers.Lambda(create_visit_mask)(date_seq)\n",
" \n",
" # util sequence \n",
" util_emb = layers.Embedding(input_dim=util_vocab, output_dim=vocab_dim, mask_zero=True, name='util_embedding')(util_seq)\n",
" util_emb = layers.Multiply()([util_emb, visit_mask[:,:,tf.newaxis]])\n",
" \n",
" # date sequence \n",
" date_emb = layers.Embedding(input_dim=date_vocab, output_dim=vocab_dim, mask_zero=True, name='date_embedding')(date_seq)\n",
" date_emb = layers.Multiply()([date_emb, visit_mask[:,:,tf.newaxis]])\n",
"\n",
" # visit sequence\n",
" visit_emb = layers.Add()([code_emb, date_emb, util_emb]) \n",
"\n",
" demo_emb = tf.expand_dims(demo_emb, 1) # (N, 1, emb_size)\n",
" demo_mask = tf.ones_like(tf.reduce_sum(demo_emb, axis=2), tf.float32) # (N, 1)\n",
" \n",
" for trans_layer in range(num_translayer):\n",
" multihead = multihead_attention(model_dim, num_heads, name=\"multihead_attention-\"+str(trans_layer))(\n",
" queries=tf.concat([demo_emb, visit_emb], 1), \n",
" keys=tf.concat([demo_emb, visit_emb], 1),\n",
" values=tf.concat([demo_emb, visit_emb], 1),\n",
" query_masks=tf.concat([demo_mask, visit_mask], 1),\n",
" key_masks=tf.concat([demo_mask, visit_mask], 1)\n",
" )\n",
" \n",
" visit_emb = ffn(model_dim, ffn_dim, name=\"ffn-\"+str(trans_layer))(multihead) # (N, max_visit, emb_size)\n",
" \n",
" demo_emb = visit_emb[:, :1, :] # (N, 1, emb_size)\n",
" visit_emb = visit_emb[:, 1:max_visit+1, :] # (N, max_visit, emb_size)\n",
" \n",
" patient_embedding = layers.Dense(patient_dim, activation=None, name=\"patient_embedding\")(tf.squeeze(demo_emb, [1]))\n",
" \n",
" # code_label = layers.Dense(units=model_dim, activation=tf.nn.sigmoid)(patient_embedding)\n",
" code_label = layers.Dense(units=cat_vocab, activation=tf.nn.sigmoid, name='code_label')(patient_embedding)\n",
" \n",
" # cat_label = layers.Dense(units=model_dim, activation=tf.nn.sigmoid)(visit_emb)\n",
" # cat_label = layers.Dense(units=cat_vocab, activation=tf.nn.sigmoid, name='cat_label')(visit_emb)\n",
" \n",
" outputs = [code_label]\n",
" return tf.keras.Model(inputs=inputs, outputs=outputs, name=model_name)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dependent-movie",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 6.54 ms\n"
]
}
],
"source": [
"def process_code(seq, vocab2int):\n",
" unseen = []\n",
" new_seq = []\n",
" for p in seq:\n",
" new_p = []\n",
" for v in p:\n",
" new_v = []\n",
" for c in v:\n",
" if c not in vocab2int: \n",
" unseen.append(c)\n",
" continue\n",
" # vocab2int[c] = len(vocab2int)\n",
" new_v.append(vocab2int[c])\n",
" new_p.append(new_v)\n",
" new_seq.append(new_p)\n",
" \n",
" print(\"UNSEEN VOCAB:\",len(set(unseen)), len(unseen))\n",
" return new_seq\n",
"\n",
"def process_util(seq, util2int):\n",
" new_seq = []\n",
" vocab2int = {\"PAD\":0,\"IP\":1,\"RX\":2,\"OP\":3}\n",
" for p in seq:\n",
" new_p = []\n",
" for v in p:\n",
" if \"IP\" in v:\n",
" new_v=1\n",
" elif \"RX\" in v:\n",
" new_v=2\n",
" else:\n",
" new_v=3\n",
" new_p.append(new_v)\n",
" new_seq.append(new_p)\n",
" return new_seq\n",
" \n",
"def process_demo(age_seq, sex_seq, vocab2int):\n",
" new_seq = []\n",
" for age, sex in zip(age_seq, sex_seq):\n",
" p = []\n",
" assert age in vocab2int\n",
" assert sex in vocab2int\n",
" \n",
" p.append(vocab2int[age])\n",
" p.append(vocab2int[sex])\n",
" new_seq.append(p)\n",
" return np.array(new_seq)\n",
"\n",
"def get_cat(seq,code2cat):\n",
" new_seq = []\n",
" for p in seq:\n",
" new_p = []\n",
" for v in p:\n",
" new_v = []\n",
" for c in v:\n",
" new_c = code2cat[c]\n",
" new_v.append(new_c)\n",
" new_p.append(new_v)\n",
" new_seq.append(new_p)\n",
" return new_seq"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ongoing-manufacturer",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"------LOADING DIC------\n",
"time: 83.4 ms\n"
]
}
],
"source": [
"print(\"------LOADING DIC------\")\n",
"path = \"/Users/xxz005/Desktop/RAW_DATA/code2cat/\"\n",
"\n",
"diag2cat = pickle.load(open(path+\"diag2cat\",\"rb\"))\n",
"proc2cat = pickle.load(open(path+\"proc2cat\",\"rb\"))\n",
"drug2cat = pickle.load(open(path+\"drug2cat\",\"rb\"))\n",
"\n",
"code2cat = {**diag2cat, **proc2cat, **drug2cat}"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "wrong-warrior",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 34 ms\n"
]
}
],
"source": [
"code2int, util2int, demo2int, cat2int = pickle.load(open(\"../../pretraining/model/saveModel/vocabs\",\"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "invalid-silicon",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"UNSEEN VOCAB: 0 0\n",
"UNSEEN VOCAB: 0 0\n",
"time: 667 ms\n"
]
}
],
"source": [
"code_feature = process_code(code_seq, code2int)\n",
"util_feature = process_util(util_seq, util2int)\n",
"demo_feature = process_demo(age_seq, sex_seq, demo2int)\n",
"\n",
"date_feature = date_seq\n",
"\n",
"cat_seq = get_cat(code_seq, code2cat)\n",
"cat_feature = process_code(cat_seq, cat2int) "
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "adverse-arkansas",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 1.27 ms\n"
]
}
],
"source": [
"MAX_VISIT=30\n",
"MAX_CODE=10\n",
"MAX_DEMO=2\n",
"PATIENT_DIM=100\n",
"\n",
"BATCH_SIZE = 500\n",
"TRAIN_RATIO = 0.7\n",
"DATA_SIZE = len(age_seq)\n",
"EPOCHS = 20\n",
"\n",
"params = {\n",
" 'seqs':[demo_feature, code_feature, util_feature, date_feature, cat_feature],\n",
" 'vocab_sizes': [len(code2int), len(cat2int)],\n",
" 'batch_size':100,\n",
" 'max_visit':MAX_VISIT, \n",
" 'max_code':MAX_CODE,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "sorted-wholesale",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 7.39 ms\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"train_IDs, valid_IDs = train_test_split(range(DATA_SIZE), train_size=TRAIN_RATIO, random_state=42)\n",
"train_generator = DataGenerator(list_IDs=train_IDs, shuffle=True, **params)\n",
"valid_generator = DataGenerator(list_IDs=valid_IDs, shuffle=False, **params)"
]
},
{
"cell_type": "markdown",
"id": "separated-layout",
"metadata": {},
"source": [
"# Pretrain"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "suspended-settlement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 2.58 s\n"
]
}
],
"source": [
"model_path = \"../../pretraining/model/saveModel\"\n",
"\n",
"model = tf.keras.models.load_model(model_path)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "veterinary-chinese",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"TransF\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"demo_feature (InputLayer) [(None, 2)] 0 \n",
"__________________________________________________________________________________________________\n",
"demo_embedding (Embedding) (None, 2, 100) 2400 demo_feature[0][0] \n",
"__________________________________________________________________________________________________\n",
"lambda (Lambda) (None, 100) 0 demo_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"code_feature (InputLayer) [(None, 30, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"date_feature (InputLayer) [(None, 30)] 0 \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_ExpandDims (TensorF (None, 1, 100) 0 lambda[0][0] \n",
"__________________________________________________________________________________________________\n",
"code_embed (Embedding) (None, 30, 10, 100) 4983600 code_feature[0][0] \n",
"__________________________________________________________________________________________________\n",
"lambda_1 (Lambda) (None, 30, 10, 1) 0 code_feature[0][0] \n",
"__________________________________________________________________________________________________\n",
"lambda_2 (Lambda) (None, 30) 0 date_feature[0][0] \n",
"__________________________________________________________________________________________________\n",
"util_feature (InputLayer) [(None, 30)] 0 \n",
"__________________________________________________________________________________________________\n",
"multiply (Multiply) (None, 30, 10, 100) 0 code_embed[0][0] \n",
" lambda_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"date_embedding (Embedding) (None, 30, 100) 36600 date_feature[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_strided_slice_1 (Te (None, 30, 1) 0 lambda_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"util_embedding (Embedding) (None, 30, 100) 400 util_feature[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_strided_slice (Tens (None, 30, 1) 0 lambda_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_Sum_1 (TensorFlowOp (None, 1) 0 tf_op_layer_ExpandDims[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_Sum (TensorFlowOpLa (None, 30, 100) 0 multiply[0][0] \n",
"__________________________________________________________________________________________________\n",
"multiply_2 (Multiply) (None, 30, 100) 0 date_embedding[0][0] \n",
" tf_op_layer_strided_slice_1[0][0]\n",
"__________________________________________________________________________________________________\n",
"multiply_1 (Multiply) (None, 30, 100) 0 util_embedding[0][0] \n",
" tf_op_layer_strided_slice[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_Shape (TensorFlowOp (2,) 0 tf_op_layer_Sum_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"add (Add) (None, 30, 100) 0 tf_op_layer_Sum[0][0] \n",
" multiply_2[0][0] \n",
" multiply_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_Fill (TensorFlowOpL (None, 1) 0 tf_op_layer_Shape[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_concat (TensorFlowO (None, 31, 100) 0 tf_op_layer_ExpandDims[0][0] \n",
" add[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_concat_4 (TensorFlo (None, 31) 0 tf_op_layer_Fill[0][0] \n",
" lambda_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_concat_1 (TensorFlo (None, 31, 100) 0 tf_op_layer_ExpandDims[0][0] \n",
" add[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_concat_3 (TensorFlo (None, 31) 0 tf_op_layer_Fill[0][0] \n",
" lambda_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_concat_2 (TensorFlo (None, 31, 100) 0 tf_op_layer_ExpandDims[0][0] \n",
" add[0][0] \n",
"__________________________________________________________________________________________________\n",
"multihead_attention-0 (multihea (None, 31, 100) 30200 tf_op_layer_concat[0][0] \n",
" tf_op_layer_concat_4[0][0] \n",
" tf_op_layer_concat_1[0][0] \n",
" tf_op_layer_concat_3[0][0] \n",
" tf_op_layer_concat_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"ffn-0 (ffn) (None, 31, 100) 20200 multihead_attention-0[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_strided_slice_2 (Te (None, 1, 100) 0 ffn-0[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_Squeeze (TensorFlow (None, 100) 0 tf_op_layer_strided_slice_2[0][0]\n",
"__________________________________________________________________________________________________\n",
"patient_embedding (Dense) (None, 100) 10100 tf_op_layer_Squeeze[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_strided_slice_3 (Te (None, 30, 100) 0 ffn-0[0][0] \n",
"__________________________________________________________________________________________________\n",
"code_label (Dense) (None, 3357) 339057 patient_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"cat_label (Dense) (None, 30, 3357) 339057 tf_op_layer_strided_slice_3[0][0]\n",
"==================================================================================================\n",
"Total params: 5,761,614\n",
"Trainable params: 5,761,614\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"None\n",
"time: 41.9 ms\n"
]
}
],
"source": [
"opt = tf.keras.optimizers.Adam(learning_rate=0.0001)\n",
"model.compile(optimizer=opt, loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.Recall(top_k=5), \n",
" tf.keras.metrics.Recall(top_k=10), \n",
" tf.keras.metrics.Recall(top_k=30)])\n",
"print(model.summary())"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "vanilla-karen",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30/30 [==============================] - 3s 61ms/step - loss: 0.0046 - code_label_loss: 0.0046 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.2107 - code_label_recall_20: 0.3366 - code_label_recall_21: 0.5424 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00\n"
]
},
{
"data": {
"text/plain": [
"[0.004690711852163076,\n",
" 0.004690711852163076,\n",
" 0.0,\n",
" 0.2135988175868988,\n",
" 0.3411019444465637,\n",
" 0.5410114526748657,\n",
" 0.0,\n",
" 0.0,\n",
" 0.0]"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"time: 2.81 s\n"
]
}
],
"source": [
"model.evaluate(valid_generator)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "verbal-kingdom",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"WARNING:tensorflow:Gradients do not exist for variables ['cat_label/kernel:0', 'cat_label/bias:0'] when minimizing the loss.\n",
"WARNING:tensorflow:Gradients do not exist for variables ['cat_label/kernel:0', 'cat_label/bias:0'] when minimizing the loss.\n",
"70/70 - 14s - loss: 0.0044 - code_label_loss: 0.0044 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.2476 - code_label_recall_20: 0.3654 - code_label_recall_21: 0.5521 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0043 - val_code_label_loss: 0.0043 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.2744 - val_code_label_recall_20: 0.3807 - val_code_label_recall_21: 0.5642 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 2/20\n",
"70/70 - 10s - loss: 0.0042 - code_label_loss: 0.0042 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.2860 - code_label_recall_20: 0.3921 - code_label_recall_21: 0.5814 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0042 - val_code_label_loss: 0.0042 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.2876 - val_code_label_recall_20: 0.3942 - val_code_label_recall_21: 0.5826 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 3/20\n",
"70/70 - 10s - loss: 0.0041 - code_label_loss: 0.0041 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.2990 - code_label_recall_20: 0.4049 - code_label_recall_21: 0.6000 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0041 - val_code_label_loss: 0.0041 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.2922 - val_code_label_recall_20: 0.4011 - val_code_label_recall_21: 0.5971 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 4/20\n",
"70/70 - 10s - loss: 0.0040 - code_label_loss: 0.0040 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.3072 - code_label_recall_20: 0.4153 - code_label_recall_21: 0.6133 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0041 - val_code_label_loss: 0.0041 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.2964 - val_code_label_recall_20: 0.4052 - val_code_label_recall_21: 0.6013 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 5/20\n",
"70/70 - 10s - loss: 0.0039 - code_label_loss: 0.0039 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.3150 - code_label_recall_20: 0.4223 - code_label_recall_21: 0.6242 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0040 - val_code_label_loss: 0.0040 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.3007 - val_code_label_recall_20: 0.4090 - val_code_label_recall_21: 0.6083 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 6/20\n",
"70/70 - 10s - loss: 0.0039 - code_label_loss: 0.0039 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.3216 - code_label_recall_20: 0.4308 - code_label_recall_21: 0.6324 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0040 - val_code_label_loss: 0.0040 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.3042 - val_code_label_recall_20: 0.4122 - val_code_label_recall_21: 0.6110 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 7/20\n",
"70/70 - 10s - loss: 0.0038 - code_label_loss: 0.0038 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.3261 - code_label_recall_20: 0.4373 - code_label_recall_21: 0.6412 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0040 - val_code_label_loss: 0.0040 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.3072 - val_code_label_recall_20: 0.4144 - val_code_label_recall_21: 0.6156 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 8/20\n",
"70/70 - 11s - loss: 0.0038 - code_label_loss: 0.0038 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.3330 - code_label_recall_20: 0.4455 - code_label_recall_21: 0.6494 - cat_label_recall_19: 0.0000e+00 - cat_label_recall_20: 0.0000e+00 - cat_label_recall_21: 0.0000e+00 - val_loss: 0.0040 - val_code_label_loss: 0.0040 - val_cat_label_loss: 0.0000e+00 - val_code_label_recall_19: 0.3084 - val_code_label_recall_20: 0.4160 - val_code_label_recall_21: 0.6147 - val_cat_label_recall_19: 0.0000e+00 - val_cat_label_recall_20: 0.0000e+00 - val_cat_label_recall_21: 0.0000e+00\n",
"Epoch 9/20\n",
"70/70 - 10s - loss: 0.0037 - code_label_loss: 0.0037 - cat_label_loss: 0.0000e+00 - code_label_recall_19: 0.3374 - code_label_
gitextract_9ionexop/
├── .gitignore
├── README.md
├── attention/
│ ├── attention.ipynb
│ ├── attention_weights_finetune
│ ├── attention_weights_pretrain
│ ├── fake_data
│ ├── get_att.py
│ └── saveModel/
│ ├── saved_model.pb
│ └── variables/
│ ├── variables.data-00000-of-00002
│ ├── variables.data-00001-of-00002
│ └── variables.index
├── finetune/
│ ├── asthma/
│ │ └── fine-tune.ipynb
│ ├── autoDiag/
│ │ └── fine-tune.ipynb
│ └── suicide/
│ └── fine-tune.ipynb
└── pretraining/
├── DataGenerator.py
├── cpt.py
├── plot.ipynb
└── train.py
SYMBOL INDEX (38 symbols across 3 files)
FILE: attention/get_att.py
class DataGenerator (line 5) | class DataGenerator(tf.keras.utils.Sequence):
method __init__ (line 6) | def __init__(self, seqs, vocab_sizes, list_IDs, max_visit, max_code, b...
method __len__ (line 17) | def __len__(self):
method __getitem__ (line 21) | def __getitem__(self, index):
method on_epoch_end (line 28) | def on_epoch_end(self):
method __data_generation (line 34) | def __data_generation(self, list_IDs_temp):
method date_padding (line 63) | def date_padding(self, seq):
method code_padding (line 71) | def code_padding(self, seq):
function process_code (line 83) | def process_code(seq, vocab2int):
function process_util (line 102) | def process_util(seq, util2int):
function process_demo (line 118) | def process_demo(age_seq, sex_seq, vocab2int):
function get_cat (line 130) | def get_cat(seq,code2cat):
FILE: pretraining/DataGenerator.py
class DataGenerator (line 5) | class DataGenerator(tf.keras.utils.Sequence):
method __init__ (line 6) | def __init__(self, seqs, vocab_sizes, list_IDs, max_visit, max_code, b...
method __len__ (line 17) | def __len__(self):
method __getitem__ (line 21) | def __getitem__(self, index):
method on_epoch_end (line 28) | def on_epoch_end(self):
method __data_generation (line 34) | def __data_generation(self, list_IDs_temp):
method date_padding (line 67) | def date_padding(self, seq):
method code_padding (line 75) | def code_padding(self, seq):
method code_labelling (line 87) | def code_labelling(self, seq):
method cat_labelling (line 96) | def cat_labelling(self, seq):
function process_code (line 108) | def process_code(seq, PAD=True):
function process_util (line 123) | def process_util(seq):
function process_demo (line 139) | def process_demo(age_seq, sex_seq):
function get_cat (line 151) | def get_cat(seq, diag2cat, proc2cat, drug2cat):
FILE: pretraining/cpt.py
function create_code_mask (line 6) | def create_code_mask(code_seq):
function create_visit_mask (line 10) | def create_visit_mask(seq):
function scaled_dot_product_attention (line 14) | def scaled_dot_product_attention(Q, K, V, Q_masks, K_masks):
class multihead_attention (line 33) | class multihead_attention(tf.keras.layers.Layer):
method __init__ (line 34) | def __init__(self, d_model, num_heads, name="multihead_attention"):
method call (line 47) | def call(self, queries, keys, values, query_masks, key_masks):
class ffn (line 71) | class ffn(tf.keras.layers.Layer):
method __init__ (line 72) | def __init__(self, d_model, ffn_dim, name="ffn"):
method call (line 80) | def call(self, inputs):
function cat_recall (line 87) | def cat_recall(y_true, y_pred):
function cat_loss_fun (line 94) | def cat_loss_fun(y_true, y_pred):
function model (line 101) | def model(
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (423K chars).
[
{
"path": ".gitignore",
"chars": 1835,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n.DS_Store\n.AppleDouble\n.LSOverride\n\n# C exten"
},
{
"path": "README.md",
"chars": 2719,
"preview": "# Claim-PT: Pretrained transformer framework on pediatric claims data for population specific tasks\n\nThis repository con"
},
{
"path": "attention/attention.ipynb",
"chars": 44471,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"regulation-chemistry\",\n \"metadata\": {},\n \"source\": [\n \"##"
},
{
"path": "attention/get_att.py",
"chars": 6704,
"preview": "import numpy as np\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nclass DataGenerator(tf.keras.utils.Seque"
},
{
"path": "finetune/asthma/fine-tune.ipynb",
"chars": 103924,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"wicked-finder\",\n \"metadata\": {},\n \"outp"
},
{
"path": "finetune/autoDiag/fine-tune.ipynb",
"chars": 127720,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"wicked-finder\",\n \"metadata\": {},\n \"outp"
},
{
"path": "finetune/suicide/fine-tune.ipynb",
"chars": 53527,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"wicked-finder\",\n \"metadata\": {},\n \"outp"
},
{
"path": "pretraining/DataGenerator.py",
"chars": 5624,
"preview": "import tensorflow as tf\nimport pandas as pd\nimport numpy as np\n\nclass DataGenerator(tf.keras.utils.Sequence):\n def __"
},
{
"path": "pretraining/cpt.py",
"chars": 8370,
"preview": "import numpy as np\nimport tensorflow as tf\n\nfrom tensorflow.keras import layers\n\ndef create_code_mask(code_seq):\n cod"
},
{
"path": "pretraining/plot.ipynb",
"chars": 46291,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"boolean-adaptation\",\n \"metadata\": {},\n "
},
{
"path": "pretraining/train.py",
"chars": 2513,
"preview": "import numpy as np\nimport _pickle as pickle\nimport pandas as pd\n\nfrom cpt import *\nfrom DataGenerator import *\n\nprint(\"-"
}
]
// ... and 7 more files (download for full content)
About this extraction
This page contains the full source code of the 2g-XzenG/Claim-PT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (38.7 MB), approximately 184.0k tokens, and a symbol index with 38 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.