Repository: 2g-XzenG/Claim-PT Branch: main Commit: 38348ccbe0c1 Files: 18 Total size: 38.7 MB Directory structure: gitextract_9ionexop/ ├── .gitignore ├── README.md ├── attention/ │ ├── attention.ipynb │ ├── attention_weights_finetune │ ├── attention_weights_pretrain │ ├── fake_data │ ├── get_att.py │ └── saveModel/ │ ├── saved_model.pb │ └── variables/ │ ├── variables.data-00000-of-00002 │ ├── variables.data-00001-of-00002 │ └── variables.index ├── finetune/ │ ├── asthma/ │ │ └── fine-tune.ipynb │ ├── autoDiag/ │ │ └── fine-tune.ipynb │ └── suicide/ │ └── fine-tune.ipynb └── pretraining/ ├── DataGenerator.py ├── cpt.py ├── plot.ipynb └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class .DS_Store .AppleDouble .LSOverride # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: README.md ================================================ # Claim-PT: Pretrained transformer framework on pediatric claims data for population specific tasks This repository contains the tensorflow implementation of the following paper: Paper Name: Pretrained transformer framework on pediatric claims data for population specific tasks Authors: Xianlong Zeng, Simon L. Linwood, Chang Liu Abstract: The adoption of electronic health records (EHR) has become universal during the past decade, which has afforded in-depth data-based research. By learning from the large amount of healthcare data, various data-driven models have been built to predict future events for different medical tasks, such as auto diagnosis and heart-attack prediction. Although EHR is abundant, the population that satisfies specific criteria for learning population-specific tasks is scarce, making it challenging to train data-hungry deep learning models. This study presents the Claim Pre-Training (Claim-PT) framework, a generic pre-training model that first trains on the entire pediatric claims dataset, followed by a discriminative fine-tuning on each population-specific task. The semantic meaning of medical events can be captured in the pre-training stage, and the effective knowledge transfer is completed through the task-aware fine-tuning stage. The fine-tuning process requires minimal parameter modification without changing the model architecture, which mitigates the data scarcity issue and helps train the deep learning model adequately on small patient cohorts. We conducted experiments on a real-world claims dataset with more than one million patient records. Experimental results on two downstream tasks demonstrated the effectiveness of our method: our general task-agnostic pre-training framework outperformed tailored task-specific models, achieving more than 10\% higher in model performance as compared to baselines. In addition, our framework showed a great generalizability potential to transfer learned knowledge from one institution to another, paving the way for future healthcare model pre-training across institutions. [Paper Link](https://www.nature.com/articles/s41598-022-07545-1): Zeng, Xianlong, Simon L. Linwood, and Chang Liu. "Pretrained transformer framework on pediatric claims data for population specific tasks." Scientific Reports 12, no. 1 (2022): 3651. # Environment Ubuntu16.04, Python3.7, TensorFlow2.1 # Model Pretraining in folder pretraining ### List of hyper-parameter we used in the pre-training stage - MAX_VISIT=30 - MAX_CODE=10 - MAX_DEMO=2 - PATIENT_DIM=100 - BATCH_SIZE = 100 - TRAIN_RATIO = 0.8 - DATA_SIZE = len(age_seq) - EPOCHS = 1000 # Downstream finetune in folder fineune # Attention playground in folder attention ================================================ FILE: attention/attention.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "regulation-chemistry", "metadata": {}, "source": [ "### Case Study: Attention Visualization\n", "To illustrate the benefit of applying attention mechanisms and better understand how attention weights are changed from the pre-training stage to the fine-tuning stage, we analyze the learned attention weights at two different stages through a case study. " ] }, { "cell_type": "code", "execution_count": 1, "id": "european-allowance", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "wandb: WARNING W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.\n" ] } ], "source": [ "import numpy as np\n", "import _pickle as pickle\n", "import pandas as pd\n", "import torch\n", "from bertviz import head_view\n", "from transformers import BertTokenizer, BertModel\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "moderate-translator", "metadata": {}, "outputs": [], "source": [ "fake_data = pickle.load(open(\"./fake_data\", \"rb\"))" ] }, { "cell_type": "code", "execution_count": 3, "id": "juvenile-tennis", "metadata": {}, "outputs": [], "source": [ "attention_weights_pretrain = pickle.load(open(\"./attention_weights_pretrain\", \"rb\"))\n", "attention_weights_finetune = pickle.load(open(\"./attention_weights_finetune\", \"rb\"))" ] }, { "cell_type": "code", "execution_count": 4, "id": "social-affairs", "metadata": {}, "outputs": [], "source": [ "def attention_scores(weights, k):\n", " weight = weights[1][k:k+2]\n", " sentence = [list(x)[1][:7] for x in fake_data[k]]\n", " \n", " attn = torch.tensor(weight[:, :len(sentence), :len(sentence)])\n", " attn = attn.repeat(1, 1, 1, 1)\n", " attn_new = []\n", " for _ in range(1):\n", " attn_new.append(attn)\n", " x = head_view(attn_new, sentence)\n", " return " ] }, { "cell_type": "markdown", "id": "floating-surge", "metadata": {}, "source": [ "# Select Patient Index" ] }, { "cell_type": "code", "execution_count": 5, "id": "specialized-maximum", "metadata": {}, "outputs": [], "source": [ "k=0" ] }, { "cell_type": "markdown", "id": "according-assets", "metadata": {}, "source": [ "# Pretrain" ] }, { "cell_type": "code", "execution_count": 6, "id": "living-gates", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " \n", "
\n", " \n", " Layer: \n", " \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": "/**\n * @fileoverview Transformer Visualization D3 javascript code.\n *\n *\n * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n *\n * Change log:\n *\n * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n * 12/29/20 Jesse Vig Significant refactor.\n * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n * 07/25/21 Jesse Vig Support layer filtering\n **/\n\nrequire.config({\n paths: {\n d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n }\n});\n\nrequirejs(['jquery', 'd3'], function ($, d3) {\n\n const params = {\"attention\": [{\"name\": null, \"attn\": [[[[0.16614454984664917, 0.1271992027759552, 0.11374139040708542, 0.15638020634651184, 0.20374387502670288, 0.11444535851478577, 0.11834541708230972], [0.16203011572360992, 0.1401018649339676, 0.13978800177574158, 0.1404246836900711, 0.13866281509399414, 0.13900551199913025, 0.1399870365858078], [0.14190000295639038, 0.14325059950351715, 0.14268504083156586, 0.14326979219913483, 0.14316092431545258, 0.14276131987571716, 0.14297233521938324], [0.1783653199672699, 0.13668516278266907, 0.13705065846443176, 0.13718663156032562, 0.1375586986541748, 0.1364724338054657, 0.13668116927146912], [0.17943505942821503, 0.13677170872688293, 0.13591942191123962, 0.13615824282169342, 0.13820093870162964, 0.13686884939670563, 0.13664573431015015], [0.1567888855934143, 0.1408628672361374, 0.14035286009311676, 0.1405779868364334, 0.14033928513526917, 0.14011918008327484, 0.14095890522003174], [0.1511572003364563, 0.1412300169467926, 0.14104074239730835, 0.14131779968738556, 0.14250311255455017, 0.1413128525018692, 0.14143821597099304]], [[0.047577470541000366, 0.04034224525094032, 0.03173127770423889, 0.03325849026441574, 0.05665835365653038, 0.040445178747177124, 0.03266838937997818], [0.03983418643474579, 0.03432993218302727, 0.0343342199921608, 0.034387726336717606, 0.034017566591501236, 0.03415762633085251, 0.03437581658363342], [0.0348617359995842, 0.03448709473013878, 0.034443072974681854, 0.034463874995708466, 0.0343620628118515, 0.03469187021255493, 0.03444177284836769], [0.03609755262732506, 0.03446199744939804, 0.03440416231751442, 0.034431908279657364, 0.03436928614974022, 0.034668270498514175, 0.0344182550907135], [0.04275275021791458, 0.034237828105688095, 0.03420986980199814, 0.03424009680747986, 0.03407548740506172, 0.03425997495651245, 0.03429330885410309], [0.04144877940416336, 0.034321531653404236, 0.034290462732315063, 0.03432182967662811, 0.03395794332027435, 0.03400801867246628, 0.0343194305896759], [0.03561968356370926, 0.034495096653699875, 0.034405313432216644, 0.03443482518196106, 0.034406013786792755, 0.03469153866171837, 0.03441842645406723]]]], \"left_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"], \"right_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"]}], \"default_filter\": \"0\", \"root_div_id\": \"bertviz-973440206fb84c43a608ed88945dabed\", \"layer\": null, \"heads\": null, \"include_layers\": [0]}; // HACK: {\"attention\": [{\"name\": null, \"attn\": [[[[0.16614454984664917, 0.1271992027759552, 0.11374139040708542, 0.15638020634651184, 0.20374387502670288, 0.11444535851478577, 0.11834541708230972], [0.16203011572360992, 0.1401018649339676, 0.13978800177574158, 0.1404246836900711, 0.13866281509399414, 0.13900551199913025, 0.1399870365858078], [0.14190000295639038, 0.14325059950351715, 0.14268504083156586, 0.14326979219913483, 0.14316092431545258, 0.14276131987571716, 0.14297233521938324], [0.1783653199672699, 0.13668516278266907, 0.13705065846443176, 0.13718663156032562, 0.1375586986541748, 0.1364724338054657, 0.13668116927146912], [0.17943505942821503, 0.13677170872688293, 0.13591942191123962, 0.13615824282169342, 0.13820093870162964, 0.13686884939670563, 0.13664573431015015], [0.1567888855934143, 0.1408628672361374, 0.14035286009311676, 0.1405779868364334, 0.14033928513526917, 0.14011918008327484, 0.14095890522003174], [0.1511572003364563, 0.1412300169467926, 0.14104074239730835, 0.14131779968738556, 0.14250311255455017, 0.1413128525018692, 0.14143821597099304]], [[0.047577470541000366, 0.04034224525094032, 0.03173127770423889, 0.03325849026441574, 0.05665835365653038, 0.040445178747177124, 0.03266838937997818], [0.03983418643474579, 0.03432993218302727, 0.0343342199921608, 0.034387726336717606, 0.034017566591501236, 0.03415762633085251, 0.03437581658363342], [0.0348617359995842, 0.03448709473013878, 0.034443072974681854, 0.034463874995708466, 0.0343620628118515, 0.03469187021255493, 0.03444177284836769], [0.03609755262732506, 0.03446199744939804, 0.03440416231751442, 0.034431908279657364, 0.03436928614974022, 0.034668270498514175, 0.0344182550907135], [0.04275275021791458, 0.034237828105688095, 0.03420986980199814, 0.03424009680747986, 0.03407548740506172, 0.03425997495651245, 0.03429330885410309], [0.04144877940416336, 0.034321531653404236, 0.034290462732315063, 0.03432182967662811, 0.03395794332027435, 0.03400801867246628, 0.0343194305896759], [0.03561968356370926, 0.034495096653699875, 0.034405313432216644, 0.03443482518196106, 0.034406013786792755, 0.03469153866171837, 0.03441842645406723]]]], \"left_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"], \"right_text\": [\"P_99284\", \"D_948.0\", \"D_944.2\", \"M_POLYE\", \"P_272\", \"P_99254\", \"P_270\"]}], \"default_filter\": \"0\", \"root_div_id\": \"bertviz-973440206fb84c43a608ed88945dabed\", \"layer\": null, \"heads\": null, \"include_layers\": [0]} is a template marker that is replaced by actual params.\n const TEXT_SIZE = 15;\n const BOXWIDTH = 110;\n const BOXHEIGHT = 22.5;\n const MATRIX_WIDTH = 115;\n const CHECKBOX_SIZE = 20;\n const TEXT_TOP = 30;\n\n console.log(\"d3 version\", d3.version)\n let headColors;\n try {\n headColors = d3.scaleOrdinal(d3.schemeCategory10);\n } catch (err) {\n console.log('Older d3 version')\n headColors = d3.scale.category10();\n }\n let config = {};\n initialize();\n renderVis();\n\n function initialize() {\n config.attention = params['attention'];\n config.filter = params['default_filter'];\n config.rootDivId = params['root_div_id'];\n config.nLayers = config.attention[config.filter]['attn'].length;\n config.nHeads = config.attention[config.filter]['attn'][0].length;\n config.layers = params['include_layers']\n\n if (params['heads']) {\n config.headVis = new Array(config.nHeads).fill(false);\n params['heads'].forEach(x => config.headVis[x] = true);\n } else {\n config.headVis = new Array(config.nHeads).fill(true);\n }\n config.initialTextLength = config.attention[config.filter].right_text.length;\n config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer));\n config.layer = config.layers[config.layer_seq]\n\n let layerEl = $(`#${config.rootDivId} #layer`);\n for (const layer of config.layers) {\n layerEl.append($(\"