[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n"
  },
  {
    "path": "LICENSE",
    "content": "\n"
  },
  {
    "path": "README.md",
    "content": "# Inter- and intra- patient ECG heartbeat classification for arrhythmia detection: a sequence to sequence deep learning approach\n\n# Paper\n Our paper can be downloaded from the [arxiv website](https://arxiv.org/pdf/1812.07421v2)\n * The Network architecture\n  ![Alt text](/images/seq2seq_b.jpg)\n \n## Requirements\n* Python 2.7\n* tensorflow/tensorflow-gpu\n* numpy\n* scipy\n* scikit-learn\n* matplotlib\n* imbalanced-learn (0.4.3)\n\n## Dataset\nWe evaluated our model using [the PhysioNet MIT-BIH Arrhythmia database](https://www.physionet.org/physiobank/database/mitdb/)\n* To download our pre-processed datasets use [this link](https://drive.google.com/drive/folders/19bDrAqlSGQuNLRmA-7pQRU9R81gSuY70?usp=sharing), then put them into the \"data\" folder.\n* Or you can follow the instructions of the readme file in the \"data preprocessing_Matlab\" folder to download the MIT-BIH database and perform data pre-processing. Then, put the pre-processed datasets into the \"data\" folder.\n\n## Train\n\n* Modify args settings in seq_seq_annot_aami.py for the intra-patient ECG heartbeat classification\n* Modify args settings in seq_seq_annot_DS1DS2.py for the inter-patient ECG heartbeat classification\n\n* Run each file to reproduce the model described in the paper, use:\n\n```\npython seq_seq_annot_aami.py --data_dir data/s2s_mitbih_aami --epochs 500\n```\n```\npython seq_seq_annot_DS1DS2.py --data_dir data/s2s_mitbih_aami_DS1DS2 --epochs 500\n```\n## Results\n  ![Alt text](/images/results.jpg)\n## Citation\nIf you find it useful, please cite our paper as follows:\n\n```\n@article{mousavi2018inter,\n  title={Inter-and intra-patient ECG heartbeat classification for arrhythmia detection: a sequence to sequence deep learning approach},\n  author={Mousavi, Sajad and Afghah, Fatemeh},\n  journal={arXiv preprint arXiv:1812.07421},\n  year={2018}\n}\n```\n\n## References\n [deepschool.io](https://github.com/sachinruk/deepschool.io/blob/master/DL-Keras_Tensorflow)\n \n## Licence \nFor academtic and non-commercial usage \n\n"
  },
  {
    "path": "data/readme",
    "content": "Downlaod datasets using the below link and put them into the \"data folder\":\n\n[this link](https://drive.google.com/drive/folders/19bDrAqlSGQuNLRmA-7pQRU9R81gSuY70?usp=sharing)\n"
  },
  {
    "path": "data preprocessing_Matlab/download_MITBIHDB.m",
    "content": "\r\n% create mitbih datasets: download all records and convert the records into .info,.mat and .txt(for annotations)\r\n% before doing it, you have to install the open-source WFDB Software Package available at\r\n% http://physionet.org/physiotools/wfdb.shtml \r\n\r\n% download the whole database but files have .dat and .atr extentions\r\n% rc=physionetdb('mitdb',1); \r\n\r\n% input\r\n% output .info, .hea, .mat and.txt files \r\n\r\nrecord_list = physionetdb('mitdb');\r\n\r\n\r\npath_to_exes = 'path_to_WFDB_Toolbox\\mcode\\nativelibs\\windows\\bin';\r\npath_to_save_records = 'path_to_downloaded_database';\r\n\r\n% path_to_exes = 'C:\\my_files\\ECG_research\\mcode\\nativelibs\\windows\\bin';\r\n% path_to_save_records = 'C:\\my_files\\ECG_dataset\\MIT-BIH\\mitbihdb';\r\n\r\nmkdir(path_to_save_records);\r\ncd(path_to_save_records);\r\ntic\r\nfor i=1:length(record_list)\r\n    \r\n    command_annot = char(strcat(path_to_exes, filesep, 'rdann.exe -r mitdb/', record_list(i), ' -a atr -v >', record_list(i), 'm.txt'));\r\n    system (command_annot);\r\n    command_mat_info_ = char(strcat(path_to_exes, filesep, 'wfdb2mat.exe -r mitdb/', record_list(i), ' >', record_list(i), 'm.info'));\r\n    system (command_mat_info_);\r\n    \r\n%     system('C:\\my_files\\ECG_research\\mcode\\nativelibs\\windows\\bin\\wfdb2mat.exe -r 100s -f 0 -t 10 >100sm.info')\r\n%     system('C:\\my_files\\ECG_research\\mcode\\nativelibs\\windows\\bin\\rdann.exe -r mitdb/100 -a atr >100.txt')\r\n    \r\nend\r\ntoc\r\ndisp('Successfully generated :)')\r\n\r\n"
  },
  {
    "path": "data preprocessing_Matlab/loadEcgSig.m",
    "content": "function [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig(Name)\r\n% USAGE: [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig('../data/200m')\r\n% This function reads a pair of files (RECORDm.mat and RECORDm.info) generated\r\n% by 'wfdb2mat' from a PhysioBank record, baseline-corrects and scales the time\r\n% series contained in the .mat file, returning time, amplitude and frequency.\r\n% The baseline-corrected and scaled time series are the rows of matrix 'val', and each\r\n% column contains simultaneous samples of each time series.\r\n% 'wfdb2mat' is part of the open-source WFDB Software Package available at\r\n%    http://physionet.org/physiotools/wfdb.shtml\r\n% If you have installed a working copy of 'wfdb2mat', run a shell command\r\n% such as wfdb2mat -r 100s -f 0 -t 10 >100sm.info\r\n% to create a pair of files ('100sm.mat', '100sm.info') that can be read\r\n% by this function.\r\n% The files needed by this function can also be produced by the\r\n% PhysioBank ATM, at http://physionet.org/cgi-bin/ATM\r\n\r\n% Adapted from\r\n% loadEcgSignal.m           O. Abdala\t\t\t16 March 2009\r\n% \t\t      James Hislop\t       27 January 2014\tversion 1.1\r\n\r\n% Last version\r\n% loadEcgSignal.m           D. Kawasaki\t\t\t15 June 2017\r\n% \t\t      Davi Kawasaki\t       15 June 2017 version 1.0\r\n\r\ninfoName = strcat(Name, '.info');\r\nmatName = strcat(Name, '.mat');\r\nload(matName);\r\necgsig = val;\r\nfid = fopen(infoName, 'rt');\r\nfgetl(fid);\r\nfgetl(fid);\r\nfgetl(fid);\r\n[freqint] = sscanf(fgetl(fid), 'Sampling frequency: %f Hz  Sampling interval: %f sec');\r\nFs = freqint(1);\r\ninterval = freqint(2);\r\nfgetl(fid);\r\n\r\nfor i = 1:size(ecgsig, 1)\r\n  [row(i), signal(i), gain(i), base(i), units(i)]=strread(fgetl(fid),'%d%s%f%f%s','delimiter','\\t');\r\nend\r\n\r\nfclose(fid);\r\necgsig(ecgsig==-32768) = NaN;\r\n\r\nfor i = 1:size(ecgsig, 1)\r\n    ecgsig(i, :) = (ecgsig(i, :) - base(i)) / gain(i);\r\nend\r\nN = size(ecgsig, 2);\r\ntm1 = 1/Fs:1/Fs:N/Fs;\r\ntm = (1:size(ecgsig, 2)) * interval;\r\nsizeEcgSig = size(ecgsig, 2);\r\ntimeEcgSig = sizeEcgSig*interval;\r\n%plot(tm', val');\r\n\r\n%for i = 1:length(signal)\r\n%    labels{i} = strcat(signal{i}, ' (', units{i}, ')'); \r\n%end\r\n\r\n%legend(labels);\r\n%xlabel('Time (sec)');\r\n% grid on\r\n\r\n% load annotations\r\n\r\nannotationName = strcat(Name, '.txt');\r\nfid = fopen(annotationName, 'rt');\r\n% was annotationsEcg = textscan(fid, '%d:%f %d %*c %*d %*d %*d %s', 'HeaderLines', 1, 'CollectOutput', 1);\r\nann = textscan(fid, '%d:%f %d %c %d %d %d %s', 'HeaderLines', 1, 'CollectOutput', 1);\r\nfclose(fid);\r\n\r\n\r\nend"
  },
  {
    "path": "data preprocessing_Matlab/onoffset.m",
    "content": "function [ ind ] = onoffset( interval,mode )\r\n%Function calculates on/off set of QRS complexe\r\nslope = [];\r\nfor i = 2:length(interval)-1\r\n    slope(end+1) = interval(i+1)-interval(i-1);\r\nend\r\n%   using MIN_SLOPE determine onset placement\r\nif strcmp(mode,'on')\r\n    [m,ind] = min(abs(slope));\r\n    %display('onset detected');\r\nelseif strcmp(mode,'off')\r\n    slope_th = 0.2*max(abs(slope));\r\n    slope_s = find(abs(slope)>=slope_th);\r\n    ind = slope_s(1);\r\nelse\r\n    display('wrong input, please select on/off set')\r\nend\r\n\r\nend\r\n\r\n"
  },
  {
    "path": "data preprocessing_Matlab/qsPeaks.m",
    "content": "function [ECGpeaks] = QSpeaks( ECG,Rposition,fs )\r\n%Q,S peaks detection\r\n% point to point time duration is determined by sampling frequency: 1/fs\r\n% the duration of QRS complex varies from 0.1s to 0.2s \r\n% complex--fs*0.2\r\naveHB = length(ECG)/length(Rposition);\r\nfid_pks = zeros(length(Rposition),7);\r\n% P QRSon Q R S QRSoff \r\n%% set up searching windows\r\nwindowS = round(fs*0.1); windowQ = round(fs*0.05);\r\nwindowP = round(aveHB/3); windowT = round(aveHB*2/3);\r\nwindowOF = round(fs*0.04);\r\n% initialization\r\nfor i = 1:length(Rposition)\r\n    thisR = Rposition(i);\r\n    if i==1\r\n        fid_pks(i,4) = thisR;\r\n        fid_pks(i,6) = thisR+windowS;\r\n    elseif i==length(Rposition)\r\n        fid_pks(i,4) = thisR;\r\n        %(thisR+windowT) < length(ECG) && (thisR - windowP) >=1\r\n        fid_pks(i,2) = thisR-windowQ;\r\n    else\r\n        \r\n        if (thisR+windowT) < length(ECG) && (thisR - windowP) >=1\r\n        % Q S peaks \r\n        fid_pks(i,4) = thisR;\r\n        [Sv,Sp] = min(ECG(thisR:thisR+windowS));\r\n        thisS = Sp + thisR-1;\r\n        fid_pks(i,5) = thisS;\r\n        [Qv,Qp] = min(ECG(thisR-windowQ:thisR));\r\n        thisQ = thisR-(windowQ+1) + Qp;\r\n        fid_pks(i,3)=thisQ;      \r\n        % onset and offset detection\r\n        interval_q = ECG(thisQ-windowOF:thisQ);\r\n        [ ind ] = onoffset(interval_q,'on' );\r\n        thisON = thisQ - (windowOF+1) + ind;\r\n        interval_s = ECG(thisS:thisS+windowOF);\r\n        [ ind ] = onoffset( interval_s,'off' );\r\n        thisOFF = thisS + ind-1;\r\n        fid_pks(i,2) = thisON;\r\n        fid_pks(i,6) = thisOFF;       \r\n%         % T and P waves detection\r\n%         lastOFF = fid_pks(i-1,6);\r\n%         nextON = \r\n        end\r\n    end       \r\nend\r\n%%\r\n% P,T detection\r\n%   Detection T waves first and distinguish the type of T waves\r\nfor i = 2:length(Rposition)-1\r\n    lastOFF = fid_pks(i-1,6);\r\n    thisON = fid_pks(i,2);\r\n    thisOFF = fid_pks(i,6);\r\n    nextON = fid_pks(i+1,2);\r\n    if thisON>lastOFF && thisOFF<nextON\r\n       Tzone = (thisOFF:(nextON-round((nextON-thisOFF)/3)));\r\n       Pzone = ((lastOFF+round(2*(thisON-lastOFF)/3)):thisON);\r\n       [Tpks,thisT] = max(ECG(Tzone));\r\n       [Ppks,thisP] = max(ECG(Pzone));\r\n       fid_pks(i,1) = (lastOFF+round(2*(thisON-lastOFF)/3)) + thisP -1;\r\n       fid_pks(i,7) = thisOFF + thisT -1;\r\n    end\r\nend\r\nECGpeaks = [];\r\nInd = zeros(length(Rposition),1);\r\nfor i = 1:length(Rposition)\r\n    Ind(i) = prod(fid_pks(i,:));\r\n    if Ind(i) ~= 0\r\n        ECGpeaks = [ECGpeaks;fid_pks(i,:)];\r\n    end\r\nend\r\n\r\n\r\nend  \r\n% Tpks = [];\r\n% Tposition =[];\r\n% Ttype = [];\r\n% Ppks = [];\r\n% Pposition =[];\r\n% if length(QRS_OFF) <= length(QRS_ON)\r\n%     for i = 1:length(QRS_OFF)-1\r\n%         lengthTzone = int64((QRS_ON(i+1)-QRS_OFF(i))*2/3);\r\n%         if lengthTzone <= 0 \r\n%             lengthTzone = abs(lengthTzone);\r\n%             disp('abnormal heart beat');\r\n%         end\r\n%         if QRS_OFF(i)+lengthTzone-1 > length(ECG)\r\n%             Tzone = ECG(QRS_OFF(i):length(ECG));\r\n%         else\r\n%             Tzone = ECG(QRS_OFF(i):QRS_OFF(i)+lengthTzone-1);\r\n%         end\r\n% %         if length(max(Tzone))~=1\r\n% %             deb = 1;\r\n% %         end\r\n%         [Tpks(end+1),Tind] = max(Tzone);\r\n%         Tposition(end+1) = QRS_OFF(i)+Tind-1;\r\n%         lengthPzone = int64((-QRS_OFF(i)+QRS_ON(i+1))/3);\r\n%         if lengthPzone <= 0 \r\n%             lengthPzone = abs(lengthPzone);\r\n%             disp('abnormal heart beat');\r\n%         end\r\n%         Pzone = ECG(QRS_ON(i+1)-lengthPzone-1:QRS_ON(i+1));\r\n%         [Ppks(end+1),Pind] = max(Pzone);\r\n%         Pposition(end+1) = QRS_ON(i+1)+Pind-lengthPzone;\r\n%     end\r\n% else\r\n%     for i = 1:length(QRS_ON)-1\r\n%         lengthTzone = int64((QRS_ON(i+1)-QRS_OFF(i))*2/3);\r\n%         if QRS_OFF(i)+lengthTzone-1 > length(ECG)\r\n%             Tzone = ECG(QRS_OFF(i):length(ECG));\r\n%         else\r\n%             Tzone = ECG(QRS_OFF(i):QRS_OFF(i)+lengthTzone-1);\r\n%         end\r\n%         [Tpks(end+1),Tind] = max(abs(Tzone));\r\n%         Tposition(end+1) = QRS_OFF(i)+Tind-1;\r\n%         lengthPzone = int64((-QRS_OFF(i)+QRS_ON(i+1))/3);\r\n%         Pzone = ECG(QRS_ON(i+1)-lengthPzone-1:QRS_ON(i+1));\r\n%         [Ppks(end+1),Pind] = max(Pzone);\r\n%         Pposition(end+1) = QRS_ON(i+1)+Pind-lengthPzone;\r\n%     end\r\n% end\r\n% for i = 1:length(Tpks)\r\n%     if Tpks(i)>0\r\n%         display('T Upright');\r\n%     else\r\n%         display('T Inverted');\r\n%     end\r\n% end\r\n% hold on;plot(Tposition,Tpks,'o');\r\n% hold on;plot(Pposition,Ppks,'*');\r\n %%   \r\n% [pks,locs] = findpeaks(A5);hold on; plot(locs,pks,'*');\r\n% \r\n% for i = 1:length(S)\r\n%     SearchArea = [];\r\n%     for j = 1:length(locs)\r\n%         if locs(j) > S(i)\r\n%             SearchArea(end+1) = locs(j);\r\n%         end\r\n%     end\r\n%     su = [];\r\n%     for k = 1:length(SearchArea)\r\n%         su(end+1)  = SearchArea(k)-S(i);\r\n%     end\r\n%     if ~isempty(su)\r\n%         Tpks(end+1) = S(i) + min(su);\r\n%     end\r\n% end\r\n% \r\n% % P detection\r\n% \r\n% \r\n% for i = 1:length(Q)\r\n%     SearchArea = [];\r\n%     for j = 1:length(locs)\r\n%         if locs(j) < Q(i)\r\n%             SearchArea(end+1) = locs(j);\r\n%         end\r\n%     end\r\n%     su = [];\r\n%     for k = 1:length(SearchArea)\r\n%         su(end+1)  = Q(i)-SearchArea(k);\r\n%     end\r\n%     if ~isempty(su)\r\n%         Ppks(end+1) = Q(i) - min(su);\r\n%     end\r\n% end\r\n% % check the result\r\n% for i = 1:length(Tpks)\r\n%     T(i) = A5(Tpks(i));\r\n% end\r\n% for i = 1:length(Ppks)\r\n%     P(i) = A5(Ppks(i));\r\n% end\r\n% figure(5);plot(A5);\r\n% hold on;plot(Tpks,T,'*');\r\n% hold on;plot(Ppks,P,'+');\r\n\r\n%%\r\n%Oops after interpolation, we slightly deviate from the true value\r\n%Thus we need to search for the local minimum in the whole signal again\r\n% Qposition = int64(Q.*Scale);\r\n% Sposition = int64(S.*Scale);\r\n% \r\n% for i = 1:length(R)\r\n%     [Sv,Sp] = min(ECG(Sposition(i):Sposition(i)+20));\r\n%     Sposition(i) = Sp + Sposition(i)-1;\r\n%     [Qv,Qp] = min(ECG(Qposition(i)-20:Qposition(i)));\r\n%     Qposition(i) = Qposition(i)-21 + Qp;\r\n% end\r\n% \r\n% Pposition = int64(Ppks.*Scale);\r\n% Tposition = int64(Tpks.*Scale);\r\n% \r\n% for i = length(Pposition)\r\n%     [Pv,Pp] = max(ECG(Pposition(i)-20:Pposition(i)));\r\n%     Pposition(i) = Pposition(i)-21 + Pp;\r\n% end\r\n% \r\n% for i = length(Tposition)\r\n%     [Tv,Tp] = max(ECG(Tposition(i):Tposition(i)+20));\r\n%     Tposition(i) = Tp + Tposition(i)-1;\r\n% end\r\n% ecgS = [];\r\n% for i = 1:2271\r\n%     ecgS(end+1) = issorted(ECGpeaks(1,:));\r\n% end\r\n% prod(ecgS)"
  },
  {
    "path": "data preprocessing_Matlab/readme",
    "content": "1. Install WFDB Toolbox for MATLAB from https://physionet.org/physiotools/matlab/wfdb-app-matlab/\n2. Run download_MITBIHDB.m to download the MIT-BIH database\n  -- In the code, you need to determine the local path to save the database and the path for the installed WFDB Toolbox\nFor example:\n  -- path_to_exes = 'path_to_WFDB_Toolbox\\mcode\\nativelibs\\windows\\bin';\n  -- path_to_save_records = 'path_to_save_database';\n\n3. Run seq2seq_mitbih_AAMI.m to prepare data for the intra-patient paradigm\n -- Modify the below line of the code based on your local address to the mitbih database\n   --addr = '.\\mitbihdb';\n4. Run seq2seq_mitbih_AAMI_DS1DS2.m to prepare data for the inter-patient paradigm\n -- Modify the below line of the code based on your local address to the mitbih database\n   -- addr = '.\\mitbihdb';\n"
  },
  {
    "path": "data preprocessing_Matlab/seq2seq_mitbih_AAMI.m",
    "content": "clear all\r\nclc\r\ntic\r\naddr = '.\\mitbihdb';\r\nFiles=dir(strcat(addr,'\\*.mat'));\r\n\r\n%% Translate PhysioNet classification results to AAMI and AAMI2 labling schemes\r\n% AAMI Classes:\r\n% % N = N, L, R, e, j\r\n% % S = A, a, J, S\r\n% % V = V, E\r\n% % F = F\r\n% % Q = /, f, Q\r\n\r\n% AAMI2 Classes:\r\n% % N = N, L, R, e, j\r\n% % S = A, a, J, S\r\n% % V = V, E, F\r\n% % Q = /, f, Q\r\n% https://github.com/ehendryx/deim-cur-ecg/blob/master/DS1_MIT_CUR_beat_classification.m\r\nAAMI_annotations = {'N' 'S' 'V' 'F' 'Q'};\r\nAAMI2_annotations = {'N' 'S' 'V_hat' 'Q'};\r\n\r\nindex = 1;\r\nbeat_len = 280;\r\nn_cycles = 0;\r\nfeaturesSeg = [];\r\ngroupN = [];\r\ngroupV = [];\r\ngroupS = [];\r\ngroupF = [];\r\ngroupQ = [];\r\nN_class = 0;V_class=0;F_class=0;Q_class=0;S_class=0;\r\nfor i=1:length(Files) % Files names3signals\r\n    \r\n    %% load the files\r\n    % load ('100m.mat')          % the signal will be loaded to \"val\" matrix\r\n    % val = (val - 1024)/200;    % you have to remove \"base\" and \"gain\"\r\n    % ECGsignal = val(1,1:1000); % select the lead (Lead I)\r\n    % Fs = 360;                  % sampling frequecy\r\n    % t = (0:length(ECGsignal)-1)/Fs;  % time\r\n    % plot(t,ECGsignal)\r\n    %\r\n    \r\n    [pathstr,name,ext] = fileparts(Files(i).name);\r\n    nsig = 1;\r\n    \r\n    [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig([addr filesep name]);\r\n    \r\n    signal = ecgsig(nsig,:);\r\n    \r\n    \r\n    \r\n    %%\r\n    %     rPeaks = rDetection(signal, Fs);\r\n    %     rPeaks = get_rpeaks(signal, Fs);\r\n    rPeaks  = cell2mat(ann(3))+1;\r\n    n_cycles = n_cycles + length(rPeaks);\r\n    %     [R_i,R_amp,S_i,S_amp,T_i,T_amp,Q_i,Q_amp] = peakdetect(signal,Fs);\r\n    %      rPeaks =  R_i;\r\n    rPeaks = double(rPeaks);\r\n    \r\n    peaks = qsPeaks(signal, rPeaks, Fs);\r\n    tpeaks = peaks(:,7);  \r\n    \r\n    %    %% Plot P Q R S T points\r\n    %     N = length(signal);\r\n    %     tm = 1/Fs:1/Fs:N/Fs;\r\n    %     figure;plot(tm,signal);hold on\r\n    %     scatter(peaks(:,1)/Fs,signal(peaks(:,1)),'g*') % P points\r\n    %     scatter(peaks(:,3)/Fs,signal(peaks(:,3)),'k+') % Q points\r\n    %     scatter(peaks(:,4)/Fs,signal(peaks(:,4)),'ro') % R points\r\n    %     scatter(peaks(:,5)/Fs,signal(peaks(:,5)),'c^') % S points\r\n    %     scatter(peaks(:,7)/Fs,signal(peaks(:,7)),'mo') % T points\r\n    %     xlabel('Seconds'); ylabel('Amplitude')\r\n    %     title('ECG peaks detection')\r\n    %     legend('Raw signal','P','Q','R','S','T')\r\n    %     hold off\r\n    %\r\n    \r\n    %% grouping\r\n    % gourp 0: N(normal and bundle branch block beats); group 2: V(ventricular\r\n    %ectopic beats); group 1: S(supraventricular ectopic beats); group 3: F (fusion of N and V beats)\r\n    % group Q:4 unknown beat\r\n    % consider just absolute features, where each row of extraxted features is\r\n    % related to one segment\r\n    \r\n    annots_list = ['N','L','R','e','j','S','A','a','J','V','E','F','/','f','Q'];\r\n    \r\n    annot  = cell2mat(ann(4));\r\n    indices  = ismember(rPeaks,peaks(:,4));\r\n    annot = annot(indices);\r\n    % rps = peaks(:,4);\r\n    \r\n    % AAMI Classes:\r\n    % % N = N, L, R, e, j\r\n    % % S = A, a, J, S\r\n    % % V = V, E\r\n    % % F = F\r\n    % % Q = /, f, Q\r\n    \r\n    seg_values = {};\r\n    seg_labels =[];  \r\n    \r\n    ind_seg = 1;\r\n    % normalize\r\n    signal = normalize(signal);\r\n    for ind=1:length(annot)\r\n        if ~ismember(annot(ind),annots_list)\r\n            continue;\r\n        end\r\n        \r\n        N_g = ['N', 'L', 'R', 'e', 'j'];%0\r\n        S_g = ['A', 'a', 'J', 'S'];%1\r\n        V_g = ['V', 'E'];%2\r\n        F_g = ['F'];%3\r\n        Q_g = [' /', 'f', 'Q'];%4\r\n        if(ismember(annot(ind),N_g))\r\n              lebel = 'N';\r\n%              if(N_class >8031) %(N_class >8031)\r\n%                 continue\r\n%              end\r\n          \r\n            \r\n        elseif(ismember(annot(ind),S_g))\r\n            lebel = 'S';\r\n        elseif(ismember(annot(ind),V_g))\r\n            lebel = 'V';\r\n        elseif(ismember(annot(ind),F_g))\r\n            lebel = 'F';\r\n        elseif(ismember(annot(ind),Q_g))\r\n            lebel = 'Q';\r\n        else\r\n            throw(\"No label! :(\")\r\n            \r\n        end\r\n        \r\n        if ind==1\r\n            \r\n            seg_values{ind_seg} = signal(1:tpeaks(ind)-1)';\r\n            t_sig = imresize(seg_values{ind_seg}(1:min(Fs,length(seg_values{ind_seg}))), [beat_len 1]);\r\n            seg_values{ind_seg} = t_sig;\r\n            seg_labels(ind_seg) = lebel;\r\n            % plot(cell2mat(seg_values(ind_seg)))\r\n            ind_seg = ind_seg+1;\r\n            continue;\r\n        end\r\n        t_sig = imresize(signal(tpeaks(ind-1):tpeaks(ind)-1)', [beat_len 1]);\r\n        seg_values{ind_seg} =t_sig ;\r\n        %     figure;\r\n        %      plot(cell2mat(seg_values(ind_seg)))\r\n        % determine the label\r\n        \r\n        \r\n        seg_labels(ind_seg) =  lebel;\r\n        ind_seg = ind_seg+1;\r\n        \r\n    end\r\n    s2s_mitbih(i).seg_values = seg_values';\r\n    s2s_mitbih(i).seg_labels = char(seg_labels);\r\n    % featuresSeg = [featuresSeg; peakSegFeats(N_inds,:),repmat(0,length(N_inds),1)];\r\n    \r\n    \r\n    % group N:0\r\n    % N = N, L, R, e, j\r\n    N_inds = find(annot=='N');\r\n    N_inds = [N_inds;find(annot=='L')];\r\n    N_inds = [N_inds;find(annot=='R')];\r\n    N_inds = [N_inds;find(annot=='e')];\r\n    N_inds = [N_inds;find(annot=='j')];\r\n    N_class = N_class + length(N_inds);\r\n    \r\n    % group S:1\r\n    % S = A, a, J, S\r\n    S_inds = find(annot=='S');\r\n    S_inds = [S_inds;find(annot=='A')];\r\n    S_inds = [S_inds;find(annot=='a')];\r\n    S_inds = [S_inds;find(annot=='J')];\r\n    S_class = S_class + length(S_inds);\r\n    \r\n    % group V:2\r\n    % V = V, E\r\n    V_inds = find(annot=='V');\r\n    V_inds = [V_inds;find(annot=='E')];\r\n    V_class = V_class + length(V_inds);\r\n    \r\n    % featuresSeg = [featuresSeg; peakSegFeats(V_inds,:),repmat(2,length(V_inds),1)];\r\n    \r\n    % group F:3\r\n    % F = F\r\n    F_inds = find(annot=='F');\r\n    F_class = F_class + length(F_inds);\r\n    \r\n    % group Q:4\r\n    % Q = /, f, Q\r\n    Q_inds = find(annot=='/');\r\n    Q_inds = [Q_inds;find(annot=='f')];\r\n    Q_inds = [Q_inds;find(annot=='Q')];\r\n    Q_class = Q_class + length(Q_inds);\r\n    \r\nend\r\n\r\n% % calucualte the mean length of all beats in the dataset: it is 280\r\n% sizes = [];\r\n% for ind=1:length(s2s_mitbih)\r\n%     sizes= [sizes;cellfun(@length,s2s_mitbih(ind).seg_values)];\r\n% end\r\n% beat_len = floor(mean(sizes))\r\n\r\nsave s2s_mitbih_aami.mat s2s_mitbih\r\ntoc\r\nF_class\r\nN_class\r\nQ_class\r\nS_class\r\nV_class\r\nF_class+N_class+Q_class+S_class+V_class\r\ndisp('Successfully generated :)')\r\n"
  },
  {
    "path": "data preprocessing_Matlab/seq2seq_mitbih_AAMI_DS1DS2.m",
    "content": "clear all\r\nclc\r\ntic\r\naddr = '.\\mitbihdb';\r\nFiles=dir(strcat(addr,'\\*.mat'));\r\n\r\n%% Translate PhysioNet classification results to AAMI and AAMI2 labling schemes\r\n% AAMI Classes:\r\n% % N = N, L, R, e, j\r\n% % S = A, a, J, S\r\n% % V = V, E\r\n% % F = F\r\n% % Q = /, f, Q\r\n\r\n% AAMI2 Classes:\r\n% % N = N, L, R, e, j\r\n% % S = A, a, J, S\r\n% % V = V, E, F\r\n% % Q = /, f, Q\r\n% https://github.com/ehendryx/deim-cur-ecg/blob/master/DS1_MIT_CUR_beat_classification.m\r\nAAMI_annotations = {'N' 'S' 'V' 'F' 'Q'};\r\nAAMI2_annotations = {'N' 'S' 'V_hat' 'Q'};\r\nDS1=[101, 106, 108, 109, 112, 114, 115, 116, 118,119, 122, 124, 201, 203, 205, 207, 208, 209, 215, 220, 223,230];\r\nDS2 =[100, 103, 105, 111, 113, 117, 121, 123,200, 202, 210, 212, 213, 214, 219, 221, 222, 228, 231, 232, 233,234];\r\nindex = 1;\r\nbeat_len = 280;\r\nn_cycles = 0;\r\nfeaturesSeg = [];\r\ngroupN = [];\r\ngroupV = [];\r\ngroupS = [];\r\ngroupF = [];\r\ngroupQ = [];\r\n\r\nfor j=1:2\r\n    N_class = 0;V_class=0;F_class=0;Q_class=0;S_class=0;\r\n    if j==1\r\n        DS = DS1;\r\n    else\r\n        DS= DS2;\r\n    end\r\n        for i=1:length(DS) % Files names3signals\r\n            \r\n            %% load the files\r\n            % load ('100m.mat')          % the signal will be loaded to \"val\" matrix\r\n            % val = (val - 1024)/200;    % you have to remove \"base\" and \"gain\"\r\n            % ECGsignal = val(1,1:1000); % select the lead (Lead I)\r\n            % Fs = 360;                  % sampling frequecy\r\n            % t = (0:length(ECGsignal)-1)/Fs;  % time\r\n            % plot(t,ECGsignal)\r\n            %\r\n            \r\n            [pathstr,name,ext] = fileparts(strcat(num2str (DS(i)),'m'));\r\n            nsig = 1;\r\n            \r\n            [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig([addr filesep name]);\r\n            \r\n            signal = ecgsig(nsig,:);\r\n            \r\n            \r\n            %%\r\n            %     rPeaks = rDetection(signal, Fs);\r\n            %     rPeaks = get_rpeaks(signal, Fs);\r\n            rPeaks  = cell2mat(ann(3))+1;\r\n            n_cycles = n_cycles + length(rPeaks);\r\n            %     [R_i,R_amp,S_i,S_amp,T_i,T_amp,Q_i,Q_amp] = peakdetect(signal,Fs);\r\n            %      rPeaks =  R_i;\r\n            rPeaks = double(rPeaks);\r\n            \r\n            peaks = qsPeaks(signal, rPeaks, Fs);\r\n            tpeaks = peaks(:,7);\r\n            \r\n            %    %% Plot P Q R S T points\r\n            %     N = length(signal);\r\n            %     tm = 1/Fs:1/Fs:N/Fs;\r\n            %     figure;plot(tm,signal);hold on\r\n            %     scatter(peaks(:,1)/Fs,signal(peaks(:,1)),'g*') % P points\r\n            %     scatter(peaks(:,3)/Fs,signal(peaks(:,3)),'k+') % Q points\r\n            %     scatter(peaks(:,4)/Fs,signal(peaks(:,4)),'ro') % R points\r\n            %     scatter(peaks(:,5)/Fs,signal(peaks(:,5)),'c^') % S points\r\n            %     scatter(peaks(:,7)/Fs,signal(peaks(:,7)),'mo') % T points\r\n            %     xlabel('Seconds'); ylabel('Amplitude')\r\n            %     title('ECG peaks detection')\r\n            %     legend('Raw signal','P','Q','R','S','T')\r\n            %     hold off\r\n            %\r\n            \r\n            %% grouping\r\n            % gourp 0: N(normal and bundle branch block beats); group 2: V(ventricular\r\n            %ectopic beats); group 1: S(supraventricular ectopic beats); group 3: F (fusion of N and V beats)\r\n            % group Q:4 unknown beat\r\n            % consider just absolute features, where each row of extraxted features is\r\n            % related to one segment\r\n            \r\n            annots_list = ['N','L','R','e','j','S','A','a','J','V','E','F','/','f','Q'];\r\n            \r\n            annot  = cell2mat(ann(4));\r\n            indices  = ismember(rPeaks,peaks(:,4));\r\n            annot = annot(indices);\r\n            % rps = peaks(:,4);\r\n            \r\n            % AAMI Classes:\r\n            % % N = N, L, R, e, j\r\n            % % S = A, a, J, S\r\n            % % V = V, E\r\n            % % F = F\r\n            % % Q = /, f, Q\r\n            \r\n            seg_values = {};\r\n            seg_labels =[];\r\n            \r\n            ind_seg = 1;\r\n            % normalize\r\n            signal = normalize(signal);\r\n            for ind=1:length(annot)\r\n                if ~ismember(annot(ind),annots_list)\r\n                    continue;\r\n                end\r\n                \r\n                N_g = ['N', 'L', 'R', 'e', 'j'];%0\r\n                S_g = ['A', 'a', 'J', 'S'];%1\r\n                V_g = ['V', 'E'];%2\r\n                F_g = ['F'];%3\r\n                Q_g = [' /', 'f', 'Q'];%4\r\n                if(ismember(annot(ind),N_g))\r\n                    lebel = 'N';\r\n                    %              if(N_class >8031) %(N_class >8031)\r\n                    %                 continue\r\n                    %              end\r\n                    \r\n                    \r\n                elseif(ismember(annot(ind),S_g))\r\n                    lebel = 'S';\r\n                elseif(ismember(annot(ind),V_g))\r\n                    lebel = 'V';\r\n                elseif(ismember(annot(ind),F_g))\r\n                    lebel = 'F';\r\n                elseif(ismember(annot(ind),Q_g))\r\n                    lebel = 'Q';\r\n                else\r\n                    throw(\"No label! :(\")\r\n                    \r\n                end\r\n                if ind==1\r\n                    \r\n                    seg_values{ind_seg} = signal(1:tpeaks(ind)-1)';\r\n                    t_sig = imresize(seg_values{ind_seg}(1:min(Fs,length(seg_values{ind_seg}))), [beat_len 1]);\r\n                    seg_values{ind_seg} = t_sig;\r\n                    seg_labels(ind_seg) = lebel;\r\n                    % plot(cell2mat(seg_values(ind_seg)))\r\n                    ind_seg = ind_seg+1;\r\n                    continue;\r\n                end\r\n                t_sig = imresize(signal(tpeaks(ind-1):tpeaks(ind)-1)', [beat_len 1]);\r\n                seg_values{ind_seg} =t_sig ;\r\n                %     figure;\r\n                %      plot(cell2mat(seg_values(ind_seg)))\r\n                % determine the label\r\n                \r\n                seg_labels(ind_seg) =  lebel;\r\n                ind_seg = ind_seg+1;\r\n                \r\n            end\r\n            if j==1\r\n                s2s_mitbih_DS1(i).seg_values = seg_values';\r\n                s2s_mitbih_DS1(i).seg_labels = char(seg_labels);\r\n            else\r\n                s2s_mitbih_DS2(i).seg_values = seg_values';\r\n                s2s_mitbih_DS2(i).seg_labels = char(seg_labels);\r\n            end\r\n            \r\n            % featuresSeg = [featuresSeg; peakSegFeats(N_inds,:),repmat(0,length(N_inds),1)];\r\n            \r\n            \r\n            % group N:0\r\n            % N = N, L, R, e, j\r\n            N_inds = find(annot=='N');\r\n            N_inds = [N_inds;find(annot=='L')];\r\n            N_inds = [N_inds;find(annot=='R')];\r\n            N_inds = [N_inds;find(annot=='e')];\r\n            N_inds = [N_inds;find(annot=='j')];\r\n            N_class = N_class + length(N_inds);\r\n            \r\n            % group S:1\r\n            % S = A, a, J, S\r\n            S_inds = find(annot=='S');\r\n            S_inds = [S_inds;find(annot=='A')];\r\n            S_inds = [S_inds;find(annot=='a')];\r\n            S_inds = [S_inds;find(annot=='J')];\r\n            S_class = S_class + length(S_inds);\r\n            \r\n            % group V:2\r\n            % V = V, E\r\n            V_inds = find(annot=='V');\r\n            V_inds = [V_inds;find(annot=='E')];\r\n            V_class = V_class + length(V_inds);\r\n            \r\n            % featuresSeg = [featuresSeg; peakSegFeats(V_inds,:),repmat(2,length(V_inds),1)];\r\n            \r\n            % group F:3\r\n            % F = F\r\n            F_inds = find(annot=='F');\r\n            F_class = F_class + length(F_inds);\r\n            \r\n            % group Q:4\r\n            % Q = /, f, Q\r\n            Q_inds = find(annot=='/');\r\n            Q_inds = [Q_inds;find(annot=='f')];\r\n            Q_inds = [Q_inds;find(annot=='Q')];\r\n            Q_class = Q_class + length(Q_inds);\r\n            \r\n        end\r\n        \r\n    F_class\r\n    N_class\r\n    Q_class\r\n    S_class\r\n    V_class\r\n    F_class+N_class+Q_class+S_class+V_class\r\n        \r\n    end\r\n    \r\n    \r\n    % % calucualte the mean length of all beats in the dataset: it is 280\r\n    % sizes = [];\r\n    % for ind=1:length(s2s_mitbih)\r\n    %     sizes= [sizes;cellfun(@length,s2s_mitbih(ind).seg_values)];\r\n    % end\r\n    % beat_len = floor(mean(sizes))\r\n    \r\n    save s2s_mitbih_aami_DS1DS2.mat s2s_mitbih_DS1 s2s_mitbih_DS2\r\n    toc\r\n\r\n    disp('Successfully generated :)')\r\n"
  },
  {
    "path": "seq_seq_annot_DS1DS2.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nimport scipy.io as spio\nfrom  sklearn.preprocessing import MinMaxScaler\nimport random\nimport time\nimport os\nfrom datetime import datetime\nfrom sklearn.metrics import confusion_matrix\nimport tensorflow as tf\nfrom imblearn.over_sampling import SMOTE\nfrom sklearn.model_selection import train_test_split\nimport argparse\n\nrandom.seed(654)\ndef read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100, trainset=1):\n    def normalize(data):\n        data = np.nan_to_num(data)  # removing NaNs and Infs\n        data = data - np.mean(data)\n        data = data / np.std(data)\n        return data\n\n    # read data\n    data = []\n    samples = spio.loadmat(filename + \".mat\")\n    if trainset == 1: #DS1\n        samples = samples['s2s_mitbih_DS1']\n\n    else: # DS2\n        samples = samples['s2s_mitbih_DS2']\n\n    values = samples[0]['seg_values']\n    labels = samples[0]['seg_labels']\n    items_len = len(labels)\n    num_annots = sum([item.shape[0] for item in values])\n\n    n_seqs = num_annots / max_time\n    #  add all segments(beats) together\n    l_data = 0\n    for i, item in enumerate(values):\n        l = item.shape[0]\n        for itm in item:\n            if l_data == n_seqs * max_time:\n                break\n            data.append(itm[0])\n            l_data = l_data + 1\n\n    #  add all labels together\n    l_lables  = 0\n    t_lables = []\n    for i, item in enumerate(labels):\n        if len(t_lables)==n_seqs*max_time:\n            break\n        item= item[0]\n        for lebel in item:\n            if l_lables == n_seqs * max_time:\n                break\n            t_lables.append(str(lebel))\n            l_lables = l_lables + 1\n\n    del values\n    data = np.asarray(data)\n    shape_v = data.shape\n    data = np.reshape(data, [shape_v[0], -1])\n    t_lables = np.array(t_lables)\n    _data  = np.asarray([],dtype=np.float64).reshape(0,shape_v[1])\n    _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,)\n    for cl in classes:\n        _label = np.where(t_lables == cl)\n        permute = np.random.permutation(len(_label[0]))\n        _label = _label[0][permute[:max_nlabel]]\n        # _label = _label[0][:max_nlabel]\n        # permute = np.random.permutation(len(_label))\n        # _label = _label[permute]\n        _data = np.concatenate((_data, data[_label]))\n        _labels = np.concatenate((_labels, t_lables[_label]))\n\n    data = _data[:(len(_data)/ max_time) * max_time, :]\n    _labels = _labels[:(len(_data) / max_time) * max_time]\n\n    # data = _data\n    #  split data into sublist of 100=se_len values\n    data = [data[i:i + max_time] for i in range(0, len(data), max_time)]\n    labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)]\n    # shuffle\n    permute = np.random.permutation(len(labels))\n    data = np.asarray(data)\n    labels = np.asarray(labels)\n    data= data[permute]\n    labels = labels[permute]\n\n    print('Records processed!')\n\n    return data, labels\ndef evaluate_metrics(confusion_matrix):\n    # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal\n    # Sensitivity, hit rate, recall, or true positive rate\n    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)\n    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)\n    TP = np.diag(confusion_matrix)\n    TN = confusion_matrix.sum() - (FP + FN + TP)\n    TPR = TP / (TP + FN)\n    # Specificity or true negative rate\n    TNR = TN / (TN + FP)\n    # Precision or positive predictive value\n    PPV = TP / (TP + FP)\n    # Negative predictive value\n    NPV = TN / (TN + FN)\n    # Fall out or false positive rate\n    FPR = FP / (FP + TN)\n    # False negative rate\n    FNR = FN / (TP + FN)\n    # False discovery rate\n    FDR = FP / (TP + FP)\n\n    # Overall accuracy\n    ACC = (TP + TN) / (TP + FP + FN + TN)\n    # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN))\n    ACC_macro = np.mean(\n        ACC)  # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average)\n\n    return ACC_macro, ACC, TPR, TNR, PPV\ndef batch_data(x, y, batch_size):\n    shuffle = np.random.permutation(len(x))\n    start = 0\n#     from IPython.core.debugger import Tracer; Tracer()()\n    x = x[shuffle]\n    y = y[shuffle]\n    while start + batch_size <= len(x):\n        yield x[start:start+batch_size], y[start:start+batch_size]\n        start += batch_size\ndef build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False):\n    _inputs = tf.reshape(inputs, [-1, n_channels, input_depth / n_channels])\n    # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels])\n\n    # #(batch*max_time, 280, 1) --> (N, 280, 18)\n    conv1 = tf.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n    max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')\n\n    conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n    max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')\n\n    conv3 = tf.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1,\n                              padding='same', activation=tf.nn.relu)\n\n    shape = conv3.get_shape().as_list()\n    data_input_embed = tf.reshape(conv3, (-1, max_time,shape[1]*shape[2]))\n\n    # timesteps = max_time\n    #\n    # lstm_in = tf.unstack(data_input_embed, timesteps, 1)\n    # lstm_size = 128\n    # # Get lstm cell output\n    # # Add LSTM layers\n    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)\n    # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32)\n    # data_input_embed = tf.stack(data_input_embed, 1)\n    # shape = data_input_embed.get_shape().as_list()\n\n    embed_size = 10 #128 lstm_size # shape[1]*shape[2]\n\n    # Embedding layers\n    output_embedding = tf.Variable(tf.random_uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding')\n    data_output_embed = tf.nn.embedding_lookup(output_embedding, dec_inputs)\n\n\n    with tf.variable_scope(\"encoding\") as encoding_scope:\n        if not bidirectional:\n\n            # Regular approach with LSTM units\n            lstm_enc = tf.contrib.rnn.LSTMCell(num_units)\n            _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32)\n\n        else:\n\n            # Using a bidirectional LSTM architecture instead\n            enc_fw_cell = tf.contrib.rnn.LSTMCell(num_units)\n            enc_bw_cell = tf.contrib.rnn.LSTMCell(num_units)\n\n            ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.nn.bidirectional_dynamic_rnn(\n                cell_fw=enc_fw_cell,\n                cell_bw=enc_bw_cell,\n                inputs=data_input_embed,\n                dtype=tf.float32)\n            enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1)\n            enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1)\n            last_state = tf.contrib.rnn.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h)\n\n    with tf.variable_scope(\"decoding\") as decoding_scope:\n        if not bidirectional:\n            lstm_dec = tf.contrib.rnn.LSTMCell(num_units)\n        else:\n            lstm_dec = tf.contrib.rnn.LSTMCell(2 * num_units)\n\n        dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state)\n\n    logits = tf.layers.dense(dec_outputs, units=len(char2numY), use_bias=True)\n\n    return logits\ndef str2bool(v):\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\ndef main():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--epochs', type=int, default=500)\n    parser.add_argument('--max_time', type=int, default=10)\n    parser.add_argument('--test_steps', type=int, default=10)\n    parser.add_argument('--batch_size', type=int, default=20)\n    parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami_DS1DS2')\n    parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False'))\n    # parser.add_argument('--lstm_layers', type=int, default=2)\n    parser.add_argument('--num_units', type=int, default=128)\n    parser.add_argument('--n_oversampling', type=int, default=6000)\n    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq_DS1DS2')\n    parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih_DS1DS2.ckpt')\n    parser.add_argument('--classes', nargs='+', type=chr,\n                        default=['N', 'S','V'])\n    args = parser.parse_args()\n    run_program(args)\ndef run_program(args):\n    print(args)\n    max_time = args.max_time # 5 3 second best 10# 40 # 100\n    epochs = args.epochs # 300\n    batch_size = args.batch_size # 10\n    num_units = args.num_units\n    bidirectional = args.bidirectional\n    # lstm_layers = args.lstm_layers\n    n_oversampling = args.n_oversampling\n    checkpoint_dir = args.checkpoint_dir\n    ckpt_name = args.ckpt_name\n    test_steps = args.test_steps\n    classes= args.classes # ['N', 'S','V']\n    filename = args.data_dir\n\n    X_train, y_train = read_mitbih(filename, max_time, classes=classes, max_nlabel=50000,trainset=1)\n    X_test, y_test = read_mitbih(filename, max_time, classes=classes, max_nlabel=50000,trainset=0)\n    input_depth = X_train.shape[2]\n    n_channels = 10\n    print (\"# of sequences: \", len(X_train))\n\n    classes = np.unique(y_train)\n    char2numY = dict(zip(classes, range(len(classes))))\n    n_classes = len(classes)\n    print ('Classes (training): ', classes)\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print (cl, len(np.where(y_train.flatten() == cl)[0]))\n\n    print ('Classes (test): ', classes)\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print (cl, len(np.where(y_test.flatten() == cl)[0]))\n\n\n    char2numY['<GO>'] = len(char2numY)\n    num2charY = dict(zip(char2numY.values(), char2numY.keys()))\n\n    y_train = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in y_train]\n    y_test = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in y_test]\n    y_test = np.asarray(y_test)\n    y_train = np.array(y_train)\n\n    x_seq_length = len(X_train[0])\n    y_seq_length = len(y_train[0]) - 1\n\n    # Placeholders\n    inputs = tf.placeholder(tf.float32, [None, max_time, input_depth], name='inputs')\n    targets = tf.placeholder(tf.int32, (None, None), 'targets')\n    dec_inputs = tf.placeholder(tf.int32, (None, None), 'output')\n\n    logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time,\n                  bidirectional=bidirectional)\n    # decoder_prediction = tf.argmax(logits, 2)\n    # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong\n    # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1)\n\n    with tf.name_scope(\"optimization\"):\n        # Loss function\n        vars = tf.trainable_variables()\n        beta = 0.001\n        lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars\n                            if 'bias' not in v.name]) * beta\n        loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length]))\n        # Optimizer\n        loss = tf.reduce_mean(loss + lossL2)\n        optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss)\n\n\n    # train the graph\n\n    # over-sampling: SMOTE\n    X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1])\n    y_train= y_train[:,1:].flatten()\n\n    nums = []\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        nums.append(len(np.where(y_train.flatten()==ind)[0]))\n\n    # ratio={0:nums[0],1:nums[0],2:nums[0]}\n    # ratio={0:7000,1:nums[1],2:7000,3:7000}\n    ratio={0:nums[0],1:n_oversampling+1000,2:n_oversampling}\n    sm = SMOTE(random_state=12,ratio=ratio)\n    X_train, y_train = sm.fit_sample(X_train, y_train)\n\n    X_train = X_train[:(X_train.shape[0]/max_time)*max_time,:]\n    y_train = y_train[:(X_train.shape[0]/max_time)*max_time]\n\n    X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]])\n    y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,])\n    y_train= [[char2numY['<GO>']] + [y_ for y_ in date] for date in y_train]\n    y_train = np.array(y_train)\n\n    print ('Classes in the training set: ', classes)\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print (cl, len(np.where(y_train.flatten()==ind)[0]))\n    print (\"------------------y_train samples--------------------\")\n    for ii in range(2):\n      print(''.join([num2charY[y_] for y_ in list(y_train[ii+5])]))\n\n    print ('Classes in the training set: ', classes)\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print (cl, len(np.where(y_test.flatten()==ind)[0]))\n\n    print (\"------------------y_test samples--------------------\")\n    for ii in range(2):\n      print(''.join([num2charY[y_] for y_ in list(y_test[ii+5])]))\n\n    def test_model():\n        # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size))\n        acc_track = []\n        sum_test_conf = []\n        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)):\n\n            dec_input = np.zeros((len(source_batch), 1)) + char2numY['<GO>']\n            for i in range(y_seq_length):\n                batch_logits = sess.run(logits,\n                                        feed_dict={inputs: source_batch, dec_inputs: dec_input})\n                prediction = batch_logits[:, -1].argmax(axis=-1)\n                dec_input = np.hstack([dec_input, prediction[:, None]])\n            # acc_track.append(np.mean(dec_input == target_batch))\n            acc_track.append(dec_input[:, 1:] == target_batch[:, 1:])\n            y_true= target_batch[:, 1:].flatten()\n            y_pred = dec_input[:, 1:].flatten()\n            sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=range(len(char2numY)-1)))\n\n        sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0)\n\n        # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track)))\n\n        # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy],\n        #                                           feed_dict={inputs: source_batch,\n        #                                                      dec_inputs: dec_input[:, :-1],\n        #                                                      targets: target_batch[:, 1:]})\n        # print (mean_p_class)\n        # print (accuracy_classes)\n        acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf)\n        print('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg))\n        for index_ in range(n_classes):\n            print(\"\\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}\".format(\n                classes[index_],\n                sensitivity[\n                    index_],\n                specificity[\n                    index_], PPV[index_],\n                acc[index_]))\n        print(\"\\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}\".format(\n            np.mean(sensitivity), np.mean(specificity), np.mean(PPV), np.mean(acc)))\n        return  acc_avg, acc, sensitivity, specificity, PPV\n    loss_track = []\n    def count_prameters():\n        print ('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))\n\n    count_prameters()\n    if (os.path.exists(checkpoint_dir) == False):\n        os.mkdir(checkpoint_dir)\n\n    with tf.Session() as sess:\n        sess.run(tf.global_variables_initializer())\n        sess.run(tf.local_variables_initializer())\n        saver = tf.train.Saver()\n        print(str(datetime.now()))\n        pre_acc_avg = 0.0\n        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)\n        if ckpt and ckpt.model_checkpoint_path:\n            # # Restore\n            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)\n            # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name))\n            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))\n            # or 'load meta graph' and restore weights\n            # saver = tf.train.import_meta_graph(ckpt_name+\".meta\")\n            # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir))\n            test_model()\n        else:\n\n            for epoch_i in range(epochs):\n                start_time = time.time()\n                train_acc = []\n                for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)):\n                    _, batch_loss, batch_logits = sess.run([optimizer, loss, logits],\n                        feed_dict = {inputs: source_batch,\n                                     dec_inputs: target_batch[:, :-1],\n                                     targets: target_batch[:, 1:]})\n                    loss_track.append(batch_loss)\n                    train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:])\n\n                accuracy = np.mean(train_acc)\n                print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss,\n                                                                                  accuracy, time.time() - start_time))\n\n                if epoch_i%test_steps==0:\n                    acc_avg, acc, sensitivity, specificity, PPV= test_model()\n\n                    print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))\n                    save_path = os.path.join(checkpoint_dir, ckpt_name)\n                    saver.save(sess, save_path)\n                    print(\"Model saved in path: %s\" % save_path)\n\n                    # if np.nan_to_num(acc_avg) > pre_acc_avg:  # save the better model based on the f1 score\n                    #     print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))\n                    #     pre_acc_avg = acc_avg\n                    #     save_path =os.path.join(checkpoint_dir, ckpt_name)\n                    #     saver.save(sess, save_path)\n                    #     print(\"The best model (till now) saved in path: %s\" % save_path)\n\n            plt.plot(loss_track)\n            plt.show()\n        print(str(datetime.now()))\n\n        # test_model()\nif __name__ == '__main__':\n    main()\n\n\n\n\n"
  },
  {
    "path": "seq_seq_annot_aami.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nimport scipy.io as spio\nfrom  sklearn.preprocessing import MinMaxScaler\nimport random\nimport time\nimport os\nfrom datetime import datetime\nfrom sklearn.metrics import confusion_matrix\nimport tensorflow as tf\nfrom imblearn.over_sampling import SMOTE\nfrom sklearn.model_selection import train_test_split\nimport argparse\nrandom.seed(654)\ndef read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100):\n    def normalize(data):\n        data = np.nan_to_num(data)  # removing NaNs and Infs\n        data = data - np.mean(data)\n        data = data / np.std(data)\n        return data\n\n    # read data\n    data = []\n    samples = spio.loadmat(filename + \".mat\")\n    samples = samples['s2s_mitbih']\n    values = samples[0]['seg_values']\n    labels = samples[0]['seg_labels']\n    num_annots = sum([item.shape[0] for item in values])\n\n    n_seqs = num_annots / max_time\n    #  add all segments(beats) together\n    l_data = 0\n    for i, item in enumerate(values):\n        l = item.shape[0]\n        for itm in item:\n            if l_data == n_seqs * max_time:\n                break\n            data.append(itm[0])\n            l_data = l_data + 1\n\n    #  add all labels together\n    l_lables  = 0\n    t_lables = []\n    for i, item in enumerate(labels):\n        if len(t_lables)==n_seqs*max_time:\n            break\n        item= item[0]\n        for lebel in item:\n            if l_lables == n_seqs * max_time:\n                break\n            t_lables.append(str(lebel))\n            l_lables = l_lables + 1\n\n    del values\n    data = np.asarray(data)\n    shape_v = data.shape\n    data = np.reshape(data, [shape_v[0], -1])\n    t_lables = np.array(t_lables)\n    _data  = np.asarray([],dtype=np.float64).reshape(0,shape_v[1])\n    _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,)\n    for cl in classes:\n        _label = np.where(t_lables == cl)\n        permute = np.random.permutation(len(_label[0]))\n        _label = _label[0][permute[:max_nlabel]]\n\n        # _label = _label[0][:max_nlabel]\n        # permute = np.random.permutation(len(_label))\n        # _label = _label[permute]\n        _data = np.concatenate((_data, data[_label]))\n        _labels = np.concatenate((_labels, t_lables[_label]))\n\n    data = _data[:(len(_data)/ max_time) * max_time, :]\n    _labels = _labels[:(len(_data) / max_time) * max_time]\n\n    # data = _data\n    #  split data into sublist of 100=se_len values\n    data = [data[i:i + max_time] for i in range(0, len(data), max_time)]\n    labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)]\n    # shuffle\n    permute = np.random.permutation(len(labels))\n    data = np.asarray(data)\n    labels = np.asarray(labels)\n    data= data[permute]\n    labels = labels[permute]\n\n    print('Records processed!')\n\n    return data, labels\ndef evaluate_metrics(confusion_matrix):\n    # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal\n    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)\n    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)\n    TP = np.diag(confusion_matrix)\n    TN = confusion_matrix.sum() - (FP + FN + TP)\n    # Sensitivity, hit rate, recall, or true positive rate\n    TPR = TP / (TP + FN)\n    # Specificity or true negative rate\n    TNR = TN / (TN + FP)\n    # Precision or positive predictive value\n    PPV = TP / (TP + FP)\n    # Negative predictive value\n    NPV = TN / (TN + FN)\n    # Fall out or false positive rate\n    FPR = FP / (FP + TN)\n    # False negative rate\n    FNR = FN / (TP + FN)\n    # False discovery rate\n    FDR = FP / (TP + FP)\n\n    # Overall accuracy\n    ACC = (TP + TN) / (TP + FP + FN + TN)\n    # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN))\n    ACC_macro = np.mean(ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average)\n\n    return ACC_macro, ACC, TPR, TNR, PPV\ndef batch_data(x, y, batch_size):\n    shuffle = np.random.permutation(len(x))\n    start = 0\n    #     from IPython.core.debugger import Tracer; Tracer()()\n    x = x[shuffle]\n    y = y[shuffle]\n    while start + batch_size <= len(x):\n        yield x[start:start + batch_size], y[start:start + batch_size]\n        start += batch_size\ndef build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False):\n    _inputs = tf.reshape(inputs, [-1, n_channels, input_depth / n_channels])\n    # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels])\n\n    # #(batch*max_time, 280, 1) --> (N, 280, 18)\n    conv1 = tf.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n    max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')\n\n    conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n    max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')\n\n    conv3 = tf.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n\n    shape = conv3.get_shape().as_list()\n    data_input_embed = tf.reshape(conv3, (-1, max_time, shape[1] * shape[2]))\n\n    # timesteps = max_time\n    #\n    # lstm_in = tf.unstack(data_input_embed, timesteps, 1)\n    # lstm_size = 128\n    # # Get lstm cell output\n    # # Add LSTM layers\n    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)\n    # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32)\n    # data_input_embed = tf.stack(data_input_embed, 1)\n\n    # shape = data_input_embed.get_shape().as_list()\n\n    embed_size = 10  # 128 lstm_size # shape[1]*shape[2]\n\n    # Embedding layers\n    output_embedding = tf.Variable(tf.random_uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding')\n    data_output_embed = tf.nn.embedding_lookup(output_embedding, dec_inputs)\n\n    with tf.variable_scope(\"encoding\") as encoding_scope:\n        if not bidirectional:\n\n            # Regular approach with LSTM units\n            lstm_enc = tf.contrib.rnn.LSTMCell(num_units)\n            _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32)\n\n        else:\n\n            # Using a bidirectional LSTM architecture instead\n            enc_fw_cell = tf.contrib.rnn.LSTMCell(num_units)\n            enc_bw_cell = tf.contrib.rnn.LSTMCell(num_units)\n\n            ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.nn.bidirectional_dynamic_rnn(\n                cell_fw=enc_fw_cell,\n                cell_bw=enc_bw_cell,\n                inputs=data_input_embed,\n                dtype=tf.float32)\n            enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1)\n            enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1)\n            last_state = tf.contrib.rnn.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h)\n\n    with tf.variable_scope(\"decoding\") as decoding_scope:\n        if not bidirectional:\n            lstm_dec = tf.contrib.rnn.LSTMCell(num_units)\n        else:\n            lstm_dec = tf.contrib.rnn.LSTMCell(2 * num_units)\n\n        dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state)\n\n    logits = tf.layers.dense(dec_outputs, units=len(char2numY), use_bias=True)\n\n    return logits\ndef str2bool(v):\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\ndef main():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--epochs', type=int, default=500)\n    parser.add_argument('--max_time', type=int, default=10)\n    parser.add_argument('--test_steps', type=int, default=10)\n    parser.add_argument('--batch_size', type=int, default=20)\n    parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami')\n    parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False'))\n    # parser.add_argument('--lstm_layers', type=int, default=2)\n    parser.add_argument('--num_units', type=int, default=128)\n    parser.add_argument('--n_oversampling', type=int, default=10000)\n    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq')\n    parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih.ckpt')\n    parser.add_argument('--classes', nargs='+', type=chr,\n                        default=['F','N', 'S','V'])\n    args = parser.parse_args()\n    run_program(args)\ndef run_program(args):\n    print(args)\n    max_time = args.max_time # 5 3 second best 10# 40 # 100\n    epochs = args.epochs # 300\n    batch_size = args.batch_size # 10\n    num_units = args.num_units\n    bidirectional = args.bidirectional\n    # lstm_layers = args.lstm_layers\n    n_oversampling = args.n_oversampling\n    checkpoint_dir = args.checkpoint_dir\n    ckpt_name = args.ckpt_name\n    test_steps = args.test_steps\n    classes= args.classes\n    filename = args.data_dir\n\n    X, Y = read_mitbih(filename,max_time,classes=classes,max_nlabel=100000) #11000\n    print (\"# of sequences: \", len(X))\n    input_depth = X.shape[2]\n    n_channels = 10\n    classes = np.unique(Y)\n    char2numY = dict(zip(classes, range(len(classes))))\n    n_classes = len(classes)\n    print ('Classes: ', classes)\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print (cl, len(np.where(Y.flatten()==cl)[0]))\n    # char2numX['<PAD>'] = len(char2numX)\n    # num2charX = dict(zip(char2numX.values(), char2numX.keys()))\n    # max_len = max([len(date) for date in x])\n    #\n    # x = [[char2numX['<PAD>']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x]\n    # print(''.join([num2charX[x_] for x_ in x[4]]))\n    # x = np.array(x)\n\n    char2numY['<GO>'] = len(char2numY)\n    num2charY = dict(zip(char2numY.values(), char2numY.keys()))\n\n    Y = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in Y]\n    Y = np.array(Y)\n\n    x_seq_length = len(X[0])\n    y_seq_length = len(Y[0])- 1\n\n    # Placeholders\n    inputs = tf.placeholder(tf.float32, [None, max_time, input_depth], name = 'inputs')\n    targets = tf.placeholder(tf.int32, (None, None), 'targets')\n    dec_inputs = tf.placeholder(tf.int32, (None, None), 'output')\n\n    # logits = build_network(inputs,dec_inputs=dec_inputs)\n    logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time,\n                  bidirectional=bidirectional)\n    # decoder_prediction = tf.argmax(logits, 2)\n    # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong\n    # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1)\n\n    with tf.name_scope(\"optimization\"):\n        # Loss function\n        vars = tf.trainable_variables()\n        beta = 0.001\n        lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars\n                            if 'bias' not in v.name]) * beta\n        loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length]))\n        # Optimizer\n        loss = tf.reduce_mean(loss + lossL2)\n        optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss)\n\n\n    # split the dataset into the training and test sets\n    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)\n\n    # over-sampling: SMOTE\n    X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1])\n    y_train= y_train[:,1:].flatten()\n\n    nums = []\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        nums.append(len(np.where(y_train.flatten()==ind)[0]))\n    # ratio={0:nums[3],1:nums[1],2:nums[3],3:nums[3]} # the best with 11000 for N\n    ratio={0:n_oversampling,1:nums[1],2:n_oversampling,3:n_oversampling}\n    sm = SMOTE(random_state=12,ratio=ratio)\n    X_train, y_train = sm.fit_sample(X_train, y_train)\n\n    X_train = X_train[:(X_train.shape[0]/max_time)*max_time,:]\n    y_train = y_train[:(X_train.shape[0]/max_time)*max_time]\n\n    X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]])\n    y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,])\n    y_train= [[char2numY['<GO>']] + [y_ for y_ in date] for date in y_train]\n    y_train = np.array(y_train)\n\n    print ('Classes in the training set: ', classes)\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print (cl, len(np.where(y_train.flatten()==ind)[0]))\n    print (\"------------------y_train samples--------------------\")\n    for ii in range(2):\n      print(''.join([num2charY[y_] for y_ in list(y_train[ii+5])]))\n    print (\"------------------y_test samples--------------------\")\n    for ii in range(2):\n      print(''.join([num2charY[y_] for y_ in list(y_test[ii+5])]))\n\n    def test_model():\n        # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size))\n        acc_track = []\n        sum_test_conf = []\n        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)):\n\n            dec_input = np.zeros((len(source_batch), 1)) + char2numY['<GO>']\n            for i in range(y_seq_length):\n                batch_logits = sess.run(logits,\n                                        feed_dict={inputs: source_batch, dec_inputs: dec_input})\n                prediction = batch_logits[:, -1].argmax(axis=-1)\n                dec_input = np.hstack([dec_input, prediction[:, None]])\n            # acc_track.append(np.mean(dec_input == target_batch))\n            acc_track.append(dec_input[:, 1:] == target_batch[:, 1:])\n            y_true= target_batch[:, 1:].flatten()\n            y_pred = dec_input[:, 1:].flatten()\n            sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=range(len(char2numY)-1)))\n\n        sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0)\n\n        # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track)))\n\n        # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy],\n        #                                           feed_dict={inputs: source_batch,\n        #                                                      dec_inputs: dec_input[:, :-1],\n        #                                                      targets: target_batch[:, 1:]})\n        # print (mean_p_class)\n        # print (accuracy_classes)\n        acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf)\n        print('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg))\n        for index_ in range(n_classes):\n            print(\"\\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}\".format(classes[index_],\n                                                                                                          sensitivity[\n                                                                                                              index_],\n                                                                                                          specificity[\n                                                                                                              index_],PPV[index_],\n                                                                                                          acc[index_]))\n        print(\"\\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}\".format(np.mean(sensitivity),np.mean(specificity),np.mean(PPV),np.mean(acc)))\n        return acc_avg, acc, sensitivity, specificity, PPV\n    loss_track = []\n    def count_prameters():\n        print ('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))\n\n    count_prameters()\n\n    if (os.path.exists(checkpoint_dir) == False):\n        os.mkdir(checkpoint_dir)\n    # train the graph\n    with tf.Session() as sess:\n        sess.run(tf.global_variables_initializer())\n        sess.run(tf.local_variables_initializer())\n        saver = tf.train.Saver()\n        print(str(datetime.now()))\n        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)\n        pre_acc_avg = 0.0\n        if ckpt and ckpt.model_checkpoint_path:\n            # # Restore\n            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)\n            # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name))\n            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))\n            # or 'load meta graph' and restore weights\n            # saver = tf.train.import_meta_graph(ckpt_name+\".meta\")\n            # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir))\n            test_model()\n        else:\n\n            for epoch_i in range(epochs):\n                start_time = time.time()\n                train_acc = []\n                for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)):\n                    _, batch_loss, batch_logits = sess.run([optimizer, loss, logits],\n                        feed_dict = {inputs: source_batch,\n                                     dec_inputs: target_batch[:, :-1],\n                                     targets: target_batch[:, 1:]})\n                    loss_track.append(batch_loss)\n                    train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:])\n                # mean_p_class,accuracy_classes = sess.run([mean_accuracy,update_mean_accuracy],\n                #                         feed_dict={inputs: source_batch,\n                #                                               dec_inputs: target_batch[:, :-1],\n                #                                               targets: target_batch[:, 1:]})\n\n                # accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:])\n                accuracy = np.mean(train_acc)\n                print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss,\n                                                                                  accuracy, time.time() - start_time))\n\n                if epoch_i%test_steps==0:\n                    acc_avg, acc, sensitivity, specificity, PPV= test_model()\n\n                    print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))\n                    save_path = os.path.join(checkpoint_dir, ckpt_name)\n                    saver.save(sess, save_path)\n                    print(\"Model saved in path: %s\" % save_path)\n\n                    # if np.nan_to_num(acc_avg) > pre_acc_avg:  # save the better model based on the f1 score\n                    #     print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))\n                    #     pre_acc_avg = acc_avg\n                    #     save_path =os.path.join(checkpoint_dir, ckpt_name)\n                    #     saver.save(sess, save_path)\n                    #     print(\"The best model (till now) saved in path: %s\" % save_path)\n\n            plt.plot(loss_track)\n            plt.show()\n        print(str(datetime.now()))\n        # test_model()\nif __name__ == '__main__':\n    main()\n\n\n\n\n\n"
  },
  {
    "path": "tf2/seq_seq_annot_aami.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nimport scipy.io as spio\nfrom  sklearn.preprocessing import MinMaxScaler\nimport random\nimport time\nimport os\nfrom datetime import datetime\nfrom sklearn.metrics import confusion_matrix\nimport tensorflow as tf\nimport tensorflow_addons as tfa\nfrom imblearn.over_sampling import SMOTE\nfrom sklearn.model_selection import train_test_split\nimport argparse\nrandom.seed(654)\ndef read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100):\n    def normalize(data):\n        data = np.nan_to_num(data)  # removing NaNs and Infs\n        data = data - np.mean(data)\n        data = data / np.std(data)\n        return data\n\n    # read data\n    data = []\n    samples = spio.loadmat(filename + \".mat\")\n    samples = samples['s2s_mitbih']\n    values = samples[0]['seg_values']\n    labels = samples[0]['seg_labels']\n    num_annots = sum([item.shape[0] for item in values])\n\n    n_seqs = num_annots / max_time\n    #  add all segments(beats) together\n    l_data = 0\n    for i, item in enumerate(values):\n        l = item.shape[0]\n        for itm in item:\n            if l_data == n_seqs * max_time:\n                break\n            data.append(itm[0])\n            l_data = l_data + 1\n\n    #  add all labels together\n    l_lables  = 0\n    t_lables = []\n    for i, item in enumerate(labels):\n        if len(t_lables)==n_seqs*max_time:\n            break\n        item= item[0]\n        for lebel in item:\n            if l_lables == n_seqs * max_time:\n                break\n            t_lables.append(str(lebel))\n            l_lables = l_lables + 1\n\n    del values\n    data = np.asarray(data)\n    shape_v = data.shape\n    data = np.reshape(data, [shape_v[0], -1])\n    t_lables = np.array(t_lables)\n    _data  = np.asarray([],dtype=np.float64).reshape(0,shape_v[1])\n    _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,)\n    for cl in classes:\n        _label = np.where(t_lables == cl)\n        permute = np.random.permutation(len(_label[0]))\n        _label = _label[0][permute[:max_nlabel]]\n\n        # _label = _label[0][:max_nlabel]\n        # permute = np.random.permutation(len(_label))\n        # _label = _label[permute]\n        _data = np.concatenate((_data, data[_label]))\n        _labels = np.concatenate((_labels, t_lables[_label]))\n\n    data = _data[:int(len(_data)/ max_time) * max_time, :]\n    _labels = _labels[:int(len(_data) / max_time) * max_time]\n\n    # data = _data\n    #  split data into sublist of 100=se_len values\n    data = [data[i:i + max_time] for i in range(0, len(data), max_time)]\n    labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)]\n    # shuffle\n    permute = np.random.permutation(len(labels))\n    data = np.asarray(data)\n    labels = np.asarray(labels)\n    data= data[permute]\n    labels = labels[permute]\n\n    print('Records processed!')\n\n    return data, labels\ndef evaluate_metrics(confusion_matrix):\n    # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal\n    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)\n    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)\n    TP = np.diag(confusion_matrix)\n    TN = confusion_matrix.sum() - (FP + FN + TP)\n    # Sensitivity, hit rate, recall, or true positive rate\n    TPR = TP / (TP + FN)\n    # Specificity or true negative rate\n    TNR = TN / (TN + FP)\n    # Precision or positive predictive value\n    PPV = TP / (TP + FP)\n    # Negative predictive value\n    NPV = TN / (TN + FN)\n    # Fall out or false positive rate\n    FPR = FP / (FP + TN)\n    # False negative rate\n    FNR = FN / (TP + FN)\n    # False discovery rate\n    FDR = FP / (TP + FP)\n\n    # Overall accuracy\n    ACC = (TP + TN) / (TP + FP + FN + TN)\n    # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN))\n    ACC_macro = np.mean(ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average)\n\n    return ACC_macro, ACC, TPR, TNR, PPV\ndef batch_data(x, y, batch_size):\n    shuffle = np.random.permutation(len(x))\n    start = 0\n    #     from IPython.core.debugger import Tracer; Tracer()()\n    x = x[shuffle]\n    y = y[shuffle]\n    while start + batch_size <= len(x):\n        yield x[start:start + batch_size], y[start:start + batch_size]\n        start += batch_size\ndef build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False):\n    _inputs = tf.reshape(inputs, [-1, n_channels, int(input_depth / n_channels)])\n    # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels])\n\n    # #(batch*max_time, 280, 1) --> (N, 280, 18)\n    conv1 = tf.compat.v1.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n    max_pool_1 = tf.compat.v1.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')\n\n    conv2 = tf.compat.v1.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n    max_pool_2 = tf.compat.v1.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')\n\n    conv3 = tf.compat.v1.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1,\n                             padding='same', activation=tf.nn.relu)\n\n    shape = conv3.get_shape().as_list()\n    data_input_embed = tf.reshape(conv3, (-1, max_time, shape[1] * shape[2]))\n\n    # timesteps = max_time\n    #\n    # lstm_in = tf.unstack(data_input_embed, timesteps, 1)\n    # lstm_size = 128\n    # # Get lstm cell output\n    # # Add LSTM layers\n    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)\n    # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32)\n    # data_input_embed = tf.stack(data_input_embed, 1)\n\n    # shape = data_input_embed.get_shape().as_list()\n\n    embed_size = 10  # 128 lstm_size # shape[1]*shape[2]\n\n    # Embedding layers\n    output_embedding = tf.Variable(tf.random.uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding')\n    data_output_embed = tf.nn.embedding_lookup(params=output_embedding, ids=dec_inputs)\n\n    with tf.compat.v1.variable_scope(\"encoding\") as encoding_scope:\n        if not bidirectional:\n\n            # Regular approach with LSTM units\n            lstm_enc = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)\n            _, last_state = tf.compat.v1.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32)\n\n        else:\n\n            # Using a bidirectional LSTM architecture instead\n            enc_fw_cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)\n            enc_bw_cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)\n\n            ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.compat.v1.nn.bidirectional_dynamic_rnn(\n                cell_fw=enc_fw_cell,\n                cell_bw=enc_bw_cell,\n                inputs=data_input_embed,\n                dtype=tf.float32)\n            enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1)\n            enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1)\n            last_state = tf.nn.rnn_cell.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h)\n\n    with tf.compat.v1.variable_scope(\"decoding\") as decoding_scope:\n        if not bidirectional:\n            lstm_dec = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)\n        else:\n            lstm_dec = tf.compat.v1.nn.rnn_cell.LSTMCell(2 * num_units)\n\n        dec_outputs, _ = tf.compat.v1.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state)\n\n    logits = tf.compat.v1.layers.dense(dec_outputs, units=len(char2numY), use_bias=True)\n\n    return logits\ndef str2bool(v):\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\ndef main():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--epochs', type=int, default=500)\n    parser.add_argument('--max_time', type=int, default=10)\n    parser.add_argument('--test_steps', type=int, default=10)\n    parser.add_argument('--batch_size', type=int, default=20)\n    parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami')\n    parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False'))\n    # parser.add_argument('--lstm_layers', type=int, default=2)\n    parser.add_argument('--num_units', type=int, default=128)\n    parser.add_argument('--n_oversampling', type=int, default=10000)\n    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq')\n    parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih.ckpt')\n    parser.add_argument('--classes', nargs='+', type=chr,\n                        default=['F','N', 'S','V'])\n    args = parser.parse_args()\n    run_program(args)\ndef run_program(args):\n    print(args)\n    max_time = args.max_time # 5 3 second best 10# 40 # 100\n    epochs = args.epochs # 300\n    batch_size = args.batch_size # 10\n    num_units = args.num_units\n    bidirectional = args.bidirectional\n    # lstm_layers = args.lstm_layers\n    n_oversampling = args.n_oversampling\n    checkpoint_dir = args.checkpoint_dir\n    ckpt_name = args.ckpt_name\n    test_steps = args.test_steps\n    classes= args.classes\n    filename = args.data_dir\n\n    X, Y = read_mitbih(filename,max_time,classes=classes,max_nlabel=100000) #11000\n    print((\"# of sequences: \", len(X)))\n    input_depth = X.shape[2]\n    n_channels = 10\n    classes = np.unique(Y)\n    char2numY = dict(list(zip(classes, list(range(len(classes))))))\n    n_classes = len(classes)\n    print(('Classes: ', classes))\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print((cl, len(np.where(Y.flatten()==cl)[0])))\n    # char2numX['<PAD>'] = len(char2numX)\n    # num2charX = dict(zip(char2numX.values(), char2numX.keys()))\n    # max_len = max([len(date) for date in x])\n    #\n    # x = [[char2numX['<PAD>']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x]\n    # print(''.join([num2charX[x_] for x_ in x[4]]))\n    # x = np.array(x)\n\n    char2numY['<GO>'] = len(char2numY)\n    num2charY = dict(list(zip(list(char2numY.values()), list(char2numY.keys()))))\n\n    Y = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in Y]\n    Y = np.array(Y)\n\n    x_seq_length = len(X[0])\n    y_seq_length = len(Y[0])- 1\n\n    # Placeholders\n    tf.compat.v1.disable_eager_execution()\n    inputs = tf.compat.v1.placeholder(tf.float32, [None, max_time, input_depth], name = 'inputs')\n    targets = tf.compat.v1.placeholder(tf.int32, (None, None), 'targets')\n    dec_inputs = tf.compat.v1.placeholder(tf.int32, (None, None), 'output')\n\n    # logits = build_network(inputs,dec_inputs=dec_inputs)\n    logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time,\n                  bidirectional=bidirectional)\n    # decoder_prediction = tf.argmax(logits, 2)\n    # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong\n    # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1)\n\n    with tf.compat.v1.name_scope(\"optimization\"):\n        # Loss function\n        vars = tf.compat.v1.trainable_variables()\n        beta = 0.001\n        lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars\n                            if 'bias' not in v.name]) * beta\n        loss = tfa.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length]))\n        # Optimizer\n        loss = tf.reduce_mean(input_tensor=loss + lossL2)\n        optimizer = tf.compat.v1.train.RMSPropOptimizer(1e-3).minimize(loss)\n\n\n    # split the dataset into the training and test sets\n    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)\n\n    # over-sampling: SMOTE\n    X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1])\n    y_train= y_train[:,1:].flatten()\n\n    nums = []\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        nums.append(len(np.where(y_train.flatten()==ind)[0]))\n    # ratio={0:nums[3],1:nums[1],2:nums[3],3:nums[3]} # the best with 11000 for N\n    ratio={0:n_oversampling,1:nums[1],2:n_oversampling,3:n_oversampling}\n    sm = SMOTE(random_state=12,ratio=ratio)\n    X_train, y_train = sm.fit_sample(X_train, y_train)\n\n    X_train = X_train[:int(X_train.shape[0]/max_time)*max_time,:]\n    y_train = y_train[:int(X_train.shape[0]/max_time)*max_time]\n\n    X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]])\n    y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,])\n    y_train= [[char2numY['<GO>']] + [y_ for y_ in date] for date in y_train]\n    y_train = np.array(y_train)\n\n    print(('Classes in the training set: ', classes))\n    for cl in classes:\n        ind = np.where(classes == cl)[0][0]\n        print((cl, len(np.where(y_train.flatten()==ind)[0])))\n    print (\"------------------y_train samples--------------------\")\n    for ii in range(2):\n      print((''.join([num2charY[y_] for y_ in list(y_train[ii+5])])))\n    print (\"------------------y_test samples--------------------\")\n    for ii in range(2):\n      print((''.join([num2charY[y_] for y_ in list(y_test[ii+5])])))\n\n    def test_model():\n        # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size))\n        acc_track = []\n        sum_test_conf = []\n        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)):\n\n            dec_input = np.zeros((len(source_batch), 1)) + char2numY['<GO>']\n            for i in range(y_seq_length):\n                batch_logits = sess.run(logits,\n                                        feed_dict={inputs: source_batch, dec_inputs: dec_input})\n                prediction = batch_logits[:, -1].argmax(axis=-1)\n                dec_input = np.hstack([dec_input, prediction[:, None]])\n            # acc_track.append(np.mean(dec_input == target_batch))\n            acc_track.append(dec_input[:, 1:] == target_batch[:, 1:])\n            y_true= target_batch[:, 1:].flatten()\n            y_pred = dec_input[:, 1:].flatten()\n            sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=list(range(len(char2numY)-1))))\n\n        sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0)\n\n        # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track)))\n\n        # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy],\n        #                                           feed_dict={inputs: source_batch,\n        #                                                      dec_inputs: dec_input[:, :-1],\n        #                                                      targets: target_batch[:, 1:]})\n        # print (mean_p_class)\n        # print (accuracy_classes)\n        acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf)\n        print(('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg)))\n        for index_ in range(n_classes):\n            print((\"\\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}\".format(classes[index_],\n                                                                                                          sensitivity[\n                                                                                                              index_],\n                                                                                                          specificity[\n                                                                                                              index_],PPV[index_],\n                                                                                                          acc[index_])))\n        print((\"\\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}\".format(np.mean(sensitivity),np.mean(specificity),np.mean(PPV),np.mean(acc))))\n        return acc_avg, acc, sensitivity, specificity, PPV\n    loss_track = []\n    def count_prameters():\n        print(('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()])))\n\n    count_prameters()\n\n    if (os.path.exists(checkpoint_dir) == False):\n        os.mkdir(checkpoint_dir)\n    # train the graph\n    with tf.compat.v1.Session() as sess:\n        sess.run(tf.compat.v1.global_variables_initializer())\n        sess.run(tf.compat.v1.local_variables_initializer())\n        saver = tf.compat.v1.train.Saver()\n        print((str(datetime.now())))\n        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)\n        pre_acc_avg = 0.0\n        if ckpt and ckpt.model_checkpoint_path:\n            # # Restore\n            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)\n            # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name))\n            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))\n            # or 'load meta graph' and restore weights\n            # saver = tf.train.import_meta_graph(ckpt_name+\".meta\")\n            # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir))\n            test_model()\n        else:\n\n            for epoch_i in range(epochs):\n                start_time = time.time()\n                train_acc = []\n                for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)):\n                    _, batch_loss, batch_logits = sess.run([optimizer, loss, logits],\n                        feed_dict = {inputs: source_batch,\n                                     dec_inputs: target_batch[:, :-1],\n                                     targets: target_batch[:, 1:]})\n                    loss_track.append(batch_loss)\n                    train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:])\n                # mean_p_class,accuracy_classes = sess.run([mean_accuracy,update_mean_accuracy],\n                #                         feed_dict={inputs: source_batch,\n                #                                               dec_inputs: target_batch[:, :-1],\n                #                                               targets: target_batch[:, 1:]})\n\n                # accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:])\n                accuracy = np.mean(train_acc)\n                print(('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss,\n                                                                                  accuracy, time.time() - start_time)))\n\n                if epoch_i%test_steps==0:\n                    acc_avg, acc, sensitivity, specificity, PPV= test_model()\n\n                    print(('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size)))\n                    save_path = os.path.join(checkpoint_dir, ckpt_name)\n                    saver.save(sess, save_path)\n                    print((\"Model saved in path: %s\" % save_path))\n\n                    # if np.nan_to_num(acc_avg) > pre_acc_avg:  # save the better model based on the f1 score\n                    #     print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))\n                    #     pre_acc_avg = acc_avg\n                    #     save_path =os.path.join(checkpoint_dir, ckpt_name)\n                    #     saver.save(sess, save_path)\n                    #     print(\"The best model (till now) saved in path: %s\" % save_path)\n\n            plt.plot(loss_track)\n            plt.show()\n        print((str(datetime.now())))\n        # test_model()\nif __name__ == '__main__':\n    main()\n"
  }
]