[
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2016, mp2893\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of RETAIN nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "RETAIN\n=========================================\n\nRETAIN is an interpretable predictive model for healthcare applications. Given patient records, it can make predictions while explaining how each medical code (diagnosis codes, medication codes, or procedure codes) at each visit contributes to the prediction. The interpretation is possible due to the use of neural attention mechanism.\n\n[![RETAIN Interpretation Demo](http://mp2893.com/images/thumbnail.png)](https://youtu.be/co3lTOSgFlA?t=1m46s \"RETAIN Interpretation Demo - Click to Watch!\")\nUsing RETAIN, you can calculate how positively/negatively each medical code (diagnosis, medication, or procedure code) at different visits contributes to the final score. In this case, we are predicting whether the given patient will be diagnosed with Heart Failure (HF). You can see that the codes that are highly related to HF makes positive contributions. RETAIN also learns to pay more attention to new information than old information. You can see that Cardiac Dysrythmia (CD) makes a bigger contribution as it occurs in the more recent visit.\n\n#### Relevant Publications\n\nRETAIN implements an algorithm introduced in the following [paper](http://papers.nips.cc/paper/6321-retain-an-interpretable-predictive-model-for-healthcare-using-reverse-time-attention-mechanism):\n\n\tRETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism\n\tEdward Choi, Mohammad Taha Bahadori, Joshua A. Kulas, Andy Schuetz, Walter F. Stewart, Jimeng Sun,\n\tNIPS 2016, pp.3504-3512\n\n#### Notice\n\nThe RETAIN paper formulates the model as being able to make prediction at each timestep (e.g. try to predict what diagnoses the patient will receive at each visit), and treats sequence classification (e.g. Given a patient record, will he be diagnosed with heart failure in the future?) as a special case, since sequence classification makes the prediction at the last timestep only.\n\nThis code, however, is implemented to perform the sequence classification task. For example, you can use this code to predict whether the given patient is a heart failure patient or not. Or you can predict whether this patient will be readmitted in the future. The more general version of RETAIN will be released in the future.\n\t\n#### Running RETAIN\n\n**STEP 1: Installation**  \n\n1. Install [python](https://www.python.org/), [Theano](http://deeplearning.net/software/theano/index.html). We use Python 2.7, Theano 0.8. Theano can be easily installed in Ubuntu as suggested [here](http://deeplearning.net/software/theano/install_ubuntu.html#install-ubuntu)\n\n2. If you plan to use GPU computation, install [CUDA](https://developer.nvidia.com/cuda-downloads)\n\n3. Download/clone the RETAIN code  \n\n**STEP 2: Fast way to test RETAIN with MIMIC-III**  \nThis step describes how to train RETAIN, with minimum number of steps using MIMIC-III, to predict patients' mortality using their visit records.\n\n0. You will first need to request access for [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/), a publicly avaiable electronic health records collected from ICU patients over 11 years. \n\n1. You can use \"process_mimic.py\" to process MIMIC-III dataset and generate a suitable training dataset for RETAIN. \nPlace the script to the same location where the MIMIC-III CSV files are located, and run the script.\nThe execution command is `python process_mimic.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv <output file>`.\n\n2. Run RETAIN using the \".seqs\" and \".morts\" file generated by process_mimic.py. \nThe \".seqs\" file contains the sequence of visits for each patient. Each visit consists of multiple diagnosis codes.\nHowever we recommend using \".3digitICD9.seqs\" file instead, as the results will be much more interpretable.\n(Or you could use [Single-level Clical Classification Software for ICD9](https://www.hcup-us.ahrq.gov/toolssoftware/ccs/ccs.jsp#examples) to decrease the number of codes to a couple of hundreds, which will even more improve the performance)\nThe \".morts\" file contains the sequence of mortality labels for each patient. \nThe command is `python retain.py <3digitICD9.seqs file> 942 <morts file> <output path> --simple_load --n_epochs 100 --keep_prob_context 0.8 --keep_prob_emb 0.5`.\n`942` is the number of the entire 3-digit ICD9 codes used in the dataset.\n\n3. To test the model for interpretation, please refer to Step 6. I personally found that _perinatal jaundice (ICD9 774)_ has high correlation with mortality.\n\n4. The model reaches AUC above 0.8 with the above command, but the interpretations are not super clear. \nYou could tune the hyper-parameters, but I doubt things will dramatically improve. \nAfter all, only 7,500 patients made more than a single hospital visit, and most of them have only two visits.\n\n**STEP 3: How to prepare your own dataset**  \n\n1. RETAIN's training dataset needs to be a Python cPickled list of list of list. The outermost list corresponds to patients, the intermediate to the visit sequence each patient made, and the innermost to the medical codes (e.g. diagnosis codes, medication codes, procedure codes, etc.) that occurred within each visit.\nFirst, medical codes need to be converted to an integer. Then a single visit can be seen as a list of integers. Then a patient can be seen as a list of visits.\nFor example, [5,8,15] means the patient was assigned with code 5, 8, and 15 at a certain visit.\nIf a patient made two visits [1,2,3] and [4,5,6,7], it can be converted to a list of list [[1,2,3], [4,5,6,7]].\nMultiple patients can be represented as [[[1,2,3], [4,5,6,7]], [[2,4], [8,3,1], [3]]], which means there are two patients where the first patient made two visits and the second patient made three visits.\nThis list of list of list needs to be pickled using cPickle. We will refer to this file as the \"visit file\".\n\n2. The total number of unique medical codes is required to run RETAIN.\nFor example, if the dataset is using 14,000 diagnosis codes and 11,000 procedure codes, the total number is 25,000. \n\n3. The label dataset (let us call this \"label file\") needs to be a Python cPickled list. Each element corresponds to the true label of each patient. For example, 1 can be the case patient and 0 can be the control patient. If there are two patients where only the first patient is a case, then we should have [1,0].\n\n4. The \"visit file\" and \"label file\" need to have 3 sets respectively: training set, validation set, and test set.\nThe file extension must be \".train\", \".valid\", and \".test\" respectivley.  \nFor example, if you want to use a file named \"my_visit_sequences\" as the \"visit file\", then RETAIN will try to load \"my_visit_sequences.train\", \"my_visit_sequences.valid\", and \"my_visit_sequences.test\".  \nThis is also true for the \"label file\"\n\n5. You can use the time information regarding the visits as an additional source of information. Let us call this \"time file\".\nNote that the time information could be anything: duration between consecutive visits, cumulative number of days since the first visit, etc.\n\"time file\" needs to be prepared as a Python cPickled list of list. The outermost list corresponds to patients, and the innermost to the time information of each visit.\nFor example, given a \"visit file\" [[[1,2,3], [4,5,6,7]], [[2,4], [8,3,1], [3]]], its corresponding \"time file\" could look like [[0, 15], [0, 45, 23]], if we are using the duration between the consecutive visits. (of course the numbers are fake, and I've set the duration for the first visit to zero.)\nUse `--time_file <path to time file>` option to use \"time file\"\nRemember that the \".train\", \".valid\", \".test\" rule also applies to the \"time file\" as well.\n\n**Additional: Using your own medical code representations**  \nRETAIN internally learns the vector representation of medical codes while training. These vectors are initialized with random values of course.  \nYou can, however, also use your own medical code representations, if you have one. (They can be trained by using Skip-gram like algorithms. Refer to [Med2Vec](http://www.kdd.org/kdd2016/subtopic/view/multi-layer-representation-learning-for-medical-concepts) or [this](http://arxiv.org/abs/1602.03686) for further details.)\nIf you want to provide the medical code representations, it has to be a list of list (basically a matrix) of N rows and M columns where N is the number of unique codes in your \"visit file\" and M is the size of the code representations.\nSpecify the path to your code representation file using `--embed_file <path to embedding file>`.\nAdditionally, even if you use your own medical code representations, you can re-train (a.k.a fine-tune) them as you train RETAIN.\nUse `--embed_finetune` option to do this. If you are not providing your own medical code representations, RETAIN will use randomly initialized one, which obviously requires this fine-tuning process. Since the default is to use the fine-tuning, you do not need to worry about this.\n\n**STEP 4: Running RETAIN**  \n\n1. The minimum input you need to run RETAIN is the \"visit file\", the number of unique medical codes in the \"visit file\", \nthe \"label file\", and the output path. The output path is where the learned weights and the log will be saved.  \n`python retain.py <visit file> <# codes in the visit file> <label file> <output path>`  \n\n2. Specifying `--verbose` option will print training process after each 10 mini-batches.\n\n3. You can specify the size of the embedding W_emb, the size of the hidden layer of the GRU that generates alpha, and the size of the hidden layer of the GRU that generates beta.\nThe respective commands are `--embed_size <integer>`, `--alpha_hidden_dim_size <integer>`, and `--beta_hidden_dim_size <integer>`.\nFor example `--alpha_hidden_dim_size 128` will tell RETAIN to use a GRU with 128-dimensional hidden layer for generating alpha.\n\n4. Dropouts are applied to two places: 1) to the input embedding, 2) to the context vector c_i. The respective dropout rates can be adjusted using `--keep_prob_embed {0.0, 1.0}` and `--keep_prob_context {0.0, 1.0}`. Dropout values affect the performance so it is recommended to tune them for your data.\n\n5. L2 regularizations can be applied to W_emb, w_alpha, W_beta, and w_output.\n\n6. Additional options can be specified such as the size of the batch size, the number of epochs, etc. Detailed information can be accessed by `python retain.py --help`\n\n7. My personal recommendation: use mild regularization (0.0001 ~ 0.001) on all four weights, and use moderate dropout on the context vector only. But this entirely depends on your data, so you should always tune the hyperparameters for yourself.\n\n**STEP 5: Getting your results**  \n\nRETAIN checks the AUC of the validation set after each epoch, and if it is higher than all previous values, it will save the current model. The model file is generated by [numpy.savez_compressed](http://docs.scipy.org/doc/numpy-1.10.1/reference/generated/numpy.savez_compressed.html).\n\n**Step 6: Testing your model**\n\n1. Using the file \"test_retain.py\", you can calculate the contributions of each medical code at each visit. First you need to have a trained model that was saved by numpy.savez_compressed. Note that you need to know the configuration with which you trained RETAIN (e.g. use of `--time_file`, use of `--use_log_time`.)\n\n2. Again, you need the \"visit file\" and \"label file\" prepared in the same way. This time, however, you do not need to follow the \".train\", \".valid\", \".test\" rule. The testing script will try to load the file name as given.\n\n3. You also need the mapping information between the actual string medical codes and their integer codes. \n(e.g. \"Hypertension\" is mapped to 24) \nThis file (let's call this \"mapping file\") need to be a Python cPickled dictionary where the keys are the string medical codes and the values are the corresponding intergers. \n(e.g. The mapping file generated by process_mimic.py is the \".types\" file)\nThis file is required to print the contributions of each medical code in a user-friendly format. \n\n4. For the additional options such as `--time_file` or `--use_log_time`, you should use exactly the same configuration with which you trained the model. For more detailed information, use \"--help\" option.\n\n5. The minimum input to run the testing script is the \"model file\", \"visit file\", \"label file\", \"mapping file\", and \"output file\". \"output file\" is where the contributions will be stored.\n`python test_retain.py <model file> <visit file> <label file> <mapping file> <output file>`\n"
  },
  {
    "path": "process_mimic.py",
    "content": "# This script processes MIMIC-III dataset and builds longitudinal diagnosis records for patients with at least two visits.\n# The output data are cPickled, and suitable for training Doctor AI or RETAIN\n# Written by Edward Choi (mp2893@gatech.edu)\n# Usage: Put this script to the foler where MIMIC-III CSV files are located. Then execute the below command.\n# python process_mimic.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv <output file> \n\n# Output files\n# <output file>.pids: List of unique Patient IDs. Used for intermediate processing\n# <output file>.morts: List of binary values indicating the mortality of each patient\n# <output file>.dates: List of List of Python datetime objects. The outer List is for each patient. The inner List is for each visit made by each patient\n# <output file>.seqs: List of List of List of integer diagnosis codes. The outer List is for each patient. The middle List contains visits made by each patient. The inner List contains the integer diagnosis codes that occurred in each visit\n# <output file>.types: Python dictionary that maps string diagnosis codes to integer diagnosis codes.\n\nimport sys\nimport cPickle as pickle\nfrom datetime import datetime\n\ndef convert_to_icd9(dxStr):\n\tif dxStr.startswith('E'):\n\t\tif len(dxStr) > 4: return dxStr[:4] + '.' + dxStr[4:]\n\t\telse: return dxStr\n\telse:\n\t\tif len(dxStr) > 3: return dxStr[:3] + '.' + dxStr[3:]\n\t\telse: return dxStr\n\t\ndef convert_to_3digit_icd9(dxStr):\n\tif dxStr.startswith('E'):\n\t\tif len(dxStr) > 4: return dxStr[:4]\n\t\telse: return dxStr\n\telse:\n\t\tif len(dxStr) > 3: return dxStr[:3]\n\t\telse: return dxStr\n\nif __name__ == '__main__':\n\tadmissionFile = sys.argv[1]\n\tdiagnosisFile = sys.argv[2]\n\tpatientsFile = sys.argv[3]\n\toutFile = sys.argv[4]\n\n\tprint 'Collecting mortality information'\n\tpidDodMap = {}\n\tinfd = open(patientsFile, 'r')\n\tinfd.readline()\n\tfor line in infd:\n\t\ttokens = line.strip().split(',')\n\t\tpid = int(tokens[1])\n\t\tdod_hosp = tokens[5]\n\t\tif len(dod_hosp) > 0:\n\t\t\tpidDodMap[pid] = 1\n\t\telse:\n\t\t\tpidDodMap[pid] = 0\n\tinfd.close()\n\n\tprint 'Building pid-admission mapping, admission-date mapping'\n\tpidAdmMap = {}\n\tadmDateMap = {}\n\tinfd = open(admissionFile, 'r')\n\tinfd.readline()\n\tfor line in infd:\n\t\ttokens = line.strip().split(',')\n\t\tpid = int(tokens[1])\n\t\tadmId = int(tokens[2])\n\t\tadmTime = datetime.strptime(tokens[3], '%Y-%m-%d %H:%M:%S')\n\t\tadmDateMap[admId] = admTime\n\t\tif pid in pidAdmMap: pidAdmMap[pid].append(admId)\n\t\telse: pidAdmMap[pid] = [admId]\n\tinfd.close()\n\n\tprint 'Building admission-dxList mapping'\n\tadmDxMap = {}\n\tadmDxMap_3digit = {}\n\tinfd = open(diagnosisFile, 'r')\n\tinfd.readline()\n\tfor line in infd:\n\t\ttokens = line.strip().split(',')\n\t\tadmId = int(tokens[2])\n\t\tdxStr = 'D_' + convert_to_icd9(tokens[4][1:-1]) ############## Uncomment this line and comment the line below, if you want to use the entire ICD9 digits.\n\t\tdxStr_3digit = 'D_' + convert_to_3digit_icd9(tokens[4][1:-1])\n\n\t\tif admId in admDxMap: \n\t\t\tadmDxMap[admId].append(dxStr)\n\t\telse: \n\t\t\tadmDxMap[admId] = [dxStr]\n\n\t\tif admId in admDxMap_3digit: \n\t\t\tadmDxMap_3digit[admId].append(dxStr_3digit)\n\t\telse: \n\t\t\tadmDxMap_3digit[admId] = [dxStr_3digit]\n\tinfd.close()\n\n\tprint 'Building pid-sortedVisits mapping'\n\tpidSeqMap = {}\n\tpidSeqMap_3digit = {}\n\tfor pid, admIdList in pidAdmMap.iteritems():\n\t\tif len(admIdList) < 2: continue\n\n\t\tsortedList = sorted([(admDateMap[admId], admDxMap[admId]) for admId in admIdList])\n\t\tpidSeqMap[pid] = sortedList\n\n\t\tsortedList_3digit = sorted([(admDateMap[admId], admDxMap_3digit[admId]) for admId in admIdList])\n\t\tpidSeqMap_3digit[pid] = sortedList_3digit\n\t\n\tprint 'Building pids, dates, mortality_labels, strSeqs'\n\tpids = []\n\tdates = []\n\tseqs = []\n\tmorts = []\n\tfor pid, visits in pidSeqMap.iteritems():\n\t\tpids.append(pid)\n\t\tmorts.append(pidDodMap[pid])\n\t\tseq = []\n\t\tdate = []\n\t\tfor visit in visits:\n\t\t\tdate.append(visit[0])\n\t\t\tseq.append(visit[1])\n\t\tdates.append(date)\n\t\tseqs.append(seq)\n\t\n\tprint 'Building pids, dates, strSeqs for 3digit ICD9 code'\n\tseqs_3digit = []\n\tfor pid, visits in pidSeqMap_3digit.iteritems():\n\t\tseq = []\n\t\tfor visit in visits:\n\t\t\tseq.append(visit[1])\n\t\tseqs_3digit.append(seq)\n\t\n\tprint 'Converting strSeqs to intSeqs, and making types'\n\ttypes = {}\n\tnewSeqs = []\n\tfor patient in seqs:\n\t\tnewPatient = []\n\t\tfor visit in patient:\n\t\t\tnewVisit = []\n\t\t\tfor code in visit:\n\t\t\t\tif code in types:\n\t\t\t\t\tnewVisit.append(types[code])\n\t\t\t\telse:\n\t\t\t\t\ttypes[code] = len(types)\n\t\t\t\t\tnewVisit.append(types[code])\n\t\t\tnewPatient.append(newVisit)\n\t\tnewSeqs.append(newPatient)\n\t\n\tprint 'Converting strSeqs to intSeqs, and making types for 3digit ICD9 code'\n\ttypes_3digit = {}\n\tnewSeqs_3digit = []\n\tfor patient in seqs_3digit:\n\t\tnewPatient = []\n\t\tfor visit in patient:\n\t\t\tnewVisit = []\n\t\t\tfor code in set(visit):\n\t\t\t\tif code in types_3digit:\n\t\t\t\t\tnewVisit.append(types_3digit[code])\n\t\t\t\telse:\n\t\t\t\t\ttypes_3digit[code] = len(types_3digit)\n\t\t\t\t\tnewVisit.append(types_3digit[code])\n\t\t\tnewPatient.append(newVisit)\n\t\tnewSeqs_3digit.append(newPatient)\n\n\tpickle.dump(pids, open(outFile+'.pids', 'wb'), -1)\n\tpickle.dump(dates, open(outFile+'.dates', 'wb'), -1)\n\tpickle.dump(morts, open(outFile+'.morts', 'wb'), -1)\n\tpickle.dump(newSeqs, open(outFile+'.seqs', 'wb'), -1)\n\tpickle.dump(types, open(outFile+'.types', 'wb'), -1)\n\tpickle.dump(newSeqs_3digit, open(outFile+'.3digitICD9.seqs', 'wb'), -1)\n\tpickle.dump(types_3digit, open(outFile+'.3digitICD9.types', 'wb'), -1)\n"
  },
  {
    "path": "retain.py",
    "content": "#################################################################\n# Code written by Edward Choi (mp2893@gatech.edu)\n# For bug report, please contact author using the email address\n#################################################################\n\nimport sys, random\nimport numpy as np\nimport cPickle as pickle\nfrom collections import OrderedDict\nimport argparse\n\nimport theano\nimport theano.tensor as T\nfrom theano import config\nfrom theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams\n\nfrom sklearn.metrics import roc_auc_score\n\n_TEST_RATIO = 0.2\n_VALIDATION_RATIO = 0.1\n\ndef unzip(zipped):\n\tnew_params = OrderedDict()\n\tfor key, value in zipped.iteritems():\n\t\tnew_params[key] = value.get_value()\n\treturn new_params\n\ndef numpy_floatX(data):\n\treturn np.asarray(data, dtype=config.floatX)\n\ndef get_random_weight(dim1, dim2, left=-0.1, right=0.1):\n\treturn np.random.uniform(left, right, (dim1, dim2)).astype(config.floatX)\n\ndef load_embedding(infile):\n\tWemb = np.array(pickle.load(open(infile, 'rb'))).astype(config.floatX)\n\treturn Wemb\n\ndef init_params(options):\n\tparams = OrderedDict()\n\ttimeFile = options['timeFile']\n\tembFile = options['embFile']\n\tembDimSize = options['embDimSize']\n\tinputDimSize = options['inputDimSize']\n\talphaHiddenDimSize= options['alphaHiddenDimSize']\n\tbetaHiddenDimSize= options['betaHiddenDimSize']\n\tnumClass = options['numClass']\n\n\tif len(embFile) > 0: \n\t\tprint 'using external code embedding'\n\t\tparams['W_emb'] = load_embedding(embFile)\n\t\tembDimSize = params['W_emb'].shape[1]\n\telse: \n\t\tprint 'using randomly initialized code embedding'\n\t\tparams['W_emb'] = get_random_weight(inputDimSize, embDimSize)\n\n\tgruInputDimSize = embDimSize\n\tif len(timeFile) > 0: gruInputDimSize = embDimSize + 1\n\n\tparams['W_gru_a'] = get_random_weight(gruInputDimSize, 3*alphaHiddenDimSize)\n\tparams['U_gru_a'] = get_random_weight(alphaHiddenDimSize, 3*alphaHiddenDimSize)\n\tparams['b_gru_a'] = np.zeros(3 * alphaHiddenDimSize).astype(config.floatX)\n\n\tparams['W_gru_b'] = get_random_weight(gruInputDimSize, 3*betaHiddenDimSize)\n\tparams['U_gru_b'] = get_random_weight(betaHiddenDimSize, 3*betaHiddenDimSize)\n\tparams['b_gru_b'] = np.zeros(3 * betaHiddenDimSize).astype(config.floatX)\n\n\tparams['w_alpha'] = get_random_weight(alphaHiddenDimSize, 1)\n\tparams['b_alpha'] = np.zeros(1).astype(config.floatX)\n\tparams['W_beta'] = get_random_weight(betaHiddenDimSize, embDimSize)\n\tparams['b_beta'] = np.zeros(embDimSize).astype(config.floatX)\n\tparams['w_output'] = get_random_weight(embDimSize, numClass)\n\tparams['b_output'] = np.zeros(numClass).astype(config.floatX)\n\n\treturn params\n\ndef load_params(options):\n\treturn np.load(options['modelFile'])\n\ndef init_tparams(params, options):\n\ttparams = OrderedDict()\n\tfor key, value in params.iteritems():\n\t\tif not options['embFineTune'] and key == 'W_emb': continue\n\t\ttparams[key] = theano.shared(value, name=key)\n\treturn tparams\n\ndef dropout_layer(state_before, use_noise, trng, keep_prob=0.5):\n\tproj = T.switch(\n                use_noise,\n                state_before * trng.binomial(state_before.shape, p=keep_prob, n=1, dtype=state_before.dtype) / keep_prob,\n                state_before)\n\treturn proj\n\ndef _slice(_x, n, dim):\n\tif _x.ndim == 3:\n\t\treturn _x[:, :, n*dim:(n+1)*dim]\n\treturn _x[:, n*dim:(n+1)*dim]\n\ndef gru_layer(tparams, emb, name, hiddenDimSize):\n\ttimesteps = emb.shape[0]\n\tif emb.ndim == 3: n_samples = emb.shape[1]\n\telse: n_samples = 1\n\n\tdef stepFn(wx, h, U_gru):\n\t\tuh = T.dot(h, U_gru)\n\t\tr = T.nnet.sigmoid(_slice(wx, 0, hiddenDimSize) + _slice(uh, 0, hiddenDimSize))\n\t\tz = T.nnet.sigmoid(_slice(wx, 1, hiddenDimSize) + _slice(uh, 1, hiddenDimSize))\n\t\th_tilde = T.tanh(_slice(wx, 2, hiddenDimSize) + r * _slice(uh, 2, hiddenDimSize))\n\t\th_new = z * h + ((1. - z) * h_tilde)\n\t\treturn h_new\n\n\tWx = T.dot(emb, tparams['W_gru_'+name]) + tparams['b_gru_'+name]\n\tresults, updates = theano.scan(fn=stepFn, sequences=[Wx], outputs_info=T.alloc(numpy_floatX(0.0), n_samples, hiddenDimSize), non_sequences=[tparams['U_gru_'+name]], name='gru_layer', n_steps=timesteps)\n\n\treturn results\n\t\ndef build_model(tparams, options, W_emb=None):\n\tkeep_prob_emb = options['keepProbEmb']\n\tkeep_prob_context = options['keepProbContext']\n\talphaHiddenDimSize = options['alphaHiddenDimSize']\n\tbetaHiddenDimSize = options['betaHiddenDimSize']\n\n\ttrng = RandomStreams(1234)\n\tuse_noise = theano.shared(numpy_floatX(0.))\n\tuseTime = options['useTime']\n\n\tx = T.tensor3('x', dtype=config.floatX)\n\tt = T.matrix('t', dtype=config.floatX)\n\ty = T.vector('y', dtype=config.floatX)\n\tlengths = T.ivector('lengths')\n\n\tn_timesteps = x.shape[0]\n\tn_samples = x.shape[1]\n\n\tif options['embFineTune']: emb = T.dot(x, tparams['W_emb'])\n\telse: emb = T.dot(x, W_emb)\n\n\tif keep_prob_emb < 1.0: emb = dropout_layer(emb, use_noise, trng, keep_prob_emb)\n\n\tif useTime: temb = T.concatenate([emb, t.reshape([n_timesteps,n_samples,1])], axis=2) #Adding the time element to the embedding\n\telse: temb = emb\n\t\n\tdef attentionStep(att_timesteps):\n\t\treverse_emb_t = temb[:att_timesteps][::-1]\n\t\treverse_h_a = gru_layer(tparams, reverse_emb_t, 'a', alphaHiddenDimSize)[::-1] * 0.5\n\t\treverse_h_b = gru_layer(tparams, reverse_emb_t, 'b', betaHiddenDimSize)[::-1] * 0.5\n\n\t\tpreAlpha = T.dot(reverse_h_a, tparams['w_alpha']) + tparams['b_alpha']\n\t\tpreAlpha = preAlpha.reshape((preAlpha.shape[0], preAlpha.shape[1]))\n\t\talpha = (T.nnet.softmax(preAlpha.T)).T\n\n\t\tbeta = T.tanh(T.dot(reverse_h_b, tparams['W_beta']) + tparams['b_beta'])\n\t\tc_t = (alpha[:,:,None] * beta * emb[:att_timesteps]).sum(axis=0)\n\t\treturn c_t\n\n\tcounts = T.arange(n_timesteps)+ 1\n\tc_t, updates = theano.scan(fn=attentionStep, sequences=[counts], outputs_info=None, name='attention_layer', n_steps=n_timesteps)\n        if keep_prob_context < 1.0: c_t = dropout_layer(c_t, use_noise, trng, keep_prob_context)\n\n\tpreY = T.nnet.sigmoid(T.dot(c_t, tparams['w_output']) + tparams['b_output'])\n\tpreY = preY.reshape((preY.shape[0], preY.shape[1]))\n\tindexRow = T.arange(n_samples)\n\ty_hat = preY.T[indexRow, lengths - 1]\n\n\tlogEps = options['logEps']\n\tcross_entropy = -(y * T.log(y_hat + logEps) + (1. - y) * T.log(1. - y_hat + logEps))\n\tcost_noreg = T.mean(cross_entropy)\n\n\tcost = cost_noreg + options['L2_output'] * (tparams['w_output']**2).sum() + options['L2_alpha'] * (tparams['w_alpha']**2).sum() + options['L2_beta'] * (tparams['W_beta']**2).sum()\n\t\n\tif options['embFineTune']: cost += options['L2_emb'] * (tparams['W_emb']**2).sum()\n\n\tif useTime: return use_noise, x, y, t, lengths, cost_noreg, cost, y_hat\n\telse: return use_noise, x, y, lengths, cost_noreg, cost, y_hat\n\ndef adadelta(tparams, grads, x, y, lengths, cost, options, t=None):\n\tzipped_grads = [theano.shared(p.get_value() * numpy_floatX(0.), name='%s_grad' % k) for k, p in tparams.iteritems()]\n\trunning_up2 = [theano.shared(p.get_value() * numpy_floatX(0.), name='%s_rup2' % k) for k, p in tparams.iteritems()]\n\trunning_grads2 = [theano.shared(p.get_value() * numpy_floatX(0.), name='%s_rgrad2' % k) for k, p in tparams.iteritems()]\n\n\tzgup = [(zg, g) for zg, g in zip(zipped_grads, grads)]\n\trg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2)) for rg2, g in zip(running_grads2, grads)]\n\n\tif len(options['timeFile']) > 0:\n\t\tf_grad_shared = theano.function([x, y, t, lengths], cost, updates=zgup + rg2up, name='adadelta_f_grad_shared')\n\telse:\n\t\tf_grad_shared = theano.function([x, y, lengths], cost, updates=zgup + rg2up, name='adadelta_f_grad_shared')\n\n\tupdir = [-T.sqrt(ru2 + 1e-6) / T.sqrt(rg2 + 1e-6) * zg for zg, ru2, rg2 in zip(zipped_grads, running_up2, running_grads2)]\n\tru2up = [(ru2, 0.95 * ru2 + 0.05 * (ud ** 2)) for ru2, ud in zip(running_up2, updir)]\n\tparam_up = [(p, p + ud) for p, ud in zip(tparams.values(), updir)]\n\n\tf_update = theano.function([], [], updates=ru2up + param_up, on_unused_input='ignore', name='adadelta_f_update')\n\n\treturn f_grad_shared, f_update\n\ndef adam(cost, tparams, lr=0.0002, b1=0.1, b2=0.001, e=1e-8):\n\tupdates = []\n\tgrads = T.grad(cost, wrt=tparams.values())\n\ti = theano.shared(numpy_floatX(0.))\n\ti_t = i + 1.\n\tfix1 = 1. - (1. - b1)**i_t\n\tfix2 = 1. - (1. - b2)**i_t\n\tlr_t = lr * (T.sqrt(fix2) / fix1)\n\tfor p, g in zip(tparams.values(), grads):\n\t\tm = theano.shared(p.get_value() * 0.)\n\t\tv = theano.shared(p.get_value() * 0.)\n\t\tm_t = (b1 * g) + ((1. - b1) * m)\n\t\tv_t = (b2 * T.sqr(g)) + ((1. - b2) * v)\n\t\tg_t = m_t / (T.sqrt(v_t) + e)\n\t\tp_t = p - (lr_t * g_t)\n\t\tupdates.append((m, m_t))\n\t\tupdates.append((v, v_t))\n\t\tupdates.append((p, p_t))\n\tupdates.append((i, i_t))\n\treturn updates\n\ndef padMatrixWithTime(seqs, times, options):\n\tlengths = np.array([len(seq) for seq in seqs]).astype('int32')\n\tn_samples = len(seqs)\n\tmaxlen = np.max(lengths)\n\n\tx = np.zeros((maxlen, n_samples, options['inputDimSize'])).astype(config.floatX)\n\tt = np.zeros((maxlen, n_samples)).astype(config.floatX)\n\tfor idx, (seq,time) in enumerate(zip(seqs,times)):\n\t\tfor xvec, subseq in zip(x[:,idx,:], seq):\n\t\t\txvec[subseq] = 1.\n\t\tt[:lengths[idx], idx] = time\n\n\tif options['useLogTime']: t = np.log(t + 1.)\n\n\treturn x, t, lengths\n\ndef padMatrixWithoutTime(seqs, options):\n\tlengths = np.array([len(seq) for seq in seqs]).astype('int32')\n\tn_samples = len(seqs)\n\tmaxlen = np.max(lengths)\n\n\tx = np.zeros((maxlen, n_samples, options['inputDimSize'])).astype(config.floatX)\n\tfor idx, seq in enumerate(seqs):\n\t\tfor xvec, subseq in zip(x[:,idx,:], seq):\n\t\t\txvec[subseq] = 1.\n\n\treturn x, lengths\n\ndef load_data_simple(seqFile, labelFile, timeFile=''):\n\tsequences = np.array(pickle.load(open(seqFile, 'rb')))\n\tlabels = np.array(pickle.load(open(labelFile, 'rb')))\n\tif len(timeFile) > 0:\n\t\ttimes = np.array(pickle.load(open(timeFile, 'rb')))\n\n\tdataSize = len(labels)\n\tnp.random.seed(0)\n\tind = np.random.permutation(dataSize)\n\tnTest = int(_TEST_RATIO * dataSize)\n\tnValid = int(_VALIDATION_RATIO * dataSize)\n\n\ttest_indices = ind[:nTest]\n\tvalid_indices = ind[nTest:nTest+nValid]\n\ttrain_indices = ind[nTest+nValid:]\n\n\ttrain_set_x = sequences[train_indices]\n\ttrain_set_y = labels[train_indices]\n\ttest_set_x = sequences[test_indices]\n\ttest_set_y = labels[test_indices]\n\tvalid_set_x = sequences[valid_indices]\n\tvalid_set_y = labels[valid_indices]\n\ttrain_set_t = None\n\ttest_set_t = None\n\tvalid_set_t = None\n\n\tif len(timeFile) > 0:\n\t\ttrain_set_t = times[train_indices]\n\t\ttest_set_t = times[test_indices]\n\t\tvalid_set_t = times[valid_indices]\n\n\tdef len_argsort(seq):\n\t\treturn sorted(range(len(seq)), key=lambda x: len(seq[x]))\n\n\ttrain_sorted_index = len_argsort(train_set_x)\n\ttrain_set_x = [train_set_x[i] for i in train_sorted_index]\n\ttrain_set_y = [train_set_y[i] for i in train_sorted_index]\n\n\tvalid_sorted_index = len_argsort(valid_set_x)\n\tvalid_set_x = [valid_set_x[i] for i in valid_sorted_index]\n\tvalid_set_y = [valid_set_y[i] for i in valid_sorted_index]\n\n\ttest_sorted_index = len_argsort(test_set_x)\n\ttest_set_x = [test_set_x[i] for i in test_sorted_index]\n\ttest_set_y = [test_set_y[i] for i in test_sorted_index]\n\n\tif len(timeFile) > 0:\n\t\ttrain_set_t = [train_set_t[i] for i in train_sorted_index]\n\t\tvalid_set_t = [valid_set_t[i] for i in valid_sorted_index]\n\t\ttest_set_t = [test_set_t[i] for i in test_sorted_index]\n\n\ttrain_set = (train_set_x, train_set_y, train_set_t)\n\tvalid_set = (valid_set_x, valid_set_y, valid_set_t)\n\ttest_set = (test_set_x, test_set_y, test_set_t)\n\n\treturn train_set, valid_set, test_set\n\n\ndef load_data(seqFile, labelFile, timeFile):\n\ttrain_set_x = pickle.load(open(seqFile+'.train', 'rb'))\n\tvalid_set_x = pickle.load(open(seqFile+'.valid', 'rb'))\n\ttest_set_x = pickle.load(open(seqFile+'.test', 'rb'))\n\ttrain_set_y = pickle.load(open(labelFile+'.train', 'rb'))\n\tvalid_set_y = pickle.load(open(labelFile+'.valid', 'rb'))\n\ttest_set_y = pickle.load(open(labelFile+'.test', 'rb'))\n\ttrain_set_t = None\n\tvalid_set_t = None\n\ttest_set_t = None\n\n\tif len(timeFile) > 0:\n\t\ttrain_set_t = pickle.load(open(timeFile+'.train', 'rb'))\n\t\tvalid_set_t = pickle.load(open(timeFile+'.valid', 'rb'))\n\t\ttest_set_t = pickle.load(open(timeFile+'.test', 'rb'))\n\n\tdef len_argsort(seq):\n\t\treturn sorted(range(len(seq)), key=lambda x: len(seq[x]))\n\n\ttrain_sorted_index = len_argsort(train_set_x)\n\ttrain_set_x = [train_set_x[i] for i in train_sorted_index]\n\ttrain_set_y = [train_set_y[i] for i in train_sorted_index]\n\n\tvalid_sorted_index = len_argsort(valid_set_x)\n\tvalid_set_x = [valid_set_x[i] for i in valid_sorted_index]\n\tvalid_set_y = [valid_set_y[i] for i in valid_sorted_index]\n\n\ttest_sorted_index = len_argsort(test_set_x)\n\ttest_set_x = [test_set_x[i] for i in test_sorted_index]\n\ttest_set_y = [test_set_y[i] for i in test_sorted_index]\n\n\tif len(timeFile) > 0:\n\t\ttrain_set_t = [train_set_t[i] for i in train_sorted_index]\n\t\tvalid_set_t = [valid_set_t[i] for i in valid_sorted_index]\n\t\ttest_set_t = [test_set_t[i] for i in test_sorted_index]\n\n\ttrain_set = (train_set_x, train_set_y, train_set_t)\n\tvalid_set = (valid_set_x, valid_set_y, valid_set_t)\n\ttest_set = (test_set_x, test_set_y, test_set_t)\n\n\treturn train_set, valid_set, test_set\n\ndef calculate_auc(test_model, dataset, options):\n\tbatchSize = options['batchSize']\n\tuseTime = options['useTime']\n\t\n\tn_batches = int(np.ceil(float(len(dataset[0])) / float(batchSize)))\n\tscoreVec = []\n\tfor index in xrange(n_batches):\n\t\tbatchX = dataset[0][index*batchSize:(index+1)*batchSize]\n\t\tif useTime:\n\t\t\tbatchT = dataset[2][index*batchSize:(index+1)*batchSize]\n\t\t\tx, t, lengths = padMatrixWithTime(batchX, batchT, options)\n\t\t\tscores = test_model(x, t, lengths)\n\t\telse:\n\t\t\tx, lengths = padMatrixWithoutTime(batchX, options)\n\t\t\tscores = test_model(x, lengths)\n\t\tscoreVec.extend(list(scores))\n\tlabels = dataset[1]\n\tauc = roc_auc_score(list(labels), list(scoreVec))\n\treturn auc\n\ndef calculate_cost(test_model, dataset, options):\n\tbatchSize = options['batchSize']\n\tuseTime = options['useTime']\n\t\n\tcostSum = 0.0\n\tdataCount = 0\n\t\n\tn_batches = int(np.ceil(float(len(dataset[0])) / float(batchSize)))\n\tfor index in xrange(n_batches):\n\t\tbatchX = dataset[0][index*batchSize:(index+1)*batchSize]\n\t\tif useTime:\n\t\t\tbatchT = dataset[2][index*batchSize:(index+1)*batchSize]\n\t\t\tx, t, lengths = padMatrixWithTime(batchX, batchT, options)\n\t\t\ty = np.array(dataset[1][index*batchSize:(index+1)*batchSize]).astype(config.floatX)\n\t\t\tscores = test_model(x, y, t, lengths)\n\t\telse:\n\t\t\tx, lengths = padMatrixWithoutTime(batchX, options)\n\t\t\ty = np.array(dataset[1][index*batchSize:(index+1)*batchSize]).astype(config.floatX)\n\t\t\tscores = test_model(x, y, lengths)\n\t\tcostSum += scores * len(batchX)\n\t\tdataCount += len(batchX)\n\treturn costSum / dataCount\n\ndef print2file(buf, outFile):\n\toutfd = open(outFile, 'a')\n\toutfd.write(buf + '\\n')\n\toutfd.close()\n\ndef train_RETAIN(\n\tseqFile='seqFile.txt',\n\tinputDimSize=20000,\n\tlabelFile='labelFile.txt',\n\tnumClass=1,\n\toutFile='outFile.txt',\n\ttimeFile='',\n\tmodelFile='model.npz',\n\tuseLogTime=True,\n\tembFile='embFile.txt',\n\tembDimSize=128,\n\tembFineTune=True,\n\talphaHiddenDimSize=128,\n\tbetaHiddenDimSize=128,\n\tbatchSize=100,\n\tmax_epochs=10,\n\tL2_output=0.001,\n\tL2_emb=0.001,\n\tL2_alpha=0.001,\n\tL2_beta=0.001,\n\tkeepProbEmb=0.5,\n\tkeepProbContext=0.5,\n\tlogEps=1e-8,\n\tsolver='adadelta',\n\tsimpleLoad=False,\n\tverbose=False\n):\n\toptions = locals().copy()\n\n\tif len(timeFile) > 0: useTime = True\n\telse: useTime = False\n\toptions['useTime'] = useTime\n\t\n\tprint 'Initializing the parameters ... ',\n\tparams = init_params(options)\n\tif len(modelFile) > 0: params = load_params(options)\n\ttparams = init_tparams(params, options)\n\n\tprint 'Building the model ... ',\n\tif useTime and embFineTune:\n\t\tprint 'using time information, fine-tuning code representations'\n\t\tuse_noise, x, y, t, lengths, cost_noreg, cost, y_hat =  build_model(tparams, options)\n\t\tif solver=='adadelta':\n\t\t\tgrads = T.grad(cost, wrt=tparams.values())\n\t\t\tf_grad_shared, f_update = adadelta(tparams, grads, x, y, lengths, cost, options, t)\n\t\telif solver=='adam':\n\t\t\tupdates = adam(cost, tparams)\n\t\t\tupdate_model = theano.function(inputs=[x, y, t, lengths], outputs=cost, updates=updates, name='update_model')\n\t\tget_prediction = theano.function(inputs=[x, t, lengths], outputs=y_hat, name='get_prediction')\n\t\tget_cost = theano.function(inputs=[x, y, t, lengths], outputs=cost_noreg, name='get_cost')\n\telif useTime and not embFineTune:\n\t\tprint 'using time information, not fine-tuning code representations'\n\t\tW_emb = theano.shared(params['W_emb'], name='W_emb')\n\t\tuse_noise, x, y, t, lengths, cost_noreg, cost, y_hat =  build_model(tparams, options, W_emb)\n\t\tif solver=='adadelta':\n\t\t\tgrads = T.grad(cost, wrt=tparams.values())\n\t\t\tf_grad_shared, f_update = adadelta(tparams, grads, x, y, lengths, cost, options, t)\n\t\telif solver=='adam':\n\t\t\tupdates = adam(cost, tparams)\n\t\t\tupdate_model = theano.function(inputs=[x, y, t, lengths], outputs=cost, updates=updates, name='update_model')\n\t\tget_prediction = theano.function(inputs=[x, t, lengths], outputs=y_hat, name='get_prediction')\n\t\tget_cost = theano.function(inputs=[x, y, t, lengths], outputs=cost_noreg, name='get_cost')\n\telif not useTime and embFineTune:\n\t\tprint 'not using time information, fine-tuning code representations'\n\t\tuse_noise, x, y, lengths, cost_noreg, cost, y_hat =  build_model(tparams, options)\n\t\tif solver=='adadelta':\n\t\t\tgrads = T.grad(cost, wrt=tparams.values())\n\t\t\tf_grad_shared, f_update = adadelta(tparams, grads, x, y, lengths, cost, options)\n\t\telif solver=='adam':\n\t\t\tupdates = adam(cost, tparams)\n\t\t\tupdate_model = theano.function(inputs=[x, y, lengths], outputs=cost, updates=updates, name='update_model')\n\t\tget_prediction = theano.function(inputs=[x, lengths], outputs=y_hat, name='get_prediction')\n\t\tget_cost = theano.function(inputs=[x, y, lengths], outputs=cost_noreg, name='get_cost')\n\telif not useTime and not embFineTune:\n\t\tprint 'not using time information, not fine-tuning code representations'\n\t\tW_emb = theano.shared(params['W_emb'], name='W_emb')\n\t\tuse_noise, x, y, lengths, cost_noreg, cost, y_hat =  build_model(tparams, options, W_emb)\n\t\tif solver=='adadelta':\n\t\t\tgrads = T.grad(cost, wrt=tparams.values())\n\t\t\tf_grad_shared, f_update = adadelta(tparams, grads, x, y, lengths, cost, options)\n\t\telif solver=='adam':\n\t\t\tupdates = adam(cost, tparams)\n\t\t\tupdate_model = theano.function(inputs=[x, y, lengths], outputs=cost, updates=updates, name='update_model')\n\t\tget_prediction = theano.function(inputs=[x, lengths], outputs=y_hat, name='get_prediction')\n\t\tget_cost = theano.function(inputs=[x, y, lengths], outputs=cost_noreg, name='get_cost')\n\n\tprint 'Loading data ... ',\n\tif simpleLoad:\n\t\ttrainSet, validSet, testSet = load_data_simple(seqFile, labelFile, timeFile)\n\telse:\n\t\ttrainSet, validSet, testSet = load_data(seqFile, labelFile, timeFile)\n\tn_batches = int(np.ceil(float(len(trainSet[0])) / float(batchSize)))\n\tprint 'done'\n\n\tbestValidAuc = 0.0\n\tbestTestAuc = 0.0\n\tbestValidEpoch = 0\n\tlogFile = outFile + '.log'\n\tprint 'Optimization start !!'\n\tfor epoch in xrange(max_epochs):\n\t\titeration = 0\n\t\tcostVector = []\n\t\tfor index in random.sample(range(n_batches), n_batches):\n\t\t\tuse_noise.set_value(1.)\n\t\t\tbatchX = trainSet[0][index*batchSize:(index+1)*batchSize]\n\t\t\ty = np.array(trainSet[1][index*batchSize:(index+1)*batchSize]).astype(config.floatX)\n\n\t\t\tif useTime:\n\t\t\t\tbatchT = trainSet[2][index*batchSize:(index+1)*batchSize]\n\t\t\t\tx, t, lengths = padMatrixWithTime(batchX, batchT, options)\n\t\t\t\tif solver=='adadelta':\n\t\t\t\t\tcostValue = f_grad_shared(x, y, t, lengths)\n\t\t\t\t\tf_update()\n\t\t\t\telif solver=='adam':\n\t\t\t\t\tcostValue = update_model(x, y, t, lengths)\n\t\t\telse:\n\t\t\t\tx, lengths = padMatrixWithoutTime(batchX, options)\n\t\t\t\tif solver=='adadelta':\n\t\t\t\t\tcostValue = f_grad_shared(x, y, lengths)\n\t\t\t\t\tf_update()\n\t\t\t\telif solver=='adam':\n\t\t\t\t\tcostValue = update_model(x, y, lengths)\n\t\t\tcostVector.append(costValue)\n\t\t\tif (iteration % 10 == 0) and verbose: \n\t\t\t\tprint 'Epoch:%d, Iteration:%d/%d, Train_Cost:%f' % (epoch, iteration, n_batches, costValue)\n\t\t\titeration += 1\n\n\t\tuse_noise.set_value(0.)\n\t\ttrainCost = np.mean(costVector)\n\t\tvalidAuc = calculate_auc(get_prediction, validSet, options)\n\t\tbuf = 'Epoch:%d, Train_cost:%f, Validation_AUC:%f' % (epoch, trainCost, validAuc)\n\t\tprint buf\n\t\tprint2file(buf, logFile)\n\t\tif validAuc > bestValidAuc: \n\t\t\tbestValidAuc = validAuc\n\t\t\tbestValidEpoch = epoch\n\t\t\tbestTestAuc = calculate_auc(get_prediction, testSet, options)\n\t\t\tbuf = 'Currently the best validation AUC found. Test AUC:%f at epoch:%d' % (bestTestAuc, epoch)\n\t\t\tprint buf\n\t\t\tprint2file(buf, logFile)\n\t\t\ttempParams = unzip(tparams)\n\t\t\tnp.savez_compressed(outFile + '.' + str(epoch), **tempParams)\n\tbuf = 'The best validation & test AUC:%f, %f at epoch:%d' % (bestValidAuc, bestTestAuc, bestValidEpoch)\n\tprint buf\n\tprint2file(buf, logFile)\n\t\ndef parse_arguments(parser):\n\tparser.add_argument('seq_file', type=str, metavar='<visit_file>', help='The path to the Pickled file containing visit information of patients')\n\tparser.add_argument('n_input_codes', type=int, metavar='<n_input_codes>', help='The number of unique input medical codes')\n\tparser.add_argument('label_file', type=str, metavar='<label_file>', help='The path to the Pickled file containing label information of patients')\n\t#parser.add_argument('n_output_codes', type=int, metavar='<n_output_codes>', help='The number of unique label medical codes')\n\tparser.add_argument('out_file', metavar='<out_file>', help='The path to the output models. The models will be saved after every epoch')\n\tparser.add_argument('--time_file', type=str, default='', help='The path to the Pickled file containing durations between visits of patients. If you are not using duration information, do not use this option')\n\tparser.add_argument('--model_file', type=str, default='', help='The path to the Numpy-compressed file containing the model parameters. Use this option if you want to re-train an existing model')\n\tparser.add_argument('--use_log_time', type=int, default=1, choices=[0,1], help='Use logarithm of time duration to dampen the impact of the outliers (0 for false, 1 for true) (default value: 1)')\n\tparser.add_argument('--embed_file', type=str, default='', help='The path to the Pickled file containing the representation vectors of medical codes. If you are not using medical code representations, do not use this option')\n\tparser.add_argument('--embed_size', type=int, default=128, help='The size of the visit embedding. If you are not providing your own medical code vectors, you can specify this value (default value: 128)')\n\tparser.add_argument('--embed_finetune', type=int, default=1, choices=[0,1], help='If you are using randomly initialized code representations, always use this option. If you are using an external medical code representations, and you want to fine-tune them as you train RETAIN, use this option (0 for false, 1 for true) (default value: 1)')\n\tparser.add_argument('--alpha_hidden_dim_size', type=int, default=128, help='The size of the hidden layers of the GRU responsible for generating alpha weights (default value: 128)')\n\tparser.add_argument('--beta_hidden_dim_size', type=int, default=128, help='The size of the hidden layers of the GRU responsible for generating beta weights (default value: 128)')\n\tparser.add_argument('--batch_size', type=int, default=100, help='The size of a single mini-batch (default value: 100)')\n\tparser.add_argument('--n_epochs', type=int, default=10, help='The number of training epochs (default value: 10)')\n\tparser.add_argument('--L2_output', type=float, default=0.001, help='L2 regularization for the final classifier weight w (default value: 0.001)')\n\tparser.add_argument('--L2_emb', type=float, default=0.001, help='L2 regularization for the input embedding weight W_emb (default value: 0.001)')\n\tparser.add_argument('--L2_alpha', type=float, default=0.001, help='L2 regularization for the alpha generating weight w_alpha (default value: 0.001).')\n\tparser.add_argument('--L2_beta', type=float, default=0.001, help='L2 regularization for the input embedding weight W_beta (default value: 0.001)')\n\tparser.add_argument('--keep_prob_emb', type=float, default=0.5, help='Decides how much you want to keep during the dropout between the embedded input and the alpha & beta generation process (default value: 0.5)')\n\tparser.add_argument('--keep_prob_context', type=float, default=0.5, help='Decides how much you want to keep during the dropout between the context vector c_i and the final classifier (default value: 0.5)')\n\tparser.add_argument('--log_eps', type=float, default=1e-8, help='A small value to prevent log(0) (default value: 1e-8)')\n\tparser.add_argument('--solver', type=str, default='adadelta', choices=['adadelta','adam'], help='Select which solver to train RETAIN: adadelta, or adam. (default: adadelta)')\n\tparser.add_argument('--simple_load', action='store_true', help='Use an alternative way to load the dataset. Instead of you having to provide a trainign set, validation set, test set, this will automatically divide the dataset. (default false)')\n\tparser.add_argument('--verbose', action='store_true', help='Print output after every 100 mini-batches (default false)')\n\targs = parser.parse_args()\n\treturn args\n\nif __name__ == '__main__':\n\tparser = argparse.ArgumentParser()\n\targs = parse_arguments(parser)\n\n\ttrain_RETAIN(\n\t\tseqFile=args.seq_file, \n\t\tinputDimSize=args.n_input_codes, \n\t\tlabelFile=args.label_file, \n\t\t#numClass=args.n_output_codes, \n\t\toutFile=args.out_file, \n\t\ttimeFile=args.time_file, \n\t\tmodelFile=args.model_file,\n\t\tuseLogTime=args.use_log_time,\n\t\tembFile=args.embed_file, \n\t\tembDimSize=args.embed_size, \n\t\tembFineTune=args.embed_finetune, \n\t\talphaHiddenDimSize=args.alpha_hidden_dim_size,\n\t\tbetaHiddenDimSize=args.beta_hidden_dim_size,\n\t\tbatchSize=args.batch_size, \n\t\tmax_epochs=args.n_epochs, \n\t\tL2_output=args.L2_output, \n\t\tL2_emb=args.L2_emb, \n\t\tL2_alpha=args.L2_alpha, \n\t\tL2_beta=args.L2_beta, \n\t\tkeepProbEmb=args.keep_prob_emb, \n\t\tkeepProbContext=args.keep_prob_context, \n\t\tlogEps=args.log_eps, \n\t\tsolver=args.solver,\n\t\tsimpleLoad=args.simple_load,\n\t\tverbose=args.verbose\n\t)\n"
  },
  {
    "path": "test_retain.py",
    "content": "#################################################################\n# Code written by Edward Choi (mp2893@gatech.edu)\n# For bug report, please contact author using the email address\n#################################################################\n\nimport sys, random\nimport numpy as np\nimport cPickle as pickle\nfrom collections import OrderedDict\nimport argparse\n\nimport theano\nimport theano.tensor as T\nfrom theano import config\n\ndef sigmoid(x):\n  return 1. / (1. + np.exp(-x))\n\ndef numpy_floatX(data):\n\treturn np.asarray(data, dtype=config.floatX)\n\ndef load_embedding(infile):\n\tWemb = np.array(pickle.load(open(infile, 'rb'))).astype(config.floatX)\n\treturn Wemb\n\ndef load_params(options):\n\tparams = OrderedDict()\n\tweights = np.load(options['modelFile'])\n\tfor k,v in weights.iteritems():\n\t\tparams[k] = v\n\tif len(options['embFile']) > 0: params['W_emb'] = np.array(pickle.load(open(options['embFile'], 'rb'))).astype(config.floatX)\n\treturn params\n\ndef init_tparams(params, options):\n\ttparams = OrderedDict()\n\tfor key, value in params.iteritems():\n\t\ttparams[key] = theano.shared(value, name=key)\n\treturn tparams\n\ndef _slice(_x, n, dim):\n\tif _x.ndim == 3:\n\t\treturn _x[:, :, n*dim:(n+1)*dim]\n\treturn _x[:, n*dim:(n+1)*dim]\n\ndef gru_layer(tparams, emb, name, hiddenDimSize):\n\ttimesteps = emb.shape[0]\n\tif emb.ndim == 3: n_samples = emb.shape[1]\n\telse: n_samples = 1\n\n\tdef stepFn(wx, h, U_gru):\n\t\tuh = T.dot(h, U_gru)\n\t\tr = T.nnet.sigmoid(_slice(wx, 0, hiddenDimSize) + _slice(uh, 0, hiddenDimSize))\n\t\tz = T.nnet.sigmoid(_slice(wx, 1, hiddenDimSize) + _slice(uh, 1, hiddenDimSize))\n\t\th_tilde = T.tanh(_slice(wx, 2, hiddenDimSize) + r * _slice(uh, 2, hiddenDimSize))\n\t\th_new = z * h + ((1. - z) * h_tilde)\n\t\treturn h_new\n\n\tWx = T.dot(emb, tparams['W_gru_'+name]) + tparams['b_gru_'+name]\n\tresults, updates = theano.scan(fn=stepFn, sequences=[Wx], outputs_info=T.alloc(numpy_floatX(0.0), n_samples, hiddenDimSize), non_sequences=[tparams['U_gru_'+name]], name='gru_layer', n_steps=timesteps)\n\n\treturn results\n\t\ndef build_model(tparams, options):\n\talphaHiddenDimSize = options['alphaHiddenDimSize']\n\tbetaHiddenDimSize = options['betaHiddenDimSize']\n\n\tx = T.tensor3('x', dtype=config.floatX)\n\n\treverse_emb_t = x[::-1]\n\treverse_h_a = gru_layer(tparams, reverse_emb_t, 'a', alphaHiddenDimSize)[::-1] * 0.5\n\treverse_h_b = gru_layer(tparams, reverse_emb_t, 'b', betaHiddenDimSize)[::-1] * 0.5\n\n\tpreAlpha = T.dot(reverse_h_a, tparams['w_alpha']) + tparams['b_alpha']\n\tpreAlpha = preAlpha.reshape((preAlpha.shape[0], preAlpha.shape[1]))\n\talpha = (T.nnet.softmax(preAlpha.T)).T\n\n\tbeta = T.tanh(T.dot(reverse_h_b, tparams['W_beta']) + tparams['b_beta'])\n\t\n\treturn x, alpha, beta\n\ndef padMatrixWithTime(seqs, times, options):\n\tlengths = np.array([len(seq) for seq in seqs]).astype('int32')\n\tn_samples = len(seqs)\n\tmaxlen = np.max(lengths)\n\n\tx = np.zeros((maxlen, n_samples, options['inputDimSize'])).astype(config.floatX)\n\tt = np.zeros((maxlen, n_samples)).astype(config.floatX)\n\tfor idx, (seq,time) in enumerate(zip(seqs,times)):\n\t\tfor xvec, subseq in zip(x[:,idx,:], seq):\n\t\t\txvec[subseq] = 1.\n\t\tt[:lengths[idx], idx] = time\n\n\tif options['useLogTime']: t = np.log(t + 1.)\n\n\treturn x, t, lengths\n\ndef padMatrixWithoutTime(seqs, options):\n\tlengths = np.array([len(seq) for seq in seqs]).astype('int32')\n\tn_samples = len(seqs)\n\tmaxlen = np.max(lengths)\n\n\tx = np.zeros((maxlen, n_samples, options['inputDimSize'])).astype(config.floatX)\n\tfor idx, seq in enumerate(seqs):\n\t\tfor xvec, subseq in zip(x[:,idx,:], seq):\n\t\t\txvec[subseq] = 1.\n\n\treturn x, lengths\n\ndef load_data_debug(seqFile, labelFile, timeFile=''):\n\tsequences = np.array(pickle.load(open(seqFile, 'rb')))\n\tlabels = np.array(pickle.load(open(labelFile, 'rb')))\n\tif len(timeFile) > 0:\n\t\ttimes = np.array(pickle.load(open(timeFile, 'rb')))\n\n\tdataSize = len(labels)\n\tnp.random.seed(0)\n\tind = np.random.permutation(dataSize)\n\tnTest = int(0.15 * dataSize)\n\tnValid = int(0.10 * dataSize)\n\n\ttest_indices = ind[:nTest]\n\tvalid_indices = ind[nTest:nTest+nValid]\n\ttrain_indices = ind[nTest+nValid:]\n\n\ttrain_set_x = sequences[train_indices]\n\ttrain_set_y = labels[train_indices]\n\ttest_set_x = sequences[test_indices]\n\ttest_set_y = labels[test_indices]\n\tvalid_set_x = sequences[valid_indices]\n\tvalid_set_y = labels[valid_indices]\n\ttrain_set_t = None\n\ttest_set_t = None\n\tvalid_set_t = None\n\n\tif len(timeFile) > 0:\n\t\ttrain_set_t = times[train_indices]\n\t\ttest_set_t = times[test_indices]\n\t\tvalid_set_t = times[valid_indices]\n\n\tdef len_argsort(seq):\n\t\treturn sorted(range(len(seq)), key=lambda x: len(seq[x]))\n\n\ttrain_sorted_index = len_argsort(train_set_x)\n\ttrain_set_x = [train_set_x[i] for i in train_sorted_index]\n\ttrain_set_y = [train_set_y[i] for i in train_sorted_index]\n\n\tvalid_sorted_index = len_argsort(valid_set_x)\n\tvalid_set_x = [valid_set_x[i] for i in valid_sorted_index]\n\tvalid_set_y = [valid_set_y[i] for i in valid_sorted_index]\n\n\ttest_sorted_index = len_argsort(test_set_x)\n\ttest_set_x = [test_set_x[i] for i in test_sorted_index]\n\ttest_set_y = [test_set_y[i] for i in test_sorted_index]\n\n\tif len(timeFile) > 0:\n\t\ttrain_set_t = [train_set_t[i] for i in train_sorted_index]\n\t\tvalid_set_t = [valid_set_t[i] for i in valid_sorted_index]\n\t\ttest_set_t = [test_set_t[i] for i in test_sorted_index]\n\n\ttrain_set = (train_set_x, train_set_y, train_set_t)\n\tvalid_set = (valid_set_x, valid_set_y, valid_set_t)\n\ttest_set = (test_set_x, test_set_y, test_set_t)\n\n\treturn train_set, valid_set, test_set\n\ndef load_data(dataFile, labelFile, timeFile):\n\ttest_set_x = np.array(pickle.load(open(dataFile, 'rb')))\n\ttest_set_y = np.array(pickle.load(open(labelFile, 'rb')))\n\ttest_set_t = None\n\tif len(timeFile) > 0:\n\t\ttest_set_t = np.array(pickle.load(open(timeFile, 'rb')))\n\n\tdef len_argsort(seq):\n\t\treturn sorted(range(len(seq)), key=lambda x: len(seq[x]))\n\n\tsorted_index = len_argsort(test_set_x)\n\ttest_set_x = [test_set_x[i] for i in sorted_index]\n\ttest_set_y = [test_set_y[i] for i in sorted_index]\n\tif len(timeFile) > 0:\n\t\ttest_set_t = [test_set_t[i] for i in sorted_index]\n\t\n\ttest_set = (test_set_x, test_set_y, test_set_t)\n\n\treturn test_set\n\ndef print2file(buf, outFile):\n\toutfd = open(outFile, 'a')\n\toutfd.write(buf + '\\n')\n\toutfd.close()\n\ndef train_RETAIN(\n\tmodelFile='model.npz',\n\tseqFile='seqFile.txt',\n\tlabelFile='labelFile.txt',\n\toutFile='outFile.txt',\n\ttimeFile='timeFile.txt',\n\ttypeFile='types.txt',\n\tuseLogTime=True,\n\tembFile='embFile.txt',\n\tlogEps=1e-8\n):\n\toptions = locals().copy()\n\n\tif len(timeFile) > 0: useTime = True\n\telse: useTime = False\n\toptions['useTime'] = useTime\n\n\tif len(embFile) > 0: useFixedEmb = True\n\telse: useFixedEmb = False\n\toptions['useFixedEmb'] = useFixedEmb\n\t\n\tprint 'Loading the parameters ... ',\n\tparams = load_params(options)\n\ttparams = init_tparams(params, options)\n\n\toptions['alphaHiddenDimSize'] = params['w_alpha'].shape[0]\n\toptions['betaHiddenDimSize'] = params['W_beta'].shape[0]\n\toptions['inputDimSize'] = params['W_emb'].shape[0]\n\n\tprint 'Building the model ... ',\n\tx, alpha, beta =  build_model(tparams, options)\n\tget_result = theano.function(inputs=[x], outputs=[alpha, beta], name='get_result')\n\n\tprint 'Loading data ... ',\n\ttestSet = load_data(seqFile, labelFile, timeFile)\n\tprint 'done'\n\n\ttypes = pickle.load(open(typeFile, 'rb'))\n\trtypes = dict([(v,k) for k,v in types.iteritems()])\n\n\tprint 'Contribution calculation start!!'\n\tcount = 0\n\toutfd = open(outFile, 'w')\n\tfor index in range(len(testSet[0])):\n\t\tif count % 100 == 0: print 'processed %d patients' % count\n\t\tcount += 1\n\t\tbatchX = [testSet[0][index]]\n\t\tlabel = testSet[1][index]\n\n\t\tif useTime: \n\t\t\tbatchT = [testSet[2][index]]\n\t\t\tx, t, lengths = padMatrixWithTime(batchX, batchT, options)\n\t\telse:\n\t\t\tx, lengths = padMatrixWithoutTime(batchX, options)\n\n\t\tn_timesteps = x.shape[0]\n\t\tn_samples = x.shape[1]\n\t\temb = np.dot(x, params['W_emb'])\n\n\t\tif useTime:\n\t\t\ttemb = np.concatenate([emb, t.reshape((n_timesteps,n_samples,1))], axis=2)\n\t\telse:\n\t\t\ttemb = emb\n\n\t\talpha, beta = get_result(temb)\n\t\talpha = alpha[:,0]\n\t\tbeta = beta[:,0,:]\n\n\t\tct = (alpha[:,None] * beta * emb[:,0,:]).sum(axis=0)\n\t\ty_t = sigmoid(np.dot(ct, params['w_output']) + params['b_output'])\n\n\t\tbuf = ''\n\t\tpatient = batchX[0]\n\t\tfor i in range(len(patient)):\n\t\t\tvisit = patient[i]\n\t\t\tbuf += '-------------- visit_index:%d ---------------\\n' % i\n\t\t\tfor j in range(len(visit)):\n\t\t\t\tcode = visit[j]\n\t\t\t\tcontribution = np.dot(params['w_output'].flatten(), alpha[i] * beta[i] * params['W_emb'][code])\n\t\t\t\tbuf += '%s:%f  ' % (rtypes[code], contribution)\n\t\t\tbuf += '\\n------------------------------------\\n'\n\t\tbuf += 'patient_index:%d, label:%d, score:%f\\n\\n' % (index, label, y_t)\n\t\toutfd.write(buf + '\\n')\n\toutfd.close()\n\t\ndef parse_arguments(parser):\n\tparser.add_argument('model_file', type=str, metavar='<model_file>', help='The path to the Numpy-compressed file containing the model parameters.')\n\tparser.add_argument('seq_file', type=str, metavar='<visit_file>', help='The path to the cPickled file containing visit information of patients')\n\tparser.add_argument('label_file', type=str, metavar='<label_file>', help='The path to the cPickled file containing label information of patients')\n\tparser.add_argument('type_file', type=str, metavar='<type_file>', help='The path to the cPickled dictionary for mapping medical code strings to integers')\n\tparser.add_argument('out_file', metavar='<out_file>', help='The path to the output models. The models will be saved after every epoch')\n\tparser.add_argument('--time_file', type=str, default='', help='The path to the cPickled file containing durations between visits of patients. If you are not using duration information, do not use this option')\n\tparser.add_argument('--use_log_time', type=int, default=1, choices=[0,1], help='Use logarithm of time duration to dampen the impact of the outliers (0 for false, 1 for true) (default value: 1)')\n\tparser.add_argument('--embed_file', type=str, default='', help='The path to the cPickled file containing the representation vectors of medical codes. If you are not using medical code representations, do not use this option')\n\targs = parser.parse_args()\n\treturn args\n\nif __name__ == '__main__':\n\tparser = argparse.ArgumentParser()\n\targs = parse_arguments(parser)\n\n\ttrain_RETAIN(\n\t\tmodelFile=args.model_file,\n\t\tseqFile=args.seq_file, \n\t\tlabelFile=args.label_file, \n\t\ttypeFile=args.type_file, \n\t\toutFile=args.out_file, \n\t\ttimeFile=args.time_file, \n\t\tuseLogTime=args.use_log_time,\n\t\tembFile=args.embed_file\n\t)\n"
  }
]