Repository: ArvinZhuang/DSI-transformers Branch: main Commit: 653b9d1fad20 Files: 7 Total size: 15.6 KB Directory structure: gitextract_c4yyz4mf/ ├── .gitignore ├── LICENSE ├── README.md ├── data/ │ └── NQ/ │ └── create_NQ_train_vali.py ├── data.py ├── train.py └── trainer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .DS_Store ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Shengyao Zhuang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # DSI-transformers A huggingface transformers implementation of [Transformer Memory as a Differentiable Search Index](https://arxiv.org/abs/2202.06991), Yi Tay, Vinh Q. Tran, Mostafa Dehghani, Jianmo Ni, Dara Bahri, Harsh Mehta, Zhen Qin, Kai Hui, Zhe Zhao, Jai Gupta, Tal Schuster, William W. Cohen, Donald Metzler Requirements: `python=3.8` `transformers=4.17.0` `datasets=1.18.3` `wandb` > Note: This is not the official repository. ## Updates - Check out our new repository for DSI training: https://github.com/ArvinZhuang/DSI-QG ## Goal of this repository Reproduce the results of DSI Large, Naive String Docid, NQ10K. According to Table 3 in the original paper, we should have `Hits@1=0.347`,`Hits@10=0.605` ### Step1: Create NQ10K training (indexing) and validation datasets ``` cd data/NQ python3 create_NQ_train_vali.py ``` ### Step2: Run training script cd back to the root directory and run: ``` python3 train.py ``` The training can be run with a single Tesla-v100 32G GPU. We use [wandb](https://wandb.ai/site) to log the Hits scores during training: ![.im](hits_plots.png) ### Discussion As the above plots show, the current implementation is worse than what is reported in the original paper, there are many possible reasons: the ratio of training and indexing examples (we use 1:1), number of training steps, the way of constructing documents text, etc. Although, seems the scores are on par with BM25 already. If you can identify the reason or any bug, welcome to open a PR to fix it! #### Indexing or overfitting? The training script also logged the hit@1 scores on the training set during training, this is aimed to analyze if the model can memorize all the training data points, the authors called this 'indexing' which I believe is just letting the model overfits the training data points. It turns out the model can reach %99.99 hit@1 on the training set very quickly (quickly overfit), but the hits scores on the test set continue going up. Seems T5 large has good generalizability on this task. ================================================ FILE: data/NQ/create_NQ_train_vali.py ================================================ import datasets import random import numpy as np import json random.seed(313) NUM_TRAIN = 8000 NUM_EVAL = 2000 data = datasets.load_dataset('natural_questions', cache_dir='cache')['train'] rand_inds = list(range(len(data))) random.shuffle(rand_inds) title_set = set() current_docid = 0 with open('NQ_10k_multi_task_train.json', 'w') as tf, \ open('NQ_10k_valid.json', 'w') as vf: for ind in rand_inds: title = data[ind]['document']['title'] # we use title as the doc identifier to prevent two docs have the same text if title not in title_set: title_set.add(title) token_inds = np.where(np.array(data[ind]['document']['tokens']['is_html']) == False)[0] tokens = np.array(data[ind]['document']['tokens']['token'])[token_inds] doc_text = " ".join(tokens) question_text = data[ind]['question']['text'] jitem = json.dumps({'text_id': str(current_docid), 'text': 'document: ' + doc_text}) tf.write(jitem + '\n') jitem = json.dumps({'text_id': str(current_docid), 'text': 'question: ' + question_text}) if len(title_set) <= NUM_TRAIN: tf.write(jitem + '\n') else: vf.write(jitem + '\n') current_docid += 1 if len(title_set) == NUM_TRAIN + NUM_EVAL: break print(f"Creating training and validation dataset: {'{:.1%}'.format(len(title_set)/(NUM_TRAIN + NUM_EVAL))}", end='\r') ================================================ FILE: data.py ================================================ from dataclasses import dataclass import datasets from torch.utils.data import Dataset from transformers import PreTrainedTokenizer, DataCollatorWithPadding class IndexingTrainDataset(Dataset): def __init__( self, path_to_data, max_length: int, cache_dir: str, tokenizer: PreTrainedTokenizer, ): self.train_data = datasets.load_dataset( 'json', data_files=path_to_data, ignore_verifications=False, cache_dir=cache_dir )['train'] self.max_length = max_length self.tokenizer = tokenizer self.total_len = len(self.train_data) def __len__(self): return self.total_len def __getitem__(self, item): data = self.train_data[item] input_ids = self.tokenizer(data['text'], return_tensors="pt", truncation='only_first', max_length=self.max_length).input_ids[0] return input_ids, str(data['text_id']) @dataclass class IndexingCollator(DataCollatorWithPadding): def __call__(self, features): input_ids = [{'input_ids': x[0]} for x in features] docids = [x[1] for x in features] inputs = super().__call__(input_ids) labels = self.tokenizer( docids, padding="longest", return_tensors="pt" ).input_ids # replace padding token id's of the labels by -100 according to https://huggingface.co/docs/transformers/model_doc/t5#training labels[labels == self.tokenizer.pad_token_id] = -100 inputs['labels'] = labels return inputs @dataclass class QueryEvalCollator(DataCollatorWithPadding): def __call__(self, features): input_ids = [{'input_ids': x[0]} for x in features] labels = [x[1] for x in features] inputs = super().__call__(input_ids) return inputs, labels ================================================ FILE: train.py ================================================ from data import IndexingTrainDataset, IndexingCollator, QueryEvalCollator from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, TrainerCallback from trainer import IndexingTrainer import numpy as np import torch import wandb from torch.utils.data import DataLoader from tqdm import tqdm class QueryEvalCallback(TrainerCallback): def __init__(self, test_dataset, logger, restrict_decode_vocab, args: TrainingArguments, tokenizer: T5Tokenizer): self.tokenizer = tokenizer self.logger = logger self.args = args self.test_dataset = test_dataset self.restrict_decode_vocab = restrict_decode_vocab self.dataloader = DataLoader( test_dataset, batch_size=self.args.per_device_eval_batch_size, collate_fn=QueryEvalCollator( self.tokenizer, padding='longest' ), shuffle=False, drop_last=False, num_workers=self.args.dataloader_num_workers, ) def on_epoch_end(self, args, state, control, **kwargs): hit_at_1 = 0 hit_at_10 = 0 model = kwargs['model'].eval() for batch in tqdm(self.dataloader, desc='Evaluating dev queries'): inputs, labels = batch with torch.no_grad(): batch_beams = model.generate( inputs['input_ids'].to(model.device), max_length=20, num_beams=10, prefix_allowed_tokens_fn=self.restrict_decode_vocab, num_return_sequences=10, early_stopping=True, ).reshape(inputs['input_ids'].shape[0], 10, -1) for beams, label in zip(batch_beams, labels): rank_list = self.tokenizer.batch_decode(beams, skip_special_tokens=True) # beam search should not return repeated docids but somehow due to T5 tokenizer there some repeats. hits = np.where(np.array(rank_list)[:10] == label)[0] if len(hits) != 0: hit_at_10 += 1 if hits[0] == 0: hit_at_1 += 1 self.logger.log({"Hits@1": hit_at_1 / len(self.test_dataset), "Hits@10": hit_at_10 / len(self.test_dataset)}) def compute_metrics(eval_preds): num_predict = 0 num_correct = 0 for predict, label in zip(eval_preds.predictions, eval_preds.label_ids): num_predict += 1 if len(np.where(predict == 1)[0]) == 0: continue if np.array_equal(label[:np.where(label == 1)[0].item()], predict[np.where(predict == 0)[0][0].item() + 1:np.where(predict == 1)[0].item()]): num_correct += 1 return {'accuracy': num_correct / num_predict} def main(): model_name = "t5-large" L = 32 # only use the first 32 tokens of documents (including title) # We use wandb to log Hits scores after each epoch. Note, this script does not save model checkpoints. wandb.login() wandb.init(project="DSI", name='NQ-10k-t5-large') tokenizer = T5Tokenizer.from_pretrained(model_name, cache_dir='cache') model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir='cache') train_dataset = IndexingTrainDataset(path_to_data='data/NQ/NQ_10k_multi_task_train.json', max_length=L, cache_dir='cache', tokenizer=tokenizer) # This eval set is really not the 'eval' set but used to report if the model can memorise (index) all training data points. eval_dataset = IndexingTrainDataset(path_to_data='data/NQ/NQ_10k_multi_task_train.json', max_length=L, cache_dir='cache', tokenizer=tokenizer) # This is the actual eval set. test_dataset = IndexingTrainDataset(path_to_data='data/NQ/NQ_10k_valid.json', max_length=L, cache_dir='cache', tokenizer=tokenizer) ################################################################ # docid generation constrain, we only generate integer docids. SPIECE_UNDERLINE = "▁" INT_TOKEN_IDS = [] for token, id in tokenizer.get_vocab().items(): if token[0] == SPIECE_UNDERLINE: if token[1:].isdigit(): INT_TOKEN_IDS.append(id) if token == SPIECE_UNDERLINE: INT_TOKEN_IDS.append(id) elif token.isdigit(): INT_TOKEN_IDS.append(id) INT_TOKEN_IDS.append(tokenizer.eos_token_id) def restrict_decode_vocab(batch_idx, prefix_beam): return INT_TOKEN_IDS ################################################################ training_args = TrainingArguments( output_dir="./results", learning_rate=0.0005, warmup_steps=10000, # weight_decay=0.01, per_device_train_batch_size=128, per_device_eval_batch_size=128, evaluation_strategy='steps', eval_steps=1000, max_steps=1000000, dataloader_drop_last=False, # necessary report_to='wandb', logging_steps=50, save_strategy='no', # fp16=True, # gives 0/nan loss at some point during training, seems this is a transformers bug. dataloader_num_workers=10, # gradient_accumulation_steps=2 ) trainer = IndexingTrainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=IndexingCollator( tokenizer, padding='longest', ), compute_metrics=compute_metrics, callbacks=[QueryEvalCallback(test_dataset, wandb, restrict_decode_vocab, training_args, tokenizer)], restrict_decode_vocab=restrict_decode_vocab ) trainer.train( ) if __name__ == "__main__": main() ================================================ FILE: trainer.py ================================================ from typing import Dict, List, Tuple, Optional, Any, Union from transformers.trainer import Trainer from torch import nn import torch class IndexingTrainer(Trainer): def __init__(self, restrict_decode_vocab, **kwds): super().__init__(**kwds) self.restrict_decode_vocab = restrict_decode_vocab def compute_loss(self, model, inputs, return_outputs=False): loss = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['labels']).loss if return_outputs: return loss, [None, None] # fake outputs return loss def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: model.eval() # eval_loss = super().prediction_step(model, inputs, True, ignore_keys)[0] with torch.no_grad(): # greedy search doc_ids = model.generate( inputs['input_ids'].to(self.args.device), max_length=20, prefix_allowed_tokens_fn=self.restrict_decode_vocab, early_stopping=True,) return (None, doc_ids, inputs['labels'])