Repository: NeuSpeech/EEG-To-Text Branch: main Commit: 01582402809a Files: 28 Total size: 173.2 KB Directory structure: gitextract_ginxdd5t/ ├── .gitignore ├── README.md ├── config.py ├── data.py ├── environment.yml ├── eval_decoding.py ├── eval_sentiment.py ├── model_decoding.py ├── model_sentiment.py ├── scripts/ │ ├── eval_decoding_1.sh │ ├── eval_decoding_2.sh │ ├── eval_decoding_3.sh │ ├── eval_decoding_4.sh │ ├── eval_sentiment_zeroshot_pipeline.sh │ ├── prepare_dataset.sh │ ├── train_decoding.sh │ ├── train_decoding_1.sh │ ├── train_eeg_sentiment_baseline.sh │ ├── train_eval_zeroshot_pipeline.sh │ └── train_text_sentiment_classifier.sh ├── train_decoding.py ├── train_sentiment_baseline.py ├── train_sentiment_textbased.py └── util/ ├── construct_dataset_mat_to_pickle_v1.py ├── construct_dataset_mat_to_pickle_v2.py ├── data_loading_helpers_modified.py ├── get_SST_ternary_dataset.py └── get_sentiment_labels.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pt *.pickle *.mat *.json *.txt # Byte-compiled / optimized / DLL files __pycache__/ csv_results/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ ================================================ FILE: README.md ================================================ The **main branch** contains the final code for our "Are EEG-to-Text Models Working?" paper. Accepted by [IJCAI workshop 2024](https://github.com/user-attachments/files/16624318/IJCAI_hyejeongjo_poster_Final.pdf) If you have any questions, you can write them in the Issues section or email Hyejeong Jo at girlsending0@khu.ac.kr. check our new paper with full detailed comparison of different models on this task at [https://arxiv.org/abs/2405.06459](https://arxiv.org/abs/2405.06459) overview ![image](https://github.com/NeuSpeech/EEG-To-Text/assets/151606332/57212488-b75f-44c7-a265-e2a51483e9f5) performance ![image](https://github.com/NeuSpeech/EEG-To-Text/assets/151606332/df58870c-5277-4935-8c66-15efd58e9283) # Correction on [(AAAI 2022) Open Vocabulary EEG-To-Text Decoding and Zero-shot sentiment classification](https://arxiv.org/abs/2112.02690) # results and code is updated on **master** branch # results and code is updated on **master** branch # results and code is updated on **master** branch **First of all, we are not pointing at others, we do this correction due to no offense, but a kind reminder of being careful of the string generation process. We repsect Mr. Wang very much, and appreciate his great contribution in this area.** After scrutilizing [the original code shared by Wang Zhenhailong](https://github.com/MikeWangWZHL/EEG-To-Text), we discovered that the eval method have an unintentional but very serious mistake in generating predicted strings, which is using teacher forcing implicitly. The code which reaches my concern is: ```python seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch) logits = seq2seqLMoutput.logits # bs*seq_len*voc_sz probs = logits[0].softmax(dim = 1) values, predictions = probs.topk(1) predictions = torch.squeeze(predictions) predicted_string = tokenizer.decode(predictions) ``` Therefore resulting in [predictions like below](https://github.com/MikeWangWZHL/EEG-To-Text/blob/main/results/task1_task2_taskNRv2-BrainTranslator_skipstep1-all_generation_results-7_22.txt#L61): ``` target string: It isn't that Stealing Harvard is a horrible movie -- if only it were that grand a failure! predicted string: was't a the. is was a bad place, it it it were a.. movie. ################################################ target string: It just doesn't have much else... especially in a moral sense. predicted string: was so't work the to to and not the country sense. ################################################ target string: Those unfamiliar with Mormon traditions may find The Singles Ward occasionally bewildering. predicted string: who with the history may be themselves Mormoning''s amusingering. ################################################ target string: Viewed as a comedy, a romance, a fairy tale, or a drama, there's nothing remotely triumphant about this motion picture. predicted string: the from a whole, it film, and comedy tale, and a tragic, it is nothing quite romantic about it. picture. ################################################ target string: But the talented cast alone will keep you watching, as will the fight scenes. predicted string: the most and of cannot not the entertained. and they the music against. ################################################ target string: It's solid and affecting and exactly as thought-provoking as it should be. predicted string: was a, it, it what it.provoking as it is be. ################################################ target string: Thanks largely to Williams, all the interesting developments are processed in 60 minutes -- the rest is just an overexposed waste of film. predicted string: to to the, the of films and in in in a minutes. and longest is a a afteragerposure, of time time ################################################ target string: Cantet perfectly captures the hotel lobbies, two-lane highways, and roadside cafes that permeate Vincent's days predicted string: urtor was describes the spirit'sies and the ofstory streets, and the parking of areate the's life.'sgggggggg,,,,,,,,,,,,,,,,,,, ################################################ target string: An important movie, a reminder of the power of film to move us and to make us examine our values. predicted string: nie part in " classic of the importance of the, shape people, our make us think our lives, ################################################ target string: Too much of this well-acted but dangerously slow thriller feels like a preamble to a bigger, more complicated story, one that never materializes. predicted string: bad of a is-known film not over- is like a film-ble to a much, more dramatic story. which that is endsizes. ``` In addition, we noticed that some people are using it as code base which generates concerning results. We are not condemning these researchers, we just want to notice them and maybe we can do something together to resolve this problem. [BELT Bootstrapping Electroencephalography-to-Language Decoding and Zero-Shot SenTiment Classification by Natural Language Supervision](https://arxiv.org/pdf/2309.12056) [Aligning Semantic in Brain and Language: A Curriculum Contrastive Method for Electroencephalography-to-Text Generation](https://ieeexplore.ieee.org/iel7/7333/4359219/10248031.pdf) [UniCoRN: Unified Cognitive Signal ReconstructioN bridging cognitive signals and human language](https://arxiv.org/pdf/2307.05355) [Semantic-aware Contrastive Learning for Electroencephalography-to-Text Generation with Curriculum Learning](https://arxiv.org/pdf/2301.09237) [DeWave: Discrete EEG Waves Encoding for Brain Dynamics to Text Translation](https://arxiv.org/pdf/2309.14030) We have written a corrected version to use model.generate to evaluate the model, the result is not so good. Basicly, we changed the model_decoding.py and eval_decoding.py to add model.generate for its originally nn.Module class model, and used model.generate to predict strings. **We are open to everyone to scrutinize on this corrected code and run the code. Then, we will show the final performance of this model in this repo and formalize a technical paper.** # We really appreciate the great contribution made by Mr. Wang, however, we should prevent others from continuing this misunderstanding. This work was supported by the Culture, Sports and Tourism R&D Program through the Korea Creative Content Agency grant funded by the Ministry of Culture, Sports and Tourism (RS-2023-00226263), the Institute for Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. RS-2024-00509257, Global AI Frontier Lab), the Information Technology Research Center (ITRC) support program (IITP-2024-RS-2024-00438239) supervised by the IITP, and the IITP grant funded by the Korea government (MSIT) (No. RS-2022-00155911, Artificial Intelligence Convergence Innovation Human Resources Development, Kyung Hee University). ================================================ FILE: config.py ================================================ import argparse def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def get_config(case): if case == 'train_decoding': # args config for training EEG-To-Text decoder parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder') parser.add_argument('-m', '--model_name', help='choose from {BrainTranslator, BrainTranslatorNaive}', default = "BrainTranslator" ,required=True) parser.add_argument('-t', '--task_name', help='choose from {task1,task1_task2, task1_task2_task3,task1_task2_taskNRv2}', default = "task1", required=True) parser.add_argument('-1step', '--one_step', dest='skip_step_one', action='store_true') parser.add_argument('-2step', '--two_step', dest='skip_step_one', action='store_false') parser.add_argument('-pre', '--pretrained', dest='use_random_init', action='store_false') parser.add_argument('-rand', '--rand_init', dest='use_random_init', action='store_true') parser.add_argument('-load1', '--load_step1_checkpoint', dest='load_step1_checkpoint', action='store_true') parser.add_argument('-no-load1', '--not_load_step1_checkpoint', dest='load_step1_checkpoint', action='store_false') parser.add_argument('-ne1', '--num_epoch_step1', type = int, help='num_epoch_step1', default = 20, required=True) parser.add_argument('-ne2', '--num_epoch_step2', type = int, help='num_epoch_step2', default = 30, required=True) parser.add_argument('-lr1', '--learning_rate_step1', type = float, help='learning_rate_step1', default = 0.00005, required=True) parser.add_argument('-lr2', '--learning_rate_step2', type = float, help='learning_rate_step2', default = 0.0000005, required=True) parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True) parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/decoding', required=True) parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False) parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False) parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False) parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0') parser.add_argument('-train_input', '--train_input', help='add noise' ,required=True) args = vars(parser.parse_args()) elif case == 'train_sentiment_baseline': # args config for training EEG-based sentiment baselines parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder') parser.add_argument('-m', '--model_name', help='choose from {BaselineMLP, BaselineLSTM, NaiveFinetuneBert}', default = "NaiveFinetuneBert" ,required=True) parser.add_argument('-ne', '--num_epoch', type = int, help='num_epoch', default = 30, required=True) parser.add_argument('-lr', '--learning_rate', type = float, help='learning_rate', default = 0.00001, required=True) parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True) parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/eeg_sentiment', required=True) parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False) parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False) parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False) parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0') args = vars(parser.parse_args()) elif case == 'train_sentiment_textbased': # args config for training text-based sentiment classification models parser = argparse.ArgumentParser(description='Specify config args for training text-based sentiment classifiers') parser.add_argument('-d', '--dataset_name', help='zero-shot setting: using external dataset from stanford sentiment treebank, pass in SST; to use ZuCo\'s own text-sentiment pairs, pass in ZuCo', default = "SST" ,required=True) parser.add_argument('-m', '--model_name', help='choose from {pretrain_Bert, pretrain_RoBerta, pretrain_Bart}', default = "pretrain_Bart" ,required=True) parser.add_argument('-ne', '--num_epoch', type = int, help='num_epoch', default = 20, required=True) parser.add_argument('-lr', '--learning_rate', type = float, help='learning_rate', default = 0.0001, required=True) parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True) parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/text_sentiment_classifier', required=True) parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False) parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False) parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False) parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0') args = vars(parser.parse_args()) elif case == 'eval_decoding': # args config for evaluating EEG-To-Text decoder parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-To-Text decoder') parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=True) parser.add_argument('-conf', '--config_path', help='specify training config json' ,required=True) parser.add_argument('-test_input', '--test_input', help='add noise' ,required=True) parser.add_argument('-train_input', '--train_input', help='add noise' ,required=True) parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0') args = vars(parser.parse_args()) elif case == 'eval_sentiment': # args config for sentiment classification models parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-based sentiment classification, including Zero-shot pipeline') # choose model_name = 'ZeroShotSentimentDiscovery' to evaluate Zero-shot pipeline parser.add_argument('-m', '--model_name', help='choose from {BaselineMLP, BaselineLSTM, NaiveFinetuneBert, FinetunedBertOnText, FinetunedRoBertaOnText, FinetunedBartOnText, ZeroShotSentimentDiscovery}', default = "ZeroShotSentimentDiscovery" ,required=True) parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=False) # required if NOT evaluating Zero-shot pipeline parser.add_argument('-conf', '--config_path', help='specify model config json' ,required=False) # required if NOT evaluating Zero-shot pipeline parser.add_argument('-checkpoint_DEC', '--decoder_checkpoint_path', help='specify decoder checkpoint for Zero-shot pipeline ', required=False) # required if evaluating Zero-shot pipeline parser.add_argument('-checkpoint_CLS', '--classifier_checkpoint_path', help='specify classifier checkpoint for Zero-shot pipeline ', required=False) # required if evaluating Zero-shot pipeline parser.add_argument('-conf_DEC', '--decoder_config_path', help='specify decoder config json' ,required=False) # required if evaluating Zero-shot pipeline parser.add_argument('-conf_CLS', '--classifier_config_path', help='specify classifier config json' ,required=False) # required if evaluating Zero-shot pipeline parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False) parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False) parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False) parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0') args = vars(parser.parse_args()) return args ================================================ FILE: data.py ================================================ import os import numpy as np import torch import pickle from torch.utils.data import Dataset, DataLoader import json import matplotlib.pyplot as plt from glob import glob from transformers import BartTokenizer, BertTokenizer from tqdm import tqdm from fuzzy_match import match from fuzzy_match import algorithims from transformers import T5Tokenizer # macro #ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json')) #SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json')) def normalize_1d(input_tensor): # normalize a 1d tensor mean = torch.mean(input_tensor) std = torch.std(input_tensor) input_tensor = (input_tensor - mean)/std return input_tensor def get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], max_len = 56, add_CLS_token = False, test_input="noise"): def get_word_embedding_eeg_tensor(word_obj, eeg_type, bands): frequency_features = [] for band in bands: frequency_features.append(word_obj['word_level_EEG'][eeg_type][eeg_type+band]) word_eeg_embedding = np.concatenate(frequency_features) if len(word_eeg_embedding) != 105*len(bands): print(f'expect word eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None') return None # assert len(word_eeg_embedding) == 105*len(bands) return_tensor = torch.from_numpy(word_eeg_embedding) return normalize_1d(return_tensor) def get_sent_eeg(sent_obj, bands): sent_eeg_features = [] for band in bands: key = 'mean'+band sent_eeg_features.append(sent_obj['sentence_level_EEG'][key]) sent_eeg_embedding = np.concatenate(sent_eeg_features) assert len(sent_eeg_embedding) == 105*len(bands) return_tensor = torch.from_numpy(sent_eeg_embedding) return normalize_1d(return_tensor) if sent_obj is None: # print(f' - skip bad sentence') return None input_sample = {} # get target label target_string = sent_obj['content'] target_tokenized = tokenizer(target_string, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True) input_sample['target_ids'] = target_tokenized['input_ids'][0] # get sentence level EEG features sent_level_eeg_tensor = get_sent_eeg(sent_obj, bands) # try: # sent_level_eeg_tensor = torch.from_numpy(sent_obj['sentence_level_EEG']) # This gives a dictionary # except: # return None if torch.isnan(sent_level_eeg_tensor).any(): # print('[NaN sent level eeg]: ', target_string) return None # if sent_level_eeg_tensor.shape[1] < 30: # return None input_sample['sent_level_EEG'] = sent_level_eeg_tensor #input_sample['sent_level_EEG'] = torch.randn(sent_level_eeg_tensor.size()) # random input code #print("NOISE:", input_sample['sent_level_EEG']) # get sentiment label # handle some wierd case if 'emp11111ty' in target_string: target_string = target_string.replace('emp11111ty','empty') if 'film.1' in target_string: target_string = target_string.replace('film.1','film.') #if target_string in ZUCO_SENTIMENT_LABELS: # input_sample['sentiment_label'] = torch.tensor(ZUCO_SENTIMENT_LABELS[target_string]+1) # 0:Negative, 1:Neutral, 2:Positive #else: # input_sample['sentiment_label'] = torch.tensor(-100) # dummy value input_sample['sentiment_label'] = torch.tensor(-100) # dummy value # get input embeddings word_embeddings = [] """add CLS token embedding at the front""" if add_CLS_token: word_embeddings.append(torch.ones(105*len(bands))) for word in sent_obj['word']: # add each word's EEG embedding as Tensors word_level_eeg_tensor = get_word_embedding_eeg_tensor(word, eeg_type, bands = bands) # check none, for v2 dataset if word_level_eeg_tensor is None: return None # check nan: if torch.isnan(word_level_eeg_tensor).any(): # print() # print('[NaN ERROR] problem sent:',sent_obj['content']) # print('[NaN ERROR] problem word:',word['content']) # print('[NaN ERROR] problem word feature:',word_level_eeg_tensor) # print() return None word_embeddings.append(word_level_eeg_tensor) # pad to max_len while len(word_embeddings) < max_len: word_embeddings.append(torch.zeros(105*len(bands))) if test_input=='noise': rand_eeg= torch.randn(torch.stack(word_embeddings).size()) input_sample['input_embeddings'] = rand_eeg # max_len * (105*num_bands) # print("rand_eeg:", rand_eeg) # print("input_embeddings:", input_sample['input_embeddings'].shape) else: input_sample['input_embeddings'] = torch.stack(word_embeddings) # max_len * (105*num_bands) print("EEG", input_sample['input_embeddings']) # mask out padding tokens input_sample['input_attn_mask'] = torch.zeros(max_len) # 0 is masked out if add_CLS_token: input_sample['input_attn_mask'][:len(sent_obj['word'])+1] = torch.ones(len(sent_obj['word'])+1) # 1 is not masked else: input_sample['input_attn_mask'][:len(sent_obj['word'])] = torch.ones(len(sent_obj['word'])) # 1 is not masked # mask out padding tokens reverted: handle different use case: this is for pytorch transformers input_sample['input_attn_mask_invert'] = torch.ones(max_len) # 1 is masked out if add_CLS_token: input_sample['input_attn_mask_invert'][:len(sent_obj['word'])+1] = torch.zeros(len(sent_obj['word'])+1) # 0 is not masked else: input_sample['input_attn_mask_invert'][:len(sent_obj['word'])] = torch.zeros(len(sent_obj['word'])) # 0 is not masked # mask out target padding for computing cross entropy loss input_sample['target_mask'] = target_tokenized['attention_mask'][0] input_sample['seq_len'] = len(sent_obj['word']) # clean 0 length data if input_sample['seq_len'] == 0: print('discard length zero instance: ', target_string) return None return input_sample class ZuCo_dataset(Dataset): def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], setting = 'unique_sent', is_add_CLS_token = False, test_input='noise'): self.inputs = [] self.tokenizer = tokenizer if not isinstance(input_dataset_dicts,list): input_dataset_dicts = [input_dataset_dicts] print(f'[INFO]loading {len(input_dataset_dicts)} task datasets') for input_dataset_dict in input_dataset_dicts: if subject == 'ALL': subjects = list(input_dataset_dict.keys()) print('[INFO]using subjects: ', subjects) else: subjects = [subject] total_num_sentence = len(input_dataset_dict[subjects[0]]) train_divider = int(0.8*total_num_sentence) dev_divider = train_divider + int(0.1*total_num_sentence) print(f'train divider = {train_divider}') print(f'dev divider = {dev_divider}') if setting == 'unique_sent': # take first 80% as trainset, 10% as dev and 10% as test if phase == 'train': print('[INFO]initializing a train set...') for key in subjects: for i in range(train_divider): input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input) if input_sample is not None: self.inputs.append(input_sample) elif phase == 'dev': print('[INFO]initializing a dev set...') for key in subjects: for i in range(train_divider,dev_divider): input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input) if input_sample is not None: self.inputs.append(input_sample) elif phase == 'test': print('[INFO]initializing a test set...') for key in subjects: for i in range(dev_divider,total_num_sentence): input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input) if input_sample is not None: self.inputs.append(input_sample) elif setting == 'unique_subj': print('WARNING!!! only implemented for SR v1 dataset ') # subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train # subject ['ZMG'] for dev # subject ['ZPH'] for test if phase == 'train': print(f'[INFO]initializing a train set using {setting} setting...') for i in range(total_num_sentence): for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']: input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) if input_sample is not None: self.inputs.append(input_sample) if phase == 'dev': print(f'[INFO]initializing a dev set using {setting} setting...') for i in range(total_num_sentence): for key in ['ZMG']: input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) if input_sample is not None: self.inputs.append(input_sample) if phase == 'test': print(f'[INFO]initializing a test set using {setting} setting...') for i in range(total_num_sentence): for key in ['ZPH']: input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) if input_sample is not None: self.inputs.append(input_sample) print('++ adding task to dataset, now we have:', len(self.inputs)) print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size()) print() def __len__(self): return len(self.inputs) def __getitem__(self, idx): input_sample = self.inputs[idx] return ( input_sample['input_embeddings'], input_sample['seq_len'], input_sample['input_attn_mask'], input_sample['input_attn_mask_invert'], input_sample['target_ids'], input_sample['target_mask'], input_sample['sentiment_label'], #input_sample['sent_level_EEG'] ) # keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, """for train classifier on stanford sentiment treebank text-sentiment pairs""" class SST_tenary_dataset(Dataset): def __init__(self, ternary_labels_dict, tokenizer, max_len = 56, balance_class = True): self.inputs = [] pos_samples = [] neg_samples = [] neu_samples = [] for key,value in ternary_labels_dict.items(): tokenized_inputs = tokenizer(key, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True) input_ids = tokenized_inputs['input_ids'][0] attn_masks = tokenized_inputs['attention_mask'][0] label = torch.tensor(value) # count: if value == 0: neg_samples.append((input_ids,attn_masks,label)) elif value == 1: neu_samples.append((input_ids,attn_masks,label)) elif value == 2: pos_samples.append((input_ids,attn_masks,label)) print(f'Original distribution:\n\tVery positive: {len(pos_samples)}\n\tNeutral: {len(neu_samples)}\n\tVery negative: {len(neg_samples)}') if balance_class: print(f'balance class to {min([len(pos_samples),len(neg_samples),len(neu_samples)])} each...') for i in range(min([len(pos_samples),len(neg_samples),len(neu_samples)])): self.inputs.append(pos_samples[i]) self.inputs.append(neg_samples[i]) self.inputs.append(neu_samples[i]) else: self.inputs = pos_samples + neg_samples + neu_samples def __len__(self): return len(self.inputs) def __getitem__(self, idx): input_sample = self.inputs[idx] return input_sample # keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, '''sanity test''' if __name__ == '__main__': check_dataset = 'stanford_sentiment' if check_dataset == 'ZuCo': whole_dataset_dicts = [] dataset_path_task1 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task1-SR/pickle/task1-SR-dataset-with-tokens_6-25.pickle' with open(dataset_path_task1, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) dataset_path_task2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR/pickle/task2-NR-dataset-with-tokens_7-10.pickle' with open(dataset_path_task2, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) # dataset_path_task3 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset-with-tokens_7-10.pickle' # with open(dataset_path_task3, 'rb') as handle: # whole_dataset_dicts.append(pickle.load(handle)) dataset_path_task2_v2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset-with-tokens_7-15.pickle' with open(dataset_path_task2_v2, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) print() for key in whole_dataset_dicts[0]: print(f'task2_v2, sentence num in {key}:',len(whole_dataset_dicts[0][key])) print() tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') dataset_setting = 'unique_sent' subject_choice = 'ALL' print(f'![Debug]using {subject_choice}') eeg_type_choice = 'GD' print(f'[INFO]eeg type {eeg_type_choice}') bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] print(f'[INFO]using bands {bands_choice}') train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) print('trainset size:',len(train_set)) print('devset size:',len(dev_set)) print('testset size:',len(test_set)) elif check_dataset == 'stanford_sentiment': tokenizer = BertTokenizer.from_pretrained('bert-base-cased') SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer) print('SST dataset size:',len(SST_dataset)) print(SST_dataset[0]) print(SST_dataset[1]) ================================================ FILE: environment.yml ================================================ name: EEGToText channels: - pytorch - anaconda - conda-forge - huggingface dependencies: - pytorch=1.9.0 - torchaudio=0.9.0 - cudatoolkit=11.1 - scipy=1.6.2 - h5py=3.4.0 - tqdm=4.62.0 - matplotlib=3.3.2 - transformers=4.6.1 - nltk=3.5 - pip=21.0.1 - pip: - fuzzy-match==0.0.1 - rouge==1.0.0 ================================================ FILE: eval_decoding.py ================================================ import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler import pickle import json import matplotlib.pyplot as plt from glob import glob import time import copy from tqdm import tqdm import torch.nn.functional as F import time from transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, PegasusForConditionalGeneration, PegasusTokenizer, T5Tokenizer, T5ForConditionalGeneration, BertGenerationDecoder from data import ZuCo_dataset from model_decoding import BrainTranslator, BrainTranslatorNaive, T5Translator from nltk.translate.bleu_score import sentence_bleu, corpus_bleu from rouge import Rouge from config import get_config import evaluate from evaluate import load metric = evaluate.load("sacrebleu") cer_metric = load("cer") wer_metric = load("wer") def remove_text_after_token(text, token=''): # 특정 토큰 이후의 텍스트를 찾아 제거 token_index = text.find(token) if token_index != -1: # 토큰이 발견된 경우 return text[:token_index] # 토큰 이전까지의 텍스트 반환 return text # 토큰이 없으면 원본 텍스트 반환 def eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = './results/temp.txt' , score_results='./score_results/task.txt'): # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html start_time = time.time() model.eval() # Set model to evaluate mode target_tokens_list = [] target_string_list = [] pred_tokens_list = [] pred_string_list = [] pred_tokens_list_previous = [] pred_string_list_previous = [] with open(output_all_results_path,'w') as f: for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels in tqdm(dataloaders['test']): # load in batch input_embeddings_batch = input_embeddings.to(device).float() # B, 56, 840 input_masks_batch = input_masks.to(device) # B, 56 target_ids_batch = target_ids.to(device) # B, 56 input_mask_invert_batch = input_mask_invert.to(device) # B, 56 target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch[0].tolist(), skip_special_tokens = True) target_string = tokenizer.decode(target_ids_batch[0], skip_special_tokens = True) # print('target ids tensor:',target_ids_batch[0]) # print('target ids:',target_ids_batch[0].tolist()) # print('target tokens:',target_tokens) # print('target string:',target_strininvert.to(device) # B, 56 f.write(f'target string: {target_string}\n') # add to list for later calculate bleu metric target_tokens_list.append([target_tokens]) target_string_list.append(target_string) """replace padding ids in target_ids with -100""" target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100 # target_ids_batch_label = target_ids_batch.clone().detach() # target_ids_batch_label[target_ids_batch_label == tokenizer.pad_token_id] = -100 # Original code seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch) # (batch, time, n_class) logits_previous = seq2seqLMoutput.logits probs_previous = logits_previous[0].softmax(dim = 1) values_previous, predictions_previous = probs_previous.topk(1) predictions_previous = torch.squeeze(predictions_previous) predicted_string_previous = remove_text_after_token(tokenizer.decode(predictions_previous).split('')[0].replace('','')) f.write(f'predicted string with tf: {predicted_string_previous}\n') predictions_previous = predictions_previous.tolist() truncated_prediction_previous = [] for t in predictions_previous: if t != tokenizer.eos_token_id: truncated_prediction_previous.append(t) else: break pred_tokens_previous = tokenizer.convert_ids_to_tokens(truncated_prediction_previous, skip_special_tokens = True) pred_tokens_list_previous.append(pred_tokens_previous) pred_string_list_previous.append(predicted_string_previous) # Modify code predictions=model.generate(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch, max_length=56, num_beams=5, do_sample=True, repetition_penalty= 5.0, no_repeat_ngram_size = 2 # num_beams=5,encoder_no_repeat_ngram_size =1, # do_sample=True, top_k=15,temperature=0.5,num_return_sequences=5, # early_stopping=True ) predicted_string=tokenizer.batch_decode(predictions, skip_special_tokens=True)[0] # predicted_string=predicted_string.squeeze() predictions=tokenizer.encode(predicted_string) # print('predicted string:',predicted_string) f.write(f'predicted string: {predicted_string}\n') f.write(f'################################################\n\n\n') # convert to int list # predictions = predictions.tolist() # 이미 list 형식이다. truncated_prediction = [] for t in predictions: if t != tokenizer.eos_token_id: truncated_prediction.append(t) else: break pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True) # print('predicted tokens:',pred_tokens) pred_tokens_list.append(pred_tokens) pred_string_list.append(predicted_string) # pred_tokens_list.extend(pred_tokens) # pred_string_list.extend(predicted_string) # print('################################################') # print() # print(f"pred_string_list : {pred_string_list}") """ calculate corpus bleu score """ weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)] corpus_bleu_scores = [] corpus_bleu_scores_previous = [] for weight in weights_list: # print('weight:',weight) corpus_bleu_score = corpus_bleu(target_tokens_list, pred_tokens_list, weights = weight) corpus_bleu_score_previous = corpus_bleu(target_tokens_list, pred_tokens_list_previous, weights = weight) corpus_bleu_scores.append(corpus_bleu_score) corpus_bleu_scores_previous.append(corpus_bleu_score_previous) print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score) print(f'corpus BLEU-{len(list(weight))} score with tf:', corpus_bleu_score_previous) """ calculate sacre bleu score """ reference_list = [[item] for item in target_string_list] #print(f'ref: {reference_list}') #print(f'pred: {prediction_list}') sacre_blue = metric.compute(predictions=pred_string_list, references=reference_list) sacre_blue_previous = metric.compute(predictions=pred_string_list_previous, references=reference_list) print("sacreblue score: ", sacre_blue, '\n') print("sacreblue score with tf: ", sacre_blue_previous) print() """ calculate rouge score """ rouge = Rouge() # pred_string_list = [item for sublist in pred_string_list for item in sublist] # pred_string_list = [item for sublist in pred_string_list for item in sublist] # pred_string_list_previous = [item for sublist in pred_string_list_previous for item in sublist] # rouge_scores = rouge.get_scores(pred_string_list, target_string_list, avg = True, ignore_empty=True) # rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True) # print('rouge_scores: ', rouge_scores) # print('rouge_scores with tf:', rouge_scores_previous) # rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True) # print('rouge_scores', rouge_scores) # print('previous rouge_scores', rouge_scores_previous) try: rouge_scores = rouge.get_scores(pred_string_list, target_string_list, avg = True, ignore_empty=True) except ValueError as e: rouge_scores = 'Hypothesis is empty' try: rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True) except ValueError as e: rouge_scores_previous = 'Hypothesis is empty' print() print() """ calculate WER score """ #wer = WordErrorRate() wer_scores = wer_metric.compute(predictions=pred_string_list, references=target_string_list) wer_scores_previous = wer_metric.compute(predictions=pred_string_list_previous, references=target_string_list) print("WER score:", wer_scores) print("WER score with tf:", wer_scores_previous) """ calculate CER score """ cer_scores = cer_metric.compute(predictions=pred_string_list, references=target_string_list) cer_scores_previous = cer_metric.compute(predictions=pred_string_list_previous, references=target_string_list) print("CER score:", cer_scores) print("CER score with tf:", cer_scores_previous) end_time = time.time() print(f"Evaluation took {(end_time-start_time)/60} minutes to execute.") # score_results (only fix teacher-forcing) file_content = [ f'corpus_bleu_score = {corpus_bleu_scores}', f'sacre_blue_score = {sacre_blue}', f'rouge_scores = {rouge_scores}', f'wer_scores = {wer_scores}', f'cer_scores = {cer_scores}', f'corpus_bleu_score_with_tf = {corpus_bleu_scores_previous}', f'sacre_blue_score_with_tf = {sacre_blue_previous}', f'rouge_scores_with_tf = {rouge_scores_previous}', f'wer_scores_with_tf = {wer_scores_previous}', f'cer_scores_with_tf = {cer_scores_previous}', ] with open(score_results, "a") as file_results: for line in file_content: if isinstance(line, list): for item in line: file_results.write(str(item) + "\n") else: file_results.write(str(line) + "\n") if __name__ == '__main__': batch_size = 1 ''' get args''' args = get_config('eval_decoding') test_input = args['test_input'] print("test_input is:", test_input) train_input = args['train_input'] print("train_input is:", train_input) ''' load training config''' training_config = json.load(open(args['config_path'])) subject_choice = training_config['subjects'] print(f'[INFO]subjects: {subject_choice}') eeg_type_choice = training_config['eeg_type'] print(f'[INFO]eeg type: {eeg_type_choice}') bands_choice = training_config['eeg_bands'] print(f'[INFO]using bands: {bands_choice}') dataset_setting = 'unique_sent' task_name = training_config['task_name'] model_name = training_config['model_name'] if test_input == 'EEG' and train_input=='EEG': print("EEG and EEG") output_all_results_path = f'./results/{task_name}-{model_name}-all_decoding_results.txt' score_results = f'./score_results/{task_name}-{model_name}.txt' else: output_all_results_path = f'./results/{task_name}-{model_name}-{train_input}_{test_input}-all_decoding_results.txt' score_results = f'./score_results/{task_name}-{model_name}-{train_input}_{test_input}.txt' ''' set random seeds ''' seed_val = 20 #500 np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) ''' set up device ''' # use cuda if torch.cuda.is_available(): dev = 0 else: dev = "cpu" # CUDA_VISIBLE_DEVICES=0,1,2,3 device = torch.device(dev) print(f'[INFO]using device {dev}') # task_name = 'task1_task2_task3' ''' set up dataloader ''' whole_dataset_dicts = [] if 'task1' in task_name: dataset_path_task1 = '/data/johj/ZuCo_data/task1-SR/task1_source.pkl' with open(dataset_path_task1, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) if 'task2' in task_name: dataset_path_task2 = '/data/johj/ZuCo_data/task2-NR/task2_source.pkl' with open(dataset_path_task2, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) if 'task3' in task_name: dataset_path_task3 = '/data/johj/ZuCo_data/task3-TSR/task3_source.pkl' with open(dataset_path_task3, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) if 'taskNRv2' in task_name: dataset_path_taskNRv2 = '/data/johj/ZuCo_data/task2-NR-2.0/taskNRv2_source.pkl' with open(dataset_path_taskNRv2, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) print() if model_name in ['BrainTranslator','BrainTranslatorNaive']: tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') elif model_name == 'PegasusTranslator': tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum') elif model_name == 'T5Translator': tokenizer = T5Tokenizer.from_pretrained("t5-large") # tokenizer.set_prefix_tokens(language='english') # test dataset test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=test_input) dataset_sizes = {"test_set":len(test_set)} print('[INFO]test_set size: ', len(test_set)) # dataloaders test_dataloader = DataLoader(test_set, batch_size = batch_size, shuffle=False, num_workers=4) dataloaders = {'test':test_dataloader} ''' set up model ''' checkpoint_path = args['checkpoint_path'] if model_name == 'BrainTranslator': pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large') model = BrainTranslator(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif model_name == 'BrainTranslatorNaive': pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large') model = BrainTranslatorNaive(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif model_name == 'BertGeneration': pretrained = BertGenerationDecoder.from_pretrained('google-bert/bert-large-uncased', is_decoder = True) model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif model_name == 'PegasusTranslator': pretrained = PegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum') model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif model_name == 'T5Translator': pretrained = T5ForConditionalGeneration.from_pretrained("t5-large") model = T5Translator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) state_dict = torch.load(checkpoint_path) new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} model.load_state_dict(new_state_dict) ''' if isinstance(model, nn.DataParallel): model.module.load_state_dict(torch.load(checkpoint_path)) else: model.load_state_dict(torch.load(checkpoint_path)) ''' # model.load_state_dict(torch.load(checkpoint_path)) model.to(device) criterion = nn.CrossEntropyLoss() ''' eval ''' eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = output_all_results_path, score_results=score_results) ================================================ FILE: eval_sentiment.py ================================================ import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler from torch.nn.utils.rnn import pack_padded_sequence import pickle import json import matplotlib.pyplot as plt from glob import glob import time import copy from tqdm import tqdm from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification from data import ZuCo_dataset from model_sentiment import BaselineMLPSentence, BaselineLSTM, FineTunePretrainedTwoStep, ZeroShotSentimentDiscovery, JointBrainTranslatorSentimentClassifier from model_decoding import BrainTranslator, BrainTranslatorNaive from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import accuracy_score from config import get_config # Function to calculate the accuracy of our predictions vs labels def flat_accuracy(preds, labels): # preds: numpy array: N * 3 # labels: numpy array: N pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() return np.sum(pred_flat == labels_flat) / len(labels_flat) def flat_accuracy_top_k(preds, labels,k): topk_preds = [] for pred in preds: topk = pred.argsort()[-k:][::-1] topk_preds.append(list(topk)) # print(topk_preds) topk_preds = list(topk_preds) right_count = 0 # print(len(labels)) for i in range(len(labels)): l = labels[i][0] if l in topk_preds[i]: right_count+=1 return right_count/len(labels) def eval_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')): def logits2PredString(logits, tokenizer): probs = logits[0].softmax(dim = 1) # print('probs size:', probs.size()) values, predictions = probs.topk(1) # print('predictions before squeeze:',predictions.size()) predictions = torch.squeeze(predictions) predict_string = tokenizer.decode(predictions) return predict_string # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 100000000000 best_acc = 0.0 total_pred_labels = np.array([]) total_true_labels = np.array([]) for epoch in range(1): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['test']: total_accuracy = 0.0 if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in dataloaders[phase]: input_word_eeg_features = input_word_eeg_features.to(device).float() input_masks = input_masks.to(device) input_mask_invert = input_mask_invert.to(device) sent_level_EEG = sent_level_EEG.to(device) sentiment_labels = sentiment_labels.to(device) target_ids = target_ids.to(device) target_mask = target_mask.to(device) ## forward ################### if isinstance(model, BaselineMLPSentence): logits = model(sent_level_EEG) # before softmax # calculate loss loss = criterion(logits, sentiment_labels) elif isinstance(model, BaselineLSTM): x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False) logits = model(x_packed) # calculate loss loss = criterion(logits, sentiment_labels) elif isinstance(model, BertForSequenceClassification) or isinstance(model, RobertaForSequenceClassification) or isinstance(model, BartForSequenceClassification): output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels) logits = output.logits loss = output.loss elif isinstance(model, FineTunePretrainedTwoStep): output = model(input_word_eeg_features, input_masks, input_mask_invert, sentiment_labels) logits = output.logits loss = output.loss elif isinstance(model, ZeroShotSentimentDiscovery): print() print('target string:',tokenizer.decode(target_ids[0]).replace('','').split('')[0]) """replace padding ids in target_ids with -100""" target_ids[target_ids == tokenizer.pad_token_id] = -100 output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels) logits = output.logits loss = output.loss elif isinstance(model, JointBrainTranslatorSentimentClassifier): print() print('target string:',tokenizer.decode(target_ids[0]).replace('','').split('')[0]) """replace padding ids in target_ids with -100""" target_ids[target_ids == tokenizer.pad_token_id] = -100 LM_output, classification_output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels) LM_logits = LM_output.logits print('pred string:', logits2PredString(LM_logits, tokenizer).split('')[0].replace('','')) classification_loss = classification_output['loss'] logits = classification_output['logits'] loss = classification_loss ############################### # backward + optimize only if in training phase if phase == 'train': # with torch.autograd.detect_anomaly(): loss.backward() optimizer.step() # calculate accuracy preds_cpu = logits.detach().cpu().numpy() label_cpu = sentiment_labels.cpu().numpy() total_accuracy += flat_accuracy(preds_cpu, label_cpu) # add to total pred and label array, for cal F1, precision, recall pred_flat = np.argmax(preds_cpu, axis=1).flatten() labels_flat = label_cpu.flatten() total_pred_labels = np.concatenate((total_pred_labels,pred_flat)) total_true_labels = np.concatenate((total_true_labels,labels_flat)) # statistics running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss # print('[DEBUG]loss:',loss.item()) # print('#################################') if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = total_accuracy / len(dataloaders[phase]) print('{} Loss: {:.4f}'.format(phase, epoch_loss)) print('{} Acc: {:.4f}'.format(phase, epoch_acc)) # deep copy the model if phase == 'test' and epoch_loss < best_loss: best_loss = epoch_loss best_acc = epoch_acc print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best test loss: {:4f}'.format(best_loss)) print('Best test acc: {:4f}'.format(best_acc)) print() print('test sample num:', len(total_pred_labels)) print('total preds:',total_pred_labels) print('total truth:',total_true_labels) print('sklearn macro: precision, recall, F1:') print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='macro')) print() print('sklearn micro: precision, recall, F1:') print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='micro')) print() print('sklearn accuracy:') print(accuracy_score(total_true_labels,total_pred_labels)) print() if __name__ == '__main__': args = get_config('eval_sentiment') ''' config param''' num_epochs = 1 dataset_setting = 'unique_sent' '''model name''' # model_name = 'BaselineMLP' # model_name = 'BaselineLSTM' # model_name = 'NaiveFinetuneBert' # model_name = 'FinetunedBertOnText' # model_name = 'FinetunedRoBertaOnText' # model_name = 'FinetunedBartOnText' # model_name = 'ZeroShotSentimentDiscovery' model_name = args['model_name'] print(f'[INFO] eval {model_name}') if model_name == 'ZeroShotSentimentDiscovery': '''load decoder and classifier config''' config_decoder = json.load(open(args['decoder_config_path'])) config_classifier = json.load(open(args['classifier_config_path'])) '''choose generator''' # decoder_name = 'BrainTranslator' # decoder_name = 'BrainTranslatorNaive' decoder_name = config_decoder['model_name'] decoder_checkpoint = args['decoder_checkpoint_path'] print(f'[INFO] using decoder: {decoder_name}') '''choose classifier''' # pretrain_Bert, pretrain_RoBerta, pretrain_Bart classifier_name = config_classifier['model_name'] classifier_checkpoint = args['classifier_checkpoint_path'] print(f'[INFO] using classifier: {classifier_name}') else: checkpoint_path = args['checkpoint_path'] print('[INFO] loading baseline:', checkpoint_path) batch_size = 1 # subject_choice = 'ALL subject_choice = args['subjects'] print(f'![Debug]using {subject_choice}') # eeg_type_choice = 'GD eeg_type_choice = args['eeg_type'] print(f'[INFO]eeg type {eeg_type_choice}') # bands_choice = ['_t1'] # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] bands_choice = args['eeg_bands'] print(f'[INFO]using bands {bands_choice}') ''' set random seeds ''' seed_val = 312 np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) ''' set up device ''' # use cuda if torch.cuda.is_available(): dev = args['cuda'] else: dev = "cpu" # CUDA_VISIBLE_DEVICES=0,1,2,3 device = torch.device(dev) print(f'[INFO]using device {dev}') ''' load pickle''' whole_dataset_dict = [] dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' with open(dataset_path_task1, 'rb') as handle: whole_dataset_dict.append(pickle.load(handle)) '''set up tokenizer''' if model_name in ['BaselineMLP','BaselineLSTM', 'NaiveFinetuneBert', 'FinetunedBertOnText']: print('[INFO]using Bert tokenizer') tokenizer = BertTokenizer.from_pretrained('bert-base-cased') elif model_name == 'FinetunedBartOnText': print('[INFO]using Bart tokenizer') tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') elif model_name == 'FinetunedRoBertaOnText': print('[INFO]using RoBerta tokenizer') tokenizer = RobertaTokenizer.from_pretrained('roberta-base') elif model_name == 'ZeroShotSentimentDiscovery': decoder_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') # Bart tokenizer = decoder_tokenizer if classifier_name == 'pretrain_Bert': sentiment_tokenizer = BertTokenizer.from_pretrained('bert-base-cased') # Bert elif classifier_name == 'pretrain_Bart': sentiment_tokenizer = decoder_tokenizer elif classifier_name == 'pretrain_RoBerta': sentiment_tokenizer = RobertaTokenizer.from_pretrained('roberta-base') ''' set up model ''' if model_name == 'BaselineMLP': print('[INFO]Model: BaselineMLP') model = BaselineMLPSentence(input_dim = 840, hidden_dim = 128, output_dim = 3) elif model_name == 'BaselineLSTM': print('[INFO]Model: BaselineLSTM') # model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1) model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 4) elif model_name == 'FinetunedBertOnText': print('[INFO]Model: FinetunedBertOnText') model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3) elif model_name == 'FinetunedRoBertaOnText': print('[INFO]Model: FinetunedRoBertaOnText') model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3) elif model_name == 'FinetunedBartOnText': print('[INFO]Model: FinetunedBartOnText') model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3) elif model_name == 'ZeroShotSentimentDiscovery': print(f'[INFO]Model: ZeroShotSentimentDiscovery, using classifer:{classifier_name}, using generator: {decoder_name}') pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large') if decoder_name == 'BrainTranslator': decoder = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif decoder_name == 'BrainTranslatorNaive': decoder = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) decoder.load_state_dict(torch.load(decoder_checkpoint)) if classifier_name == 'pretrain_Bert': classifier = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3) elif classifier_name == 'pretrain_Bart': classifier = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3) elif classifier_name == 'pretrain_RoBerta': classifier = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3) classifier.load_state_dict(torch.load(classifier_checkpoint)) model = ZeroShotSentimentDiscovery(decoder, classifier, decoder_tokenizer, sentiment_tokenizer, device = device) model.to(device) if model_name != 'ZeroShotSentimentDiscovery': # load model and send to device model.load_state_dict(torch.load(checkpoint_path)) model.to(device) ''' set up dataloader ''' # test dataset test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = 'unique_sent') dataset_sizes = {'test': len(test_set)} # print('[INFO]train_set size: ', len(train_set)) print('[INFO]test_set size: ', len(test_set)) test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4) # dataloaders dataloaders = {'test':test_dataloader} ''' set up optimizer and scheduler''' optimizer_step1 = None exp_lr_scheduler_step1 = None ''' set up loss function ''' criterion = nn.CrossEntropyLoss() print('=== start training ... ===') # return best loss model from step1 training model = eval_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, tokenizer = tokenizer) ================================================ FILE: model_decoding.py ================================================ import torch.nn as nn import torch.nn.functional as F import torch.utils.data from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig import math import numpy as np """ main architecture for open vocabulary EEG-To-Text decoding""" class BrainTranslator(nn.Module): def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048): super(BrainTranslator, self).__init__() self.pretrained = pretrained_layers # additional transformer encoder, following BART paper about self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True) self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6) # print('[INFO]adding positional embedding') # self.positional_embedding = PositionalEncoding(in_feature) self.fc1 = nn.Linear(in_feature, decoder_embedding_size) def addin_forward(self,input_embeddings_batch, input_masks_invert): """input_embeddings_batch: batch_size*Seq_len*840""" """input_mask: 1 is not masked, 0 is masked""" """input_masks_invert: 1 is masked, 0 is not masked""" # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) # use src_key_padding_masks encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert) # encoded_embedding = self.additional_encoder(input_embeddings_batch) encoded_embedding = F.relu(self.fc1(encoded_embedding)) return encoded_embedding @torch.no_grad() def generate( self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, generation_config = None, logits_processor = None, stopping_criteria = None, prefix_allowed_tokens_fn= None, synced_gpus= None, assistant_model = None, streamer= None, negative_prompt_ids= None, negative_prompt_attention_mask = None, **kwargs, ): encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert) output=self.pretrained.generate( inputs_embeds = encoded_embedding, attention_mask = input_masks_batch[:,:encoded_embedding.shape[1]], labels = target_ids_batch_converted, return_dict = True, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, **kwargs,) return output def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted): encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert) # print(f'forward:{input_embeddings_batch.shape,input_masks_batch.shape,input_masks_invert.shape,target_ids_batch_converted.shape,encoded_embedding.shape}') out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted) return out from transformers import T5Tokenizer """ main architecture for open vocabulary EEG-To-Text decoding""" class T5Translator(nn.Module): def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048): super(T5Translator, self).__init__() self.pretrained = pretrained_layers self.tokenizer = T5Tokenizer.from_pretrained("t5-large") # additional transformer encoder, following BART paper about self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True) self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6) # print('[INFO]adding positional embedding') # self.positional_embedding = PositionalEncoding(in_feature) self.fc1 = nn.Linear(in_feature, decoder_embedding_size) def addin_forward(self,input_embeddings_batch, input_masks_invert): """input_embeddings_batch: batch_size*Seq_len*840""" """input_mask: 1 is not masked, 0 is masked""" """input_masks_invert: 1 is masked, 0 is not masked""" # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) # use src_key_padding_masks encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert) # encoded_embedding = self.additional_encoder(input_embeddings_batch) encoded_embedding = F.relu(self.fc1(encoded_embedding)) return encoded_embedding @torch.no_grad() def generate( self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, generation_config = None, logits_processor = None, stopping_criteria = None, prefix_allowed_tokens_fn= None, synced_gpus= None, assistant_model = None, streamer= None, negative_prompt_ids= None, negative_prompt_attention_mask = None, **kwargs, ): encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert) input_ids = self.tokenizer("transcribe in English: ", return_tensors="pt").input_ids.to(encoded_embedding.device) self.task_embedding = self.pretrained.shared(input_ids).to(encoded_embedding.device) task_embedding = self.task_embedding.repeat(encoded_embedding.size(0), 1, 1).to(encoded_embedding.device) encoded_embedding = torch.cat((task_embedding, encoded_embedding), dim=1) input_masks_batch = torch.cat((torch.ones(encoded_embedding.size(0), task_embedding.size(1)).to(encoded_embedding.device), input_masks_batch), dim=1) output=self.pretrained.generate( inputs_embeds = encoded_embedding, attention_mask = input_masks_batch[:,:encoded_embedding.shape[1]], labels = target_ids_batch_converted, return_dict = True, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, **kwargs,) return output def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted): encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert) # task definition input_ids = self.tokenizer("transcribe in English: ", return_tensors="pt").input_ids.to(encoded_embedding.device) self.task_embedding = self.pretrained.shared(input_ids).to(encoded_embedding.device) task_embedding = self.task_embedding.repeat(encoded_embedding.size(0), 1, 1).to(encoded_embedding.device) encoded_embedding = torch.cat((task_embedding, encoded_embedding), dim=1) input_masks_batch = torch.cat((torch.ones(encoded_embedding.size(0), task_embedding.size(1)).to(encoded_embedding.device), input_masks_batch), dim=1) out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted) return out """ crippled open vocabulary EEG-To-Text decoding model w/o additional MTE encoder""" class BrainTranslatorNaive(nn.Module): def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048): super(BrainTranslatorNaive, self).__init__() '''no additional transformer encoder version''' self.pretrained = pretrained_layers self.fc1 = nn.Linear(in_feature, decoder_embedding_size) def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted): """input_embeddings_batch: batch_size*Seq_len*840""" """input_mask: 1 is not masked, 0 is masked""" """input_masks_invert: 1 is masked, 0 is not masked""" encoded_embedding = F.relu(self.fc1(input_embeddings_batch)) out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted) return out """ helper modules """ # modified from BertPooler class Pooler(nn.Module): def __init__(self, hidden_size): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output # from https://pytorch.org/tutorials/beginner/transformer_tutorial.html class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # print('[DEBUG] input size:', x.size()) # print('[DEBUG] positional embedding size:', self.pe.size()) x = x + self.pe[:x.size(0), :] # print('[DEBUG] output x with pe size:', x.size()) return self.dropout(x) """ Miscellaneous (not working well) """ class BrainTranslatorBert(nn.Module): def __init__(self, pretrained_layers, in_feature = 840, hidden_size = 768): super(BrainTranslatorBert, self).__init__() self.pretrained_Bert = pretrained_layers self.fc1 = nn.Linear(in_feature, hidden_size) def forward(self, input_embeddings_batch, input_masks_batch, target_ids_batch): embedding = F.relu(self.fc1(input_embeddings_batch)) out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = target_ids_batch, return_dict = True) return out class EEG2BertMapping(nn.Module): def __init__(self, in_feature = 840, hidden_size = 512, out_feature = 768): super(EEG2BertMapping, self).__init__() self.fc1 = nn.Linear(in_feature, hidden_size) self.fc2 = nn.Linear(hidden_size, out_feature) def forward(self, x): out = F.relu(self.fc1(x)) out = self.fc2(out) return out class ContrastiveBrainTextEncoder(nn.Module): def __init__(self, pretrained_text_encoder, in_feature = 840, eeg_encoder_nhead=8, eeg_encoder_dim_feedforward = 2048, embed_dim = 768): super(ContrastiveBrainTextEncoder, self).__init__() # EEG Encoder self.positional_embedding = PositionalEncoding(in_feature) self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=eeg_encoder_nhead, dim_feedforward = eeg_encoder_dim_feedforward, batch_first=True) self.EEG_Encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6) self.EEG_pooler = Pooler(in_feature) self.ln_final = nn.LayerNorm(in_feature) # to be considered # project to text embedding self.EEG_projection = nn.Parameter(torch.empty(in_feature, embed_dim)) # Text Encoder self.TextEncoder = pretrained_text_encoder # learned temperature parameter self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def forward(self, input_EEG_features, input_EEG_attn_mask, input_ids, input_text_attention_masks): # add positional embedding input_EEG_features = self.positional_embedding(input_EEG_features) # get EEG feature embedding EEG_hiddenstates = self.EEG_Encoder(input_EEG_features, src_key_padding_mask = input_EEG_attn_mask) EEG_hiddenstates = self.ln_final(EEG_hiddenstates) EEG_features = self.EEG_pooler(EEG_hiddenstates) # [N, 840] # project to text embed size EEG_features = EEG_features @ self.EEG_projection # [N, 768] # get text feature embedding Text_features = self.TextEncoder(input_ids = input_ids, attention_mask = input_text_attention_masks, return_dict = True).pooler_output # [N, 768] # normalized features EEG_features = EEG_features / EEG_features.norm(dim=-1, keepdim=True) # [N, 768] Text_features = Text_features / Text_features.norm(dim=-1, keepdim=True) # [N, 768] # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_EEG = logit_scale * EEG_features @ Text_features.t() # [N, N] logits_per_text = logit_scale * Text_features @ EEG_features.t() # [N, N] return logits_per_EEG, logits_per_text ================================================ FILE: model_sentiment.py ================================================ import torch.nn as nn import torch.nn.functional as F import torch.utils.data from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertForSequenceClassification import math import numpy as np from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence """MLP baseline using sentence level eeg""" # using sent level EEG, MLP baseline for sentiment class BaselineMLPSentence(nn.Module): def __init__(self, input_dim = 840, hidden_dim = 128, output_dim = 3): super(BaselineMLPSentence, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(hidden_dim, output_dim) # positive, negative, neutral self.dropout = nn.Dropout(0.25) def forward(self, x): out = self.fc1(x) out = self.relu1(out) out = self.fc2(out) out = self.relu2(out) out = self.dropout(out) out = self.fc3(out) return out """bidirectional LSTM baseline using word level eeg""" class BaselineLSTM(nn.Module): def __init__(self, input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1): super(BaselineLSTM, self).__init__() self.hidden_dim = hidden_dim self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers = 1, batch_first = True, bidirectional = True) self.hidden2sentiment = nn.Linear(hidden_dim*2, output_dim) def forward(self, x_packed): # input: (N,seq_len,input_dim) # print(x_packed.data.size()) lstm_out, _ = self.lstm(x_packed) last_hidden_state = pad_packed_sequence(lstm_out, batch_first = True)[0][:,-1,:] # print(last_hidden_state.size()) out = self.hidden2sentiment(last_hidden_state) return out """ Bert Baseline: Finetuning from a pretrained language model Bert""" class NaiveFineTunePretrainedBert(nn.Module): def __init__(self, input_dim = 840, hidden_dim = 768, output_dim = 3, pretrained_checkpoint = None): super(NaiveFineTunePretrainedBert, self).__init__() # mapping hidden states dimensioin self.fc1 = nn.Linear(input_dim, hidden_dim) self.pretrained_Bert = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3) if pretrained_checkpoint is not None: self.pretrained_Bert.load_state_dict(torch.load(pretrained_checkpoint)) def forward(self, input_embeddings_batch, input_masks_batch, labels): embedding = F.relu(self.fc1(input_embeddings_batch)) out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = labels, return_dict = True) return out """ Finetuning from a pretrained language model BART, two step training""" class FineTunePretrainedTwoStep(nn.Module): def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048): super(FineTunePretrainedTwoStep, self).__init__() self.pretrained_layers = pretrained_layers # additional transformer encoder, following BART paper about self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True) self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6) # NOTE: add positional embedding? # print('[INFO]adding positional embedding') # self.positional_embedding = PositionalEncoding(in_feature) self.fc1 = nn.Linear(in_feature, d_model) def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, labels): """input_embeddings_batch: batch_size*Seq_len*840""" """input_mask: 1 is not masked, 0 is masked""" """input_masks_invert: 1 is masked, 0 is not masked""" """labels: sentitment labels 0,1,2""" # NOTE: add positional embedding? # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) # use src_key_padding_masks encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) # encoded_embedding = self.additional_encoder(input_embeddings_batch) encoded_embedding = F.relu(self.fc1(encoded_embedding)) out = self.pretrained_layers(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = labels) return out """ Zero-shot sentiment discovery using a finetuned generation model and a sentiment model pretrained on text """ class ZeroShotSentimentDiscovery(nn.Module): def __init__(self, brain2text_translator, sentiment_classifier, translation_tokenizer, sentiment_tokenizer, device = 'cpu'): # only for inference super(ZeroShotSentimentDiscovery, self).__init__() self.brain2text_translator = brain2text_translator self.sentiment_classifier = sentiment_classifier self.translation_tokenizer = translation_tokenizer self.sentiment_tokenizer = sentiment_tokenizer self.device = device def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels): """input_embeddings_batch: batch_size*Seq_len*840""" """input_mask: 1 is not masked, 0 is masked""" """input_masks_invert: 1 is masked, 0 is not masked""" """labels: sentitment labels 0,1,2""" def logits2PredString(logits): probs = logits[0].softmax(dim = 1) # print('probs size:', probs.size()) values, predictions = probs.topk(1) # print('predictions before squeeze:',predictions.size()) predictions = torch.squeeze(predictions) predict_string = self.translation_tokenizer.decode(predictions) return predict_string # only works on batch is one assert input_embeddings_batch.size()[0] == 1 seq2seqLMoutput = self.brain2text_translator(input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted) predict_string = logits2PredString(seq2seqLMoutput.logits) predict_string = predict_string.split('')[0] predict_string = predict_string.replace('','') print('predict string:', predict_string) re_tokenized = self.sentiment_tokenizer(predict_string, return_tensors='pt', return_attention_mask = True) input_ids = re_tokenized['input_ids'].to(self.device) # batch = 1 attn_mask = re_tokenized['attention_mask'].to(self.device) # batch = 1 out = self.sentiment_classifier(input_ids = input_ids, attention_mask = attn_mask, return_dict = True, labels = sentiment_labels) return out """ Miscellaneous: jointly learn generation and classification (not working well) """ class BartClassificationHead(nn.Module): # from transformers: https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html """Head for sentence-level classification tasks.""" def __init__( self, input_dim: int, inner_dim: int, num_classes: int, pooler_dropout: float, ): super().__init__() self.dense = nn.Linear(input_dim, inner_dim) self.dropout = nn.Dropout(p=pooler_dropout) self.out_proj = nn.Linear(inner_dim, num_classes) def forward(self, hidden_states: torch.Tensor): hidden_states = self.dropout(hidden_states) hidden_states = self.dense(hidden_states) hidden_states = torch.tanh(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.out_proj(hidden_states) return hidden_states class JointBrainTranslatorSentimentClassifier(nn.Module): def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048, num_labels = 3): super(JointBrainTranslatorSentimentClassifier, self).__init__() self.pretrained_generator = pretrained_layers # additional transformer encoder, following BART paper about self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True) self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6) self.fc1 = nn.Linear(in_feature, d_model) self.num_labels = num_labels self.pooler = Pooler(d_model) self.classifier = BartClassificationHead(input_dim = d_model, inner_dim = d_model, num_classes = num_labels, pooler_dropout = pretrained_layers.config.classifier_dropout) def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels): """input_embeddings_batch: batch_size*Seq_len*840""" """input_mask: 1 is not masked, 0 is masked""" """input_masks_invert: 1 is masked, 0 is not masked""" # NOTE: add positional embedding? # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) # use src_key_padding_masks encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) # encoded_embedding = self.additional_encoder(input_embeddings_batch) encoded_embedding = F.relu(self.fc1(encoded_embedding)) LMoutput = self.pretrained_generator(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted, output_hidden_states = True) hidden_states = LMoutput.decoder_hidden_states # N, seq_len, hidden_dim # print('hidden states len:', len(hidden_states)) last_hidden_states = hidden_states[-1] # print('last hidden states size:', last_hidden_states.size()) sentence_representation = self.pooler(last_hidden_states) classification_logits = self.classifier(sentence_representation) loss_fct = nn.CrossEntropyLoss() classification_loss = loss_fct(classification_logits.view(-1, self.num_labels), sentiment_labels.view(-1)) classification_output = {'loss':classification_loss,'logits':classification_logits} # print('successful one forward!!!!') return LMoutput, classification_output """ helper modules """ # modified from BertPooler class Pooler(nn.Module): def __init__(self, hidden_size): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output # from https://pytorch.org/tutorials/beginner/transformer_tutorial.html class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # print('[DEBUG] input size:', x.size()) # print('[DEBUG] positional embedding size:', self.pe.size()) x = x + self.pe[:x.size(0), :] # print('[DEBUG] output x with pe size:', x.size()) return self.dropout(x) ================================================ FILE: scripts/eval_decoding_1.sh ================================================ CUDA_VISIBLE_DEVICES=0 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \ --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \ --test_input EEG \ --train_input EEG \ -cuda cuda:0 CUDA_VISIBLE_DEVICES=0 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \ --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \ --test_input noise \ --train_input EEG \ -cuda cuda:0 ================================================ FILE: scripts/eval_decoding_2.sh ================================================ CUDA_VISIBLE_DEVICES=1 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \ --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \ --test_input EEG \ --train_input EEG \ -cuda cuda:0 CUDA_VISIBLE_DEVICES=1 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \ --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \ --test_input noise \ --train_input EEG \ -cuda cuda:0 ================================================ FILE: scripts/eval_decoding_3.sh ================================================ CUDA_VISIBLE_DEVICES=2 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \ --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \ --test_input EEG \ --train_input noise \ -cuda cuda:0 CUDA_VISIBLE_DEVICES=2 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \ --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \ --test_input noise \ --train_input noise \ -cuda cuda:0 ================================================ FILE: scripts/eval_decoding_4.sh ================================================ CUDA_VISIBLE_DEVICES=3 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \ --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \ --test_input EEG \ --train_input noise \ -cuda cuda:0 CUDA_VISIBLE_DEVICES=3 python3 eval_decoding.py \ --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \ --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \ --test_input noise \ --train_input noise \ -cuda cuda:0 ================================================ FILE: scripts/eval_sentiment_zeroshot_pipeline.sh ================================================ python3 eval_sentiment.py --model_name ZeroShotSentimentDiscovery \ --decoder_checkpoint_path ./checkpoints/decoding/best/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.pt \ --classifier_checkpoint_path ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt \ --decoder_config_path ./config/decoding/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.json \ --classifier_config_path ./config/text_sentiment_classifier/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.json \ --cuda cuda:0 ================================================ FILE: scripts/prepare_dataset.sh ================================================ echo "This scirpt construct .pickle files from .mat files from ZuCo dataset." echo "This script also generates tenary sentiment_labels.json file for ZuCo task1-SR v1.0 and ternary_dataset.json from filtered StanfordSentimentTreebank" echo "Note: the sentences in ZuCo task1-SR do not overlap with sentences in filtered StanfordSentimentTreebank " echo "Note: This process can take time, please be patient..." python3 ./util/construct_dataset_mat_to_pickle_v1.py -t task1-SR python3 ./util/construct_dataset_mat_to_pickle_v1.py -t task2-NR python3 ./util/construct_dataset_mat_to_pickle_v1.py -t task3-TSR python3 ./util/construct_dataset_mat_to_pickle_v2.py python3 ./util/get_sentiment_labels.py python3 ./util/get_SST_ternary_dataset.py ================================================ FILE: scripts/train_decoding.sh ================================================ CUDA_VISIBLE_DEVICES=0 python3 train_decoding.py --model_name BrainTranslator \ --task_name task1_task2_task3 \ --one_step \ --pretrained \ --not_load_step1_checkpoint \ --num_epoch_step1 20 \ --num_epoch_step2 30 \ --train_input noise \ -lr1 0.00002 \ -lr2 0.00002 \ -b 32 \ -s ./checkpoints/decoding \ CUDA_VISIBLE_DEVICES=0,1 python3 train_decoding.py --model_name T5Translator \ --task_name task1_task2_task3 \ --one_step \ --pretrained \ --not_load_step1_checkpoint \ --num_epoch_step1 20 \ --num_epoch_step2 30 \ --train_input noise \ -lr1 0.00002 \ -lr2 0.00002 \ -b 32 \ -s ./checkpoints/decoding \ ================================================ FILE: scripts/train_decoding_1.sh ================================================ CUDA_VISIBLE_DEVICES=2,3 python3 train_decoding.py --model_name T5Translator \ --task_name task1_task2_taskNRv2 \ --one_step \ --pretrained \ --not_load_step1_checkpoint \ --num_epoch_step1 20 \ --num_epoch_step2 30 \ --train_input EEG \ -lr1 0.00002 \ -lr2 0.00002 \ -b 32 \ -s ./checkpoints/decoding \ CUDA_VISIBLE_DEVICES=2,3 python3 train_decoding.py --model_name T5Translator \ --task_name task1_task2_task3 \ --one_step \ --pretrained \ --not_load_step1_checkpoint \ --num_epoch_step1 20 \ --num_epoch_step2 30 \ --train_input EEG \ -lr1 0.00002 \ -lr2 0.00002 \ -b 32 \ -s ./checkpoints/decoding \ ================================================ FILE: scripts/train_eeg_sentiment_baseline.sh ================================================ python3 train_sentiment_baseline.py --model_name BaselineMLP --num_epoch 20 -lr 0.00005 -b 32 -s ./checkpoints/eeg_sentiment -cuda cuda:0 ================================================ FILE: scripts/train_eval_zeroshot_pipeline.sh ================================================ echo "###################################" echo "Training decoder: BART, task1-SR..." echo "###################################" echo "" python3 train_decoding.py --model_name BrainTranslator \ --task_name task1 \ --one_step \ --pretrained \ --not_load_step1_checkpoint \ --num_epoch_step1 20 \ --num_epoch_step2 30 \ -lr1 0.00005 \ -lr2 0.0000005 \ -b 32 \ -s ./checkpoints/decoding \ -cuda cuda:0 echo "###################################" echo "Training classifier: BART, filtered Stanford Sentiment Treebank..." echo "###################################" echo "" python3 train_sentiment_textbased.py \ --dataset_name SST \ --model_name pretrain_Bart \ --num_epoch 20 \ -lr 0.0001 \ -b 32 \ -s ./checkpoints/text_sentiment_classifier \ -cuda cuda:0 echo "###################################" echo "Evaluating Zero-shot pipeline: DEC(BART) + CLS(BART)" echo "###################################" echo "" python3 eval_sentiment.py --model_name ZeroShotSentimentDiscovery \ --decoder_checkpoint_path ./checkpoints/decoding/best/task1_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.pt \ --classifier_checkpoint_path ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt \ --decoder_config_path ./config/decoding/task1_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.json \ --classifier_config_path ./config/text_sentiment_classifier/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.json \ --cuda cuda:0 ================================================ FILE: scripts/train_text_sentiment_classifier.sh ================================================ python3 train_sentiment_textbased.py \ --dataset_name SST \ --model_name pretrain_Bart \ --num_epoch 20 \ -lr 0.0001 \ -b 32 \ -s ./checkpoints/text_sentiment_classifier \ -cuda cuda:0 ================================================ FILE: train_decoding.py ================================================ import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler import pickle import json import matplotlib.pyplot as plt from glob import glob import time import copy from tqdm import tqdm from transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, PegasusForConditionalGeneration, PegasusTokenizer, T5Tokenizer, T5ForConditionalGeneration, BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderConfig, EncoderDecoderModel from data import ZuCo_dataset from model_decoding import BrainTranslator, BrainTranslatorNaive, T5Translator from config import get_config def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/decoding/best/temp_decoding.pt', checkpoint_path_last = './checkpoints/decoding/last/temp_decoding.pt'): # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 100000000000 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'dev']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels in tqdm(dataloaders[phase]): # load in batch input_embeddings_batch = input_embeddings.to(device).float() input_masks_batch = input_masks.to(device) input_mask_invert_batch = input_mask_invert.to(device) target_ids_batch = target_ids.to(device) """replace padding ids in target_ids with -100""" target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100 # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch) """calculate loss""" # logits = seq2seqLMoutput.logits # 8*48*50265 # logits = logits.permute(0,2,1) # 8*50265*48 # loss = criterion(logits, target_ids_batch_label) # calculate cross entropy loss only on encoded target parts # NOTE: my criterion not used loss = seq2seqLMoutput.loss # use the BART language modeling loss # """check prediction, instance 0 of each batch""" # print('target size:', target_ids_batch.size(), ',original logits size:', logits.size(), ',target_mask size', target_mask_batch.size()) # logits = logits.permute(0,2,1) # for idx in [0]: # print(f'-- instance {idx} --') # # print('permuted logits size:', logits.size()) # probs = logits[idx].softmax(dim = 1) # # print('probs size:', probs.size()) # values, predictions = probs.topk(1) # # print('predictions before squeeze:',predictions.size()) # predictions = torch.squeeze(predictions) # # print('predictions:',predictions) # # print('target mask:', target_mask_batch[idx]) # # print('[DEBUG]target tokens:',tokenizer.decode(target_ids_batch_copy[idx])) # print('[DEBUG]predicted tokens:',tokenizer.decode(predictions)) # backward + optimize only if in training phase if phase == 'train': # with torch.autograd.detect_anomaly(): loss.sum().backward() optimizer.step() # statistics running_loss += loss.sum().item() * input_embeddings_batch.size()[0] # batch loss # print('[DEBUG]loss:',loss.item()) # print('#################################') if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] print('{} Loss: {:.4f}'.format(phase, epoch_loss)) # deep copy the model if phase == 'dev' and epoch_loss < best_loss: best_loss = epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) '''save checkpoint''' torch.save(model.state_dict(), checkpoint_path_best) print(f'update best on dev checkpoint: {checkpoint_path_best}') # with torch.set_grad_enabled(False): # traced_model_1 = torch.jit.trace(model, (torch.rand(1, 56, 840).to(device), torch.randint(1, 56).to(device), torch.rand(1, 56).to(device), torch.rand(1, 56).to(device))) # traced_model_32 = torch.jit.trace(model, (torch.rand(32, 56, 840).to(device), torch.randint(32, 56).to(device), torch.rand(32, 56).to(device), torch.rand(32, 56).to(device))) # torch.jit.save(traced_model_1, checkpoint_path_best[:-3]+'_1_jit.pt') # torch.jit.save(traced_model_32, checkpoint_path_best[:-3]+'_32_jit.pt') print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val loss: {:4f}'.format(best_loss)) torch.save(model.state_dict(), checkpoint_path_last) print(f'update last checkpoint: {checkpoint_path_last}') # load best model weights model.load_state_dict(best_model_wts) return model def show_require_grad_layers(model): print() print(' require_grad layers:') # sanity check for name, param in model.named_parameters(): if param.requires_grad: print(' ', name) if __name__ == '__main__': args = get_config('train_decoding') ''' config param''' dataset_setting = 'unique_sent' num_epochs_step1 = args['num_epoch_step1'] num_epochs_step2 = args['num_epoch_step2'] step1_lr = args['learning_rate_step1'] step2_lr = args['learning_rate_step2'] batch_size = args['batch_size'] model_name = args['model_name'] # model_name = 'BrainTranslatorNaive' # with no additional transformers # model_name = 'BrainTranslator' # task_name = 'task1' # task_name = 'task1_task2' # task_name = 'task1_task2_task3' # task_name = 'task1_task2_taskNRv2' task_name = args['task_name'] train_input = args['train_input'] print("train_input is:", train_input) save_path = args['save_path'] if not os.path.exists(save_path): os.makedirs(save_path) skip_step_one = args['skip_step_one'] load_step1_checkpoint = args['load_step1_checkpoint'] use_random_init = args['use_random_init'] device_ids = [0] # device setting if use_random_init and skip_step_one: step2_lr = 5*1e-4 print(f'[INFO]using model: {model_name}') if skip_step_one: save_name = f'{task_name}_finetune_{model_name}_skipstep1_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}_{train_input}' else: save_name = f'{task_name}_finetune_{model_name}_2steptraining_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}_{train_input}' if use_random_init: save_name = 'randinit_' + save_name save_path_best = os.path.join(save_path, 'best') if not os.path.exists(save_path_best): os.makedirs(save_path_best) output_checkpoint_name_best = os.path.join(save_path_best, f'{save_name}.pt') save_path_last = os.path.join(save_path, 'last') if not os.path.exists(save_path_last): os.makedirs(save_path_last) output_checkpoint_name_last = os.path.join(save_path_last, f'{save_name}.pt') # subject_choice = 'ALL subject_choice = args['subjects'] print(f'![Debug]using {subject_choice}') # eeg_type_choice = 'GD eeg_type_choice = args['eeg_type'] print(f'[INFO]eeg type {eeg_type_choice}') # bands_choice = ['_t1'] # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] bands_choice = args['eeg_bands'] print(f'[INFO]using bands {bands_choice}') ''' set random seeds ''' seed_val = 312 np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) ''' set up device ''' # use cuda if torch.cuda.is_available(): # dev = "cuda:3" dev = args['cuda'] else: dev = "cpu" # CUDA_VISIBLE_DEVICES=0,1,2,3 device = torch.device(dev) print(f'[INFO]using device {dev}') print() ''' set up dataloader ''' whole_dataset_dicts = [] if 'task1' in task_name: dataset_path_task1 = '/data/johj/ZuCo_data/task1-SR/task1_source.pkl' with open(dataset_path_task1, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) if 'task2' in task_name: dataset_path_task2 = '/data/johj/ZuCo_data/task2-NR/task2_source.pkl' with open(dataset_path_task2, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) if 'task3' in task_name: dataset_path_task3 = '/data/johj/ZuCo_data/task3-TSR/task3_source.pkl' with open(dataset_path_task3, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) if 'taskNRv2' in task_name: dataset_path_taskNRv2 = '/data/johj/ZuCo_data/task2-NR-2.0/taskNRv2_source.pkl' with open(dataset_path_taskNRv2, 'rb') as handle: whole_dataset_dicts.append(pickle.load(handle)) print() """save config""" cfg_dir = './config/decoding/' if not os.path.exists(cfg_dir): os.makedirs(cfg_dir) with open(os.path.join(cfg_dir,f'{save_name}.json'), 'w') as out_config: json.dump(args, out_config, indent = 4) if model_name in ['BrainTranslator','BrainTranslatorNaive']: tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') elif model_name == 'PegasusTranslator': tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum') elif model_name == 'T5Translator': tokenizer = T5Tokenizer.from_pretrained("t5-large") #tokenizer.set_prefix_tokens(language='english') # train dataset train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=train_input) # dev dataset dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=train_input) # test dataset # test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)} print('[INFO]train_set size: ', len(train_set)) print('[INFO]dev_set size: ', len(dev_set)) # print('[INFO]test_set size: ', len(test_set)) # train dataloader train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4) # dev dataloader val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4) # dataloaders dataloaders = {'train':train_dataloader, 'dev':val_dataloader} ''' set up model ''' if model_name == 'BrainTranslator': if use_random_init: config = BartConfig.from_pretrained('facebook/bart-large') pretrained = BartForConditionalGeneration(config) else: pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large') model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif model_name == 'BrainTranslatorNaive': pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large') model = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif model_name == 'PegasusTranslator': pretrained = PegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum') model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) elif model_name == 'T5Translator': pretrained = T5ForConditionalGeneration.from_pretrained("t5-large") model = T5Translator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) model.to(device) model = torch.nn.DataParallel(model, device_ids=device_ids) ''' training loop ''' ###################################################### '''step one trainig: freeze most of BART params''' ###################################################### # closely follow BART paper if model_name in ['BrainTranslator','BrainTranslatorNaive', 'PegasusTranslator', 'T5Translator']: for name, param in model.named_parameters(): if param.requires_grad and 'pretrained' in name: if ('shared' in name) or ('embed_positions' in name) or ('encoder.layers.0' in name): continue else: param.requires_grad = False elif model_name == 'BertGeneration': for name, param in model.named_parameters(): if param.requires_grad and 'pretrained' in name: if ('embeddings' in name) or ('encoder.layer.0' in name): continue else: param.requires_grad = False if skip_step_one: if load_step1_checkpoint: stepone_checkpoint = 'path_to_step_1_checkpoint.pt' print(f'skip step one, load checkpoint: {stepone_checkpoint}') model.load_state_dict(torch.load(stepone_checkpoint)) else: print('skip step one, start from scratch at step two') else: ''' set up optimizer and scheduler''' optimizer_step1 = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=step1_lr, momentum=0.9) exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.1) ''' set up loss function ''' criterion = nn.CrossEntropyLoss() print('=== start Step1 training ... ===') # print training layers show_require_grad_layers(model) # return best loss model from step1 training model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs_step1, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last) ###################################################### '''step two trainig: update whole model for a few iterations''' ###################################################### for name, param in model.named_parameters(): param.requires_grad = True ''' set up optimizer and scheduler''' optimizer_step2 = optim.SGD(model.parameters(), lr=step2_lr, momentum=0.9) exp_lr_scheduler_step2 = lr_scheduler.StepLR(optimizer_step2, step_size=30, gamma=0.1) ''' set up loss function ''' criterion = nn.CrossEntropyLoss() print() print('=== start Step2 training ... ===') # print training layers show_require_grad_layers(model) '''main loop''' trained_model = train_model(dataloaders, device, model, criterion, optimizer_step2, exp_lr_scheduler_step2, num_epochs=num_epochs_step2, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last) # '''save checkpoint''' # torch.save(trained_model.state_dict(), os.path.join(save_path,output_checkpoint_name)) ================================================ FILE: train_sentiment_baseline.py ================================================ import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler from torch.nn.utils.rnn import pack_padded_sequence import pickle import json import matplotlib.pyplot as plt from glob import glob import time import copy from tqdm import tqdm from transformers import BertTokenizer, BertLMHeadModel, BertConfig from data import ZuCo_dataset from model_sentiment import BaselineMLPSentence, BaselineLSTM, NaiveFineTunePretrainedBert from config import get_config # Function to calculate the accuracy of our predictions vs labels def flat_accuracy(preds, labels): # preds: numpy array: N * 3 # labels: numpy array: N pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() return np.sum(pred_flat == labels_flat) / len(labels_flat) def flat_accuracy_top_k(preds, labels,k): topk_preds = [] for pred in preds: topk = pred.argsort()[-k:][::-1] topk_preds.append(list(topk)) # print(topk_preds) topk_preds = list(topk_preds) right_count = 0 # print(len(labels)) for i in range(len(labels)): l = labels[i][0] if l in topk_preds[i]: right_count+=1 return right_count/len(labels) def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/eeg_sentiment/best/test.pt', checkpoint_path_last = './checkpoints/eeg_sentiment/last/test.pt'): # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 100000000000 best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'dev']: total_accuracy = 0.0 if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]): input_word_eeg_features = input_word_eeg_features.to(device).float() sent_level_EEG = sent_level_EEG.to(device) input_masks = input_masks.to(device) sentiment_labels = sentiment_labels.to(device) # zero the parameter gradients optimizer.zero_grad() if isinstance(model, BaselineMLPSentence): # forward logits = model(sent_level_EEG) # before softmax # calculate loss loss = criterion(logits, sentiment_labels) elif isinstance(model, BaselineLSTM): x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False) logits = model(x_packed) # calculate loss loss = criterion(logits, sentiment_labels) elif isinstance(model, NaiveFineTunePretrainedBert): output = model(input_word_eeg_features, input_masks, sentiment_labels) logits = output.logits loss = output.loss # backward + optimize only if in training phase if phase == 'train': # with torch.autograd.detect_anomaly(): loss.backward() optimizer.step() # calculate accuracy preds_cpu = logits.detach().cpu().numpy() label_cpu = sentiment_labels.cpu().numpy() total_accuracy += flat_accuracy(preds_cpu, label_cpu) # statistics running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss # print('[DEBUG]loss:',loss.item()) # print('#################################') if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = total_accuracy / len(dataloaders[phase]) print('{} Loss: {:.4f}'.format(phase, epoch_loss)) print('{} Acc: {:.4f}'.format(phase, epoch_acc)) # deep copy the model if phase == 'dev' and (epoch_acc > best_acc): best_loss = epoch_loss best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) '''save checkpoint''' torch.save(model.state_dict(), checkpoint_path_best) print(f'update best on dev checkpoint: {checkpoint_path_best}') print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val loss: {:4f}'.format(best_loss)) print('Best val acc: {:4f}'.format(best_acc)) torch.save(model.state_dict(), checkpoint_path_last) print(f'update last checkpoint: {checkpoint_path_last}') # load best model weights model.load_state_dict(best_model_wts) return model if __name__ == '__main__': args = get_config('train_sentiment_baseline') ''' config param''' num_epochs = args['num_epoch'] step_lr = args['learning_rate'] '''dataset division''' dataset_setting = 'unique_sent' # subject_choice = 'ALL subject_choice = args['subjects'] print(f'![Debug]using {subject_choice}') # eeg_type_choice = 'GD eeg_type_choice = args['eeg_type'] print(f'[INFO]eeg type {eeg_type_choice}') # bands_choice = ['_t1'] # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] bands_choice = args['eeg_bands'] print(f'[INFO]using bands {bands_choice}') '''model name''' # model_name = 'BaselineMLP' # model_name = 'BaselineLSTM' # model_name = 'NaiveFinetuneBert' model_name = args['model_name'] batch_size = 32 save_path = args['save_path'] save_name = f'{model_name}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}' if model_name == 'BaselineLSTM': num_layers = 4 save_name = f'{model_name}_numLayers-{num_layers}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}' output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' ''' set random seeds ''' seed_val = 312 np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) ''' set up device ''' # use cuda if torch.cuda.is_available(): dev = args['cuda'] else: dev = "cpu" # CUDA_VISIBLE_DEVICES=0,1,2,3 device = torch.device(dev) print(f'[INFO]using device {dev}') ''' load pickle''' whole_dataset_dict = [] dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' with open(dataset_path_task1, 'rb') as handle: whole_dataset_dict.append(pickle.load(handle)) '''set up tokenizer''' tokenizer = BertTokenizer.from_pretrained('bert-base-cased') ''' set up dataloader ''' # train dataset train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) # dev dataset dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) # test dataset # test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice) dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)} print('[INFO]train_set size: ', len(train_set)) print('[INFO]dev_set size: ', len(dev_set)) # train dataloader train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4) # dev dataloader val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4) # dataloaders dataloaders = {'train':train_dataloader, 'dev':val_dataloader} ''' set up model ''' if model_name == 'BaselineMLP': print('[INFO]Model: BaselineMLP') model = BaselineMLPSentence(input_dim = 105*len(bands_choice), hidden_dim = 128, output_dim = 3) elif model_name == 'BaselineLSTM': print('[INFO]Model: BaselineLSTM') model = BaselineLSTM(input_dim = 105*len(bands_choice), hidden_dim = 256, output_dim = 3, num_layers = num_layers) elif model_name == 'NaiveFinetuneBert': print('[INFO]Model: NaiveFinetuneBert') model = NaiveFineTunePretrainedBert(input_dim = 105*len(bands_choice), hidden_dim = 768, output_dim = 3) model.to(device) """save config""" with open(f'./config/eeg_sentiment/{save_name}.json', 'w') as out_config: json.dump(args, out_config, indent = 4) ''' training loop ''' ''' set up optimizer and scheduler''' optimizer_step1 = optim.SGD(model.parameters(), lr=step_lr, momentum=0.9) exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.5) ''' set up loss function ''' criterion = nn.CrossEntropyLoss() print('=== start training ... ===') # return best loss model from step1 training model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last) ================================================ FILE: train_sentiment_textbased.py ================================================ import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, random_split import pickle import json import matplotlib.pyplot as plt from glob import glob import time import copy from tqdm import tqdm from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification from data import ZuCo_dataset, SST_tenary_dataset from model_sentiment import FineTunePretrainedTwoStep from config import get_config # Function to calculate the accuracy of our predictions vs labels def flat_accuracy(preds, labels): # preds: numpy array: N * 3 # labels: numpy array: N pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() return np.sum(pred_flat == labels_flat) / len(labels_flat) def flat_accuracy_top_k(preds, labels,k): topk_preds = [] for pred in preds: topk = pred.argsort()[-k:][::-1] topk_preds.append(list(topk)) # print(topk_preds) topk_preds = list(topk_preds) right_count = 0 # print(len(labels)) for i in range(len(labels)): l = labels[i][0] if l in topk_preds[i]: right_count+=1 return right_count/len(labels) def train_model_ZuCo(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'): # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 100000000000 best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'dev']: total_accuracy = 0.0 if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]): # input_word_eeg_features = input_word_eeg_features.to(device).float() # input_masks = input_masks.to(device) # input_mask_invert = input_mask_invert.to(device) target_ids = target_ids.to(device) target_mask = target_mask.to(device) sentiment_labels = sentiment_labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels) logits = output.logits loss = output.loss # backward + optimize only if in training phase if phase == 'train': # with torch.autograd.detect_anomaly(): loss.backward() optimizer.step() # calculate accuracy preds_cpu = logits.detach().cpu().numpy() label_cpu = sentiment_labels.cpu().numpy() total_accuracy += flat_accuracy(preds_cpu, label_cpu) # statistics running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss # print('[DEBUG]loss:',loss.item()) # print('#################################') if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = total_accuracy / len(dataloaders[phase]) print('{} Loss: {:.4f}'.format(phase, epoch_loss)) print('{} Acc: {:.4f}'.format(phase, epoch_acc)) # deep copy the model if phase == 'dev' and (epoch_acc > best_acc): best_loss = epoch_loss best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) '''save checkpoint''' torch.save(model.state_dict(), checkpoint_path_best) print(f'update best on dev checkpoint: {checkpoint_path_best}') print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val loss: {:4f}'.format(best_loss)) print('Best val acc: {:4f}'.format(best_acc)) torch.save(model.state_dict(), checkpoint_path_last) print(f'update last checkpoint: {checkpoint_path_last}') # write to log with open(output_log_file_name, 'w') as outlog: outlog.write(f'best val loss: {best_loss}\n') outlog.write('Best val acc: {:4f}'.format(best_acc)) # load best model weights model.load_state_dict(best_model_wts) return model def train_model_SST(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'): # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 100000000000 best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'dev']: total_accuracy = 0.0 if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for input_ids,input_masks,sentiment_labels in tqdm(dataloaders[phase]): input_ids = input_ids.to(device) input_masks = input_masks.to(device) sentiment_labels = sentiment_labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward output = model(input_ids = input_ids, attention_mask = input_masks, return_dict = True, labels = sentiment_labels) logits = output.logits loss = output.loss # backward + optimize only if in training phase if phase == 'train': # with torch.autograd.detect_anomaly(): loss.backward() optimizer.step() # calculate accuracy preds_cpu = logits.detach().cpu().numpy() label_cpu = sentiment_labels.cpu().numpy() total_accuracy += flat_accuracy(preds_cpu, label_cpu) # statistics running_loss += loss.item() * input_ids.size()[0] # batch loss # print('[DEBUG]loss:',loss.item()) # print('#################################') if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = total_accuracy / len(dataloaders[phase]) print('{} Loss: {:.4f}'.format(phase, epoch_loss)) print('{} Acc: {:.4f}'.format(phase, epoch_acc)) # deep copy the model if phase == 'dev' and (epoch_acc > best_acc): best_loss = epoch_loss best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) '''save checkpoint''' torch.save(model.state_dict(), checkpoint_path_best) print(f'update best on dev checkpoint: {checkpoint_path_best}') print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val loss: {:4f}'.format(best_loss)) print('Best val acc: {:4f}'.format(best_acc)) torch.save(model.state_dict(), checkpoint_path_last) print(f'update last checkpoint: {checkpoint_path_last}') # load best model weights model.load_state_dict(best_model_wts) return model if __name__ == '__main__': args = get_config('train_sentiment_textbased') ''' config param''' num_epoch = args['num_epoch'] # lr = 1e-3 # Bert, RoBerta # lr = 1e-4 # Bart lr = args['learning_rate'] dataset_name = args['dataset_name'] # zero-shot setting: using external dataset from stanford sentiment treebank, pass in 'SST'; or pass in 'ZuCo' to train on ZuCo's text-sentiment pairs dataset_setting = 'unique_sent' batch_size = args['batch_size'] # model_name = 'pretrain_Bert' # model_name = 'pretrain_RoBerta' # model_name = 'pretrain_Bart' model_name = args['model_name'] print(f'[INFO]model name: {model_name}') save_path = args['save_path'] if dataset_name == 'ZuCo': # subject_choice = 'ALL subject_choice = args['subjects'] print(f'![Debug]using {subject_choice}') # eeg_type_choice = 'GD eeg_type_choice = args['eeg_type'] print(f'[INFO]eeg type {eeg_type_choice}') # bands_choice = ['_t1'] # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] bands_choice = args['eeg_bands'] print(f'[INFO]using bands {bands_choice}') save_name = f'Textbased_ZuCo_{model_name}_b{batch_size}_{num_epoch}_{lr}_{dataset_setting}_{eeg_type_choice}' elif dataset_name == 'SST': save_name = f'Textbased_StanfordSentitmentTreeband_{model_name}_b{batch_size}_{num_epoch}_{lr}' output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' ''' set random seeds ''' seed_val = 312 np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) ''' set up device ''' # use cuda if torch.cuda.is_available(): dev = args['cuda'] else: dev = "cpu" # CUDA_VISIBLE_DEVICES=0,1,2,3 device = torch.device(dev) print(f'[INFO]using device {dev}') ''' load pickle ''' if dataset_name == 'ZuCo': whole_dataset_dict = [] dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' with open(dataset_path_task1, 'rb') as handle: whole_dataset_dict.append(pickle.load(handle)) '''tokenizer''' if model_name == 'pretrain_Bert': print('[INFO]pretrained checkpoint: bert-base-cased') tokenizer = BertTokenizer.from_pretrained('bert-base-cased') elif model_name == 'pretrain_RoBerta': print('[INFO]pretrained checkpoint: roberta-base') tokenizer = RobertaTokenizer.from_pretrained('roberta-base') elif model_name == 'pretrain_Bart': print('[INFO]pretrained checkpoint: bart-large') tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') ''' set up dataloader ''' if dataset_name == 'ZuCo': # train dataset train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) # dev dataset dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) elif dataset_name == 'SST': SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json')) SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer) train_size = int(0.9 * len(SST_dataset)) val_size = len(SST_dataset) - train_size train_set, dev_set = random_split(SST_dataset, [train_size, val_size]) print('{:>5,} training samples'.format(len(train_set))) print('{:>5,} validation samples'.format(len(dev_set))) dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)} print('[INFO]train_set size: ', len(train_set)) print('[INFO]dev_set size: ', len(dev_set)) # train dataloader train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4) # dev dataloader val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4) # dataloaders dataloaders = {'train':train_dataloader, 'dev':val_dataloader} ''' set up model ''' if model_name == 'pretrain_Bert': model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3) elif model_name == 'pretrain_RoBerta': model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3) elif model_name == 'pretrain_Bart': model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels = 3) model.to(device) """save config""" with open(f'./config/text_sentiment_classifier/{save_name}.json', 'w') as out_config: json.dump(args, out_config, indent = 4) ''' training loop ''' ###################################################### '''step one trainig: freeze most of BART params''' ###################################################### ''' set up optimizer and scheduler''' optimizer_step1 = optim.SGD(model.parameters(), lr=lr, momentum=0.9) exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=10, gamma=0.1) # TODO: rethink about the loss function ''' set up loss function ''' criterion = nn.CrossEntropyLoss() # return best loss model from step1 training print(f'=== start training {dataset_name} ... ===') if dataset_name == 'ZuCo': model = train_model_ZuCo(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last) elif dataset_name == 'SST': model = train_model_SST(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last) ================================================ FILE: util/construct_dataset_mat_to_pickle_v1.py ================================================ import scipy.io as io import h5py import os import json from glob import glob from tqdm import tqdm import numpy as np import pickle import argparse parser = argparse.ArgumentParser(description='Specify task name for converting ZuCo v1.0 Mat file to Pickle') parser.add_argument('-t', '--task_name', help='name of the task in /dataset/ZuCo, choose from {task1-SR,task2-NR,task3-TSR}', required=True) args = vars(parser.parse_args()) """config""" version = 'v1' # 'old' # version = 'v2' # 'new' task_name = args['task_name'] # task_name = 'task1-SR' # task_name = 'task2-NR' # task_name = 'task3-TSR' print('##############################') print(f'start processing ZuCo {task_name}...') if version == 'v1': # old version input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files' elif version == 'v2': # new version, mat73 input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files' output_dir = f'./dataset/ZuCo/{task_name}/pickle' if not os.path.exists(output_dir): os.makedirs(output_dir) """load files""" mat_files = glob(os.path.join(input_mat_files_dir,'*.mat')) mat_files = sorted(mat_files) if len(mat_files) == 0: print(f'No mat files found for {task_name}') quit() dataset_dict = {} for mat_file in tqdm(mat_files): subject_name = os.path.basename(mat_file).split('_')[0].replace('results','').strip() dataset_dict[subject_name] = [] if version == 'v1': matdata = io.loadmat(mat_file, squeeze_me=True, struct_as_record=False)['sentenceData'] elif version == 'v2': matdata = h5py.File(mat_file,'r') print(matdata) for sent in matdata: word_data = sent.word if not isinstance(word_data, float): # sentence level: sent_obj = {'content':sent.content} sent_obj['sentence_level_EEG'] = {'mean_t1':sent.mean_t1, 'mean_t2':sent.mean_t2, 'mean_a1':sent.mean_a1, 'mean_a2':sent.mean_a2, 'mean_b1':sent.mean_b1, 'mean_b2':sent.mean_b2, 'mean_g1':sent.mean_g1, 'mean_g2':sent.mean_g2} if task_name == 'task1-SR': sent_obj['answer_EEG'] = {'answer_mean_t1':sent.answer_mean_t1, 'answer_mean_t2':sent.answer_mean_t2, 'answer_mean_a1':sent.answer_mean_a1, 'answer_mean_a2':sent.answer_mean_a2, 'answer_mean_b1':sent.answer_mean_b1, 'answer_mean_b2':sent.answer_mean_b2, 'answer_mean_g1':sent.answer_mean_g1, 'answer_mean_g2':sent.answer_mean_g2} # word level: sent_obj['word'] = [] word_tokens_has_fixation = [] word_tokens_with_mask = [] word_tokens_all = [] for word in word_data: word_obj = {'content':word.content} word_tokens_all.append(word.content) # TODO: add more version of word level eeg: GD, SFD, GPT word_obj['nFixations'] = word.nFixations if word.nFixations > 0: word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':word.FFD_t1, 'FFD_t2':word.FFD_t2, 'FFD_a1':word.FFD_a1, 'FFD_a2':word.FFD_a2, 'FFD_b1':word.FFD_b1, 'FFD_b2':word.FFD_b2, 'FFD_g1':word.FFD_g1, 'FFD_g2':word.FFD_g2}} word_obj['word_level_EEG']['TRT'] = {'TRT_t1':word.TRT_t1, 'TRT_t2':word.TRT_t2, 'TRT_a1':word.TRT_a1, 'TRT_a2':word.TRT_a2, 'TRT_b1':word.TRT_b1, 'TRT_b2':word.TRT_b2, 'TRT_g1':word.TRT_g1, 'TRT_g2':word.TRT_g2} word_obj['word_level_EEG']['GD'] = {'GD_t1':word.GD_t1, 'GD_t2':word.GD_t2, 'GD_a1':word.GD_a1, 'GD_a2':word.GD_a2, 'GD_b1':word.GD_b1, 'GD_b2':word.GD_b2, 'GD_g1':word.GD_g1, 'GD_g2':word.GD_g2} sent_obj['word'].append(word_obj) word_tokens_has_fixation.append(word.content) word_tokens_with_mask.append(word.content) else: word_tokens_with_mask.append('[MASK]') # if a word has no fixation, use sentence level feature # word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':sent.mean_t1, 'FFD_t2':sent.mean_t2, 'FFD_a1':sent.mean_a1, 'FFD_a2':sent.mean_a2, 'FFD_b1':sent.mean_b1, 'FFD_b2':sent.mean_b2, 'FFD_g1':sent.mean_g1, 'FFD_g2':sent.mean_g2}} # word_obj['word_level_EEG']['TRT'] = {'TRT_t1':sent.mean_t1, 'TRT_t2':sent.mean_t2, 'TRT_a1':sent.mean_a1, 'TRT_a2':sent.mean_a2, 'TRT_b1':sent.mean_b1, 'TRT_b2':sent.mean_b2, 'TRT_g1':sent.mean_g1, 'TRT_g2':sent.mean_g2} # NOTE:if a word has no fixation, simply skip it continue sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation sent_obj['word_tokens_with_mask'] = word_tokens_with_mask sent_obj['word_tokens_all'] = word_tokens_all dataset_dict[subject_name].append(sent_obj) else: print(f'missing sent: subj:{subject_name} content:{sent.content}, return None') dataset_dict[subject_name].append(None) continue # print(dataset_dict.keys()) # print(dataset_dict[subject_name][0].keys()) # print(dataset_dict[subject_name][0]['content']) # print(dataset_dict[subject_name][0]['word'][0].keys()) # print(dataset_dict[subject_name][0]['word'][0]['word_level_EEG']['FFD']) """output""" output_name = f'{task_name}-dataset.pickle' # with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out: # json.dump(dataset_dict,out,indent = 4) with open(os.path.join(output_dir,output_name), 'wb') as handle: pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) print('write to:', os.path.join(output_dir,output_name)) """sanity check""" # check dataset with open(os.path.join(output_dir,output_name), 'rb') as handle: whole_dataset = pickle.load(handle) print('subjects:', whole_dataset.keys()) if version == 'v1': print('num of sent:', len(whole_dataset['ZAB'])) print() ================================================ FILE: util/construct_dataset_mat_to_pickle_v2.py ================================================ import os import numpy as np import h5py import data_loading_helpers_modified as dh from glob import glob from tqdm import tqdm import pickle task = "NR" rootdir = "./dataset/ZuCo/task2-NR-2.0/Matlab_files/" print('##############################') print(f'start processing ZuCo task2-NR-2.0...') dataset_dict = {} for file in tqdm(os.listdir(rootdir)): if file.endswith(task+".mat"): file_name = rootdir + file # print('file name:', file_name) subject = file_name.split("ts")[1].split("_")[0] # print('subject: ', subject) # exclude YMH due to incomplete data because of dyslexia if subject != 'YMH': assert subject not in dataset_dict dataset_dict[subject] = [] f = h5py.File(file_name,'r') # print('keys in f:', list(f.keys())) sentence_data = f['sentenceData'] # print('keys in sentence_data:', list(sentence_data.keys())) # sent level eeg # mean_t1 = np.squeeze(f[sentence_data['mean_t1'][0][0]][()]) mean_t1_objs = sentence_data['mean_t1'] mean_t2_objs = sentence_data['mean_t2'] mean_a1_objs = sentence_data['mean_a1'] mean_a2_objs = sentence_data['mean_a2'] mean_b1_objs = sentence_data['mean_b1'] mean_b2_objs = sentence_data['mean_b2'] mean_g1_objs = sentence_data['mean_g1'] mean_g2_objs = sentence_data['mean_g2'] rawData = sentence_data['rawData'] contentData = sentence_data['content'] # print('contentData shape:', contentData.shape, 'dtype:', contentData.dtype) omissionR = sentence_data['omissionRate'] wordData = sentence_data['word'] for idx in range(len(rawData)): # get sentence string obj_reference_content = contentData[idx][0] sent_string = dh.load_matlab_string(f[obj_reference_content]) # print('sentence string:', sent_string) sent_obj = {'content':sent_string} # get sentence level EEG sent_obj['sentence_level_EEG'] = { 'mean_t1':np.squeeze(f[mean_t1_objs[idx][0]][()]), 'mean_t2':np.squeeze(f[mean_t2_objs[idx][0]][()]), 'mean_a1':np.squeeze(f[mean_a1_objs[idx][0]][()]), 'mean_a2':np.squeeze(f[mean_a2_objs[idx][0]][()]), 'mean_b1':np.squeeze(f[mean_b1_objs[idx][0]][()]), 'mean_b2':np.squeeze(f[mean_b2_objs[idx][0]][()]), 'mean_g1':np.squeeze(f[mean_g1_objs[idx][0]][()]), 'mean_g2':np.squeeze(f[mean_g2_objs[idx][0]][()]) } # print(sent_obj) sent_obj['word'] = [] # get word level data word_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask = dh.extract_word_level_data(f, f[wordData[idx][0]]) if word_data == {}: print(f'missing sent: subj:{subject} content:{sent_string}, append None') dataset_dict[subject].append(None) continue elif len(word_tokens_all) == 0: print(f'no word level features: subj:{subject} content:{sent_string}, append None') dataset_dict[subject].append(None) continue else: for widx in range(len(word_data)): data_dict = word_data[widx] word_obj = {'content':data_dict['content'], 'nFixations': data_dict['nFix']} if 'GD_EEG' in data_dict: # print('has fixation: ', data_dict['content']) gd = data_dict["GD_EEG"] ffd = data_dict["FFD_EEG"] trt = data_dict["TRT_EEG"] assert len(gd) == len(trt) == len(ffd) == 8 word_obj['word_level_EEG'] = { 'GD':{'GD_t1':gd[0], 'GD_t2':gd[1], 'GD_a1':gd[2], 'GD_a2':gd[3], 'GD_b1':gd[4], 'GD_b2':gd[5], 'GD_g1':gd[6], 'GD_g2':gd[7]}, 'FFD':{'FFD_t1':ffd[0], 'FFD_t2':ffd[1], 'FFD_a1':ffd[2], 'FFD_a2':ffd[3], 'FFD_b1':ffd[4], 'FFD_b2':ffd[5], 'FFD_g1':ffd[6], 'FFD_g2':ffd[7]}, 'TRT':{'TRT_t1':trt[0], 'TRT_t2':trt[1], 'TRT_a1':trt[2], 'TRT_a2':trt[3], 'TRT_b1':trt[4], 'TRT_b2':trt[5], 'TRT_g1':trt[6], 'TRT_g2':trt[7]} } sent_obj['word'].append(word_obj) sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation sent_obj['word_tokens_with_mask'] = word_tokens_with_mask sent_obj['word_tokens_all'] = word_tokens_all # print(sent_obj.keys()) # print(len(sent_obj['word'])) # print(sent_obj['word'][0]) dataset_dict[subject].append(sent_obj) """output""" task_name = 'task2-NR-2.0' if dataset_dict == {}: print(f'No mat file found for {task_name}') quit() output_dir = f'./dataset/ZuCo/{task_name}/pickle' if not os.path.exists(output_dir): os.makedirs(output_dir) output_name = f'{task_name}-dataset.pickle' # with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out: # json.dump(dataset_dict,out,indent = 4) with open(os.path.join(output_dir,output_name), 'wb') as handle: pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) print('write to:', os.path.join(output_dir,output_name)) """sanity check""" print('subjects:', dataset_dict.keys()) print('num of sent:', len(dataset_dict['YAC'])) ================================================ FILE: util/data_loading_helpers_modified.py ================================================ import numpy as np import re eeg_float_resolution=np.float16 Alpha_ffd_names = ['FFD_a1', 'FFD_a1_diff', 'FFD_a2', 'FFD_a2_diff'] Beta_ffd_names = ['FFD_b1', 'FFD_b1_diff', 'FFD_b2', 'FFD_b2_diff'] Gamma_ffd_names = ['FFD_g1', 'FFD_g1_diff', 'FFD_g2', 'FFD_g2_diff'] Theta_ffd_names = ['FFD_t1', 'FFD_t1_diff', 'FFD_t2', 'FFD_t2_diff'] Alpha_gd_names = ['GD_a1', 'GD_a1_diff', 'GD_a2', 'GD_a2_diff'] Beta_gd_names = ['GD_b1', 'GD_b1_diff', 'GD_b2', 'GD_b2_diff'] Gamma_gd_names = ['GD_g1', 'GD_g1_diff', 'GD_g2', 'GD_g2_diff'] Theta_gd_names = ['GD_t1', 'GD_t1_diff', 'GD_t2', 'GD_t2_diff'] Alpha_gpt_names = ['GPT_a1', 'GPT_a1_diff', 'GPT_a2', 'GPT_a2_diff'] Beta_gpt_names = ['GPT_b1', 'GPT_b1_diff', 'GPT_b2', 'GPT_b2_diff'] Gamma_gpt_names = ['GPT_g1', 'GPT_g1_diff', 'GPT_g2', 'GPT_g2_diff'] Theta_gpt_names = ['GPT_t1', 'GPT_t1_diff', 'GPT_t2', 'GPT_t2_diff'] Alpha_sfd_names = ['SFD_a1', 'SFD_a1_diff', 'SFD_a2', 'SFD_a2_diff'] Beta_sfd_names = ['SFD_b1', 'SFD_b1_diff', 'SFD_b2', 'SFD_b2_diff'] Gamma_sfd_names = ['SFD_g1', 'SFD_g1_diff', 'SFD_g2', 'SFD_g2_diff'] Theta_sfd_names = ['SFD_t1', 'SFD_t1_diff', 'SFD_t2', 'SFD_t2_diff'] Alpha_trt_names = ['TRT_a1', 'TRT_a1_diff', 'TRT_a2', 'TRT_a2_diff'] Beta_trt_names = ['TRT_b1', 'TRT_b1_diff', 'TRT_b2', 'TRT_b2_diff'] Gamma_trt_names = ['TRT_g1', 'TRT_g1_diff', 'TRT_g2', 'TRT_g2_diff'] Theta_trt_names = ['TRT_t1', 'TRT_t1_diff', 'TRT_t2', 'TRT_t2_diff'] # IF YOU CHANGE THOSE YOU MUST ALSO CHANGE CONSTANTS Alpha_features = Alpha_ffd_names + Alpha_gd_names + Alpha_gpt_names + Alpha_trt_names# + Alpha_sfd_names Beta_features = Beta_ffd_names + Beta_gd_names + Beta_gpt_names + Beta_trt_names# + Beta_sfd_names Gamma_features = Gamma_ffd_names + Gamma_gd_names + Gamma_gpt_names + Gamma_trt_names# + Gamma_sfd_names Theta_features = Theta_ffd_names + Theta_gd_names + Theta_gpt_names + Theta_trt_names# + Theta_sfd_names # print(Alpha_features) # GD_EEG_feautres def extract_all_fixations(data_container, word_data_object, float_resolution = np.float16): """ Extracts all fixations from a word data object :param data_container: (h5py) Container of the whole data, h5py object :param word_data_object: (h5py) Container of fixation objects, h5py object :param float_resolution: (type) Resolution to which data re to be converted, used for data compression :return: fixations_data (list) Data arrays representing each fixation """ word_data = data_container[word_data_object] fixations_data = [] if len(word_data.shape) > 1: for fixation_idx in range(word_data.shape[0]): fixations_data.append(np.array(data_container[word_data[fixation_idx][0]]).astype(float_resolution)) return fixations_data def is_real_word(word): """ Check if the word is a real word :param word: (str) word string :return: is_word (bool) True if it is a real word """ is_word = re.search('[a-zA-Z0-9]', word) return is_word def load_matlab_string(matlab_extracted_object): """ Converts a string loaded from h5py into a python string :param matlab_extracted_object: (h5py) matlab string object :return: extracted_string (str) translated string """ extracted_string = u''.join(chr(c[0]) for c in matlab_extracted_object) return extracted_string def extract_word_level_data(data_container, word_objects, eeg_float_resolution = np.float16): """ Extracts word level data for a specific sentence :param data_container: (h5py) Container of the whole data, h5py object :param word_objects: (h5py) Container of all word data for a specific sentence :param eeg_float_resolution: (type) Resolution with which to save EEG, used for data compression :return: word_level_data (dict) Contains all word level data indexed by their index number in the sentence, together with the reading order, indexed by "word_reading_order" """ available_objects = list(word_objects) #print(available_objects) #print(len(available_objects)) # print('available_objects:', available_objects) if isinstance(available_objects[0], str): contentData = word_objects['content'] #fixations_order_per_word = [] if "rawEEG" in available_objects: rawData = word_objects['rawEEG'] etData = word_objects['rawET'] ffdData = word_objects['FFD'] gdData = word_objects['GD'] gptData = word_objects['GPT'] trtData = word_objects['TRT'] try: sfdData = word_objects['SFD'] except KeyError: print("no SFD!") sfdData = [] nFixData = word_objects['nFixations'] fixPositions = word_objects["fixPositions"] Alpha_features_data = [word_objects[feature] for feature in Alpha_features] Beta_features_data = [word_objects[feature] for feature in Beta_features] Gamma_features_data = [word_objects[feature] for feature in Gamma_features] Theta_features_data = [word_objects[feature] for feature in Theta_features] #### GD_EEG_features = [word_objects[feature] for feature in ['GD_t1','GD_t2','GD_a1','GD_a2','GD_b1','GD_b2','GD_g1','GD_g2']] FFD_EEG_features = [word_objects[feature] for feature in ['FFD_t1','FFD_t2','FFD_a1','FFD_a2','FFD_b1','FFD_b2','FFD_g1','FFD_g2']] TRT_EEG_features = [word_objects[feature] for feature in ['TRT_t1','TRT_t2','TRT_a1','TRT_a2','TRT_b1','TRT_b2','TRT_g1','TRT_g2']] #### assert len(contentData) == len(etData) == len(rawData), "different amounts of different data!!" zipped_data = zip(rawData, etData, contentData, ffdData, gdData, gptData, trtData, sfdData, nFixData, fixPositions) word_level_data = {} word_idx = 0 word_tokens_has_fixation = [] word_tokens_with_mask = [] word_tokens_all = [] for raw_eegs_obj, ets_obj, word_obj, ffd, gd, gpt, trt, sfd, nFix, fixPos in zipped_data: word_string = load_matlab_string(data_container[word_obj[0]]) if is_real_word(word_string): data_dict = {} data_dict["RAW_EEG"] = extract_all_fixations(data_container, raw_eegs_obj[0], eeg_float_resolution) data_dict["RAW_ET"] = extract_all_fixations(data_container, ets_obj[0], np.float32) data_dict["FFD"] = data_container[ffd[0]][()][0, 0] if len(data_container[ffd[0]][()].shape) == 2 else None data_dict["GD"] = data_container[gd[0]][()][0, 0] if len(data_container[gd[0]][()].shape) == 2 else None data_dict["GPT"] = data_container[gpt[0]][()][0, 0] if len(data_container[gpt[0]][()].shape) == 2 else None data_dict["TRT"] = data_container[trt[0]][()][0, 0] if len(data_container[trt[0]][()].shape) == 2 else None data_dict["SFD"] = data_container[sfd[0]][()][0, 0] if len(data_container[sfd[0]][()].shape) == 2 else None data_dict["nFix"] = data_container[nFix[0]][()][0, 0] if len(data_container[nFix[0]][()].shape) == 2 else None #fixations_order_per_word.append(np.array(data_container[fixPos[0]])) #print([data_container[obj[word_idx][0]][()] for obj in Alpha_features_data]) data_dict["ALPHA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in Alpha_features_data], 0) data_dict["BETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in Beta_features_data], 0) data_dict["GAMMA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in Gamma_features_data], 0) data_dict["THETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in Theta_features_data], 0) data_dict["word_idx"] = word_idx data_dict["content"] = word_string #################################### word_tokens_all.append(word_string) if data_dict["nFix"] is not None: #################################### data_dict["GD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in GD_EEG_features] data_dict["FFD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in FFD_EEG_features] data_dict["TRT_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in TRT_EEG_features] #################################### word_tokens_has_fixation.append(word_string) word_tokens_with_mask.append(word_string) else: word_tokens_with_mask.append('[MASK]') word_level_data[word_idx] = data_dict word_idx += 1 else: print(word_string + " is not a real word.") else: # If there are no word-level data it will be word embeddings alone word_level_data = {} word_idx = 0 word_tokens_has_fixation = [] word_tokens_with_mask = [] word_tokens_all = [] for word_obj in contentData: word_string = load_matlab_string(data_container[word_obj[0]]) if is_real_word(word_string): data_dict = {} data_dict["RAW_EEG"] = [] data_dict["ICA_EEG"] = [] data_dict["RAW_ET"] = [] data_dict["FFD"] = None data_dict["GD"] = None data_dict["GPT"] = None data_dict["TRT"] = None data_dict["SFD"] = None data_dict["nFix"] = None data_dict["ALPHA_EEG"] = [] data_dict["BETA_EEG"] = [] data_dict["GAMMA_EEG"] = [] data_dict["THETA_EEG"] = [] data_dict["word_idx"] = word_idx data_dict["content"] = word_string word_level_data[word_idx] = data_dict word_idx += 1 else: print(word_string + " is not a real word.") sentence = " ".join([load_matlab_string(data_container[word_obj[0]]) for word_obj in word_objects['content']]) #print("Only available objects for the sentence '{}' are {}.".format(sentence, available_objects)) #word_level_data["word_reading_order"] = extract_word_order_from_fixations(fixations_order_per_word) else: word_tokens_has_fixation = [] word_tokens_with_mask = [] word_tokens_all = [] word_level_data = {} return word_level_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask ================================================ FILE: util/get_SST_ternary_dataset.py ================================================ import os import numpy as np import torch import pickle from torch.utils.data import Dataset, DataLoader import json import matplotlib.pyplot as plt from glob import glob from transformers import BartTokenizer from tqdm import tqdm from fuzzy_match import match from fuzzy_match import algorithims def get_SST_dataset(SST_dir_path, ZuCo_used_sentences, ZUCO_SENTIMENT_LABELS): def get_sentiment_label_dict(SST_dictionary_file_path): ''' return {phrase_id:sentiment_score(0-1)} ''' ret_dict = {} with open(SST_dictionary_file_path) as f: for line in f: if line.startswith('phrase'): continue else: phrase_id = int(line.split('|')[0]) label = float(line.split('|')[1].strip()) assert phrase_id not in ret_dict ret_dict[phrase_id] = label return ret_dict def get_phrasestr_phrase_dict(SST_dictionary_file_path): ''' return {phrase_str: phrase_id} ''' ret_dict = {} with open(SST_dictionary_file_path) as f: for line in f: phrase_str = line.split('|')[0] phrase_id = int(line.split('|')[1].strip()) assert phrase_str not in ret_dict ret_dict[phrase_str] = phrase_id return ret_dict def get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path): ''' return {sentence_str:label(0-1)} ''' phraseID_2_label = get_sentiment_label_dict(SST_labels_file_path) phraseStr_2_phraseID = get_phrasestr_phrase_dict(SST_dictionary_file_path) sentence_2_label_all = {} sentence_2_label_ternary = {} with open(SST_sentences_file_path) as f: for line in f: if line.startswith('sentence_index'): continue else: parsed_line = line.split('\t') assert len(parsed_line) == 2 sentence = parsed_line[1].strip() # convert -LRB- to (, -RRB- to ): sentence = sentence.replace('-LRB-','(').replace('-RRB-',')').replace('é','é') if sentence not in phraseStr_2_phraseID: # print(f'[ERROR]sentence-phrase match not found in dictionary, skipped: {sentence}') # print() continue sent_phrase_id = phraseStr_2_phraseID[sentence] label = phraseID_2_label[sent_phrase_id] # add to all dict if sentence not in sentence_2_label_all: sentence_2_label_all[sentence] = label # add to ternary dict if sentence not in sentence_2_label_ternary: if label<=0.2: label = 0 sentence_2_label_ternary[sentence] = label elif (label > 0.4) and (label<=0.6): label = 1 sentence_2_label_ternary[sentence] = label elif label>0.8: label = 2 sentence_2_label_ternary[sentence] = label return sentence_2_label_all, sentence_2_label_ternary SST_sentences_file_path = os.path.join(SST_dir_path,'datasetSentences.txt') if not os.path.isfile(SST_sentences_file_path): print(f'NOT FOUND file: {SST_sentences_file_path}') SST_labels_file_path = os.path.join(SST_dir_path,'sentiment_labels.txt') if not os.path.isfile(SST_labels_file_path): print(f'NOT FOUND file: {SST_labels_file_path}') SST_dictionary_file_path = os.path.join(SST_dir_path,'dictionary.txt') if not os.path.isfile(SST_dictionary_file_path): print(f'NOT FOUND file: {SST_dictionary_file_path}') sentence_2_label_all, sentence_2_label_ternary = get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path) print('original ternary dataset size:', len(sentence_2_label_ternary)) ZuCo_used_sentences = list(ZUCO_SENTIMENT_LABELS) filtered_ternary_dataset = {} filtered_pairs = [] for key,value in sentence_2_label_ternary.items(): add_instance = True for used_sent in ZuCo_used_sentences: if algorithims.trigram(used_sent, key) > 0.7: # print(f'Filter match: \n\t{used_sent}\n\t{key}') # print('###########################') filtered_pairs.append((used_sent, key)) ZuCo_used_sentences.remove(used_sent) add_instance = False break if add_instance: filtered_ternary_dataset[key] = value print('filtered instance number:', len(filtered_pairs)) print('filtered ternary dataset size:', len(filtered_ternary_dataset)) print('unmatched remaining sentences:', ZuCo_used_sentences) print('unmatched remaining sentences length:', len(ZuCo_used_sentences)) with open('temp.txt','w') as temp: for matched_pair in filtered_pairs: temp.write('#######\n') temp.write('\t'+matched_pair[0]+'\n') temp.write('\t'+matched_pair[1]+'\n') temp.write('\n') with open('./dataset/stanfordsentiment/ternary_dataset.json', 'w') as out: json.dump(filtered_ternary_dataset,out, indent = 4) print('write json to /dataset/stanfordsentiment/ternary_dataset.json') if __name__ == '__main__': print('##############################') print('start generating stanfordSentimentTreebank ternary sentiment dataset...') SST_dir_path = './dataset/stanfordsentiment/stanfordSentimentTreebank' ZuCo_task1_csv_path = './dataset/ZuCo/task_materials/sentiment_labels_task1.csv' ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json')) get_SST_dataset(SST_dir_path, ZuCo_task1_csv_path, ZUCO_SENTIMENT_LABELS) ================================================ FILE: util/get_sentiment_labels.py ================================================ import os from glob import glob import json print('##############################') print('start generating ZuCo task1-SR sentiment labels...') sentiment_labels_task1_csv_path = './dataset/ZuCo/task_materials/sentiment_labels_task1.csv' sentiment_labels = {} with open(sentiment_labels_task1_csv_path, 'r') as f: for line in f: if line.startswith('sentence_id') or line.startswith('#'): continue else: parsed_line = line.split(';') # handle edge case: if '\";' in line: sent_text = line.split('\";')[0].split('\"')[1] else: sent_text = parsed_line[1] label = int(parsed_line[-1].strip()) sentiment_labels[sent_text] = label output_dir = f'./dataset/ZuCo/task1-SR/sentiment_labels' if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, 'sentiment_labels.json'), 'w') as out: json.dump(sentiment_labels,out,indent = 4) print('write to ./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json')