[
  {
    "path": ".gitignore",
    "content": "*.pt\n*.pickle\n*.mat\n*.json\n*.txt\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\ncsv_results/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n"
  },
  {
    "path": "README.md",
    "content": "The **main branch** contains the final code for our \"Are EEG-to-Text Models Working?\" paper. \n\nAccepted by [IJCAI workshop 2024](https://github.com/user-attachments/files/16624318/IJCAI_hyejeongjo_poster_Final.pdf)\n\nIf you have any questions, you can write them in the Issues section or email Hyejeong Jo at girlsending0@khu.ac.kr.\n\ncheck our new paper with full detailed comparison of different models on this task at [https://arxiv.org/abs/2405.06459](https://arxiv.org/abs/2405.06459)\n\noverview\n![image](https://github.com/NeuSpeech/EEG-To-Text/assets/151606332/57212488-b75f-44c7-a265-e2a51483e9f5)\n\nperformance\n![image](https://github.com/NeuSpeech/EEG-To-Text/assets/151606332/df58870c-5277-4935-8c66-15efd58e9283)\n\n\n\n# Correction on [(AAAI 2022) Open Vocabulary EEG-To-Text Decoding and Zero-shot sentiment classification](https://arxiv.org/abs/2112.02690)\n# results and code is updated on **master** branch\n# results and code is updated on **master** branch\n# results and code is updated on **master** branch\n**First of all, we are not pointing at others, we do this correction due to no offense, but a kind reminder of being careful of the string generation process. \nWe repsect Mr. Wang very much, and appreciate his great contribution in this area.**\n\nAfter scrutilizing [the original code shared by Wang Zhenhailong](https://github.com/MikeWangWZHL/EEG-To-Text), we discovered that the eval method have an unintentional but very serious mistake in generating predicted strings, which is using teacher forcing implicitly. \n\nThe code which reaches my concern is:\n\n\n```python\nseq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch)\nlogits = seq2seqLMoutput.logits # bs*seq_len*voc_sz\nprobs = logits[0].softmax(dim = 1)\nvalues, predictions = probs.topk(1)\npredictions = torch.squeeze(predictions)\npredicted_string = tokenizer.decode(predictions) \n```\n\nTherefore resulting in [predictions like below](https://github.com/MikeWangWZHL/EEG-To-Text/blob/main/results/task1_task2_taskNRv2-BrainTranslator_skipstep1-all_generation_results-7_22.txt#L61):\n\n```\ntarget string: It isn't that Stealing Harvard is a horrible movie -- if only it were that grand a failure!\npredicted string:  was't a the. is was a bad place, it it it were a.. movie.\n################################################\n\n\ntarget string: It just doesn't have much else... especially in a moral sense.\npredicted string:  was so't work the to to and not the country sense.\n################################################\n\n\ntarget string: Those unfamiliar with Mormon traditions may find The Singles Ward occasionally bewildering.\npredicted string:  who with the history may be themselves Mormoning''s amusingering.\n################################################\n\n\ntarget string: Viewed as a comedy, a romance, a fairy tale, or a drama, there's nothing remotely triumphant about this motion picture.\npredicted string:  the from a whole, it film, and comedy tale, and a tragic, it is nothing quite romantic about it. picture.\n################################################\n\n\ntarget string: But the talented cast alone will keep you watching, as will the fight scenes.\npredicted string:  the most and of cannot not the entertained. and they the music against.\n################################################\n\n\ntarget string: It's solid and affecting and exactly as thought-provoking as it should be.\npredicted string:  was a, it, it what it.provoking as it is be.\n################################################\n\n\ntarget string: Thanks largely to Williams, all the interesting developments are processed in 60 minutes -- the rest is just an overexposed waste of film.\npredicted string:  to to the, the of films and in in in a minutes. and longest is a a afteragerposure, of time time\n################################################\n\n\ntarget string: Cantet perfectly captures the hotel lobbies, two-lane highways, and roadside cafes that permeate Vincent's days\npredicted string: urtor was describes the spirit'sies and the ofstory streets, and the parking of areate the's life.</s>'sgggggggg,,,,,,,,,,,,,,</s>,,,,,\n################################################\n\n\ntarget string: An important movie, a reminder of the power of film to move us and to make us examine our values.\npredicted string: nie part in \" classic of the importance of the, shape people, our make us think our lives,\n################################################\n\n\ntarget string: Too much of this well-acted but dangerously slow thriller feels like a preamble to a bigger, more complicated story, one that never materializes.\npredicted string:  bad of a is-known film not over- is like a film-ble to a much, more dramatic story. which that is endsizes.\n```\n\nIn addition, we noticed that some people are using it as code base which generates concerning results. We are not condemning these researchers, we just want to notice them and maybe we can do something together to resolve this problem. \n\n[BELT Bootstrapping Electroencephalography-to-Language Decoding and Zero-Shot SenTiment Classification by Natural Language Supervision](https://arxiv.org/pdf/2309.12056)\n\n[Aligning Semantic in Brain and Language: A Curriculum Contrastive Method for Electroencephalography-to-Text Generation](https://ieeexplore.ieee.org/iel7/7333/4359219/10248031.pdf)\n\n[UniCoRN: Unified Cognitive Signal ReconstructioN bridging cognitive signals and human language](https://arxiv.org/pdf/2307.05355)\n\n[Semantic-aware Contrastive Learning for Electroencephalography-to-Text Generation with Curriculum Learning](https://arxiv.org/pdf/2301.09237)\n\n[DeWave: Discrete EEG Waves Encoding for Brain Dynamics to Text Translation](https://arxiv.org/pdf/2309.14030)\n\nWe have written a corrected version to use model.generate to evaluate the model, the result is not so good. \nBasicly, we changed the model_decoding.py and eval_decoding.py to add model.generate for its originally nn.Module class model, and used model.generate to predict strings.\n\n**We are open to everyone to scrutinize on this corrected code and run the code. Then, we will show the final performance of this model in this repo and formalize a technical paper.**\n# We really appreciate the great contribution made by Mr. Wang, however, we should prevent others from continuing this misunderstanding. \n\n\nThis work was supported by the Culture, Sports and Tourism R&D Program through the Korea Creative Content Agency grant funded by the Ministry of Culture, Sports and Tourism (RS-2023-00226263), the Institute for Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. RS-2024-00509257, Global AI Frontier Lab), the Information Technology Research Center (ITRC) support program (IITP-2024-RS-2024-00438239) supervised by the IITP, and the IITP grant funded by the Korea government (MSIT) (No. RS-2022-00155911, Artificial Intelligence Convergence Innovation Human Resources Development, Kyung Hee University).\n"
  },
  {
    "path": "config.py",
    "content": "import argparse\n\ndef str2bool(v):\n    if isinstance(v, bool):\n        return v\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\ndef get_config(case):\n    if case == 'train_decoding': \n        # args config for training EEG-To-Text decoder\n        parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder')\n        \n        parser.add_argument('-m', '--model_name', help='choose from {BrainTranslator, BrainTranslatorNaive}', default = \"BrainTranslator\" ,required=True)\n        parser.add_argument('-t', '--task_name', help='choose from {task1,task1_task2, task1_task2_task3,task1_task2_taskNRv2}', default = \"task1\", required=True)\n        \n        parser.add_argument('-1step', '--one_step', dest='skip_step_one', action='store_true')\n        parser.add_argument('-2step', '--two_step', dest='skip_step_one', action='store_false')\n\n        parser.add_argument('-pre', '--pretrained', dest='use_random_init', action='store_false')\n        parser.add_argument('-rand', '--rand_init', dest='use_random_init', action='store_true')\n        \n        parser.add_argument('-load1', '--load_step1_checkpoint', dest='load_step1_checkpoint', action='store_true')\n        parser.add_argument('-no-load1', '--not_load_step1_checkpoint', dest='load_step1_checkpoint', action='store_false')\n\n        parser.add_argument('-ne1', '--num_epoch_step1', type = int, help='num_epoch_step1', default = 20, required=True)\n        parser.add_argument('-ne2', '--num_epoch_step2', type = int, help='num_epoch_step2', default = 30, required=True)\n        parser.add_argument('-lr1', '--learning_rate_step1', type = float, help='learning_rate_step1', default = 0.00005, required=True)\n        parser.add_argument('-lr2', '--learning_rate_step2', type = float, help='learning_rate_step2', default = 0.0000005, required=True)\n        parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True)\n        \n        parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/decoding', required=True)\n        parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)\n        parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)\n        parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)\n        parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')\n        \n        parser.add_argument('-train_input', '--train_input', help='add noise' ,required=True)\n        args = vars(parser.parse_args())\n\n    elif case == 'train_sentiment_baseline':\n        # args config for training EEG-based sentiment baselines\n        parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder')\n        \n        parser.add_argument('-m', '--model_name', help='choose from {BaselineMLP, BaselineLSTM, NaiveFinetuneBert}', default = \"NaiveFinetuneBert\" ,required=True)\n        parser.add_argument('-ne', '--num_epoch', type = int, help='num_epoch', default = 30, required=True)\n        parser.add_argument('-lr', '--learning_rate', type = float, help='learning_rate', default = 0.00001, required=True)\n        parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True)\n        parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/eeg_sentiment', required=True)\n        parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)\n        parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)\n        parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)\n        parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')\n        args = vars(parser.parse_args())\n        \n    elif case == 'train_sentiment_textbased': \n        # args config for training text-based sentiment classification models\n        parser = argparse.ArgumentParser(description='Specify config args for training text-based sentiment classifiers')\n        parser.add_argument('-d', '--dataset_name', help='zero-shot setting: using external dataset from stanford sentiment treebank, pass in SST; to use ZuCo\\'s own text-sentiment pairs, pass in ZuCo', default = \"SST\" ,required=True)\n        parser.add_argument('-m', '--model_name', help='choose from {pretrain_Bert, pretrain_RoBerta, pretrain_Bart}', default = \"pretrain_Bart\" ,required=True)\n        parser.add_argument('-ne', '--num_epoch', type = int, help='num_epoch', default = 20, required=True)\n        parser.add_argument('-lr', '--learning_rate', type = float, help='learning_rate', default = 0.0001, required=True)\n        parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True)\n        parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/text_sentiment_classifier', required=True)\n        parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)\n        parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)\n        parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)\n        parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')\n        args = vars(parser.parse_args())\n        \n    elif case == 'eval_decoding':\n        # args config for evaluating EEG-To-Text decoder\n        parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-To-Text decoder')\n        parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=True)\n        parser.add_argument('-conf', '--config_path', help='specify training config json' ,required=True)\n        parser.add_argument('-test_input', '--test_input', help='add noise' ,required=True)\n        parser.add_argument('-train_input', '--train_input', help='add noise' ,required=True)\n        parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')\n        args = vars(parser.parse_args())\n        \n    elif case == 'eval_sentiment':\n        # args config for sentiment classification models\n        parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-based sentiment classification, including Zero-shot pipeline')\n        # choose model_name = 'ZeroShotSentimentDiscovery' to evaluate Zero-shot pipeline\n        parser.add_argument('-m', '--model_name', help='choose from {BaselineMLP, BaselineLSTM, NaiveFinetuneBert, FinetunedBertOnText, FinetunedRoBertaOnText, FinetunedBartOnText, ZeroShotSentimentDiscovery}', default = \"ZeroShotSentimentDiscovery\" ,required=True)\n        parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=False) # required if NOT evaluating Zero-shot pipeline\n        parser.add_argument('-conf', '--config_path', help='specify model config json' ,required=False) # required if NOT evaluating Zero-shot pipeline\n        parser.add_argument('-checkpoint_DEC', '--decoder_checkpoint_path', help='specify decoder checkpoint for Zero-shot pipeline ', required=False) # required if evaluating Zero-shot pipeline\n        parser.add_argument('-checkpoint_CLS', '--classifier_checkpoint_path', help='specify classifier checkpoint for Zero-shot pipeline ', required=False) # required if evaluating Zero-shot pipeline\n        parser.add_argument('-conf_DEC', '--decoder_config_path', help='specify decoder config json' ,required=False) # required if evaluating Zero-shot pipeline\n        parser.add_argument('-conf_CLS', '--classifier_config_path', help='specify classifier config json' ,required=False) # required if evaluating Zero-shot pipeline\n        parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)\n        parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)\n        parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)\n        parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')\n        args = vars(parser.parse_args())\n\n    return args"
  },
  {
    "path": "data.py",
    "content": "import os\nimport numpy as np\nimport torch\nimport pickle\nfrom torch.utils.data import Dataset, DataLoader\nimport json\nimport matplotlib.pyplot as plt\nfrom glob import glob\nfrom transformers import BartTokenizer, BertTokenizer\nfrom tqdm import tqdm\nfrom fuzzy_match import match\nfrom fuzzy_match import algorithims\nfrom transformers import T5Tokenizer\n# macro\n#ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))\n#SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))\n\ndef normalize_1d(input_tensor):\n    # normalize a 1d tensor\n    mean = torch.mean(input_tensor)\n    std = torch.std(input_tensor)\n    input_tensor = (input_tensor - mean)/std\n    return input_tensor \n\ndef get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], max_len = 56, add_CLS_token = False, test_input=\"noise\"):\n    \n    def get_word_embedding_eeg_tensor(word_obj, eeg_type, bands):\n        frequency_features = []\n\n        for band in bands:\n            frequency_features.append(word_obj['word_level_EEG'][eeg_type][eeg_type+band])\n        word_eeg_embedding = np.concatenate(frequency_features)\n\n        if len(word_eeg_embedding) != 105*len(bands):\n            print(f'expect word eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None')\n            return None\n        \n        # assert len(word_eeg_embedding) == 105*len(bands)\n        return_tensor = torch.from_numpy(word_eeg_embedding)\n        return normalize_1d(return_tensor)\n\n    def get_sent_eeg(sent_obj, bands):\n        sent_eeg_features = []\n\n        for band in bands:\n            key = 'mean'+band\n            sent_eeg_features.append(sent_obj['sentence_level_EEG'][key])\n\n        sent_eeg_embedding = np.concatenate(sent_eeg_features)\n        assert len(sent_eeg_embedding) == 105*len(bands)\n        return_tensor = torch.from_numpy(sent_eeg_embedding)\n        return normalize_1d(return_tensor)\n\n    if sent_obj is None:\n        # print(f'  - skip bad sentence')   \n        return None\n\n    input_sample = {}\n    # get target label\n    target_string = sent_obj['content']\n    target_tokenized = tokenizer(target_string, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)\n    input_sample['target_ids'] = target_tokenized['input_ids'][0]\n    \n    # get sentence level EEG features\n    sent_level_eeg_tensor = get_sent_eeg(sent_obj, bands)\n    # try:\n    #     sent_level_eeg_tensor = torch.from_numpy(sent_obj['sentence_level_EEG']) # This gives a dictionary\n    # except:\n    #     return None\n    \n    if torch.isnan(sent_level_eeg_tensor).any():\n        # print('[NaN sent level eeg]: ', target_string)\n        return None\n    # if sent_level_eeg_tensor.shape[1] < 30:\n    #     return None\n    \n    input_sample['sent_level_EEG'] = sent_level_eeg_tensor\n    #input_sample['sent_level_EEG'] = torch.randn(sent_level_eeg_tensor.size()) # random input code\n    #print(\"NOISE:\", input_sample['sent_level_EEG'])\n\n    # get sentiment label\n    # handle some wierd case\n    if 'emp11111ty' in target_string:\n        target_string = target_string.replace('emp11111ty','empty')\n    if 'film.1' in target_string:\n        target_string = target_string.replace('film.1','film.')\n    \n    #if target_string in ZUCO_SENTIMENT_LABELS:\n    #    input_sample['sentiment_label'] = torch.tensor(ZUCO_SENTIMENT_LABELS[target_string]+1) # 0:Negative, 1:Neutral, 2:Positive\n    #else:\n    #    input_sample['sentiment_label'] = torch.tensor(-100) # dummy value\n    input_sample['sentiment_label'] = torch.tensor(-100) # dummy value\n\n    # get input embeddings\n    word_embeddings = []\n\n    \"\"\"add CLS token embedding at the front\"\"\"\n    if add_CLS_token:\n        word_embeddings.append(torch.ones(105*len(bands)))\n\n    for word in sent_obj['word']:\n        # add each word's EEG embedding as Tensors\n        word_level_eeg_tensor = get_word_embedding_eeg_tensor(word, eeg_type, bands = bands)\n        # check none, for v2 dataset\n        if word_level_eeg_tensor is None:\n            return None\n        # check nan:\n        if torch.isnan(word_level_eeg_tensor).any():\n            # print()\n            # print('[NaN ERROR] problem sent:',sent_obj['content'])\n            # print('[NaN ERROR] problem word:',word['content'])\n            # print('[NaN ERROR] problem word feature:',word_level_eeg_tensor)\n            # print()\n            return None\n        \n        word_embeddings.append(word_level_eeg_tensor)\n\n    # pad to max_len\n    while len(word_embeddings) < max_len:\n        word_embeddings.append(torch.zeros(105*len(bands)))\n\n    if test_input=='noise':\n        rand_eeg= torch.randn(torch.stack(word_embeddings).size())\n        input_sample['input_embeddings'] = rand_eeg # max_len * (105*num_bands)\n        # print(\"rand_eeg:\", rand_eeg)\n        # print(\"input_embeddings:\", input_sample['input_embeddings'].shape)\n\n    else:\n        input_sample['input_embeddings'] = torch.stack(word_embeddings) # max_len * (105*num_bands)\n        print(\"EEG\", input_sample['input_embeddings'])\n    \n    # mask out padding tokens\n    input_sample['input_attn_mask'] = torch.zeros(max_len) # 0 is masked out\n\n    if add_CLS_token:\n        input_sample['input_attn_mask'][:len(sent_obj['word'])+1] = torch.ones(len(sent_obj['word'])+1) # 1 is not masked\n    else:\n        input_sample['input_attn_mask'][:len(sent_obj['word'])] = torch.ones(len(sent_obj['word'])) # 1 is not masked\n    \n\n    # mask out padding tokens reverted: handle different use case: this is for pytorch transformers\n    input_sample['input_attn_mask_invert'] = torch.ones(max_len) # 1 is masked out\n\n    if add_CLS_token:\n        input_sample['input_attn_mask_invert'][:len(sent_obj['word'])+1] = torch.zeros(len(sent_obj['word'])+1) # 0 is not masked\n    else:\n        input_sample['input_attn_mask_invert'][:len(sent_obj['word'])] = torch.zeros(len(sent_obj['word'])) # 0 is not masked\n\n    # mask out target padding for computing cross entropy loss\n    input_sample['target_mask'] = target_tokenized['attention_mask'][0]\n    input_sample['seq_len'] = len(sent_obj['word'])\n    \n    # clean 0 length data\n    if input_sample['seq_len'] == 0:\n        print('discard length zero instance: ', target_string)\n        return None\n\n    return input_sample\n\nclass ZuCo_dataset(Dataset):\n    def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], setting = 'unique_sent', is_add_CLS_token = False, test_input='noise'):\n        self.inputs = []\n        self.tokenizer = tokenizer\n\n        if not isinstance(input_dataset_dicts,list):\n            input_dataset_dicts = [input_dataset_dicts]\n        print(f'[INFO]loading {len(input_dataset_dicts)} task datasets')\n        for input_dataset_dict in input_dataset_dicts:\n            if subject == 'ALL':\n                subjects = list(input_dataset_dict.keys())\n                print('[INFO]using subjects: ', subjects)\n            else:\n                subjects = [subject]\n            \n            total_num_sentence = len(input_dataset_dict[subjects[0]])\n            \n            train_divider = int(0.8*total_num_sentence)\n            dev_divider = train_divider + int(0.1*total_num_sentence)\n            \n            print(f'train divider = {train_divider}')\n            print(f'dev divider = {dev_divider}')\n\n            if setting == 'unique_sent':\n                # take first 80% as trainset, 10% as dev and 10% as test\n                if phase == 'train':\n                    print('[INFO]initializing a train set...')\n                    for key in subjects:\n                        for i in range(train_divider):\n                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input)\n                            if input_sample is not None:\n                                self.inputs.append(input_sample)\n                elif phase == 'dev':\n                    print('[INFO]initializing a dev set...')\n                    for key in subjects:\n                        for i in range(train_divider,dev_divider):\n                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input)\n                            if input_sample is not None:\n                                self.inputs.append(input_sample)\n                elif phase == 'test':\n                    print('[INFO]initializing a test set...')\n                    for key in subjects:\n                        for i in range(dev_divider,total_num_sentence):\n                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input)\n                            if input_sample is not None:\n                                self.inputs.append(input_sample)\n            elif setting == 'unique_subj':\n                print('WARNING!!! only implemented for SR v1 dataset ')\n                # subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train\n                # subject ['ZMG'] for dev\n                # subject ['ZPH'] for test\n                if phase == 'train':\n                    print(f'[INFO]initializing a train set using {setting} setting...')\n                    for i in range(total_num_sentence):\n                        for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']:\n                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)\n                            if input_sample is not None:\n                                self.inputs.append(input_sample)\n                if phase == 'dev':\n                    print(f'[INFO]initializing a dev set using {setting} setting...')\n                    for i in range(total_num_sentence):\n                        for key in ['ZMG']:\n                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)\n                            if input_sample is not None:\n                                self.inputs.append(input_sample)\n                if phase == 'test':\n                    print(f'[INFO]initializing a test set using {setting} setting...')\n                    for i in range(total_num_sentence):\n                        for key in ['ZPH']:\n                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)\n                            if input_sample is not None:\n                                self.inputs.append(input_sample)\n            print('++ adding task to dataset, now we have:', len(self.inputs))\n\n        print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size())\n        print()\n\n    def __len__(self):\n        return len(self.inputs)\n\n    def __getitem__(self, idx):\n        input_sample = self.inputs[idx]\n        return (\n            input_sample['input_embeddings'], \n            input_sample['seq_len'],\n            input_sample['input_attn_mask'], \n            input_sample['input_attn_mask_invert'],\n            input_sample['target_ids'], \n            input_sample['target_mask'], \n            input_sample['sentiment_label'], \n            #input_sample['sent_level_EEG']\n        )\n        # keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, \n\n\n\"\"\"for train classifier on stanford sentiment treebank text-sentiment pairs\"\"\"\nclass SST_tenary_dataset(Dataset):\n    def __init__(self, ternary_labels_dict, tokenizer, max_len = 56, balance_class = True):\n        self.inputs = []\n        \n        pos_samples = []\n        neg_samples = []\n        neu_samples = []\n\n        for key,value in ternary_labels_dict.items():\n            tokenized_inputs = tokenizer(key, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)\n            input_ids = tokenized_inputs['input_ids'][0]\n            attn_masks = tokenized_inputs['attention_mask'][0]\n            label = torch.tensor(value)\n            # count:\n            if value == 0:\n                neg_samples.append((input_ids,attn_masks,label))\n            elif value == 1:\n                neu_samples.append((input_ids,attn_masks,label))\n            elif value == 2:\n                pos_samples.append((input_ids,attn_masks,label))\n        print(f'Original distribution:\\n\\tVery positive: {len(pos_samples)}\\n\\tNeutral: {len(neu_samples)}\\n\\tVery negative: {len(neg_samples)}')    \n        if balance_class:\n            print(f'balance class to {min([len(pos_samples),len(neg_samples),len(neu_samples)])} each...')\n            for i in range(min([len(pos_samples),len(neg_samples),len(neu_samples)])):\n                self.inputs.append(pos_samples[i])\n                self.inputs.append(neg_samples[i])\n                self.inputs.append(neu_samples[i])\n        else:\n            self.inputs = pos_samples + neg_samples + neu_samples\n        \n    def __len__(self):\n        return len(self.inputs)\n\n    def __getitem__(self, idx):\n        input_sample = self.inputs[idx]\n        return input_sample\n        # keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, \n        \n\n\n'''sanity test'''\nif __name__ == '__main__':\n\n    check_dataset = 'stanford_sentiment'\n\n    if check_dataset == 'ZuCo':\n        whole_dataset_dicts = []\n        \n        dataset_path_task1 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task1-SR/pickle/task1-SR-dataset-with-tokens_6-25.pickle' \n        with open(dataset_path_task1, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n\n        dataset_path_task2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR/pickle/task2-NR-dataset-with-tokens_7-10.pickle' \n        with open(dataset_path_task2, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n\n        # dataset_path_task3 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset-with-tokens_7-10.pickle' \n        # with open(dataset_path_task3, 'rb') as handle:\n        #     whole_dataset_dicts.append(pickle.load(handle))\n\n        dataset_path_task2_v2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset-with-tokens_7-15.pickle' \n        with open(dataset_path_task2_v2, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n\n        print()\n        for key in whole_dataset_dicts[0]:\n            print(f'task2_v2, sentence num in {key}:',len(whole_dataset_dicts[0][key]))\n        print()\n\n        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')\n        dataset_setting = 'unique_sent'\n        subject_choice = 'ALL'\n        print(f'![Debug]using {subject_choice}')\n        eeg_type_choice = 'GD'\n        print(f'[INFO]eeg type {eeg_type_choice}') \n        bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] \n        print(f'[INFO]using bands {bands_choice}')\n        train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n        dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n        test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n\n        print('trainset size:',len(train_set))\n        print('devset size:',len(dev_set))\n        print('testset size:',len(test_set))\n\n    elif check_dataset == 'stanford_sentiment':\n        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n        SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer)\n        print('SST dataset size:',len(SST_dataset))\n        print(SST_dataset[0])\n        print(SST_dataset[1])\n"
  },
  {
    "path": "environment.yml",
    "content": "name: EEGToText\nchannels:\n  - pytorch\n  - anaconda\n  - conda-forge\n  - huggingface\ndependencies:\n  - pytorch=1.9.0\n  - torchaudio=0.9.0\n  - cudatoolkit=11.1\n  - scipy=1.6.2\n  - h5py=3.4.0\n  - tqdm=4.62.0\n  - matplotlib=3.3.2\n  - transformers=4.6.1\n  - nltk=3.5\n  - pip=21.0.1\n  - pip:\n    - fuzzy-match==0.0.1\n    - rouge==1.0.0\n"
  },
  {
    "path": "eval_decoding.py",
    "content": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\nimport pickle\nimport json\nimport matplotlib.pyplot as plt\nfrom glob import glob\nimport time\nimport copy\nfrom tqdm import tqdm\nimport torch.nn.functional as F\nimport time\nfrom transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, PegasusForConditionalGeneration, PegasusTokenizer, T5Tokenizer, T5ForConditionalGeneration, BertGenerationDecoder\nfrom data import ZuCo_dataset\nfrom model_decoding import BrainTranslator, BrainTranslatorNaive, T5Translator\nfrom nltk.translate.bleu_score import sentence_bleu, corpus_bleu\nfrom rouge import Rouge\nfrom config import get_config\nimport evaluate\nfrom evaluate import load\n\nmetric = evaluate.load(\"sacrebleu\")\ncer_metric = load(\"cer\")\nwer_metric = load(\"wer\")\n\ndef remove_text_after_token(text, token='</s>'):\n    # 특정 토큰 이후의 텍스트를 찾아 제거\n    token_index = text.find(token)\n    if token_index != -1:  # 토큰이 발견된 경우\n        return text[:token_index]  # 토큰 이전까지의 텍스트 반환\n    return text  # 토큰이 없으면 원본 텍스트 반환\n\ndef eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = './results/temp.txt' , score_results='./score_results/task.txt'):\n    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n    start_time = time.time()\n    model.eval()   # Set model to evaluate mode\n    \n    target_tokens_list = []\n    target_string_list = []\n    pred_tokens_list = []\n    pred_string_list = []\n    pred_tokens_list_previous = []\n    pred_string_list_previous = []\n\n\n    with open(output_all_results_path,'w') as f:\n        for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels in tqdm(dataloaders['test']):\n            # load in batch\n            input_embeddings_batch = input_embeddings.to(device).float() # B, 56, 840\n            input_masks_batch = input_masks.to(device) # B, 56\n            target_ids_batch = target_ids.to(device) # B, 56\n            input_mask_invert_batch = input_mask_invert.to(device) # B, 56\n            \n            target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch[0].tolist(), skip_special_tokens = True)\n            target_string = tokenizer.decode(target_ids_batch[0], skip_special_tokens = True)\n            # print('target ids tensor:',target_ids_batch[0])\n            # print('target ids:',target_ids_batch[0].tolist())\n            # print('target tokens:',target_tokens)\n            # print('target string:',target_strininvert.to(device) # B, 56\n            \n            f.write(f'target string: {target_string}\\n')\n\n            # add to list for later calculate bleu metric\n            target_tokens_list.append([target_tokens])\n            target_string_list.append(target_string)\n            \n            \"\"\"replace padding ids in target_ids with -100\"\"\"\n            target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100 \n\n            # target_ids_batch_label = target_ids_batch.clone().detach()\n            # target_ids_batch_label[target_ids_batch_label == tokenizer.pad_token_id] = -100\n\n            # Original code \n            seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch) # (batch, time, n_class)\n            logits_previous = seq2seqLMoutput.logits\n            probs_previous = logits_previous[0].softmax(dim = 1)\n            values_previous, predictions_previous = probs_previous.topk(1)\n            predictions_previous = torch.squeeze(predictions_previous)\n            predicted_string_previous = remove_text_after_token(tokenizer.decode(predictions_previous).split('</s></s>')[0].replace('<s>',''))\n            f.write(f'predicted string with tf: {predicted_string_previous}\\n')\n            predictions_previous = predictions_previous.tolist()\n            truncated_prediction_previous = []\n            for t in predictions_previous:\n                if t != tokenizer.eos_token_id:\n                    truncated_prediction_previous.append(t)\n                else:\n                    break\n            pred_tokens_previous = tokenizer.convert_ids_to_tokens(truncated_prediction_previous, skip_special_tokens = True)\n            pred_tokens_list_previous.append(pred_tokens_previous)\n            pred_string_list_previous.append(predicted_string_previous)\n            \n\n            # Modify code\n            predictions=model.generate(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch,\n                                       max_length=56,\n                                       num_beams=5,\n                                       do_sample=True,\n                                       repetition_penalty= 5.0,\n                                       no_repeat_ngram_size = 2\n                                       # num_beams=5,encoder_no_repeat_ngram_size =1,\n                                       # do_sample=True, top_k=15,temperature=0.5,num_return_sequences=5,\n                                       # early_stopping=True\n                                       )\n            \n            predicted_string=tokenizer.batch_decode(predictions, skip_special_tokens=True)[0]\n            # predicted_string=predicted_string.squeeze()\n            \n            predictions=tokenizer.encode(predicted_string)\n            # print('predicted string:',predicted_string)\n            f.write(f'predicted string: {predicted_string}\\n')\n            f.write(f'################################################\\n\\n\\n')\n\n            # convert to int list\n            # predictions = predictions.tolist() # 이미 list 형식이다. \n            truncated_prediction = []\n            for t in predictions:\n                if t != tokenizer.eos_token_id:\n                    truncated_prediction.append(t)\n                else:\n                    break\n            pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True)\n            # print('predicted tokens:',pred_tokens)\n            pred_tokens_list.append(pred_tokens)\n            pred_string_list.append(predicted_string)\n            # pred_tokens_list.extend(pred_tokens)\n            # pred_string_list.extend(predicted_string)\n            # print('################################################')\n            # print()\n    # print(f\"pred_string_list : {pred_string_list}\")\n    \n    \"\"\" calculate corpus bleu score \"\"\"\n    weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)]\n    corpus_bleu_scores = []\n    corpus_bleu_scores_previous = []\n    for weight in weights_list:\n        # print('weight:',weight)\n        corpus_bleu_score = corpus_bleu(target_tokens_list, pred_tokens_list, weights = weight)\n        corpus_bleu_score_previous = corpus_bleu(target_tokens_list, pred_tokens_list_previous, weights = weight)\n        corpus_bleu_scores.append(corpus_bleu_score)\n        corpus_bleu_scores_previous.append(corpus_bleu_score_previous)\n        print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score)\n        print(f'corpus BLEU-{len(list(weight))} score with tf:', corpus_bleu_score_previous)\n\n\n    \"\"\" calculate sacre bleu score \"\"\"\n    \n    reference_list = [[item] for item in target_string_list]\n\n    #print(f'ref: {reference_list}')\n    #print(f'pred: {prediction_list}')\n    sacre_blue = metric.compute(predictions=pred_string_list, references=reference_list)\n    sacre_blue_previous = metric.compute(predictions=pred_string_list_previous, references=reference_list)\n    print(\"sacreblue score: \", sacre_blue, '\\n')\n    print(\"sacreblue score with tf: \", sacre_blue_previous)\n\n\n    print()\n    \"\"\" calculate rouge score \"\"\"\n    rouge = Rouge()\n    \n    # pred_string_list = [item for sublist in pred_string_list for item in sublist]\n    # pred_string_list = [item for sublist in pred_string_list for item in sublist]\n    # pred_string_list_previous = [item for sublist in pred_string_list_previous for item in sublist]\n    # rouge_scores = rouge.get_scores(pred_string_list, target_string_list, avg = True, ignore_empty=True)\n    # rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True)\n    # print('rouge_scores: ', rouge_scores)\n    # print('rouge_scores with tf:', rouge_scores_previous)\n\n    # rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True)\n    # print('rouge_scores', rouge_scores)\n    # print('previous rouge_scores', rouge_scores_previous)\n\n    try:\n        rouge_scores = rouge.get_scores(pred_string_list, target_string_list, avg = True, ignore_empty=True)\n    except ValueError as e:\n        rouge_scores = 'Hypothesis is empty'\n\n    try:\n        rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True)\n    except ValueError as e:\n        rouge_scores_previous = 'Hypothesis is empty'\n    print()\n\n\n    print()\n    \"\"\" calculate WER score \"\"\"\n    #wer = WordErrorRate()\n    wer_scores = wer_metric.compute(predictions=pred_string_list, references=target_string_list)\n    wer_scores_previous = wer_metric.compute(predictions=pred_string_list_previous, references=target_string_list)\n    print(\"WER score:\", wer_scores)\n    print(\"WER score with tf:\", wer_scores_previous)\n    \n\n    \"\"\" calculate CER score \"\"\"\n    cer_scores = cer_metric.compute(predictions=pred_string_list, references=target_string_list)\n    cer_scores_previous = cer_metric.compute(predictions=pred_string_list_previous, references=target_string_list)\n    print(\"CER score:\", cer_scores)\n    print(\"CER score with tf:\", cer_scores_previous)\n\n\n    end_time = time.time()\n    print(f\"Evaluation took {(end_time-start_time)/60} minutes to execute.\")\n\n     # score_results (only fix teacher-forcing)\n    file_content = [\n    f'corpus_bleu_score = {corpus_bleu_scores}',\n    f'sacre_blue_score = {sacre_blue}',\n    f'rouge_scores = {rouge_scores}',\n    f'wer_scores = {wer_scores}',\n    f'cer_scores = {cer_scores}',\n    f'corpus_bleu_score_with_tf = {corpus_bleu_scores_previous}',\n    f'sacre_blue_score_with_tf = {sacre_blue_previous}',\n    f'rouge_scores_with_tf = {rouge_scores_previous}',\n    f'wer_scores_with_tf = {wer_scores_previous}',\n    f'cer_scores_with_tf = {cer_scores_previous}',\n    ]\n    \n    with open(score_results, \"a\") as file_results:\n        for line in file_content:\n            if isinstance(line, list):\n                for item in line:\n                    file_results.write(str(item) + \"\\n\")\n            else:\n                file_results.write(str(line) + \"\\n\")\n\n\n\nif __name__ == '__main__': \n    batch_size = 1\n    ''' get args'''\n    args = get_config('eval_decoding')\n    test_input = args['test_input']\n    print(\"test_input is:\", test_input)\n    train_input = args['train_input']\n    print(\"train_input is:\", train_input)\n    ''' load training config'''\n    training_config = json.load(open(args['config_path']))\n\n\n    subject_choice = training_config['subjects']\n    print(f'[INFO]subjects: {subject_choice}')\n    eeg_type_choice = training_config['eeg_type']\n    print(f'[INFO]eeg type: {eeg_type_choice}')\n    bands_choice = training_config['eeg_bands']\n    print(f'[INFO]using bands: {bands_choice}')\n    \n    dataset_setting = 'unique_sent'\n\n    task_name = training_config['task_name']\n    model_name = training_config['model_name']\n    \n\n    if test_input == 'EEG' and train_input=='EEG':\n        print(\"EEG and EEG\")\n        output_all_results_path = f'./results/{task_name}-{model_name}-all_decoding_results.txt'\n        score_results = f'./score_results/{task_name}-{model_name}.txt'\n    else:\n        output_all_results_path = f'./results/{task_name}-{model_name}-{train_input}_{test_input}-all_decoding_results.txt'\n        score_results = f'./score_results/{task_name}-{model_name}-{train_input}_{test_input}.txt'\n\n\n    ''' set random seeds '''\n    seed_val = 20 #500\n    np.random.seed(seed_val)\n    torch.manual_seed(seed_val)\n    torch.cuda.manual_seed_all(seed_val)\n\n    ''' set up device '''\n    # use cuda\n    if torch.cuda.is_available():  \n        dev = 0\n    else:  \n        dev = \"cpu\"\n    # CUDA_VISIBLE_DEVICES=0,1,2,3  \n    device = torch.device(dev)\n    print(f'[INFO]using device {dev}')\n\n    # task_name = 'task1_task2_task3'\n\n    ''' set up dataloader '''\n    whole_dataset_dicts = []\n    if 'task1' in task_name:\n        dataset_path_task1 = '/data/johj/ZuCo_data/task1-SR/task1_source.pkl' \n        with open(dataset_path_task1, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n    if 'task2' in task_name:\n        dataset_path_task2 = '/data/johj/ZuCo_data/task2-NR/task2_source.pkl' \n        with open(dataset_path_task2, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n    if 'task3' in task_name:\n        dataset_path_task3 = '/data/johj/ZuCo_data/task3-TSR/task3_source.pkl' \n        with open(dataset_path_task3, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n    if 'taskNRv2' in task_name:\n        dataset_path_taskNRv2 = '/data/johj/ZuCo_data/task2-NR-2.0/taskNRv2_source.pkl' \n        with open(dataset_path_taskNRv2, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n    print()\n    \n    if model_name in ['BrainTranslator','BrainTranslatorNaive']:\n        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')\n\n    elif model_name == 'PegasusTranslator':\n        tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum')\n    \n    elif model_name == 'T5Translator':\n        tokenizer = T5Tokenizer.from_pretrained(\"t5-large\")\n        # tokenizer.set_prefix_tokens(language='english')\n\n    # test dataset\n    test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=test_input)\n\n    dataset_sizes = {\"test_set\":len(test_set)}\n    print('[INFO]test_set size: ', len(test_set))\n    \n    # dataloaders\n    test_dataloader = DataLoader(test_set, batch_size = batch_size, shuffle=False, num_workers=4)\n\n    dataloaders = {'test':test_dataloader}\n\n    ''' set up model '''\n    checkpoint_path = args['checkpoint_path']\n    \n    if model_name == 'BrainTranslator':\n        pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')\n        model = BrainTranslator(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n    \n    elif model_name == 'BrainTranslatorNaive':\n        pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')\n        model = BrainTranslatorNaive(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n\n    elif model_name == 'BertGeneration':\n        pretrained = BertGenerationDecoder.from_pretrained('google-bert/bert-large-uncased', is_decoder = True)\n        model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n        \n    elif model_name == 'PegasusTranslator':\n        pretrained = PegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum')\n        model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n    \n    elif model_name == 'T5Translator':\n        pretrained = T5ForConditionalGeneration.from_pretrained(\"t5-large\")\n        model = T5Translator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n    \n\n    state_dict = torch.load(checkpoint_path)\n    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}\n    model.load_state_dict(new_state_dict)\n\n    '''\n    if isinstance(model, nn.DataParallel):\n        model.module.load_state_dict(torch.load(checkpoint_path))\n    else:\n        model.load_state_dict(torch.load(checkpoint_path))\n    '''\n\n    # model.load_state_dict(torch.load(checkpoint_path))\n    model.to(device)\n    \n    criterion = nn.CrossEntropyLoss()\n    \n    ''' eval '''\n    eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = output_all_results_path, score_results=score_results)\n"
  },
  {
    "path": "eval_sentiment.py",
    "content": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\nfrom torch.nn.utils.rnn import pack_padded_sequence \nimport pickle\nimport json\nimport matplotlib.pyplot as plt\nfrom glob import glob\nimport time\nimport copy\nfrom tqdm import tqdm\n\nfrom transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification\nfrom data import ZuCo_dataset\nfrom model_sentiment import BaselineMLPSentence, BaselineLSTM, FineTunePretrainedTwoStep, ZeroShotSentimentDiscovery, JointBrainTranslatorSentimentClassifier\nfrom model_decoding import BrainTranslator, BrainTranslatorNaive\nfrom sklearn.metrics import precision_recall_fscore_support\nfrom sklearn.metrics import accuracy_score\nfrom config import get_config\n\n# Function to calculate the accuracy of our predictions vs labels\ndef flat_accuracy(preds, labels):\n    # preds: numpy array: N * 3 \n    # labels: numpy array: N \n    pred_flat = np.argmax(preds, axis=1).flatten()  \n    \n    labels_flat = labels.flatten()\n    \n    return np.sum(pred_flat == labels_flat) / len(labels_flat)\n\ndef flat_accuracy_top_k(preds, labels,k):\n    topk_preds = []\n    for pred in preds:\n        topk = pred.argsort()[-k:][::-1]\n        topk_preds.append(list(topk))\n    # print(topk_preds)\n    topk_preds = list(topk_preds)\n    right_count = 0\n    # print(len(labels))\n    for i in range(len(labels)):\n        l = labels[i][0]\n        if l in topk_preds[i]:\n            right_count+=1\n    return right_count/len(labels)\n\ndef eval_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')):\n\n    def logits2PredString(logits, tokenizer):\n        probs = logits[0].softmax(dim = 1)\n        # print('probs size:', probs.size())\n        values, predictions = probs.topk(1)\n        # print('predictions before squeeze:',predictions.size())\n        predictions = torch.squeeze(predictions)\n        predict_string = tokenizer.decode(predictions)\n        return predict_string\n    \n    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n    since = time.time()\n      \n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = 100000000000\n    best_acc = 0.0\n    \n    total_pred_labels = np.array([])\n    total_true_labels = np.array([])\n\n    for epoch in range(1):\n        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n        print('-' * 10)\n\n        # Each epoch has a training and validation phase\n        for phase in ['test']:\n            total_accuracy = 0.0\n            if phase == 'train':\n                model.train()  # Set model to training mode\n            else:\n                model.eval()   # Set model to evaluate mode\n\n            running_loss = 0.0\n\n            # Iterate over data.\n            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in dataloaders[phase]:\n                \n                input_word_eeg_features = input_word_eeg_features.to(device).float()\n                input_masks = input_masks.to(device)\n                input_mask_invert = input_mask_invert.to(device)\n \n                sent_level_EEG = sent_level_EEG.to(device)\n                sentiment_labels = sentiment_labels.to(device)\n\n                target_ids = target_ids.to(device)\n                target_mask = target_mask.to(device)\n\n                ## forward ###################\n                if isinstance(model, BaselineMLPSentence):\n                    logits = model(sent_level_EEG) # before softmax\n                    # calculate loss\n                    loss = criterion(logits, sentiment_labels)\n\n                elif isinstance(model, BaselineLSTM):\n                    x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False)\n                    logits = model(x_packed)\n                    # calculate loss\n                    loss = criterion(logits, sentiment_labels)\n\n                elif isinstance(model, BertForSequenceClassification) or isinstance(model, RobertaForSequenceClassification) or isinstance(model, BartForSequenceClassification):\n                    output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels)\n                    logits = output.logits\n                    loss = output.loss\n                \n                elif isinstance(model, FineTunePretrainedTwoStep):\n                    output = model(input_word_eeg_features, input_masks, input_mask_invert, sentiment_labels)\n                    logits = output.logits\n                    loss = output.loss\n\n                elif isinstance(model, ZeroShotSentimentDiscovery):    \n                    print()\n                    print('target string:',tokenizer.decode(target_ids[0]).replace('<pad>','').split('</s>')[0]) \n\n                    \"\"\"replace padding ids in target_ids with -100\"\"\"\n                    target_ids[target_ids == tokenizer.pad_token_id] = -100 \n\n                    output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels)\n                    logits = output.logits\n                    loss = output.loss\n                \n                elif isinstance(model, JointBrainTranslatorSentimentClassifier):\n\n                    print()\n                    print('target string:',tokenizer.decode(target_ids[0]).replace('<pad>','').split('</s>')[0]) \n\n                    \"\"\"replace padding ids in target_ids with -100\"\"\"\n                    target_ids[target_ids == tokenizer.pad_token_id] = -100 \n\n                    LM_output, classification_output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels)\n                    LM_logits = LM_output.logits\n                    print('pred string:', logits2PredString(LM_logits, tokenizer).split('</s></s>')[0].replace('<s>',''))\n                    classification_loss = classification_output['loss']\n                    logits = classification_output['logits']\n                    loss = classification_loss \n                ###############################\n\n                # backward + optimize only if in training phase\n                if phase == 'train':\n                    # with torch.autograd.detect_anomaly():\n                    loss.backward()\n                    optimizer.step()\n\n                # calculate accuracy\n                preds_cpu = logits.detach().cpu().numpy()\n                label_cpu = sentiment_labels.cpu().numpy()\n\n                total_accuracy += flat_accuracy(preds_cpu, label_cpu)\n                \n                # add to total pred and label array, for cal F1, precision, recall\n                pred_flat = np.argmax(preds_cpu, axis=1).flatten()\n                labels_flat = label_cpu.flatten()\n\n                total_pred_labels = np.concatenate((total_pred_labels,pred_flat))\n                total_true_labels = np.concatenate((total_true_labels,labels_flat))\n                \n\n                # statistics\n                running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss\n                # print('[DEBUG]loss:',loss.item())\n                # print('#################################')\n                \n\n            if phase == 'train':\n                scheduler.step()\n\n            epoch_loss = running_loss / dataset_sizes[phase]\n            epoch_acc = total_accuracy / len(dataloaders[phase])\n            print('{} Loss: {:.4f}'.format(phase, epoch_loss))\n            print('{} Acc: {:.4f}'.format(phase, epoch_acc))\n\n            # deep copy the model\n            if phase == 'test' and epoch_loss < best_loss:\n                best_loss = epoch_loss\n                best_acc = epoch_acc\n        print()\n\n    time_elapsed = time.time() - since\n    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n    print('Best test loss: {:4f}'.format(best_loss))\n    print('Best test acc: {:4f}'.format(best_acc))\n    print()\n    print('test sample num:', len(total_pred_labels))\n    print('total preds:',total_pred_labels)\n    print('total truth:',total_true_labels)\n    print('sklearn macro: precision, recall, F1:')\n    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='macro'))\n    print()\n    print('sklearn micro: precision, recall, F1:')\n    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='micro'))\n    print()\n    print('sklearn accuracy:')\n    print(accuracy_score(total_true_labels,total_pred_labels))\n    print()\n\n\n\nif __name__ == '__main__':\n    args = get_config('eval_sentiment')\n\n    ''' config param'''\n    num_epochs = 1\n\n    dataset_setting = 'unique_sent'\n    \n    '''model name'''\n    # model_name = 'BaselineMLP'\n    # model_name = 'BaselineLSTM'\n    # model_name = 'NaiveFinetuneBert'\n    # model_name = 'FinetunedBertOnText'\n    # model_name = 'FinetunedRoBertaOnText'\n    # model_name = 'FinetunedBartOnText'\n    # model_name = 'ZeroShotSentimentDiscovery'\n    model_name = args['model_name']\n\n    print(f'[INFO] eval {model_name}')\n    if model_name == 'ZeroShotSentimentDiscovery':\n        '''load decoder and classifier config'''\n        config_decoder = json.load(open(args['decoder_config_path']))\n        config_classifier = json.load(open(args['classifier_config_path']))\n        '''choose generator'''\n        # decoder_name = 'BrainTranslator'\n        # decoder_name = 'BrainTranslatorNaive'\n        decoder_name = config_decoder['model_name']\n        decoder_checkpoint = args['decoder_checkpoint_path']\n        print(f'[INFO] using decoder: {decoder_name}')\n\n        '''choose classifier'''\n        # pretrain_Bert, pretrain_RoBerta, pretrain_Bart\n        classifier_name = config_classifier['model_name']\n        classifier_checkpoint = args['classifier_checkpoint_path']\n        print(f'[INFO] using classifier: {classifier_name}')\n    else:\n        checkpoint_path = args['checkpoint_path']\n        print('[INFO] loading baseline:', checkpoint_path)\n\n    batch_size = 1\n\n\n    # subject_choice = 'ALL\n    subject_choice = args['subjects']\n    print(f'![Debug]using {subject_choice}')\n    # eeg_type_choice = 'GD\n    eeg_type_choice = args['eeg_type']\n    print(f'[INFO]eeg type {eeg_type_choice}')\n    # bands_choice = ['_t1'] \n    # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] \n    bands_choice = args['eeg_bands']\n    print(f'[INFO]using bands {bands_choice}')\n\n\n    \n    ''' set random seeds '''\n    seed_val = 312\n    np.random.seed(seed_val)\n    torch.manual_seed(seed_val)\n    torch.cuda.manual_seed_all(seed_val)\n\n\n    ''' set up device '''\n    # use cuda\n    if torch.cuda.is_available():  \n        dev = args['cuda']\n    else:  \n        dev = \"cpu\"\n    # CUDA_VISIBLE_DEVICES=0,1,2,3  \n    device = torch.device(dev)\n    print(f'[INFO]using device {dev}')\n\n\n    ''' load pickle'''\n    whole_dataset_dict = []\n    dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' \n    with open(dataset_path_task1, 'rb') as handle:\n        whole_dataset_dict.append(pickle.load(handle))\n    \n    '''set up tokenizer'''\n    if model_name in ['BaselineMLP','BaselineLSTM', 'NaiveFinetuneBert', 'FinetunedBertOnText']:\n        print('[INFO]using Bert tokenizer')\n        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n    elif model_name == 'FinetunedBartOnText':\n        print('[INFO]using Bart tokenizer')\n        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')\n    elif model_name == 'FinetunedRoBertaOnText':\n        print('[INFO]using RoBerta tokenizer')\n        tokenizer =  RobertaTokenizer.from_pretrained('roberta-base')\n    elif model_name == 'ZeroShotSentimentDiscovery':\n        decoder_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') # Bart\n        tokenizer = decoder_tokenizer\n        if classifier_name == 'pretrain_Bert':\n            sentiment_tokenizer = BertTokenizer.from_pretrained('bert-base-cased') # Bert\n        elif classifier_name == 'pretrain_Bart':\n            sentiment_tokenizer = decoder_tokenizer\n        elif classifier_name == 'pretrain_RoBerta':\n            sentiment_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n\n    ''' set up model '''\n    if model_name == 'BaselineMLP':\n        print('[INFO]Model: BaselineMLP')\n        model = BaselineMLPSentence(input_dim = 840, hidden_dim = 128, output_dim = 3)\n    elif model_name == 'BaselineLSTM':\n        print('[INFO]Model: BaselineLSTM')\n        # model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1)\n        model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 4)\n    elif model_name == 'FinetunedBertOnText':\n        print('[INFO]Model: FinetunedBertOnText')\n        model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)\n    elif model_name == 'FinetunedRoBertaOnText':\n        print('[INFO]Model: FinetunedRoBertaOnText')\n        model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)\n    elif model_name == 'FinetunedBartOnText':\n        print('[INFO]Model: FinetunedBartOnText')\n        model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3)\n    elif model_name == 'ZeroShotSentimentDiscovery':\n        print(f'[INFO]Model: ZeroShotSentimentDiscovery, using classifer:{classifier_name}, using generator: {decoder_name}')\n        pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')\n        if decoder_name == 'BrainTranslator':\n            decoder = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n        elif decoder_name == 'BrainTranslatorNaive':\n            decoder = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n        decoder.load_state_dict(torch.load(decoder_checkpoint))\n        \n        if classifier_name == 'pretrain_Bert':\n            classifier = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)\n        elif classifier_name == 'pretrain_Bart':\n            classifier = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3)\n        elif classifier_name == 'pretrain_RoBerta':\n            classifier = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)\n\n        classifier.load_state_dict(torch.load(classifier_checkpoint))\n\n        model = ZeroShotSentimentDiscovery(decoder, classifier, decoder_tokenizer, sentiment_tokenizer, device = device)\n        model.to(device)\n\n    if model_name != 'ZeroShotSentimentDiscovery':\n        # load model and send to device\n        model.load_state_dict(torch.load(checkpoint_path))\n        model.to(device)\n\n    ''' set up dataloader '''\n    # test dataset\n    test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = 'unique_sent')\n\n    dataset_sizes = {'test': len(test_set)}\n    # print('[INFO]train_set size: ', len(train_set))\n    print('[INFO]test_set size: ', len(test_set))\n    \n    test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4)\n    # dataloaders\n    dataloaders = {'test':test_dataloader}\n    \n    ''' set up optimizer and scheduler'''\n    optimizer_step1 = None\n    exp_lr_scheduler_step1 = None\n\n    ''' set up loss function '''\n    criterion = nn.CrossEntropyLoss()\n\n    print('=== start training ... ===')\n    # return best loss model from step1 training\n    model = eval_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, tokenizer = tokenizer)\n"
  },
  {
    "path": "model_decoding.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.data\nfrom transformers import BartTokenizer, BartForConditionalGeneration, BartConfig\nimport math\nimport numpy as np\n\n\"\"\" main architecture for open vocabulary EEG-To-Text decoding\"\"\"\nclass BrainTranslator(nn.Module):\n    def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):\n        super(BrainTranslator, self).__init__()\n        \n        self.pretrained = pretrained_layers\n        # additional transformer encoder, following BART paper about \n        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)\n        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)\n        \n        # print('[INFO]adding positional embedding')\n        # self.positional_embedding = PositionalEncoding(in_feature)\n\n        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)\n\n    def addin_forward(self,input_embeddings_batch,  input_masks_invert):\n        \"\"\"input_embeddings_batch: batch_size*Seq_len*840\"\"\"\n        \"\"\"input_mask: 1 is not masked, 0 is masked\"\"\"\n        \"\"\"input_masks_invert: 1 is masked, 0 is not masked\"\"\"\n\n        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch)\n        # use src_key_padding_masks\n        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert)\n\n        # encoded_embedding = self.additional_encoder(input_embeddings_batch)\n        encoded_embedding = F.relu(self.fc1(encoded_embedding))\n        return encoded_embedding\n\n    @torch.no_grad()\n    def generate(\n            self,\n            input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted,\n            generation_config = None,\n            logits_processor = None,\n            stopping_criteria = None,\n            prefix_allowed_tokens_fn= None,\n            synced_gpus= None,\n            assistant_model = None,\n            streamer= None,\n            negative_prompt_ids= None,\n            negative_prompt_attention_mask = None,\n            **kwargs,\n    ):\n        encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)\n        output=self.pretrained.generate(\n            inputs_embeds = encoded_embedding,\n            attention_mask = input_masks_batch[:,:encoded_embedding.shape[1]],\n            labels = target_ids_batch_converted,\n            return_dict = True,\n            generation_config=generation_config,\n            logits_processor=logits_processor,\n            stopping_criteria=stopping_criteria,\n            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n            synced_gpus=synced_gpus,\n            assistant_model=assistant_model,\n            streamer=streamer,\n            negative_prompt_ids=negative_prompt_ids,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            **kwargs,)\n\n        return output\n\n    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):\n        encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)\n        # print(f'forward:{input_embeddings_batch.shape,input_masks_batch.shape,input_masks_invert.shape,target_ids_batch_converted.shape,encoded_embedding.shape}')\n        out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch,\n                              return_dict = True, labels = target_ids_batch_converted)\n        \n        return out\n\n\nfrom transformers import T5Tokenizer\n\"\"\" main architecture for open vocabulary EEG-To-Text decoding\"\"\"\nclass T5Translator(nn.Module):\n    def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):\n        super(T5Translator, self).__init__()\n        \n        self.pretrained = pretrained_layers\n\n        self.tokenizer = T5Tokenizer.from_pretrained(\"t5-large\")\n        \n        # additional transformer encoder, following BART paper about \n        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)\n        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)\n        \n        # print('[INFO]adding positional embedding')\n        # self.positional_embedding = PositionalEncoding(in_feature)\n\n        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)\n\n    def addin_forward(self,input_embeddings_batch,  input_masks_invert):\n        \"\"\"input_embeddings_batch: batch_size*Seq_len*840\"\"\"\n        \"\"\"input_mask: 1 is not masked, 0 is masked\"\"\"\n        \"\"\"input_masks_invert: 1 is masked, 0 is not masked\"\"\"\n\n        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch)\n        # use src_key_padding_masks\n        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert)\n\n        # encoded_embedding = self.additional_encoder(input_embeddings_batch)\n        encoded_embedding = F.relu(self.fc1(encoded_embedding))\n        return encoded_embedding\n\n    @torch.no_grad()\n    def generate(\n            self,\n            input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted,\n            generation_config = None,\n            logits_processor = None,\n            stopping_criteria = None,\n            prefix_allowed_tokens_fn= None,\n            synced_gpus= None,\n            assistant_model = None,\n            streamer= None,\n            negative_prompt_ids= None,\n            negative_prompt_attention_mask = None,\n            **kwargs,\n    ):\n        encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)\n\n\n        input_ids = self.tokenizer(\"transcribe in English: \", return_tensors=\"pt\").input_ids.to(encoded_embedding.device)\n        self.task_embedding = self.pretrained.shared(input_ids).to(encoded_embedding.device)\n        task_embedding = self.task_embedding.repeat(encoded_embedding.size(0), 1, 1).to(encoded_embedding.device)\n        encoded_embedding = torch.cat((task_embedding, encoded_embedding), dim=1)\n        input_masks_batch = torch.cat((torch.ones(encoded_embedding.size(0), task_embedding.size(1)).to(encoded_embedding.device), input_masks_batch), dim=1)\n\n\n        output=self.pretrained.generate(\n            inputs_embeds = encoded_embedding,\n            attention_mask = input_masks_batch[:,:encoded_embedding.shape[1]],\n            labels = target_ids_batch_converted,\n            return_dict = True,\n            generation_config=generation_config,\n            logits_processor=logits_processor,\n            stopping_criteria=stopping_criteria,\n            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n            synced_gpus=synced_gpus,\n            assistant_model=assistant_model,\n            streamer=streamer,\n            negative_prompt_ids=negative_prompt_ids,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            **kwargs,)\n\n        return output\n\n    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):\n        encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)\n        \n        # task definition\n        input_ids = self.tokenizer(\"transcribe in English: \", return_tensors=\"pt\").input_ids.to(encoded_embedding.device)\n        self.task_embedding = self.pretrained.shared(input_ids).to(encoded_embedding.device)\n        task_embedding = self.task_embedding.repeat(encoded_embedding.size(0), 1, 1).to(encoded_embedding.device)\n        encoded_embedding = torch.cat((task_embedding, encoded_embedding), dim=1)\n        input_masks_batch = torch.cat((torch.ones(encoded_embedding.size(0), task_embedding.size(1)).to(encoded_embedding.device), input_masks_batch), dim=1)\n\n        out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch,\n                              return_dict = True, labels = target_ids_batch_converted)\n        return out\n\n\n\"\"\" crippled open vocabulary EEG-To-Text decoding model w/o additional MTE encoder\"\"\"\nclass BrainTranslatorNaive(nn.Module):\n    def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):\n        super(BrainTranslatorNaive, self).__init__()\n        '''no additional transformer encoder version'''\n        self.pretrained = pretrained_layers\n        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)\n\n    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):\n        \"\"\"input_embeddings_batch: batch_size*Seq_len*840\"\"\"\n        \"\"\"input_mask: 1 is not masked, 0 is masked\"\"\"\n        \"\"\"input_masks_invert: 1 is masked, 0 is not masked\"\"\"\n        encoded_embedding = F.relu(self.fc1(input_embeddings_batch))\n        out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted)                    \n        return out\n\n\n\"\"\" helper modules \"\"\"\n# modified from BertPooler\nclass Pooler(nn.Module):\n    def __init__(self, hidden_size):\n        super().__init__()\n        self.dense = nn.Linear(hidden_size, hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html\nclass PositionalEncoding(nn.Module):\n\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        # print('[DEBUG] input size:', x.size())\n        # print('[DEBUG] positional embedding size:', self.pe.size())\n        x = x + self.pe[:x.size(0), :]\n        # print('[DEBUG] output x with pe size:', x.size())\n        return self.dropout(x)\n\n\n\"\"\" Miscellaneous (not working well) \"\"\"\nclass BrainTranslatorBert(nn.Module):\n    def __init__(self, pretrained_layers, in_feature = 840, hidden_size = 768):\n        super(BrainTranslatorBert, self).__init__()\n\n        self.pretrained_Bert = pretrained_layers\n        self.fc1 = nn.Linear(in_feature, hidden_size)\n\n    def forward(self, input_embeddings_batch, input_masks_batch, target_ids_batch):\n        embedding = F.relu(self.fc1(input_embeddings_batch))\n        out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = target_ids_batch, return_dict = True)\n        return out\n\nclass EEG2BertMapping(nn.Module):\n    def __init__(self, in_feature = 840, hidden_size = 512, out_feature = 768):\n        super(EEG2BertMapping, self).__init__()\n        self.fc1 = nn.Linear(in_feature, hidden_size)\n        self.fc2 = nn.Linear(hidden_size, out_feature)\n\n    def forward(self, x):\n        out = F.relu(self.fc1(x))\n        out = self.fc2(out)\n        return out\n\nclass ContrastiveBrainTextEncoder(nn.Module):\n    def __init__(self, pretrained_text_encoder, in_feature = 840, eeg_encoder_nhead=8, eeg_encoder_dim_feedforward = 2048, embed_dim = 768):\n        super(ContrastiveBrainTextEncoder, self).__init__()\n        # EEG Encoder\n        self.positional_embedding = PositionalEncoding(in_feature)\n        self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=eeg_encoder_nhead,  dim_feedforward = eeg_encoder_dim_feedforward, batch_first=True)\n        self.EEG_Encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)\n        self.EEG_pooler = Pooler(in_feature)\n        self.ln_final = nn.LayerNorm(in_feature) # to be considered\n        \n        # project to text embedding\n        self.EEG_projection = nn.Parameter(torch.empty(in_feature, embed_dim))\n        \n        # Text Encoder\n        self.TextEncoder = pretrained_text_encoder\n        \n        # learned temperature parameter\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n    def forward(self, input_EEG_features, input_EEG_attn_mask, input_ids, input_text_attention_masks):\n        # add positional embedding\n        input_EEG_features = self.positional_embedding(input_EEG_features)\n        # get EEG feature embedding\n        EEG_hiddenstates = self.EEG_Encoder(input_EEG_features,  src_key_padding_mask = input_EEG_attn_mask)\n        EEG_hiddenstates = self.ln_final(EEG_hiddenstates)\n        EEG_features = self.EEG_pooler(EEG_hiddenstates) # [N, 840]\n\n        # project to text embed size\n        EEG_features = EEG_features @ self.EEG_projection # [N, 768]\n\n        # get text feature embedding\n        Text_features = self.TextEncoder(input_ids = input_ids, attention_mask = input_text_attention_masks, return_dict = True).pooler_output # [N, 768]\n        \n        # normalized features\n        EEG_features = EEG_features / EEG_features.norm(dim=-1, keepdim=True) # [N, 768]\n        Text_features = Text_features / Text_features.norm(dim=-1, keepdim=True) # [N, 768]\n\n        # cosine similarity as logits\n        logit_scale = self.logit_scale.exp() \n        logits_per_EEG = logit_scale * EEG_features @ Text_features.t() # [N, N]\n        logits_per_text = logit_scale * Text_features @ EEG_features.t() # [N, N]\n\n        return logits_per_EEG, logits_per_text\n"
  },
  {
    "path": "model_sentiment.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.data\nfrom transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertForSequenceClassification\nimport math\nimport numpy as np\nfrom torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n\n\"\"\"MLP baseline using sentence level eeg\"\"\"\n# using sent level EEG, MLP baseline for sentiment\nclass BaselineMLPSentence(nn.Module):\n    def __init__(self, input_dim = 840, hidden_dim = 128, output_dim = 3):\n        super(BaselineMLPSentence, self).__init__()\n        self.fc1 = nn.Linear(input_dim, hidden_dim) \n        self.relu1 = nn.ReLU()\n        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n        self.relu2 = nn.ReLU()\n        self.fc3 = nn.Linear(hidden_dim, output_dim) # positive, negative, neutral  \n        self.dropout = nn.Dropout(0.25)\n\n    def forward(self, x):\n        out = self.fc1(x)\n        out = self.relu1(out)\n        out = self.fc2(out)\n        out = self.relu2(out)\n        out = self.dropout(out)\n        out = self.fc3(out)\n        return out\n\n\n\"\"\"bidirectional LSTM baseline using word level eeg\"\"\"\nclass BaselineLSTM(nn.Module):\n    def __init__(self, input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1):\n        super(BaselineLSTM, self).__init__()\n        \n        self.hidden_dim = hidden_dim\n\n        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers = 1, batch_first = True, bidirectional = True)\n\n        self.hidden2sentiment = nn.Linear(hidden_dim*2, output_dim)\n\n    def forward(self, x_packed):\n        # input: (N,seq_len,input_dim)\n        # print(x_packed.data.size())\n        lstm_out, _ = self.lstm(x_packed)\n        last_hidden_state = pad_packed_sequence(lstm_out, batch_first = True)[0][:,-1,:]\n        # print(last_hidden_state.size())\n        out = self.hidden2sentiment(last_hidden_state)\n        return out\n\n\"\"\" Bert Baseline: Finetuning from a pretrained language model Bert\"\"\"\nclass NaiveFineTunePretrainedBert(nn.Module):\n    def __init__(self, input_dim = 840, hidden_dim = 768, output_dim = 3, pretrained_checkpoint = None):\n        super(NaiveFineTunePretrainedBert, self).__init__()\n        # mapping hidden states dimensioin\n        self.fc1 = nn.Linear(input_dim, hidden_dim)\n        self.pretrained_Bert = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)\n        \n        if pretrained_checkpoint is not None:\n            self.pretrained_Bert.load_state_dict(torch.load(pretrained_checkpoint))\n\n    def forward(self, input_embeddings_batch, input_masks_batch, labels):\n        embedding = F.relu(self.fc1(input_embeddings_batch))\n        out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = labels, return_dict = True)\n        return out\n\n\"\"\" Finetuning from a pretrained language model BART, two step training\"\"\"\nclass FineTunePretrainedTwoStep(nn.Module):\n    def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):\n        super(FineTunePretrainedTwoStep, self).__init__()\n        \n        self.pretrained_layers = pretrained_layers\n        # additional transformer encoder, following BART paper about \n        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)\n        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)\n        \n        # NOTE: add positional embedding?\n        # print('[INFO]adding positional embedding')\n        # self.positional_embedding = PositionalEncoding(in_feature)\n\n        self.fc1 = nn.Linear(in_feature, d_model)\n\n    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, labels):\n        \"\"\"input_embeddings_batch: batch_size*Seq_len*840\"\"\"\n        \"\"\"input_mask: 1 is not masked, 0 is masked\"\"\"\n        \"\"\"input_masks_invert: 1 is masked, 0 is not masked\"\"\"\n        \"\"\"labels: sentitment labels 0,1,2\"\"\"\n        \n        # NOTE: add positional embedding?\n        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) \n\n        # use src_key_padding_masks\n        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) \n        # encoded_embedding = self.additional_encoder(input_embeddings_batch) \n        \n        encoded_embedding = F.relu(self.fc1(encoded_embedding))\n        out = self.pretrained_layers(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = labels)                    \n        \n        return out\n\n\"\"\" Zero-shot sentiment discovery using a finetuned generation model and a sentiment model pretrained on text \"\"\"\nclass ZeroShotSentimentDiscovery(nn.Module):\n    def __init__(self, brain2text_translator, sentiment_classifier, translation_tokenizer, sentiment_tokenizer, device = 'cpu'):\n        # only for inference\n        super(ZeroShotSentimentDiscovery, self).__init__()\n        \n        self.brain2text_translator = brain2text_translator\n        self.sentiment_classifier = sentiment_classifier\n        self.translation_tokenizer = translation_tokenizer\n        self.sentiment_tokenizer = sentiment_tokenizer\n        self.device = device\n    \n\n    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):\n        \"\"\"input_embeddings_batch: batch_size*Seq_len*840\"\"\"\n        \"\"\"input_mask: 1 is not masked, 0 is masked\"\"\"\n        \"\"\"input_masks_invert: 1 is masked, 0 is not masked\"\"\"\n        \"\"\"labels: sentitment labels 0,1,2\"\"\"\n        \n        def logits2PredString(logits):\n            probs = logits[0].softmax(dim = 1)\n            # print('probs size:', probs.size())\n            values, predictions = probs.topk(1)\n            # print('predictions before squeeze:',predictions.size())\n            predictions = torch.squeeze(predictions)\n            predict_string = self.translation_tokenizer.decode(predictions)\n            return predict_string\n\n        # only works on batch is one\n        assert input_embeddings_batch.size()[0] == 1\n\n        seq2seqLMoutput = self.brain2text_translator(input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted)\n        predict_string = logits2PredString(seq2seqLMoutput.logits)\n        predict_string = predict_string.split('</s></s>')[0]\n        predict_string = predict_string.replace('<s>','')\n        print('predict string:', predict_string)\n        re_tokenized = self.sentiment_tokenizer(predict_string, return_tensors='pt', return_attention_mask = True)\n        input_ids = re_tokenized['input_ids'].to(self.device) # batch = 1\n        attn_mask = re_tokenized['attention_mask'].to(self.device) # batch = 1\n\n        out = self.sentiment_classifier(input_ids = input_ids, attention_mask = attn_mask, return_dict = True, labels = sentiment_labels)\n\n        return out\n\n\n\"\"\" Miscellaneous: jointly learn generation and classification (not working well) \"\"\"\nclass BartClassificationHead(nn.Module):\n    # from transformers: https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html\n    \"\"\"Head for sentence-level classification tasks.\"\"\"\n    def __init__(\n        self,\n        input_dim: int,\n        inner_dim: int,\n        num_classes: int,\n        pooler_dropout: float,\n    ):\n        super().__init__()\n        self.dense = nn.Linear(input_dim, inner_dim)\n        self.dropout = nn.Dropout(p=pooler_dropout)\n        self.out_proj = nn.Linear(inner_dim, num_classes)\n\n    def forward(self, hidden_states: torch.Tensor):\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.dense(hidden_states)\n        hidden_states = torch.tanh(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.out_proj(hidden_states)\n        return hidden_states\n\nclass JointBrainTranslatorSentimentClassifier(nn.Module):\n    def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048, num_labels = 3):\n        super(JointBrainTranslatorSentimentClassifier, self).__init__()\n        \n        self.pretrained_generator = pretrained_layers\n        # additional transformer encoder, following BART paper about \n        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)\n        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)\n        self.fc1 = nn.Linear(in_feature, d_model)\n        self.num_labels = num_labels\n\n        self.pooler = Pooler(d_model)\n        self.classifier = BartClassificationHead(input_dim = d_model, inner_dim = d_model, num_classes = num_labels, pooler_dropout = pretrained_layers.config.classifier_dropout)\n\n    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):\n        \"\"\"input_embeddings_batch: batch_size*Seq_len*840\"\"\"\n        \"\"\"input_mask: 1 is not masked, 0 is masked\"\"\"\n        \"\"\"input_masks_invert: 1 is masked, 0 is not masked\"\"\"\n        \n        # NOTE: add positional embedding?\n        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) \n\n        # use src_key_padding_masks\n        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) \n        \n        # encoded_embedding = self.additional_encoder(input_embeddings_batch) \n        encoded_embedding = F.relu(self.fc1(encoded_embedding))\n        LMoutput = self.pretrained_generator(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted, output_hidden_states = True)                    \n        hidden_states = LMoutput.decoder_hidden_states # N, seq_len, hidden_dim\n        # print('hidden states len:', len(hidden_states))\n        last_hidden_states = hidden_states[-1]\n        # print('last hidden states size:', last_hidden_states.size())\n        sentence_representation = self.pooler(last_hidden_states)\n \n        classification_logits = self.classifier(sentence_representation) \n        loss_fct = nn.CrossEntropyLoss()\n        classification_loss = loss_fct(classification_logits.view(-1, self.num_labels), sentiment_labels.view(-1))\n        classification_output = {'loss':classification_loss,'logits':classification_logits}\n        # print('successful one forward!!!!')\n        return LMoutput, classification_output\n\n\n\"\"\" helper modules \"\"\"\n# modified from BertPooler\nclass Pooler(nn.Module):\n    def __init__(self, hidden_size):\n        super().__init__()\n        self.dense = nn.Linear(hidden_size, hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html\nclass PositionalEncoding(nn.Module):\n\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        # print('[DEBUG] input size:', x.size())\n        # print('[DEBUG] positional embedding size:', self.pe.size())\n        x = x + self.pe[:x.size(0), :]\n        # print('[DEBUG] output x with pe size:', x.size())\n        return self.dropout(x)\n\n"
  },
  {
    "path": "scripts/eval_decoding_1.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \\\n    --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \\\n    --test_input EEG \\\n    --train_input EEG \\\n    -cuda cuda:0\n\nCUDA_VISIBLE_DEVICES=0 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \\\n    --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \\\n    --test_input noise \\\n    --train_input EEG \\\n    -cuda cuda:0\n\n"
  },
  {
    "path": "scripts/eval_decoding_2.sh",
    "content": "CUDA_VISIBLE_DEVICES=1 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \\\n    --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \\\n    --test_input EEG \\\n    --train_input EEG \\\n    -cuda cuda:0\n\nCUDA_VISIBLE_DEVICES=1 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \\\n    --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \\\n    --test_input noise \\\n    --train_input EEG \\\n    -cuda cuda:0\n\n"
  },
  {
    "path": "scripts/eval_decoding_3.sh",
    "content": "CUDA_VISIBLE_DEVICES=2 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \\\n    --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \\\n    --test_input EEG \\\n    --train_input noise \\\n    -cuda cuda:0\n\nCUDA_VISIBLE_DEVICES=2 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \\\n    --config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \\\n    --test_input noise \\\n    --train_input noise \\\n    -cuda cuda:0\n\n"
  },
  {
    "path": "scripts/eval_decoding_4.sh",
    "content": "CUDA_VISIBLE_DEVICES=3 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \\\n    --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \\\n    --test_input EEG \\\n    --train_input noise \\\n    -cuda cuda:0\n\nCUDA_VISIBLE_DEVICES=3 python3 eval_decoding.py \\\n    --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \\\n    --config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \\\n    --test_input noise \\\n    --train_input noise \\\n    -cuda cuda:0\n\n"
  },
  {
    "path": "scripts/eval_sentiment_zeroshot_pipeline.sh",
    "content": "python3 eval_sentiment.py --model_name ZeroShotSentimentDiscovery \\\n    --decoder_checkpoint_path ./checkpoints/decoding/best/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.pt \\\n    --classifier_checkpoint_path ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt \\\n    --decoder_config_path ./config/decoding/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.json \\\n    --classifier_config_path ./config/text_sentiment_classifier/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.json \\\n    --cuda cuda:0"
  },
  {
    "path": "scripts/prepare_dataset.sh",
    "content": "echo \"This scirpt construct .pickle files from .mat files from ZuCo dataset.\"\necho \"This script also generates tenary sentiment_labels.json file for ZuCo task1-SR v1.0 and ternary_dataset.json from filtered StanfordSentimentTreebank\"\necho \"Note: the sentences in ZuCo task1-SR do not overlap with sentences in filtered StanfordSentimentTreebank \"\necho \"Note: This process can take time, please be patient...\"\n\npython3 ./util/construct_dataset_mat_to_pickle_v1.py -t task1-SR\npython3 ./util/construct_dataset_mat_to_pickle_v1.py -t task2-NR\npython3 ./util/construct_dataset_mat_to_pickle_v1.py -t task3-TSR\npython3 ./util/construct_dataset_mat_to_pickle_v2.py\n\npython3 ./util/get_sentiment_labels.py\npython3 ./util/get_SST_ternary_dataset.py\n"
  },
  {
    "path": "scripts/train_decoding.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python3 train_decoding.py --model_name BrainTranslator \\\n    --task_name task1_task2_task3 \\\n    --one_step \\\n    --pretrained \\\n    --not_load_step1_checkpoint \\\n    --num_epoch_step1 20 \\\n    --num_epoch_step2 30 \\\n    --train_input noise \\\n    -lr1 0.00002 \\\n    -lr2 0.00002 \\\n    -b 32 \\\n    -s ./checkpoints/decoding \\\n\nCUDA_VISIBLE_DEVICES=0,1 python3 train_decoding.py --model_name T5Translator \\\n    --task_name task1_task2_task3 \\\n    --one_step \\\n    --pretrained \\\n    --not_load_step1_checkpoint \\\n    --num_epoch_step1 20 \\\n    --num_epoch_step2 30 \\\n    --train_input noise \\\n    -lr1 0.00002 \\\n    -lr2 0.00002 \\\n    -b 32 \\\n    -s ./checkpoints/decoding \\\n"
  },
  {
    "path": "scripts/train_decoding_1.sh",
    "content": "CUDA_VISIBLE_DEVICES=2,3 python3 train_decoding.py --model_name T5Translator \\\n    --task_name task1_task2_taskNRv2 \\\n    --one_step \\\n    --pretrained \\\n    --not_load_step1_checkpoint \\\n    --num_epoch_step1 20 \\\n    --num_epoch_step2 30 \\\n    --train_input EEG \\\n    -lr1 0.00002 \\\n    -lr2 0.00002 \\\n    -b 32 \\\n    -s ./checkpoints/decoding \\\n\nCUDA_VISIBLE_DEVICES=2,3 python3 train_decoding.py --model_name T5Translator \\\n    --task_name task1_task2_task3 \\\n    --one_step \\\n    --pretrained \\\n    --not_load_step1_checkpoint \\\n    --num_epoch_step1 20 \\\n    --num_epoch_step2 30 \\\n    --train_input EEG \\\n    -lr1 0.00002 \\\n    -lr2 0.00002 \\\n    -b 32 \\\n    -s ./checkpoints/decoding \\\n"
  },
  {
    "path": "scripts/train_eeg_sentiment_baseline.sh",
    "content": "python3 train_sentiment_baseline.py --model_name BaselineMLP --num_epoch 20 -lr 0.00005 -b 32 -s ./checkpoints/eeg_sentiment -cuda cuda:0"
  },
  {
    "path": "scripts/train_eval_zeroshot_pipeline.sh",
    "content": "\necho \"###################################\"\necho \"Training decoder: BART, task1-SR...\"\necho \"###################################\"\necho \"\"\npython3 train_decoding.py --model_name BrainTranslator \\\n    --task_name task1 \\\n    --one_step \\\n    --pretrained \\\n    --not_load_step1_checkpoint \\\n    --num_epoch_step1 20 \\\n    --num_epoch_step2 30 \\\n    -lr1 0.00005 \\\n    -lr2 0.0000005 \\\n    -b 32 \\\n    -s ./checkpoints/decoding \\\n    -cuda cuda:0\n\necho \"###################################\"\necho \"Training classifier: BART, filtered Stanford Sentiment Treebank...\"\necho \"###################################\"\necho \"\"\npython3 train_sentiment_textbased.py \\\n    --dataset_name SST \\\n    --model_name pretrain_Bart \\\n    --num_epoch 20 \\\n    -lr 0.0001 \\\n    -b 32 \\\n    -s ./checkpoints/text_sentiment_classifier \\\n    -cuda cuda:0\n\necho \"###################################\"\necho \"Evaluating Zero-shot pipeline: DEC(BART) + CLS(BART)\"\necho \"###################################\"\necho \"\"\npython3 eval_sentiment.py --model_name ZeroShotSentimentDiscovery \\\n    --decoder_checkpoint_path ./checkpoints/decoding/best/task1_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.pt \\\n    --classifier_checkpoint_path ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt \\\n    --decoder_config_path ./config/decoding/task1_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.json \\\n    --classifier_config_path ./config/text_sentiment_classifier/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.json \\\n    --cuda cuda:0"
  },
  {
    "path": "scripts/train_text_sentiment_classifier.sh",
    "content": "python3 train_sentiment_textbased.py \\\n    --dataset_name SST \\\n    --model_name pretrain_Bart \\\n    --num_epoch 20 \\\n    -lr 0.0001 \\\n    -b 32 \\\n    -s ./checkpoints/text_sentiment_classifier \\\n    -cuda cuda:0"
  },
  {
    "path": "train_decoding.py",
    "content": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\nimport pickle\nimport json\nimport matplotlib.pyplot as plt\nfrom glob import glob\nimport time\nimport copy\nfrom tqdm import tqdm\nfrom transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, PegasusForConditionalGeneration, PegasusTokenizer, T5Tokenizer, T5ForConditionalGeneration, BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderConfig, EncoderDecoderModel\nfrom data import ZuCo_dataset\nfrom model_decoding import BrainTranslator, BrainTranslatorNaive, T5Translator\nfrom config import get_config\n\ndef train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/decoding/best/temp_decoding.pt', checkpoint_path_last = './checkpoints/decoding/last/temp_decoding.pt'):\n    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n    since = time.time()\n      \n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = 100000000000\n\n    for epoch in range(num_epochs):\n        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n        print('-' * 10)\n\n        # Each epoch has a training and validation phase\n        for phase in ['train', 'dev']:\n            if phase == 'train':\n                model.train()  # Set model to training mode\n            else:\n                model.eval()   # Set model to evaluate mode\n\n            running_loss = 0.0\n\n            # Iterate over data.\n            for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels in tqdm(dataloaders[phase]):\n                \n                # load in batch\n                input_embeddings_batch = input_embeddings.to(device).float()\n                input_masks_batch = input_masks.to(device)\n                input_mask_invert_batch = input_mask_invert.to(device)\n                target_ids_batch = target_ids.to(device)\n                \"\"\"replace padding ids in target_ids with -100\"\"\"\n                target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100\n\n                # zero the parameter gradients\n                optimizer.zero_grad()\n\n                # forward\n    \t        # track history if only in train\n                with torch.set_grad_enabled(phase == 'train'):\n                    seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch)\n\n                    \"\"\"calculate loss\"\"\"\n                    # logits = seq2seqLMoutput.logits # 8*48*50265\n                    # logits = logits.permute(0,2,1) # 8*50265*48\n\n                    # loss = criterion(logits, target_ids_batch_label) # calculate cross entropy loss only on encoded target parts\n                    # NOTE: my criterion not used\n                    loss = seq2seqLMoutput.loss # use the BART language modeling loss\n\n                    # \"\"\"check prediction, instance 0 of each batch\"\"\"\n                    # print('target size:', target_ids_batch.size(), ',original logits size:', logits.size(), ',target_mask size', target_mask_batch.size())\n                    # logits = logits.permute(0,2,1)\n                    # for idx in [0]:\n                    #     print(f'-- instance {idx} --')\n                    #     # print('permuted logits size:', logits.size())\n                    #     probs = logits[idx].softmax(dim = 1)\n                    #     # print('probs size:', probs.size())\n                    #     values, predictions = probs.topk(1)\n                    #     # print('predictions before squeeze:',predictions.size())\n                    #     predictions = torch.squeeze(predictions)\n                    #     # print('predictions:',predictions)\n                    #     # print('target mask:', target_mask_batch[idx])\n                    #     # print('[DEBUG]target tokens:',tokenizer.decode(target_ids_batch_copy[idx]))\n                    #     print('[DEBUG]predicted tokens:',tokenizer.decode(predictions))\n                \n                    # backward + optimize only if in training phase\n                    if phase == 'train':\n                        # with torch.autograd.detect_anomaly():\n                        loss.sum().backward()\n                        optimizer.step()\n\n                # statistics\n                running_loss += loss.sum().item() * input_embeddings_batch.size()[0] # batch loss\n                # print('[DEBUG]loss:',loss.item())\n                # print('#################################')\n                \n\n            if phase == 'train':\n                scheduler.step()\n\n            epoch_loss = running_loss / dataset_sizes[phase]\n\n            print('{} Loss: {:.4f}'.format(phase, epoch_loss))\n\n            # deep copy the model\n            if phase == 'dev' and epoch_loss < best_loss:\n                best_loss = epoch_loss\n                best_model_wts = copy.deepcopy(model.state_dict())\n                '''save checkpoint'''\n                torch.save(model.state_dict(), checkpoint_path_best)\n                print(f'update best on dev checkpoint: {checkpoint_path_best}')\n                # with torch.set_grad_enabled(False):\n                #     traced_model_1 = torch.jit.trace(model, (torch.rand(1, 56, 840).to(device), torch.randint(1, 56).to(device), torch.rand(1, 56).to(device), torch.rand(1, 56).to(device)))\n                #     traced_model_32 = torch.jit.trace(model, (torch.rand(32, 56, 840).to(device), torch.randint(32, 56).to(device), torch.rand(32, 56).to(device), torch.rand(32, 56).to(device)))\n                # torch.jit.save(traced_model_1, checkpoint_path_best[:-3]+'_1_jit.pt')\n                # torch.jit.save(traced_model_32, checkpoint_path_best[:-3]+'_32_jit.pt')\n        print()\n\n    time_elapsed = time.time() - since\n    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n    print('Best val loss: {:4f}'.format(best_loss))\n    torch.save(model.state_dict(), checkpoint_path_last)\n    print(f'update last checkpoint: {checkpoint_path_last}')\n\n    # load best model weights\n    model.load_state_dict(best_model_wts)\n    return model\n\ndef show_require_grad_layers(model):\n    print()\n    print(' require_grad layers:')\n    # sanity check\n    for name, param in model.named_parameters():\n        if param.requires_grad:\n            print(' ', name)\n\nif __name__ == '__main__':\n    args = get_config('train_decoding')\n\n    ''' config param'''\n    dataset_setting = 'unique_sent'\n    \n    num_epochs_step1 = args['num_epoch_step1']\n    num_epochs_step2 = args['num_epoch_step2']\n    step1_lr = args['learning_rate_step1']\n    step2_lr = args['learning_rate_step2']\n    \n    batch_size = args['batch_size']\n    \n    model_name = args['model_name']\n    # model_name = 'BrainTranslatorNaive' # with no additional transformers\n    # model_name = 'BrainTranslator' \n    \n    # task_name = 'task1'\n    # task_name = 'task1_task2'\n    # task_name = 'task1_task2_task3'\n    # task_name = 'task1_task2_taskNRv2'\n    task_name = args['task_name']\n    train_input = args['train_input']\n    print(\"train_input is:\", train_input)   \n    save_path = args['save_path']\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n\n    skip_step_one = args['skip_step_one']\n    load_step1_checkpoint = args['load_step1_checkpoint']\n    use_random_init = args['use_random_init']\n    device_ids = [0] # device setting\n\n    if use_random_init and skip_step_one:\n        step2_lr = 5*1e-4\n        \n    print(f'[INFO]using model: {model_name}')\n    \n    if skip_step_one:\n        save_name = f'{task_name}_finetune_{model_name}_skipstep1_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}_{train_input}'\n    else:\n        save_name = f'{task_name}_finetune_{model_name}_2steptraining_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}_{train_input}'\n    \n    if use_random_init:\n        save_name = 'randinit_' + save_name\n\n    save_path_best = os.path.join(save_path, 'best')\n    if not os.path.exists(save_path_best):\n        os.makedirs(save_path_best)\n\n    output_checkpoint_name_best = os.path.join(save_path_best, f'{save_name}.pt')\n\n    save_path_last = os.path.join(save_path, 'last')\n    if not os.path.exists(save_path_last):\n        os.makedirs(save_path_last)\n\n    output_checkpoint_name_last = os.path.join(save_path_last, f'{save_name}.pt')\n\n    # subject_choice = 'ALL\n    subject_choice = args['subjects']\n    print(f'![Debug]using {subject_choice}')\n    # eeg_type_choice = 'GD\n    eeg_type_choice = args['eeg_type']\n    print(f'[INFO]eeg type {eeg_type_choice}')\n    # bands_choice = ['_t1'] \n    # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] \n    bands_choice = args['eeg_bands']\n    print(f'[INFO]using bands {bands_choice}')\n\n\n    \n    ''' set random seeds '''\n    seed_val = 312\n    np.random.seed(seed_val)\n    torch.manual_seed(seed_val)\n    torch.cuda.manual_seed_all(seed_val)\n\n\n    ''' set up device '''\n    # use cuda\n    if torch.cuda.is_available():  \n        # dev = \"cuda:3\" \n        dev = args['cuda'] \n    else:  \n        dev = \"cpu\"\n    # CUDA_VISIBLE_DEVICES=0,1,2,3  \n    device = torch.device(dev)\n    print(f'[INFO]using device {dev}')\n    print()\n\n    ''' set up dataloader '''\n    whole_dataset_dicts = []\n    if 'task1' in task_name:\n        dataset_path_task1 = '/data/johj/ZuCo_data/task1-SR/task1_source.pkl'\n        with open(dataset_path_task1, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n    if 'task2' in task_name:\n        dataset_path_task2 = '/data/johj/ZuCo_data/task2-NR/task2_source.pkl' \n        with open(dataset_path_task2, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n    if 'task3' in task_name:\n        dataset_path_task3 = '/data/johj/ZuCo_data/task3-TSR/task3_source.pkl' \n        with open(dataset_path_task3, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n    if 'taskNRv2' in task_name:\n        dataset_path_taskNRv2 = '/data/johj/ZuCo_data/task2-NR-2.0/taskNRv2_source.pkl' \n        with open(dataset_path_taskNRv2, 'rb') as handle:\n            whole_dataset_dicts.append(pickle.load(handle))\n\n    print()\n\n    \"\"\"save config\"\"\"\n    cfg_dir = './config/decoding/'\n\n    if not os.path.exists(cfg_dir):\n        os.makedirs(cfg_dir)\n\n    with open(os.path.join(cfg_dir,f'{save_name}.json'), 'w') as out_config:\n        json.dump(args, out_config, indent = 4)\n\n    if model_name in ['BrainTranslator','BrainTranslatorNaive']:\n        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')\n\n    elif model_name == 'PegasusTranslator':\n        tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum')\n    \n    elif model_name == 'T5Translator':\n        tokenizer = T5Tokenizer.from_pretrained(\"t5-large\")\n        #tokenizer.set_prefix_tokens(language='english')\n\n    # train dataset\n    train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=train_input)\n    # dev dataset\n    dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=train_input)\n    # test dataset\n    # test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n\n    \n    dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}\n    print('[INFO]train_set size: ', len(train_set))\n    print('[INFO]dev_set size: ', len(dev_set))\n    # print('[INFO]test_set size: ', len(test_set))\n    \n    # train dataloader\n    train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)\n    # dev dataloader\n    val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)\n    # dataloaders\n    dataloaders = {'train':train_dataloader, 'dev':val_dataloader}\n\n    ''' set up model '''\n    if model_name == 'BrainTranslator':\n        if use_random_init:\n            config = BartConfig.from_pretrained('facebook/bart-large')\n            pretrained = BartForConditionalGeneration(config)\n        else:\n            pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')\n    \n        model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n    \n    elif model_name == 'BrainTranslatorNaive':\n        pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')\n        model = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n\n    elif model_name == 'PegasusTranslator':\n        pretrained = PegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum')\n        model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n    \n    elif model_name == 'T5Translator':\n        pretrained = T5ForConditionalGeneration.from_pretrained(\"t5-large\")\n        model = T5Translator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)\n    \n    model.to(device)\n    model = torch.nn.DataParallel(model, device_ids=device_ids)\n    \n    ''' training loop '''\n\n    ######################################################\n    '''step one trainig: freeze most of BART params'''\n    ######################################################\n\n    # closely follow BART paper\n    if model_name in ['BrainTranslator','BrainTranslatorNaive', 'PegasusTranslator', 'T5Translator']:\n        for name, param in model.named_parameters():\n            if param.requires_grad and 'pretrained' in name:\n                if ('shared' in name) or ('embed_positions' in name) or ('encoder.layers.0' in name):\n                    continue\n                else:\n                    param.requires_grad = False\n\n    elif model_name == 'BertGeneration':\n        for name, param in model.named_parameters():\n            if param.requires_grad and 'pretrained' in name:\n                if ('embeddings' in name) or ('encoder.layer.0' in name):\n                    continue\n                else:\n                    param.requires_grad = False\n \n\n    if skip_step_one:\n        if load_step1_checkpoint:\n            stepone_checkpoint = 'path_to_step_1_checkpoint.pt'\n            print(f'skip step one, load checkpoint: {stepone_checkpoint}')\n            model.load_state_dict(torch.load(stepone_checkpoint))\n        else:\n            print('skip step one, start from scratch at step two')\n    else:\n\n        ''' set up optimizer and scheduler'''\n        optimizer_step1 = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=step1_lr, momentum=0.9)\n\n        exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.1)\n\n        ''' set up loss function '''\n        criterion = nn.CrossEntropyLoss()\n\n        print('=== start Step1 training ... ===')\n        # print training layers\n        show_require_grad_layers(model)\n        # return best loss model from step1 training\n        model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs_step1, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)\n\n    ######################################################\n    '''step two trainig: update whole model for a few iterations'''\n    ######################################################\n    for name, param in model.named_parameters():\n        param.requires_grad = True\n\n    ''' set up optimizer and scheduler'''\n    optimizer_step2 = optim.SGD(model.parameters(), lr=step2_lr, momentum=0.9)\n\n    exp_lr_scheduler_step2 = lr_scheduler.StepLR(optimizer_step2, step_size=30, gamma=0.1)\n\n    ''' set up loss function '''\n    criterion = nn.CrossEntropyLoss()\n    \n    print()\n    print('=== start Step2 training ... ===')\n    # print training layers\n    show_require_grad_layers(model)\n    \n    '''main loop'''\n    trained_model = train_model(dataloaders, device, model, criterion, optimizer_step2, exp_lr_scheduler_step2, num_epochs=num_epochs_step2, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)\n\n    # '''save checkpoint'''\n    # torch.save(trained_model.state_dict(), os.path.join(save_path,output_checkpoint_name))\n"
  },
  {
    "path": "train_sentiment_baseline.py",
    "content": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\nfrom torch.nn.utils.rnn import pack_padded_sequence \nimport pickle\nimport json\nimport matplotlib.pyplot as plt\nfrom glob import glob\nimport time\nimport copy\nfrom tqdm import tqdm\n\nfrom transformers import BertTokenizer, BertLMHeadModel, BertConfig\nfrom data import ZuCo_dataset\nfrom model_sentiment import BaselineMLPSentence, BaselineLSTM, NaiveFineTunePretrainedBert\nfrom config import get_config\n# Function to calculate the accuracy of our predictions vs labels\ndef flat_accuracy(preds, labels):\n    # preds: numpy array: N * 3 \n    # labels: numpy array: N \n    pred_flat = np.argmax(preds, axis=1).flatten()  \n    \n    labels_flat = labels.flatten()\n    \n    return np.sum(pred_flat == labels_flat) / len(labels_flat)\n\ndef flat_accuracy_top_k(preds, labels,k):\n    topk_preds = []\n    for pred in preds:\n        topk = pred.argsort()[-k:][::-1]\n        topk_preds.append(list(topk))\n    # print(topk_preds)\n    topk_preds = list(topk_preds)\n    right_count = 0\n    # print(len(labels))\n    for i in range(len(labels)):\n        l = labels[i][0]\n        if l in topk_preds[i]:\n            right_count+=1\n    return right_count/len(labels)\n\ndef train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/eeg_sentiment/best/test.pt', checkpoint_path_last = './checkpoints/eeg_sentiment/last/test.pt'):\n    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n    since = time.time()\n      \n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = 100000000000\n    best_acc = 0.0\n    \n\n    for epoch in range(num_epochs):\n        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n        print('-' * 10)\n\n        # Each epoch has a training and validation phase\n        for phase in ['train', 'dev']:\n            total_accuracy = 0.0\n            if phase == 'train':\n                model.train()  # Set model to training mode\n            else:\n                model.eval()   # Set model to evaluate mode\n\n            running_loss = 0.0\n\n            # Iterate over data.\n            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]):\n                \n                input_word_eeg_features = input_word_eeg_features.to(device).float()\n                sent_level_EEG = sent_level_EEG.to(device)\n                input_masks = input_masks.to(device)\n                sentiment_labels = sentiment_labels.to(device)\n\n                # zero the parameter gradients\n                optimizer.zero_grad()\n\n                if isinstance(model, BaselineMLPSentence):\n                    # forward\n                    logits = model(sent_level_EEG) # before softmax\n                    # calculate loss\n                    loss = criterion(logits, sentiment_labels)\n                \n                elif isinstance(model, BaselineLSTM):\n                    x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False)\n                    logits = model(x_packed)\n                    # calculate loss\n                    loss = criterion(logits, sentiment_labels)\n                \n                elif isinstance(model, NaiveFineTunePretrainedBert):\n                    output = model(input_word_eeg_features, input_masks, sentiment_labels)\n                    logits = output.logits\n                    loss = output.loss\n\n\n                # backward + optimize only if in training phase\n                if phase == 'train':\n                    # with torch.autograd.detect_anomaly():\n                    loss.backward()\n                    optimizer.step()\n\n                # calculate accuracy\n                preds_cpu = logits.detach().cpu().numpy()\n                label_cpu = sentiment_labels.cpu().numpy()\n\n                total_accuracy += flat_accuracy(preds_cpu, label_cpu)\n\n                # statistics\n                running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss\n                # print('[DEBUG]loss:',loss.item())\n                # print('#################################')\n                \n\n            if phase == 'train':\n                scheduler.step()\n\n            epoch_loss = running_loss / dataset_sizes[phase]\n            epoch_acc = total_accuracy / len(dataloaders[phase])\n            print('{} Loss: {:.4f}'.format(phase, epoch_loss))\n            print('{} Acc: {:.4f}'.format(phase, epoch_acc))\n\n            # deep copy the model\n            if phase == 'dev' and (epoch_acc > best_acc):\n                best_loss = epoch_loss\n                best_acc = epoch_acc\n                best_model_wts = copy.deepcopy(model.state_dict())\n                '''save checkpoint'''\n                torch.save(model.state_dict(), checkpoint_path_best)\n                print(f'update best on dev checkpoint: {checkpoint_path_best}')\n        print()\n\n    time_elapsed = time.time() - since\n    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n    print('Best val loss: {:4f}'.format(best_loss))\n    print('Best val acc: {:4f}'.format(best_acc))\n    torch.save(model.state_dict(), checkpoint_path_last)\n    print(f'update last checkpoint: {checkpoint_path_last}')\n\n\n    # load best model weights\n    model.load_state_dict(best_model_wts)\n    return model\n\nif __name__ == '__main__':\n    args = get_config('train_sentiment_baseline')\n    \n    ''' config param'''\n    num_epochs = args['num_epoch']\n    step_lr = args['learning_rate']\n\n    '''dataset division'''\n    dataset_setting = 'unique_sent'\n\n    # subject_choice = 'ALL\n    subject_choice = args['subjects']\n    print(f'![Debug]using {subject_choice}')\n    # eeg_type_choice = 'GD\n    eeg_type_choice = args['eeg_type']\n    print(f'[INFO]eeg type {eeg_type_choice}')\n    # bands_choice = ['_t1'] \n    # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] \n    bands_choice = args['eeg_bands']\n    print(f'[INFO]using bands {bands_choice}')\n    \n    '''model name'''\n    # model_name = 'BaselineMLP'\n    # model_name = 'BaselineLSTM'\n    # model_name = 'NaiveFinetuneBert'\n    model_name = args['model_name']\n\n    batch_size = 32\n    save_path = args['save_path']\n    save_name = f'{model_name}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}'\n\n    if model_name == 'BaselineLSTM':\n        num_layers = 4\n        save_name = f'{model_name}_numLayers-{num_layers}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}'\n\n    output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' \n    output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' \n\n\n    \n    ''' set random seeds '''\n    seed_val = 312\n    np.random.seed(seed_val)\n    torch.manual_seed(seed_val)\n    torch.cuda.manual_seed_all(seed_val)\n\n\n    ''' set up device '''\n    # use cuda\n    if torch.cuda.is_available():  \n        dev = args['cuda']\n    else:  \n        dev = \"cpu\"\n    # CUDA_VISIBLE_DEVICES=0,1,2,3  \n    device = torch.device(dev)\n    print(f'[INFO]using device {dev}')\n\n\n    ''' load pickle'''\n    whole_dataset_dict = []\n    dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' \n    with open(dataset_path_task1, 'rb') as handle:\n        whole_dataset_dict.append(pickle.load(handle))\n    \n    '''set up tokenizer'''\n    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n\n    ''' set up dataloader '''\n    # train dataset\n    train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n    # dev dataset\n    dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n    # test dataset\n    # test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice)\n\n    dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}\n    print('[INFO]train_set size: ', len(train_set))\n    print('[INFO]dev_set size: ', len(dev_set))\n    \n    # train dataloader\n    train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)\n    # dev dataloader\n    val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)\n    # dataloaders\n    dataloaders = {'train':train_dataloader, 'dev':val_dataloader}\n\n    ''' set up model '''\n    if model_name == 'BaselineMLP':\n        print('[INFO]Model: BaselineMLP')\n        model = BaselineMLPSentence(input_dim = 105*len(bands_choice), hidden_dim = 128, output_dim = 3)\n    elif model_name == 'BaselineLSTM':\n        print('[INFO]Model: BaselineLSTM')\n        model = BaselineLSTM(input_dim = 105*len(bands_choice), hidden_dim = 256, output_dim = 3, num_layers = num_layers)\n    elif model_name == 'NaiveFinetuneBert':\n        print('[INFO]Model: NaiveFinetuneBert')\n        model = NaiveFineTunePretrainedBert(input_dim = 105*len(bands_choice), hidden_dim = 768, output_dim = 3)\n    \n    model.to(device)\n    \n\n    \"\"\"save config\"\"\"\n    with open(f'./config/eeg_sentiment/{save_name}.json', 'w') as out_config:\n        json.dump(args, out_config, indent = 4)\n    \n    \n    ''' training loop '''\n\n    ''' set up optimizer and scheduler'''\n    optimizer_step1 = optim.SGD(model.parameters(), lr=step_lr, momentum=0.9)\n    exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.5)\n\n    ''' set up loss function '''\n    criterion = nn.CrossEntropyLoss()\n\n    print('=== start training ... ===')\n    # return best loss model from step1 training\n    model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)\n"
  },
  {
    "path": "train_sentiment_textbased.py",
    "content": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, random_split\nimport pickle\nimport json\nimport matplotlib.pyplot as plt\nfrom glob import glob\nimport time\nimport copy\nfrom tqdm import tqdm\n\nfrom transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification\nfrom data import ZuCo_dataset, SST_tenary_dataset\nfrom model_sentiment import FineTunePretrainedTwoStep\nfrom config import get_config\n# Function to calculate the accuracy of our predictions vs labels\ndef flat_accuracy(preds, labels):\n    # preds: numpy array: N * 3 \n    # labels: numpy array: N \n    pred_flat = np.argmax(preds, axis=1).flatten()  \n    \n    labels_flat = labels.flatten()\n    \n    return np.sum(pred_flat == labels_flat) / len(labels_flat)\n\ndef flat_accuracy_top_k(preds, labels,k):\n    topk_preds = []\n    for pred in preds:\n        topk = pred.argsort()[-k:][::-1]\n        topk_preds.append(list(topk))\n    # print(topk_preds)\n    topk_preds = list(topk_preds)\n    right_count = 0\n    # print(len(labels))\n    for i in range(len(labels)):\n        l = labels[i][0]\n        if l in topk_preds[i]:\n            right_count+=1\n    return right_count/len(labels)\n\ndef train_model_ZuCo(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'):\n    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n    since = time.time()\n      \n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = 100000000000\n    best_acc = 0.0\n    \n\n    for epoch in range(num_epochs):\n        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n        print('-' * 10)\n\n        # Each epoch has a training and validation phase\n        for phase in ['train', 'dev']:\n            total_accuracy = 0.0\n            if phase == 'train':\n                model.train()  # Set model to training mode\n            else:\n                model.eval()   # Set model to evaluate mode\n\n            running_loss = 0.0\n\n            # Iterate over data.\n            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]):\n                \n                # input_word_eeg_features = input_word_eeg_features.to(device).float()\n                # input_masks = input_masks.to(device)\n                # input_mask_invert = input_mask_invert.to(device)\n                target_ids = target_ids.to(device)\n                target_mask = target_mask.to(device)\n                sentiment_labels = sentiment_labels.to(device)\n\n                # zero the parameter gradients\n                optimizer.zero_grad()\n\n                # forward\n                output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels)\n                logits = output.logits\n                loss = output.loss\n\n                # backward + optimize only if in training phase\n                if phase == 'train':\n                    # with torch.autograd.detect_anomaly():\n                    loss.backward()\n                    optimizer.step()\n\n                # calculate accuracy\n                preds_cpu = logits.detach().cpu().numpy()\n                label_cpu = sentiment_labels.cpu().numpy()\n\n                total_accuracy += flat_accuracy(preds_cpu, label_cpu)\n\n                # statistics\n                running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss\n                # print('[DEBUG]loss:',loss.item())\n                # print('#################################')\n                \n\n            if phase == 'train':\n                scheduler.step()\n\n            epoch_loss = running_loss / dataset_sizes[phase]\n            epoch_acc = total_accuracy / len(dataloaders[phase])\n            print('{} Loss: {:.4f}'.format(phase, epoch_loss))\n            print('{} Acc: {:.4f}'.format(phase, epoch_acc))\n\n            # deep copy the model\n            if phase == 'dev' and (epoch_acc > best_acc):\n                best_loss = epoch_loss\n                best_acc = epoch_acc\n                best_model_wts = copy.deepcopy(model.state_dict())\n                '''save checkpoint'''\n                torch.save(model.state_dict(), checkpoint_path_best)\n                print(f'update best on dev checkpoint: {checkpoint_path_best}')\n        print()\n\n    time_elapsed = time.time() - since\n    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n    print('Best val loss: {:4f}'.format(best_loss))\n    print('Best val acc: {:4f}'.format(best_acc))\n    torch.save(model.state_dict(), checkpoint_path_last)\n    print(f'update last checkpoint: {checkpoint_path_last}')\n    \n    # write to log\n    with open(output_log_file_name, 'w') as outlog:\n        outlog.write(f'best val loss: {best_loss}\\n')\n        outlog.write('Best val acc: {:4f}'.format(best_acc))\n    # load best model weights\n    model.load_state_dict(best_model_wts)\n    return model\n\ndef train_model_SST(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'):\n    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n    since = time.time()\n      \n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = 100000000000\n    best_acc = 0.0\n    \n\n    for epoch in range(num_epochs):\n        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n        print('-' * 10)\n\n        # Each epoch has a training and validation phase\n        for phase in ['train', 'dev']:\n            total_accuracy = 0.0\n            if phase == 'train':\n                model.train()  # Set model to training mode\n            else:\n                model.eval()   # Set model to evaluate mode\n\n            running_loss = 0.0\n\n            # Iterate over data.\n            for input_ids,input_masks,sentiment_labels in tqdm(dataloaders[phase]):\n                \n                input_ids = input_ids.to(device)\n                input_masks = input_masks.to(device)\n                sentiment_labels = sentiment_labels.to(device)\n\n                # zero the parameter gradients\n                optimizer.zero_grad()\n\n                # forward\n                output = model(input_ids = input_ids, attention_mask = input_masks, return_dict = True, labels = sentiment_labels)\n                logits = output.logits\n                loss = output.loss\n\n                # backward + optimize only if in training phase\n                if phase == 'train':\n                    # with torch.autograd.detect_anomaly():\n                    loss.backward()\n                    optimizer.step()\n\n                # calculate accuracy\n                preds_cpu = logits.detach().cpu().numpy()\n                label_cpu = sentiment_labels.cpu().numpy()\n\n                total_accuracy += flat_accuracy(preds_cpu, label_cpu)\n\n                # statistics\n                running_loss += loss.item() * input_ids.size()[0] # batch loss\n                # print('[DEBUG]loss:',loss.item())\n                # print('#################################')\n                \n\n            if phase == 'train':\n                scheduler.step()\n\n            epoch_loss = running_loss / dataset_sizes[phase]\n            epoch_acc = total_accuracy / len(dataloaders[phase])\n            print('{} Loss: {:.4f}'.format(phase, epoch_loss))\n            print('{} Acc: {:.4f}'.format(phase, epoch_acc))\n\n            # deep copy the model\n            if phase == 'dev' and (epoch_acc > best_acc):\n                best_loss = epoch_loss\n                best_acc = epoch_acc\n                best_model_wts = copy.deepcopy(model.state_dict())\n                '''save checkpoint'''\n                torch.save(model.state_dict(), checkpoint_path_best)\n                print(f'update best on dev checkpoint: {checkpoint_path_best}')\n        print()\n\n    time_elapsed = time.time() - since\n    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n    print('Best val loss: {:4f}'.format(best_loss))\n    print('Best val acc: {:4f}'.format(best_acc))\n    torch.save(model.state_dict(), checkpoint_path_last)\n    print(f'update last checkpoint: {checkpoint_path_last}')\n    \n    # load best model weights\n    model.load_state_dict(best_model_wts)\n    return model\n\n\nif __name__ == '__main__':\n    args = get_config('train_sentiment_textbased')\n\n    ''' config param'''\n\n    num_epoch = args['num_epoch']\n    # lr = 1e-3 # Bert, RoBerta\n    # lr = 1e-4 # Bart\n    lr = args['learning_rate']\n\n    dataset_name = args['dataset_name'] # zero-shot setting: using external dataset from stanford sentiment treebank, pass in 'SST'; or pass in 'ZuCo' to train on ZuCo's text-sentiment pairs\n\n    dataset_setting = 'unique_sent'\n\n    batch_size = args['batch_size']\n    \n    # model_name = 'pretrain_Bert'\n    # model_name = 'pretrain_RoBerta'\n    # model_name = 'pretrain_Bart'\n    model_name = args['model_name']\n    print(f'[INFO]model name: {model_name}')\n\n    save_path = args['save_path'] \n\n    if dataset_name == 'ZuCo':\n        # subject_choice = 'ALL\n        subject_choice = args['subjects']\n        print(f'![Debug]using {subject_choice}')\n        # eeg_type_choice = 'GD\n        eeg_type_choice = args['eeg_type']\n        print(f'[INFO]eeg type {eeg_type_choice}')\n        # bands_choice = ['_t1'] \n        # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] \n        bands_choice = args['eeg_bands']\n        print(f'[INFO]using bands {bands_choice}')\n        save_name = f'Textbased_ZuCo_{model_name}_b{batch_size}_{num_epoch}_{lr}_{dataset_setting}_{eeg_type_choice}'\n    elif dataset_name == 'SST':\n        save_name = f'Textbased_StanfordSentitmentTreeband_{model_name}_b{batch_size}_{num_epoch}_{lr}'\n\n    output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' \n    output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' \n\n\n    ''' set random seeds '''\n    seed_val = 312\n    np.random.seed(seed_val)\n    torch.manual_seed(seed_val)\n    torch.cuda.manual_seed_all(seed_val)\n\n    ''' set up device '''\n    # use cuda\n    if torch.cuda.is_available():  \n        dev = args['cuda']\n    else:  \n        dev = \"cpu\"\n    # CUDA_VISIBLE_DEVICES=0,1,2,3  \n    device = torch.device(dev)\n    print(f'[INFO]using device {dev}')\n\n\n    ''' load pickle '''\n    if dataset_name == 'ZuCo':\n        whole_dataset_dict = []\n        dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' \n        with open(dataset_path_task1, 'rb') as handle:\n            whole_dataset_dict.append(pickle.load(handle))\n    \n    '''tokenizer'''\n    if model_name == 'pretrain_Bert':\n        print('[INFO]pretrained checkpoint: bert-base-cased')\n        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n    elif model_name == 'pretrain_RoBerta':\n        print('[INFO]pretrained checkpoint: roberta-base')\n        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n    elif model_name == 'pretrain_Bart':\n        print('[INFO]pretrained checkpoint: bart-large')\n        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')\n\n    ''' set up dataloader '''\n    if dataset_name == 'ZuCo':\n        # train dataset\n        train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n        # dev dataset\n        dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)\n    \n    elif dataset_name == 'SST':\n        SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))\n\n        SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer)  \n        \n        train_size = int(0.9 * len(SST_dataset))\n        val_size = len(SST_dataset) - train_size\n\n        train_set, dev_set = random_split(SST_dataset, [train_size, val_size])\n        print('{:>5,} training samples'.format(len(train_set)))\n        print('{:>5,} validation samples'.format(len(dev_set)))\n\n\n    dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}\n    print('[INFO]train_set size: ', len(train_set))\n    print('[INFO]dev_set size: ', len(dev_set))\n    \n    # train dataloader\n    train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)\n    # dev dataloader\n    val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)\n    # dataloaders\n    dataloaders = {'train':train_dataloader, 'dev':val_dataloader}\n\n    ''' set up model '''\n    if model_name == 'pretrain_Bert':\n        model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)\n    elif model_name == 'pretrain_RoBerta':\n        model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)\n    elif model_name == 'pretrain_Bart':\n        model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels = 3)\n    \n    model.to(device)\n    \n\n    \"\"\"save config\"\"\"\n    with open(f'./config/text_sentiment_classifier/{save_name}.json', 'w') as out_config:\n        json.dump(args, out_config, indent = 4)\n\n\n    ''' training loop '''\n    ######################################################\n    '''step one trainig: freeze most of BART params'''\n    ######################################################\n\n    ''' set up optimizer and scheduler'''\n    optimizer_step1 = optim.SGD(model.parameters(), lr=lr, momentum=0.9)\n\n    exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=10, gamma=0.1)\n\n    # TODO: rethink about the loss function\n    ''' set up loss function '''\n    criterion = nn.CrossEntropyLoss()\n\n    # return best loss model from step1 training\n    print(f'=== start training {dataset_name} ... ===')\n    if dataset_name == 'ZuCo':\n        model = train_model_ZuCo(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)\n    elif dataset_name == 'SST':\n        model = train_model_SST(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)\n        "
  },
  {
    "path": "util/construct_dataset_mat_to_pickle_v1.py",
    "content": "import scipy.io as io\nimport h5py\nimport os\nimport json\nfrom glob import glob\nfrom tqdm import tqdm\nimport numpy as np\nimport pickle\nimport argparse\n\nparser = argparse.ArgumentParser(description='Specify task name for converting ZuCo v1.0 Mat file to Pickle')\nparser.add_argument('-t', '--task_name', help='name of the task in /dataset/ZuCo, choose from {task1-SR,task2-NR,task3-TSR}', required=True)\nargs = vars(parser.parse_args())\n\n\n\"\"\"config\"\"\"\nversion = 'v1' # 'old'\n# version = 'v2' # 'new'\n\ntask_name = args['task_name']\n# task_name = 'task1-SR'\n# task_name = 'task2-NR'\n# task_name = 'task3-TSR'\n\n\nprint('##############################')\nprint(f'start processing ZuCo {task_name}...')\n\n\nif version == 'v1':\n    # old version \n    input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files' \nelif version == 'v2':\n    # new version, mat73 \n    input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files' \n\noutput_dir = f'./dataset/ZuCo/{task_name}/pickle'\nif not os.path.exists(output_dir):\n    os.makedirs(output_dir)\n\n\"\"\"load files\"\"\"\nmat_files = glob(os.path.join(input_mat_files_dir,'*.mat'))\nmat_files = sorted(mat_files)\n\nif len(mat_files) == 0:\n    print(f'No mat files found for {task_name}')\n    quit()\n\ndataset_dict = {}\nfor mat_file in tqdm(mat_files):\n    subject_name = os.path.basename(mat_file).split('_')[0].replace('results','').strip()\n    dataset_dict[subject_name] = []\n    \n    if version == 'v1':\n        matdata = io.loadmat(mat_file, squeeze_me=True, struct_as_record=False)['sentenceData']\n    elif version == 'v2':\n        matdata = h5py.File(mat_file,'r')\n        print(matdata)\n\n    for sent in matdata:\n        word_data = sent.word\n        if not isinstance(word_data, float):\n            # sentence level:\n            sent_obj = {'content':sent.content}\n            sent_obj['sentence_level_EEG'] = {'mean_t1':sent.mean_t1, 'mean_t2':sent.mean_t2, 'mean_a1':sent.mean_a1, 'mean_a2':sent.mean_a2, 'mean_b1':sent.mean_b1, 'mean_b2':sent.mean_b2, 'mean_g1':sent.mean_g1, 'mean_g2':sent.mean_g2}\n\n            if task_name == 'task1-SR':\n                sent_obj['answer_EEG'] = {'answer_mean_t1':sent.answer_mean_t1, 'answer_mean_t2':sent.answer_mean_t2, 'answer_mean_a1':sent.answer_mean_a1, 'answer_mean_a2':sent.answer_mean_a2, 'answer_mean_b1':sent.answer_mean_b1, 'answer_mean_b2':sent.answer_mean_b2, 'answer_mean_g1':sent.answer_mean_g1, 'answer_mean_g2':sent.answer_mean_g2}\n            \n            # word level:\n            sent_obj['word'] = []\n            \n            word_tokens_has_fixation = [] \n            word_tokens_with_mask = []\n            word_tokens_all = []\n\n            for word in word_data:\n                word_obj = {'content':word.content}\n                word_tokens_all.append(word.content)\n                # TODO: add more version of word level eeg: GD, SFD, GPT\n                word_obj['nFixations'] = word.nFixations\n                if word.nFixations > 0:    \n                    word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':word.FFD_t1, 'FFD_t2':word.FFD_t2, 'FFD_a1':word.FFD_a1, 'FFD_a2':word.FFD_a2, 'FFD_b1':word.FFD_b1, 'FFD_b2':word.FFD_b2, 'FFD_g1':word.FFD_g1, 'FFD_g2':word.FFD_g2}}\n                    word_obj['word_level_EEG']['TRT'] = {'TRT_t1':word.TRT_t1, 'TRT_t2':word.TRT_t2, 'TRT_a1':word.TRT_a1, 'TRT_a2':word.TRT_a2, 'TRT_b1':word.TRT_b1, 'TRT_b2':word.TRT_b2, 'TRT_g1':word.TRT_g1, 'TRT_g2':word.TRT_g2}\n                    word_obj['word_level_EEG']['GD'] = {'GD_t1':word.GD_t1, 'GD_t2':word.GD_t2, 'GD_a1':word.GD_a1, 'GD_a2':word.GD_a2, 'GD_b1':word.GD_b1, 'GD_b2':word.GD_b2, 'GD_g1':word.GD_g1, 'GD_g2':word.GD_g2}\n                    sent_obj['word'].append(word_obj)\n                    word_tokens_has_fixation.append(word.content)\n                    word_tokens_with_mask.append(word.content)\n                else:\n                    word_tokens_with_mask.append('[MASK]')\n                    # if a word has no fixation, use sentence level feature\n                    # word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':sent.mean_t1, 'FFD_t2':sent.mean_t2, 'FFD_a1':sent.mean_a1, 'FFD_a2':sent.mean_a2, 'FFD_b1':sent.mean_b1, 'FFD_b2':sent.mean_b2, 'FFD_g1':sent.mean_g1, 'FFD_g2':sent.mean_g2}}\n                    # word_obj['word_level_EEG']['TRT'] = {'TRT_t1':sent.mean_t1, 'TRT_t2':sent.mean_t2, 'TRT_a1':sent.mean_a1, 'TRT_a2':sent.mean_a2, 'TRT_b1':sent.mean_b1, 'TRT_b2':sent.mean_b2, 'TRT_g1':sent.mean_g1, 'TRT_g2':sent.mean_g2}\n                    \n                    # NOTE:if a word has no fixation, simply skip it\n                    continue\n            \n            sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation\n            sent_obj['word_tokens_with_mask'] = word_tokens_with_mask\n            sent_obj['word_tokens_all'] = word_tokens_all\n            \n            dataset_dict[subject_name].append(sent_obj)\n\n        else:\n            print(f'missing sent: subj:{subject_name} content:{sent.content}, return None')\n            dataset_dict[subject_name].append(None)\n\n            continue\n    # print(dataset_dict.keys())\n    # print(dataset_dict[subject_name][0].keys())\n    # print(dataset_dict[subject_name][0]['content'])\n    # print(dataset_dict[subject_name][0]['word'][0].keys())\n    # print(dataset_dict[subject_name][0]['word'][0]['word_level_EEG']['FFD'])\n\n\"\"\"output\"\"\"\noutput_name = f'{task_name}-dataset.pickle'\n# with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out:\n#     json.dump(dataset_dict,out,indent = 4)\n\nwith open(os.path.join(output_dir,output_name), 'wb') as handle:\n    pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n    print('write to:', os.path.join(output_dir,output_name))\n\n\n\"\"\"sanity check\"\"\"\n# check dataset\nwith open(os.path.join(output_dir,output_name), 'rb') as handle:\n    whole_dataset = pickle.load(handle)\nprint('subjects:', whole_dataset.keys())\n\nif version == 'v1':\n    print('num of sent:', len(whole_dataset['ZAB']))\n    print()\n\n\n"
  },
  {
    "path": "util/construct_dataset_mat_to_pickle_v2.py",
    "content": "import os\nimport numpy as np\nimport h5py\nimport data_loading_helpers_modified as dh\nfrom glob import glob\nfrom tqdm import tqdm\nimport pickle\n\n\ntask = \"NR\"\n\nrootdir = \"./dataset/ZuCo/task2-NR-2.0/Matlab_files/\"\n\nprint('##############################')\nprint(f'start processing ZuCo task2-NR-2.0...')\n\ndataset_dict = {}\n\nfor file in tqdm(os.listdir(rootdir)):\n    if file.endswith(task+\".mat\"):\n\n        file_name = rootdir + file\n\n        # print('file name:', file_name)\n        subject = file_name.split(\"ts\")[1].split(\"_\")[0]\n        # print('subject: ', subject)\n\n        # exclude YMH due to incomplete data because of dyslexia\n        if subject != 'YMH':\n            assert subject not in dataset_dict\n            dataset_dict[subject] = []\n\n            f = h5py.File(file_name,'r')\n            # print('keys in f:', list(f.keys()))\n            sentence_data = f['sentenceData']\n            # print('keys in sentence_data:', list(sentence_data.keys()))\n            \n            # sent level eeg \n            # mean_t1 = np.squeeze(f[sentence_data['mean_t1'][0][0]][()])\n            mean_t1_objs = sentence_data['mean_t1']\n            mean_t2_objs = sentence_data['mean_t2']\n            mean_a1_objs = sentence_data['mean_a1']\n            mean_a2_objs = sentence_data['mean_a2']\n            mean_b1_objs = sentence_data['mean_b1']\n            mean_b2_objs = sentence_data['mean_b2']\n            mean_g1_objs = sentence_data['mean_g1']\n            mean_g2_objs = sentence_data['mean_g2']\n            \n            rawData = sentence_data['rawData']\n            contentData = sentence_data['content']\n            # print('contentData shape:', contentData.shape, 'dtype:', contentData.dtype)\n            omissionR = sentence_data['omissionRate']\n            wordData = sentence_data['word']\n\n\n            for idx in range(len(rawData)):\n                # get sentence string\n                obj_reference_content = contentData[idx][0]\n                sent_string = dh.load_matlab_string(f[obj_reference_content])\n                # print('sentence string:', sent_string)\n                \n                sent_obj = {'content':sent_string}\n                \n                # get sentence level EEG\n                sent_obj['sentence_level_EEG'] = {\n                    'mean_t1':np.squeeze(f[mean_t1_objs[idx][0]][()]), \n                    'mean_t2':np.squeeze(f[mean_t2_objs[idx][0]][()]), \n                    'mean_a1':np.squeeze(f[mean_a1_objs[idx][0]][()]), \n                    'mean_a2':np.squeeze(f[mean_a2_objs[idx][0]][()]), \n                    'mean_b1':np.squeeze(f[mean_b1_objs[idx][0]][()]), \n                    'mean_b2':np.squeeze(f[mean_b2_objs[idx][0]][()]), \n                    'mean_g1':np.squeeze(f[mean_g1_objs[idx][0]][()]), \n                    'mean_g2':np.squeeze(f[mean_g2_objs[idx][0]][()])\n                }\n                # print(sent_obj)\n                sent_obj['word'] = []\n\n                # get word level data\n                word_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask = dh.extract_word_level_data(f, f[wordData[idx][0]])\n                \n                if word_data == {}:\n                    print(f'missing sent: subj:{subject} content:{sent_string}, append None')\n                    dataset_dict[subject].append(None)\n                    continue\n                elif len(word_tokens_all) == 0:\n                    print(f'no word level features: subj:{subject} content:{sent_string}, append None')\n                    dataset_dict[subject].append(None)\n                    continue\n\n                else:                    \n                    for widx in range(len(word_data)):\n                        data_dict = word_data[widx]\n                        word_obj = {'content':data_dict['content'], 'nFixations': data_dict['nFix']}\n                        if 'GD_EEG' in data_dict:\n                            # print('has fixation: ', data_dict['content'])\n                            gd = data_dict[\"GD_EEG\"]\n                            ffd = data_dict[\"FFD_EEG\"]\n                            trt = data_dict[\"TRT_EEG\"]\n                            assert len(gd) == len(trt) == len(ffd) == 8\n                            word_obj['word_level_EEG'] = {\n                                'GD':{'GD_t1':gd[0], 'GD_t2':gd[1], 'GD_a1':gd[2], 'GD_a2':gd[3], 'GD_b1':gd[4], 'GD_b2':gd[5], 'GD_g1':gd[6], 'GD_g2':gd[7]},\n                                'FFD':{'FFD_t1':ffd[0], 'FFD_t2':ffd[1], 'FFD_a1':ffd[2], 'FFD_a2':ffd[3], 'FFD_b1':ffd[4], 'FFD_b2':ffd[5], 'FFD_g1':ffd[6], 'FFD_g2':ffd[7]},\n                                'TRT':{'TRT_t1':trt[0], 'TRT_t2':trt[1], 'TRT_a1':trt[2], 'TRT_a2':trt[3], 'TRT_b1':trt[4], 'TRT_b2':trt[5], 'TRT_g1':trt[6], 'TRT_g2':trt[7]}\n                            }\n                            sent_obj['word'].append(word_obj)\n                        \n                    sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation\n                    sent_obj['word_tokens_with_mask'] = word_tokens_with_mask\n                    sent_obj['word_tokens_all'] = word_tokens_all     \n                    \n                    # print(sent_obj.keys())\n                    # print(len(sent_obj['word']))\n                    # print(sent_obj['word'][0])\n\n                    dataset_dict[subject].append(sent_obj)\n\n\"\"\"output\"\"\"\ntask_name = 'task2-NR-2.0'\n\nif dataset_dict == {}:\n    print(f'No mat file found for {task_name}')\n    quit()\n\noutput_dir = f'./dataset/ZuCo/{task_name}/pickle'\nif not os.path.exists(output_dir):\n    os.makedirs(output_dir)\n\noutput_name = f'{task_name}-dataset.pickle'\n# with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out:\n#     json.dump(dataset_dict,out,indent = 4)\n\nwith open(os.path.join(output_dir,output_name), 'wb') as handle:\n    pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n    print('write to:', os.path.join(output_dir,output_name))\n\n\"\"\"sanity check\"\"\"\nprint('subjects:', dataset_dict.keys())\nprint('num of sent:', len(dataset_dict['YAC']))"
  },
  {
    "path": "util/data_loading_helpers_modified.py",
    "content": "import numpy as np\nimport re\n\neeg_float_resolution=np.float16\n\nAlpha_ffd_names = ['FFD_a1', 'FFD_a1_diff', 'FFD_a2', 'FFD_a2_diff']\nBeta_ffd_names = ['FFD_b1', 'FFD_b1_diff', 'FFD_b2', 'FFD_b2_diff']\nGamma_ffd_names = ['FFD_g1', 'FFD_g1_diff', 'FFD_g2', 'FFD_g2_diff']\nTheta_ffd_names = ['FFD_t1', 'FFD_t1_diff', 'FFD_t2', 'FFD_t2_diff']\nAlpha_gd_names = ['GD_a1', 'GD_a1_diff', 'GD_a2', 'GD_a2_diff']\nBeta_gd_names = ['GD_b1', 'GD_b1_diff', 'GD_b2', 'GD_b2_diff']\nGamma_gd_names = ['GD_g1', 'GD_g1_diff', 'GD_g2', 'GD_g2_diff']\nTheta_gd_names = ['GD_t1', 'GD_t1_diff', 'GD_t2', 'GD_t2_diff']\nAlpha_gpt_names = ['GPT_a1', 'GPT_a1_diff', 'GPT_a2', 'GPT_a2_diff']\nBeta_gpt_names = ['GPT_b1', 'GPT_b1_diff', 'GPT_b2', 'GPT_b2_diff']\nGamma_gpt_names = ['GPT_g1', 'GPT_g1_diff', 'GPT_g2', 'GPT_g2_diff']\nTheta_gpt_names = ['GPT_t1', 'GPT_t1_diff', 'GPT_t2', 'GPT_t2_diff']\nAlpha_sfd_names = ['SFD_a1', 'SFD_a1_diff', 'SFD_a2', 'SFD_a2_diff']\nBeta_sfd_names = ['SFD_b1', 'SFD_b1_diff', 'SFD_b2', 'SFD_b2_diff']\nGamma_sfd_names = ['SFD_g1', 'SFD_g1_diff', 'SFD_g2', 'SFD_g2_diff']\nTheta_sfd_names = ['SFD_t1', 'SFD_t1_diff', 'SFD_t2', 'SFD_t2_diff']\nAlpha_trt_names = ['TRT_a1', 'TRT_a1_diff', 'TRT_a2', 'TRT_a2_diff']\nBeta_trt_names = ['TRT_b1', 'TRT_b1_diff', 'TRT_b2', 'TRT_b2_diff']\nGamma_trt_names = ['TRT_g1', 'TRT_g1_diff', 'TRT_g2', 'TRT_g2_diff']\nTheta_trt_names = ['TRT_t1', 'TRT_t1_diff', 'TRT_t2', 'TRT_t2_diff']\n\n# IF YOU CHANGE THOSE YOU MUST ALSO CHANGE CONSTANTS\nAlpha_features = Alpha_ffd_names + Alpha_gd_names + Alpha_gpt_names + Alpha_trt_names# + Alpha_sfd_names\nBeta_features = Beta_ffd_names + Beta_gd_names + Beta_gpt_names + Beta_trt_names# + Beta_sfd_names\nGamma_features = Gamma_ffd_names + Gamma_gd_names + Gamma_gpt_names + Gamma_trt_names# + Gamma_sfd_names\nTheta_features = Theta_ffd_names + Theta_gd_names + Theta_gpt_names + Theta_trt_names# + Theta_sfd_names\n# print(Alpha_features)\n\n# GD_EEG_feautres\n\n\ndef extract_all_fixations(data_container, word_data_object, float_resolution = np.float16):\n    \"\"\"\n    Extracts all fixations from a word data object\n    :param data_container:      (h5py)  Container of the whole data, h5py object\n    :param word_data_object:    (h5py)  Container of fixation objects, h5py object\n    :param float_resolution:    (type)  Resolution to which data re to be converted, used for data compression\n    :return:\n        fixations_data  (list)  Data arrays representing each fixation\n    \"\"\"\n    word_data = data_container[word_data_object]\n    fixations_data = []\n    if len(word_data.shape) > 1:\n        for fixation_idx in range(word_data.shape[0]):\n            fixations_data.append(np.array(data_container[word_data[fixation_idx][0]]).astype(float_resolution))\n    return fixations_data\n\n\ndef is_real_word(word):\n    \"\"\"\n    Check if the word is a real word\n    :param word:    (str)   word string\n    :return:\n        is_word (bool)  True if it is a real word\n    \"\"\"\n    is_word = re.search('[a-zA-Z0-9]', word)\n    return is_word\n\n\ndef load_matlab_string(matlab_extracted_object):\n    \"\"\"\n    Converts a string loaded from h5py into a python string\n    :param matlab_extracted_object:     (h5py)  matlab string object\n    :return:\n        extracted_string    (str)   translated string\n    \"\"\"\n    extracted_string = u''.join(chr(c[0]) for c in matlab_extracted_object)\n    return extracted_string\n\n\ndef extract_word_level_data(data_container, word_objects, eeg_float_resolution = np.float16):\n    \"\"\"\n    Extracts word level data for a specific sentence\n    :param data_container:          (h5py)  Container of the whole data, h5py object\n    :param word_objects:            (h5py)  Container of all word data for a specific sentence\n    :param eeg_float_resolution:    (type)  Resolution with which to save EEG, used for data compression\n    :return:\n        word_level_data     (dict)  Contains all word level data indexed by their index number in the sentence,\n                                    together with the reading order, indexed by \"word_reading_order\"\n    \"\"\"\n    available_objects = list(word_objects)\n    #print(available_objects)\n    #print(len(available_objects))\n    # print('available_objects:', available_objects)\n\n    if isinstance(available_objects[0], str):\n\n        contentData = word_objects['content']\n        #fixations_order_per_word = []\n        if \"rawEEG\" in available_objects:\n\n            rawData = word_objects['rawEEG']\n            etData = word_objects['rawET']\n\n            ffdData = word_objects['FFD']\n            gdData = word_objects['GD']\n            gptData = word_objects['GPT']\n            trtData = word_objects['TRT']\n\n            try:\n                sfdData = word_objects['SFD']\n            except KeyError:\n                print(\"no SFD!\")\n                sfdData = []\n            nFixData = word_objects['nFixations']\n            fixPositions = word_objects[\"fixPositions\"]\n\n            Alpha_features_data = [word_objects[feature] for feature in Alpha_features]\n            Beta_features_data = [word_objects[feature] for feature in Beta_features]\n            Gamma_features_data = [word_objects[feature] for feature in Gamma_features]\n            Theta_features_data = [word_objects[feature] for feature in Theta_features]\n            #### \n            GD_EEG_features = [word_objects[feature] for feature in ['GD_t1','GD_t2','GD_a1','GD_a2','GD_b1','GD_b2','GD_g1','GD_g2']]\n            FFD_EEG_features = [word_objects[feature] for feature in ['FFD_t1','FFD_t2','FFD_a1','FFD_a2','FFD_b1','FFD_b2','FFD_g1','FFD_g2']]\n            TRT_EEG_features = [word_objects[feature] for feature in ['TRT_t1','TRT_t2','TRT_a1','TRT_a2','TRT_b1','TRT_b2','TRT_g1','TRT_g2']]\n            #### \n            assert len(contentData) == len(etData) == len(rawData), \"different amounts of different data!!\"\n\n            zipped_data = zip(rawData, etData, contentData, ffdData, gdData, gptData, trtData, sfdData, nFixData, fixPositions)\n            \n            word_level_data = {}\n            word_idx = 0\n\n            word_tokens_has_fixation = [] \n            word_tokens_with_mask = []\n            word_tokens_all = []\n            for raw_eegs_obj, ets_obj, word_obj, ffd, gd, gpt, trt, sfd, nFix, fixPos in zipped_data:\n                word_string = load_matlab_string(data_container[word_obj[0]])\n                if is_real_word(word_string):\n                    data_dict = {}\n                    data_dict[\"RAW_EEG\"] = extract_all_fixations(data_container, raw_eegs_obj[0], eeg_float_resolution)\n                    data_dict[\"RAW_ET\"] = extract_all_fixations(data_container, ets_obj[0], np.float32)\n\n                    data_dict[\"FFD\"] = data_container[ffd[0]][()][0, 0] if len(data_container[ffd[0]][()].shape) == 2 else None\n                    data_dict[\"GD\"] = data_container[gd[0]][()][0, 0] if len(data_container[gd[0]][()].shape) == 2 else None\n                    data_dict[\"GPT\"] = data_container[gpt[0]][()][0, 0] if len(data_container[gpt[0]][()].shape) == 2 else None\n                    data_dict[\"TRT\"] = data_container[trt[0]][()][0, 0] if len(data_container[trt[0]][()].shape) == 2 else None\n                    data_dict[\"SFD\"] = data_container[sfd[0]][()][0, 0] if len(data_container[sfd[0]][()].shape) == 2 else None\n                    data_dict[\"nFix\"] = data_container[nFix[0]][()][0, 0] if len(data_container[nFix[0]][()].shape) == 2 else None\n\n                    #fixations_order_per_word.append(np.array(data_container[fixPos[0]]))\n\n                    #print([data_container[obj[word_idx][0]][()] for obj in Alpha_features_data])\n\n\n                    data_dict[\"ALPHA_EEG\"] = np.concatenate([data_container[obj[word_idx][0]][()]\n                                                             if len(data_container[obj[word_idx][0]][()].shape) == 2 else []\n                                                             for obj in Alpha_features_data], 0)\n\n                    data_dict[\"BETA_EEG\"] = np.concatenate([data_container[obj[word_idx][0]][()]\n                                                            if len(data_container[obj[word_idx][0]][()].shape) == 2 else []\n                                                            for obj in Beta_features_data], 0)\n\n                    data_dict[\"GAMMA_EEG\"] = np.concatenate([data_container[obj[word_idx][0]][()]\n                                                             if len(data_container[obj[word_idx][0]][()].shape) == 2 else []\n                                                             for obj in Gamma_features_data], 0)\n\n                    data_dict[\"THETA_EEG\"] = np.concatenate([data_container[obj[word_idx][0]][()]\n                                                             if len(data_container[obj[word_idx][0]][()].shape) == 2 else []\n                                                             for obj in Theta_features_data], 0)\n\n\n\n\n                    data_dict[\"word_idx\"] = word_idx\n                    data_dict[\"content\"] = word_string\n                    ####################################\n                    word_tokens_all.append(word_string)\n                    if data_dict[\"nFix\"] is not None:\n                        ####################################\n                        data_dict[\"GD_EEG\"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in GD_EEG_features]\n                        data_dict[\"FFD_EEG\"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in FFD_EEG_features]\n                        data_dict[\"TRT_EEG\"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in TRT_EEG_features]\n                        ####################################\n                        word_tokens_has_fixation.append(word_string)\n                        word_tokens_with_mask.append(word_string)\n                    else:\n                        word_tokens_with_mask.append('[MASK]')\n                        \n\n                    word_level_data[word_idx] = data_dict\n                    word_idx += 1\n                else:\n                    print(word_string + \" is not a real word.\")\n        else:\n            # If there are no word-level data it will be word embeddings alone\n            word_level_data = {}\n            word_idx = 0\n            word_tokens_has_fixation = [] \n            word_tokens_with_mask = []\n            word_tokens_all = []\n\n            for word_obj in contentData:\n                word_string = load_matlab_string(data_container[word_obj[0]])\n                if is_real_word(word_string):\n                    data_dict = {}\n                    data_dict[\"RAW_EEG\"] = []\n                    data_dict[\"ICA_EEG\"] = []\n                    data_dict[\"RAW_ET\"] = []\n                    data_dict[\"FFD\"] = None\n                    data_dict[\"GD\"] = None\n                    data_dict[\"GPT\"] = None\n                    data_dict[\"TRT\"] = None\n                    data_dict[\"SFD\"] = None\n                    data_dict[\"nFix\"] = None\n                    data_dict[\"ALPHA_EEG\"] = []\n                    data_dict[\"BETA_EEG\"] = []\n                    data_dict[\"GAMMA_EEG\"] = []\n                    data_dict[\"THETA_EEG\"] = []\n\n                    data_dict[\"word_idx\"] = word_idx\n                    data_dict[\"content\"] = word_string\n                    word_level_data[word_idx] = data_dict\n                    word_idx += 1\n                else:\n                    print(word_string + \" is not a real word.\")\n\n            sentence = \" \".join([load_matlab_string(data_container[word_obj[0]]) for word_obj in word_objects['content']])\n            #print(\"Only available objects for the sentence '{}' are {}.\".format(sentence, available_objects))\n            #word_level_data[\"word_reading_order\"] = extract_word_order_from_fixations(fixations_order_per_word)\n    else:\n        word_tokens_has_fixation = [] \n        word_tokens_with_mask = []\n        word_tokens_all = []\n        word_level_data = {}\n    return word_level_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask"
  },
  {
    "path": "util/get_SST_ternary_dataset.py",
    "content": "import os\nimport numpy as np\nimport torch\nimport pickle\nfrom torch.utils.data import Dataset, DataLoader\nimport json\nimport matplotlib.pyplot as plt\nfrom glob import glob\nfrom transformers import BartTokenizer\nfrom tqdm import tqdm\nfrom fuzzy_match import match\nfrom fuzzy_match import algorithims\n\n\ndef get_SST_dataset(SST_dir_path, ZuCo_used_sentences, ZUCO_SENTIMENT_LABELS):\n    \n    def get_sentiment_label_dict(SST_dictionary_file_path):\n        '''\n            return {phrase_id:sentiment_score(0-1)}\n        '''\n        ret_dict = {}\n        with open(SST_dictionary_file_path) as f:\n            for line in f:\n                if line.startswith('phrase'):\n                    continue\n                else:\n                    phrase_id = int(line.split('|')[0])\n                    label = float(line.split('|')[1].strip())\n                    assert phrase_id not in ret_dict\n                    ret_dict[phrase_id] = label\n        return ret_dict\n\n    def get_phrasestr_phrase_dict(SST_dictionary_file_path):\n        '''\n            return {phrase_str: phrase_id}\n        '''\n        ret_dict = {}\n        with open(SST_dictionary_file_path) as f:\n            for line in f:\n                phrase_str = line.split('|')[0]\n                phrase_id = int(line.split('|')[1].strip())\n                assert phrase_str not in ret_dict\n                ret_dict[phrase_str] = phrase_id\n        return ret_dict\n\n    def get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path):\n        '''\n            return {sentence_str:label(0-1)}\n        '''\n        phraseID_2_label = get_sentiment_label_dict(SST_labels_file_path)\n        phraseStr_2_phraseID = get_phrasestr_phrase_dict(SST_dictionary_file_path)\n\n        sentence_2_label_all = {}\n        sentence_2_label_ternary = {}\n        with open(SST_sentences_file_path) as f:\n            for line in f:\n                if line.startswith('sentence_index'):\n                    continue\n                else:\n                    parsed_line = line.split('\\t')\n                    assert len(parsed_line) == 2\n                    sentence = parsed_line[1].strip()\n                    # convert -LRB- to (, -RRB- to ):\n                    sentence = sentence.replace('-LRB-','(').replace('-RRB-',')').replace('Ã©','é')\n                    if sentence not in phraseStr_2_phraseID:\n                        # print(f'[ERROR]sentence-phrase match not found in dictionary, skipped: {sentence}')\n                        # print()\n                        continue\n                    sent_phrase_id = phraseStr_2_phraseID[sentence]\n                    label = phraseID_2_label[sent_phrase_id]\n                    \n                    # add to all dict\n                    if sentence not in sentence_2_label_all:\n                        sentence_2_label_all[sentence] = label\n\n                    # add to ternary dict\n                    if sentence not in sentence_2_label_ternary:\n                        if label<=0.2:\n                            label = 0\n                            sentence_2_label_ternary[sentence] = label\n                        elif (label > 0.4) and (label<=0.6): \n                            label = 1\n                            sentence_2_label_ternary[sentence] = label\n                        elif label>0.8:\n                            label = 2\n                            sentence_2_label_ternary[sentence] = label\n\n        return sentence_2_label_all, sentence_2_label_ternary\n\n\n    SST_sentences_file_path = os.path.join(SST_dir_path,'datasetSentences.txt')\n    if not os.path.isfile(SST_sentences_file_path):\n        print(f'NOT FOUND file: {SST_sentences_file_path}')\n    SST_labels_file_path = os.path.join(SST_dir_path,'sentiment_labels.txt')\n    if not os.path.isfile(SST_labels_file_path):\n        print(f'NOT FOUND file: {SST_labels_file_path}')\n    SST_dictionary_file_path = os.path.join(SST_dir_path,'dictionary.txt')\n    if not os.path.isfile(SST_dictionary_file_path):\n        print(f'NOT FOUND file: {SST_dictionary_file_path}')\n\n    sentence_2_label_all, sentence_2_label_ternary = get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path)\n    print('original ternary dataset size:', len(sentence_2_label_ternary))\n\n    ZuCo_used_sentences = list(ZUCO_SENTIMENT_LABELS)\n\n    filtered_ternary_dataset = {}\n    filtered_pairs = []\n    for key,value in sentence_2_label_ternary.items():\n        add_instance = True\n        for used_sent in ZuCo_used_sentences:\n            if algorithims.trigram(used_sent, key) > 0.7:\n                # print(f'Filter match: \\n\\t{used_sent}\\n\\t{key}')\n                # print('###########################')\n                filtered_pairs.append((used_sent, key))\n                ZuCo_used_sentences.remove(used_sent)\n                add_instance = False\n                break\n        if add_instance:\n            filtered_ternary_dataset[key] = value\n    \n    print('filtered instance number:', len(filtered_pairs))\n    print('filtered ternary dataset size:', len(filtered_ternary_dataset))\n    print('unmatched remaining sentences:', ZuCo_used_sentences)\n    print('unmatched remaining sentences length:', len(ZuCo_used_sentences))\n    with open('temp.txt','w') as temp:\n        for matched_pair in filtered_pairs:\n            temp.write('#######\\n')\n            temp.write('\\t'+matched_pair[0]+'\\n')\n            temp.write('\\t'+matched_pair[1]+'\\n')\n            temp.write('\\n')\n\n    with open('./dataset/stanfordsentiment/ternary_dataset.json', 'w') as out:\n        json.dump(filtered_ternary_dataset,out, indent = 4)\n    print('write json to /dataset/stanfordsentiment/ternary_dataset.json')\n\nif __name__ == '__main__':\n    print('##############################')\n    print('start generating stanfordSentimentTreebank ternary sentiment dataset...')\n    SST_dir_path = './dataset/stanfordsentiment/stanfordSentimentTreebank'\n    ZuCo_task1_csv_path = './dataset/ZuCo/task_materials/sentiment_labels_task1.csv'\n    ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))\n\n    get_SST_dataset(SST_dir_path, ZuCo_task1_csv_path, ZUCO_SENTIMENT_LABELS)"
  },
  {
    "path": "util/get_sentiment_labels.py",
    "content": "import os\nfrom glob import glob\nimport json\n\nprint('##############################')\nprint('start generating ZuCo task1-SR sentiment labels...')\n\n\nsentiment_labels_task1_csv_path = './dataset/ZuCo/task_materials/sentiment_labels_task1.csv'\n\nsentiment_labels = {}\nwith open(sentiment_labels_task1_csv_path, 'r') as f:\n    for line in f:\n        if line.startswith('sentence_id') or line.startswith('#'):\n            continue\n        else:\n            parsed_line = line.split(';')\n            # handle edge case:\n            if '\\\";' in line:\n                sent_text = line.split('\\\";')[0].split('\\\"')[1]\n            else:\n                sent_text = parsed_line[1]\n            label = int(parsed_line[-1].strip())\n            sentiment_labels[sent_text] = label\n\noutput_dir = f'./dataset/ZuCo/task1-SR/sentiment_labels'\nif not os.path.exists(output_dir):\n    os.makedirs(output_dir)\n\nwith open(os.path.join(output_dir, 'sentiment_labels.json'), 'w') as out:\n    json.dump(sentiment_labels,out,indent = 4)\n    print('write to ./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json')\n\n"
  }
]