Full Code of ArvinZhuang/DSI-transformers for AI

main 653b9d1fad20 cached
7 files
15.6 KB
3.9k tokens
17 symbols
1 requests
Download .txt
Repository: ArvinZhuang/DSI-transformers
Branch: main
Commit: 653b9d1fad20
Files: 7
Total size: 15.6 KB

Directory structure:
gitextract_c4yyz4mf/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   └── NQ/
│       └── create_NQ_train_vali.py
├── data.py
├── train.py
└── trainer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
.DS_Store


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 Shengyao Zhuang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# DSI-transformers
A huggingface transformers implementation of [Transformer Memory as a Differentiable Search Index](https://arxiv.org/abs/2202.06991), Yi Tay, Vinh Q. Tran, Mostafa Dehghani, Jianmo Ni, Dara Bahri, Harsh Mehta, Zhen Qin, Kai Hui, Zhe Zhao, Jai Gupta, Tal Schuster, William W. Cohen, Donald Metzler

Requirements: `python=3.8` `transformers=4.17.0` `datasets=1.18.3` `wandb`
> Note: This is not the official repository.

## Updates
- Check out our new repository for DSI training: https://github.com/ArvinZhuang/DSI-QG

## Goal of this repository
Reproduce the results of DSI Large, Naive String Docid, NQ10K. According to Table 3 in the original paper, we should have `Hits@1=0.347`,`Hits@10=0.605`

### Step1: Create NQ10K training (indexing) and validation datasets

```
cd data/NQ
python3 create_NQ_train_vali.py
```

### Step2: Run training script
cd back to the root directory and run:

```
python3 train.py
```
The training can be run with a single Tesla-v100 32G GPU. We use [wandb](https://wandb.ai/site) to log the Hits scores during training:

![.im](hits_plots.png)


### Discussion

As the above plots show, the current implementation is worse than what is reported in the original paper, there are many possible reasons: the ratio of training and indexing examples (we use 1:1), number of training steps, the way of constructing documents text, etc. Although, seems the scores are on par with BM25 already.

If you can identify the reason or any bug, welcome to open a PR to fix it!

#### Indexing or overfitting?
The training script also logged the hit@1 scores on the training set during training, this is aimed to analyze if the model can memorize all the training data points, the authors called this 'indexing' which I believe is just letting the model overfits the training data points. It turns out the model can reach %99.99 hit@1 on the training set very quickly (quickly overfit), but the hits scores on the test set continue going up. Seems T5 large has good generalizability on this task.


================================================
FILE: data/NQ/create_NQ_train_vali.py
================================================
import datasets
import random
import numpy as np
import json

random.seed(313)

NUM_TRAIN = 8000
NUM_EVAL = 2000

data = datasets.load_dataset('natural_questions', cache_dir='cache')['train']
rand_inds = list(range(len(data)))
random.shuffle(rand_inds)

title_set = set()
current_docid = 0

with open('NQ_10k_multi_task_train.json', 'w') as tf, \
        open('NQ_10k_valid.json', 'w') as vf:
    for ind in rand_inds:
        title = data[ind]['document']['title']  # we use title as the doc identifier to prevent two docs have the same text
        if title not in title_set:
            title_set.add(title)
            token_inds = np.where(np.array(data[ind]['document']['tokens']['is_html']) == False)[0]
            tokens = np.array(data[ind]['document']['tokens']['token'])[token_inds]
            doc_text = " ".join(tokens)
            question_text = data[ind]['question']['text']

            jitem = json.dumps({'text_id': str(current_docid), 'text': 'document: ' + doc_text})
            tf.write(jitem + '\n')
            jitem = json.dumps({'text_id': str(current_docid), 'text': 'question: ' + question_text})
            if len(title_set) <= NUM_TRAIN:
                tf.write(jitem + '\n')
            else:
                vf.write(jitem + '\n')
            current_docid += 1
            if len(title_set) == NUM_TRAIN + NUM_EVAL:
                break
        print(f"Creating training and validation dataset: {'{:.1%}'.format(len(title_set)/(NUM_TRAIN + NUM_EVAL))}", end='\r')

================================================
FILE: data.py
================================================
from dataclasses import dataclass

import datasets
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, DataCollatorWithPadding


class IndexingTrainDataset(Dataset):
    def __init__(
            self,
            path_to_data,
            max_length: int,
            cache_dir: str,
            tokenizer: PreTrainedTokenizer,
    ):
        self.train_data = datasets.load_dataset(
            'json',
            data_files=path_to_data,
            ignore_verifications=False,
            cache_dir=cache_dir
        )['train']

        self.max_length = max_length
        self.tokenizer = tokenizer
        self.total_len = len(self.train_data)


    def __len__(self):
        return self.total_len

    def __getitem__(self, item):
        data = self.train_data[item]

        input_ids = self.tokenizer(data['text'],
                                   return_tensors="pt",
                                   truncation='only_first',
                                   max_length=self.max_length).input_ids[0]
        return input_ids, str(data['text_id'])


@dataclass
class IndexingCollator(DataCollatorWithPadding):
    def __call__(self, features):
        input_ids = [{'input_ids': x[0]} for x in features]
        docids = [x[1] for x in features]
        inputs = super().__call__(input_ids)

        labels = self.tokenizer(
            docids, padding="longest", return_tensors="pt"
        ).input_ids

        # replace padding token id's of the labels by -100 according to https://huggingface.co/docs/transformers/model_doc/t5#training
        labels[labels == self.tokenizer.pad_token_id] = -100
        inputs['labels'] = labels
        return inputs


@dataclass
class QueryEvalCollator(DataCollatorWithPadding):
    def __call__(self, features):
        input_ids = [{'input_ids': x[0]} for x in features]
        labels = [x[1] for x in features]
        inputs = super().__call__(input_ids)

        return inputs, labels

================================================
FILE: train.py
================================================
from data import IndexingTrainDataset, IndexingCollator, QueryEvalCollator
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, TrainerCallback
from trainer import IndexingTrainer
import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm


class QueryEvalCallback(TrainerCallback):
    def __init__(self, test_dataset, logger, restrict_decode_vocab, args: TrainingArguments, tokenizer: T5Tokenizer):
        self.tokenizer = tokenizer
        self.logger = logger
        self.args = args
        self.test_dataset = test_dataset
        self.restrict_decode_vocab = restrict_decode_vocab
        self.dataloader = DataLoader(
            test_dataset,
            batch_size=self.args.per_device_eval_batch_size,
            collate_fn=QueryEvalCollator(
                self.tokenizer,
                padding='longest'
            ),
            shuffle=False,
            drop_last=False,
            num_workers=self.args.dataloader_num_workers,
        )

    def on_epoch_end(self, args, state, control, **kwargs):
        hit_at_1 = 0
        hit_at_10 = 0
        model = kwargs['model'].eval()
        for batch in tqdm(self.dataloader, desc='Evaluating dev queries'):
            inputs, labels = batch
            with torch.no_grad():
                batch_beams = model.generate(
                    inputs['input_ids'].to(model.device),
                    max_length=20,
                    num_beams=10,
                    prefix_allowed_tokens_fn=self.restrict_decode_vocab,
                    num_return_sequences=10,
                    early_stopping=True, ).reshape(inputs['input_ids'].shape[0], 10, -1)
                for beams, label in zip(batch_beams, labels):
                    rank_list = self.tokenizer.batch_decode(beams,
                                                            skip_special_tokens=True)  # beam search should not return repeated docids but somehow due to T5 tokenizer there some repeats.
                    hits = np.where(np.array(rank_list)[:10] == label)[0]
                    if len(hits) != 0:
                        hit_at_10 += 1
                        if hits[0] == 0:
                            hit_at_1 += 1
        self.logger.log({"Hits@1": hit_at_1 / len(self.test_dataset), "Hits@10": hit_at_10 / len(self.test_dataset)})


def compute_metrics(eval_preds):
    num_predict = 0
    num_correct = 0
    for predict, label in zip(eval_preds.predictions, eval_preds.label_ids):
        num_predict += 1
        if len(np.where(predict == 1)[0]) == 0:
            continue
        if np.array_equal(label[:np.where(label == 1)[0].item()],
                          predict[np.where(predict == 0)[0][0].item() + 1:np.where(predict == 1)[0].item()]):
            num_correct += 1

    return {'accuracy': num_correct / num_predict}


def main():
    model_name = "t5-large"
    L = 32  # only use the first 32 tokens of documents (including title)

    # We use wandb to log Hits scores after each epoch. Note, this script does not save model checkpoints.
    wandb.login()
    wandb.init(project="DSI", name='NQ-10k-t5-large')

    tokenizer = T5Tokenizer.from_pretrained(model_name, cache_dir='cache')
    model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir='cache')

    train_dataset = IndexingTrainDataset(path_to_data='data/NQ/NQ_10k_multi_task_train.json',
                                         max_length=L,
                                         cache_dir='cache',
                                         tokenizer=tokenizer)
    
    # This eval set is really not the 'eval' set but used to report if the model can memorise (index) all training data points.
    eval_dataset = IndexingTrainDataset(path_to_data='data/NQ/NQ_10k_multi_task_train.json',
                                        max_length=L,
                                        cache_dir='cache',
                                        tokenizer=tokenizer)
    
    # This is the actual eval set.
    test_dataset = IndexingTrainDataset(path_to_data='data/NQ/NQ_10k_valid.json',
                                        max_length=L,
                                        cache_dir='cache',
                                        tokenizer=tokenizer)

    ################################################################
    # docid generation constrain, we only generate integer docids.
    SPIECE_UNDERLINE = "▁"
    INT_TOKEN_IDS = []
    for token, id in tokenizer.get_vocab().items():
        if token[0] == SPIECE_UNDERLINE:
            if token[1:].isdigit():
                INT_TOKEN_IDS.append(id)
        if token == SPIECE_UNDERLINE:
            INT_TOKEN_IDS.append(id)
        elif token.isdigit():
            INT_TOKEN_IDS.append(id)
    INT_TOKEN_IDS.append(tokenizer.eos_token_id)

    def restrict_decode_vocab(batch_idx, prefix_beam):
        return INT_TOKEN_IDS
    ################################################################

    training_args = TrainingArguments(
        output_dir="./results",
        learning_rate=0.0005,
        warmup_steps=10000,
        # weight_decay=0.01,
        per_device_train_batch_size=128,
        per_device_eval_batch_size=128,
        evaluation_strategy='steps',
        eval_steps=1000,
        max_steps=1000000,
        dataloader_drop_last=False,  # necessary
        report_to='wandb',
        logging_steps=50,
        save_strategy='no',
        # fp16=True,  # gives 0/nan loss at some point during training, seems this is a transformers bug.
        dataloader_num_workers=10,
        # gradient_accumulation_steps=2
    )

    trainer = IndexingTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=IndexingCollator(
            tokenizer,
            padding='longest',
        ),
        compute_metrics=compute_metrics,
        callbacks=[QueryEvalCallback(test_dataset, wandb, restrict_decode_vocab, training_args, tokenizer)],
        restrict_decode_vocab=restrict_decode_vocab
    )
    trainer.train(
    )


if __name__ == "__main__":
    main()


================================================
FILE: trainer.py
================================================
from typing import Dict, List, Tuple, Optional, Any, Union
from transformers.trainer import Trainer
from torch import nn
import torch


class IndexingTrainer(Trainer):
    def __init__(self, restrict_decode_vocab, **kwds):
        super().__init__(**kwds)
        self.restrict_decode_vocab = restrict_decode_vocab

    def compute_loss(self, model, inputs, return_outputs=False):
        loss = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['labels']).loss
        if return_outputs:
            return loss, [None, None]  # fake outputs
        return loss

    def prediction_step(
            self,
            model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]],
            prediction_loss_only: bool,
            ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        model.eval()
        # eval_loss = super().prediction_step(model, inputs, True, ignore_keys)[0]
        with torch.no_grad():
            # greedy search
            doc_ids = model.generate(
                inputs['input_ids'].to(self.args.device),
                max_length=20,
                prefix_allowed_tokens_fn=self.restrict_decode_vocab,
                early_stopping=True,)
        return (None, doc_ids, inputs['labels'])

Download .txt
gitextract_c4yyz4mf/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   └── NQ/
│       └── create_NQ_train_vali.py
├── data.py
├── train.py
└── trainer.py
Download .txt
SYMBOL INDEX (17 symbols across 3 files)

FILE: data.py
  class IndexingTrainDataset (line 8) | class IndexingTrainDataset(Dataset):
    method __init__ (line 9) | def __init__(
    method __len__ (line 28) | def __len__(self):
    method __getitem__ (line 31) | def __getitem__(self, item):
  class IndexingCollator (line 42) | class IndexingCollator(DataCollatorWithPadding):
    method __call__ (line 43) | def __call__(self, features):
  class QueryEvalCollator (line 59) | class QueryEvalCollator(DataCollatorWithPadding):
    method __call__ (line 60) | def __call__(self, features):

FILE: train.py
  class QueryEvalCallback (line 11) | class QueryEvalCallback(TrainerCallback):
    method __init__ (line 12) | def __init__(self, test_dataset, logger, restrict_decode_vocab, args: ...
    method on_epoch_end (line 30) | def on_epoch_end(self, args, state, control, **kwargs):
  function compute_metrics (line 55) | def compute_metrics(eval_preds):
  function main (line 69) | def main():

FILE: trainer.py
  class IndexingTrainer (line 7) | class IndexingTrainer(Trainer):
    method __init__ (line 8) | def __init__(self, restrict_decode_vocab, **kwds):
    method compute_loss (line 12) | def compute_loss(self, model, inputs, return_outputs=False):
    method prediction_step (line 18) | def prediction_step(
Condensed preview — 7 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (17K chars).
[
  {
    "path": ".gitignore",
    "chars": 1809,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 1072,
    "preview": "MIT License\n\nCopyright (c) 2022 Shengyao Zhuang\n\nPermission is hereby granted, free of charge, to any person obtaining a"
  },
  {
    "path": "README.md",
    "chars": 2032,
    "preview": "# DSI-transformers\nA huggingface transformers implementation of [Transformer Memory as a Differentiable Search Index](ht"
  },
  {
    "path": "data/NQ/create_NQ_train_vali.py",
    "chars": 1502,
    "preview": "import datasets\nimport random\nimport numpy as np\nimport json\n\nrandom.seed(313)\n\nNUM_TRAIN = 8000\nNUM_EVAL = 2000\n\ndata ="
  },
  {
    "path": "data.py",
    "chars": 1981,
    "preview": "from dataclasses import dataclass\n\nimport datasets\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrai"
  },
  {
    "path": "train.py",
    "chars": 6215,
    "preview": "from data import IndexingTrainDataset, IndexingCollator, QueryEvalCollator\nfrom transformers import T5Tokenizer, T5ForCo"
  },
  {
    "path": "trainer.py",
    "chars": 1363,
    "preview": "from typing import Dict, List, Tuple, Optional, Any, Union\nfrom transformers.trainer import Trainer\nfrom torch import nn"
  }
]

About this extraction

This page contains the full source code of the ArvinZhuang/DSI-transformers GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 7 files (15.6 KB), approximately 3.9k tokens, and a symbol index with 17 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!