Repository: MousaviSajad/ECG-Heartbeat-Classification-seq2seq-model Branch: master Commit: 52ed4efd5871 Files: 14 Total size: 86.4 KB Directory structure: gitextract_29kwkvv7/ ├── .gitignore ├── LICENSE ├── README.md ├── data/ │ └── readme ├── data preprocessing_Matlab/ │ ├── download_MITBIHDB.m │ ├── loadEcgSig.m │ ├── onoffset.m │ ├── qsPeaks.m │ ├── readme │ ├── seq2seq_mitbih_AAMI.m │ └── seq2seq_mitbih_AAMI_DS1DS2.m ├── seq_seq_annot_DS1DS2.py ├── seq_seq_annot_aami.py └── tf2/ └── seq_seq_annot_aami.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ ================================================ FILE: LICENSE ================================================ ================================================ FILE: README.md ================================================ # Inter- and intra- patient ECG heartbeat classification for arrhythmia detection: a sequence to sequence deep learning approach # Paper Our paper can be downloaded from the [arxiv website](https://arxiv.org/pdf/1812.07421v2) * The Network architecture ![Alt text](/images/seq2seq_b.jpg) ## Requirements * Python 2.7 * tensorflow/tensorflow-gpu * numpy * scipy * scikit-learn * matplotlib * imbalanced-learn (0.4.3) ## Dataset We evaluated our model using [the PhysioNet MIT-BIH Arrhythmia database](https://www.physionet.org/physiobank/database/mitdb/) * To download our pre-processed datasets use [this link](https://drive.google.com/drive/folders/19bDrAqlSGQuNLRmA-7pQRU9R81gSuY70?usp=sharing), then put them into the "data" folder. * Or you can follow the instructions of the readme file in the "data preprocessing_Matlab" folder to download the MIT-BIH database and perform data pre-processing. Then, put the pre-processed datasets into the "data" folder. ## Train * Modify args settings in seq_seq_annot_aami.py for the intra-patient ECG heartbeat classification * Modify args settings in seq_seq_annot_DS1DS2.py for the inter-patient ECG heartbeat classification * Run each file to reproduce the model described in the paper, use: ``` python seq_seq_annot_aami.py --data_dir data/s2s_mitbih_aami --epochs 500 ``` ``` python seq_seq_annot_DS1DS2.py --data_dir data/s2s_mitbih_aami_DS1DS2 --epochs 500 ``` ## Results ![Alt text](/images/results.jpg) ## Citation If you find it useful, please cite our paper as follows: ``` @article{mousavi2018inter, title={Inter-and intra-patient ECG heartbeat classification for arrhythmia detection: a sequence to sequence deep learning approach}, author={Mousavi, Sajad and Afghah, Fatemeh}, journal={arXiv preprint arXiv:1812.07421}, year={2018} } ``` ## References [deepschool.io](https://github.com/sachinruk/deepschool.io/blob/master/DL-Keras_Tensorflow) ## Licence For academtic and non-commercial usage ================================================ FILE: data/readme ================================================ Downlaod datasets using the below link and put them into the "data folder": [this link](https://drive.google.com/drive/folders/19bDrAqlSGQuNLRmA-7pQRU9R81gSuY70?usp=sharing) ================================================ FILE: data preprocessing_Matlab/download_MITBIHDB.m ================================================ % create mitbih datasets: download all records and convert the records into .info,.mat and .txt(for annotations) % before doing it, you have to install the open-source WFDB Software Package available at % http://physionet.org/physiotools/wfdb.shtml % download the whole database but files have .dat and .atr extentions % rc=physionetdb('mitdb',1); % input % output .info, .hea, .mat and.txt files record_list = physionetdb('mitdb'); path_to_exes = 'path_to_WFDB_Toolbox\mcode\nativelibs\windows\bin'; path_to_save_records = 'path_to_downloaded_database'; % path_to_exes = 'C:\my_files\ECG_research\mcode\nativelibs\windows\bin'; % path_to_save_records = 'C:\my_files\ECG_dataset\MIT-BIH\mitbihdb'; mkdir(path_to_save_records); cd(path_to_save_records); tic for i=1:length(record_list) command_annot = char(strcat(path_to_exes, filesep, 'rdann.exe -r mitdb/', record_list(i), ' -a atr -v >', record_list(i), 'm.txt')); system (command_annot); command_mat_info_ = char(strcat(path_to_exes, filesep, 'wfdb2mat.exe -r mitdb/', record_list(i), ' >', record_list(i), 'm.info')); system (command_mat_info_); % system('C:\my_files\ECG_research\mcode\nativelibs\windows\bin\wfdb2mat.exe -r 100s -f 0 -t 10 >100sm.info') % system('C:\my_files\ECG_research\mcode\nativelibs\windows\bin\rdann.exe -r mitdb/100 -a atr >100.txt') end toc disp('Successfully generated :)') ================================================ FILE: data preprocessing_Matlab/loadEcgSig.m ================================================ function [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig(Name) % USAGE: [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig('../data/200m') % This function reads a pair of files (RECORDm.mat and RECORDm.info) generated % by 'wfdb2mat' from a PhysioBank record, baseline-corrects and scales the time % series contained in the .mat file, returning time, amplitude and frequency. % The baseline-corrected and scaled time series are the rows of matrix 'val', and each % column contains simultaneous samples of each time series. % 'wfdb2mat' is part of the open-source WFDB Software Package available at % http://physionet.org/physiotools/wfdb.shtml % If you have installed a working copy of 'wfdb2mat', run a shell command % such as wfdb2mat -r 100s -f 0 -t 10 >100sm.info % to create a pair of files ('100sm.mat', '100sm.info') that can be read % by this function. % The files needed by this function can also be produced by the % PhysioBank ATM, at http://physionet.org/cgi-bin/ATM % Adapted from % loadEcgSignal.m O. Abdala 16 March 2009 % James Hislop 27 January 2014 version 1.1 % Last version % loadEcgSignal.m D. Kawasaki 15 June 2017 % Davi Kawasaki 15 June 2017 version 1.0 infoName = strcat(Name, '.info'); matName = strcat(Name, '.mat'); load(matName); ecgsig = val; fid = fopen(infoName, 'rt'); fgetl(fid); fgetl(fid); fgetl(fid); [freqint] = sscanf(fgetl(fid), 'Sampling frequency: %f Hz Sampling interval: %f sec'); Fs = freqint(1); interval = freqint(2); fgetl(fid); for i = 1:size(ecgsig, 1) [row(i), signal(i), gain(i), base(i), units(i)]=strread(fgetl(fid),'%d%s%f%f%s','delimiter','\t'); end fclose(fid); ecgsig(ecgsig==-32768) = NaN; for i = 1:size(ecgsig, 1) ecgsig(i, :) = (ecgsig(i, :) - base(i)) / gain(i); end N = size(ecgsig, 2); tm1 = 1/Fs:1/Fs:N/Fs; tm = (1:size(ecgsig, 2)) * interval; sizeEcgSig = size(ecgsig, 2); timeEcgSig = sizeEcgSig*interval; %plot(tm', val'); %for i = 1:length(signal) % labels{i} = strcat(signal{i}, ' (', units{i}, ')'); %end %legend(labels); %xlabel('Time (sec)'); % grid on % load annotations annotationName = strcat(Name, '.txt'); fid = fopen(annotationName, 'rt'); % was annotationsEcg = textscan(fid, '%d:%f %d %*c %*d %*d %*d %s', 'HeaderLines', 1, 'CollectOutput', 1); ann = textscan(fid, '%d:%f %d %c %d %d %d %s', 'HeaderLines', 1, 'CollectOutput', 1); fclose(fid); end ================================================ FILE: data preprocessing_Matlab/onoffset.m ================================================ function [ ind ] = onoffset( interval,mode ) %Function calculates on/off set of QRS complexe slope = []; for i = 2:length(interval)-1 slope(end+1) = interval(i+1)-interval(i-1); end % using MIN_SLOPE determine onset placement if strcmp(mode,'on') [m,ind] = min(abs(slope)); %display('onset detected'); elseif strcmp(mode,'off') slope_th = 0.2*max(abs(slope)); slope_s = find(abs(slope)>=slope_th); ind = slope_s(1); else display('wrong input, please select on/off set') end end ================================================ FILE: data preprocessing_Matlab/qsPeaks.m ================================================ function [ECGpeaks] = QSpeaks( ECG,Rposition,fs ) %Q,S peaks detection % point to point time duration is determined by sampling frequency: 1/fs % the duration of QRS complex varies from 0.1s to 0.2s % complex--fs*0.2 aveHB = length(ECG)/length(Rposition); fid_pks = zeros(length(Rposition),7); % P QRSon Q R S QRSoff %% set up searching windows windowS = round(fs*0.1); windowQ = round(fs*0.05); windowP = round(aveHB/3); windowT = round(aveHB*2/3); windowOF = round(fs*0.04); % initialization for i = 1:length(Rposition) thisR = Rposition(i); if i==1 fid_pks(i,4) = thisR; fid_pks(i,6) = thisR+windowS; elseif i==length(Rposition) fid_pks(i,4) = thisR; %(thisR+windowT) < length(ECG) && (thisR - windowP) >=1 fid_pks(i,2) = thisR-windowQ; else if (thisR+windowT) < length(ECG) && (thisR - windowP) >=1 % Q S peaks fid_pks(i,4) = thisR; [Sv,Sp] = min(ECG(thisR:thisR+windowS)); thisS = Sp + thisR-1; fid_pks(i,5) = thisS; [Qv,Qp] = min(ECG(thisR-windowQ:thisR)); thisQ = thisR-(windowQ+1) + Qp; fid_pks(i,3)=thisQ; % onset and offset detection interval_q = ECG(thisQ-windowOF:thisQ); [ ind ] = onoffset(interval_q,'on' ); thisON = thisQ - (windowOF+1) + ind; interval_s = ECG(thisS:thisS+windowOF); [ ind ] = onoffset( interval_s,'off' ); thisOFF = thisS + ind-1; fid_pks(i,2) = thisON; fid_pks(i,6) = thisOFF; % % T and P waves detection % lastOFF = fid_pks(i-1,6); % nextON = end end end %% % P,T detection % Detection T waves first and distinguish the type of T waves for i = 2:length(Rposition)-1 lastOFF = fid_pks(i-1,6); thisON = fid_pks(i,2); thisOFF = fid_pks(i,6); nextON = fid_pks(i+1,2); if thisON>lastOFF && thisOFF length(ECG) % Tzone = ECG(QRS_OFF(i):length(ECG)); % else % Tzone = ECG(QRS_OFF(i):QRS_OFF(i)+lengthTzone-1); % end % % if length(max(Tzone))~=1 % % deb = 1; % % end % [Tpks(end+1),Tind] = max(Tzone); % Tposition(end+1) = QRS_OFF(i)+Tind-1; % lengthPzone = int64((-QRS_OFF(i)+QRS_ON(i+1))/3); % if lengthPzone <= 0 % lengthPzone = abs(lengthPzone); % disp('abnormal heart beat'); % end % Pzone = ECG(QRS_ON(i+1)-lengthPzone-1:QRS_ON(i+1)); % [Ppks(end+1),Pind] = max(Pzone); % Pposition(end+1) = QRS_ON(i+1)+Pind-lengthPzone; % end % else % for i = 1:length(QRS_ON)-1 % lengthTzone = int64((QRS_ON(i+1)-QRS_OFF(i))*2/3); % if QRS_OFF(i)+lengthTzone-1 > length(ECG) % Tzone = ECG(QRS_OFF(i):length(ECG)); % else % Tzone = ECG(QRS_OFF(i):QRS_OFF(i)+lengthTzone-1); % end % [Tpks(end+1),Tind] = max(abs(Tzone)); % Tposition(end+1) = QRS_OFF(i)+Tind-1; % lengthPzone = int64((-QRS_OFF(i)+QRS_ON(i+1))/3); % Pzone = ECG(QRS_ON(i+1)-lengthPzone-1:QRS_ON(i+1)); % [Ppks(end+1),Pind] = max(Pzone); % Pposition(end+1) = QRS_ON(i+1)+Pind-lengthPzone; % end % end % for i = 1:length(Tpks) % if Tpks(i)>0 % display('T Upright'); % else % display('T Inverted'); % end % end % hold on;plot(Tposition,Tpks,'o'); % hold on;plot(Pposition,Ppks,'*'); %% % [pks,locs] = findpeaks(A5);hold on; plot(locs,pks,'*'); % % for i = 1:length(S) % SearchArea = []; % for j = 1:length(locs) % if locs(j) > S(i) % SearchArea(end+1) = locs(j); % end % end % su = []; % for k = 1:length(SearchArea) % su(end+1) = SearchArea(k)-S(i); % end % if ~isempty(su) % Tpks(end+1) = S(i) + min(su); % end % end % % % P detection % % % for i = 1:length(Q) % SearchArea = []; % for j = 1:length(locs) % if locs(j) < Q(i) % SearchArea(end+1) = locs(j); % end % end % su = []; % for k = 1:length(SearchArea) % su(end+1) = Q(i)-SearchArea(k); % end % if ~isempty(su) % Ppks(end+1) = Q(i) - min(su); % end % end % % check the result % for i = 1:length(Tpks) % T(i) = A5(Tpks(i)); % end % for i = 1:length(Ppks) % P(i) = A5(Ppks(i)); % end % figure(5);plot(A5); % hold on;plot(Tpks,T,'*'); % hold on;plot(Ppks,P,'+'); %% %Oops after interpolation, we slightly deviate from the true value %Thus we need to search for the local minimum in the whole signal again % Qposition = int64(Q.*Scale); % Sposition = int64(S.*Scale); % % for i = 1:length(R) % [Sv,Sp] = min(ECG(Sposition(i):Sposition(i)+20)); % Sposition(i) = Sp + Sposition(i)-1; % [Qv,Qp] = min(ECG(Qposition(i)-20:Qposition(i))); % Qposition(i) = Qposition(i)-21 + Qp; % end % % Pposition = int64(Ppks.*Scale); % Tposition = int64(Tpks.*Scale); % % for i = length(Pposition) % [Pv,Pp] = max(ECG(Pposition(i)-20:Pposition(i))); % Pposition(i) = Pposition(i)-21 + Pp; % end % % for i = length(Tposition) % [Tv,Tp] = max(ECG(Tposition(i):Tposition(i)+20)); % Tposition(i) = Tp + Tposition(i)-1; % end % ecgS = []; % for i = 1:2271 % ecgS(end+1) = issorted(ECGpeaks(1,:)); % end % prod(ecgS) ================================================ FILE: data preprocessing_Matlab/readme ================================================ 1. Install WFDB Toolbox for MATLAB from https://physionet.org/physiotools/matlab/wfdb-app-matlab/ 2. Run download_MITBIHDB.m to download the MIT-BIH database -- In the code, you need to determine the local path to save the database and the path for the installed WFDB Toolbox For example: -- path_to_exes = 'path_to_WFDB_Toolbox\mcode\nativelibs\windows\bin'; -- path_to_save_records = 'path_to_save_database'; 3. Run seq2seq_mitbih_AAMI.m to prepare data for the intra-patient paradigm -- Modify the below line of the code based on your local address to the mitbih database --addr = '.\mitbihdb'; 4. Run seq2seq_mitbih_AAMI_DS1DS2.m to prepare data for the inter-patient paradigm -- Modify the below line of the code based on your local address to the mitbih database -- addr = '.\mitbihdb'; ================================================ FILE: data preprocessing_Matlab/seq2seq_mitbih_AAMI.m ================================================ clear all clc tic addr = '.\mitbihdb'; Files=dir(strcat(addr,'\*.mat')); %% Translate PhysioNet classification results to AAMI and AAMI2 labling schemes % AAMI Classes: % % N = N, L, R, e, j % % S = A, a, J, S % % V = V, E % % F = F % % Q = /, f, Q % AAMI2 Classes: % % N = N, L, R, e, j % % S = A, a, J, S % % V = V, E, F % % Q = /, f, Q % https://github.com/ehendryx/deim-cur-ecg/blob/master/DS1_MIT_CUR_beat_classification.m AAMI_annotations = {'N' 'S' 'V' 'F' 'Q'}; AAMI2_annotations = {'N' 'S' 'V_hat' 'Q'}; index = 1; beat_len = 280; n_cycles = 0; featuresSeg = []; groupN = []; groupV = []; groupS = []; groupF = []; groupQ = []; N_class = 0;V_class=0;F_class=0;Q_class=0;S_class=0; for i=1:length(Files) % Files names3signals %% load the files % load ('100m.mat') % the signal will be loaded to "val" matrix % val = (val - 1024)/200; % you have to remove "base" and "gain" % ECGsignal = val(1,1:1000); % select the lead (Lead I) % Fs = 360; % sampling frequecy % t = (0:length(ECGsignal)-1)/Fs; % time % plot(t,ECGsignal) % [pathstr,name,ext] = fileparts(Files(i).name); nsig = 1; [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig([addr filesep name]); signal = ecgsig(nsig,:); %% % rPeaks = rDetection(signal, Fs); % rPeaks = get_rpeaks(signal, Fs); rPeaks = cell2mat(ann(3))+1; n_cycles = n_cycles + length(rPeaks); % [R_i,R_amp,S_i,S_amp,T_i,T_amp,Q_i,Q_amp] = peakdetect(signal,Fs); % rPeaks = R_i; rPeaks = double(rPeaks); peaks = qsPeaks(signal, rPeaks, Fs); tpeaks = peaks(:,7); % %% Plot P Q R S T points % N = length(signal); % tm = 1/Fs:1/Fs:N/Fs; % figure;plot(tm,signal);hold on % scatter(peaks(:,1)/Fs,signal(peaks(:,1)),'g*') % P points % scatter(peaks(:,3)/Fs,signal(peaks(:,3)),'k+') % Q points % scatter(peaks(:,4)/Fs,signal(peaks(:,4)),'ro') % R points % scatter(peaks(:,5)/Fs,signal(peaks(:,5)),'c^') % S points % scatter(peaks(:,7)/Fs,signal(peaks(:,7)),'mo') % T points % xlabel('Seconds'); ylabel('Amplitude') % title('ECG peaks detection') % legend('Raw signal','P','Q','R','S','T') % hold off % %% grouping % gourp 0: N(normal and bundle branch block beats); group 2: V(ventricular %ectopic beats); group 1: S(supraventricular ectopic beats); group 3: F (fusion of N and V beats) % group Q:4 unknown beat % consider just absolute features, where each row of extraxted features is % related to one segment annots_list = ['N','L','R','e','j','S','A','a','J','V','E','F','/','f','Q']; annot = cell2mat(ann(4)); indices = ismember(rPeaks,peaks(:,4)); annot = annot(indices); % rps = peaks(:,4); % AAMI Classes: % % N = N, L, R, e, j % % S = A, a, J, S % % V = V, E % % F = F % % Q = /, f, Q seg_values = {}; seg_labels =[]; ind_seg = 1; % normalize signal = normalize(signal); for ind=1:length(annot) if ~ismember(annot(ind),annots_list) continue; end N_g = ['N', 'L', 'R', 'e', 'j'];%0 S_g = ['A', 'a', 'J', 'S'];%1 V_g = ['V', 'E'];%2 F_g = ['F'];%3 Q_g = [' /', 'f', 'Q'];%4 if(ismember(annot(ind),N_g)) lebel = 'N'; % if(N_class >8031) %(N_class >8031) % continue % end elseif(ismember(annot(ind),S_g)) lebel = 'S'; elseif(ismember(annot(ind),V_g)) lebel = 'V'; elseif(ismember(annot(ind),F_g)) lebel = 'F'; elseif(ismember(annot(ind),Q_g)) lebel = 'Q'; else throw("No label! :(") end if ind==1 seg_values{ind_seg} = signal(1:tpeaks(ind)-1)'; t_sig = imresize(seg_values{ind_seg}(1:min(Fs,length(seg_values{ind_seg}))), [beat_len 1]); seg_values{ind_seg} = t_sig; seg_labels(ind_seg) = lebel; % plot(cell2mat(seg_values(ind_seg))) ind_seg = ind_seg+1; continue; end t_sig = imresize(signal(tpeaks(ind-1):tpeaks(ind)-1)', [beat_len 1]); seg_values{ind_seg} =t_sig ; % figure; % plot(cell2mat(seg_values(ind_seg))) % determine the label seg_labels(ind_seg) = lebel; ind_seg = ind_seg+1; end s2s_mitbih(i).seg_values = seg_values'; s2s_mitbih(i).seg_labels = char(seg_labels); % featuresSeg = [featuresSeg; peakSegFeats(N_inds,:),repmat(0,length(N_inds),1)]; % group N:0 % N = N, L, R, e, j N_inds = find(annot=='N'); N_inds = [N_inds;find(annot=='L')]; N_inds = [N_inds;find(annot=='R')]; N_inds = [N_inds;find(annot=='e')]; N_inds = [N_inds;find(annot=='j')]; N_class = N_class + length(N_inds); % group S:1 % S = A, a, J, S S_inds = find(annot=='S'); S_inds = [S_inds;find(annot=='A')]; S_inds = [S_inds;find(annot=='a')]; S_inds = [S_inds;find(annot=='J')]; S_class = S_class + length(S_inds); % group V:2 % V = V, E V_inds = find(annot=='V'); V_inds = [V_inds;find(annot=='E')]; V_class = V_class + length(V_inds); % featuresSeg = [featuresSeg; peakSegFeats(V_inds,:),repmat(2,length(V_inds),1)]; % group F:3 % F = F F_inds = find(annot=='F'); F_class = F_class + length(F_inds); % group Q:4 % Q = /, f, Q Q_inds = find(annot=='/'); Q_inds = [Q_inds;find(annot=='f')]; Q_inds = [Q_inds;find(annot=='Q')]; Q_class = Q_class + length(Q_inds); end % % calucualte the mean length of all beats in the dataset: it is 280 % sizes = []; % for ind=1:length(s2s_mitbih) % sizes= [sizes;cellfun(@length,s2s_mitbih(ind).seg_values)]; % end % beat_len = floor(mean(sizes)) save s2s_mitbih_aami.mat s2s_mitbih toc F_class N_class Q_class S_class V_class F_class+N_class+Q_class+S_class+V_class disp('Successfully generated :)') ================================================ FILE: data preprocessing_Matlab/seq2seq_mitbih_AAMI_DS1DS2.m ================================================ clear all clc tic addr = '.\mitbihdb'; Files=dir(strcat(addr,'\*.mat')); %% Translate PhysioNet classification results to AAMI and AAMI2 labling schemes % AAMI Classes: % % N = N, L, R, e, j % % S = A, a, J, S % % V = V, E % % F = F % % Q = /, f, Q % AAMI2 Classes: % % N = N, L, R, e, j % % S = A, a, J, S % % V = V, E, F % % Q = /, f, Q % https://github.com/ehendryx/deim-cur-ecg/blob/master/DS1_MIT_CUR_beat_classification.m AAMI_annotations = {'N' 'S' 'V' 'F' 'Q'}; AAMI2_annotations = {'N' 'S' 'V_hat' 'Q'}; DS1=[101, 106, 108, 109, 112, 114, 115, 116, 118,119, 122, 124, 201, 203, 205, 207, 208, 209, 215, 220, 223,230]; DS2 =[100, 103, 105, 111, 113, 117, 121, 123,200, 202, 210, 212, 213, 214, 219, 221, 222, 228, 231, 232, 233,234]; index = 1; beat_len = 280; n_cycles = 0; featuresSeg = []; groupN = []; groupV = []; groupS = []; groupF = []; groupQ = []; for j=1:2 N_class = 0;V_class=0;F_class=0;Q_class=0;S_class=0; if j==1 DS = DS1; else DS= DS2; end for i=1:length(DS) % Files names3signals %% load the files % load ('100m.mat') % the signal will be loaded to "val" matrix % val = (val - 1024)/200; % you have to remove "base" and "gain" % ECGsignal = val(1,1:1000); % select the lead (Lead I) % Fs = 360; % sampling frequecy % t = (0:length(ECGsignal)-1)/Fs; % time % plot(t,ECGsignal) % [pathstr,name,ext] = fileparts(strcat(num2str (DS(i)),'m')); nsig = 1; [tm,ecgsig,ann,Fs,sizeEcgSig,timeEcgSig] = loadEcgSig([addr filesep name]); signal = ecgsig(nsig,:); %% % rPeaks = rDetection(signal, Fs); % rPeaks = get_rpeaks(signal, Fs); rPeaks = cell2mat(ann(3))+1; n_cycles = n_cycles + length(rPeaks); % [R_i,R_amp,S_i,S_amp,T_i,T_amp,Q_i,Q_amp] = peakdetect(signal,Fs); % rPeaks = R_i; rPeaks = double(rPeaks); peaks = qsPeaks(signal, rPeaks, Fs); tpeaks = peaks(:,7); % %% Plot P Q R S T points % N = length(signal); % tm = 1/Fs:1/Fs:N/Fs; % figure;plot(tm,signal);hold on % scatter(peaks(:,1)/Fs,signal(peaks(:,1)),'g*') % P points % scatter(peaks(:,3)/Fs,signal(peaks(:,3)),'k+') % Q points % scatter(peaks(:,4)/Fs,signal(peaks(:,4)),'ro') % R points % scatter(peaks(:,5)/Fs,signal(peaks(:,5)),'c^') % S points % scatter(peaks(:,7)/Fs,signal(peaks(:,7)),'mo') % T points % xlabel('Seconds'); ylabel('Amplitude') % title('ECG peaks detection') % legend('Raw signal','P','Q','R','S','T') % hold off % %% grouping % gourp 0: N(normal and bundle branch block beats); group 2: V(ventricular %ectopic beats); group 1: S(supraventricular ectopic beats); group 3: F (fusion of N and V beats) % group Q:4 unknown beat % consider just absolute features, where each row of extraxted features is % related to one segment annots_list = ['N','L','R','e','j','S','A','a','J','V','E','F','/','f','Q']; annot = cell2mat(ann(4)); indices = ismember(rPeaks,peaks(:,4)); annot = annot(indices); % rps = peaks(:,4); % AAMI Classes: % % N = N, L, R, e, j % % S = A, a, J, S % % V = V, E % % F = F % % Q = /, f, Q seg_values = {}; seg_labels =[]; ind_seg = 1; % normalize signal = normalize(signal); for ind=1:length(annot) if ~ismember(annot(ind),annots_list) continue; end N_g = ['N', 'L', 'R', 'e', 'j'];%0 S_g = ['A', 'a', 'J', 'S'];%1 V_g = ['V', 'E'];%2 F_g = ['F'];%3 Q_g = [' /', 'f', 'Q'];%4 if(ismember(annot(ind),N_g)) lebel = 'N'; % if(N_class >8031) %(N_class >8031) % continue % end elseif(ismember(annot(ind),S_g)) lebel = 'S'; elseif(ismember(annot(ind),V_g)) lebel = 'V'; elseif(ismember(annot(ind),F_g)) lebel = 'F'; elseif(ismember(annot(ind),Q_g)) lebel = 'Q'; else throw("No label! :(") end if ind==1 seg_values{ind_seg} = signal(1:tpeaks(ind)-1)'; t_sig = imresize(seg_values{ind_seg}(1:min(Fs,length(seg_values{ind_seg}))), [beat_len 1]); seg_values{ind_seg} = t_sig; seg_labels(ind_seg) = lebel; % plot(cell2mat(seg_values(ind_seg))) ind_seg = ind_seg+1; continue; end t_sig = imresize(signal(tpeaks(ind-1):tpeaks(ind)-1)', [beat_len 1]); seg_values{ind_seg} =t_sig ; % figure; % plot(cell2mat(seg_values(ind_seg))) % determine the label seg_labels(ind_seg) = lebel; ind_seg = ind_seg+1; end if j==1 s2s_mitbih_DS1(i).seg_values = seg_values'; s2s_mitbih_DS1(i).seg_labels = char(seg_labels); else s2s_mitbih_DS2(i).seg_values = seg_values'; s2s_mitbih_DS2(i).seg_labels = char(seg_labels); end % featuresSeg = [featuresSeg; peakSegFeats(N_inds,:),repmat(0,length(N_inds),1)]; % group N:0 % N = N, L, R, e, j N_inds = find(annot=='N'); N_inds = [N_inds;find(annot=='L')]; N_inds = [N_inds;find(annot=='R')]; N_inds = [N_inds;find(annot=='e')]; N_inds = [N_inds;find(annot=='j')]; N_class = N_class + length(N_inds); % group S:1 % S = A, a, J, S S_inds = find(annot=='S'); S_inds = [S_inds;find(annot=='A')]; S_inds = [S_inds;find(annot=='a')]; S_inds = [S_inds;find(annot=='J')]; S_class = S_class + length(S_inds); % group V:2 % V = V, E V_inds = find(annot=='V'); V_inds = [V_inds;find(annot=='E')]; V_class = V_class + length(V_inds); % featuresSeg = [featuresSeg; peakSegFeats(V_inds,:),repmat(2,length(V_inds),1)]; % group F:3 % F = F F_inds = find(annot=='F'); F_class = F_class + length(F_inds); % group Q:4 % Q = /, f, Q Q_inds = find(annot=='/'); Q_inds = [Q_inds;find(annot=='f')]; Q_inds = [Q_inds;find(annot=='Q')]; Q_class = Q_class + length(Q_inds); end F_class N_class Q_class S_class V_class F_class+N_class+Q_class+S_class+V_class end % % calucualte the mean length of all beats in the dataset: it is 280 % sizes = []; % for ind=1:length(s2s_mitbih) % sizes= [sizes;cellfun(@length,s2s_mitbih(ind).seg_values)]; % end % beat_len = floor(mean(sizes)) save s2s_mitbih_aami_DS1DS2.mat s2s_mitbih_DS1 s2s_mitbih_DS2 toc disp('Successfully generated :)') ================================================ FILE: seq_seq_annot_DS1DS2.py ================================================ import numpy as np import matplotlib.pyplot as plt import scipy.io as spio from sklearn.preprocessing import MinMaxScaler import random import time import os from datetime import datetime from sklearn.metrics import confusion_matrix import tensorflow as tf from imblearn.over_sampling import SMOTE from sklearn.model_selection import train_test_split import argparse random.seed(654) def read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100, trainset=1): def normalize(data): data = np.nan_to_num(data) # removing NaNs and Infs data = data - np.mean(data) data = data / np.std(data) return data # read data data = [] samples = spio.loadmat(filename + ".mat") if trainset == 1: #DS1 samples = samples['s2s_mitbih_DS1'] else: # DS2 samples = samples['s2s_mitbih_DS2'] values = samples[0]['seg_values'] labels = samples[0]['seg_labels'] items_len = len(labels) num_annots = sum([item.shape[0] for item in values]) n_seqs = num_annots / max_time # add all segments(beats) together l_data = 0 for i, item in enumerate(values): l = item.shape[0] for itm in item: if l_data == n_seqs * max_time: break data.append(itm[0]) l_data = l_data + 1 # add all labels together l_lables = 0 t_lables = [] for i, item in enumerate(labels): if len(t_lables)==n_seqs*max_time: break item= item[0] for lebel in item: if l_lables == n_seqs * max_time: break t_lables.append(str(lebel)) l_lables = l_lables + 1 del values data = np.asarray(data) shape_v = data.shape data = np.reshape(data, [shape_v[0], -1]) t_lables = np.array(t_lables) _data = np.asarray([],dtype=np.float64).reshape(0,shape_v[1]) _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,) for cl in classes: _label = np.where(t_lables == cl) permute = np.random.permutation(len(_label[0])) _label = _label[0][permute[:max_nlabel]] # _label = _label[0][:max_nlabel] # permute = np.random.permutation(len(_label)) # _label = _label[permute] _data = np.concatenate((_data, data[_label])) _labels = np.concatenate((_labels, t_lables[_label])) data = _data[:(len(_data)/ max_time) * max_time, :] _labels = _labels[:(len(_data) / max_time) * max_time] # data = _data # split data into sublist of 100=se_len values data = [data[i:i + max_time] for i in range(0, len(data), max_time)] labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)] # shuffle permute = np.random.permutation(len(labels)) data = np.asarray(data) labels = np.asarray(labels) data= data[permute] labels = labels[permute] print('Records processed!') return data, labels def evaluate_metrics(confusion_matrix): # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal # Sensitivity, hit rate, recall, or true positive rate FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix) FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix) TP = np.diag(confusion_matrix) TN = confusion_matrix.sum() - (FP + FN + TP) TPR = TP / (TP + FN) # Specificity or true negative rate TNR = TN / (TN + FP) # Precision or positive predictive value PPV = TP / (TP + FP) # Negative predictive value NPV = TN / (TN + FN) # Fall out or false positive rate FPR = FP / (FP + TN) # False negative rate FNR = FN / (TP + FN) # False discovery rate FDR = FP / (TP + FP) # Overall accuracy ACC = (TP + TN) / (TP + FP + FN + TN) # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN)) ACC_macro = np.mean( ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average) return ACC_macro, ACC, TPR, TNR, PPV def batch_data(x, y, batch_size): shuffle = np.random.permutation(len(x)) start = 0 # from IPython.core.debugger import Tracer; Tracer()() x = x[shuffle] y = y[shuffle] while start + batch_size <= len(x): yield x[start:start+batch_size], y[start:start+batch_size] start += batch_size def build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False): _inputs = tf.reshape(inputs, [-1, n_channels, input_depth / n_channels]) # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels]) # #(batch*max_time, 280, 1) --> (N, 280, 18) conv1 = tf.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same') conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same') conv3 = tf.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) shape = conv3.get_shape().as_list() data_input_embed = tf.reshape(conv3, (-1, max_time,shape[1]*shape[2])) # timesteps = max_time # # lstm_in = tf.unstack(data_input_embed, timesteps, 1) # lstm_size = 128 # # Get lstm cell output # # Add LSTM layers # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size) # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32) # data_input_embed = tf.stack(data_input_embed, 1) # shape = data_input_embed.get_shape().as_list() embed_size = 10 #128 lstm_size # shape[1]*shape[2] # Embedding layers output_embedding = tf.Variable(tf.random_uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding') data_output_embed = tf.nn.embedding_lookup(output_embedding, dec_inputs) with tf.variable_scope("encoding") as encoding_scope: if not bidirectional: # Regular approach with LSTM units lstm_enc = tf.contrib.rnn.LSTMCell(num_units) _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32) else: # Using a bidirectional LSTM architecture instead enc_fw_cell = tf.contrib.rnn.LSTMCell(num_units) enc_bw_cell = tf.contrib.rnn.LSTMCell(num_units) ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.nn.bidirectional_dynamic_rnn( cell_fw=enc_fw_cell, cell_bw=enc_bw_cell, inputs=data_input_embed, dtype=tf.float32) enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1) enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1) last_state = tf.contrib.rnn.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h) with tf.variable_scope("decoding") as decoding_scope: if not bidirectional: lstm_dec = tf.contrib.rnn.LSTMCell(num_units) else: lstm_dec = tf.contrib.rnn.LSTMCell(2 * num_units) dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state) logits = tf.layers.dense(dec_outputs, units=len(char2numY), use_bias=True) return logits def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def main(): parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=500) parser.add_argument('--max_time', type=int, default=10) parser.add_argument('--test_steps', type=int, default=10) parser.add_argument('--batch_size', type=int, default=20) parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami_DS1DS2') parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False')) # parser.add_argument('--lstm_layers', type=int, default=2) parser.add_argument('--num_units', type=int, default=128) parser.add_argument('--n_oversampling', type=int, default=6000) parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq_DS1DS2') parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih_DS1DS2.ckpt') parser.add_argument('--classes', nargs='+', type=chr, default=['N', 'S','V']) args = parser.parse_args() run_program(args) def run_program(args): print(args) max_time = args.max_time # 5 3 second best 10# 40 # 100 epochs = args.epochs # 300 batch_size = args.batch_size # 10 num_units = args.num_units bidirectional = args.bidirectional # lstm_layers = args.lstm_layers n_oversampling = args.n_oversampling checkpoint_dir = args.checkpoint_dir ckpt_name = args.ckpt_name test_steps = args.test_steps classes= args.classes # ['N', 'S','V'] filename = args.data_dir X_train, y_train = read_mitbih(filename, max_time, classes=classes, max_nlabel=50000,trainset=1) X_test, y_test = read_mitbih(filename, max_time, classes=classes, max_nlabel=50000,trainset=0) input_depth = X_train.shape[2] n_channels = 10 print ("# of sequences: ", len(X_train)) classes = np.unique(y_train) char2numY = dict(zip(classes, range(len(classes)))) n_classes = len(classes) print ('Classes (training): ', classes) for cl in classes: ind = np.where(classes == cl)[0][0] print (cl, len(np.where(y_train.flatten() == cl)[0])) print ('Classes (test): ', classes) for cl in classes: ind = np.where(classes == cl)[0][0] print (cl, len(np.where(y_test.flatten() == cl)[0])) char2numY[''] = len(char2numY) num2charY = dict(zip(char2numY.values(), char2numY.keys())) y_train = [[char2numY['']] + [char2numY[y_] for y_ in date] for date in y_train] y_test = [[char2numY['']] + [char2numY[y_] for y_ in date] for date in y_test] y_test = np.asarray(y_test) y_train = np.array(y_train) x_seq_length = len(X_train[0]) y_seq_length = len(y_train[0]) - 1 # Placeholders inputs = tf.placeholder(tf.float32, [None, max_time, input_depth], name='inputs') targets = tf.placeholder(tf.int32, (None, None), 'targets') dec_inputs = tf.placeholder(tf.int32, (None, None), 'output') logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time, bidirectional=bidirectional) # decoder_prediction = tf.argmax(logits, 2) # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1) with tf.name_scope("optimization"): # Loss function vars = tf.trainable_variables() beta = 0.001 lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name]) * beta loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length])) # Optimizer loss = tf.reduce_mean(loss + lossL2) optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss) # train the graph # over-sampling: SMOTE X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1]) y_train= y_train[:,1:].flatten() nums = [] for cl in classes: ind = np.where(classes == cl)[0][0] nums.append(len(np.where(y_train.flatten()==ind)[0])) # ratio={0:nums[0],1:nums[0],2:nums[0]} # ratio={0:7000,1:nums[1],2:7000,3:7000} ratio={0:nums[0],1:n_oversampling+1000,2:n_oversampling} sm = SMOTE(random_state=12,ratio=ratio) X_train, y_train = sm.fit_sample(X_train, y_train) X_train = X_train[:(X_train.shape[0]/max_time)*max_time,:] y_train = y_train[:(X_train.shape[0]/max_time)*max_time] X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]]) y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,]) y_train= [[char2numY['']] + [y_ for y_ in date] for date in y_train] y_train = np.array(y_train) print ('Classes in the training set: ', classes) for cl in classes: ind = np.where(classes == cl)[0][0] print (cl, len(np.where(y_train.flatten()==ind)[0])) print ("------------------y_train samples--------------------") for ii in range(2): print(''.join([num2charY[y_] for y_ in list(y_train[ii+5])])) print ('Classes in the training set: ', classes) for cl in classes: ind = np.where(classes == cl)[0][0] print (cl, len(np.where(y_test.flatten()==ind)[0])) print ("------------------y_test samples--------------------") for ii in range(2): print(''.join([num2charY[y_] for y_ in list(y_test[ii+5])])) def test_model(): # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size)) acc_track = [] sum_test_conf = [] for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)): dec_input = np.zeros((len(source_batch), 1)) + char2numY[''] for i in range(y_seq_length): batch_logits = sess.run(logits, feed_dict={inputs: source_batch, dec_inputs: dec_input}) prediction = batch_logits[:, -1].argmax(axis=-1) dec_input = np.hstack([dec_input, prediction[:, None]]) # acc_track.append(np.mean(dec_input == target_batch)) acc_track.append(dec_input[:, 1:] == target_batch[:, 1:]) y_true= target_batch[:, 1:].flatten() y_pred = dec_input[:, 1:].flatten() sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=range(len(char2numY)-1))) sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0) # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track))) # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy], # feed_dict={inputs: source_batch, # dec_inputs: dec_input[:, :-1], # targets: target_batch[:, 1:]}) # print (mean_p_class) # print (accuracy_classes) acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf) print('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg)) for index_ in range(n_classes): print("\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format( classes[index_], sensitivity[ index_], specificity[ index_], PPV[index_], acc[index_])) print("\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format( np.mean(sensitivity), np.mean(specificity), np.mean(PPV), np.mean(acc))) return acc_avg, acc, sensitivity, specificity, PPV loss_track = [] def count_prameters(): print ('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])) count_prameters() if (os.path.exists(checkpoint_dir) == False): os.mkdir(checkpoint_dir) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver() print(str(datetime.now())) pre_acc_avg = 0.0 ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: # # Restore ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name)) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) # or 'load meta graph' and restore weights # saver = tf.train.import_meta_graph(ckpt_name+".meta") # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir)) test_model() else: for epoch_i in range(epochs): start_time = time.time() train_acc = [] for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)): _, batch_loss, batch_logits = sess.run([optimizer, loss, logits], feed_dict = {inputs: source_batch, dec_inputs: target_batch[:, :-1], targets: target_batch[:, 1:]}) loss_track.append(batch_loss) train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:]) accuracy = np.mean(train_acc) print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss, accuracy, time.time() - start_time)) if epoch_i%test_steps==0: acc_avg, acc, sensitivity, specificity, PPV= test_model() print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size)) save_path = os.path.join(checkpoint_dir, ckpt_name) saver.save(sess, save_path) print("Model saved in path: %s" % save_path) # if np.nan_to_num(acc_avg) > pre_acc_avg: # save the better model based on the f1 score # print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size)) # pre_acc_avg = acc_avg # save_path =os.path.join(checkpoint_dir, ckpt_name) # saver.save(sess, save_path) # print("The best model (till now) saved in path: %s" % save_path) plt.plot(loss_track) plt.show() print(str(datetime.now())) # test_model() if __name__ == '__main__': main() ================================================ FILE: seq_seq_annot_aami.py ================================================ import numpy as np import matplotlib.pyplot as plt import scipy.io as spio from sklearn.preprocessing import MinMaxScaler import random import time import os from datetime import datetime from sklearn.metrics import confusion_matrix import tensorflow as tf from imblearn.over_sampling import SMOTE from sklearn.model_selection import train_test_split import argparse random.seed(654) def read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100): def normalize(data): data = np.nan_to_num(data) # removing NaNs and Infs data = data - np.mean(data) data = data / np.std(data) return data # read data data = [] samples = spio.loadmat(filename + ".mat") samples = samples['s2s_mitbih'] values = samples[0]['seg_values'] labels = samples[0]['seg_labels'] num_annots = sum([item.shape[0] for item in values]) n_seqs = num_annots / max_time # add all segments(beats) together l_data = 0 for i, item in enumerate(values): l = item.shape[0] for itm in item: if l_data == n_seqs * max_time: break data.append(itm[0]) l_data = l_data + 1 # add all labels together l_lables = 0 t_lables = [] for i, item in enumerate(labels): if len(t_lables)==n_seqs*max_time: break item= item[0] for lebel in item: if l_lables == n_seqs * max_time: break t_lables.append(str(lebel)) l_lables = l_lables + 1 del values data = np.asarray(data) shape_v = data.shape data = np.reshape(data, [shape_v[0], -1]) t_lables = np.array(t_lables) _data = np.asarray([],dtype=np.float64).reshape(0,shape_v[1]) _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,) for cl in classes: _label = np.where(t_lables == cl) permute = np.random.permutation(len(_label[0])) _label = _label[0][permute[:max_nlabel]] # _label = _label[0][:max_nlabel] # permute = np.random.permutation(len(_label)) # _label = _label[permute] _data = np.concatenate((_data, data[_label])) _labels = np.concatenate((_labels, t_lables[_label])) data = _data[:(len(_data)/ max_time) * max_time, :] _labels = _labels[:(len(_data) / max_time) * max_time] # data = _data # split data into sublist of 100=se_len values data = [data[i:i + max_time] for i in range(0, len(data), max_time)] labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)] # shuffle permute = np.random.permutation(len(labels)) data = np.asarray(data) labels = np.asarray(labels) data= data[permute] labels = labels[permute] print('Records processed!') return data, labels def evaluate_metrics(confusion_matrix): # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix) FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix) TP = np.diag(confusion_matrix) TN = confusion_matrix.sum() - (FP + FN + TP) # Sensitivity, hit rate, recall, or true positive rate TPR = TP / (TP + FN) # Specificity or true negative rate TNR = TN / (TN + FP) # Precision or positive predictive value PPV = TP / (TP + FP) # Negative predictive value NPV = TN / (TN + FN) # Fall out or false positive rate FPR = FP / (FP + TN) # False negative rate FNR = FN / (TP + FN) # False discovery rate FDR = FP / (TP + FP) # Overall accuracy ACC = (TP + TN) / (TP + FP + FN + TN) # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN)) ACC_macro = np.mean(ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average) return ACC_macro, ACC, TPR, TNR, PPV def batch_data(x, y, batch_size): shuffle = np.random.permutation(len(x)) start = 0 # from IPython.core.debugger import Tracer; Tracer()() x = x[shuffle] y = y[shuffle] while start + batch_size <= len(x): yield x[start:start + batch_size], y[start:start + batch_size] start += batch_size def build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False): _inputs = tf.reshape(inputs, [-1, n_channels, input_depth / n_channels]) # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels]) # #(batch*max_time, 280, 1) --> (N, 280, 18) conv1 = tf.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same') conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same') conv3 = tf.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) shape = conv3.get_shape().as_list() data_input_embed = tf.reshape(conv3, (-1, max_time, shape[1] * shape[2])) # timesteps = max_time # # lstm_in = tf.unstack(data_input_embed, timesteps, 1) # lstm_size = 128 # # Get lstm cell output # # Add LSTM layers # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size) # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32) # data_input_embed = tf.stack(data_input_embed, 1) # shape = data_input_embed.get_shape().as_list() embed_size = 10 # 128 lstm_size # shape[1]*shape[2] # Embedding layers output_embedding = tf.Variable(tf.random_uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding') data_output_embed = tf.nn.embedding_lookup(output_embedding, dec_inputs) with tf.variable_scope("encoding") as encoding_scope: if not bidirectional: # Regular approach with LSTM units lstm_enc = tf.contrib.rnn.LSTMCell(num_units) _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32) else: # Using a bidirectional LSTM architecture instead enc_fw_cell = tf.contrib.rnn.LSTMCell(num_units) enc_bw_cell = tf.contrib.rnn.LSTMCell(num_units) ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.nn.bidirectional_dynamic_rnn( cell_fw=enc_fw_cell, cell_bw=enc_bw_cell, inputs=data_input_embed, dtype=tf.float32) enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1) enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1) last_state = tf.contrib.rnn.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h) with tf.variable_scope("decoding") as decoding_scope: if not bidirectional: lstm_dec = tf.contrib.rnn.LSTMCell(num_units) else: lstm_dec = tf.contrib.rnn.LSTMCell(2 * num_units) dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state) logits = tf.layers.dense(dec_outputs, units=len(char2numY), use_bias=True) return logits def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def main(): parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=500) parser.add_argument('--max_time', type=int, default=10) parser.add_argument('--test_steps', type=int, default=10) parser.add_argument('--batch_size', type=int, default=20) parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami') parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False')) # parser.add_argument('--lstm_layers', type=int, default=2) parser.add_argument('--num_units', type=int, default=128) parser.add_argument('--n_oversampling', type=int, default=10000) parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq') parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih.ckpt') parser.add_argument('--classes', nargs='+', type=chr, default=['F','N', 'S','V']) args = parser.parse_args() run_program(args) def run_program(args): print(args) max_time = args.max_time # 5 3 second best 10# 40 # 100 epochs = args.epochs # 300 batch_size = args.batch_size # 10 num_units = args.num_units bidirectional = args.bidirectional # lstm_layers = args.lstm_layers n_oversampling = args.n_oversampling checkpoint_dir = args.checkpoint_dir ckpt_name = args.ckpt_name test_steps = args.test_steps classes= args.classes filename = args.data_dir X, Y = read_mitbih(filename,max_time,classes=classes,max_nlabel=100000) #11000 print ("# of sequences: ", len(X)) input_depth = X.shape[2] n_channels = 10 classes = np.unique(Y) char2numY = dict(zip(classes, range(len(classes)))) n_classes = len(classes) print ('Classes: ', classes) for cl in classes: ind = np.where(classes == cl)[0][0] print (cl, len(np.where(Y.flatten()==cl)[0])) # char2numX[''] = len(char2numX) # num2charX = dict(zip(char2numX.values(), char2numX.keys())) # max_len = max([len(date) for date in x]) # # x = [[char2numX['']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x] # print(''.join([num2charX[x_] for x_ in x[4]])) # x = np.array(x) char2numY[''] = len(char2numY) num2charY = dict(zip(char2numY.values(), char2numY.keys())) Y = [[char2numY['']] + [char2numY[y_] for y_ in date] for date in Y] Y = np.array(Y) x_seq_length = len(X[0]) y_seq_length = len(Y[0])- 1 # Placeholders inputs = tf.placeholder(tf.float32, [None, max_time, input_depth], name = 'inputs') targets = tf.placeholder(tf.int32, (None, None), 'targets') dec_inputs = tf.placeholder(tf.int32, (None, None), 'output') # logits = build_network(inputs,dec_inputs=dec_inputs) logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time, bidirectional=bidirectional) # decoder_prediction = tf.argmax(logits, 2) # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1) with tf.name_scope("optimization"): # Loss function vars = tf.trainable_variables() beta = 0.001 lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name]) * beta loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length])) # Optimizer loss = tf.reduce_mean(loss + lossL2) optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss) # split the dataset into the training and test sets X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42) # over-sampling: SMOTE X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1]) y_train= y_train[:,1:].flatten() nums = [] for cl in classes: ind = np.where(classes == cl)[0][0] nums.append(len(np.where(y_train.flatten()==ind)[0])) # ratio={0:nums[3],1:nums[1],2:nums[3],3:nums[3]} # the best with 11000 for N ratio={0:n_oversampling,1:nums[1],2:n_oversampling,3:n_oversampling} sm = SMOTE(random_state=12,ratio=ratio) X_train, y_train = sm.fit_sample(X_train, y_train) X_train = X_train[:(X_train.shape[0]/max_time)*max_time,:] y_train = y_train[:(X_train.shape[0]/max_time)*max_time] X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]]) y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,]) y_train= [[char2numY['']] + [y_ for y_ in date] for date in y_train] y_train = np.array(y_train) print ('Classes in the training set: ', classes) for cl in classes: ind = np.where(classes == cl)[0][0] print (cl, len(np.where(y_train.flatten()==ind)[0])) print ("------------------y_train samples--------------------") for ii in range(2): print(''.join([num2charY[y_] for y_ in list(y_train[ii+5])])) print ("------------------y_test samples--------------------") for ii in range(2): print(''.join([num2charY[y_] for y_ in list(y_test[ii+5])])) def test_model(): # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size)) acc_track = [] sum_test_conf = [] for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)): dec_input = np.zeros((len(source_batch), 1)) + char2numY[''] for i in range(y_seq_length): batch_logits = sess.run(logits, feed_dict={inputs: source_batch, dec_inputs: dec_input}) prediction = batch_logits[:, -1].argmax(axis=-1) dec_input = np.hstack([dec_input, prediction[:, None]]) # acc_track.append(np.mean(dec_input == target_batch)) acc_track.append(dec_input[:, 1:] == target_batch[:, 1:]) y_true= target_batch[:, 1:].flatten() y_pred = dec_input[:, 1:].flatten() sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=range(len(char2numY)-1))) sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0) # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track))) # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy], # feed_dict={inputs: source_batch, # dec_inputs: dec_input[:, :-1], # targets: target_batch[:, 1:]}) # print (mean_p_class) # print (accuracy_classes) acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf) print('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg)) for index_ in range(n_classes): print("\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(classes[index_], sensitivity[ index_], specificity[ index_],PPV[index_], acc[index_])) print("\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(np.mean(sensitivity),np.mean(specificity),np.mean(PPV),np.mean(acc))) return acc_avg, acc, sensitivity, specificity, PPV loss_track = [] def count_prameters(): print ('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])) count_prameters() if (os.path.exists(checkpoint_dir) == False): os.mkdir(checkpoint_dir) # train the graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver() print(str(datetime.now())) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) pre_acc_avg = 0.0 if ckpt and ckpt.model_checkpoint_path: # # Restore ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name)) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) # or 'load meta graph' and restore weights # saver = tf.train.import_meta_graph(ckpt_name+".meta") # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir)) test_model() else: for epoch_i in range(epochs): start_time = time.time() train_acc = [] for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)): _, batch_loss, batch_logits = sess.run([optimizer, loss, logits], feed_dict = {inputs: source_batch, dec_inputs: target_batch[:, :-1], targets: target_batch[:, 1:]}) loss_track.append(batch_loss) train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:]) # mean_p_class,accuracy_classes = sess.run([mean_accuracy,update_mean_accuracy], # feed_dict={inputs: source_batch, # dec_inputs: target_batch[:, :-1], # targets: target_batch[:, 1:]}) # accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:]) accuracy = np.mean(train_acc) print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss, accuracy, time.time() - start_time)) if epoch_i%test_steps==0: acc_avg, acc, sensitivity, specificity, PPV= test_model() print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size)) save_path = os.path.join(checkpoint_dir, ckpt_name) saver.save(sess, save_path) print("Model saved in path: %s" % save_path) # if np.nan_to_num(acc_avg) > pre_acc_avg: # save the better model based on the f1 score # print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size)) # pre_acc_avg = acc_avg # save_path =os.path.join(checkpoint_dir, ckpt_name) # saver.save(sess, save_path) # print("The best model (till now) saved in path: %s" % save_path) plt.plot(loss_track) plt.show() print(str(datetime.now())) # test_model() if __name__ == '__main__': main() ================================================ FILE: tf2/seq_seq_annot_aami.py ================================================ import numpy as np import matplotlib.pyplot as plt import scipy.io as spio from sklearn.preprocessing import MinMaxScaler import random import time import os from datetime import datetime from sklearn.metrics import confusion_matrix import tensorflow as tf import tensorflow_addons as tfa from imblearn.over_sampling import SMOTE from sklearn.model_selection import train_test_split import argparse random.seed(654) def read_mitbih(filename, max_time=100, classes= ['F', 'N', 'S', 'V', 'Q'], max_nlabel=100): def normalize(data): data = np.nan_to_num(data) # removing NaNs and Infs data = data - np.mean(data) data = data / np.std(data) return data # read data data = [] samples = spio.loadmat(filename + ".mat") samples = samples['s2s_mitbih'] values = samples[0]['seg_values'] labels = samples[0]['seg_labels'] num_annots = sum([item.shape[0] for item in values]) n_seqs = num_annots / max_time # add all segments(beats) together l_data = 0 for i, item in enumerate(values): l = item.shape[0] for itm in item: if l_data == n_seqs * max_time: break data.append(itm[0]) l_data = l_data + 1 # add all labels together l_lables = 0 t_lables = [] for i, item in enumerate(labels): if len(t_lables)==n_seqs*max_time: break item= item[0] for lebel in item: if l_lables == n_seqs * max_time: break t_lables.append(str(lebel)) l_lables = l_lables + 1 del values data = np.asarray(data) shape_v = data.shape data = np.reshape(data, [shape_v[0], -1]) t_lables = np.array(t_lables) _data = np.asarray([],dtype=np.float64).reshape(0,shape_v[1]) _labels = np.asarray([],dtype=np.dtype('|S1')).reshape(0,) for cl in classes: _label = np.where(t_lables == cl) permute = np.random.permutation(len(_label[0])) _label = _label[0][permute[:max_nlabel]] # _label = _label[0][:max_nlabel] # permute = np.random.permutation(len(_label)) # _label = _label[permute] _data = np.concatenate((_data, data[_label])) _labels = np.concatenate((_labels, t_lables[_label])) data = _data[:int(len(_data)/ max_time) * max_time, :] _labels = _labels[:int(len(_data) / max_time) * max_time] # data = _data # split data into sublist of 100=se_len values data = [data[i:i + max_time] for i in range(0, len(data), max_time)] labels = [_labels[i:i + max_time] for i in range(0, len(_labels), max_time)] # shuffle permute = np.random.permutation(len(labels)) data = np.asarray(data) labels = np.asarray(labels) data= data[permute] labels = labels[permute] print('Records processed!') return data, labels def evaluate_metrics(confusion_matrix): # https://stackoverflow.com/questions/31324218/scikit-learn-how-to-obtain-true-positive-true-negative-false-positive-and-fal FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix) FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix) TP = np.diag(confusion_matrix) TN = confusion_matrix.sum() - (FP + FN + TP) # Sensitivity, hit rate, recall, or true positive rate TPR = TP / (TP + FN) # Specificity or true negative rate TNR = TN / (TN + FP) # Precision or positive predictive value PPV = TP / (TP + FP) # Negative predictive value NPV = TN / (TN + FN) # Fall out or false positive rate FPR = FP / (FP + TN) # False negative rate FNR = FN / (TP + FN) # False discovery rate FDR = FP / (TP + FP) # Overall accuracy ACC = (TP + TN) / (TP + FP + FN + TN) # ACC_micro = (sum(TP) + sum(TN)) / (sum(TP) + sum(FP) + sum(FN) + sum(TN)) ACC_macro = np.mean(ACC) # to get a sense of effectiveness of our method on the small classes we computed this average (macro-average) return ACC_macro, ACC, TPR, TNR, PPV def batch_data(x, y, batch_size): shuffle = np.random.permutation(len(x)) start = 0 # from IPython.core.debugger import Tracer; Tracer()() x = x[shuffle] y = y[shuffle] while start + batch_size <= len(x): yield x[start:start + batch_size], y[start:start + batch_size] start += batch_size def build_network(inputs, dec_inputs,char2numY,n_channels=10,input_depth=280,num_units=128,max_time=10,bidirectional=False): _inputs = tf.reshape(inputs, [-1, n_channels, int(input_depth / n_channels)]) # _inputs = tf.reshape(inputs, [-1,input_depth,n_channels]) # #(batch*max_time, 280, 1) --> (N, 280, 18) conv1 = tf.compat.v1.layers.conv1d(inputs=_inputs, filters=32, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) max_pool_1 = tf.compat.v1.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same') conv2 = tf.compat.v1.layers.conv1d(inputs=max_pool_1, filters=64, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) max_pool_2 = tf.compat.v1.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same') conv3 = tf.compat.v1.layers.conv1d(inputs=max_pool_2, filters=128, kernel_size=2, strides=1, padding='same', activation=tf.nn.relu) shape = conv3.get_shape().as_list() data_input_embed = tf.reshape(conv3, (-1, max_time, shape[1] * shape[2])) # timesteps = max_time # # lstm_in = tf.unstack(data_input_embed, timesteps, 1) # lstm_size = 128 # # Get lstm cell output # # Add LSTM layers # lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size) # data_input_embed, states = tf.contrib.rnn.static_rnn(lstm_cell, lstm_in, dtype=tf.float32) # data_input_embed = tf.stack(data_input_embed, 1) # shape = data_input_embed.get_shape().as_list() embed_size = 10 # 128 lstm_size # shape[1]*shape[2] # Embedding layers output_embedding = tf.Variable(tf.random.uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding') data_output_embed = tf.nn.embedding_lookup(params=output_embedding, ids=dec_inputs) with tf.compat.v1.variable_scope("encoding") as encoding_scope: if not bidirectional: # Regular approach with LSTM units lstm_enc = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) _, last_state = tf.compat.v1.nn.dynamic_rnn(lstm_enc, inputs=data_input_embed, dtype=tf.float32) else: # Using a bidirectional LSTM architecture instead enc_fw_cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) enc_bw_cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) ((enc_fw_out, enc_bw_out), (enc_fw_final, enc_bw_final)) = tf.compat.v1.nn.bidirectional_dynamic_rnn( cell_fw=enc_fw_cell, cell_bw=enc_bw_cell, inputs=data_input_embed, dtype=tf.float32) enc_fin_c = tf.concat((enc_fw_final.c, enc_bw_final.c), 1) enc_fin_h = tf.concat((enc_fw_final.h, enc_bw_final.h), 1) last_state = tf.nn.rnn_cell.LSTMStateTuple(c=enc_fin_c, h=enc_fin_h) with tf.compat.v1.variable_scope("decoding") as decoding_scope: if not bidirectional: lstm_dec = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) else: lstm_dec = tf.compat.v1.nn.rnn_cell.LSTMCell(2 * num_units) dec_outputs, _ = tf.compat.v1.nn.dynamic_rnn(lstm_dec, inputs=data_output_embed, initial_state=last_state) logits = tf.compat.v1.layers.dense(dec_outputs, units=len(char2numY), use_bias=True) return logits def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def main(): parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=500) parser.add_argument('--max_time', type=int, default=10) parser.add_argument('--test_steps', type=int, default=10) parser.add_argument('--batch_size', type=int, default=20) parser.add_argument('--data_dir', type=str, default='data/s2s_mitbih_aami') parser.add_argument('--bidirectional', type=str2bool, default=str2bool('False')) # parser.add_argument('--lstm_layers', type=int, default=2) parser.add_argument('--num_units', type=int, default=128) parser.add_argument('--n_oversampling', type=int, default=10000) parser.add_argument('--checkpoint_dir', type=str, default='checkpoints-seq2seq') parser.add_argument('--ckpt_name', type=str, default='seq2seq_mitbih.ckpt') parser.add_argument('--classes', nargs='+', type=chr, default=['F','N', 'S','V']) args = parser.parse_args() run_program(args) def run_program(args): print(args) max_time = args.max_time # 5 3 second best 10# 40 # 100 epochs = args.epochs # 300 batch_size = args.batch_size # 10 num_units = args.num_units bidirectional = args.bidirectional # lstm_layers = args.lstm_layers n_oversampling = args.n_oversampling checkpoint_dir = args.checkpoint_dir ckpt_name = args.ckpt_name test_steps = args.test_steps classes= args.classes filename = args.data_dir X, Y = read_mitbih(filename,max_time,classes=classes,max_nlabel=100000) #11000 print(("# of sequences: ", len(X))) input_depth = X.shape[2] n_channels = 10 classes = np.unique(Y) char2numY = dict(list(zip(classes, list(range(len(classes)))))) n_classes = len(classes) print(('Classes: ', classes)) for cl in classes: ind = np.where(classes == cl)[0][0] print((cl, len(np.where(Y.flatten()==cl)[0]))) # char2numX[''] = len(char2numX) # num2charX = dict(zip(char2numX.values(), char2numX.keys())) # max_len = max([len(date) for date in x]) # # x = [[char2numX['']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x] # print(''.join([num2charX[x_] for x_ in x[4]])) # x = np.array(x) char2numY[''] = len(char2numY) num2charY = dict(list(zip(list(char2numY.values()), list(char2numY.keys())))) Y = [[char2numY['']] + [char2numY[y_] for y_ in date] for date in Y] Y = np.array(Y) x_seq_length = len(X[0]) y_seq_length = len(Y[0])- 1 # Placeholders tf.compat.v1.disable_eager_execution() inputs = tf.compat.v1.placeholder(tf.float32, [None, max_time, input_depth], name = 'inputs') targets = tf.compat.v1.placeholder(tf.int32, (None, None), 'targets') dec_inputs = tf.compat.v1.placeholder(tf.int32, (None, None), 'output') # logits = build_network(inputs,dec_inputs=dec_inputs) logits = build_network(inputs, dec_inputs, char2numY, n_channels=n_channels, input_depth=input_depth, num_units=num_units, max_time=max_time, bidirectional=bidirectional) # decoder_prediction = tf.argmax(logits, 2) # confusion = tf.confusion_matrix(labels=tf.argmax(targets, 1), predictions=tf.argmax(logits, 2), num_classes=len(char2numY) - 1)# it is wrong # mean_accuracy,update_mean_accuracy = tf.metrics.mean_per_class_accuracy(labels=targets, predictions=decoder_prediction, num_classes=len(char2numY) - 1) with tf.compat.v1.name_scope("optimization"): # Loss function vars = tf.compat.v1.trainable_variables() beta = 0.001 lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name]) * beta loss = tfa.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length])) # Optimizer loss = tf.reduce_mean(input_tensor=loss + lossL2) optimizer = tf.compat.v1.train.RMSPropOptimizer(1e-3).minimize(loss) # split the dataset into the training and test sets X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42) # over-sampling: SMOTE X_train = np.reshape(X_train,[X_train.shape[0]*X_train.shape[1],-1]) y_train= y_train[:,1:].flatten() nums = [] for cl in classes: ind = np.where(classes == cl)[0][0] nums.append(len(np.where(y_train.flatten()==ind)[0])) # ratio={0:nums[3],1:nums[1],2:nums[3],3:nums[3]} # the best with 11000 for N ratio={0:n_oversampling,1:nums[1],2:n_oversampling,3:n_oversampling} sm = SMOTE(random_state=12,ratio=ratio) X_train, y_train = sm.fit_sample(X_train, y_train) X_train = X_train[:int(X_train.shape[0]/max_time)*max_time,:] y_train = y_train[:int(X_train.shape[0]/max_time)*max_time] X_train = np.reshape(X_train,[-1,X_test.shape[1],X_test.shape[2]]) y_train = np.reshape(y_train,[-1,y_test.shape[1]-1,]) y_train= [[char2numY['']] + [y_ for y_ in date] for date in y_train] y_train = np.array(y_train) print(('Classes in the training set: ', classes)) for cl in classes: ind = np.where(classes == cl)[0][0] print((cl, len(np.where(y_train.flatten()==ind)[0]))) print ("------------------y_train samples--------------------") for ii in range(2): print((''.join([num2charY[y_] for y_ in list(y_train[ii+5])]))) print ("------------------y_test samples--------------------") for ii in range(2): print((''.join([num2charY[y_] for y_ in list(y_test[ii+5])]))) def test_model(): # source_batch, target_batch = next(batch_data(X_test, y_test, batch_size)) acc_track = [] sum_test_conf = [] for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_test, y_test, batch_size)): dec_input = np.zeros((len(source_batch), 1)) + char2numY[''] for i in range(y_seq_length): batch_logits = sess.run(logits, feed_dict={inputs: source_batch, dec_inputs: dec_input}) prediction = batch_logits[:, -1].argmax(axis=-1) dec_input = np.hstack([dec_input, prediction[:, None]]) # acc_track.append(np.mean(dec_input == target_batch)) acc_track.append(dec_input[:, 1:] == target_batch[:, 1:]) y_true= target_batch[:, 1:].flatten() y_pred = dec_input[:, 1:].flatten() sum_test_conf.append(confusion_matrix(y_true, y_pred,labels=list(range(len(char2numY)-1)))) sum_test_conf= np.mean(np.array(sum_test_conf, dtype=np.float32), axis=0) # print('Accuracy on test set is: {:>6.4f}'.format(np.mean(acc_track))) # mean_p_class, accuracy_classes = sess.run([mean_accuracy, update_mean_accuracy], # feed_dict={inputs: source_batch, # dec_inputs: dec_input[:, :-1], # targets: target_batch[:, 1:]}) # print (mean_p_class) # print (accuracy_classes) acc_avg, acc, sensitivity, specificity, PPV = evaluate_metrics(sum_test_conf) print(('Average Accuracy is: {:>6.4f} on test set'.format(acc_avg))) for index_ in range(n_classes): print(("\t{} rhythm -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(classes[index_], sensitivity[ index_], specificity[ index_],PPV[index_], acc[index_]))) print(("\t Average -> Sensitivity: {:1.4f}, Specificity : {:1.4f}, Precision (PPV) : {:1.4f}, Accuracy : {:1.4f}".format(np.mean(sensitivity),np.mean(specificity),np.mean(PPV),np.mean(acc)))) return acc_avg, acc, sensitivity, specificity, PPV loss_track = [] def count_prameters(): print(('# of Params: ', np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()]))) count_prameters() if (os.path.exists(checkpoint_dir) == False): os.mkdir(checkpoint_dir) # train the graph with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) saver = tf.compat.v1.train.Saver() print((str(datetime.now()))) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) pre_acc_avg = 0.0 if ckpt and ckpt.model_checkpoint_path: # # Restore ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # saver.restore(session, os.path.join(checkpoint_dir, ckpt_name)) saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) # or 'load meta graph' and restore weights # saver = tf.train.import_meta_graph(ckpt_name+".meta") # saver.restore(session,tf.train.latest_checkpoint(checkpoint_dir)) test_model() else: for epoch_i in range(epochs): start_time = time.time() train_acc = [] for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)): _, batch_loss, batch_logits = sess.run([optimizer, loss, logits], feed_dict = {inputs: source_batch, dec_inputs: target_batch[:, :-1], targets: target_batch[:, 1:]}) loss_track.append(batch_loss) train_acc.append(batch_logits.argmax(axis=-1) == target_batch[:,1:]) # mean_p_class,accuracy_classes = sess.run([mean_accuracy,update_mean_accuracy], # feed_dict={inputs: source_batch, # dec_inputs: target_batch[:, :-1], # targets: target_batch[:, 1:]}) # accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:]) accuracy = np.mean(train_acc) print(('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss, accuracy, time.time() - start_time))) if epoch_i%test_steps==0: acc_avg, acc, sensitivity, specificity, PPV= test_model() print(('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size))) save_path = os.path.join(checkpoint_dir, ckpt_name) saver.save(sess, save_path) print(("Model saved in path: %s" % save_path)) # if np.nan_to_num(acc_avg) > pre_acc_avg: # save the better model based on the f1 score # print('loss {:.4f} after {} epochs (batch_size={})'.format(loss_track[-1], epoch_i + 1, batch_size)) # pre_acc_avg = acc_avg # save_path =os.path.join(checkpoint_dir, ckpt_name) # saver.save(sess, save_path) # print("The best model (till now) saved in path: %s" % save_path) plt.plot(loss_track) plt.show() print((str(datetime.now()))) # test_model() if __name__ == '__main__': main()