Repository: NeuSpeech/EEG-To-Text
Branch: main
Commit: 01582402809a
Files: 28
Total size: 173.2 KB
Directory structure:
gitextract_ginxdd5t/
├── .gitignore
├── README.md
├── config.py
├── data.py
├── environment.yml
├── eval_decoding.py
├── eval_sentiment.py
├── model_decoding.py
├── model_sentiment.py
├── scripts/
│ ├── eval_decoding_1.sh
│ ├── eval_decoding_2.sh
│ ├── eval_decoding_3.sh
│ ├── eval_decoding_4.sh
│ ├── eval_sentiment_zeroshot_pipeline.sh
│ ├── prepare_dataset.sh
│ ├── train_decoding.sh
│ ├── train_decoding_1.sh
│ ├── train_eeg_sentiment_baseline.sh
│ ├── train_eval_zeroshot_pipeline.sh
│ └── train_text_sentiment_classifier.sh
├── train_decoding.py
├── train_sentiment_baseline.py
├── train_sentiment_textbased.py
└── util/
├── construct_dataset_mat_to_pickle_v1.py
├── construct_dataset_mat_to_pickle_v2.py
├── data_loading_helpers_modified.py
├── get_SST_ternary_dataset.py
└── get_sentiment_labels.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pt
*.pickle
*.mat
*.json
*.txt
# Byte-compiled / optimized / DLL files
__pycache__/
csv_results/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: README.md
================================================
The **main branch** contains the final code for our "Are EEG-to-Text Models Working?" paper.
Accepted by [IJCAI workshop 2024](https://github.com/user-attachments/files/16624318/IJCAI_hyejeongjo_poster_Final.pdf)
If you have any questions, you can write them in the Issues section or email Hyejeong Jo at girlsending0@khu.ac.kr.
check our new paper with full detailed comparison of different models on this task at [https://arxiv.org/abs/2405.06459](https://arxiv.org/abs/2405.06459)
overview

performance

# Correction on [(AAAI 2022) Open Vocabulary EEG-To-Text Decoding and Zero-shot sentiment classification](https://arxiv.org/abs/2112.02690)
# results and code is updated on **master** branch
# results and code is updated on **master** branch
# results and code is updated on **master** branch
**First of all, we are not pointing at others, we do this correction due to no offense, but a kind reminder of being careful of the string generation process.
We repsect Mr. Wang very much, and appreciate his great contribution in this area.**
After scrutilizing [the original code shared by Wang Zhenhailong](https://github.com/MikeWangWZHL/EEG-To-Text), we discovered that the eval method have an unintentional but very serious mistake in generating predicted strings, which is using teacher forcing implicitly.
The code which reaches my concern is:
```python
seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch)
logits = seq2seqLMoutput.logits # bs*seq_len*voc_sz
probs = logits[0].softmax(dim = 1)
values, predictions = probs.topk(1)
predictions = torch.squeeze(predictions)
predicted_string = tokenizer.decode(predictions)
```
Therefore resulting in [predictions like below](https://github.com/MikeWangWZHL/EEG-To-Text/blob/main/results/task1_task2_taskNRv2-BrainTranslator_skipstep1-all_generation_results-7_22.txt#L61):
```
target string: It isn't that Stealing Harvard is a horrible movie -- if only it were that grand a failure!
predicted string: was't a the. is was a bad place, it it it were a.. movie.
################################################
target string: It just doesn't have much else... especially in a moral sense.
predicted string: was so't work the to to and not the country sense.
################################################
target string: Those unfamiliar with Mormon traditions may find The Singles Ward occasionally bewildering.
predicted string: who with the history may be themselves Mormoning''s amusingering.
################################################
target string: Viewed as a comedy, a romance, a fairy tale, or a drama, there's nothing remotely triumphant about this motion picture.
predicted string: the from a whole, it film, and comedy tale, and a tragic, it is nothing quite romantic about it. picture.
################################################
target string: But the talented cast alone will keep you watching, as will the fight scenes.
predicted string: the most and of cannot not the entertained. and they the music against.
################################################
target string: It's solid and affecting and exactly as thought-provoking as it should be.
predicted string: was a, it, it what it.provoking as it is be.
################################################
target string: Thanks largely to Williams, all the interesting developments are processed in 60 minutes -- the rest is just an overexposed waste of film.
predicted string: to to the, the of films and in in in a minutes. and longest is a a afteragerposure, of time time
################################################
target string: Cantet perfectly captures the hotel lobbies, two-lane highways, and roadside cafes that permeate Vincent's days
predicted string: urtor was describes the spirit'sies and the ofstory streets, and the parking of areate the's life.</s>'sgggggggg,,,,,,,,,,,,,,</s>,,,,,
################################################
target string: An important movie, a reminder of the power of film to move us and to make us examine our values.
predicted string: nie part in " classic of the importance of the, shape people, our make us think our lives,
################################################
target string: Too much of this well-acted but dangerously slow thriller feels like a preamble to a bigger, more complicated story, one that never materializes.
predicted string: bad of a is-known film not over- is like a film-ble to a much, more dramatic story. which that is endsizes.
```
In addition, we noticed that some people are using it as code base which generates concerning results. We are not condemning these researchers, we just want to notice them and maybe we can do something together to resolve this problem.
[BELT Bootstrapping Electroencephalography-to-Language Decoding and Zero-Shot SenTiment Classification by Natural Language Supervision](https://arxiv.org/pdf/2309.12056)
[Aligning Semantic in Brain and Language: A Curriculum Contrastive Method for Electroencephalography-to-Text Generation](https://ieeexplore.ieee.org/iel7/7333/4359219/10248031.pdf)
[UniCoRN: Unified Cognitive Signal ReconstructioN bridging cognitive signals and human language](https://arxiv.org/pdf/2307.05355)
[Semantic-aware Contrastive Learning for Electroencephalography-to-Text Generation with Curriculum Learning](https://arxiv.org/pdf/2301.09237)
[DeWave: Discrete EEG Waves Encoding for Brain Dynamics to Text Translation](https://arxiv.org/pdf/2309.14030)
We have written a corrected version to use model.generate to evaluate the model, the result is not so good.
Basicly, we changed the model_decoding.py and eval_decoding.py to add model.generate for its originally nn.Module class model, and used model.generate to predict strings.
**We are open to everyone to scrutinize on this corrected code and run the code. Then, we will show the final performance of this model in this repo and formalize a technical paper.**
# We really appreciate the great contribution made by Mr. Wang, however, we should prevent others from continuing this misunderstanding.
This work was supported by the Culture, Sports and Tourism R&D Program through the Korea Creative Content Agency grant funded by the Ministry of Culture, Sports and Tourism (RS-2023-00226263), the Institute for Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. RS-2024-00509257, Global AI Frontier Lab), the Information Technology Research Center (ITRC) support program (IITP-2024-RS-2024-00438239) supervised by the IITP, and the IITP grant funded by the Korea government (MSIT) (No. RS-2022-00155911, Artificial Intelligence Convergence Innovation Human Resources Development, Kyung Hee University).
================================================
FILE: config.py
================================================
import argparse
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def get_config(case):
if case == 'train_decoding':
# args config for training EEG-To-Text decoder
parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder')
parser.add_argument('-m', '--model_name', help='choose from {BrainTranslator, BrainTranslatorNaive}', default = "BrainTranslator" ,required=True)
parser.add_argument('-t', '--task_name', help='choose from {task1,task1_task2, task1_task2_task3,task1_task2_taskNRv2}', default = "task1", required=True)
parser.add_argument('-1step', '--one_step', dest='skip_step_one', action='store_true')
parser.add_argument('-2step', '--two_step', dest='skip_step_one', action='store_false')
parser.add_argument('-pre', '--pretrained', dest='use_random_init', action='store_false')
parser.add_argument('-rand', '--rand_init', dest='use_random_init', action='store_true')
parser.add_argument('-load1', '--load_step1_checkpoint', dest='load_step1_checkpoint', action='store_true')
parser.add_argument('-no-load1', '--not_load_step1_checkpoint', dest='load_step1_checkpoint', action='store_false')
parser.add_argument('-ne1', '--num_epoch_step1', type = int, help='num_epoch_step1', default = 20, required=True)
parser.add_argument('-ne2', '--num_epoch_step2', type = int, help='num_epoch_step2', default = 30, required=True)
parser.add_argument('-lr1', '--learning_rate_step1', type = float, help='learning_rate_step1', default = 0.00005, required=True)
parser.add_argument('-lr2', '--learning_rate_step2', type = float, help='learning_rate_step2', default = 0.0000005, required=True)
parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True)
parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/decoding', required=True)
parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)
parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)
parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)
parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')
parser.add_argument('-train_input', '--train_input', help='add noise' ,required=True)
args = vars(parser.parse_args())
elif case == 'train_sentiment_baseline':
# args config for training EEG-based sentiment baselines
parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder')
parser.add_argument('-m', '--model_name', help='choose from {BaselineMLP, BaselineLSTM, NaiveFinetuneBert}', default = "NaiveFinetuneBert" ,required=True)
parser.add_argument('-ne', '--num_epoch', type = int, help='num_epoch', default = 30, required=True)
parser.add_argument('-lr', '--learning_rate', type = float, help='learning_rate', default = 0.00001, required=True)
parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True)
parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/eeg_sentiment', required=True)
parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)
parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)
parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)
parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')
args = vars(parser.parse_args())
elif case == 'train_sentiment_textbased':
# args config for training text-based sentiment classification models
parser = argparse.ArgumentParser(description='Specify config args for training text-based sentiment classifiers')
parser.add_argument('-d', '--dataset_name', help='zero-shot setting: using external dataset from stanford sentiment treebank, pass in SST; to use ZuCo\'s own text-sentiment pairs, pass in ZuCo', default = "SST" ,required=True)
parser.add_argument('-m', '--model_name', help='choose from {pretrain_Bert, pretrain_RoBerta, pretrain_Bart}', default = "pretrain_Bart" ,required=True)
parser.add_argument('-ne', '--num_epoch', type = int, help='num_epoch', default = 20, required=True)
parser.add_argument('-lr', '--learning_rate', type = float, help='learning_rate', default = 0.0001, required=True)
parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True)
parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/text_sentiment_classifier', required=True)
parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)
parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)
parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)
parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')
args = vars(parser.parse_args())
elif case == 'eval_decoding':
# args config for evaluating EEG-To-Text decoder
parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-To-Text decoder')
parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=True)
parser.add_argument('-conf', '--config_path', help='specify training config json' ,required=True)
parser.add_argument('-test_input', '--test_input', help='add noise' ,required=True)
parser.add_argument('-train_input', '--train_input', help='add noise' ,required=True)
parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')
args = vars(parser.parse_args())
elif case == 'eval_sentiment':
# args config for sentiment classification models
parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-based sentiment classification, including Zero-shot pipeline')
# choose model_name = 'ZeroShotSentimentDiscovery' to evaluate Zero-shot pipeline
parser.add_argument('-m', '--model_name', help='choose from {BaselineMLP, BaselineLSTM, NaiveFinetuneBert, FinetunedBertOnText, FinetunedRoBertaOnText, FinetunedBartOnText, ZeroShotSentimentDiscovery}', default = "ZeroShotSentimentDiscovery" ,required=True)
parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=False) # required if NOT evaluating Zero-shot pipeline
parser.add_argument('-conf', '--config_path', help='specify model config json' ,required=False) # required if NOT evaluating Zero-shot pipeline
parser.add_argument('-checkpoint_DEC', '--decoder_checkpoint_path', help='specify decoder checkpoint for Zero-shot pipeline ', required=False) # required if evaluating Zero-shot pipeline
parser.add_argument('-checkpoint_CLS', '--classifier_checkpoint_path', help='specify classifier checkpoint for Zero-shot pipeline ', required=False) # required if evaluating Zero-shot pipeline
parser.add_argument('-conf_DEC', '--decoder_config_path', help='specify decoder config json' ,required=False) # required if evaluating Zero-shot pipeline
parser.add_argument('-conf_CLS', '--classifier_config_path', help='specify classifier config json' ,required=False) # required if evaluating Zero-shot pipeline
parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)
parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)
parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)
parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')
args = vars(parser.parse_args())
return args
================================================
FILE: data.py
================================================
import os
import numpy as np
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
import json
import matplotlib.pyplot as plt
from glob import glob
from transformers import BartTokenizer, BertTokenizer
from tqdm import tqdm
from fuzzy_match import match
from fuzzy_match import algorithims
from transformers import T5Tokenizer
# macro
#ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))
#SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))
def normalize_1d(input_tensor):
# normalize a 1d tensor
mean = torch.mean(input_tensor)
std = torch.std(input_tensor)
input_tensor = (input_tensor - mean)/std
return input_tensor
def get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], max_len = 56, add_CLS_token = False, test_input="noise"):
def get_word_embedding_eeg_tensor(word_obj, eeg_type, bands):
frequency_features = []
for band in bands:
frequency_features.append(word_obj['word_level_EEG'][eeg_type][eeg_type+band])
word_eeg_embedding = np.concatenate(frequency_features)
if len(word_eeg_embedding) != 105*len(bands):
print(f'expect word eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None')
return None
# assert len(word_eeg_embedding) == 105*len(bands)
return_tensor = torch.from_numpy(word_eeg_embedding)
return normalize_1d(return_tensor)
def get_sent_eeg(sent_obj, bands):
sent_eeg_features = []
for band in bands:
key = 'mean'+band
sent_eeg_features.append(sent_obj['sentence_level_EEG'][key])
sent_eeg_embedding = np.concatenate(sent_eeg_features)
assert len(sent_eeg_embedding) == 105*len(bands)
return_tensor = torch.from_numpy(sent_eeg_embedding)
return normalize_1d(return_tensor)
if sent_obj is None:
# print(f' - skip bad sentence')
return None
input_sample = {}
# get target label
target_string = sent_obj['content']
target_tokenized = tokenizer(target_string, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)
input_sample['target_ids'] = target_tokenized['input_ids'][0]
# get sentence level EEG features
sent_level_eeg_tensor = get_sent_eeg(sent_obj, bands)
# try:
# sent_level_eeg_tensor = torch.from_numpy(sent_obj['sentence_level_EEG']) # This gives a dictionary
# except:
# return None
if torch.isnan(sent_level_eeg_tensor).any():
# print('[NaN sent level eeg]: ', target_string)
return None
# if sent_level_eeg_tensor.shape[1] < 30:
# return None
input_sample['sent_level_EEG'] = sent_level_eeg_tensor
#input_sample['sent_level_EEG'] = torch.randn(sent_level_eeg_tensor.size()) # random input code
#print("NOISE:", input_sample['sent_level_EEG'])
# get sentiment label
# handle some wierd case
if 'emp11111ty' in target_string:
target_string = target_string.replace('emp11111ty','empty')
if 'film.1' in target_string:
target_string = target_string.replace('film.1','film.')
#if target_string in ZUCO_SENTIMENT_LABELS:
# input_sample['sentiment_label'] = torch.tensor(ZUCO_SENTIMENT_LABELS[target_string]+1) # 0:Negative, 1:Neutral, 2:Positive
#else:
# input_sample['sentiment_label'] = torch.tensor(-100) # dummy value
input_sample['sentiment_label'] = torch.tensor(-100) # dummy value
# get input embeddings
word_embeddings = []
"""add CLS token embedding at the front"""
if add_CLS_token:
word_embeddings.append(torch.ones(105*len(bands)))
for word in sent_obj['word']:
# add each word's EEG embedding as Tensors
word_level_eeg_tensor = get_word_embedding_eeg_tensor(word, eeg_type, bands = bands)
# check none, for v2 dataset
if word_level_eeg_tensor is None:
return None
# check nan:
if torch.isnan(word_level_eeg_tensor).any():
# print()
# print('[NaN ERROR] problem sent:',sent_obj['content'])
# print('[NaN ERROR] problem word:',word['content'])
# print('[NaN ERROR] problem word feature:',word_level_eeg_tensor)
# print()
return None
word_embeddings.append(word_level_eeg_tensor)
# pad to max_len
while len(word_embeddings) < max_len:
word_embeddings.append(torch.zeros(105*len(bands)))
if test_input=='noise':
rand_eeg= torch.randn(torch.stack(word_embeddings).size())
input_sample['input_embeddings'] = rand_eeg # max_len * (105*num_bands)
# print("rand_eeg:", rand_eeg)
# print("input_embeddings:", input_sample['input_embeddings'].shape)
else:
input_sample['input_embeddings'] = torch.stack(word_embeddings) # max_len * (105*num_bands)
print("EEG", input_sample['input_embeddings'])
# mask out padding tokens
input_sample['input_attn_mask'] = torch.zeros(max_len) # 0 is masked out
if add_CLS_token:
input_sample['input_attn_mask'][:len(sent_obj['word'])+1] = torch.ones(len(sent_obj['word'])+1) # 1 is not masked
else:
input_sample['input_attn_mask'][:len(sent_obj['word'])] = torch.ones(len(sent_obj['word'])) # 1 is not masked
# mask out padding tokens reverted: handle different use case: this is for pytorch transformers
input_sample['input_attn_mask_invert'] = torch.ones(max_len) # 1 is masked out
if add_CLS_token:
input_sample['input_attn_mask_invert'][:len(sent_obj['word'])+1] = torch.zeros(len(sent_obj['word'])+1) # 0 is not masked
else:
input_sample['input_attn_mask_invert'][:len(sent_obj['word'])] = torch.zeros(len(sent_obj['word'])) # 0 is not masked
# mask out target padding for computing cross entropy loss
input_sample['target_mask'] = target_tokenized['attention_mask'][0]
input_sample['seq_len'] = len(sent_obj['word'])
# clean 0 length data
if input_sample['seq_len'] == 0:
print('discard length zero instance: ', target_string)
return None
return input_sample
class ZuCo_dataset(Dataset):
def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], setting = 'unique_sent', is_add_CLS_token = False, test_input='noise'):
self.inputs = []
self.tokenizer = tokenizer
if not isinstance(input_dataset_dicts,list):
input_dataset_dicts = [input_dataset_dicts]
print(f'[INFO]loading {len(input_dataset_dicts)} task datasets')
for input_dataset_dict in input_dataset_dicts:
if subject == 'ALL':
subjects = list(input_dataset_dict.keys())
print('[INFO]using subjects: ', subjects)
else:
subjects = [subject]
total_num_sentence = len(input_dataset_dict[subjects[0]])
train_divider = int(0.8*total_num_sentence)
dev_divider = train_divider + int(0.1*total_num_sentence)
print(f'train divider = {train_divider}')
print(f'dev divider = {dev_divider}')
if setting == 'unique_sent':
# take first 80% as trainset, 10% as dev and 10% as test
if phase == 'train':
print('[INFO]initializing a train set...')
for key in subjects:
for i in range(train_divider):
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input)
if input_sample is not None:
self.inputs.append(input_sample)
elif phase == 'dev':
print('[INFO]initializing a dev set...')
for key in subjects:
for i in range(train_divider,dev_divider):
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input)
if input_sample is not None:
self.inputs.append(input_sample)
elif phase == 'test':
print('[INFO]initializing a test set...')
for key in subjects:
for i in range(dev_divider,total_num_sentence):
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, test_input=test_input)
if input_sample is not None:
self.inputs.append(input_sample)
elif setting == 'unique_subj':
print('WARNING!!! only implemented for SR v1 dataset ')
# subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train
# subject ['ZMG'] for dev
# subject ['ZPH'] for test
if phase == 'train':
print(f'[INFO]initializing a train set using {setting} setting...')
for i in range(total_num_sentence):
for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']:
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
if input_sample is not None:
self.inputs.append(input_sample)
if phase == 'dev':
print(f'[INFO]initializing a dev set using {setting} setting...')
for i in range(total_num_sentence):
for key in ['ZMG']:
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
if input_sample is not None:
self.inputs.append(input_sample)
if phase == 'test':
print(f'[INFO]initializing a test set using {setting} setting...')
for i in range(total_num_sentence):
for key in ['ZPH']:
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
if input_sample is not None:
self.inputs.append(input_sample)
print('++ adding task to dataset, now we have:', len(self.inputs))
print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size())
print()
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
input_sample = self.inputs[idx]
return (
input_sample['input_embeddings'],
input_sample['seq_len'],
input_sample['input_attn_mask'],
input_sample['input_attn_mask_invert'],
input_sample['target_ids'],
input_sample['target_mask'],
input_sample['sentiment_label'],
#input_sample['sent_level_EEG']
)
# keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask,
"""for train classifier on stanford sentiment treebank text-sentiment pairs"""
class SST_tenary_dataset(Dataset):
def __init__(self, ternary_labels_dict, tokenizer, max_len = 56, balance_class = True):
self.inputs = []
pos_samples = []
neg_samples = []
neu_samples = []
for key,value in ternary_labels_dict.items():
tokenized_inputs = tokenizer(key, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)
input_ids = tokenized_inputs['input_ids'][0]
attn_masks = tokenized_inputs['attention_mask'][0]
label = torch.tensor(value)
# count:
if value == 0:
neg_samples.append((input_ids,attn_masks,label))
elif value == 1:
neu_samples.append((input_ids,attn_masks,label))
elif value == 2:
pos_samples.append((input_ids,attn_masks,label))
print(f'Original distribution:\n\tVery positive: {len(pos_samples)}\n\tNeutral: {len(neu_samples)}\n\tVery negative: {len(neg_samples)}')
if balance_class:
print(f'balance class to {min([len(pos_samples),len(neg_samples),len(neu_samples)])} each...')
for i in range(min([len(pos_samples),len(neg_samples),len(neu_samples)])):
self.inputs.append(pos_samples[i])
self.inputs.append(neg_samples[i])
self.inputs.append(neu_samples[i])
else:
self.inputs = pos_samples + neg_samples + neu_samples
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
input_sample = self.inputs[idx]
return input_sample
# keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask,
'''sanity test'''
if __name__ == '__main__':
check_dataset = 'stanford_sentiment'
if check_dataset == 'ZuCo':
whole_dataset_dicts = []
dataset_path_task1 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task1-SR/pickle/task1-SR-dataset-with-tokens_6-25.pickle'
with open(dataset_path_task1, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
dataset_path_task2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR/pickle/task2-NR-dataset-with-tokens_7-10.pickle'
with open(dataset_path_task2, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
# dataset_path_task3 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset-with-tokens_7-10.pickle'
# with open(dataset_path_task3, 'rb') as handle:
# whole_dataset_dicts.append(pickle.load(handle))
dataset_path_task2_v2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset-with-tokens_7-15.pickle'
with open(dataset_path_task2_v2, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
print()
for key in whole_dataset_dicts[0]:
print(f'task2_v2, sentence num in {key}:',len(whole_dataset_dicts[0][key]))
print()
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
dataset_setting = 'unique_sent'
subject_choice = 'ALL'
print(f'![Debug]using {subject_choice}')
eeg_type_choice = 'GD'
print(f'[INFO]eeg type {eeg_type_choice}')
bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']
print(f'[INFO]using bands {bands_choice}')
train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
print('trainset size:',len(train_set))
print('devset size:',len(dev_set))
print('testset size:',len(test_set))
elif check_dataset == 'stanford_sentiment':
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer)
print('SST dataset size:',len(SST_dataset))
print(SST_dataset[0])
print(SST_dataset[1])
================================================
FILE: environment.yml
================================================
name: EEGToText
channels:
- pytorch
- anaconda
- conda-forge
- huggingface
dependencies:
- pytorch=1.9.0
- torchaudio=0.9.0
- cudatoolkit=11.1
- scipy=1.6.2
- h5py=3.4.0
- tqdm=4.62.0
- matplotlib=3.3.2
- transformers=4.6.1
- nltk=3.5
- pip=21.0.1
- pip:
- fuzzy-match==0.0.1
- rouge==1.0.0
================================================
FILE: eval_decoding.py
================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm
import torch.nn.functional as F
import time
from transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, PegasusForConditionalGeneration, PegasusTokenizer, T5Tokenizer, T5ForConditionalGeneration, BertGenerationDecoder
from data import ZuCo_dataset
from model_decoding import BrainTranslator, BrainTranslatorNaive, T5Translator
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from rouge import Rouge
from config import get_config
import evaluate
from evaluate import load
metric = evaluate.load("sacrebleu")
cer_metric = load("cer")
wer_metric = load("wer")
def remove_text_after_token(text, token='</s>'):
# 특정 토큰 이후의 텍스트를 찾아 제거
token_index = text.find(token)
if token_index != -1: # 토큰이 발견된 경우
return text[:token_index] # 토큰 이전까지의 텍스트 반환
return text # 토큰이 없으면 원본 텍스트 반환
def eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = './results/temp.txt' , score_results='./score_results/task.txt'):
# modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
start_time = time.time()
model.eval() # Set model to evaluate mode
target_tokens_list = []
target_string_list = []
pred_tokens_list = []
pred_string_list = []
pred_tokens_list_previous = []
pred_string_list_previous = []
with open(output_all_results_path,'w') as f:
for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels in tqdm(dataloaders['test']):
# load in batch
input_embeddings_batch = input_embeddings.to(device).float() # B, 56, 840
input_masks_batch = input_masks.to(device) # B, 56
target_ids_batch = target_ids.to(device) # B, 56
input_mask_invert_batch = input_mask_invert.to(device) # B, 56
target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch[0].tolist(), skip_special_tokens = True)
target_string = tokenizer.decode(target_ids_batch[0], skip_special_tokens = True)
# print('target ids tensor:',target_ids_batch[0])
# print('target ids:',target_ids_batch[0].tolist())
# print('target tokens:',target_tokens)
# print('target string:',target_strininvert.to(device) # B, 56
f.write(f'target string: {target_string}\n')
# add to list for later calculate bleu metric
target_tokens_list.append([target_tokens])
target_string_list.append(target_string)
"""replace padding ids in target_ids with -100"""
target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100
# target_ids_batch_label = target_ids_batch.clone().detach()
# target_ids_batch_label[target_ids_batch_label == tokenizer.pad_token_id] = -100
# Original code
seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch) # (batch, time, n_class)
logits_previous = seq2seqLMoutput.logits
probs_previous = logits_previous[0].softmax(dim = 1)
values_previous, predictions_previous = probs_previous.topk(1)
predictions_previous = torch.squeeze(predictions_previous)
predicted_string_previous = remove_text_after_token(tokenizer.decode(predictions_previous).split('</s></s>')[0].replace('<s>',''))
f.write(f'predicted string with tf: {predicted_string_previous}\n')
predictions_previous = predictions_previous.tolist()
truncated_prediction_previous = []
for t in predictions_previous:
if t != tokenizer.eos_token_id:
truncated_prediction_previous.append(t)
else:
break
pred_tokens_previous = tokenizer.convert_ids_to_tokens(truncated_prediction_previous, skip_special_tokens = True)
pred_tokens_list_previous.append(pred_tokens_previous)
pred_string_list_previous.append(predicted_string_previous)
# Modify code
predictions=model.generate(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch,
max_length=56,
num_beams=5,
do_sample=True,
repetition_penalty= 5.0,
no_repeat_ngram_size = 2
# num_beams=5,encoder_no_repeat_ngram_size =1,
# do_sample=True, top_k=15,temperature=0.5,num_return_sequences=5,
# early_stopping=True
)
predicted_string=tokenizer.batch_decode(predictions, skip_special_tokens=True)[0]
# predicted_string=predicted_string.squeeze()
predictions=tokenizer.encode(predicted_string)
# print('predicted string:',predicted_string)
f.write(f'predicted string: {predicted_string}\n')
f.write(f'################################################\n\n\n')
# convert to int list
# predictions = predictions.tolist() # 이미 list 형식이다.
truncated_prediction = []
for t in predictions:
if t != tokenizer.eos_token_id:
truncated_prediction.append(t)
else:
break
pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True)
# print('predicted tokens:',pred_tokens)
pred_tokens_list.append(pred_tokens)
pred_string_list.append(predicted_string)
# pred_tokens_list.extend(pred_tokens)
# pred_string_list.extend(predicted_string)
# print('################################################')
# print()
# print(f"pred_string_list : {pred_string_list}")
""" calculate corpus bleu score """
weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)]
corpus_bleu_scores = []
corpus_bleu_scores_previous = []
for weight in weights_list:
# print('weight:',weight)
corpus_bleu_score = corpus_bleu(target_tokens_list, pred_tokens_list, weights = weight)
corpus_bleu_score_previous = corpus_bleu(target_tokens_list, pred_tokens_list_previous, weights = weight)
corpus_bleu_scores.append(corpus_bleu_score)
corpus_bleu_scores_previous.append(corpus_bleu_score_previous)
print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score)
print(f'corpus BLEU-{len(list(weight))} score with tf:', corpus_bleu_score_previous)
""" calculate sacre bleu score """
reference_list = [[item] for item in target_string_list]
#print(f'ref: {reference_list}')
#print(f'pred: {prediction_list}')
sacre_blue = metric.compute(predictions=pred_string_list, references=reference_list)
sacre_blue_previous = metric.compute(predictions=pred_string_list_previous, references=reference_list)
print("sacreblue score: ", sacre_blue, '\n')
print("sacreblue score with tf: ", sacre_blue_previous)
print()
""" calculate rouge score """
rouge = Rouge()
# pred_string_list = [item for sublist in pred_string_list for item in sublist]
# pred_string_list = [item for sublist in pred_string_list for item in sublist]
# pred_string_list_previous = [item for sublist in pred_string_list_previous for item in sublist]
# rouge_scores = rouge.get_scores(pred_string_list, target_string_list, avg = True, ignore_empty=True)
# rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True)
# print('rouge_scores: ', rouge_scores)
# print('rouge_scores with tf:', rouge_scores_previous)
# rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True)
# print('rouge_scores', rouge_scores)
# print('previous rouge_scores', rouge_scores_previous)
try:
rouge_scores = rouge.get_scores(pred_string_list, target_string_list, avg = True, ignore_empty=True)
except ValueError as e:
rouge_scores = 'Hypothesis is empty'
try:
rouge_scores_previous = rouge.get_scores(pred_string_list_previous, target_string_list, avg = True, ignore_empty=True)
except ValueError as e:
rouge_scores_previous = 'Hypothesis is empty'
print()
print()
""" calculate WER score """
#wer = WordErrorRate()
wer_scores = wer_metric.compute(predictions=pred_string_list, references=target_string_list)
wer_scores_previous = wer_metric.compute(predictions=pred_string_list_previous, references=target_string_list)
print("WER score:", wer_scores)
print("WER score with tf:", wer_scores_previous)
""" calculate CER score """
cer_scores = cer_metric.compute(predictions=pred_string_list, references=target_string_list)
cer_scores_previous = cer_metric.compute(predictions=pred_string_list_previous, references=target_string_list)
print("CER score:", cer_scores)
print("CER score with tf:", cer_scores_previous)
end_time = time.time()
print(f"Evaluation took {(end_time-start_time)/60} minutes to execute.")
# score_results (only fix teacher-forcing)
file_content = [
f'corpus_bleu_score = {corpus_bleu_scores}',
f'sacre_blue_score = {sacre_blue}',
f'rouge_scores = {rouge_scores}',
f'wer_scores = {wer_scores}',
f'cer_scores = {cer_scores}',
f'corpus_bleu_score_with_tf = {corpus_bleu_scores_previous}',
f'sacre_blue_score_with_tf = {sacre_blue_previous}',
f'rouge_scores_with_tf = {rouge_scores_previous}',
f'wer_scores_with_tf = {wer_scores_previous}',
f'cer_scores_with_tf = {cer_scores_previous}',
]
with open(score_results, "a") as file_results:
for line in file_content:
if isinstance(line, list):
for item in line:
file_results.write(str(item) + "\n")
else:
file_results.write(str(line) + "\n")
if __name__ == '__main__':
batch_size = 1
''' get args'''
args = get_config('eval_decoding')
test_input = args['test_input']
print("test_input is:", test_input)
train_input = args['train_input']
print("train_input is:", train_input)
''' load training config'''
training_config = json.load(open(args['config_path']))
subject_choice = training_config['subjects']
print(f'[INFO]subjects: {subject_choice}')
eeg_type_choice = training_config['eeg_type']
print(f'[INFO]eeg type: {eeg_type_choice}')
bands_choice = training_config['eeg_bands']
print(f'[INFO]using bands: {bands_choice}')
dataset_setting = 'unique_sent'
task_name = training_config['task_name']
model_name = training_config['model_name']
if test_input == 'EEG' and train_input=='EEG':
print("EEG and EEG")
output_all_results_path = f'./results/{task_name}-{model_name}-all_decoding_results.txt'
score_results = f'./score_results/{task_name}-{model_name}.txt'
else:
output_all_results_path = f'./results/{task_name}-{model_name}-{train_input}_{test_input}-all_decoding_results.txt'
score_results = f'./score_results/{task_name}-{model_name}-{train_input}_{test_input}.txt'
''' set random seeds '''
seed_val = 20 #500
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
''' set up device '''
# use cuda
if torch.cuda.is_available():
dev = 0
else:
dev = "cpu"
# CUDA_VISIBLE_DEVICES=0,1,2,3
device = torch.device(dev)
print(f'[INFO]using device {dev}')
# task_name = 'task1_task2_task3'
''' set up dataloader '''
whole_dataset_dicts = []
if 'task1' in task_name:
dataset_path_task1 = '/data/johj/ZuCo_data/task1-SR/task1_source.pkl'
with open(dataset_path_task1, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
if 'task2' in task_name:
dataset_path_task2 = '/data/johj/ZuCo_data/task2-NR/task2_source.pkl'
with open(dataset_path_task2, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
if 'task3' in task_name:
dataset_path_task3 = '/data/johj/ZuCo_data/task3-TSR/task3_source.pkl'
with open(dataset_path_task3, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
if 'taskNRv2' in task_name:
dataset_path_taskNRv2 = '/data/johj/ZuCo_data/task2-NR-2.0/taskNRv2_source.pkl'
with open(dataset_path_taskNRv2, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
print()
if model_name in ['BrainTranslator','BrainTranslatorNaive']:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
elif model_name == 'PegasusTranslator':
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum')
elif model_name == 'T5Translator':
tokenizer = T5Tokenizer.from_pretrained("t5-large")
# tokenizer.set_prefix_tokens(language='english')
# test dataset
test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=test_input)
dataset_sizes = {"test_set":len(test_set)}
print('[INFO]test_set size: ', len(test_set))
# dataloaders
test_dataloader = DataLoader(test_set, batch_size = batch_size, shuffle=False, num_workers=4)
dataloaders = {'test':test_dataloader}
''' set up model '''
checkpoint_path = args['checkpoint_path']
if model_name == 'BrainTranslator':
pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
model = BrainTranslator(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif model_name == 'BrainTranslatorNaive':
pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
model = BrainTranslatorNaive(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif model_name == 'BertGeneration':
pretrained = BertGenerationDecoder.from_pretrained('google-bert/bert-large-uncased', is_decoder = True)
model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif model_name == 'PegasusTranslator':
pretrained = PegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum')
model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif model_name == 'T5Translator':
pretrained = T5ForConditionalGeneration.from_pretrained("t5-large")
model = T5Translator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
state_dict = torch.load(checkpoint_path)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
'''
if isinstance(model, nn.DataParallel):
model.module.load_state_dict(torch.load(checkpoint_path))
else:
model.load_state_dict(torch.load(checkpoint_path))
'''
# model.load_state_dict(torch.load(checkpoint_path))
model.to(device)
criterion = nn.CrossEntropyLoss()
''' eval '''
eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = output_all_results_path, score_results=score_results)
================================================
FILE: eval_sentiment.py
================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pack_padded_sequence
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
from data import ZuCo_dataset
from model_sentiment import BaselineMLPSentence, BaselineLSTM, FineTunePretrainedTwoStep, ZeroShotSentimentDiscovery, JointBrainTranslatorSentimentClassifier
from model_decoding import BrainTranslator, BrainTranslatorNaive
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
from config import get_config
# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
# preds: numpy array: N * 3
# labels: numpy array: N
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
def flat_accuracy_top_k(preds, labels,k):
topk_preds = []
for pred in preds:
topk = pred.argsort()[-k:][::-1]
topk_preds.append(list(topk))
# print(topk_preds)
topk_preds = list(topk_preds)
right_count = 0
# print(len(labels))
for i in range(len(labels)):
l = labels[i][0]
if l in topk_preds[i]:
right_count+=1
return right_count/len(labels)
def eval_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')):
def logits2PredString(logits, tokenizer):
probs = logits[0].softmax(dim = 1)
# print('probs size:', probs.size())
values, predictions = probs.topk(1)
# print('predictions before squeeze:',predictions.size())
predictions = torch.squeeze(predictions)
predict_string = tokenizer.decode(predictions)
return predict_string
# modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 100000000000
best_acc = 0.0
total_pred_labels = np.array([])
total_true_labels = np.array([])
for epoch in range(1):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['test']:
total_accuracy = 0.0
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
# Iterate over data.
for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in dataloaders[phase]:
input_word_eeg_features = input_word_eeg_features.to(device).float()
input_masks = input_masks.to(device)
input_mask_invert = input_mask_invert.to(device)
sent_level_EEG = sent_level_EEG.to(device)
sentiment_labels = sentiment_labels.to(device)
target_ids = target_ids.to(device)
target_mask = target_mask.to(device)
## forward ###################
if isinstance(model, BaselineMLPSentence):
logits = model(sent_level_EEG) # before softmax
# calculate loss
loss = criterion(logits, sentiment_labels)
elif isinstance(model, BaselineLSTM):
x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False)
logits = model(x_packed)
# calculate loss
loss = criterion(logits, sentiment_labels)
elif isinstance(model, BertForSequenceClassification) or isinstance(model, RobertaForSequenceClassification) or isinstance(model, BartForSequenceClassification):
output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels)
logits = output.logits
loss = output.loss
elif isinstance(model, FineTunePretrainedTwoStep):
output = model(input_word_eeg_features, input_masks, input_mask_invert, sentiment_labels)
logits = output.logits
loss = output.loss
elif isinstance(model, ZeroShotSentimentDiscovery):
print()
print('target string:',tokenizer.decode(target_ids[0]).replace('<pad>','').split('</s>')[0])
"""replace padding ids in target_ids with -100"""
target_ids[target_ids == tokenizer.pad_token_id] = -100
output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels)
logits = output.logits
loss = output.loss
elif isinstance(model, JointBrainTranslatorSentimentClassifier):
print()
print('target string:',tokenizer.decode(target_ids[0]).replace('<pad>','').split('</s>')[0])
"""replace padding ids in target_ids with -100"""
target_ids[target_ids == tokenizer.pad_token_id] = -100
LM_output, classification_output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels)
LM_logits = LM_output.logits
print('pred string:', logits2PredString(LM_logits, tokenizer).split('</s></s>')[0].replace('<s>',''))
classification_loss = classification_output['loss']
logits = classification_output['logits']
loss = classification_loss
###############################
# backward + optimize only if in training phase
if phase == 'train':
# with torch.autograd.detect_anomaly():
loss.backward()
optimizer.step()
# calculate accuracy
preds_cpu = logits.detach().cpu().numpy()
label_cpu = sentiment_labels.cpu().numpy()
total_accuracy += flat_accuracy(preds_cpu, label_cpu)
# add to total pred and label array, for cal F1, precision, recall
pred_flat = np.argmax(preds_cpu, axis=1).flatten()
labels_flat = label_cpu.flatten()
total_pred_labels = np.concatenate((total_pred_labels,pred_flat))
total_true_labels = np.concatenate((total_true_labels,labels_flat))
# statistics
running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss
# print('[DEBUG]loss:',loss.item())
# print('#################################')
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = total_accuracy / len(dataloaders[phase])
print('{} Loss: {:.4f}'.format(phase, epoch_loss))
print('{} Acc: {:.4f}'.format(phase, epoch_acc))
# deep copy the model
if phase == 'test' and epoch_loss < best_loss:
best_loss = epoch_loss
best_acc = epoch_acc
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best test loss: {:4f}'.format(best_loss))
print('Best test acc: {:4f}'.format(best_acc))
print()
print('test sample num:', len(total_pred_labels))
print('total preds:',total_pred_labels)
print('total truth:',total_true_labels)
print('sklearn macro: precision, recall, F1:')
print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='macro'))
print()
print('sklearn micro: precision, recall, F1:')
print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='micro'))
print()
print('sklearn accuracy:')
print(accuracy_score(total_true_labels,total_pred_labels))
print()
if __name__ == '__main__':
args = get_config('eval_sentiment')
''' config param'''
num_epochs = 1
dataset_setting = 'unique_sent'
'''model name'''
# model_name = 'BaselineMLP'
# model_name = 'BaselineLSTM'
# model_name = 'NaiveFinetuneBert'
# model_name = 'FinetunedBertOnText'
# model_name = 'FinetunedRoBertaOnText'
# model_name = 'FinetunedBartOnText'
# model_name = 'ZeroShotSentimentDiscovery'
model_name = args['model_name']
print(f'[INFO] eval {model_name}')
if model_name == 'ZeroShotSentimentDiscovery':
'''load decoder and classifier config'''
config_decoder = json.load(open(args['decoder_config_path']))
config_classifier = json.load(open(args['classifier_config_path']))
'''choose generator'''
# decoder_name = 'BrainTranslator'
# decoder_name = 'BrainTranslatorNaive'
decoder_name = config_decoder['model_name']
decoder_checkpoint = args['decoder_checkpoint_path']
print(f'[INFO] using decoder: {decoder_name}')
'''choose classifier'''
# pretrain_Bert, pretrain_RoBerta, pretrain_Bart
classifier_name = config_classifier['model_name']
classifier_checkpoint = args['classifier_checkpoint_path']
print(f'[INFO] using classifier: {classifier_name}')
else:
checkpoint_path = args['checkpoint_path']
print('[INFO] loading baseline:', checkpoint_path)
batch_size = 1
# subject_choice = 'ALL
subject_choice = args['subjects']
print(f'![Debug]using {subject_choice}')
# eeg_type_choice = 'GD
eeg_type_choice = args['eeg_type']
print(f'[INFO]eeg type {eeg_type_choice}')
# bands_choice = ['_t1']
# bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']
bands_choice = args['eeg_bands']
print(f'[INFO]using bands {bands_choice}')
''' set random seeds '''
seed_val = 312
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
''' set up device '''
# use cuda
if torch.cuda.is_available():
dev = args['cuda']
else:
dev = "cpu"
# CUDA_VISIBLE_DEVICES=0,1,2,3
device = torch.device(dev)
print(f'[INFO]using device {dev}')
''' load pickle'''
whole_dataset_dict = []
dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle'
with open(dataset_path_task1, 'rb') as handle:
whole_dataset_dict.append(pickle.load(handle))
'''set up tokenizer'''
if model_name in ['BaselineMLP','BaselineLSTM', 'NaiveFinetuneBert', 'FinetunedBertOnText']:
print('[INFO]using Bert tokenizer')
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
elif model_name == 'FinetunedBartOnText':
print('[INFO]using Bart tokenizer')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
elif model_name == 'FinetunedRoBertaOnText':
print('[INFO]using RoBerta tokenizer')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
elif model_name == 'ZeroShotSentimentDiscovery':
decoder_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') # Bart
tokenizer = decoder_tokenizer
if classifier_name == 'pretrain_Bert':
sentiment_tokenizer = BertTokenizer.from_pretrained('bert-base-cased') # Bert
elif classifier_name == 'pretrain_Bart':
sentiment_tokenizer = decoder_tokenizer
elif classifier_name == 'pretrain_RoBerta':
sentiment_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
''' set up model '''
if model_name == 'BaselineMLP':
print('[INFO]Model: BaselineMLP')
model = BaselineMLPSentence(input_dim = 840, hidden_dim = 128, output_dim = 3)
elif model_name == 'BaselineLSTM':
print('[INFO]Model: BaselineLSTM')
# model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1)
model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 4)
elif model_name == 'FinetunedBertOnText':
print('[INFO]Model: FinetunedBertOnText')
model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
elif model_name == 'FinetunedRoBertaOnText':
print('[INFO]Model: FinetunedRoBertaOnText')
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)
elif model_name == 'FinetunedBartOnText':
print('[INFO]Model: FinetunedBartOnText')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3)
elif model_name == 'ZeroShotSentimentDiscovery':
print(f'[INFO]Model: ZeroShotSentimentDiscovery, using classifer:{classifier_name}, using generator: {decoder_name}')
pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
if decoder_name == 'BrainTranslator':
decoder = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif decoder_name == 'BrainTranslatorNaive':
decoder = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
decoder.load_state_dict(torch.load(decoder_checkpoint))
if classifier_name == 'pretrain_Bert':
classifier = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
elif classifier_name == 'pretrain_Bart':
classifier = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3)
elif classifier_name == 'pretrain_RoBerta':
classifier = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)
classifier.load_state_dict(torch.load(classifier_checkpoint))
model = ZeroShotSentimentDiscovery(decoder, classifier, decoder_tokenizer, sentiment_tokenizer, device = device)
model.to(device)
if model_name != 'ZeroShotSentimentDiscovery':
# load model and send to device
model.load_state_dict(torch.load(checkpoint_path))
model.to(device)
''' set up dataloader '''
# test dataset
test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = 'unique_sent')
dataset_sizes = {'test': len(test_set)}
# print('[INFO]train_set size: ', len(train_set))
print('[INFO]test_set size: ', len(test_set))
test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4)
# dataloaders
dataloaders = {'test':test_dataloader}
''' set up optimizer and scheduler'''
optimizer_step1 = None
exp_lr_scheduler_step1 = None
''' set up loss function '''
criterion = nn.CrossEntropyLoss()
print('=== start training ... ===')
# return best loss model from step1 training
model = eval_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, tokenizer = tokenizer)
================================================
FILE: model_decoding.py
================================================
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
import math
import numpy as np
""" main architecture for open vocabulary EEG-To-Text decoding"""
class BrainTranslator(nn.Module):
def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
super(BrainTranslator, self).__init__()
self.pretrained = pretrained_layers
# additional transformer encoder, following BART paper about
self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
# print('[INFO]adding positional embedding')
# self.positional_embedding = PositionalEncoding(in_feature)
self.fc1 = nn.Linear(in_feature, decoder_embedding_size)
def addin_forward(self,input_embeddings_batch, input_masks_invert):
"""input_embeddings_batch: batch_size*Seq_len*840"""
"""input_mask: 1 is not masked, 0 is masked"""
"""input_masks_invert: 1 is masked, 0 is not masked"""
# input_embeddings_batch = self.positional_embedding(input_embeddings_batch)
# use src_key_padding_masks
encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert)
# encoded_embedding = self.additional_encoder(input_embeddings_batch)
encoded_embedding = F.relu(self.fc1(encoded_embedding))
return encoded_embedding
@torch.no_grad()
def generate(
self,
input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted,
generation_config = None,
logits_processor = None,
stopping_criteria = None,
prefix_allowed_tokens_fn= None,
synced_gpus= None,
assistant_model = None,
streamer= None,
negative_prompt_ids= None,
negative_prompt_attention_mask = None,
**kwargs,
):
encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)
output=self.pretrained.generate(
inputs_embeds = encoded_embedding,
attention_mask = input_masks_batch[:,:encoded_embedding.shape[1]],
labels = target_ids_batch_converted,
return_dict = True,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**kwargs,)
return output
def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)
# print(f'forward:{input_embeddings_batch.shape,input_masks_batch.shape,input_masks_invert.shape,target_ids_batch_converted.shape,encoded_embedding.shape}')
out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch,
return_dict = True, labels = target_ids_batch_converted)
return out
from transformers import T5Tokenizer
""" main architecture for open vocabulary EEG-To-Text decoding"""
class T5Translator(nn.Module):
def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
super(T5Translator, self).__init__()
self.pretrained = pretrained_layers
self.tokenizer = T5Tokenizer.from_pretrained("t5-large")
# additional transformer encoder, following BART paper about
self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
# print('[INFO]adding positional embedding')
# self.positional_embedding = PositionalEncoding(in_feature)
self.fc1 = nn.Linear(in_feature, decoder_embedding_size)
def addin_forward(self,input_embeddings_batch, input_masks_invert):
"""input_embeddings_batch: batch_size*Seq_len*840"""
"""input_mask: 1 is not masked, 0 is masked"""
"""input_masks_invert: 1 is masked, 0 is not masked"""
# input_embeddings_batch = self.positional_embedding(input_embeddings_batch)
# use src_key_padding_masks
encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert)
# encoded_embedding = self.additional_encoder(input_embeddings_batch)
encoded_embedding = F.relu(self.fc1(encoded_embedding))
return encoded_embedding
@torch.no_grad()
def generate(
self,
input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted,
generation_config = None,
logits_processor = None,
stopping_criteria = None,
prefix_allowed_tokens_fn= None,
synced_gpus= None,
assistant_model = None,
streamer= None,
negative_prompt_ids= None,
negative_prompt_attention_mask = None,
**kwargs,
):
encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)
input_ids = self.tokenizer("transcribe in English: ", return_tensors="pt").input_ids.to(encoded_embedding.device)
self.task_embedding = self.pretrained.shared(input_ids).to(encoded_embedding.device)
task_embedding = self.task_embedding.repeat(encoded_embedding.size(0), 1, 1).to(encoded_embedding.device)
encoded_embedding = torch.cat((task_embedding, encoded_embedding), dim=1)
input_masks_batch = torch.cat((torch.ones(encoded_embedding.size(0), task_embedding.size(1)).to(encoded_embedding.device), input_masks_batch), dim=1)
output=self.pretrained.generate(
inputs_embeds = encoded_embedding,
attention_mask = input_masks_batch[:,:encoded_embedding.shape[1]],
labels = target_ids_batch_converted,
return_dict = True,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**kwargs,)
return output
def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
encoded_embedding=self.addin_forward(input_embeddings_batch, input_masks_invert)
# task definition
input_ids = self.tokenizer("transcribe in English: ", return_tensors="pt").input_ids.to(encoded_embedding.device)
self.task_embedding = self.pretrained.shared(input_ids).to(encoded_embedding.device)
task_embedding = self.task_embedding.repeat(encoded_embedding.size(0), 1, 1).to(encoded_embedding.device)
encoded_embedding = torch.cat((task_embedding, encoded_embedding), dim=1)
input_masks_batch = torch.cat((torch.ones(encoded_embedding.size(0), task_embedding.size(1)).to(encoded_embedding.device), input_masks_batch), dim=1)
out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch,
return_dict = True, labels = target_ids_batch_converted)
return out
""" crippled open vocabulary EEG-To-Text decoding model w/o additional MTE encoder"""
class BrainTranslatorNaive(nn.Module):
def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
super(BrainTranslatorNaive, self).__init__()
'''no additional transformer encoder version'''
self.pretrained = pretrained_layers
self.fc1 = nn.Linear(in_feature, decoder_embedding_size)
def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
"""input_embeddings_batch: batch_size*Seq_len*840"""
"""input_mask: 1 is not masked, 0 is masked"""
"""input_masks_invert: 1 is masked, 0 is not masked"""
encoded_embedding = F.relu(self.fc1(input_embeddings_batch))
out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted)
return out
""" helper modules """
# modified from BertPooler
class Pooler(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# print('[DEBUG] input size:', x.size())
# print('[DEBUG] positional embedding size:', self.pe.size())
x = x + self.pe[:x.size(0), :]
# print('[DEBUG] output x with pe size:', x.size())
return self.dropout(x)
""" Miscellaneous (not working well) """
class BrainTranslatorBert(nn.Module):
def __init__(self, pretrained_layers, in_feature = 840, hidden_size = 768):
super(BrainTranslatorBert, self).__init__()
self.pretrained_Bert = pretrained_layers
self.fc1 = nn.Linear(in_feature, hidden_size)
def forward(self, input_embeddings_batch, input_masks_batch, target_ids_batch):
embedding = F.relu(self.fc1(input_embeddings_batch))
out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = target_ids_batch, return_dict = True)
return out
class EEG2BertMapping(nn.Module):
def __init__(self, in_feature = 840, hidden_size = 512, out_feature = 768):
super(EEG2BertMapping, self).__init__()
self.fc1 = nn.Linear(in_feature, hidden_size)
self.fc2 = nn.Linear(hidden_size, out_feature)
def forward(self, x):
out = F.relu(self.fc1(x))
out = self.fc2(out)
return out
class ContrastiveBrainTextEncoder(nn.Module):
def __init__(self, pretrained_text_encoder, in_feature = 840, eeg_encoder_nhead=8, eeg_encoder_dim_feedforward = 2048, embed_dim = 768):
super(ContrastiveBrainTextEncoder, self).__init__()
# EEG Encoder
self.positional_embedding = PositionalEncoding(in_feature)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=eeg_encoder_nhead, dim_feedforward = eeg_encoder_dim_feedforward, batch_first=True)
self.EEG_Encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
self.EEG_pooler = Pooler(in_feature)
self.ln_final = nn.LayerNorm(in_feature) # to be considered
# project to text embedding
self.EEG_projection = nn.Parameter(torch.empty(in_feature, embed_dim))
# Text Encoder
self.TextEncoder = pretrained_text_encoder
# learned temperature parameter
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, input_EEG_features, input_EEG_attn_mask, input_ids, input_text_attention_masks):
# add positional embedding
input_EEG_features = self.positional_embedding(input_EEG_features)
# get EEG feature embedding
EEG_hiddenstates = self.EEG_Encoder(input_EEG_features, src_key_padding_mask = input_EEG_attn_mask)
EEG_hiddenstates = self.ln_final(EEG_hiddenstates)
EEG_features = self.EEG_pooler(EEG_hiddenstates) # [N, 840]
# project to text embed size
EEG_features = EEG_features @ self.EEG_projection # [N, 768]
# get text feature embedding
Text_features = self.TextEncoder(input_ids = input_ids, attention_mask = input_text_attention_masks, return_dict = True).pooler_output # [N, 768]
# normalized features
EEG_features = EEG_features / EEG_features.norm(dim=-1, keepdim=True) # [N, 768]
Text_features = Text_features / Text_features.norm(dim=-1, keepdim=True) # [N, 768]
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_EEG = logit_scale * EEG_features @ Text_features.t() # [N, N]
logits_per_text = logit_scale * Text_features @ EEG_features.t() # [N, N]
return logits_per_EEG, logits_per_text
================================================
FILE: model_sentiment.py
================================================
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertForSequenceClassification
import math
import numpy as np
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
"""MLP baseline using sentence level eeg"""
# using sent level EEG, MLP baseline for sentiment
class BaselineMLPSentence(nn.Module):
def __init__(self, input_dim = 840, hidden_dim = 128, output_dim = 3):
super(BaselineMLPSentence, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_dim, output_dim) # positive, negative, neutral
self.dropout = nn.Dropout(0.25)
def forward(self, x):
out = self.fc1(x)
out = self.relu1(out)
out = self.fc2(out)
out = self.relu2(out)
out = self.dropout(out)
out = self.fc3(out)
return out
"""bidirectional LSTM baseline using word level eeg"""
class BaselineLSTM(nn.Module):
def __init__(self, input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1):
super(BaselineLSTM, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers = 1, batch_first = True, bidirectional = True)
self.hidden2sentiment = nn.Linear(hidden_dim*2, output_dim)
def forward(self, x_packed):
# input: (N,seq_len,input_dim)
# print(x_packed.data.size())
lstm_out, _ = self.lstm(x_packed)
last_hidden_state = pad_packed_sequence(lstm_out, batch_first = True)[0][:,-1,:]
# print(last_hidden_state.size())
out = self.hidden2sentiment(last_hidden_state)
return out
""" Bert Baseline: Finetuning from a pretrained language model Bert"""
class NaiveFineTunePretrainedBert(nn.Module):
def __init__(self, input_dim = 840, hidden_dim = 768, output_dim = 3, pretrained_checkpoint = None):
super(NaiveFineTunePretrainedBert, self).__init__()
# mapping hidden states dimensioin
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.pretrained_Bert = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
if pretrained_checkpoint is not None:
self.pretrained_Bert.load_state_dict(torch.load(pretrained_checkpoint))
def forward(self, input_embeddings_batch, input_masks_batch, labels):
embedding = F.relu(self.fc1(input_embeddings_batch))
out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = labels, return_dict = True)
return out
""" Finetuning from a pretrained language model BART, two step training"""
class FineTunePretrainedTwoStep(nn.Module):
def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
super(FineTunePretrainedTwoStep, self).__init__()
self.pretrained_layers = pretrained_layers
# additional transformer encoder, following BART paper about
self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
# NOTE: add positional embedding?
# print('[INFO]adding positional embedding')
# self.positional_embedding = PositionalEncoding(in_feature)
self.fc1 = nn.Linear(in_feature, d_model)
def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, labels):
"""input_embeddings_batch: batch_size*Seq_len*840"""
"""input_mask: 1 is not masked, 0 is masked"""
"""input_masks_invert: 1 is masked, 0 is not masked"""
"""labels: sentitment labels 0,1,2"""
# NOTE: add positional embedding?
# input_embeddings_batch = self.positional_embedding(input_embeddings_batch)
# use src_key_padding_masks
encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert)
# encoded_embedding = self.additional_encoder(input_embeddings_batch)
encoded_embedding = F.relu(self.fc1(encoded_embedding))
out = self.pretrained_layers(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = labels)
return out
""" Zero-shot sentiment discovery using a finetuned generation model and a sentiment model pretrained on text """
class ZeroShotSentimentDiscovery(nn.Module):
def __init__(self, brain2text_translator, sentiment_classifier, translation_tokenizer, sentiment_tokenizer, device = 'cpu'):
# only for inference
super(ZeroShotSentimentDiscovery, self).__init__()
self.brain2text_translator = brain2text_translator
self.sentiment_classifier = sentiment_classifier
self.translation_tokenizer = translation_tokenizer
self.sentiment_tokenizer = sentiment_tokenizer
self.device = device
def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):
"""input_embeddings_batch: batch_size*Seq_len*840"""
"""input_mask: 1 is not masked, 0 is masked"""
"""input_masks_invert: 1 is masked, 0 is not masked"""
"""labels: sentitment labels 0,1,2"""
def logits2PredString(logits):
probs = logits[0].softmax(dim = 1)
# print('probs size:', probs.size())
values, predictions = probs.topk(1)
# print('predictions before squeeze:',predictions.size())
predictions = torch.squeeze(predictions)
predict_string = self.translation_tokenizer.decode(predictions)
return predict_string
# only works on batch is one
assert input_embeddings_batch.size()[0] == 1
seq2seqLMoutput = self.brain2text_translator(input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted)
predict_string = logits2PredString(seq2seqLMoutput.logits)
predict_string = predict_string.split('</s></s>')[0]
predict_string = predict_string.replace('<s>','')
print('predict string:', predict_string)
re_tokenized = self.sentiment_tokenizer(predict_string, return_tensors='pt', return_attention_mask = True)
input_ids = re_tokenized['input_ids'].to(self.device) # batch = 1
attn_mask = re_tokenized['attention_mask'].to(self.device) # batch = 1
out = self.sentiment_classifier(input_ids = input_ids, attention_mask = attn_mask, return_dict = True, labels = sentiment_labels)
return out
""" Miscellaneous: jointly learn generation and classification (not working well) """
class BartClassificationHead(nn.Module):
# from transformers: https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim: int,
inner_dim: int,
num_classes: int,
pooler_dropout: float,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class JointBrainTranslatorSentimentClassifier(nn.Module):
def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048, num_labels = 3):
super(JointBrainTranslatorSentimentClassifier, self).__init__()
self.pretrained_generator = pretrained_layers
# additional transformer encoder, following BART paper about
self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
self.fc1 = nn.Linear(in_feature, d_model)
self.num_labels = num_labels
self.pooler = Pooler(d_model)
self.classifier = BartClassificationHead(input_dim = d_model, inner_dim = d_model, num_classes = num_labels, pooler_dropout = pretrained_layers.config.classifier_dropout)
def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):
"""input_embeddings_batch: batch_size*Seq_len*840"""
"""input_mask: 1 is not masked, 0 is masked"""
"""input_masks_invert: 1 is masked, 0 is not masked"""
# NOTE: add positional embedding?
# input_embeddings_batch = self.positional_embedding(input_embeddings_batch)
# use src_key_padding_masks
encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert)
# encoded_embedding = self.additional_encoder(input_embeddings_batch)
encoded_embedding = F.relu(self.fc1(encoded_embedding))
LMoutput = self.pretrained_generator(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted, output_hidden_states = True)
hidden_states = LMoutput.decoder_hidden_states # N, seq_len, hidden_dim
# print('hidden states len:', len(hidden_states))
last_hidden_states = hidden_states[-1]
# print('last hidden states size:', last_hidden_states.size())
sentence_representation = self.pooler(last_hidden_states)
classification_logits = self.classifier(sentence_representation)
loss_fct = nn.CrossEntropyLoss()
classification_loss = loss_fct(classification_logits.view(-1, self.num_labels), sentiment_labels.view(-1))
classification_output = {'loss':classification_loss,'logits':classification_logits}
# print('successful one forward!!!!')
return LMoutput, classification_output
""" helper modules """
# modified from BertPooler
class Pooler(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# print('[DEBUG] input size:', x.size())
# print('[DEBUG] positional embedding size:', self.pe.size())
x = x + self.pe[:x.size(0), :]
# print('[DEBUG] output x with pe size:', x.size())
return self.dropout(x)
================================================
FILE: scripts/eval_decoding_1.sh
================================================
CUDA_VISIBLE_DEVICES=0 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \
--config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \
--test_input EEG \
--train_input EEG \
-cuda cuda:0
CUDA_VISIBLE_DEVICES=0 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \
--config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \
--test_input noise \
--train_input EEG \
-cuda cuda:0
================================================
FILE: scripts/eval_decoding_2.sh
================================================
CUDA_VISIBLE_DEVICES=1 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \
--config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \
--test_input EEG \
--train_input EEG \
-cuda cuda:0
CUDA_VISIBLE_DEVICES=1 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.pt \
--config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_EEG.json \
--test_input noise \
--train_input EEG \
-cuda cuda:0
================================================
FILE: scripts/eval_decoding_3.sh
================================================
CUDA_VISIBLE_DEVICES=2 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \
--config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \
--test_input EEG \
--train_input noise \
-cuda cuda:0
CUDA_VISIBLE_DEVICES=2 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \
--config_path config/decoding/task1_task2_task3_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \
--test_input noise \
--train_input noise \
-cuda cuda:0
================================================
FILE: scripts/eval_decoding_4.sh
================================================
CUDA_VISIBLE_DEVICES=3 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \
--config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \
--test_input EEG \
--train_input noise \
-cuda cuda:0
CUDA_VISIBLE_DEVICES=3 python3 eval_decoding.py \
--checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.pt \
--config_path config/decoding/task1_task2_taskNRv2_finetune_T5Translator_skipstep1_b32_20_30_2e-05_2e-05_unique_sent_noise.json \
--test_input noise \
--train_input noise \
-cuda cuda:0
================================================
FILE: scripts/eval_sentiment_zeroshot_pipeline.sh
================================================
python3 eval_sentiment.py --model_name ZeroShotSentimentDiscovery \
--decoder_checkpoint_path ./checkpoints/decoding/best/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.pt \
--classifier_checkpoint_path ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt \
--decoder_config_path ./config/decoding/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.json \
--classifier_config_path ./config/text_sentiment_classifier/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.json \
--cuda cuda:0
================================================
FILE: scripts/prepare_dataset.sh
================================================
echo "This scirpt construct .pickle files from .mat files from ZuCo dataset."
echo "This script also generates tenary sentiment_labels.json file for ZuCo task1-SR v1.0 and ternary_dataset.json from filtered StanfordSentimentTreebank"
echo "Note: the sentences in ZuCo task1-SR do not overlap with sentences in filtered StanfordSentimentTreebank "
echo "Note: This process can take time, please be patient..."
python3 ./util/construct_dataset_mat_to_pickle_v1.py -t task1-SR
python3 ./util/construct_dataset_mat_to_pickle_v1.py -t task2-NR
python3 ./util/construct_dataset_mat_to_pickle_v1.py -t task3-TSR
python3 ./util/construct_dataset_mat_to_pickle_v2.py
python3 ./util/get_sentiment_labels.py
python3 ./util/get_SST_ternary_dataset.py
================================================
FILE: scripts/train_decoding.sh
================================================
CUDA_VISIBLE_DEVICES=0 python3 train_decoding.py --model_name BrainTranslator \
--task_name task1_task2_task3 \
--one_step \
--pretrained \
--not_load_step1_checkpoint \
--num_epoch_step1 20 \
--num_epoch_step2 30 \
--train_input noise \
-lr1 0.00002 \
-lr2 0.00002 \
-b 32 \
-s ./checkpoints/decoding \
CUDA_VISIBLE_DEVICES=0,1 python3 train_decoding.py --model_name T5Translator \
--task_name task1_task2_task3 \
--one_step \
--pretrained \
--not_load_step1_checkpoint \
--num_epoch_step1 20 \
--num_epoch_step2 30 \
--train_input noise \
-lr1 0.00002 \
-lr2 0.00002 \
-b 32 \
-s ./checkpoints/decoding \
================================================
FILE: scripts/train_decoding_1.sh
================================================
CUDA_VISIBLE_DEVICES=2,3 python3 train_decoding.py --model_name T5Translator \
--task_name task1_task2_taskNRv2 \
--one_step \
--pretrained \
--not_load_step1_checkpoint \
--num_epoch_step1 20 \
--num_epoch_step2 30 \
--train_input EEG \
-lr1 0.00002 \
-lr2 0.00002 \
-b 32 \
-s ./checkpoints/decoding \
CUDA_VISIBLE_DEVICES=2,3 python3 train_decoding.py --model_name T5Translator \
--task_name task1_task2_task3 \
--one_step \
--pretrained \
--not_load_step1_checkpoint \
--num_epoch_step1 20 \
--num_epoch_step2 30 \
--train_input EEG \
-lr1 0.00002 \
-lr2 0.00002 \
-b 32 \
-s ./checkpoints/decoding \
================================================
FILE: scripts/train_eeg_sentiment_baseline.sh
================================================
python3 train_sentiment_baseline.py --model_name BaselineMLP --num_epoch 20 -lr 0.00005 -b 32 -s ./checkpoints/eeg_sentiment -cuda cuda:0
================================================
FILE: scripts/train_eval_zeroshot_pipeline.sh
================================================
echo "###################################"
echo "Training decoder: BART, task1-SR..."
echo "###################################"
echo ""
python3 train_decoding.py --model_name BrainTranslator \
--task_name task1 \
--one_step \
--pretrained \
--not_load_step1_checkpoint \
--num_epoch_step1 20 \
--num_epoch_step2 30 \
-lr1 0.00005 \
-lr2 0.0000005 \
-b 32 \
-s ./checkpoints/decoding \
-cuda cuda:0
echo "###################################"
echo "Training classifier: BART, filtered Stanford Sentiment Treebank..."
echo "###################################"
echo ""
python3 train_sentiment_textbased.py \
--dataset_name SST \
--model_name pretrain_Bart \
--num_epoch 20 \
-lr 0.0001 \
-b 32 \
-s ./checkpoints/text_sentiment_classifier \
-cuda cuda:0
echo "###################################"
echo "Evaluating Zero-shot pipeline: DEC(BART) + CLS(BART)"
echo "###################################"
echo ""
python3 eval_sentiment.py --model_name ZeroShotSentimentDiscovery \
--decoder_checkpoint_path ./checkpoints/decoding/best/task1_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.pt \
--classifier_checkpoint_path ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt \
--decoder_config_path ./config/decoding/task1_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.json \
--classifier_config_path ./config/text_sentiment_classifier/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.json \
--cuda cuda:0
================================================
FILE: scripts/train_text_sentiment_classifier.sh
================================================
python3 train_sentiment_textbased.py \
--dataset_name SST \
--model_name pretrain_Bart \
--num_epoch 20 \
-lr 0.0001 \
-b 32 \
-s ./checkpoints/text_sentiment_classifier \
-cuda cuda:0
================================================
FILE: train_decoding.py
================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm
from transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, PegasusForConditionalGeneration, PegasusTokenizer, T5Tokenizer, T5ForConditionalGeneration, BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderConfig, EncoderDecoderModel
from data import ZuCo_dataset
from model_decoding import BrainTranslator, BrainTranslatorNaive, T5Translator
from config import get_config
def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/decoding/best/temp_decoding.pt', checkpoint_path_last = './checkpoints/decoding/last/temp_decoding.pt'):
# modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 100000000000
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'dev']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
# Iterate over data.
for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels in tqdm(dataloaders[phase]):
# load in batch
input_embeddings_batch = input_embeddings.to(device).float()
input_masks_batch = input_masks.to(device)
input_mask_invert_batch = input_mask_invert.to(device)
target_ids_batch = target_ids.to(device)
"""replace padding ids in target_ids with -100"""
target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch)
"""calculate loss"""
# logits = seq2seqLMoutput.logits # 8*48*50265
# logits = logits.permute(0,2,1) # 8*50265*48
# loss = criterion(logits, target_ids_batch_label) # calculate cross entropy loss only on encoded target parts
# NOTE: my criterion not used
loss = seq2seqLMoutput.loss # use the BART language modeling loss
# """check prediction, instance 0 of each batch"""
# print('target size:', target_ids_batch.size(), ',original logits size:', logits.size(), ',target_mask size', target_mask_batch.size())
# logits = logits.permute(0,2,1)
# for idx in [0]:
# print(f'-- instance {idx} --')
# # print('permuted logits size:', logits.size())
# probs = logits[idx].softmax(dim = 1)
# # print('probs size:', probs.size())
# values, predictions = probs.topk(1)
# # print('predictions before squeeze:',predictions.size())
# predictions = torch.squeeze(predictions)
# # print('predictions:',predictions)
# # print('target mask:', target_mask_batch[idx])
# # print('[DEBUG]target tokens:',tokenizer.decode(target_ids_batch_copy[idx]))
# print('[DEBUG]predicted tokens:',tokenizer.decode(predictions))
# backward + optimize only if in training phase
if phase == 'train':
# with torch.autograd.detect_anomaly():
loss.sum().backward()
optimizer.step()
# statistics
running_loss += loss.sum().item() * input_embeddings_batch.size()[0] # batch loss
# print('[DEBUG]loss:',loss.item())
# print('#################################')
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
print('{} Loss: {:.4f}'.format(phase, epoch_loss))
# deep copy the model
if phase == 'dev' and epoch_loss < best_loss:
best_loss = epoch_loss
best_model_wts = copy.deepcopy(model.state_dict())
'''save checkpoint'''
torch.save(model.state_dict(), checkpoint_path_best)
print(f'update best on dev checkpoint: {checkpoint_path_best}')
# with torch.set_grad_enabled(False):
# traced_model_1 = torch.jit.trace(model, (torch.rand(1, 56, 840).to(device), torch.randint(1, 56).to(device), torch.rand(1, 56).to(device), torch.rand(1, 56).to(device)))
# traced_model_32 = torch.jit.trace(model, (torch.rand(32, 56, 840).to(device), torch.randint(32, 56).to(device), torch.rand(32, 56).to(device), torch.rand(32, 56).to(device)))
# torch.jit.save(traced_model_1, checkpoint_path_best[:-3]+'_1_jit.pt')
# torch.jit.save(traced_model_32, checkpoint_path_best[:-3]+'_32_jit.pt')
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
torch.save(model.state_dict(), checkpoint_path_last)
print(f'update last checkpoint: {checkpoint_path_last}')
# load best model weights
model.load_state_dict(best_model_wts)
return model
def show_require_grad_layers(model):
print()
print(' require_grad layers:')
# sanity check
for name, param in model.named_parameters():
if param.requires_grad:
print(' ', name)
if __name__ == '__main__':
args = get_config('train_decoding')
''' config param'''
dataset_setting = 'unique_sent'
num_epochs_step1 = args['num_epoch_step1']
num_epochs_step2 = args['num_epoch_step2']
step1_lr = args['learning_rate_step1']
step2_lr = args['learning_rate_step2']
batch_size = args['batch_size']
model_name = args['model_name']
# model_name = 'BrainTranslatorNaive' # with no additional transformers
# model_name = 'BrainTranslator'
# task_name = 'task1'
# task_name = 'task1_task2'
# task_name = 'task1_task2_task3'
# task_name = 'task1_task2_taskNRv2'
task_name = args['task_name']
train_input = args['train_input']
print("train_input is:", train_input)
save_path = args['save_path']
if not os.path.exists(save_path):
os.makedirs(save_path)
skip_step_one = args['skip_step_one']
load_step1_checkpoint = args['load_step1_checkpoint']
use_random_init = args['use_random_init']
device_ids = [0] # device setting
if use_random_init and skip_step_one:
step2_lr = 5*1e-4
print(f'[INFO]using model: {model_name}')
if skip_step_one:
save_name = f'{task_name}_finetune_{model_name}_skipstep1_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}_{train_input}'
else:
save_name = f'{task_name}_finetune_{model_name}_2steptraining_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}_{train_input}'
if use_random_init:
save_name = 'randinit_' + save_name
save_path_best = os.path.join(save_path, 'best')
if not os.path.exists(save_path_best):
os.makedirs(save_path_best)
output_checkpoint_name_best = os.path.join(save_path_best, f'{save_name}.pt')
save_path_last = os.path.join(save_path, 'last')
if not os.path.exists(save_path_last):
os.makedirs(save_path_last)
output_checkpoint_name_last = os.path.join(save_path_last, f'{save_name}.pt')
# subject_choice = 'ALL
subject_choice = args['subjects']
print(f'![Debug]using {subject_choice}')
# eeg_type_choice = 'GD
eeg_type_choice = args['eeg_type']
print(f'[INFO]eeg type {eeg_type_choice}')
# bands_choice = ['_t1']
# bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']
bands_choice = args['eeg_bands']
print(f'[INFO]using bands {bands_choice}')
''' set random seeds '''
seed_val = 312
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
''' set up device '''
# use cuda
if torch.cuda.is_available():
# dev = "cuda:3"
dev = args['cuda']
else:
dev = "cpu"
# CUDA_VISIBLE_DEVICES=0,1,2,3
device = torch.device(dev)
print(f'[INFO]using device {dev}')
print()
''' set up dataloader '''
whole_dataset_dicts = []
if 'task1' in task_name:
dataset_path_task1 = '/data/johj/ZuCo_data/task1-SR/task1_source.pkl'
with open(dataset_path_task1, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
if 'task2' in task_name:
dataset_path_task2 = '/data/johj/ZuCo_data/task2-NR/task2_source.pkl'
with open(dataset_path_task2, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
if 'task3' in task_name:
dataset_path_task3 = '/data/johj/ZuCo_data/task3-TSR/task3_source.pkl'
with open(dataset_path_task3, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
if 'taskNRv2' in task_name:
dataset_path_taskNRv2 = '/data/johj/ZuCo_data/task2-NR-2.0/taskNRv2_source.pkl'
with open(dataset_path_taskNRv2, 'rb') as handle:
whole_dataset_dicts.append(pickle.load(handle))
print()
"""save config"""
cfg_dir = './config/decoding/'
if not os.path.exists(cfg_dir):
os.makedirs(cfg_dir)
with open(os.path.join(cfg_dir,f'{save_name}.json'), 'w') as out_config:
json.dump(args, out_config, indent = 4)
if model_name in ['BrainTranslator','BrainTranslatorNaive']:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
elif model_name == 'PegasusTranslator':
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum')
elif model_name == 'T5Translator':
tokenizer = T5Tokenizer.from_pretrained("t5-large")
#tokenizer.set_prefix_tokens(language='english')
# train dataset
train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=train_input)
# dev dataset
dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, test_input=train_input)
# test dataset
# test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}
print('[INFO]train_set size: ', len(train_set))
print('[INFO]dev_set size: ', len(dev_set))
# print('[INFO]test_set size: ', len(test_set))
# train dataloader
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)
# dev dataloader
val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)
# dataloaders
dataloaders = {'train':train_dataloader, 'dev':val_dataloader}
''' set up model '''
if model_name == 'BrainTranslator':
if use_random_init:
config = BartConfig.from_pretrained('facebook/bart-large')
pretrained = BartForConditionalGeneration(config)
else:
pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif model_name == 'BrainTranslatorNaive':
pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
model = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif model_name == 'PegasusTranslator':
pretrained = PegasusForConditionalGeneration.from_pretrained('google/pegasus-xsum')
model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
elif model_name == 'T5Translator':
pretrained = T5ForConditionalGeneration.from_pretrained("t5-large")
model = T5Translator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
model.to(device)
model = torch.nn.DataParallel(model, device_ids=device_ids)
''' training loop '''
######################################################
'''step one trainig: freeze most of BART params'''
######################################################
# closely follow BART paper
if model_name in ['BrainTranslator','BrainTranslatorNaive', 'PegasusTranslator', 'T5Translator']:
for name, param in model.named_parameters():
if param.requires_grad and 'pretrained' in name:
if ('shared' in name) or ('embed_positions' in name) or ('encoder.layers.0' in name):
continue
else:
param.requires_grad = False
elif model_name == 'BertGeneration':
for name, param in model.named_parameters():
if param.requires_grad and 'pretrained' in name:
if ('embeddings' in name) or ('encoder.layer.0' in name):
continue
else:
param.requires_grad = False
if skip_step_one:
if load_step1_checkpoint:
stepone_checkpoint = 'path_to_step_1_checkpoint.pt'
print(f'skip step one, load checkpoint: {stepone_checkpoint}')
model.load_state_dict(torch.load(stepone_checkpoint))
else:
print('skip step one, start from scratch at step two')
else:
''' set up optimizer and scheduler'''
optimizer_step1 = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=step1_lr, momentum=0.9)
exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.1)
''' set up loss function '''
criterion = nn.CrossEntropyLoss()
print('=== start Step1 training ... ===')
# print training layers
show_require_grad_layers(model)
# return best loss model from step1 training
model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs_step1, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
######################################################
'''step two trainig: update whole model for a few iterations'''
######################################################
for name, param in model.named_parameters():
param.requires_grad = True
''' set up optimizer and scheduler'''
optimizer_step2 = optim.SGD(model.parameters(), lr=step2_lr, momentum=0.9)
exp_lr_scheduler_step2 = lr_scheduler.StepLR(optimizer_step2, step_size=30, gamma=0.1)
''' set up loss function '''
criterion = nn.CrossEntropyLoss()
print()
print('=== start Step2 training ... ===')
# print training layers
show_require_grad_layers(model)
'''main loop'''
trained_model = train_model(dataloaders, device, model, criterion, optimizer_step2, exp_lr_scheduler_step2, num_epochs=num_epochs_step2, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
# '''save checkpoint'''
# torch.save(trained_model.state_dict(), os.path.join(save_path,output_checkpoint_name))
================================================
FILE: train_sentiment_baseline.py
================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pack_padded_sequence
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm
from transformers import BertTokenizer, BertLMHeadModel, BertConfig
from data import ZuCo_dataset
from model_sentiment import BaselineMLPSentence, BaselineLSTM, NaiveFineTunePretrainedBert
from config import get_config
# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
# preds: numpy array: N * 3
# labels: numpy array: N
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
def flat_accuracy_top_k(preds, labels,k):
topk_preds = []
for pred in preds:
topk = pred.argsort()[-k:][::-1]
topk_preds.append(list(topk))
# print(topk_preds)
topk_preds = list(topk_preds)
right_count = 0
# print(len(labels))
for i in range(len(labels)):
l = labels[i][0]
if l in topk_preds[i]:
right_count+=1
return right_count/len(labels)
def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/eeg_sentiment/best/test.pt', checkpoint_path_last = './checkpoints/eeg_sentiment/last/test.pt'):
# modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 100000000000
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'dev']:
total_accuracy = 0.0
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
# Iterate over data.
for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]):
input_word_eeg_features = input_word_eeg_features.to(device).float()
sent_level_EEG = sent_level_EEG.to(device)
input_masks = input_masks.to(device)
sentiment_labels = sentiment_labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
if isinstance(model, BaselineMLPSentence):
# forward
logits = model(sent_level_EEG) # before softmax
# calculate loss
loss = criterion(logits, sentiment_labels)
elif isinstance(model, BaselineLSTM):
x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False)
logits = model(x_packed)
# calculate loss
loss = criterion(logits, sentiment_labels)
elif isinstance(model, NaiveFineTunePretrainedBert):
output = model(input_word_eeg_features, input_masks, sentiment_labels)
logits = output.logits
loss = output.loss
# backward + optimize only if in training phase
if phase == 'train':
# with torch.autograd.detect_anomaly():
loss.backward()
optimizer.step()
# calculate accuracy
preds_cpu = logits.detach().cpu().numpy()
label_cpu = sentiment_labels.cpu().numpy()
total_accuracy += flat_accuracy(preds_cpu, label_cpu)
# statistics
running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss
# print('[DEBUG]loss:',loss.item())
# print('#################################')
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = total_accuracy / len(dataloaders[phase])
print('{} Loss: {:.4f}'.format(phase, epoch_loss))
print('{} Acc: {:.4f}'.format(phase, epoch_acc))
# deep copy the model
if phase == 'dev' and (epoch_acc > best_acc):
best_loss = epoch_loss
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
'''save checkpoint'''
torch.save(model.state_dict(), checkpoint_path_best)
print(f'update best on dev checkpoint: {checkpoint_path_best}')
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
print('Best val acc: {:4f}'.format(best_acc))
torch.save(model.state_dict(), checkpoint_path_last)
print(f'update last checkpoint: {checkpoint_path_last}')
# load best model weights
model.load_state_dict(best_model_wts)
return model
if __name__ == '__main__':
args = get_config('train_sentiment_baseline')
''' config param'''
num_epochs = args['num_epoch']
step_lr = args['learning_rate']
'''dataset division'''
dataset_setting = 'unique_sent'
# subject_choice = 'ALL
subject_choice = args['subjects']
print(f'![Debug]using {subject_choice}')
# eeg_type_choice = 'GD
eeg_type_choice = args['eeg_type']
print(f'[INFO]eeg type {eeg_type_choice}')
# bands_choice = ['_t1']
# bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']
bands_choice = args['eeg_bands']
print(f'[INFO]using bands {bands_choice}')
'''model name'''
# model_name = 'BaselineMLP'
# model_name = 'BaselineLSTM'
# model_name = 'NaiveFinetuneBert'
model_name = args['model_name']
batch_size = 32
save_path = args['save_path']
save_name = f'{model_name}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}'
if model_name == 'BaselineLSTM':
num_layers = 4
save_name = f'{model_name}_numLayers-{num_layers}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}'
output_checkpoint_name_best = save_path + f'/best/{save_name}.pt'
output_checkpoint_name_last = save_path + f'/last/{save_name}.pt'
''' set random seeds '''
seed_val = 312
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
''' set up device '''
# use cuda
if torch.cuda.is_available():
dev = args['cuda']
else:
dev = "cpu"
# CUDA_VISIBLE_DEVICES=0,1,2,3
device = torch.device(dev)
print(f'[INFO]using device {dev}')
''' load pickle'''
whole_dataset_dict = []
dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle'
with open(dataset_path_task1, 'rb') as handle:
whole_dataset_dict.append(pickle.load(handle))
'''set up tokenizer'''
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
''' set up dataloader '''
# train dataset
train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
# dev dataset
dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
# test dataset
# test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice)
dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}
print('[INFO]train_set size: ', len(train_set))
print('[INFO]dev_set size: ', len(dev_set))
# train dataloader
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)
# dev dataloader
val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)
# dataloaders
dataloaders = {'train':train_dataloader, 'dev':val_dataloader}
''' set up model '''
if model_name == 'BaselineMLP':
print('[INFO]Model: BaselineMLP')
model = BaselineMLPSentence(input_dim = 105*len(bands_choice), hidden_dim = 128, output_dim = 3)
elif model_name == 'BaselineLSTM':
print('[INFO]Model: BaselineLSTM')
model = BaselineLSTM(input_dim = 105*len(bands_choice), hidden_dim = 256, output_dim = 3, num_layers = num_layers)
elif model_name == 'NaiveFinetuneBert':
print('[INFO]Model: NaiveFinetuneBert')
model = NaiveFineTunePretrainedBert(input_dim = 105*len(bands_choice), hidden_dim = 768, output_dim = 3)
model.to(device)
"""save config"""
with open(f'./config/eeg_sentiment/{save_name}.json', 'w') as out_config:
json.dump(args, out_config, indent = 4)
''' training loop '''
''' set up optimizer and scheduler'''
optimizer_step1 = optim.SGD(model.parameters(), lr=step_lr, momentum=0.9)
exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.5)
''' set up loss function '''
criterion = nn.CrossEntropyLoss()
print('=== start training ... ===')
# return best loss model from step1 training
model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
================================================
FILE: train_sentiment_textbased.py
================================================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, random_split
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
from data import ZuCo_dataset, SST_tenary_dataset
from model_sentiment import FineTunePretrainedTwoStep
from config import get_config
# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
# preds: numpy array: N * 3
# labels: numpy array: N
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
def flat_accuracy_top_k(preds, labels,k):
topk_preds = []
for pred in preds:
topk = pred.argsort()[-k:][::-1]
topk_preds.append(list(topk))
# print(topk_preds)
topk_preds = list(topk_preds)
right_count = 0
# print(len(labels))
for i in range(len(labels)):
l = labels[i][0]
if l in topk_preds[i]:
right_count+=1
return right_count/len(labels)
def train_model_ZuCo(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'):
# modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 100000000000
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'dev']:
total_accuracy = 0.0
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
# Iterate over data.
for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]):
# input_word_eeg_features = input_word_eeg_features.to(device).float()
# input_masks = input_masks.to(device)
# input_mask_invert = input_mask_invert.to(device)
target_ids = target_ids.to(device)
target_mask = target_mask.to(device)
sentiment_labels = sentiment_labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels)
logits = output.logits
loss = output.loss
# backward + optimize only if in training phase
if phase == 'train':
# with torch.autograd.detect_anomaly():
loss.backward()
optimizer.step()
# calculate accuracy
preds_cpu = logits.detach().cpu().numpy()
label_cpu = sentiment_labels.cpu().numpy()
total_accuracy += flat_accuracy(preds_cpu, label_cpu)
# statistics
running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss
# print('[DEBUG]loss:',loss.item())
# print('#################################')
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = total_accuracy / len(dataloaders[phase])
print('{} Loss: {:.4f}'.format(phase, epoch_loss))
print('{} Acc: {:.4f}'.format(phase, epoch_acc))
# deep copy the model
if phase == 'dev' and (epoch_acc > best_acc):
best_loss = epoch_loss
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
'''save checkpoint'''
torch.save(model.state_dict(), checkpoint_path_best)
print(f'update best on dev checkpoint: {checkpoint_path_best}')
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
print('Best val acc: {:4f}'.format(best_acc))
torch.save(model.state_dict(), checkpoint_path_last)
print(f'update last checkpoint: {checkpoint_path_last}')
# write to log
with open(output_log_file_name, 'w') as outlog:
outlog.write(f'best val loss: {best_loss}\n')
outlog.write('Best val acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
return model
def train_model_SST(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'):
# modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 100000000000
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'dev']:
total_accuracy = 0.0
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
# Iterate over data.
for input_ids,input_masks,sentiment_labels in tqdm(dataloaders[phase]):
input_ids = input_ids.to(device)
input_masks = input_masks.to(device)
sentiment_labels = sentiment_labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
output = model(input_ids = input_ids, attention_mask = input_masks, return_dict = True, labels = sentiment_labels)
logits = output.logits
loss = output.loss
# backward + optimize only if in training phase
if phase == 'train':
# with torch.autograd.detect_anomaly():
loss.backward()
optimizer.step()
# calculate accuracy
preds_cpu = logits.detach().cpu().numpy()
label_cpu = sentiment_labels.cpu().numpy()
total_accuracy += flat_accuracy(preds_cpu, label_cpu)
# statistics
running_loss += loss.item() * input_ids.size()[0] # batch loss
# print('[DEBUG]loss:',loss.item())
# print('#################################')
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = total_accuracy / len(dataloaders[phase])
print('{} Loss: {:.4f}'.format(phase, epoch_loss))
print('{} Acc: {:.4f}'.format(phase, epoch_acc))
# deep copy the model
if phase == 'dev' and (epoch_acc > best_acc):
best_loss = epoch_loss
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
'''save checkpoint'''
torch.save(model.state_dict(), checkpoint_path_best)
print(f'update best on dev checkpoint: {checkpoint_path_best}')
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
print('Best val acc: {:4f}'.format(best_acc))
torch.save(model.state_dict(), checkpoint_path_last)
print(f'update last checkpoint: {checkpoint_path_last}')
# load best model weights
model.load_state_dict(best_model_wts)
return model
if __name__ == '__main__':
args = get_config('train_sentiment_textbased')
''' config param'''
num_epoch = args['num_epoch']
# lr = 1e-3 # Bert, RoBerta
# lr = 1e-4 # Bart
lr = args['learning_rate']
dataset_name = args['dataset_name'] # zero-shot setting: using external dataset from stanford sentiment treebank, pass in 'SST'; or pass in 'ZuCo' to train on ZuCo's text-sentiment pairs
dataset_setting = 'unique_sent'
batch_size = args['batch_size']
# model_name = 'pretrain_Bert'
# model_name = 'pretrain_RoBerta'
# model_name = 'pretrain_Bart'
model_name = args['model_name']
print(f'[INFO]model name: {model_name}')
save_path = args['save_path']
if dataset_name == 'ZuCo':
# subject_choice = 'ALL
subject_choice = args['subjects']
print(f'![Debug]using {subject_choice}')
# eeg_type_choice = 'GD
eeg_type_choice = args['eeg_type']
print(f'[INFO]eeg type {eeg_type_choice}')
# bands_choice = ['_t1']
# bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']
bands_choice = args['eeg_bands']
print(f'[INFO]using bands {bands_choice}')
save_name = f'Textbased_ZuCo_{model_name}_b{batch_size}_{num_epoch}_{lr}_{dataset_setting}_{eeg_type_choice}'
elif dataset_name == 'SST':
save_name = f'Textbased_StanfordSentitmentTreeband_{model_name}_b{batch_size}_{num_epoch}_{lr}'
output_checkpoint_name_best = save_path + f'/best/{save_name}.pt'
output_checkpoint_name_last = save_path + f'/last/{save_name}.pt'
''' set random seeds '''
seed_val = 312
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
''' set up device '''
# use cuda
if torch.cuda.is_available():
dev = args['cuda']
else:
dev = "cpu"
# CUDA_VISIBLE_DEVICES=0,1,2,3
device = torch.device(dev)
print(f'[INFO]using device {dev}')
''' load pickle '''
if dataset_name == 'ZuCo':
whole_dataset_dict = []
dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle'
with open(dataset_path_task1, 'rb') as handle:
whole_dataset_dict.append(pickle.load(handle))
'''tokenizer'''
if model_name == 'pretrain_Bert':
print('[INFO]pretrained checkpoint: bert-base-cased')
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
elif model_name == 'pretrain_RoBerta':
print('[INFO]pretrained checkpoint: roberta-base')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
elif model_name == 'pretrain_Bart':
print('[INFO]pretrained checkpoint: bart-large')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
''' set up dataloader '''
if dataset_name == 'ZuCo':
# train dataset
train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
# dev dataset
dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
elif dataset_name == 'SST':
SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))
SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer)
train_size = int(0.9 * len(SST_dataset))
val_size = len(SST_dataset) - train_size
train_set, dev_set = random_split(SST_dataset, [train_size, val_size])
print('{:>5,} training samples'.format(len(train_set)))
print('{:>5,} validation samples'.format(len(dev_set)))
dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}
print('[INFO]train_set size: ', len(train_set))
print('[INFO]dev_set size: ', len(dev_set))
# train dataloader
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)
# dev dataloader
val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)
# dataloaders
dataloaders = {'train':train_dataloader, 'dev':val_dataloader}
''' set up model '''
if model_name == 'pretrain_Bert':
model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
elif model_name == 'pretrain_RoBerta':
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)
elif model_name == 'pretrain_Bart':
model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels = 3)
model.to(device)
"""save config"""
with open(f'./config/text_sentiment_classifier/{save_name}.json', 'w') as out_config:
json.dump(args, out_config, indent = 4)
''' training loop '''
######################################################
'''step one trainig: freeze most of BART params'''
######################################################
''' set up optimizer and scheduler'''
optimizer_step1 = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=10, gamma=0.1)
# TODO: rethink about the loss function
''' set up loss function '''
criterion = nn.CrossEntropyLoss()
# return best loss model from step1 training
print(f'=== start training {dataset_name} ... ===')
if dataset_name == 'ZuCo':
model = train_model_ZuCo(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
elif dataset_name == 'SST':
model = train_model_SST(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
================================================
FILE: util/construct_dataset_mat_to_pickle_v1.py
================================================
import scipy.io as io
import h5py
import os
import json
from glob import glob
from tqdm import tqdm
import numpy as np
import pickle
import argparse
parser = argparse.ArgumentParser(description='Specify task name for converting ZuCo v1.0 Mat file to Pickle')
parser.add_argument('-t', '--task_name', help='name of the task in /dataset/ZuCo, choose from {task1-SR,task2-NR,task3-TSR}', required=True)
args = vars(parser.parse_args())
"""config"""
version = 'v1' # 'old'
# version = 'v2' # 'new'
task_name = args['task_name']
# task_name = 'task1-SR'
# task_name = 'task2-NR'
# task_name = 'task3-TSR'
print('##############################')
print(f'start processing ZuCo {task_name}...')
if version == 'v1':
# old version
input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files'
elif version == 'v2':
# new version, mat73
input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files'
output_dir = f'./dataset/ZuCo/{task_name}/pickle'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
"""load files"""
mat_files = glob(os.path.join(input_mat_files_dir,'*.mat'))
mat_files = sorted(mat_files)
if len(mat_files) == 0:
print(f'No mat files found for {task_name}')
quit()
dataset_dict = {}
for mat_file in tqdm(mat_files):
subject_name = os.path.basename(mat_file).split('_')[0].replace('results','').strip()
dataset_dict[subject_name] = []
if version == 'v1':
matdata = io.loadmat(mat_file, squeeze_me=True, struct_as_record=False)['sentenceData']
elif version == 'v2':
matdata = h5py.File(mat_file,'r')
print(matdata)
for sent in matdata:
word_data = sent.word
if not isinstance(word_data, float):
# sentence level:
sent_obj = {'content':sent.content}
sent_obj['sentence_level_EEG'] = {'mean_t1':sent.mean_t1, 'mean_t2':sent.mean_t2, 'mean_a1':sent.mean_a1, 'mean_a2':sent.mean_a2, 'mean_b1':sent.mean_b1, 'mean_b2':sent.mean_b2, 'mean_g1':sent.mean_g1, 'mean_g2':sent.mean_g2}
if task_name == 'task1-SR':
sent_obj['answer_EEG'] = {'answer_mean_t1':sent.answer_mean_t1, 'answer_mean_t2':sent.answer_mean_t2, 'answer_mean_a1':sent.answer_mean_a1, 'answer_mean_a2':sent.answer_mean_a2, 'answer_mean_b1':sent.answer_mean_b1, 'answer_mean_b2':sent.answer_mean_b2, 'answer_mean_g1':sent.answer_mean_g1, 'answer_mean_g2':sent.answer_mean_g2}
# word level:
sent_obj['word'] = []
word_tokens_has_fixation = []
word_tokens_with_mask = []
word_tokens_all = []
for word in word_data:
word_obj = {'content':word.content}
word_tokens_all.append(word.content)
# TODO: add more version of word level eeg: GD, SFD, GPT
word_obj['nFixations'] = word.nFixations
if word.nFixations > 0:
word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':word.FFD_t1, 'FFD_t2':word.FFD_t2, 'FFD_a1':word.FFD_a1, 'FFD_a2':word.FFD_a2, 'FFD_b1':word.FFD_b1, 'FFD_b2':word.FFD_b2, 'FFD_g1':word.FFD_g1, 'FFD_g2':word.FFD_g2}}
word_obj['word_level_EEG']['TRT'] = {'TRT_t1':word.TRT_t1, 'TRT_t2':word.TRT_t2, 'TRT_a1':word.TRT_a1, 'TRT_a2':word.TRT_a2, 'TRT_b1':word.TRT_b1, 'TRT_b2':word.TRT_b2, 'TRT_g1':word.TRT_g1, 'TRT_g2':word.TRT_g2}
word_obj['word_level_EEG']['GD'] = {'GD_t1':word.GD_t1, 'GD_t2':word.GD_t2, 'GD_a1':word.GD_a1, 'GD_a2':word.GD_a2, 'GD_b1':word.GD_b1, 'GD_b2':word.GD_b2, 'GD_g1':word.GD_g1, 'GD_g2':word.GD_g2}
sent_obj['word'].append(word_obj)
word_tokens_has_fixation.append(word.content)
word_tokens_with_mask.append(word.content)
else:
word_tokens_with_mask.append('[MASK]')
# if a word has no fixation, use sentence level feature
# word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':sent.mean_t1, 'FFD_t2':sent.mean_t2, 'FFD_a1':sent.mean_a1, 'FFD_a2':sent.mean_a2, 'FFD_b1':sent.mean_b1, 'FFD_b2':sent.mean_b2, 'FFD_g1':sent.mean_g1, 'FFD_g2':sent.mean_g2}}
# word_obj['word_level_EEG']['TRT'] = {'TRT_t1':sent.mean_t1, 'TRT_t2':sent.mean_t2, 'TRT_a1':sent.mean_a1, 'TRT_a2':sent.mean_a2, 'TRT_b1':sent.mean_b1, 'TRT_b2':sent.mean_b2, 'TRT_g1':sent.mean_g1, 'TRT_g2':sent.mean_g2}
# NOTE:if a word has no fixation, simply skip it
continue
sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation
sent_obj['word_tokens_with_mask'] = word_tokens_with_mask
sent_obj['word_tokens_all'] = word_tokens_all
dataset_dict[subject_name].append(sent_obj)
else:
print(f'missing sent: subj:{subject_name} content:{sent.content}, return None')
dataset_dict[subject_name].append(None)
continue
# print(dataset_dict.keys())
# print(dataset_dict[subject_name][0].keys())
# print(dataset_dict[subject_name][0]['content'])
# print(dataset_dict[subject_name][0]['word'][0].keys())
# print(dataset_dict[subject_name][0]['word'][0]['word_level_EEG']['FFD'])
"""output"""
output_name = f'{task_name}-dataset.pickle'
# with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out:
# json.dump(dataset_dict,out,indent = 4)
with open(os.path.join(output_dir,output_name), 'wb') as handle:
pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('write to:', os.path.join(output_dir,output_name))
"""sanity check"""
# check dataset
with open(os.path.join(output_dir,output_name), 'rb') as handle:
whole_dataset = pickle.load(handle)
print('subjects:', whole_dataset.keys())
if version == 'v1':
print('num of sent:', len(whole_dataset['ZAB']))
print()
================================================
FILE: util/construct_dataset_mat_to_pickle_v2.py
================================================
import os
import numpy as np
import h5py
import data_loading_helpers_modified as dh
from glob import glob
from tqdm import tqdm
import pickle
task = "NR"
rootdir = "./dataset/ZuCo/task2-NR-2.0/Matlab_files/"
print('##############################')
print(f'start processing ZuCo task2-NR-2.0...')
dataset_dict = {}
for file in tqdm(os.listdir(rootdir)):
if file.endswith(task+".mat"):
file_name = rootdir + file
# print('file name:', file_name)
subject = file_name.split("ts")[1].split("_")[0]
# print('subject: ', subject)
# exclude YMH due to incomplete data because of dyslexia
if subject != 'YMH':
assert subject not in dataset_dict
dataset_dict[subject] = []
f = h5py.File(file_name,'r')
# print('keys in f:', list(f.keys()))
sentence_data = f['sentenceData']
# print('keys in sentence_data:', list(sentence_data.keys()))
# sent level eeg
# mean_t1 = np.squeeze(f[sentence_data['mean_t1'][0][0]][()])
mean_t1_objs = sentence_data['mean_t1']
mean_t2_objs = sentence_data['mean_t2']
mean_a1_objs = sentence_data['mean_a1']
mean_a2_objs = sentence_data['mean_a2']
mean_b1_objs = sentence_data['mean_b1']
mean_b2_objs = sentence_data['mean_b2']
mean_g1_objs = sentence_data['mean_g1']
mean_g2_objs = sentence_data['mean_g2']
rawData = sentence_data['rawData']
contentData = sentence_data['content']
# print('contentData shape:', contentData.shape, 'dtype:', contentData.dtype)
omissionR = sentence_data['omissionRate']
wordData = sentence_data['word']
for idx in range(len(rawData)):
# get sentence string
obj_reference_content = contentData[idx][0]
sent_string = dh.load_matlab_string(f[obj_reference_content])
# print('sentence string:', sent_string)
sent_obj = {'content':sent_string}
# get sentence level EEG
sent_obj['sentence_level_EEG'] = {
'mean_t1':np.squeeze(f[mean_t1_objs[idx][0]][()]),
'mean_t2':np.squeeze(f[mean_t2_objs[idx][0]][()]),
'mean_a1':np.squeeze(f[mean_a1_objs[idx][0]][()]),
'mean_a2':np.squeeze(f[mean_a2_objs[idx][0]][()]),
'mean_b1':np.squeeze(f[mean_b1_objs[idx][0]][()]),
'mean_b2':np.squeeze(f[mean_b2_objs[idx][0]][()]),
'mean_g1':np.squeeze(f[mean_g1_objs[idx][0]][()]),
'mean_g2':np.squeeze(f[mean_g2_objs[idx][0]][()])
}
# print(sent_obj)
sent_obj['word'] = []
# get word level data
word_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask = dh.extract_word_level_data(f, f[wordData[idx][0]])
if word_data == {}:
print(f'missing sent: subj:{subject} content:{sent_string}, append None')
dataset_dict[subject].append(None)
continue
elif len(word_tokens_all) == 0:
print(f'no word level features: subj:{subject} content:{sent_string}, append None')
dataset_dict[subject].append(None)
continue
else:
for widx in range(len(word_data)):
data_dict = word_data[widx]
word_obj = {'content':data_dict['content'], 'nFixations': data_dict['nFix']}
if 'GD_EEG' in data_dict:
# print('has fixation: ', data_dict['content'])
gd = data_dict["GD_EEG"]
ffd = data_dict["FFD_EEG"]
trt = data_dict["TRT_EEG"]
assert len(gd) == len(trt) == len(ffd) == 8
word_obj['word_level_EEG'] = {
'GD':{'GD_t1':gd[0], 'GD_t2':gd[1], 'GD_a1':gd[2], 'GD_a2':gd[3], 'GD_b1':gd[4], 'GD_b2':gd[5], 'GD_g1':gd[6], 'GD_g2':gd[7]},
'FFD':{'FFD_t1':ffd[0], 'FFD_t2':ffd[1], 'FFD_a1':ffd[2], 'FFD_a2':ffd[3], 'FFD_b1':ffd[4], 'FFD_b2':ffd[5], 'FFD_g1':ffd[6], 'FFD_g2':ffd[7]},
'TRT':{'TRT_t1':trt[0], 'TRT_t2':trt[1], 'TRT_a1':trt[2], 'TRT_a2':trt[3], 'TRT_b1':trt[4], 'TRT_b2':trt[5], 'TRT_g1':trt[6], 'TRT_g2':trt[7]}
}
sent_obj['word'].append(word_obj)
sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation
sent_obj['word_tokens_with_mask'] = word_tokens_with_mask
sent_obj['word_tokens_all'] = word_tokens_all
# print(sent_obj.keys())
# print(len(sent_obj['word']))
# print(sent_obj['word'][0])
dataset_dict[subject].append(sent_obj)
"""output"""
task_name = 'task2-NR-2.0'
if dataset_dict == {}:
print(f'No mat file found for {task_name}')
quit()
output_dir = f'./dataset/ZuCo/{task_name}/pickle'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_name = f'{task_name}-dataset.pickle'
# with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out:
# json.dump(dataset_dict,out,indent = 4)
with open(os.path.join(output_dir,output_name), 'wb') as handle:
pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('write to:', os.path.join(output_dir,output_name))
"""sanity check"""
print('subjects:', dataset_dict.keys())
print('num of sent:', len(dataset_dict['YAC']))
================================================
FILE: util/data_loading_helpers_modified.py
================================================
import numpy as np
import re
eeg_float_resolution=np.float16
Alpha_ffd_names = ['FFD_a1', 'FFD_a1_diff', 'FFD_a2', 'FFD_a2_diff']
Beta_ffd_names = ['FFD_b1', 'FFD_b1_diff', 'FFD_b2', 'FFD_b2_diff']
Gamma_ffd_names = ['FFD_g1', 'FFD_g1_diff', 'FFD_g2', 'FFD_g2_diff']
Theta_ffd_names = ['FFD_t1', 'FFD_t1_diff', 'FFD_t2', 'FFD_t2_diff']
Alpha_gd_names = ['GD_a1', 'GD_a1_diff', 'GD_a2', 'GD_a2_diff']
Beta_gd_names = ['GD_b1', 'GD_b1_diff', 'GD_b2', 'GD_b2_diff']
Gamma_gd_names = ['GD_g1', 'GD_g1_diff', 'GD_g2', 'GD_g2_diff']
Theta_gd_names = ['GD_t1', 'GD_t1_diff', 'GD_t2', 'GD_t2_diff']
Alpha_gpt_names = ['GPT_a1', 'GPT_a1_diff', 'GPT_a2', 'GPT_a2_diff']
Beta_gpt_names = ['GPT_b1', 'GPT_b1_diff', 'GPT_b2', 'GPT_b2_diff']
Gamma_gpt_names = ['GPT_g1', 'GPT_g1_diff', 'GPT_g2', 'GPT_g2_diff']
Theta_gpt_names = ['GPT_t1', 'GPT_t1_diff', 'GPT_t2', 'GPT_t2_diff']
Alpha_sfd_names = ['SFD_a1', 'SFD_a1_diff', 'SFD_a2', 'SFD_a2_diff']
Beta_sfd_names = ['SFD_b1', 'SFD_b1_diff', 'SFD_b2', 'SFD_b2_diff']
Gamma_sfd_names = ['SFD_g1', 'SFD_g1_diff', 'SFD_g2', 'SFD_g2_diff']
Theta_sfd_names = ['SFD_t1', 'SFD_t1_diff', 'SFD_t2', 'SFD_t2_diff']
Alpha_trt_names = ['TRT_a1', 'TRT_a1_diff', 'TRT_a2', 'TRT_a2_diff']
Beta_trt_names = ['TRT_b1', 'TRT_b1_diff', 'TRT_b2', 'TRT_b2_diff']
Gamma_trt_names = ['TRT_g1', 'TRT_g1_diff', 'TRT_g2', 'TRT_g2_diff']
Theta_trt_names = ['TRT_t1', 'TRT_t1_diff', 'TRT_t2', 'TRT_t2_diff']
# IF YOU CHANGE THOSE YOU MUST ALSO CHANGE CONSTANTS
Alpha_features = Alpha_ffd_names + Alpha_gd_names + Alpha_gpt_names + Alpha_trt_names# + Alpha_sfd_names
Beta_features = Beta_ffd_names + Beta_gd_names + Beta_gpt_names + Beta_trt_names# + Beta_sfd_names
Gamma_features = Gamma_ffd_names + Gamma_gd_names + Gamma_gpt_names + Gamma_trt_names# + Gamma_sfd_names
Theta_features = Theta_ffd_names + Theta_gd_names + Theta_gpt_names + Theta_trt_names# + Theta_sfd_names
# print(Alpha_features)
# GD_EEG_feautres
def extract_all_fixations(data_container, word_data_object, float_resolution = np.float16):
"""
Extracts all fixations from a word data object
:param data_container: (h5py) Container of the whole data, h5py object
:param word_data_object: (h5py) Container of fixation objects, h5py object
:param float_resolution: (type) Resolution to which data re to be converted, used for data compression
:return:
fixations_data (list) Data arrays representing each fixation
"""
word_data = data_container[word_data_object]
fixations_data = []
if len(word_data.shape) > 1:
for fixation_idx in range(word_data.shape[0]):
fixations_data.append(np.array(data_container[word_data[fixation_idx][0]]).astype(float_resolution))
return fixations_data
def is_real_word(word):
"""
Check if the word is a real word
:param word: (str) word string
:return:
is_word (bool) True if it is a real word
"""
is_word = re.search('[a-zA-Z0-9]', word)
return is_word
def load_matlab_string(matlab_extracted_object):
"""
Converts a string loaded from h5py into a python string
:param matlab_extracted_object: (h5py) matlab string object
:return:
extracted_string (str) translated string
"""
extracted_string = u''.join(chr(c[0]) for c in matlab_extracted_object)
return extracted_string
def extract_word_level_data(data_container, word_objects, eeg_float_resolution = np.float16):
"""
Extracts word level data for a specific sentence
:param data_container: (h5py) Container of the whole data, h5py object
:param word_objects: (h5py) Container of all word data for a specific sentence
:param eeg_float_resolution: (type) Resolution with which to save EEG, used for data compression
:return:
word_level_data (dict) Contains all word level data indexed by their index number in the sentence,
together with the reading order, indexed by "word_reading_order"
"""
available_objects = list(word_objects)
#print(available_objects)
#print(len(available_objects))
# print('available_objects:', available_objects)
if isinstance(available_objects[0], str):
contentData = word_objects['content']
#fixations_order_per_word = []
if "rawEEG" in available_objects:
rawData = word_objects['rawEEG']
etData = word_objects['rawET']
ffdData = word_objects['FFD']
gdData = word_objects['GD']
gptData = word_objects['GPT']
trtData = word_objects['TRT']
try:
sfdData = word_objects['SFD']
except KeyError:
print("no SFD!")
sfdData = []
nFixData = word_objects['nFixations']
fixPositions = word_objects["fixPositions"]
Alpha_features_data = [word_objects[feature] for feature in Alpha_features]
Beta_features_data = [word_objects[feature] for feature in Beta_features]
Gamma_features_data = [word_objects[feature] for feature in Gamma_features]
Theta_features_data = [word_objects[feature] for feature in Theta_features]
####
GD_EEG_features = [word_objects[feature] for feature in ['GD_t1','GD_t2','GD_a1','GD_a2','GD_b1','GD_b2','GD_g1','GD_g2']]
FFD_EEG_features = [word_objects[feature] for feature in ['FFD_t1','FFD_t2','FFD_a1','FFD_a2','FFD_b1','FFD_b2','FFD_g1','FFD_g2']]
TRT_EEG_features = [word_objects[feature] for feature in ['TRT_t1','TRT_t2','TRT_a1','TRT_a2','TRT_b1','TRT_b2','TRT_g1','TRT_g2']]
####
assert len(contentData) == len(etData) == len(rawData), "different amounts of different data!!"
zipped_data = zip(rawData, etData, contentData, ffdData, gdData, gptData, trtData, sfdData, nFixData, fixPositions)
word_level_data = {}
word_idx = 0
word_tokens_has_fixation = []
word_tokens_with_mask = []
word_tokens_all = []
for raw_eegs_obj, ets_obj, word_obj, ffd, gd, gpt, trt, sfd, nFix, fixPos in zipped_data:
word_string = load_matlab_string(data_container[word_obj[0]])
if is_real_word(word_string):
data_dict = {}
data_dict["RAW_EEG"] = extract_all_fixations(data_container, raw_eegs_obj[0], eeg_float_resolution)
data_dict["RAW_ET"] = extract_all_fixations(data_container, ets_obj[0], np.float32)
data_dict["FFD"] = data_container[ffd[0]][()][0, 0] if len(data_container[ffd[0]][()].shape) == 2 else None
data_dict["GD"] = data_container[gd[0]][()][0, 0] if len(data_container[gd[0]][()].shape) == 2 else None
data_dict["GPT"] = data_container[gpt[0]][()][0, 0] if len(data_container[gpt[0]][()].shape) == 2 else None
data_dict["TRT"] = data_container[trt[0]][()][0, 0] if len(data_container[trt[0]][()].shape) == 2 else None
data_dict["SFD"] = data_container[sfd[0]][()][0, 0] if len(data_container[sfd[0]][()].shape) == 2 else None
data_dict["nFix"] = data_container[nFix[0]][()][0, 0] if len(data_container[nFix[0]][()].shape) == 2 else None
#fixations_order_per_word.append(np.array(data_container[fixPos[0]]))
#print([data_container[obj[word_idx][0]][()] for obj in Alpha_features_data])
data_dict["ALPHA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
for obj in Alpha_features_data], 0)
data_dict["BETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
for obj in Beta_features_data], 0)
data_dict["GAMMA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
for obj in Gamma_features_data], 0)
data_dict["THETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
for obj in Theta_features_data], 0)
data_dict["word_idx"] = word_idx
data_dict["content"] = word_string
####################################
word_tokens_all.append(word_string)
if data_dict["nFix"] is not None:
####################################
data_dict["GD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in GD_EEG_features]
data_dict["FFD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in FFD_EEG_features]
data_dict["TRT_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in TRT_EEG_features]
####################################
word_tokens_has_fixation.append(word_string)
word_tokens_with_mask.append(word_string)
else:
word_tokens_with_mask.append('[MASK]')
word_level_data[word_idx] = data_dict
word_idx += 1
else:
print(word_string + " is not a real word.")
else:
# If there are no word-level data it will be word embeddings alone
word_level_data = {}
word_idx = 0
word_tokens_has_fixation = []
word_tokens_with_mask = []
word_tokens_all = []
for word_obj in contentData:
word_string = load_matlab_string(data_container[word_obj[0]])
if is_real_word(word_string):
data_dict = {}
data_dict["RAW_EEG"] = []
data_dict["ICA_EEG"] = []
data_dict["RAW_ET"] = []
data_dict["FFD"] = None
data_dict["GD"] = None
data_dict["GPT"] = None
data_dict["TRT"] = None
data_dict["SFD"] = None
data_dict["nFix"] = None
data_dict["ALPHA_EEG"] = []
data_dict["BETA_EEG"] = []
data_dict["GAMMA_EEG"] = []
data_dict["THETA_EEG"] = []
data_dict["word_idx"] = word_idx
data_dict["content"] = word_string
word_level_data[word_idx] = data_dict
word_idx += 1
else:
print(word_string + " is not a real word.")
sentence = " ".join([load_matlab_string(data_container[word_obj[0]]) for word_obj in word_objects['content']])
#print("Only available objects for the sentence '{}' are {}.".format(sentence, available_objects))
#word_level_data["word_reading_order"] = extract_word_order_from_fixations(fixations_order_per_word)
else:
word_tokens_has_fixation = []
word_tokens_with_mask = []
word_tokens_all = []
word_level_data = {}
return word_level_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask
================================================
FILE: util/get_SST_ternary_dataset.py
================================================
import os
import numpy as np
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
import json
import matplotlib.pyplot as plt
from glob import glob
from transformers import BartTokenizer
from tqdm import tqdm
from fuzzy_match import match
from fuzzy_match import algorithims
def get_SST_dataset(SST_dir_path, ZuCo_used_sentences, ZUCO_SENTIMENT_LABELS):
def get_sentiment_label_dict(SST_dictionary_file_path):
'''
return {phrase_id:sentiment_score(0-1)}
'''
ret_dict = {}
with open(SST_dictionary_file_path) as f:
for line in f:
if line.startswith('phrase'):
continue
else:
phrase_id = int(line.split('|')[0])
label = float(line.split('|')[1].strip())
assert phrase_id not in ret_dict
ret_dict[phrase_id] = label
return ret_dict
def get_phrasestr_phrase_dict(SST_dictionary_file_path):
'''
return {phrase_str: phrase_id}
'''
ret_dict = {}
with open(SST_dictionary_file_path) as f:
for line in f:
phrase_str = line.split('|')[0]
phrase_id = int(line.split('|')[1].strip())
assert phrase_str not in ret_dict
ret_dict[phrase_str] = phrase_id
return ret_dict
def get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path):
'''
return {sentence_str:label(0-1)}
'''
phraseID_2_label = get_sentiment_label_dict(SST_labels_file_path)
phraseStr_2_phraseID = get_phrasestr_phrase_dict(SST_dictionary_file_path)
sentence_2_label_all = {}
sentence_2_label_ternary = {}
with open(SST_sentences_file_path) as f:
for line in f:
if line.startswith('sentence_index'):
continue
else:
parsed_line = line.split('\t')
assert len(parsed_line) == 2
sentence = parsed_line[1].strip()
# convert -LRB- to (, -RRB- to ):
sentence = sentence.replace('-LRB-','(').replace('-RRB-',')').replace('é','é')
if sentence not in phraseStr_2_phraseID:
# print(f'[ERROR]sentence-phrase match not found in dictionary, skipped: {sentence}')
# print()
continue
sent_phrase_id = phraseStr_2_phraseID[sentence]
label = phraseID_2_label[sent_phrase_id]
# add to all dict
if sentence not in sentence_2_label_all:
sentence_2_label_all[sentence] = label
# add to ternary dict
if sentence not in sentence_2_label_ternary:
if label<=0.2:
label = 0
sentence_2_label_ternary[sentence] = label
elif (label > 0.4) and (label<=0.6):
label = 1
sentence_2_label_ternary[sentence] = label
elif label>0.8:
label = 2
sentence_2_label_ternary[sentence] = label
return sentence_2_label_all, sentence_2_label_ternary
SST_sentences_file_path = os.path.join(SST_dir_path,'datasetSentences.txt')
if not os.path.isfile(SST_sentences_file_path):
print(f'NOT FOUND file: {SST_sentences_file_path}')
SST_labels_file_path = os.path.join(SST_dir_path,'sentiment_labels.txt')
if not os.path.isfile(SST_labels_file_path):
print(f'NOT FOUND file: {SST_labels_file_path}')
SST_dictionary_file_path = os.path.join(SST_dir_path,'dictionary.txt')
if not os.path.isfile(SST_dictionary_file_path):
print(f'NOT FOUND file: {SST_dictionary_file_path}')
sentence_2_label_all, sentence_2_label_ternary = get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path)
print('original ternary dataset size:', len(sentence_2_label_ternary))
ZuCo_used_sentences = list(ZUCO_SENTIMENT_LABELS)
filtered_ternary_dataset = {}
filtered_pairs = []
for key,value in sentence_2_label_ternary.items():
add_instance = True
for used_sent in ZuCo_used_sentences:
if algorithims.trigram(used_sent, key) > 0.7:
# print(f'Filter match: \n\t{used_sent}\n\t{key}')
# print('###########################')
filtered_pairs.append((used_sent, key))
ZuCo_used_sentences.remove(used_sent)
add_instance = False
break
if add_instance:
filtered_ternary_dataset[key] = value
print('filtered instance number:', len(filtered_pairs))
print('filtered ternary dataset size:', len(filtered_ternary_dataset))
print('unmatched remaining sentences:', ZuCo_used_sentences)
print('unmatched remaining sentences length:', len(ZuCo_used_sentences))
with open('temp.txt','w') as temp:
for matched_pair in filtered_pairs:
temp.write('#######\n')
temp.write('\t'+matched_pair[0]+'\n')
temp.write('\t'+matched_pair[1]+'\n')
temp.write('\n')
with open('./dataset/stanfordsentiment/ternary_dataset.json', 'w') as out:
json.dump(filtered_ternary_dataset,out, indent = 4)
print('write json to /dataset/stanfordsentiment/ternary_dataset.json')
if __name__ == '__main__':
print('##############################')
print('start generating stanfordSentimentTreebank ternary sentiment dataset...')
SST_dir_path = './dataset/stanfordsentiment/stanfordSentimentTreebank'
ZuCo_task1_csv_path = './dataset/ZuCo/task_materials/sentiment_labels_task1.csv'
ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))
get_SST_dataset(SST_dir_path, ZuCo_task1_csv_path, ZUCO_SENTIMENT_LABELS)
================================================
FILE: util/get_sentiment_labels.py
================================================
import os
from glob import glob
import json
print('##############################')
print('start generating ZuCo task1-SR sentiment labels...')
sentiment_labels_task1_csv_path = './dataset/ZuCo/task_materials/sentiment_labels_task1.csv'
sentiment_labels = {}
with open(sentiment_labels_task1_csv_path, 'r') as f:
for line in f:
if line.startswith('sentence_id') or line.startswith('#'):
continue
else:
parsed_line = line.split(';')
# handle edge case:
if '\";' in line:
sent_text = line.split('\";')[0].split('\"')[1]
else:
sent_text = parsed_line[1]
label = int(parsed_line[-1].strip())
sentiment_labels[sent_text] = label
output_dir = f'./dataset/ZuCo/task1-SR/sentiment_labels'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, 'sentiment_labels.json'), 'w') as out:
json.dump(sentiment_labels,out,indent = 4)
print('write to ./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json')
gitextract_ginxdd5t/
├── .gitignore
├── README.md
├── config.py
├── data.py
├── environment.yml
├── eval_decoding.py
├── eval_sentiment.py
├── model_decoding.py
├── model_sentiment.py
├── scripts/
│ ├── eval_decoding_1.sh
│ ├── eval_decoding_2.sh
│ ├── eval_decoding_3.sh
│ ├── eval_decoding_4.sh
│ ├── eval_sentiment_zeroshot_pipeline.sh
│ ├── prepare_dataset.sh
│ ├── train_decoding.sh
│ ├── train_decoding_1.sh
│ ├── train_eeg_sentiment_baseline.sh
│ ├── train_eval_zeroshot_pipeline.sh
│ └── train_text_sentiment_classifier.sh
├── train_decoding.py
├── train_sentiment_baseline.py
├── train_sentiment_textbased.py
└── util/
├── construct_dataset_mat_to_pickle_v1.py
├── construct_dataset_mat_to_pickle_v2.py
├── data_loading_helpers_modified.py
├── get_SST_ternary_dataset.py
└── get_sentiment_labels.py
SYMBOL INDEX (86 symbols across 11 files)
FILE: config.py
function str2bool (line 3) | def str2bool(v):
function get_config (line 13) | def get_config(case):
FILE: data.py
function normalize_1d (line 18) | def normalize_1d(input_tensor):
function get_input_sample (line 25) | def get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1...
class ZuCo_dataset (line 160) | class ZuCo_dataset(Dataset):
method __init__ (line 161) | def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'A...
method __len__ (line 237) | def __len__(self):
method __getitem__ (line 240) | def __getitem__(self, idx):
class SST_tenary_dataset (line 256) | class SST_tenary_dataset(Dataset):
method __init__ (line 257) | def __init__(self, ternary_labels_dict, tokenizer, max_len = 56, balan...
method __len__ (line 286) | def __len__(self):
method __getitem__ (line 289) | def __getitem__(self, idx):
FILE: eval_decoding.py
function remove_text_after_token (line 30) | def remove_text_after_token(text, token='</s>'):
function eval_model (line 37) | def eval_model(dataloaders, device, tokenizer, criterion, model, output_...
FILE: eval_sentiment.py
function flat_accuracy (line 26) | def flat_accuracy(preds, labels):
function flat_accuracy_top_k (line 35) | def flat_accuracy_top_k(preds, labels,k):
function eval_model (line 50) | def eval_model(dataloaders, device, model, criterion, optimizer, schedul...
FILE: model_decoding.py
class BrainTranslator (line 9) | class BrainTranslator(nn.Module):
method __init__ (line 10) | def __init__(self, pretrained_layers, in_feature = 840, decoder_embedd...
method addin_forward (line 23) | def addin_forward(self,input_embeddings_batch, input_masks_invert):
method generate (line 37) | def generate(
method forward (line 70) | def forward(self, input_embeddings_batch, input_masks_batch, input_mas...
class T5Translator (line 81) | class T5Translator(nn.Module):
method __init__ (line 82) | def __init__(self, pretrained_layers, in_feature = 840, decoder_embedd...
method addin_forward (line 98) | def addin_forward(self,input_embeddings_batch, input_masks_invert):
method generate (line 112) | def generate(
method forward (line 154) | def forward(self, input_embeddings_batch, input_masks_batch, input_mas...
class BrainTranslatorNaive (line 170) | class BrainTranslatorNaive(nn.Module):
method __init__ (line 171) | def __init__(self, pretrained_layers, in_feature = 840, decoder_embedd...
method forward (line 177) | def forward(self, input_embeddings_batch, input_masks_batch, input_mas...
class Pooler (line 188) | class Pooler(nn.Module):
method __init__ (line 189) | def __init__(self, hidden_size):
method forward (line 194) | def forward(self, hidden_states):
class PositionalEncoding (line 203) | class PositionalEncoding(nn.Module):
method __init__ (line 205) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 217) | def forward(self, x):
class BrainTranslatorBert (line 226) | class BrainTranslatorBert(nn.Module):
method __init__ (line 227) | def __init__(self, pretrained_layers, in_feature = 840, hidden_size = ...
method forward (line 233) | def forward(self, input_embeddings_batch, input_masks_batch, target_id...
class EEG2BertMapping (line 238) | class EEG2BertMapping(nn.Module):
method __init__ (line 239) | def __init__(self, in_feature = 840, hidden_size = 512, out_feature = ...
method forward (line 244) | def forward(self, x):
class ContrastiveBrainTextEncoder (line 249) | class ContrastiveBrainTextEncoder(nn.Module):
method __init__ (line 250) | def __init__(self, pretrained_text_encoder, in_feature = 840, eeg_enco...
method forward (line 268) | def forward(self, input_EEG_features, input_EEG_attn_mask, input_ids, ...
FILE: model_sentiment.py
class BaselineMLPSentence (line 11) | class BaselineMLPSentence(nn.Module):
method __init__ (line 12) | def __init__(self, input_dim = 840, hidden_dim = 128, output_dim = 3):
method forward (line 21) | def forward(self, x):
class BaselineLSTM (line 32) | class BaselineLSTM(nn.Module):
method __init__ (line 33) | def __init__(self, input_dim = 840, hidden_dim = 256, output_dim = 3, ...
method forward (line 42) | def forward(self, x_packed):
class NaiveFineTunePretrainedBert (line 52) | class NaiveFineTunePretrainedBert(nn.Module):
method __init__ (line 53) | def __init__(self, input_dim = 840, hidden_dim = 768, output_dim = 3, ...
method forward (line 62) | def forward(self, input_embeddings_batch, input_masks_batch, labels):
class FineTunePretrainedTwoStep (line 68) | class FineTunePretrainedTwoStep(nn.Module):
method __init__ (line 69) | def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024...
method forward (line 83) | def forward(self, input_embeddings_batch, input_masks_batch, input_mas...
class ZeroShotSentimentDiscovery (line 102) | class ZeroShotSentimentDiscovery(nn.Module):
method __init__ (line 103) | def __init__(self, brain2text_translator, sentiment_classifier, transl...
method forward (line 114) | def forward(self, input_embeddings_batch, input_masks_batch, input_mas...
class BartClassificationHead (line 147) | class BartClassificationHead(nn.Module):
method __init__ (line 150) | def __init__(
method forward (line 162) | def forward(self, hidden_states: torch.Tensor):
class JointBrainTranslatorSentimentClassifier (line 170) | class JointBrainTranslatorSentimentClassifier(nn.Module):
method __init__ (line 171) | def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024...
method forward (line 184) | def forward(self, input_embeddings_batch, input_masks_batch, input_mas...
class Pooler (line 214) | class Pooler(nn.Module):
method __init__ (line 215) | def __init__(self, hidden_size):
method forward (line 220) | def forward(self, hidden_states):
class PositionalEncoding (line 229) | class PositionalEncoding(nn.Module):
method __init__ (line 231) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 243) | def forward(self, x):
FILE: train_decoding.py
function train_model (line 20) | def train_model(dataloaders, device, model, criterion, optimizer, schedu...
function show_require_grad_layers (line 126) | def show_require_grad_layers(model):
FILE: train_sentiment_baseline.py
function flat_accuracy (line 22) | def flat_accuracy(preds, labels):
function flat_accuracy_top_k (line 31) | def flat_accuracy_top_k(preds, labels,k):
function train_model (line 46) | def train_model(dataloaders, device, model, criterion, optimizer, schedu...
FILE: train_sentiment_textbased.py
function flat_accuracy (line 21) | def flat_accuracy(preds, labels):
function flat_accuracy_top_k (line 30) | def flat_accuracy_top_k(preds, labels,k):
function train_model_ZuCo (line 45) | def train_model_ZuCo(dataloaders, device, model, criterion, optimizer, s...
function train_model_SST (line 137) | def train_model_SST(dataloaders, device, model, criterion, optimizer, sc...
FILE: util/data_loading_helpers_modified.py
function extract_all_fixations (line 37) | def extract_all_fixations(data_container, word_data_object, float_resolu...
function is_real_word (line 54) | def is_real_word(word):
function load_matlab_string (line 65) | def load_matlab_string(matlab_extracted_object):
function extract_word_level_data (line 76) | def extract_word_level_data(data_container, word_objects, eeg_float_reso...
FILE: util/get_SST_ternary_dataset.py
function get_SST_dataset (line 15) | def get_SST_dataset(SST_dir_path, ZuCo_used_sentences, ZUCO_SENTIMENT_LA...
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (184K chars).
[
{
"path": ".gitignore",
"chars": 3125,
"preview": "*.pt\n*.pickle\n*.mat\n*.json\n*.txt\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\ncsv_results/\n*.py[cod]\n*$py.class"
},
{
"path": "README.md",
"chars": 7015,
"preview": "The **main branch** contains the final code for our \"Are EEG-to-Text Models Working?\" paper. \n\nAccepted by [IJCAI worksh"
},
{
"path": "config.py",
"chars": 9196,
"preview": "import argparse\n\ndef str2bool(v):\n if isinstance(v, bool):\n return v\n if v.lower() in ('yes', 'true', 't', "
},
{
"path": "data.py",
"chars": 16425,
"preview": "import os\nimport numpy as np\nimport torch\nimport pickle\nfrom torch.utils.data import Dataset, DataLoader\nimport json\nimp"
},
{
"path": "environment.yml",
"chars": 329,
"preview": "name: EEGToText\nchannels:\n - pytorch\n - anaconda\n - conda-forge\n - huggingface\ndependencies:\n - pytorch=1.9.0\n - t"
},
{
"path": "eval_decoding.py",
"chars": 16833,
"preview": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_s"
},
{
"path": "eval_sentiment.py",
"chars": 16220,
"preview": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_s"
},
{
"path": "model_decoding.py",
"chars": 14313,
"preview": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.data\nfrom transformers import BartTokenizer, Ba"
},
{
"path": "model_sentiment.py",
"chars": 12293,
"preview": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.data\nfrom transformers import BartTokenizer, Ba"
},
{
"path": "scripts/eval_decoding_1.sh",
"chars": 772,
"preview": "CUDA_VISIBLE_DEVICES=0 python3 eval_decoding.py \\\n --checkpoint_path checkpoints/decoding/best/task1_task2_task3_fine"
},
{
"path": "scripts/eval_decoding_2.sh",
"chars": 784,
"preview": "CUDA_VISIBLE_DEVICES=1 python3 eval_decoding.py \\\n --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_f"
},
{
"path": "scripts/eval_decoding_3.sh",
"chars": 784,
"preview": "CUDA_VISIBLE_DEVICES=2 python3 eval_decoding.py \\\n --checkpoint_path checkpoints/decoding/best/task1_task2_task3_fine"
},
{
"path": "scripts/eval_decoding_4.sh",
"chars": 796,
"preview": "CUDA_VISIBLE_DEVICES=3 python3 eval_decoding.py \\\n --checkpoint_path checkpoints/decoding/best/task1_task2_taskNRv2_f"
},
{
"path": "scripts/eval_sentiment_zeroshot_pipeline.sh",
"chars": 663,
"preview": "python3 eval_sentiment.py --model_name ZeroShotSentimentDiscovery \\\n --decoder_checkpoint_path ./checkpoints/decoding"
},
{
"path": "scripts/prepare_dataset.sh",
"chars": 741,
"preview": "echo \"This scirpt construct .pickle files from .mat files from ZuCo dataset.\"\necho \"This script also generates tenary se"
},
{
"path": "scripts/train_decoding.sh",
"chars": 696,
"preview": "CUDA_VISIBLE_DEVICES=0 python3 train_decoding.py --model_name BrainTranslator \\\n --task_name task1_task2_task3 \\\n "
},
{
"path": "scripts/train_decoding_1.sh",
"chars": 694,
"preview": "CUDA_VISIBLE_DEVICES=2,3 python3 train_decoding.py --model_name T5Translator \\\n --task_name task1_task2_taskNRv2 \\\n "
},
{
"path": "scripts/train_eeg_sentiment_baseline.sh",
"chars": 137,
"preview": "python3 train_sentiment_baseline.py --model_name BaselineMLP --num_epoch 20 -lr 0.00005 -b 32 -s ./checkpoints/eeg_senti"
},
{
"path": "scripts/train_eval_zeroshot_pipeline.sh",
"chars": 1614,
"preview": "\necho \"###################################\"\necho \"Training decoder: BART, task1-SR...\"\necho \"###########################"
},
{
"path": "scripts/train_text_sentiment_classifier.sh",
"chars": 212,
"preview": "python3 train_sentiment_textbased.py \\\n --dataset_name SST \\\n --model_name pretrain_Bart \\\n --num_epoch 20 \\\n "
},
{
"path": "train_decoding.py",
"chars": 17153,
"preview": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_s"
},
{
"path": "train_sentiment_baseline.py",
"chars": 10137,
"preview": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_s"
},
{
"path": "train_sentiment_textbased.py",
"chars": 14976,
"preview": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_s"
},
{
"path": "util/construct_dataset_mat_to_pickle_v1.py",
"chars": 5964,
"preview": "import scipy.io as io\nimport h5py\nimport os\nimport json\nfrom glob import glob\nfrom tqdm import tqdm\nimport numpy as np\ni"
},
{
"path": "util/construct_dataset_mat_to_pickle_v2.py",
"chars": 6030,
"preview": "import os\nimport numpy as np\nimport h5py\nimport data_loading_helpers_modified as dh\nfrom glob import glob\nfrom tqdm impo"
},
{
"path": "util/data_loading_helpers_modified.py",
"chars": 12175,
"preview": "import numpy as np\nimport re\n\neeg_float_resolution=np.float16\n\nAlpha_ffd_names = ['FFD_a1', 'FFD_a1_diff', 'FFD_a2', 'FF"
},
{
"path": "util/get_SST_ternary_dataset.py",
"chars": 6211,
"preview": "import os\nimport numpy as np\nimport torch\nimport pickle\nfrom torch.utils.data import Dataset, DataLoader\nimport json\nimp"
},
{
"path": "util/get_sentiment_labels.py",
"chars": 1093,
"preview": "import os\nfrom glob import glob\nimport json\n\nprint('##############################')\nprint('start generating ZuCo task1-"
}
]
About this extraction
This page contains the full source code of the NeuSpeech/EEG-To-Text GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (173.2 KB), approximately 42.6k tokens, and a symbol index with 86 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.