Full Code of stanford-crfm/pubmedgpt for AI

main 9e35fddada3e cached
36 files
214.5 KB
51.7k tokens
118 symbols
1 requests
Download .txt
Showing preview only (227K chars total). Download the full file or copy to clipboard to get everything.
Repository: stanford-crfm/pubmedgpt
Branch: main
Commit: 9e35fddada3e
Files: 36
Total size: 214.5 KB

Directory structure:
gitextract_dlrbq2_y/

├── README.md
├── demo.py
├── finetune/
│   ├── README.md
│   ├── deepspeed/
│   │   └── cpu_offload.json
│   ├── mc/
│   │   ├── README.md
│   │   ├── data/
│   │   │   └── medqa_usmle_hf/
│   │   │       ├── dev.json
│   │   │       ├── test.json
│   │   │       └── train.json
│   │   ├── preprocess_medqa.py
│   │   ├── run_experiments.py
│   │   └── run_multiple_choice.py
│   ├── seqcls/
│   │   ├── README.md
│   │   ├── data/
│   │   │   ├── bioasq_hf/
│   │   │   │   ├── dev.json
│   │   │   │   ├── test.json
│   │   │   │   └── train.json
│   │   │   └── pubmedqa_hf/
│   │   │       ├── dev.json
│   │   │       ├── test.json
│   │   │       └── train.json
│   │   ├── preprocess_blurb_seqcls.py
│   │   └── run_seqcls_gpt.py
│   ├── setup/
│   │   └── requirements.txt
│   ├── textgen/
│   │   ├── data/
│   │   │   └── meqsum/
│   │   │       ├── test.source
│   │   │       ├── test.target
│   │   │       ├── train.source
│   │   │       ├── train.target
│   │   │       ├── val.source
│   │   │       └── val.target
│   │   └── gpt2/
│   │       ├── finetune_for_summarization.py
│   │       ├── generate_demo.py
│   │       ├── run_generation_batch.py
│   │       ├── sum_data_collator.py
│   │       └── sum_dataset.py
│   └── utils/
│       ├── custom_modeling_gpt2.py
│       ├── custom_modeling_gpt_neo.py
│       └── hf_flash_gpt_2.py
└── tokenize/
    └── train_bpe.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# BioMedLM

Code used for pre-training and fine-tuning the [BioMedLM](https://huggingface.co/stanford-crfm/pubmedgpt) model.

Note: This model was previously known as PubMedGPT, but the NIH has asked us to change the name since they hold the trademark on "PubMed", so the new name is BioMedLM!

### Links

[Blog](https://crfm.stanford.edu/2022/12/15/pubmedgpt.html)

[Model](https://huggingface.co/stanford-crfm/pubmedgpt/tree/main)

[MosaicML Composer](https://github.com/mosaicml/composer)

### Example Usage

```
import torch

from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda")

tokenizer = GPT2Tokenizer.from_pretrained("stanford-crfm/BioMedLM")

model = GPT2LMHeadModel.from_pretrained("stanford-crfm/BioMedLM").to(device)

input_ids = tokenizer.encode(
    "Photosynthesis is ", return_tensors="pt"
).to(device)

sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50)

print("Output:\n" + 100 * "-")
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
```


================================================
FILE: demo.py
================================================
import torch

from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda")

tokenizer = GPT2Tokenizer.from_pretrained("stanford-crfm/pubmed_gpt_tokenizer")

model = GPT2LMHeadModel.from_pretrained("stanford-crfm/pubmedgpt").to(device)

input_ids = tokenizer.encode(
    "Photosynthesis is ", return_tensors="pt"
).to(device)

sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50)

print("Output:\n" + 100 * "-")
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))


================================================
FILE: finetune/README.md
================================================
# Biomedical downstream evaluation

## NLU
### Dependencies
```bash
conda create -n pubmedgpt python=3.8.12 pytorch=1.12.1 torchdata cudatoolkit=11.3 -c pytorch
conda activate pubmedgpt
pip install -r setup/requirements.txt
```

### Usage

Note we are not providing the data. Demo versions of the `.jsonl` files are presented to show expected format.
There should be one json per line for each example in the respective data sets for these tasks.

For PubMedQA and BioASQ, go to `seqcls/` and run the following command (change paths appropriately for task):
```bash
task=pubmedqa_hf
datadir=data/$task
outdir=runs/$task/GPT2
mkdir -p $outdir
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 run_seqcls_gpt.py \
  --tokenizer_name stanford-crfm/pubmed_gpt_tokenizer --model_name_or_path {checkpoint} --train_file \
  $datadir/train.json --validation_file $datadir/dev.json --test_file $datadir/test.json --do_train \
  --do_eval --do_predict --per_device_train_batch_size 1 --gradient_accumulation_steps \
  {grad_accum} --learning_rate {lr} --warmup_ratio 0.5 --num_train_epochs {num_epochs}  --max_seq_length \
  {seq_len}  --logging_steps 100 --save_strategy no --evaluation_strategy no --output_dir \
  {run_dir} --overwrite_output_dir --bf16
  --seed {seed} --run_name {name}
```


For MedQA-USMLE, go to `mc/` and run the following command:
```bash
task=medqa_usmle_hf
datadir=data/$task
outdir=runs/$task/GPT2
mkdir -p $outdir
python -m torch.distributed.launch --nproc_per_node={num_devices} --nnodes=1 --node_rank=0 \
  run_multiple_choice.py --tokenizer_name stanford-crfm/pubmed_gpt_tokenizer --model_name_or_path \
  {checkpoint} --train_file data/medqa_usmle_hf/train.json --validation_file data/medqa_usmle_hf/dev.json \
  --test_file data/medqa_usmle_hf/test.json --do_train --do_eval --do_predict --per_device_train_batch_size \
  {train_per_device_batch_size} --per_device_eval_batch_size 1 --gradient_accumulation_steps {grad_accum} \
  --learning_rate {lr} --warmup_ratio 0.5 --num_train_epochs {epochs} --max_seq_length 512 \
  --{numerical_format} --seed {seed} --data_seed {seed} --logging_first_step --logging_steps 20 \
  --save_strategy no --evaluation_strategy steps --eval_steps 500 --run_name {run_name} \
  --output_dir trash/ \
  --overwrite_output_dir 
```

## NLG
Go to `./textgen`.

### Usage (seq2seq tasks)
Make sure the task dataset is in `./textgen/data`. See `meqsum` (a medical text simplification task) as an example. The dataset folder should have `<split>.source` and `<split>.target` files. The `.source` file should contain the original text in a one example per line format (e.g. the full original question from the user in the MeQSum task) and the `.target` file should contain the desired output in a one example per line format (e.g. the summarization of the question). This set up can be adapted for a new task. For instance you could place biomedical articles in the source files and then brief summaries in the target files.

Go to `./textgen/gpt2`.
To finetune, run:
```
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \
  finetune_for_summarization.py --output_dir {run_dir} --model_name_or_path {checkpoint}
  --tokenizer_name stanford-crfm/pubmed_gpt_tokenizer --per_device_train_batch_size 1 
  --per_device_eval_batch_size 1 --save_strategy no --do_eval --train_data_file 
  data/meqsum/train.source --eval_data_file data/meqsum/val.source --save_total_limit 2 
  --overwrite_output_dir --gradient_accumulation_steps {grad_accum} --learning_rate {lr} 
  --warmup_ratio 0.5 --weight_decay 0.0 --seed 7 --evaluation_strategy steps --eval_steps 200 
  --bf16 --num_train_epochs {num_epochs} --logging_steps 100 --logging_first_step 
```

After finetuning, run generation on the test set by:

```
CUDA_VISIBLE_DEVICES=0 python -u run_generation_batch.py --fp16 --max_source_length -1 --length 400 --model_name_or_path={finetune_checkpoint} --num_return_sequences 5 --stop_token [SEP] --tokenizer_name={finetune_checkpoint} --task_mode=meqsum --control_mode=no --tuning_mode finetune --gen_dir gen_results__tgtlen400__no_repeat_ngram_size6 --batch_size 9 --temperature 1.0 --no_repeat_ngram_size 6 --length_penalty -0.5 --wandb_entity=None --wandb_project=None --wandb_run_name=None
```


### Acknowledgement
The NLG part of the code was built on https://github.com/XiangLi1999/PrefixTuning


================================================
FILE: finetune/deepspeed/cpu_offload.json
================================================
{
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": 2e-06,
      "betas": [
        0.9,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 0.0
    }
  },

  "scheduler": {
    "type": "WarmupDecayLR",
    "params": {
      "total_num_steps": "auto",
      "warmup_max_lr": 2e-06,
      "warmup_num_steps": "auto"
    }
  },

  "zero_optimization": {
    "stage": 1,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "cpu_offload": true
  },
  
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",

  "fp16": {
   "enabled": true
  }

}


================================================
FILE: finetune/mc/README.md
================================================
## Setting Up MedQA

1.) Download data from [MedQA GitHub](https://github.com/jind11/MedQA) . The GitHub should have a link to a Google Drive. Make sure to download the contents to a directory path matching `raw_data/medqa` in this directory. For more details, review the `preprocess_medqa.py` script to see the specific paths the preprocessing script expects. For example, `raw_data/medqa/data_clean/questions/US/4_options` should exist when the original data is set up properly.

2.) Run the `preprocess_medqa.py` script in this directory to produce the data in the format expected by our fine-tuning code. It should produce the appropriate `.jsonl` files in `data/medqa_usmle_hf`.


================================================
FILE: finetune/mc/data/medqa_usmle_hf/dev.json
================================================
{"id": "id", "sent1": "passage and question ...", "sent2": "", "ending0": "answer 0", "ending1": "answer 1", "ending2": "answer 2", "ending3": "answer 3", "label": "int of correct answer"}


================================================
FILE: finetune/mc/data/medqa_usmle_hf/test.json
================================================
{"id": "id", "sent1": "passage and question ...", "sent2": "", "ending0": "answer 0", "ending1": "answer 1", "ending2": "answer 2", "ending3": "answer 3", "label": "int of correct answer"}


================================================
FILE: finetune/mc/data/medqa_usmle_hf/train.json
================================================
{"id": "id", "sent1": "passage and question ...", "sent2": "", "ending0": "answer 0", "ending1": "answer 1", "ending2": "answer 2", "ending3": "answer 3", "label": "int of correct answer"}


================================================
FILE: finetune/mc/preprocess_medqa.py
================================================
import os
import json
import random
import shutil
import numpy as np
from tqdm import tqdm


root = "data"
os.system(f"mkdir -p {root}")


def dump_jsonl(data, fpath):
    with open(fpath, "w") as outf:
        for d in data:
            print (json.dumps(d), file=outf)

def process_medqa(fname):
    dname = "medqa_usmle"
    lines = open(f"raw_data/medqa/data_clean/questions/US/4_options/phrases_no_exclude_{fname}.jsonl").readlines()
    outs, lens = [], []
    for i, line in enumerate(tqdm(lines)):
        stmt = json.loads(line)
        sent1 = stmt["question"]
        ends = [stmt["options"][key] for key in "ABCD"]
        outs.append({"id": f"{fname}-{i:05d}",
                      "sent1": sent1,
                      "sent2": "",
                      "ending0": ends[0],
                      "ending1": ends[1],
                      "ending2": ends[2],
                      "ending3": ends[3],
                      "label": ord(stmt["answer_idx"]) - ord("A")
                    })
        lens.append(len(sent1) + max([len(ends[0]),len(ends[1]), len(ends[2]), len(ends[3])]))
    print ("total", len(outs), "seqlen mean", int(np.mean(lens)), "median", int(np.median(lens)), "95th", int(np.percentile(lens, 95)), "max", np.max(lens))
    #
    os.system(f'mkdir -p {root}/{dname}_hf')
    dump_jsonl(outs, f"{root}/{dname}_hf/{fname}.json")


process_medqa("train")
process_medqa("test")
process_medqa("dev")


================================================
FILE: finetune/mc/run_experiments.py
================================================
import json
import os
import subprocess
import sys

env_setup_cmd = "task=medqa_usmle_hf ; datadir=data/$task ; export WANDB_PROJECT='biomedical-nlp-eval'"

experiments = [json.loads(line) for line in open(sys.argv[1]).read().split("\n") if line]

for experiment in experiments:
    checkpoint = experiment["checkpoint"]
    lr = experiment["lr"]
    epochs = experiment["epochs"]
    grad_accum = experiment["grad_accum"]
    train_per_device_batch_size = experiment["train_per_device_batch_size"]
    num_devices = experiment["num_devices"] if "num_devices" in experiment else 8
    batch_size = int(num_devices) * int(grad_accum) * int(train_per_device_batch_size)
    tokenizer = experiment["tokenizer"]
    numerical_format = experiment["numerical"] if "numerical" in experiment else "bf16"
    seed = experiment["seed"]
    use_flash = experiment["use_flash"]
    run_name = f"{os.path.basename(checkpoint)}-lr={lr}-batch_size={batch_size}-epochs={epochs}-seed={seed}-task=medqa"
    exp_cmd = (
        f"python -m torch.distributed.launch --nproc_per_node={num_devices} --nnodes=1 --node_rank=0"
        f" run_multiple_choice.py --use_flash {use_flash} --tokenizer_name {tokenizer} --model_name_or_path"
        f" {checkpoint} --train_file data/medqa_usmle_hf/train.json --validation_file data/medqa_usmle_hf/dev.json"
        " --test_file data/medqa_usmle_hf/test.json --do_train --do_eval --do_predict --per_device_train_batch_size"
        f" {train_per_device_batch_size} --per_device_eval_batch_size 1 --gradient_accumulation_steps {grad_accum}"
        f" --learning_rate {lr} --warmup_ratio 0.5 --num_train_epochs {epochs} --max_seq_length 512"
        f" --{numerical_format} --seed {seed} --data_seed {seed} --logging_first_step --logging_steps 20"
        f" --save_strategy no --evaluation_strategy steps --eval_steps 500 --run_name {run_name} "
        " --output_dir trash/"
        " --overwrite_output_dir"
    )
    if "sharded_ddp" in experiment and experiment["sharded_ddp"].lower() == "true":
        exp_cmd += " --sharded_ddp zero_dp_2 "
    print("---")
    print(exp_cmd)
    subprocess.call(f"{env_setup_cmd} ; {exp_cmd}", shell=True)


================================================
FILE: finetune/mc/run_multiple_choice.py
================================================
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for multiple choice.

https://github.com/huggingface/transformers/blob/bff1c71e84e392af9625c345f9ea71f7b6d75fb3/examples/pytorch/multiple-choice/run_swag.py
"""
# You can also adapt this script on your own multiple choice task. Pointers for this are left as comments.

import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional, Union

import datasets
import numpy as np
import torch
from datasets import load_dataset

import transformers
from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version

sys.path.insert(0, '..')
from utils.custom_modeling_gpt2 import GPT2ForMultipleChoice


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# check_min_version("4.9.0")

logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )
    use_flash: bool = field(
        default=False,
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_gpt_neo: bool = field(
        default=False,
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    # num_choices: int = field(
    #     default=4,
    #     metadata={"help": "Number of choices in multiple-choice QA."},
    # )
    max_seq_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total input sequence length after tokenization. If passed, sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": "Whether to pad all samples to the maximum sentence length. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
            "efficient on GPU but very bad for TPU."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )

    def __post_init__(self):
        if self.train_file is not None:
            extension = self.train_file.split(".")[-1]
            assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
        if self.validation_file is not None:
            extension = self.validation_file.split(".")[-1]
            assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        if self.test_file is not None:
            extension = self.test_file.split(".")[-1]
            assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."

@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [int(feature.pop(label_name)) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        # Un-flatten
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        # Add back labels
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).

    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).

    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.train_file is not None or data_args.validation_file is not None:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
        extension = data_args.train_file.split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    else:
        # Downloading and loading the swag dataset from the hub.
        raw_datasets = load_dataset("swag", "regular", cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    config.use_flash = model_args.use_flash
    config.use_gpt_neo = model_args.use_gpt_neo
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    #Added for GPT2
    if config.model_type == "gpt2" or "gpt_neo":
        model_class = GPT2ForMultipleChoice
    else:
        model_class = AutoModelForMultipleChoice

    model = model_class.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    #Added for GPT2
    if tokenizer.pad_token_id is None:
        print('Adding [PAD] token to tokenizer and model word embeddings.')
        num_added_tokens = tokenizer.add_special_tokens({'pad_token': '[PAD]', 'cls_token': '[CLS]', 'sep_token': '[SEP]'})
        embedding_layer = model.resize_token_embeddings(len(tokenizer))
        config.pad_token_id = tokenizer.pad_token_id



    # When using your own dataset or a different dataset from swag, you will probably need to change this.
    _num_choices = len([elm for elm in raw_datasets['train'].features.keys() if elm.startswith('ending')])
    print ('\nnum_choices according to dataset:', _num_choices, '\n')
    # raw_datasets['train'].features: {'id': Value(dtype='int64', id=None), 'sent1': Value(dtype='string', id=None), 'sent2': Value(dtype='string', id=None), 'ending0': Value(dtype='string', id=None), 'ending1': Value(dtype='string', id=None), 'ending2': Value(dtype='string', id=None), 'ending3': Value(dtype='string', id=None), 'label': Value(dtype='string', id=None)}
    ending_names = [f"ending{i}" for i in range(_num_choices)]
    context_name = "sent1"
    question_header_name = "sent2"

    if data_args.max_seq_length is None:
        max_seq_length = tokenizer.model_max_length
        if max_seq_length > 1024:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
            )
            max_seq_length = 1024
    else:
        if data_args.max_seq_length > tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
                f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
            )
        max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Preprocessing the datasets.
    def preprocess_function(examples):
        first_sentences = [[context] * _num_choices for context in examples[context_name]]
        question_headers = examples[question_header_name]
        second_sentences = [
            [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
        ]

        # Flatten out
        first_sentences = sum(first_sentences, [])
        second_sentences = sum(second_sentences, [])

        #Added for GPT2
        if config.model_type == "gpt2":
            first_sentences  = [s + tokenizer.sep_token for s in first_sentences]
            second_sentences = [s + tokenizer.sep_token for s in second_sentences]

        # Tokenize
        tokenized_examples = tokenizer(
            first_sentences,
            second_sentences,
            truncation=True,
            max_length=max_seq_length,
            padding="max_length" if data_args.pad_to_max_length else False,
        )
        # Un-flatten
        return {k: [v[i : i + _num_choices] for i in range(0, len(v), _num_choices)] for k, v in tokenized_examples.items()}


    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(range(data_args.max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=not data_args.overwrite_cache,
            )

    if training_args.do_eval:
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = raw_datasets["validation"]
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=not data_args.overwrite_cache,
            )

    if training_args.do_predict: #Added
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = raw_datasets["test"]
        with training_args.main_process_first(desc="test dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=not data_args.overwrite_cache,
            )

    # Data collator
    data_collator = (
        default_data_collator
        if data_args.pad_to_max_length
        else DataCollatorForMultipleChoice(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
    )

    # Metric
    def compute_metrics(eval_predictions):
        predictions, label_ids = eval_predictions
        preds = np.argmax(predictions, axis=1)
        return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()}

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the tokenizer too for easy upload
        metrics = train_result.metrics

        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        metrics = trainer.evaluate()
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    if training_args.do_predict: #Added
        logger.info("*** Predict ***")
        results = trainer.predict(predict_dataset)
        metrics = results.metrics
        metrics["predict_samples"] = len(predict_dataset)

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)
        trainer.log(metrics) #Added

        #Added
        import json
        output_dir = training_args.output_dir
        json.dump({"predictions": results.predictions.tolist(), "label_ids": results.label_ids.tolist()},
                      open(f"{output_dir}/predict_outputs.json", "w"))


    if training_args.push_to_hub:
        trainer.push_to_hub(
            finetuned_from=model_args.model_name_or_path,
            tasks="multiple-choice",
            dataset_tags="swag",
            dataset_args="regular",
            dataset="SWAG",
            language="en",
        )


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()


================================================
FILE: finetune/seqcls/README.md
================================================
## Setting Up BLURB (PubMedQA and BioASQ)

1.) Download [BioASQ](http://www.bioasq.org/) and [PubMedQA](https://pubmedqa.github.io/) original data. Make sure when downloading and expanding the data that it matches these paths: `raw_data/blurb/data_generation/data/pubmedqa` and `raw_data/blurb/data_generation/data/BioASQ` in this directory. For more details, review the `preprocess_blurb_seqcls.py` script to see the specific paths the preprocessing script expects. For example, the path `raw_data/blurb/data_generation/data/pubmedqa/pqal_fold0` should exist when the data has been set up properly.

2.) Run the `preprocess_medqa.py` script in this directory to produce the data in the format expected by our fine-tuning code. It should produce the appropriate `.jsonl` files in `data/pubmedqa_hf` and `data/bioasq_hf`.


================================================
FILE: finetune/seqcls/data/bioasq_hf/dev.json
================================================
{"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"}


================================================
FILE: finetune/seqcls/data/bioasq_hf/test.json
================================================
{"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"}


================================================
FILE: finetune/seqcls/data/bioasq_hf/train.json
================================================
{"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"}


================================================
FILE: finetune/seqcls/data/pubmedqa_hf/dev.json
================================================
{"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"}


================================================
FILE: finetune/seqcls/data/pubmedqa_hf/test.json
================================================
{"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"}


================================================
FILE: finetune/seqcls/data/pubmedqa_hf/train.json
================================================
{"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"}


================================================
FILE: finetune/seqcls/preprocess_blurb_seqcls.py
================================================
import os
import csv
import json
import random
import shutil
import numpy as np
import pandas as pd
from tqdm import tqdm


def dump_jsonl(data, fpath):
    with open(fpath, "w") as outf:
        for d in data:
            print (json.dumps(d), file=outf)


######################### BLURB sequence classification #########################
root = "data"
os.system(f"mkdir -p {root}")


def process_pubmedqa(fname):
    dname = "pubmedqa"
    print (dname, fname)
    if fname in ["train", "dev"]:
        data = json.load(open(f"raw_data/blurb/data_generation/data/pubmedqa/pqal_fold0/{fname}_set.json"))
    elif fname == "test":
        data = json.load(open(f"raw_data/blurb/data_generation/data/pubmedqa/{fname}_set.json"))
    else:
        assert False
    outs, lens = [], []
    for id in data:
        obj = data[id]
        context = " ".join([c.strip() for c in obj["CONTEXTS"] if c.strip()])
        question = obj["QUESTION"].strip()
        label = obj["final_decision"].strip()
        assert label in ["yes", "no", "maybe"]
        outs.append({"id": id, "sentence1": question, "sentence2": context, "label": label})
        lens.append(len(question) + len(context))
    print ("total", len(outs), "seqlen mean", int(np.mean(lens)), "median", int(np.median(lens)), "95th", int(np.percentile(lens, 95)), "max", np.max(lens))
    #
    os.system(f"mkdir -p {root}/{dname}_hf")
    dump_jsonl(outs, f"{root}/{dname}_hf/{fname}.json")

process_pubmedqa("test")
process_pubmedqa("train")
process_pubmedqa("dev")


def process_bioasq(fname):
    dname = "bioasq"
    print (dname, fname)
    df = pd.read_csv(open(f"raw_data/blurb/data_generation/data/BioASQ/{fname}.tsv"), sep="\t", header=None)
    outs, lens = [], []
    for _, row in df.iterrows():
        id       = row[0].strip()
        question = row[1].strip()
        context  = row[2].strip()
        label    = row[3].strip()
        assert label in ["yes", "no"]
        outs.append({"id": id, "sentence1": question, "sentence2": context, "label": label})
        lens.append(len(question) + len(context))
    print ("total", len(outs), "seqlen mean", int(np.mean(lens)), "median", int(np.median(lens)), "95th", int(np.percentile(lens, 95)), "max", np.max(lens))
    #
    os.system(f"mkdir -p {root}/{dname}_hf")
    dump_jsonl(outs, f"{root}/{dname}_hf/{fname}.json")

process_bioasq("test")
process_bioasq("dev")
process_bioasq("train")


================================================
FILE: finetune/seqcls/run_seqcls_gpt.py
================================================
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification.

Adapted from
https://github.com/huggingface/transformers/blob/72aee83ced5f31302c5e331d896412737287f976/examples/pytorch/text-classification/run_glue.py
"""
# You can also adapt this script on your own text classification task. Pointers for this are left as comments.

import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import numpy as np
from datasets import load_dataset, load_metric

import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

sys.path.insert(0, '..')
from utils.custom_modeling_gpt2 import GPT2ForSequenceClassification
from utils.custom_modeling_gpt_neo import GPTNeoForSequenceClassification


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.9.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

logger = logging.getLogger(__name__)


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
    )
    metric_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the metric"},
    )
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )

    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})

    gpt2_append_eos_tok: int = field(
        default=0, metadata={"help": "Append EOS token after input sequence or not"}
    )

    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
        elif self.dataset_name is not None:
            pass
        elif self.train_file is None or self.validation_file is None:
            raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
        else:
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )
    use_flash: bool = field(
        default=False, metadata={"help": "Use flash attention."}
    )


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.
    #
    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.task_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
    elif data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
        )
    else:
        # Loading a dataset from your local files.
        # CSV/JSON training and evaluation files are needed.
        data_files = {"train": data_args.train_file, "validation": data_args.validation_file}

        # Get the test dataset: you can provide your own CSV/JSON test file (see below)
        # when you use `do_predict` without specifying a GLUE benchmark task.
        if training_args.do_predict:
            if data_args.test_file is not None:
                train_extension = data_args.train_file.split(".")[-1]
                test_extension = data_args.test_file.split(".")[-1]
                assert (
                    test_extension == train_extension
                ), "`test_file` should have the same extension (csv or json) as `train_file`."
                data_files["test"] = data_args.test_file
            else:
                raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

        for key in data_files.keys():
            logger.info(f"load a local file for {key}: {data_files[key]}")

        if data_args.train_file.endswith(".csv"):
            # Loading a dataset from local csv files
            raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir)
        else:
            # Loading a dataset from local json files
            raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if data_args.task_name is not None:
        is_regression = data_args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            print ('is_regression', is_regression)
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            print ('\nlabel_list', label_list)
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    config.use_flash = model_args.use_flash
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    if config.model_type == "gpt2":
        model_class = GPT2ForSequenceClassification
    elif config.model_type == "gpt_neo":
        model_class = GPTNeoForSequenceClassification
    else:
        model_class = AutoModelForSequenceClassification
    model = model_class.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    #Added for GPT
    if tokenizer.pad_token_id is None:
        print('Adding [PAD] token to tokenizer and model word embeddings.')
        num_added_tokens = tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.add_tokens(["<|CONTEXT|>", "<|QUESTION1|>", "<|QUESTION2|>", "<|ANSWER|>"])
        embedding_layer = model.resize_token_embeddings(len(tokenizer))
        config.pad_token_id = tokenizer.pad_token_id

    # Preprocessing the raw_datasets
    if data_args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        elif "sentence" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence", None
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and data_args.task_name is not None
        and not is_regression
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif data_args.task_name is None and not is_regression:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}

    if data_args.max_seq_length > tokenizer.model_max_length:
        logger.warning(
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    #def modify_sentence1(text):
        #return "<|CONTEXT|>" + text

    #def modify_sentence2(text):
        #return "<|QUESTION|>" + text + "<|ANSWER|>"

    def preprocess_function(examples):
        
        # Tokenize the texts
        contexts = examples[sentence2_key]
        questions = examples[sentence1_key]

        args = (
            (examples[sentence1_key],) if sentence2_key is None else (contexts, questions)
        )

        result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

        #Added for GPT2
        if config.model_type in ["gpt2"] and data_args.gpt2_append_eos_tok:
            assert padding == "max_length"
            assert sorted(result.keys()) == sorted(["input_ids", "attention_mask"])
            input_ids = torch.tensor(result["input_ids"])
            attention_mask = torch.tensor(result["attention_mask"])
            sequence_lengths = torch.clamp(input_ids.ne(tokenizer.pad_token_id).sum(-1), max=max_seq_length-1)
            input_ids[range(len(input_ids)), sequence_lengths] = tokenizer.eos_token_id
            attention_mask[range(len(input_ids)), sequence_lengths] = 1
            result["input_ids"] = input_ids.tolist()
            result["attention_mask"] = attention_mask.tolist()

        # Map labels to IDs (not necessary for GLUE tasks)
        if label_to_id is not None and "label" in examples:
            result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
        return result

    with training_args.main_process_first(desc="dataset map pre-processing"):
        raw_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(range(data_args.max_train_samples))

    if training_args.do_eval:
        if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))

    if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
        if "test" not in raw_datasets and "test_matched" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
        if data_args.max_predict_samples is not None:
            predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))

    # Log a few random samples from the training set:
    # if training_args.do_train:
    #     for index in random.sample(range(len(train_dataset)), 3):
    #         logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")



    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        # Get the metric function
        if data_args.task_name is not None:
            metric = load_metric("glue", data_args.task_name)
        else:
            metric = load_metric("accuracy")

        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
        if data_args.task_name is not None:
            result = metric.compute(predictions=preds, references=p.label_ids)
            if len(result) > 1:
                result["combined_score"] = np.mean(list(result.values())).item()
            return result
        elif data_args.metric_name == "pearsonr":
            from scipy.stats import pearsonr as scipy_pearsonr
            pearsonr = float(scipy_pearsonr(p.label_ids, preds)[0])
            return {"pearsonr": pearsonr}
        elif data_args.metric_name == "PRF1":
            TP = ((preds == p.label_ids) & (preds != 0)).astype(int).sum().item()
            P_total = (preds != 0).astype(int).sum().item()
            L_total = (p.label_ids != 0).astype(int).sum().item()
            P = TP / P_total if P_total else 0
            R = TP / L_total if L_total else 0
            F1 = 2 * P*R/(P+R) if (P+R) else 0
            return {"precision": P, "recall": R, "F1": F1}
        elif is_regression:
            return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
        else:
            return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

    # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        #trainer.save_model()  # Saves the tokenizer too for easy upload

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            tasks.append("mnli-mm")
            eval_datasets.append(raw_datasets["validation_mismatched"])

        for eval_dataset, task in zip(eval_datasets, tasks):
            metrics = trainer.evaluate(eval_dataset=eval_dataset)

            max_eval_samples = (
                data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
            )
            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

            trainer.log_metrics("eval", metrics)
            trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        logger.info("*** Predict ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
        predict_datasets = [predict_dataset]
        if data_args.task_name == "mnli":
            tasks.append("mnli-mm")
            predict_datasets.append(raw_datasets["test_mismatched"])

        for predict_dataset, task in zip(predict_datasets, tasks):
            metrics = trainer.evaluate(eval_dataset=predict_dataset, metric_key_prefix="test")

            max_test_samples = (
                data_args.max_eval_samples if data_args.max_eval_samples is not None else len(predict_dataset)
            )
            metrics["test_samples"] = min(max_test_samples, len(predict_dataset))

            trainer.log_metrics("test", metrics)
            trainer.save_metrics("test", metrics)
            trainer.log(metrics)


    if training_args.push_to_hub:
        kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
        if data_args.task_name is not None:
            kwargs["language"] = "en"
            kwargs["dataset_tags"] = "glue"
            kwargs["dataset_args"] = data_args.task_name
            kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"

        trainer.push_to_hub(**kwargs)


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()


================================================
FILE: finetune/setup/requirements.txt
================================================
datasets==2.6.1
fairscale==0.4.12
huggingface-hub==0.10.1
rouge-score==0.0.4
sacrebleu==2.0.0
transformers==4.24.0
wandb==0.13.5


================================================
FILE: finetune/textgen/data/meqsum/test.source
================================================
The source text for an example. For instance this could be the full article that is supposed to be summarized. There should be one example per line. The corresponding train.target file would have the gold generations for each example. So the Nth line of this file would correspond to the Nth line of the *.target file.


================================================
FILE: finetune/textgen/data/meqsum/test.target
================================================
The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is the original text. This text is the desired generation for that source. So if this was a summarization task, the *.source file would have the full article, and this would be the summarization. The Nth line of this file corresponds to the Nth line of the *.source file.


================================================
FILE: finetune/textgen/data/meqsum/train.source
================================================
The source text for an example. For instance this could be the full article that is supposed to be summarized. There should be one example per line. The corresponding train.target file would have the gold generations for each example. So the Nth line of this file would correspond to the Nth line of the *.target file.


================================================
FILE: finetune/textgen/data/meqsum/train.target
================================================
The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is the original text. This text is the desired generation for that source. So if this was a summarization task, the *.source file would have the full article, and this would be the summarization. The Nth line of this file corresponds to the Nth line of the *.source file.


================================================
FILE: finetune/textgen/data/meqsum/val.source
================================================
The source text for an example. For instance this could be the full article that is supposed to be summarized. There should be one example per line. The corresponding train.target file would have the gold generations for each example. So the Nth line of this file would correspond to the Nth line of the *.target file.


================================================
FILE: finetune/textgen/data/meqsum/val.target
================================================
The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is the original text. This text is the desired generation for that source. So if this was a summarization task, the *.source file would have the full article, and this would be the summarization. The Nth line of this file corresponds to the Nth line of the *.source file.


================================================
FILE: finetune/textgen/gpt2/finetune_for_summarization.py
================================================
import torch
from typing import Optional
from dataclasses import dataclass, field
from transformers import (
    CONFIG_MAPPING,
    MODEL_WITH_LM_HEAD_MAPPING,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
    HfArgumentParser,
    PreTrainedTokenizer,
    TextDataset,
    Trainer,
    TrainingArguments,
    set_seed,
    GPT2LMHeadModel,
    AutoModelForCausalLM,
)

from sum_data_collator import DataCollatorForSumLanguageModeling
from sum_dataset import LineByLineSumTextDataset

import torch.distributed as dist

import json

import sys

sys.path.insert(0, "../..")

@dataclass
class ModelArguments:
    """
    Arguments for the model
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Leave None if you want to train a model from"
                " scratch."
            )
        },
    )

    tokenizer_name: Optional[str] = field(
        default="gpt2", metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )

    use_flash: bool = field(
        default=False, metadata={"help": "Use flash attention."}
    )

@dataclass
class DataArguments:
    """
    Arguments for data
    """

    train_data_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a text file)."}
    )
    eval_data_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    max_source_length: Optional[int] = field(
        default=510, metadata={"help": "the max source length of summarization data. "}
    )
    train_max_target_length: Optional[int] = field(
        default=510, metadata={"help": "the max target length for training data. "}
    )
    eval_max_target_length: Optional[int] = field(
        default=510, metadata={"help": "the max target length for dev data. "}
    )
    seq_prefix: Optional[str] = field(
        default="",
        metadata={"help": "A string to begin every sequence with."},
    )
    no_sep: bool = field(
        default=False, metadata={"help": "Don't use a separator token."}
    )
    block_size: int = field(
        default=-1,
        metadata={
            "help": (
                "Optional input sequence length after tokenization."
                "The training dataset will be truncated in block of this size for training."
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
        },
    )


def get_dataset(
    args: DataArguments,
    tokenizer: PreTrainedTokenizer,
    evaluate: bool = False,
    cache_dir: Optional[str] = None,
    training_args: TrainingArguments = None,
):
    file_path = args.eval_data_file if evaluate else args.train_data_file
    max_source_length = args.max_source_length
    max_target_length = args.train_max_target_length if not evaluate else args.eval_max_target_length
    dataset = LineByLineSumTextDataset(
        tokenizer=tokenizer,
        file_path=file_path,
        block_size=1024,
        bos_tok=tokenizer.bos_token,
        eos_tok=tokenizer.eos_token,
        max_source_length=max_source_length,
        max_target_length=max_target_length,
        seq_prefix=args.seq_prefix,
        no_sep=args.no_sep
    )

    return dataset


def finetune():
    # parse args
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    # set seed
    set_seed(training_args.seed)
    # set up model
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    if model_args.use_flash:
        from utils.hf_flash_gpt_2 import GPT2FlashLMHeadModel
        model = GPT2FlashLMHeadModel.from_pretrained(
            model_args.model_name_or_path,
            config=config,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
        )
    # set up tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name)
    # add extra pad token
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    tokenizer.add_special_tokens({"bos_token": "<|startoftext|>"})
    tokenizer.add_special_tokens({"eos_token": "<|endoftext|>"})
    embedding_layer = model.resize_token_embeddings(len(tokenizer))
    # set up data collator
    data_collator = DataCollatorForSumLanguageModeling(tokenizer=tokenizer)
    # set up data sets
    train_dataset = get_dataset(data_args, tokenizer=tokenizer, training_args=training_args)
    eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True)
    # set up trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator
    )
    # launch fine tuning
    trainer.train()
    # save final model
    trainer.save_model()
    trainer.save_state()

if __name__ == "__main__":
    finetune()


================================================
FILE: finetune/textgen/gpt2/generate_demo.py
================================================
import sys
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = sys.argv[1]
device = torch.device("cuda")

# load tokenizer
print("Loading tokenizer ...")
tokenizer = AutoTokenizer.from_pretrained(model_path)

# load model
print("Loading model ...")
model = AutoModelForCausalLM.from_pretrained(sys.argv[1]).to(device)

# run model
print("Generating text ...")
prompt = sys.argv[2]
prompt_w_start = f"{prompt}<|startoftext|>"
encoding = tokenizer.encode(prompt_w_start, return_tensors='pt').to(device)
generated_ids = model.generate(encoding, max_new_tokens=100, eos_token_id=28895)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Input: {prompt}")
print(f"Output: {generated_text[len(prompt):]}")


================================================
FILE: finetune/textgen/gpt2/run_generation_batch.py
================================================

#!/usr/bin/env python3
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
"""


import argparse
import logging

import numpy as np
import torch
import json
import os
from tqdm import tqdm
from torch.utils.data import DataLoader
import time
from rouge_score import rouge_scorer, scoring
import itertools
from transformers import (
    CTRLLMHeadModel,
    CTRLTokenizer,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    TransfoXLLMHeadModel,
    TransfoXLTokenizer,
    XLMTokenizer,
    XLMWithLMHeadModel,
    XLNetLMHeadModel,
    XLNetTokenizer,
    BertForMaskedLM, BertModel,
    BertTokenizer, BertTokenizerFast, AutoConfig,
    set_seed,
    #GPT2LMHeadModelAdapter,
    #LineByLineSumBatchGenTextDataset,
    #DataCollatorForSumBatchGenLanguageModeling,
    AutoModelWithLMHead,
    AutoTokenizer,
)

from sum_data_collator import DataCollatorForSumBatchGenLanguageModeling
from sum_dataset import LineByLineSumBatchGenTextDataset

import sys, os
sys.path.insert(1, '/u/scr/xlisali/contrast_LM/transformers/examples/control')
from train_control import PrefixTuning, PrefixEmbTuning

# imports for wandb
from datetime import datetime
import wandb


logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

MODEL_CLASSES = {
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
    "gpt_neo": (AutoModelWithLMHead, AutoTokenizer),
    "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
    "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
    "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
    "xlm": (XLMWithLMHeadModel, XLMTokenizer),
}

# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""


# def set_seed(args):
#     np.random.seed(args.seed)
#     torch.manual_seed(args.seed)
#     if args.n_gpu > 0:
#         torch.cuda.manual_seed_all(args.seed)


#
# Functions to prepare models' input
#


def prepare_ctrl_input(args, _, tokenizer, prompt_text):
    if args.temperature > 0.7:
        logger.info("CTRL typically works better with lower temperatures (and lower top_k).")

    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
    if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
        logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
    return prompt_text


def prepare_xlm_input(args, model, tokenizer, prompt_text):
    # kwargs = {"language": None, "mask_token_id": None}

    # Set the language
    use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
    if hasattr(model.config, "lang2id") and use_lang_emb:
        available_languages = model.config.lang2id.keys()
        if args.xlm_language in available_languages:
            language = args.xlm_language
        else:
            language = None
            while language not in available_languages:
                language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")

        model.config.lang_id = model.config.lang2id[language]
        # kwargs["language"] = tokenizer.lang2id[language]

    # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
    # XLM masked-language modeling (MLM) models need masked token
    # is_xlm_mlm = "mlm" in args.model_name_or_path
    # if is_xlm_mlm:
    #     kwargs["mask_token_id"] = tokenizer.mask_token_id

    return prompt_text


def prepare_xlnet_input(args, _, tokenizer, prompt_text):
    prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
    prompt_text = prefix + prompt_text
    return prompt_text


def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
    prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
    prompt_text = prefix + prompt_text
    return prompt_text


PREPROCESSING_FUNCTIONS = {
    "ctrl": prepare_ctrl_input,
    "xlm": prepare_xlm_input,
    "xlnet": prepare_xlnet_input,
    "transfo-xl": prepare_transfoxl_input,
}

def read_e2e_files(path, tokenizer, lowdata_token=None):
    file_dict = {}
    with open(path, 'r') as f:
        for line in f:
            src, tgt = line.strip().split('||')
            # URGENT CHANGE
            # src =  src + ' {}'.format(' summarize :')
            if lowdata_token is None:
                src = ' {} {}'.format(src, tokenizer.bos_token)
                # src =  src + ' {}'.format(tokenizer.bos_token)
            else:
                src = ' {} {} {}'.format(lowdata_token, src, tokenizer.bos_token)
            if src not in file_dict:
                file_dict[src] = []
            file_dict[src].append(tgt)
    return file_dict

def read_wp_files(path, tokenizer):
    file_dict = {}
    with open(path, 'r') as f:
        for line in f:
            src, tgt = line.strip().split('|||')
            src = src + ' {}'.format(tokenizer.bos_token)
            if src not in file_dict:
                file_dict[src] = []
            file_dict[src].append(tgt)
    return file_dict


def read_classifySentiment_files(path, tokenizer):
    file_dict = []
    with open(path, 'r') as f:
        for line in f:
            tgt, src = line.strip().split('|||')
            src = src.replace("< br / >", "\n")
            src = ' {} {}'.format(src, tokenizer.bos_token)
            file_dict.append((src, tgt))
    return file_dict

def read_classifyTopic_files(path, tokenizer):
    file_dict = []
    with open(path, 'r') as f:
        for line in f:
            if (len(line) > 0 and not line.isspace()
                    and len(line.split('||')) == 2):
                tgt, src = line.strip().split('||')
            else:
                continue
            src = ' {} {}'.format(src, tokenizer.bos_token)
            file_dict.append((src, tgt))
    return file_dict


# def ids_to_text_without_prompt(tokenizer, generated_ids, prompt):
#     gen_text = tokenizer.batch_decode(
#         generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True
#     )
#     for idx, text in enumerate(gen_text):
#         text_output = text[len(tokenizer.decode(prompt[idx], clean_up_tokenization_spaces=True)):]
#         idx = text_output.find(tokenizer.eos_token)
#     return lmap(str.strip, gen_text)

def lmap(f, x):
    """list(map(f, x))"""
    return list(map(f, x))

def ids_to_clean_text(tokenizer, generated_ids):
    gen_text = tokenizer.batch_decode(
        generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )
    return lmap(str.strip, gen_text)

ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]

def flatten_list(summary_ids):
    return [x for x in itertools.chain.from_iterable(summary_ids)]

def calculate_rouge(output_lns, reference_lns, use_stemmer=True):
    scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
    aggregator = scoring.BootstrapAggregator()

    for reference_ln, output_ln in zip(reference_lns, output_lns):
        scores = scorer.score(reference_ln, output_ln)
        aggregator.add_scores(scores)

    result = aggregator.aggregate()
    return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}

def test_epoch_end(outputs, prefix="test"):
    # losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
    # loss = losses["loss"]
    # print(loss)
    metric_names = ROUGE_KEYS
    generative_metrics = {
        k: np.array([x[k] for x in outputs]).mean() for k in metric_names + ["gen_time", "gen_len"]
    }
    # metric_val = (
    #     generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
    # )
    # metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
    # generative_metrics.update({k: v.item() for k, v in losses.items()})
    losses = {}
    losses.update(generative_metrics)
    all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
    preds = flatten_list([x["preds"] for x in outputs])
    return {
        "log": all_metrics,
        "preds": preds,
        # f"{prefix}_loss": loss,
        # f"{prefix}_{self.val_metric}": metric_tensor,
    }

def test_step(model, gpt2, batch, batch_idx, args, tokenizer, beam_handle, gold_handle, tuning_mode):
    t0 = time.time()
    # TODO(LISA)
    # write the prompt generation from self.model.
    # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
    # get the prompt:
    bsz = batch["input_ids"].size(0)
    # prefix_prompt = model.get_prompt(bsz=bsz,)
    # expand to get bsz * sample_size.
    control_code = None
    print('control code is ', control_code)
    # prompt = model.get_prompt(control_code, gpt2=gpt2, bsz=1)



    # print('the max length of the model is {}'.format(model.config.max_length))

    input_ids = batch["input_ids"] #bsz, seqlen
    seqlen = len(input_ids[0])
    # bos_seq = torch.ones(bsz, 1).fill_(tokenizer.bos_token_id)
    input_attn  = batch["src_attn"].to(gpt2.device)

    if tuning_mode == "prefixtune":
        prompt = model.get_prompt(bsz=1)
        num_beamsize = 5
        prompt = [x.expand(-1, num_beamsize*bsz, -1, -1, -1) for x in prompt]
        prefix_attn = torch.ones(bsz, model.config.preseqlen).long().to(gpt2.device)
        input_attn = torch.cat([prefix_attn, input_attn], dim=-1)
    elif tuning_mode == "finetune":
        prompt = None
    else:
        raise NotImplementedError

    # input_ids = torch.cat([input_ids, bos_seq], dim=-1)
    # print(input_ids.shape)
    # print(input_ids.shape, input_attn.shape)

    # torch.set_printoptions(profile="full")
    # print(input_ids)
    # print(input_attn)
    # torch.set_printoptions(profile="default")
    # print(prompt[5][0][0][0])
    if args.fp16:
        prompt = [p.half() for p in prompt] if prompt is not None else None
        # input_attn = input_attn.half()

    with torch.cuda.amp.autocast(args.fp16):
        generated_ids = gpt2.generate(
            input_ids=input_ids.to(gpt2.device),
            emb_match=None,
            control_code=None,
            past_key_values=prompt,
            attention_mask=input_attn,
            #use_prefix_test=True,
            max_length=args.length + seqlen, # what is self.eval_max_length
            min_length=5,
            temperature=args.temperature,
            top_k=args.k,
            top_p=0.9,  # top_p=0.5,
            no_repeat_ngram_size=args.no_repeat_ngram_size, #add
            length_penalty=args.length_penalty, #add
            repetition_penalty=args.repetition_penalty,  ##args.repetition_penalty,
            do_sample=False,
            num_beams=5,
            bad_words_ids=[[628], [198]] if True else None,
            num_return_sequences=1,

        )
    # clean up generated_ids
    bsz, seqlen = input_ids.shape
    generated_ids = generated_ids[:,seqlen:]
    # print(generated_ids)

    # generated_ids = gpt2.generate(
    #     batch["input_ids"],
    #     past_key_values=prefix_prompt,
    #     attention_mask=batch["attention_mask"],
    #     use_cache=True,
    #     use_prefix=True,
    #     decoder_start_token_id=self.decoder_start_token_id,
    #     num_beams=self.eval_beams,
    #     max_length=self.eval_max_length,
    # )
    gen_time = (time.time() - t0) / batch["input_ids"].shape[0]

    preds: List[str] = ids_to_clean_text(tokenizer, generated_ids)
    # src: List[str] = ids_to_clean_text(tokenizer, input_ids)
    # print(src)
    target: List[str] = ids_to_clean_text(tokenizer, batch["labels"])
    # print(preds)
    # print(target)
    # loss_tensors = self._step(batch)
    # base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
    # print('INPUT:', self.ids_to_clean_text(batch["input_ids"]))
    # print(preds, target)

    for predd in preds:
        print(predd, file=beam_handle)

    for tgtt in target:
        print(tgtt, file=gold_handle)
    beam_handle.flush()
    gold_handle.flush()

    base_metrics = {}
    rouge: Dict = calculate_rouge(preds, target)
    summ_len = np.mean(lmap(len, generated_ids))
    base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
    return base_metrics


def read_webnlg_files(path, tokenizer):
    file_dict = {}

    with open(path) as f:
        lines_dict = json.load(f)

    full_rela_lst = []
    full_src_lst = []
    # full_tgt_lst = []
    total_count = 0
    for i, example in enumerate(lines_dict['entries']):
        sents = example[str(i + 1)]['lexicalisations']
        triples = example[str(i + 1)]['modifiedtripleset']

        rela_lst = []
        temp_triples = ''
        for j, tripleset in enumerate(triples):
            subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object']
            rela_lst.append(rela)
            if i > 0:
                temp_triples += ' | '
            temp_triples += '{} : {} : {}'.format(subj, rela, obj)

        temp_triples = ' {} {}'.format(temp_triples, tokenizer.bos_token)


        for sent in sents:
            if True: #sent["comment"] == 'good'
                if (temp_triples,tuple(rela_lst)) not in file_dict:
                    file_dict[(temp_triples,tuple(rela_lst))] = []
                    full_src_lst.append(temp_triples)
                    full_rela_lst.append(tuple(rela_lst))
                file_dict[(temp_triples,tuple(rela_lst))].append(sent["lex"])


    print(len(file_dict), len(full_src_lst))
    assert len(full_rela_lst) == len(full_src_lst)
    assert len(full_rela_lst) == len(file_dict)

    return file_dict


def read_triples_files2(path, tokenizer):
    file_src = []
    file_tgt = []

    with open(path) as f:
        lines_dict = json.load(f)

    print(len(lines_dict))
    full_rela_lst = []
    full_src_lst = []
    for example in lines_dict:
        rela_lst = []
        temp_triples = ''
        for i, tripleset in enumerate(example['tripleset']):
            subj, rela, obj = tripleset
            rela = rela.lower()
            rela_lst.append(rela)
            if i > 0:
                temp_triples += ' | '
            temp_triples += '{} : {} : {}'.format(subj, rela, obj)

        temp_triples = ' {} {}'.format(temp_triples, tokenizer.bos_token)

        file_src.append((temp_triples, tuple(rela_lst)))
        # file_tgt

        for sent in example['annotations']:
            if (temp_triples, tuple(rela_lst)) not in file_dict:
                file_dict[(temp_triples, tuple(rela_lst))] = []
                full_src_lst.append(temp_triples)
                full_rela_lst.append(tuple(rela_lst))
            file_dict[(temp_triples, tuple(rela_lst))].append(sent['text'])

    print(len(file_dict), len(full_src_lst))
    assert len(full_rela_lst) == len(full_src_lst)
    assert len(full_rela_lst) == len(file_dict)
    return file_dict

def read_triples_files(path, tokenizer):
    file_dict = {}

    with open(path) as f:
        lines_dict = json.load(f)

    print(len(lines_dict))
    full_rela_lst = []
    full_src_lst = []
    for example in lines_dict:
        rela_lst = []
        temp_triples = ''
        for i, tripleset in enumerate(example['tripleset']):
            subj, rela, obj = tripleset
            rela = rela.lower()
            rela_lst.append(rela)
            if i > 0:
                temp_triples += ' | '
            temp_triples += '{} : {} : {}'.format(subj, rela, obj)

        temp_triples = ' {} {}'.format(temp_triples, tokenizer.bos_token)

        for sent in example['annotations']:
            if (temp_triples, tuple(rela_lst)) not in file_dict:
                file_dict[(temp_triples, tuple(rela_lst))] = []
                full_src_lst.append(temp_triples)
                full_rela_lst.append(tuple(rela_lst))
            file_dict[(temp_triples, tuple(rela_lst))].append(sent['text'])

    print(len(file_dict), len(full_src_lst))
    assert len(full_rela_lst) == len(full_src_lst)
    assert len(full_rela_lst) == len(file_dict)
    return file_dict

# def write_e2e_corr(prompt_lst, file_dict, corr_path):
#     with open(corr_path, 'w') as f:
#         for x in prompt_lst:
#             for line in file_dict[x]:
#                 print(line, file=f)
#             print('', file=f)
#     return

def write_e2e_corr(prompt_lst, file_dict, corr_path):
    print(len(prompt_lst))
    with open(corr_path, 'w') as f:
        for x in prompt_lst:
            for line in file_dict[x]:
                if not line.strip():
                    print('PROBLEM', line,'PROBLEM',file_dict[x] )
                else:
                    print(line, file=f)
            print('', file=f)

    # buf = [[]]
    # with open(corr_path, 'r') as fh:
    #     for line in fh:
    #         line = line.strip()
    #         if True:
    #             # print(line)
    #             if not line:
    #                 buf.append([])
    #             else:
    #                 buf[-1].append(line)
    #         else:
    #             buf.append(line)
    # if not buf[-1]:
    #     del buf[-1]
    #
    # print(buf[:3])
    #
    # print(len(buf))

    return

def write_e2e_src(prompt_lst, corr_path):
    with open(corr_path, 'w') as f:
        for x in prompt_lst:
            print(x, file=f)
    return



def get_emb(sent_lst, word_lst, num_layer=1):
    # load bert
    tokenizer_bert = BertTokenizerFast.from_pretrained('bert-large-uncased')
    model = BertModel.from_pretrained('bert-large-uncased', return_dict=True).cuda()
    for param in model.parameters():
        param.requires_grad = False

    device = model.device

    edited_sent = []
    chosen_word = []
    with torch.no_grad():
        computed_ = 0
        mid_ = 300
        full_score = []
        while computed_ < len(sent_lst):
            temp_sent = sent_lst[computed_:computed_ + mid_]
            temp_word = word_lst[computed_:computed_ + mid_]
            temp_input = tokenizer_bert(temp_sent, return_tensors="pt", padding=True,
                                        is_split_into_words=False, return_offsets_mapping=True, add_special_tokens=True)
            input_ids = temp_input["input_ids"]
            # print(temp_input.keys())
            mask_input = temp_input['attention_mask']
            bsz, seqlen = input_ids.shape

            # print(input_ids.shape)

            cand_idx = tokenizer_bert(temp_word, add_special_tokens=False)['input_ids']
            # print(cand_idx)
            # if BPE has multiple subwords.
            cand_idx = torch.tensor([i[-1] for i in cand_idx])  # bsz
            # print(cand_idx)
            cand_idx2 = cand_idx.unsqueeze(1).expand(bsz, seqlen)

            mask = (input_ids == cand_idx2)
            # print(mask.sum(dim=1))
            # print(mask.nonzero())

            # what if the occurence of a subword is not in the primary word?

            # if has multiple occurence? only taking the first one.
            mask = (mask.cumsum(dim=1) == 1) & mask
            # print(mask)
            # print(mask.sum(dim=1))
            # print(mask.nonzero())
            mask_idx = mask.nonzero()

            # print(input_ids.shape)

            edit_temp = []
            keep_mask = []
            word_temp = []
            for i, (sent1, word1) in enumerate(zip(temp_sent, temp_word)):
                # TODO: could check against the offests and make final changes!
                temp_idx1 = temp_input["offset_mapping"][i][mask_idx[i, 1]]
                # print(word1, sent1)
                # print(sent1[temp_idx1[0]:temp_idx1[1]])
                sent1 = sent1.split()
                widx = sent1.index(word1)
                by_tokenl = sum([len(l) + 1 for l in sent1[:widx]])
                by_tokenr = sum([len(l) + 1 for l in sent1[:widx + 1]]) - 1
                # print(by_tokenl, by_tokenr, temp_idx1)
                if by_tokenl != temp_idx1[0].item() and by_tokenr != temp_idx1[1].item():
                    # print('dangerous')
                    # print(sent1, word1, by_tokenl, by_tokenr, temp_idx1)
                    # simple option: delete it form input_ids
                    keep_mask.append(False)
                    continue
                else:
                    keep_mask.append(True)
                new_sent = [word1, '[BOS]'] + sent1[:widx] + ['[', sent1[widx], ']'] + sent1[widx + 1:] + ['[EOS]']
                assert len(new_sent) == len(sent1) + 5
                edit_temp.append(new_sent)
                word_temp.append(word1)

            keep_mask = torch.tensor(keep_mask)
            # print(keep_mask.shape, input_ids.shape, mask.shape, 'hi')
            input_ids = input_ids[keep_mask]
            mask = mask[keep_mask]
            mask_input = mask_input[keep_mask]
            # print(input_ids.shape, mask.shape, len(edit_temp))
            assert input_ids.size(0) == len(edit_temp)

            edited_sent += edit_temp
            chosen_word += word_temp
            # print(len(edited_sent), len(chosen_word))

            outputs = model(input_ids.to(device), attention_mask=mask_input.to(device), output_hidden_states=True)

            if num_layer > 1:
                all_hidden_states = outputs.hidden_states
                selected_all_hidden_states = [ii[mask] for ii in all_hidden_states[-num_layer:]]
                # print([ii.shape for ii in selected_all_hidden_states])
                hidden_layer = torch.stack(selected_all_hidden_states, dim=1)
                # print(hidden_layer.shape, selected_all_hidden_states[0].shape)
                # print('all hidden', selected_all_hidden_states.shape)

            else:
                last_hidden_states = outputs.last_hidden_state
                hidden_layer = last_hidden_states[mask].unsqueeze(1)


            computed_ += mid_
            full_score.append(hidden_layer.cpu())

        full_score = torch.cat(full_score, dim=0)

    return full_score, edited_sent, chosen_word

def adjust_length_to_model(length, max_sequence_length):
    if length < 0 and max_sequence_length > 0:
        length = max_sequence_length
    elif 0 < max_sequence_length < length:
        length = max_sequence_length  # No generation bigger than model size
    elif length < 0:
        length = MAX_LENGTH  # avoid infinite loop
    return length


def read_doc_for_embmatch(file_name, num_layer):
    word_lst = []
    sent_lst = []
    with open(file_name, 'r') as f:
        for line in f:
            word, sent = line.strip().split('||')
            word_lst.append(word)
            sent_lst.append(sent)

    emb_match, sent_cleaned_lst, chosen_word = get_emb(sent_lst, word_lst, num_layer=num_layer)
    prompt_text_lst = [word + ' [BOS]' for word in chosen_word]
    return prompt_text_lst, emb_match.split(1), sent_cleaned_lst


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=False,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )

    parser.add_argument(
        "--tokenizer_name",
        default=None,
        type=str,
        required=False,
        help="Path to pre-trained tokenizer or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )

    parser.add_argument(
        "--prefixModel_name_or_path",
        default=None,
        type=str,
        required=False,
        help="Path to pre-trained PrefixTuning Model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )

    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--cache_dir", type=str, default=None)
    parser.add_argument("--task_mode", type=str, default="embMatch")
    parser.add_argument("--control_mode", type=str, default="yes")
    parser.add_argument("--prefix_mode", type=str, default="activation")
    parser.add_argument("--length", type=int, default=20)
    parser.add_argument("--gen_dir", type=str, default="e2e_results_conv")
    parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")

    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
    )
    parser.add_argument(
        "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
    )

    parser.add_argument("--no_repeat_ngram_size", type=int, default=0)
    parser.add_argument("--length_penalty", type=float, default=1.0)
    parser.add_argument("--k", type=int, default=0)
    parser.add_argument("--p", type=float, default=0.9)

    parser.add_argument("--batch_size", type=int, default=4)

    parser.add_argument("--tuning_mode", type=str, default="finetune", help="prefixtune or finetune")
    parser.add_argument("--objective_mode", type=int, default=2)
    parser.add_argument("--format_mode", type=str, default="peek", help="peek, cat, nopeek, or infix")
    parser.add_argument("--optim_prefix", type=str, default="no", help="optim_prefix")
    parser.add_argument("--preseqlen", type=int, default=5, help="preseqlen")

    parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
    parser.add_argument("--control_dataless", type=str, default="no", help="control dataless mode")
    parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
    parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )

    parser.add_argument("--use_task_instruction", type=int, default=0, help="")
    parser.add_argument("--max_source_length", type=int, default=-1, help="")
    parser.add_argument("--wandb_entity", type=str, default=None)
    parser.add_argument("--wandb_project", type=str, default=None)
    parser.add_argument("--wandb_run_name", type=str, default=None)

    args = parser.parse_args()

    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

    logger.warning(
        "device: %s, n_gpu: %s, 16-bits training: %s",
        args.device,
        args.n_gpu,
        args.fp16,
    )

    # initialize wandb run
    if args.wandb_entity and args.wandb_project and args.wandb_run_name:
        wandb_run = wandb.init(
                        entity=args.wandb_entity, 
                        project=args.wandb_project,
                        name=args.wandb_run_name
                    )
        wandb_run.summary["start_time"] = str(datetime.now())
    else:
        wandb_run = None

    set_seed(args.seed)

    # Initialize the model and tokenizer
    if args.model_type is None:
        from transformers import AutoConfig
        _config = AutoConfig.from_pretrained(args.model_name_or_path)
        args.model_type = _config.model_type

    if args.tuning_mode == 'finetune':
        print(args.tuning_mode, args.model_type, args.model_name_or_path)
        try:
            args.model_type = args.model_type.lower()
            model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
        except KeyError:
            raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

        if args.model_name_or_path:
            print('loading the trained tokenizer')
            tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
        elif args.tokenizer_name:
            print('loading from the init tokenizer')
            tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)

        # tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

        print(len(tokenizer), tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token)
        config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
        config.use_cache = True
        print(config)
        model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=args.cache_dir)
        model.to(args.device)
        gpt2 = model

    elif args.tuning_mode == 'adaptertune':
        print(args.tuning_mode, args.model_name_or_path)

        try:
            args.model_type = args.model_type.lower()
            _, tokenizer_class = MODEL_CLASSES[args.model_type]
        except KeyError:
            raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

        if args.model_name_or_path:
            print('loading the trained tokenizer')
            tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
        elif args.tokenizer_name:
            print('loading from the init tokenizer')
            tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)

        print(len(tokenizer), tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token)
        config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
        config.use_cache = True
        print(config)
        model = GPT2LMHeadModelAdapter.from_pretrained(
            args.model_name_or_path,
            config=config,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            cache_dir=args.cache_dir,
        )

        model.to(args.device)
        args.tuning_mode = 'finetune'

    elif args.tuning_mode == 'bothtune':
        print(args.tuning_mode, args.model_name_or_path, args.prefixModel_name_or_path)
        try:
            args.model_type = args.model_type.lower()
            model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
        except KeyError:
            raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

        if args.prefixModel_name_or_path:
            print('loading the trained tokenizer')
            tokenizer = tokenizer_class.from_pretrained(args.prefixModel_name_or_path, cache_dir=args.cache_dir)
        elif args.tokenizer_name:
            print('loading from the init tokenizer')
            assert False, "should load from the prefixModel_name_or_path tokenizer"
            tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)

            # tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

        print(len(tokenizer), tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token)
        config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
        config.use_cache = True
        print(config)
        model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=args.cache_dir)
        model.to(args.device)
        gpt2 = model


        print('loading from PrefixTuning.', args.prefixModel_name_or_path, )
        if args.optim_prefix == 'yes':
            optim_prefix_bool = True
        elif args.optim_prefix == 'no':
            optim_prefix_bool = False
        else:
            assert False, "model_args.optim_prefix should be either yes or no"

        if args.prefixModel_name_or_path is not None:
            config = AutoConfig.from_pretrained(args.prefixModel_name_or_path, cache_dir=args.cache_dir)
            config.use_cache = True
            print(config)

            if args.prefix_mode == 'embedding':
                model = PrefixEmbTuning.from_pretrained(
                    args.prefixModel_name_or_path,
                    from_tf=bool(".ckpt" in args.prefixModel_name_or_path, ),
                    config=config,
                    model_gpt2=gpt2, optim_prefix=optim_prefix_bool, preseqlen=args.preseqlen,
                    use_infix=(args.format_mode == 'infix')
                )

            elif args.prefix_mode == 'activation':

                model = PrefixTuning.from_pretrained(
                    args.prefixModel_name_or_path,
                    from_tf=bool(".ckpt" in args.prefixModel_name_or_path, ),
                    config=config,
                    model_gpt2=gpt2, optim_prefix=optim_prefix_bool, preseqlen=args.preseqlen,
                    use_infix=(args.format_mode == 'infix')
                )

            model.to(args.device)




    elif args.tuning_mode == 'prefixtune':

        print('loading from PrefixTuning.', args.prefixModel_name_or_path,)
        if args.model_name_or_path:
            config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
            config.use_cache = True
        else:
            assert False, 'shouldn not init config from scratch. '
            config = CONFIG_MAPPING[args.model_type]()
            config.use_cache = True
            logger.warning("You are instantiating a new config instance from scratch.")

        try:
            args.model_type = args.model_type.lower()
            model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
        except KeyError:
            raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

        if args.model_name_or_path:
            print('loading the trained tokenizer')
            tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
        elif args.tokenizer_name:
            print('loading from the init tokenizer')
            tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)

        # TODAYFIX.
        config._my_arg_tune_mode = args.tuning_mode
        config._my_arg_task_mode = args.task_mode
        config._objective_mode = args.objective_mode
        model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=args.cache_dir)
        model.to(args.device)

        print(len(tokenizer), tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token)

        # TODO LISA
        add_pad = False

        if args.model_name_or_path == 'gpt2-medium':
            if args.task_mode == 'dataless':
                print(args.tuning_mode, 'dataless setting, so no new tokens at all.')
                print('We do not add special tokens to the tokenizer, instead, we just finetune on <|endoftext|>')

                print(tokenizer.eos_token_id)
                print(tokenizer.eos_token)
                print(tokenizer.pad_token_id)
                tokenizer.pad_token = tokenizer.eos_token
                print(tokenizer.pad_token, tokenizer.pad_token_id)

            elif add_pad:
                print('extending the size of word embeddings. to include the [PAD] ')
                num_added_tokens = tokenizer.add_special_tokens(
                    {'pad_token': '[PAD]'})
                embedding_layer = model.resize_token_embeddings(len(tokenizer))
            else:
                print(tokenizer.eos_token_id)
                print(tokenizer.eos_token)
                print(tokenizer.pad_token_id)
                tokenizer.pad_token = tokenizer.eos_token
                print(tokenizer.pad_token, tokenizer.pad_token_id)


            ########################################3

        print(len(tokenizer), tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token)


        gpt2 = model

        # config._my_arg_task_mode = args.task_mode
        # config._my_arg_control = True
        # config.train_weights = 'no'
        print(config)
        if args.optim_prefix == 'yes':
            optim_prefix_bool = True
        elif args.optim_prefix == 'no':
            optim_prefix_bool = False
        else:
            assert False, "model_args.optim_prefix should be either yes or no"

        if args.prefixModel_name_or_path is not None:

            #################
            #
            config = AutoConfig.from_pretrained(args.prefixModel_name_or_path, cache_dir=args.cache_dir )
            config.use_cache = True
            print(config)

            if args.prefix_mode == 'embedding':
                model = PrefixEmbTuning.from_pretrained(
                    args.prefixModel_name_or_path,
                    from_tf=bool(".ckpt" in args.prefixModel_name_or_path, ),
                    config=config,
                    model_gpt2=gpt2, optim_prefix=optim_prefix_bool, preseqlen=args.preseqlen,
                    use_infix=(args.format_mode == 'infix')
                )

            elif args.prefix_mode == 'activation':

                model = PrefixTuning.from_pretrained(
                    args.prefixModel_name_or_path,
                    from_tf=bool(".ckpt" in args.prefixModel_name_or_path, ),
                    config=config,
                    model_gpt2=gpt2, optim_prefix=optim_prefix_bool, preseqlen=args.preseqlen,
                    use_infix=(args.format_mode == 'infix')
                )
            #
            ######################

            # model = PrefixTuning.from_pretrained(
            #     args.prefixModel_name_or_path,
            #     from_tf=bool(".ckpt" in args.prefixModel_name_or_path,),
            #     config=config,
            #     model_gpt2=gpt2, optim_prefix=optim_prefix_bool, preseqlen=args.preseqlen,
            # )
            model.to(args.device)

            # print('-'*100)
            # print(model.training)
            # print(gpt2.training)
            # model.train()
            # gpt2.train()
            # print(model.training)
            # print(gpt2.training)
            # model.eval()
            # gpt2.eval()
            # print(model.training)
            # print(gpt2.training)
            # print('-' * 100)

        else:
            assert False, "prefixModel_name_or_path is NONE."



    # if args.fp16:
    #     model.half()

    args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
    logger.info(args)

    if args.task_mode == 'data2text':

        QUICK_CHECK = False

        if QUICK_CHECK:

            prompt_text_lst = [
                "name : Blue Spice | Type : coffee shop | area : city centre {}".format(tokenizer.bos_token),
                "name : Blue Spice | Type : coffee shop | customer rating : 5 out of 5 {}".format(tokenizer.bos_token),
                "name : Blue Spice | Type : pub | food : Chinese | area : city centre | family friendly : no {}".format(tokenizer.bos_token),
                "name : Blue Spice | Type : restaurant | food : Chinese | area : city centre | family friendly : yes | near : Rainbow Vegetarian Café {}".format(tokenizer.bos_token),
                "name : Giraffe | Type : restaurant | food : Fast food | area : riverside | family friendly : no | near : Rainbow Vegetarian Café {}".format(tokenizer.bos_token),
                "name : The Cricketers | Type : coffee shop | customer rating : 1 out of 5 | family friendly : yes | near : Avalon {}".format(tokenizer.bos_token),
                "name : The Cricketers | Type : restaurant | food : Chinese | price : high | customer rating : 1 out of 5 | area : city centre | family friendly : no {}".format(tokenizer.bos_token),
                "name : The Mill | Type : restaurant | food : English | price : moderate | area : riverside | family friendly : yes | near : Raja Indian Cuisine {}".format(tokenizer.bos_token),

            ]
            decode_mode = 'beam'

        else:
            # TODO.LISA
            # test_path = '/u/scr/xlisali/e2e_data/contain_near_Type_src1_test.txt'
            if ('lowdata' in args.model_name_or_path) or (args.prefixModel_name_or_path is not None and 'lowdata' in args.prefixModel_name_or_path):
                test_path = '/u/scr/xlisali/e2e_data/src1_valid.txt'
            else:
                test_path = '/u/scr/xlisali/e2e_data/src1_test.txt'

            print('using the test path ', test_path)
            # test_path = '/u/scr/xlisali/e2e_data/src1_valid.txt'
            if args.prefixModel_name_or_path is not None:
                temp = os.path.basename(args.prefixModel_name_or_path)
            else:
                temp = os.path.basename(args.model_name_or_path)

            if 'lowdata' in temp and 'finetune' in temp:
                lowdata_token = temp.split('_t=')[1].split('-checkpoint-')[0]
                print('the LOWDATA token is {}'.format(lowdata_token))
            else:
                lowdata_token = None
            prompt_text_dict = read_e2e_files(test_path, tokenizer, lowdata_token)

            # print(prompt_text_dict)
            prompt_text_lst = list(prompt_text_dict.keys())
            split_file = 'valid'
            decode_mode = 'beam'
            curr_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, decode_mode))
            print(curr_dir)
            gold_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file,'gold'))
            print(gold_dir)
            write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
            src_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                   args.gen_dir,
                                   '{}_{}_{}'.format(temp,split_file, 'src'))
            write_e2e_src(prompt_text_lst, src_dir)
            out_handle = open(curr_dir, 'w')


    elif args.task_mode == 'webnlg' or args.task_mode == 'triples':
        QUICK_CHECK = False
        if args.task_mode == 'webnlg':
            # test_path = "/u/scr/xlisali/WebNLG/webnlg-dataset/release_v2/json/webnlg_release_v2_test.json"
            test_path = "/u/scr/xlisali/WebNLG/webnlg-dataset/webnlg_challenge_2017/test.json"
            prompt_text_dict = read_webnlg_files(test_path, tokenizer)
        elif args.task_mode == 'triples':
            test_path = "/u/scr/xlisali/DART/dart/data/v1.1.1/dart-v1.1.1-full-test.json"
            prompt_text_dict = read_triples_files(test_path, tokenizer)

        if QUICK_CHECK:
            prompt_text_pair = list(prompt_text_dict.keys())[:20]
            prompt_text_lst, prompt_rela_lst = zip(*prompt_text_pair)
            decode_mode = 'beam'

        else:
            prompt_text_pair = list(prompt_text_dict.keys())
            prompt_text_lst, prompt_rela_lst = zip(*prompt_text_pair)
            if args.prefixModel_name_or_path is not None:
                temp = os.path.basename(args.prefixModel_name_or_path)
            else:
                temp = os.path.basename(args.model_name_or_path)
            # print(prompt_text_dict)
            split_file = 'test' # test
            decode_mode = 'beam'
            curr_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, decode_mode))

            print(curr_dir)
            gold_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, 'gold'))

            print(gold_dir)
            write_e2e_corr(prompt_text_pair, prompt_text_dict, gold_dir)
            src_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, 'src'))

            write_e2e_src(prompt_text_pair, src_dir)


            out_handle = open(curr_dir, 'w')

    elif args.task_mode == 'writingPrompts':
        QUICK_CHECK = True
        test_path = "/juice/u/xlisali/WritingPrompts/writingPrompts/test_small.txt"
        prompt_text_dict = read_wp_files(test_path, tokenizer)
        args.num_return_sequences = 1

        if QUICK_CHECK:
            prompt_text_lst = list(prompt_text_dict.keys())[:20]
            print(prompt_text_lst)
            decode_mode = 'nucleus'

        else:
            prompt_text_pair = list(prompt_text_dict.keys())
            prompt_text_lst, prompt_rela_lst = zip(*prompt_text_pair)
            if args.prefixModel_name_or_path is not None:
                temp = os.path.basename(args.prefixModel_name_or_path)
            else:
                temp = os.path.basename(args.model_name_or_path)
            # print(prompt_text_dict)
            split_file = 'test' # test
            decode_mode = 'beam'
            curr_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, decode_mode))

            print(curr_dir)
            gold_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, 'gold'))

            print(gold_dir)
            write_e2e_corr(prompt_text_pair, prompt_text_dict, gold_dir)

            src_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                   args.gen_dir,
                                   '{}_{}_{}'.format(temp, split_file, 'src'))

            write_e2e_src(prompt_text_pair, src_dir)
            out_handle = open(curr_dir, 'w')


    elif args.task_mode == 'sentiment' or args.task_mode == 'topic':
        QUICK_CHECK = False
        args.num_return_sequences = 3

        if QUICK_CHECK:
            prompt_text_lst = [" positive {}".format(tokenizer.bos_token)] * 10  + [" negative {}".format(tokenizer.bos_token)] * 10
            print(prompt_text_lst)
            decode_mode = 'nucleus'

        else:
            #UNCHECKED
            topic_prompt_pplm_lst = ['In summary', 'This essay discusses', 'Views on', 'The connection',
                               'Foundational to this is', 'To review', 'In brief', 'An illustration of', 'Furthermore',
                               'The central theme', 'To conclude', 'The key aspect', 'Prior to this', 'Emphasised are',
                               'To summarize', 'The relationship', 'More importantly', 'It has been shown',
                               'The issue focused on', 'In this essay']

            sent_prompt_pplm_lst = ['Once upon a time', 'The book', 'The chicken', 'The city', 'The country', 'The horse',
                               'The lake', 'The last time']

            if args.task_mode == 'topic':
                pplm_lst = topic_prompt_pplm_lst
                prompt_text_lst = []
                for i in range(len(pplm_lst)):
                    prompt_text_lst.append(" business {} {}".format(tokenizer.bos_token, pplm_lst[i]))
                    prompt_text_lst.append(" sports {} {}".format(tokenizer.bos_token, pplm_lst[i]))
                    prompt_text_lst.append(" science {} {}".format(tokenizer.bos_token, pplm_lst[i]))
                    prompt_text_lst.append(" world {} {}".format(tokenizer.bos_token, pplm_lst[i]))
            else:
                pplm_lst = sent_prompt_pplm_lst
                prompt_text_lst = []
                for i in range(len(pplm_lst)):
                    prompt_text_lst.append(" positive {} {}".format(tokenizer.bos_token, pplm_lst[i]))
                    prompt_text_lst.append(" negative {} {}".format(tokenizer.bos_token, pplm_lst[i]))

            if args.prefixModel_name_or_path is not None:
                temp = os.path.basename(args.prefixModel_name_or_path)
            else:
                temp = os.path.basename(args.model_name_or_path)
            split_file = 'test' # test
            decode_mode = 'nucleus'

            curr_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, decode_mode))
            print(curr_dir)

            src_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                   args.gen_dir,
                                   '{}_{}_{}'.format(temp, split_file, 'src'))


            write_e2e_src(prompt_text_lst, src_dir)
            out_handle = open(curr_dir, 'w')


    elif args.task_mode == 'classify-sentiment' or args.task_mode == 'classify-topic':
        QUICK_CHECK = False
        if args.task_mode == 'classify-sentiment':
            test_path = "/u/scr/xlisali/IMDB/test.txt"
            prompt_text_dict = read_classifySentiment_files(test_path, tokenizer)
        elif args.task_mode == 'classify-topic':
            test_path = "/u/scr/xlisali/contrast_LM/transformers/examples/text-classification/glue_data/AG-news/dev1.tsv"
            prompt_text_dict = read_classifyTopic_files(test_path, tokenizer)

        args.num_return_sequences = 1

        if QUICK_CHECK:
            prompt_text_lst, prompt_text_tgt = zip(*prompt_text_dict)
            prompt_text_lst = prompt_text_lst[:20]
            print(prompt_text_lst)
            decode_mode = 'greedy'

        else:
            #UNCHECKED
            prompt_text_lst, prompt_text_tgt = zip(*prompt_text_dict)
            if args.prefixModel_name_or_path is not None:
                temp = os.path.basename(args.prefixModel_name_or_path)
            else:
                temp = os.path.basename(args.model_name_or_path)
            # print(prompt_text_dict)
            split_file = 'test' # test
            decode_mode = 'greedy'
            curr_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, decode_mode))

            print(curr_dir)
            gold_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                    args.gen_dir,
                                    '{}_{}_{}'.format(temp, split_file, 'gold'))

            print(gold_dir)
            write_e2e_src(prompt_text_tgt, gold_dir)
            src_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
                                   args.gen_dir,
                                   '{}_{}_{}'.format(temp, split_file, 'src'))

            write_e2e_src(prompt_text_lst, src_dir)
            out_handle = open(curr_dir, 'w')

            print('the total length of generation should be {}'.format(len(prompt_text_lst)))




    else: #elif args.task_mode in ['cnndm', 'xsum', 'bioleaflets', 'medparasimp']:
        QUICK_CHECK = False
        if args.task_mode == 'cnndm':
            # test_path = "/u/scr/xlisali/WebNLG/webnlg-dataset/release_v2/json/webnlg_release_v2_test.json"
            test_path = "/u/scr/xlisali/contrast_LM/transformers/examples/seq2seq/cnn_dm/test.source"
            max_source_length = 512
            max_target_length = 142
            args.length = max_target_length
            # prompt_text_dict = read_sum_files(test_path, tokenizer, max_source_len, max_target_len)
        elif args.task_mode == 'xsum':
            test_path = "../data/xsum/test.source"
            max_source_length = 512
            max_target_length = 100
            args.length = max_target_length
            # prompt_text_dict = read_sum_files(test_path, tokenizer, max_source_len, max_target_len)
        elif args.task_mode == 'bioleaflets':
            test_path = "../data/bioleaflets/test.source"
            max_source_length = 512 - 2 - args.preseqlen//2
            max_target_length = 512
            # args.length = max_target_length
        elif args.task_mode == 'medparasimp' or args.task_mode == 'meqsum':
            test_path = f"data/{args.task_mode}/val.source"
            if args.max_source_length < 0:
                max_source_length = 512
            else:
                max_source_length = args.max_source_length
            max_target_length = 512
            # args.length = max_target_length
        else:
            test_path = f"../data/{args.task_mode}/test.source"
            assert os.path.exists(test_path)
            if args.max_source_length < 0:
                max_source_length = 512
            else:
                max_source_length = args.max_source_length
            max_target_length = 1024


        test_tgt_path = test_path[:-6] + "target"

        tokenizer.padding_side = "left"

        print(tokenizer.eos_token_id)
        print(tokenizer.eos_token)
        print(tokenizer.pad_token_id)
        tokenizer.pad_token = tokenizer.eos_token
        print(tokenizer.pad_token, tokenizer.pad_token_id)

        dataset = LineByLineSumBatchGenTextDataset(tokenizer=tokenizer, file_path=test_path,
                                           block_size=1024, bos_tok=tokenizer.bos_token,
                                           eos_tok=tokenizer.eos_token, max_source_length=max_source_length,
                                           max_target_length=max_target_length, use_task_instruction=args.use_task_instruction)


        data_collator = DataCollatorForSumBatchGenLanguageModeling(
            tokenizer=tokenizer, mlm=False, mlm_probability=0.0,max_source_length=max_source_length,
            max_target_length=max_target_length,
        )

        # prompt_text_pair = list(prompt_text_dict.keys())
        # prompt_text_lst, prompt_rela_lst = zip(*prompt_text_pair)
        if args.prefixModel_name_or_path is not None:
            # temp = os.path.basename(args.prefixModel_name_or_path)
            temp = args.prefixModel_name_or_path
        else:
            # temp = os.path.basename(args.model_name_or_path)
            temp = args.model_name_or_path
        # # print(prompt_text_dict)
        split_file = 'test'  # test
        decode_mode = 'beam'
        # curr_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
        #                         args.gen_dir,
        #                         '{}_{}_{}_batch'.format(temp, split_file, decode_mode))
        os.system(f"mkdir -p {temp}/{args.gen_dir}")
        curr_dir = os.path.join(temp, args.gen_dir, '{}_{}.txt'.format(split_file, decode_mode))
        #
        # print(curr_dir)
        # gold_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
        #                         args.gen_dir,
        #                         '{}_{}_{}_batch'.format(temp, split_file, 'gold'))
        gold_dir = os.path.join(temp, args.gen_dir, '{}_{}.txt'.format(split_file, 'gold'))
        #
        # print(gold_dir)
        # write_e2e_corr(prompt_text_pair, prompt_text_dict, gold_dir)
        # src_dir = os.path.join('/u/scr/xlisali/contrast_LM/transformers/examples/text-generation/',
        #                        args.gen_dir,
        #                        '{}_{}_{}'.format(temp, split_file, 'src'))
        #
        # write_e2e_src(prompt_text_pair, src_dir)
        #
        out_handle_beam = open(curr_dir, 'w')
        out_handle_gold = open(gold_dir, 'w')



    if args.control_mode == 'yes':
        print('processing control codes')


    # Since we are doing batch processing, should use data loader and batch it, rather than using these for-loops.
    data_loader = DataLoader(
                    dataset,
                    batch_size=args.batch_size,
                    collate_fn=data_collator,
                    shuffle=False,
                    num_workers=4,
                    sampler=None,
                )

    out_lst = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader)):
            # print(batch)
            # batch = model.transfer_batch_to_device(batch, model.device)
            print(batch_idx)
            # if batch_idx >= 5:
            #     break
            # print(batch['input_ids'].device, model.device)
            out = test_step(model, gpt2, batch, batch_idx, args, tokenizer, beam_handle=out_handle_beam, gold_handle=out_handle_gold, tuning_mode=args.tuning_mode)
            out_lst.append(out)
            for x in out['preds']:
                print(x)
            # batch = model.transfer_batch_to_device(batch, 'cpu')
        result = test_epoch_end(out_lst)

    out_handle_beam.close()
    out_handle_gold.close()

    print('writing the test results to ', curr_dir)
    print('writing the gold results to ', gold_dir)


    # print(result)
    for k, v in result.items():
        if k != 'preds':
            print(k, v)

    import sys
    sys.path.insert(0, '../eval')
    from utils import calculate_rouge, chunks, parse_numeric_n_bool_cl_kwargs, use_task_specific_params

    try:
        print ('test_tgt_path', test_tgt_path)
        output_lns    = [x.rstrip() for x in open(curr_dir).readlines()]
        reference_lns = [x.rstrip() for x in open(test_tgt_path).readlines()]
        assert len(output_lns) == len(reference_lns)
        scores = calculate_rouge(output_lns, reference_lns)
        if wandb_run:
            wandb_scores = dict([(f"eval/{k}", scores[k]) for k in scores])
            wandb_run.log(wandb_scores)
            wandb_run.summary["finish_time"] = str(datetime.now())
        print (scores)
    except:
        pass

    return


if __name__ == "__main__":
    main()


================================================
FILE: finetune/textgen/gpt2/sum_data_collator.py
================================================
import torch

from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy
from transformers.tokenization_utils import PreTrainedTokenizer
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

@dataclass
class DataCollatorForSumLanguageModeling:
    """
    Data collator used for language modeling.
    - collates batches of tensors, honoring their tokenizer's pad_token
    - preprocesses batches for masked language modeling
    """
    tokenizer: PreTrainedTokenizer
    mlm: bool = False
    format_mode: str = 'cat'
    mlm_probability: float = 0.15

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        if isinstance(examples[0], (dict, BatchEncoding)):
            examples = [e["input_ids"] for e in examples]
        # print(examples[0])
        # print(len(examples))
        input_ids, labels, src, tgt = zip(*examples)
        # print(len(input_ids), len(labels), len(weights))
        if self.mlm:
            inputs, labels = self.mask_tokens(batch)
            return {"input_ids": inputs, "labels": labels}
        else:

            # print(self.format_mode)

            if self.format_mode == 'peek' or self.format_mode == 'cat':
                mode_input = 1
            elif self.format_mode == 'nopeek':
                assert False, 'should use format_mode = peek or cat.'
                mode_input = 2
            elif self.format_mode == 'infix':
                assert False, 'should use format_mode = peek or cat.'
                mode_input = 4

            # mode_input = 1 # means that we take the input again.
            # mode_input = 2 # means that we do not peek at src again.
            # mode_input = 3 # means that we look at the categories, and see the input again.

            # print(self.format_mode, mode_input)

            if mode_input == 1:
                # input, batch
                batch = self._tensorize_batch(input_ids)
                labels = self._tensorize_batch(labels)
                src = self._tensorize_batch(src)

            labels[labels == self.tokenizer.pad_token_id] = -100 # tgt
            src_attn = (src != self.tokenizer.pad_token_id) # src
            tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt

            return {"input_ids": batch, "labels": labels}


    def _tensorize_batch(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> torch.Tensor:
        # In order to accept both lists of lists and lists of Tensors
        if isinstance(examples[0], (list, tuple)):
            examples = [torch.tensor(e, dtype=torch.long) for e in examples]
        length_of_first = examples[0].size(0)
        are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
        if are_tensors_same_length:
            return torch.stack(examples, dim=0)
        else:
            if self.tokenizer._pad_token is None:
                raise ValueError(
                    "You are attempting to pad samples but the tokenizer you are using"
                    f" ({self.tokenizer.__class__.__name__}) does not have one."
                )
            return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)


@dataclass
class DataCollatorForSumBatchGenLanguageModeling:
    """
    Data collator used for language modeling.
    - collates batches of tensors, honoring their tokenizer's pad_token
    - preprocesses batches for masked language modeling
    """
    tokenizer: PreTrainedTokenizer
    mlm: bool = True
    format_mode: str = 'cat'
    mlm_probability: float = 0.15
    max_source_length: int = 512
    max_target_length: int = 100


    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        if isinstance(examples[0], (dict, BatchEncoding)):
            examples = [e["input_ids"] for e in examples]
        # print(examples[0])
        # print(len(examples))

        mode_gen = 1

        if mode_gen == 0:
            input_ids, labels, src, tgt = zip(*examples)
            # print(len(input_ids), len(labels), len(weights))



            src = self._tensorize_batch(src) #src
            tgt = self._tensorize_batch(tgt)  # src

            src_attn = (src != self.tokenizer.pad_token_id) # src
            tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt

            return {"input_ids": src, "labels": tgt, 'src_attn': src_attn, 'tgt_attn':tgt_attn,
                    'src':src}

        else:
            src, tgt = zip(*examples)
            bsz = len(src)
            self.tokenizer.padding_side = "left"
            src = self.tokenizer(src, return_tensors="pt", padding=True, truncation=True, max_length=self.max_source_length)
            tgt = self.tokenizer(tgt, return_tensors="pt", padding=True, truncation=True, max_length=self.max_target_length)
            bos_seq = torch.ones(bsz, 1).fill_(self.tokenizer.bos_token_id).long()
            src_input_ids = torch.cat([src['input_ids'], bos_seq], dim=-1)
            bos_mask = torch.ones(bsz, 1).long()
            src_mask = torch.cat([src["attention_mask"], bos_mask],dim=-1)

            return {"input_ids": src_input_ids, "labels": tgt['input_ids'], 'src_attn': src_mask,
                    'tgt_attn': tgt["attention_mask"]}




    def _tensorize_batch(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> torch.Tensor:
        # In order to accept both lists of lists and lists of Tensors
        if isinstance(examples[0], (list, tuple)):
            examples = [torch.tensor(e, dtype=torch.long) for e in examples]
        length_of_first = examples[0].size(0)
        are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
        if are_tensors_same_length:
            return torch.stack(examples, dim=0)
        else:
            if self.tokenizer._pad_token is None:
                raise ValueError(
                    "You are attempting to pad samples but the tokenizer you are using"
                    f" ({self.tokenizer.__class__.__name__}) does not have one."
                )
            return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)



================================================
FILE: finetune/textgen/gpt2/sum_dataset.py
================================================
import os
import pickle
import random
import time
import copy
import json
from typing import Dict, List, Optional
import ast
import torch
from torch.utils.data.dataset import Dataset

from filelock import FileLock

from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging

from pathlib import Path
import linecache

# from transformers import BertTokenizer, BertForMaskedLM, BertModel, BertTokenizerFast
# from transformers import BertTokenizer,  BertTokenizerFast
logger = logging.get_logger(__name__)


class LineByLineSumTextDataset(Dataset):
    """
    This will be superseded by a framework-agnostic approach
    soon.
    """

    def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, bos_tok:str, eos_tok:str,
                 max_source_length:int, max_target_length:int, seq_prefix:str="", no_sep:bool=False, use_task_instruction:int=0, use_stream_mode:bool=True):
        assert os.path.isfile(file_path), f"Input file path {file_path} not found"
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)
        logger.info("Creating features from dataset file at %s", file_path)

        self.src_file = file_path
        self.tgt_file = file_path[:-6] + 'target'
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        if use_task_instruction:
            self.instruction = "Summarize the following text: "
        else:
            self.instruction = None
        print (f'Task instruction: "{self.instruction}"')

        separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0]
        eos_idx = tokenizer(eos_tok, add_special_tokens=False)['input_ids'][0]

        self.bos_idx = separator
        self.eos_idx = eos_idx

        self.length = [len(x) for x in Path(self.tgt_file).open().readlines()]
        self.tokenizer = tokenizer

        self.use_stream_mode = use_stream_mode

        self.seq_prefix = seq_prefix
        self.no_sep = no_sep

        if self.use_stream_mode:
            return
        else:
            src_lines = []
            with open(self.src_file, encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    line = self.instruction + line if self.instruction else line
                    if len(line) > 0 and not line.isspace():
                        src_lines.append(line)

                # print(len(list(f.read().splitlines())))
                # src_lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
            print(len(src_lines))
            with open(self.tgt_file, encoding="utf-8") as f:
                tgt_lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

            print(self.tgt_file, len(tgt_lines), '\n', self.src_file, len(src_lines))

            assert len(tgt_lines) == len(src_lines)

            src_encoding = tokenizer(src_lines, add_special_tokens=True, truncation=True, max_length=max_source_length,
                                                                  is_split_into_words=False)['input_ids']

            tgt_encoding = tokenizer(tgt_lines, add_special_tokens=True, truncation=True, max_length=max_target_length,
                                     is_split_into_words=False)['input_ids']

            assert len(src_encoding) == len(tgt_encoding)
            separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0]
            eos_idx = tokenizer(eos_tok, add_special_tokens=False)['input_ids'][0]

            edited_sents = []
            for src, tgt in zip(src_encoding, tgt_encoding):
                sent = src + [separator] + tgt + [eos_idx]
                # sent = ' {} {} '.format(src, bos_tok) + tgt + ' {}'.format(eos_tok)
                edited_sents.append(sent)

            # batch_encoding = tokenizer(edited_sents, add_special_tokens=True, truncation=True, max_length=block_size,
            #                                                       is_split_into_words=False)

            self.examples = edited_sents

            self.labels = copy.deepcopy(self.examples)



            self.src_sent = []
            self.tgt_sent = []
            if True:
                separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0]
                for i, elem in enumerate(self.labels):
                    sep_idx = elem.index(separator) + 1
                    self.src_sent.append(self.examples[i][:sep_idx-1])
                    self.tgt_sent.append(self.examples[i][sep_idx-1:])
                    self.labels[i][:sep_idx] = [-100] * sep_idx


            print(self.labels[0])
            print(self.examples[0])
            print(edited_sents[0])
            print(self.src_sent[0])
            print(self.tgt_sent[0])
            # assert len(self.src_cat) == len(self.examples)




    def __len__(self):
        return len(self.length)


    def __getitem__(self, i):
        if not self.use_stream_mode:
            return (torch.tensor(self.examples[i], dtype=torch.long),
                    torch.tensor(self.labels[i], dtype=torch.long),
                    torch.tensor(self.src_sent[i], dtype=torch.long),
                    torch.tensor(self.tgt_sent[i], dtype=torch.long),
                    )
        else:
            index = i + 1  # linecache starts at 1
            source_line = linecache.getline(str(self.src_file), index).rstrip("\n")
            tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
            assert source_line, f"empty source line for index {index}"
            assert tgt_line, f"empty tgt line for index {index}"

            source_line = self.instruction + source_line if self.instruction else self.seq_prefix + source_line

            src = self.tokenizer(source_line, add_special_tokens=True, truncation=True, max_length=self.max_source_length,
                                     is_split_into_words=False)['input_ids']

            tgt = self.tokenizer(tgt_line, add_special_tokens=True, truncation=True, max_length=self.max_target_length,
                                     is_split_into_words=False)['input_ids']

            if self.no_sep:
                sent = src + tgt + [self.eos_idx]
                label = copy.deepcopy(sent)
                label[:len(src)] = [-100] * len(src)
                src_sent = sent[:len(src)]
                tgt_sent = sent[len(src):]
            else:
                sent = src + [self.bos_idx] + tgt + [self.eos_idx]
                sep_idx = sent.index(self.bos_idx) + 1
                label = copy.deepcopy(sent)
                label[:sep_idx] = [-100] * sep_idx
                src_sent = sent[:sep_idx - 1]
                tgt_sent = sent[sep_idx - 1:]

            return (torch.tensor(sent, dtype=torch.long),
                    torch.tensor(label, dtype=torch.long),
                    torch.tensor(src_sent, dtype=torch.long),
                    torch.tensor(tgt_sent, dtype=torch.long),
                    )


class LineByLineSumBatchGenTextDataset(Dataset):
    """
    This will be superseded by a framework-agnostic approach
    soon.
    """

    def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, bos_tok:str, eos_tok:str,
                 max_source_length:int, max_target_length:int, use_task_instruction:int=0):
        assert os.path.isfile(file_path), f"Input file path {file_path} not found"
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)
        logger.info("Creating features from dataset file at %s", file_path)

        self.src_file = file_path
        self.tgt_file = file_path[:-6] + 'target'
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        if use_task_instruction:
            self.instruction = "Summarize the following text: "
        else:
            self.instruction = None
        print (f'Task instruction: "{self.instruction}"')

        separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0]
        eos_tok = "[SEP]"
        eos_idx = tokenizer(eos_tok, add_special_tokens=False)['input_ids'][0]

        self.bos_idx = separator
        self.eos_idx = eos_idx

        tokenizer.pad_token = "[PAD]"
        tokenizer.pad_token_id = 28896

        self.length = [len(x) for x in Path(self.tgt_file).open().readlines()]
        self.tokenizer = tokenizer
        return




    def __len__(self):
        return len(self.length)

    # def __getitem__(self, i) -> torch.Tensor:
    def __getitem__(self, i):
        # return (torch.tensor(self.examples[i], dtype=torch.long),
        #         torch.tensor(self.labels[i], dtype=torch.long),
        #         torch.tensor(self.src_sent[i], dtype=torch.long),
        #         torch.tensor(self.tgt_sent[i], dtype=torch.long),
        #         )

        modegen = 1
        index = i + 1  # linecache starts at 1
        source_line = linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"

        source_line = self.instruction + source_line if self.instruction else source_line

        if modegen == 0:

            src = self.tokenizer(source_line, add_special_tokens=True, truncation=True, max_length=self.max_source_length,
                                     is_split_into_words=False)['input_ids']

            tgt = self.tokenizer(tgt_line, add_special_tokens=True, truncation=True, max_length=self.max_target_length,
                                     is_split_into_words=False)['input_ids']

            sent = src + [self.bos_idx] + tgt + [self.eos_idx]

            sep_idx = sent.index(self.bos_idx) + 1

            label = copy.deepcopy(sent)
            label[:sep_idx] = [-100] * sep_idx
            src_sent = sent[:sep_idx - 1]
            tgt_sent = sent[sep_idx - 1:]

            return (torch.tensor(sent, dtype=torch.long),
                    torch.tensor(label, dtype=torch.long),
                    )

        else:
            return (source_line, tgt_line)



================================================
FILE: finetune/utils/custom_modeling_gpt2.py
================================================
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss


from transformers.activations import ACT2FN
from transformers.file_utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    replace_return_docstrings,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
    MultipleChoiceModelOutput,
)
from transformers.modeling_utils import (
    Conv1D,
    PreTrainedModel,
    SequenceSummary,
    find_pruneable_heads_and_indices,
    prune_conv1d_layer,
)
from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers.models.gpt2.configuration_gpt2 import GPT2Config


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "gpt2"
_CONFIG_FOR_DOC = "GPT2Config"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"

GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "gpt2",
    "gpt2-medium",
    "gpt2-large",
    "gpt2-xl",
    "distilgpt2",
    # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
]
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel


class GPT2ForTokenClassification(GPT2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.transformer = GPT2Model(config)
        if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
            classifier_dropout = config.classifier_dropout
        elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Model parallel
        self.model_parallel = False
        self.device_map = None

        # Initialize weights and apply final processing
        self.init_weights()

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        hidden_states = self.dropout(hidden_states)
        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + transformer_outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


class GPT2ForMultipleChoice(GPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]

    def __init__(self, config):
        super().__init__(config)
        # self.num_labels = config.num_labels
        if config.use_flash:
            print("GPT2ForMultipleChoice using Flash !!")
            from .hf_flash_gpt_2 import GPT2FlashModel
            self.transformer = GPT2FlashModel(config)
        elif config.use_gpt_neo:
            print("Using GPT2Neo Model !!")
            from .custom_modeling_gpt_neo import GPTNeoModel
            self.transformer = GPTNeoModel(config)
        else:
            self.transformer = GPT2Model(config)
            print("GPT2ForMultipleChoice not using Flash !!")
        # self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
        hidden_size = config.hidden_size if config.use_gpt_neo else config.n_embd
        self.classifier = nn.Linear(hidden_size, 1)

        self.init_weights()

        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the multiple choice classification loss. Indices should be in :obj:`[0, ...,
            num_choices - 1]`, where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None:
            batch_size, num_choices, sequence_length = input_ids.shape[:3]
        else:
            batch_size, num_choices, sequence_length = inputs_embeds.shape[:3]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        logits = self.classifier(hidden_states) #[batch x num_choices, ]

        assert (
            self.config.pad_token_id is not None
        ), "Cannot handle if no padding token is defined."
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            if input_ids is not None:
                sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
            else:
                sequence_lengths = -1
                logger.warning(
                    f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                    f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
                )

        pooled_logits = logits[range(batch_size * num_choices), sequence_lengths] #[batch x num_choices, ]
        reshaped_logits = pooled_logits.view(-1, num_choices) #[batch, num_choices]

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            # hidden_states=transformer_outputs.hidden_states,
            # attentions=transformer_outputs.attentions,
        )


class GPT2ForSequenceClassification(GPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        if config.use_flash:
            print("GPT2ForSequenceClassification using Flash !!")
            from .hf_flash_gpt_2 import GPT2FlashModel
            self.transformer = GPT2FlashModel(config)
        else:
            self.transformer = GPT2Model(config)

        self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)

        self.init_weights()

        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        logits = self.classifier(hidden_states)

        if input_ids is not None:
            batch_size, sequence_length = input_ids.shape[:2]
        else:
            batch_size, sequence_length = inputs_embeds.shape[:2]

        assert (
            self.config.pad_token_id is not None or batch_size == 1
        ), "Cannot handle batch sizes > 1 if no padding token is defined."
        if self.config.pad_token_id is None:
            sequence_lengths = -1
        else:
            if input_ids is not None:
                sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
            else:
                sequence_lengths = -1
                logger.warning(
                    f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                    f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
                )

        pooled_logits = logits[range(batch_size), sequence_lengths]

        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (pooled_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            # past_key_values=transformer_outputs.past_key_values,
            # hidden_states=transformer_outputs.hidden_states,
            # attentions=transformer_outputs.attentions,
        )


================================================
FILE: finetune/utils/custom_modeling_gpt_neo.py
================================================
# coding=utf-8
# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPT Neo model. torch==4.9.0 """


import os
from typing import Tuple

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.gpt_neo.configuration_gpt_neo import GPTNeoConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "GPTNeoConfig"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"

GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "EleutherAI/gpt-neo-1.3B",
    # See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
]

_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"


def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
    """Load tf checkpoints in a pytorch model"""
    try:
        import re

        import tensorflow as tf
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    tf_path = os.path.abspath(gpt_neo_checkpoint_path)
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        if "global_step" not in name and "adam" not in name:
            array = tf.train.load_variable(tf_path, name)
            array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy()
            name = name.replace("attn/q", "attn/attention/q_proj/w")
            name = name.replace("attn/k", "attn/attention/k_proj/w")
            name = name.replace("attn/v", "attn/attention/v_proj/w")
            name = name.replace("attn/o", "attn/attention/out_proj/w")
            name = name.replace("norm_1", "ln_1")
            name = name.replace("norm_2", "ln_2")
            name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b")
            name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w")
            name = name.replace("conv1d_main/c_fc/bias", "c_fc/b")
            name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w")
            name = name.replace("conv1d_main/c_proj/bias", "c_proj/b")

            names.append(name)
            arrays.append(array)

    for name, array in zip(names, arrays):
        name = name[5:]  # skip "gpt2/"
        name = name.split("/")
        pointer = model.transformer
        for m_name in name:
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
                scope_names = re.split(r"(\d+)", m_name)
            else:
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "b":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
                pointer = getattr(pointer, "weight")
            else:
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]

        if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
            array = array.transpose()

        if name == ["wte"]:
            # if vocab is padded, then trim off the padding embeddings
            array = array[: config.vocab_size]

        try:
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}"
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print(f"Initialize PyTorch weight {name}")
        pointer.data = torch.from_numpy(array)

    # init the final linear layer using word embeddings
    embs = model.transformer.wte.weight
    lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False)
    lin.weight = embs
    model.set_output_embeddings(lin)
    return model


class GPTNeoAttentionMixin:
    """
    A few attention related utilities for attention modules in GPT Neo, to be used as a mixin.
    """

    @staticmethod
    def _get_block_length_and_num_blocks(seq_length, window_size):
        """
        Computes ``block_length`` and ``num_blocks`` such that ``seq_length`` becomes evenly divisible by
        ``block_length``.
        """
        block_length = window_size
        while seq_length % block_length != 0:
            block_length -= 1
        num_blocks = seq_length // block_length
        return block_length, num_blocks

    @staticmethod
    def _look_back(tensor, block_length, window_size, pad_value=0, is_key_value=True):
        """
        Used to implement attention between consecutive blocks. This method assumes that dim 1 of :obj:`tensor`
        represents the :obj:`seq_length` dimension. It splits :obj:`seq_length` dimension into :obj:`num_blocks` and
        :obj:`window_size` + :obj:`block_length`. It pads the :obj:`seq_length` dimension if necessary.

        Example::

            tensor: torch.tensor([[[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]])
            with shape (1, 8, 1)
            block_length = window_size = 4
            _look_back =>
            torch.tensor([[[[ 0.0000], [ 0.0000], [ 0.0000], [ 0.0000], [ 0.4983], [ 2.6918], [-0.0071], [ 1.0492]],
                           [[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]]])

        Args:
            tensor (:obj:`torch.Tensor`): tensor of shape :obj:`[batch_size, seq_length, hidden_dim]` or :obj:`[batch_size, seq_length]`
            block_length (:obj:`int`): An integer specifying the length of each block, used as a step size when creating the blocks.
            window_size (:obj:`int`): An integer specifying the size of attention window, used to calculate the final block size when creating the block.
            pad_value (obj:`int`): An integer specifying the value to use when padding the :obj:`tensor`.
            is_key_value (:obj:`bool`): A boolean indicating if the :obj:`tensor` is a key/value tensor.

        Returns:
            tensor of shape :obj:`[batch_size, num_blocks, window_size + block_length, ...]` if :obj:`is_key_value` is
            :obj:`True` else a tensor of shape :obj:`[batch_size, window_size + block_length, num_blocks, ...]`
        """
        if len(tensor.shape) == 3:
            padding_side = (0, 0, window_size, 0)
        elif len(tensor.shape) == 2:
            padding_side = (window_size, 0)
        else:
            raise ValueError(f"Input tensor rank should be one of [2, 3], but is: {len(tensor.shape)}")

        padded_tensor = nn.functional.pad(tensor, padding_side, value=pad_value)
        padded_tensor = padded_tensor.unfold(dimension=1, size=window_size + block_length, step=block_length)

        if is_key_value:
            padded_tensor = padded_tensor.transpose(-2, -1)
        return padded_tensor

    @staticmethod
    def _split_seq_length_dim_to(tensors, dim_factor_1, dim_factor_2):
        """
        Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
        """
        batch_size = tensors.shape[0]
        split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)

        if len(tensors.shape) == 3:
            return torch.reshape(tensors, split_dim_shape + (-1,))
        elif len(tensors.shape) == 2:
            return torch.reshape(tensors, split_dim_shape)
        else:
            raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")

    @staticmethod
    def create_local_attention_mask(batch_size, seq_length, window_size, device, attention_mask=None):
        block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
        indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)

        query_indices = GPTNeoAttentionMixin._split_seq_length_dim_to(indices, num_blocks, block_length)
        key_indices = GPTNeoAttentionMixin._look_back(indices, block_length, window_size, is_key_value=False)

        # create mask tensor such that each block contains a causal_mask for that block
        causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))

        if attention_mask is None:
            attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)

        # A block can also be padded because of the _look_back operation
        # look back into the attention_block such that it will also get padded the same way
        # and have 0s in the padded position
        attention_mask = GPTNeoAttentionMixin._look_back(attention_mask, block_length, window_size, is_key_value=False)
        attention_mask = attention_mask.unsqueeze(-2)  # Add an extra dimension to account for hidden_dim

        # Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
        # will contain 0s.
        # This also makes sure that other positions ignored by the attention_mask will also be ignored
        # in the causal_mask.
        causal_mask = causal_mask * attention_mask

        # In GPT Neo's local attention each window can attend to at most window_size tokens
        # rest of the tokens should be ignored.
        relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
        visible = torch.gt(relative_position, -window_size)

        causal_mask = causal_mask * visible
        causal_mask = causal_mask.unsqueeze(-3).bool()  # Add an extra dimension to account for num_heads

        return causal_mask

    def _split_heads(self, tensor, num_heads, attn_head_size):
        """
        Splits hidden_size dim into attn_head_size and num_heads
        """
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
        tensor = tensor.view(*new_shape)
        if len(tensor.shape) == 5:
            return tensor.permute(0, 1, 3, 2, 4)  # (batch, blocks, head, block_length, head_features)
        elif len(tensor.shape) == 4:
            return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
        else:
            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden_size
        """
        if len(tensor.shape) == 5:
            tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
        elif len(tensor.shape) == 4:
            tensor = tensor.permute(0, 2, 1, 3).contiguous()
        else:
            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        return tensor.view(new_shape)

    def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None):
        # Keep the attention weights computation in fp32 to avoid overflow issues
        query = query.to(torch.float32)
        key = key.to(torch.float32)

        with torch.cuda.amp.autocast(enabled=False):
            attn_weights = torch.matmul(query, key.transpose(-1, -2))
        attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.Softmax(dim=-1)(attn_weights)
        attn_weights = attn_weights.to(value.dtype)
        attn_weights = attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights


class GPTNeoSelfAttention(nn.Module, GPTNeoAttentionMixin):
    def __init__(self, config):
        super().__init__()

        max_positions = config.max_position_embeddings
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
                1, 1, max_positions, max_positions
            ),
        )
        self.register_buffer("masked_bias", torch.tensor(-1e9))

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.resid_dropout)

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
            )

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        layer_past=None,
        head_mask=None,
        use_cache=False,
        output_attentions=False,
    ):

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()

        attn_output, attn_weights = self._attn(
            query, key, value, causal_mask, self.masked_bias, self.attn_dropout, attention_mask, head_mask
        )

        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # a, present, (attentions)


class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
    def __init__(self, config):
        super().__init__()

        self.register_buffer("masked_bias", torch.tensor(-1e9))

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.resid_dropout)

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
            )

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.window_size = config.window_size

    def forward(
        self,
        hidden_states,
        attention_mask,
        layer_past=None,
        head_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        query = self.q_proj(hidden_states)

        if layer_past is not None:
            past = layer_past[0]
            key_value_hidden_states = torch.cat([past, hidden_states], dim=1)
            past_length = past.size()[1]
        else:
            key_value_hidden_states = hidden_states
            past_length = 0

        key = self.k_proj(key_value_hidden_states)
        value = self.v_proj(key_value_hidden_states)

        # compute block length and num_blocks
        batch_size, seq_length = hidden_states.shape[:2]
        full_seq_length = seq_length + past_length
        block_length, num_blocks = self._get_block_length_and_num_blocks(full_seq_length, self.window_size)

        # create buckets
        if layer_past is not None:
            # we just need 1 block with block_length 1 when caching is enabled
            query = self._split_seq_length_dim_to(query, 1, 1)
        else:
            query = self._split_seq_length_dim_to(query, num_blocks, block_length)

        key = self._look_back(key, block_length, self.window_size)
        value = self._look_back(value, block_length, self.window_size)

        # select key/value vectors only for the last block
        if layer_past is not None:
            key = key[:, -1:, ...]
            value = value[:, -1:, ...]

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            # only take the mask for the last block
            attention_mask = attention_mask[:, -1:, :, -1:, :]

        # attn
        attn_output, attn_weights = self._attn(
            query,
            key,
            value,
            causal_mask=attention_mask,
            masked_bias=self.masked_bias,
            attn_dropout=self.attn_dropout,
            head_mask=head_mask,
        )

        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim)

        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output,)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # a, (attentions)


class GPTNeoAttention(nn.Module):
    def __init__(self, config, layer_id=0):
        super().__init__()
        self.layer_id = layer_id
        self.attention_layers = config.attention_layers
        self.attention_type = self.attention_layers[layer_id]

        if self.attention_type == "global":
            self.attention = GPTNeoSelfAttention(config)
        elif self.attention_type == "local":
            self.attention = GPTNeoLocalSelfAttention(config)
        else:
            raise NotImplementedError(
                "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
                f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only."
            )

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        outputs = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            layer_past=layer_past,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        # cache the hidden_states instead of key_value_states
        # for local attention layer
        if self.attention_type == "local":
            if layer_past is None:
                past = hidden_states
            else:
                past = torch.cat([layer_past[0], hidden_states], dim=1)
            outputs = (outputs[0], (past,)) + outputs[1:]
        return outputs


class GPTNeoMLP(nn.Module):
    def __init__(self, intermediate_size, config):  # in MLP: intermediate_size= 4 * hidden_size
        super().__init__()
        embed_dim = config.hidden_size
        self.c_fc = nn.Linear(embed_dim, intermediate_size)
        self.c_proj = nn.Linear(intermediate_size, embed_dim)
        self.act = ACT2FN[config.activation_function]
        self.dropout = nn.Dropout(config.resid_dropout)

    def forward(self, hidden_states):
        hidden_states = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states


class GPTNeoBlock(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPTNeoAttention(config, layer_id)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPTNeoMLP(inner_dim, config)

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)


class GPTNeoPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = GPTNeoConfig
    load_tf_weights = load_tf_weights_in_gpt_neo
    base_model_prefix = "transformer"

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    def _init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, (nn.Linear,)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


GPT_NEO_START_DOCSTRING = r"""

    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.

    Parameters:
        config (:class:`~transformers.GPTNeoConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
"""

GPT_NEO_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
            :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
            ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
            passed as ``input_ids``.

            Indices can be obtained using :class:`~transformers.GPTNeoTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            details.

            `What are input IDs? <../glossary.html#input-ids>`__
        past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.num_layers`):
            Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
            :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
            have their past given to this model should not be passed as ``input_ids`` as they have already been
            computed.
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            `What are attention masks? <../glossary.html#attention-mask>`__
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.

            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.

            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.

            If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
            :obj:`past_key_values`).
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        output_attentions (:obj:`bool`, `optional`):
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
        output_hidden_states (:obj:`bool`, `optional`):
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
        return_dict (:obj:`bool`, `optional`):
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""


@add_start_docstrings(
    "The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.",
    GPT_NEO_START_DOCSTRING,
)
class GPTNeoModel(GPTNeoPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = co
Download .txt
gitextract_dlrbq2_y/

├── README.md
├── demo.py
├── finetune/
│   ├── README.md
│   ├── deepspeed/
│   │   └── cpu_offload.json
│   ├── mc/
│   │   ├── README.md
│   │   ├── data/
│   │   │   └── medqa_usmle_hf/
│   │   │       ├── dev.json
│   │   │       ├── test.json
│   │   │       └── train.json
│   │   ├── preprocess_medqa.py
│   │   ├── run_experiments.py
│   │   └── run_multiple_choice.py
│   ├── seqcls/
│   │   ├── README.md
│   │   ├── data/
│   │   │   ├── bioasq_hf/
│   │   │   │   ├── dev.json
│   │   │   │   ├── test.json
│   │   │   │   └── train.json
│   │   │   └── pubmedqa_hf/
│   │   │       ├── dev.json
│   │   │       ├── test.json
│   │   │       └── train.json
│   │   ├── preprocess_blurb_seqcls.py
│   │   └── run_seqcls_gpt.py
│   ├── setup/
│   │   └── requirements.txt
│   ├── textgen/
│   │   ├── data/
│   │   │   └── meqsum/
│   │   │       ├── test.source
│   │   │       ├── test.target
│   │   │       ├── train.source
│   │   │       ├── train.target
│   │   │       ├── val.source
│   │   │       └── val.target
│   │   └── gpt2/
│   │       ├── finetune_for_summarization.py
│   │       ├── generate_demo.py
│   │       ├── run_generation_batch.py
│   │       ├── sum_data_collator.py
│   │       └── sum_dataset.py
│   └── utils/
│       ├── custom_modeling_gpt2.py
│       ├── custom_modeling_gpt_neo.py
│       └── hf_flash_gpt_2.py
└── tokenize/
    └── train_bpe.py
Download .txt
SYMBOL INDEX (118 symbols across 11 files)

FILE: finetune/mc/preprocess_medqa.py
  function dump_jsonl (line 13) | def dump_jsonl(data, fpath):
  function process_medqa (line 18) | def process_medqa(fname):

FILE: finetune/mc/run_multiple_choice.py
  class ModelArguments (line 61) | class ModelArguments:
  class DataTrainingArguments (line 105) | class DataTrainingArguments:
    method __post_init__ (line 160) | def __post_init__(self):
  class DataCollatorForMultipleChoice (line 172) | class DataCollatorForMultipleChoice:
    method __call__ (line 200) | def __call__(self, features):
  function main (line 225) | def main():
  function _mp_fn (line 520) | def _mp_fn(index):

FILE: finetune/seqcls/preprocess_blurb_seqcls.py
  function dump_jsonl (line 11) | def dump_jsonl(data, fpath):
  function process_pubmedqa (line 22) | def process_pubmedqa(fname):
  function process_bioasq (line 50) | def process_bioasq(fname):

FILE: finetune/seqcls/run_seqcls_gpt.py
  class DataTrainingArguments (line 79) | class DataTrainingArguments:
    method __post_init__ (line 156) | def __post_init__(self):
  class ModelArguments (line 175) | class ModelArguments:
  function main (line 213) | def main():
  function _mp_fn (line 632) | def _mp_fn(index):

FILE: finetune/textgen/gpt2/finetune_for_summarization.py
  class ModelArguments (line 32) | class ModelArguments:
  class DataArguments (line 56) | class DataArguments:
  function get_dataset (line 96) | def get_dataset(
  function finetune (line 121) | def finetune():

FILE: finetune/textgen/gpt2/run_generation_batch.py
  function prepare_ctrl_input (line 115) | def prepare_ctrl_input(args, _, tokenizer, prompt_text):
  function prepare_xlm_input (line 125) | def prepare_xlm_input(args, model, tokenizer, prompt_text):
  function prepare_xlnet_input (line 151) | def prepare_xlnet_input(args, _, tokenizer, prompt_text):
  function prepare_transfoxl_input (line 157) | def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
  function read_e2e_files (line 170) | def read_e2e_files(path, tokenizer, lowdata_token=None):
  function read_wp_files (line 187) | def read_wp_files(path, tokenizer):
  function read_classifySentiment_files (line 199) | def read_classifySentiment_files(path, tokenizer):
  function read_classifyTopic_files (line 209) | def read_classifyTopic_files(path, tokenizer):
  function lmap (line 232) | def lmap(f, x):
  function ids_to_clean_text (line 236) | def ids_to_clean_text(tokenizer, generated_ids):
  function flatten_list (line 244) | def flatten_list(summary_ids):
  function calculate_rouge (line 247) | def calculate_rouge(output_lns, reference_lns, use_stemmer=True):
  function test_epoch_end (line 258) | def test_epoch_end(outputs, prefix="test"):
  function test_step (line 282) | def test_step(model, gpt2, batch, batch_idx, args, tokenizer, beam_handl...
  function read_webnlg_files (line 393) | def read_webnlg_files(path, tokenizer):
  function read_triples_files2 (line 435) | def read_triples_files2(path, tokenizer):
  function read_triples_files (line 473) | def read_triples_files(path, tokenizer):
  function write_e2e_corr (line 515) | def write_e2e_corr(prompt_lst, file_dict, corr_path):
  function write_e2e_src (line 547) | def write_e2e_src(prompt_lst, corr_path):
  function get_emb (line 555) | def get_emb(sent_lst, word_lst, num_layer=1):
  function adjust_length_to_model (line 664) | def adjust_length_to_model(length, max_sequence_length):
  function read_doc_for_embmatch (line 674) | def read_doc_for_embmatch(file_name, num_layer):
  function main (line 688) | def main():

FILE: finetune/textgen/gpt2/sum_data_collator.py
  class DataCollatorForSumLanguageModeling (line 10) | class DataCollatorForSumLanguageModeling:
    method __call__ (line 21) | def __call__(
    method _tensorize_batch (line 65) | def _tensorize_batch(
  class DataCollatorForSumBatchGenLanguageModeling (line 85) | class DataCollatorForSumBatchGenLanguageModeling:
    method __call__ (line 99) | def __call__(
    method _tensorize_batch (line 141) | def _tensorize_batch(

FILE: finetune/textgen/gpt2/sum_dataset.py
  class LineByLineSumTextDataset (line 25) | class LineByLineSumTextDataset(Dataset):
    method __init__ (line 31) | def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, blo...
    method __len__ (line 130) | def __len__(self):
    method __getitem__ (line 134) | def __getitem__(self, i):
  class LineByLineSumBatchGenTextDataset (line 177) | class LineByLineSumBatchGenTextDataset(Dataset):
    method __init__ (line 183) | def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, blo...
    method __len__ (line 218) | def __len__(self):
    method __getitem__ (line 222) | def __getitem__(self, i):

FILE: finetune/utils/custom_modeling_gpt2.py
  class GPT2ForTokenClassification (line 57) | class GPT2ForTokenClassification(GPT2PreTrainedModel):
    method __init__ (line 58) | def __init__(self, config):
    method forward (line 79) | def forward(
  class GPT2ForMultipleChoice (line 137) | class GPT2ForMultipleChoice(GPT2PreTrainedModel):
    method __init__ (line 140) | def __init__(self, config):
    method forward (line 164) | def forward(
  class GPT2ForSequenceClassification (line 253) | class GPT2ForSequenceClassification(GPT2PreTrainedModel):
    method __init__ (line 256) | def __init__(self, config):
    method forward (line 274) | def forward(

FILE: finetune/utils/custom_modeling_gpt_neo.py
  function load_tf_weights_in_gpt_neo (line 53) | def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
  class GPTNeoAttentionMixin (line 137) | class GPTNeoAttentionMixin:
    method _get_block_length_and_num_blocks (line 143) | def _get_block_length_and_num_blocks(seq_length, window_size):
    method _look_back (line 155) | def _look_back(tensor, block_length, window_size, pad_value=0, is_key_...
    method _split_seq_length_dim_to (line 196) | def _split_seq_length_dim_to(tensors, dim_factor_1, dim_factor_2):
    method create_local_attention_mask (line 211) | def create_local_attention_mask(batch_size, seq_length, window_size, d...
    method _split_heads (line 246) | def _split_heads(self, tensor, num_heads, attn_head_size):
    method _merge_heads (line 259) | def _merge_heads(self, tensor, num_heads, attn_head_size):
    method _attn (line 272) | def _attn(self, query, key, value, causal_mask, masked_bias, attn_drop...
  class GPTNeoSelfAttention (line 298) | class GPTNeoSelfAttention(nn.Module, GPTNeoAttentionMixin):
    method __init__ (line 299) | def __init__(self, config):
    method forward (line 327) | def forward(
  class GPTNeoLocalSelfAttention (line 374) | class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
    method __init__ (line 375) | def __init__(self, config):
    method forward (line 398) | def forward(
  class GPTNeoAttention (line 472) | class GPTNeoAttention(nn.Module):
    method __init__ (line 473) | def __init__(self, config, layer_id=0):
    method forward (line 489) | def forward(
  class GPTNeoMLP (line 518) | class GPTNeoMLP(nn.Module):
    method __init__ (line 519) | def __init__(self, intermediate_size, config):  # in MLP: intermediate...
    method forward (line 527) | def forward(self, hidden_states):
  class GPTNeoBlock (line 535) | class GPTNeoBlock(nn.Module):
    method __init__ (line 536) | def __init__(self, config, layer_id):
    method forward (line 545) | def forward(
  class GPTNeoPreTrainedModel (line 583) | class GPTNeoPreTrainedModel(PreTrainedModel):
    method __init__ (line 593) | def __init__(self, *inputs, **kwargs):
    method _init_weights (line 596) | def _init_weights(self, module):
  class GPTNeoModel (line 701) | class GPTNeoModel(GPTNeoPreTrainedModel):
    method __init__ (line 702) | def __init__(self, config):
    method get_input_embeddings (line 714) | def get_input_embeddings(self):
    method set_input_embeddings (line 717) | def set_input_embeddings(self, new_embeddings):
    method forward (line 727) | def forward(
  class GPTNeoForCausalLM (line 900) | class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
    method __init__ (line 904) | def __init__(self, config):
    method get_output_embeddings (line 911) | def get_output_embeddings(self):
    method set_output_embeddings (line 914) | def set_output_embeddings(self, new_embeddings):
    method prepare_inputs_for_generation (line 917) | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
    method forward (line 952) | def forward(
    method _reorder_cache (line 1021) | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.T...
  class GPTNeoForSequenceClassification (line 1048) | class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
    method __init__ (line 1051) | def __init__(self, config):
    method forward (line 1066) | def forward(

FILE: finetune/utils/hf_flash_gpt_2.py
  class GPT2FlashAttention (line 37) | class GPT2FlashAttention(GPT2Attention):
    method __init__ (line 38) | def __init__(self, config, is_cross_attention=False, layer_idx=None):
    method _attn (line 42) | def _attn(self, query, key, value, attention_mask=None, head_mask=None):
  class GPT2FlashBlock (line 72) | class GPT2FlashBlock(GPT2Block):
    method __init__ (line 73) | def __init__(self, config, layer_idx=None):
  class GPT2FlashModel (line 89) | class GPT2FlashModel(GPT2Model):
    method __init__ (line 90) | def __init__(self, config):
  class GPT2FlashLMHeadModel (line 111) | class GPT2FlashLMHeadModel(GPT2LMHeadModel):
    method __init__ (line 112) | def __init__(self, config):
Condensed preview — 36 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (230K chars).
[
  {
    "path": "README.md",
    "chars": 1042,
    "preview": "# BioMedLM\n\nCode used for pre-training and fine-tuning the [BioMedLM](https://huggingface.co/stanford-crfm/pubmedgpt) mo"
  },
  {
    "path": "demo.py",
    "chars": 535,
    "preview": "import torch\n\nfrom transformers import GPT2LMHeadModel, GPT2Tokenizer\n\ndevice = torch.device(\"cuda\")\n\ntokenizer = GPT2To"
  },
  {
    "path": "finetune/README.md",
    "chars": 4403,
    "preview": "# Biomedical downstream evaluation\n\n## NLU\n### Dependencies\n```bash\nconda create -n pubmedgpt python=3.8.12 pytorch=1.12"
  },
  {
    "path": "finetune/deepspeed/cpu_offload.json",
    "chars": 728,
    "preview": "{\n  \"optimizer\": {\n    \"type\": \"AdamW\",\n    \"params\": {\n      \"lr\": 2e-06,\n      \"betas\": [\n        0.9,\n        0.999\n "
  },
  {
    "path": "finetune/mc/README.md",
    "chars": 684,
    "preview": "## Setting Up MedQA\n\n1.) Download data from [MedQA GitHub](https://github.com/jind11/MedQA) . The GitHub should have a l"
  },
  {
    "path": "finetune/mc/data/medqa_usmle_hf/dev.json",
    "chars": 189,
    "preview": "{\"id\": \"id\", \"sent1\": \"passage and question ...\", \"sent2\": \"\", \"ending0\": \"answer 0\", \"ending1\": \"answer 1\", \"ending2\": "
  },
  {
    "path": "finetune/mc/data/medqa_usmle_hf/test.json",
    "chars": 189,
    "preview": "{\"id\": \"id\", \"sent1\": \"passage and question ...\", \"sent2\": \"\", \"ending0\": \"answer 0\", \"ending1\": \"answer 1\", \"ending2\": "
  },
  {
    "path": "finetune/mc/data/medqa_usmle_hf/train.json",
    "chars": 189,
    "preview": "{\"id\": \"id\", \"sent1\": \"passage and question ...\", \"sent2\": \"\", \"ending0\": \"answer 0\", \"ending1\": \"answer 1\", \"ending2\": "
  },
  {
    "path": "finetune/mc/preprocess_medqa.py",
    "chars": 1431,
    "preview": "import os\nimport json\nimport random\nimport shutil\nimport numpy as np\nfrom tqdm import tqdm\n\n\nroot = \"data\"\nos.system(f\"m"
  },
  {
    "path": "finetune/mc/run_experiments.py",
    "chars": 2170,
    "preview": "import json\nimport os\nimport subprocess\nimport sys\n\nenv_setup_cmd = \"task=medqa_usmle_hf ; datadir=data/$task ; export W"
  },
  {
    "path": "finetune/mc/run_multiple_choice.py",
    "chars": 22678,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved"
  },
  {
    "path": "finetune/seqcls/README.md",
    "chars": 821,
    "preview": "## Setting Up BLURB (PubMedQA and BioASQ)\n\n1.) Download [BioASQ](http://www.bioasq.org/) and [PubMedQA](https://pubmedqa"
  },
  {
    "path": "finetune/seqcls/data/bioasq_hf/dev.json",
    "chars": 106,
    "preview": "{\"id\": \"passage id\", \"sentence1\": \"question text ...\", \"sentence2\": \"passage text ...\", \"label\": \"label\"}\n"
  },
  {
    "path": "finetune/seqcls/data/bioasq_hf/test.json",
    "chars": 106,
    "preview": "{\"id\": \"passage id\", \"sentence1\": \"question text ...\", \"sentence2\": \"passage text ...\", \"label\": \"label\"}\n"
  },
  {
    "path": "finetune/seqcls/data/bioasq_hf/train.json",
    "chars": 106,
    "preview": "{\"id\": \"passage id\", \"sentence1\": \"question text ...\", \"sentence2\": \"passage text ...\", \"label\": \"label\"}\n"
  },
  {
    "path": "finetune/seqcls/data/pubmedqa_hf/dev.json",
    "chars": 106,
    "preview": "{\"id\": \"passage id\", \"sentence1\": \"question text ...\", \"sentence2\": \"passage text ...\", \"label\": \"label\"}\n"
  },
  {
    "path": "finetune/seqcls/data/pubmedqa_hf/test.json",
    "chars": 106,
    "preview": "{\"id\": \"passage id\", \"sentence1\": \"question text ...\", \"sentence2\": \"passage text ...\", \"label\": \"label\"}\n"
  },
  {
    "path": "finetune/seqcls/data/pubmedqa_hf/train.json",
    "chars": 106,
    "preview": "{\"id\": \"passage id\", \"sentence1\": \"question text ...\", \"sentence2\": \"passage text ...\", \"label\": \"label\"}\n"
  },
  {
    "path": "finetune/seqcls/preprocess_blurb_seqcls.py",
    "chars": 2415,
    "preview": "import os\nimport csv\nimport json\nimport random\nimport shutil\nimport numpy as np\nimport pandas as pd\nfrom tqdm import tqd"
  },
  {
    "path": "finetune/seqcls/run_seqcls_gpt.py",
    "chars": 27857,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2020 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under"
  },
  {
    "path": "finetune/setup/requirements.txt",
    "chars": 129,
    "preview": "datasets==2.6.1\nfairscale==0.4.12\nhuggingface-hub==0.10.1\nrouge-score==0.0.4\nsacrebleu==2.0.0\ntransformers==4.24.0\nwandb"
  },
  {
    "path": "finetune/textgen/data/meqsum/test.source",
    "chars": 319,
    "preview": "The source text for an example. For instance this could be the full article that is supposed to be summarized. There sho"
  },
  {
    "path": "finetune/textgen/data/meqsum/test.target",
    "chars": 390,
    "preview": "The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is"
  },
  {
    "path": "finetune/textgen/data/meqsum/train.source",
    "chars": 319,
    "preview": "The source text for an example. For instance this could be the full article that is supposed to be summarized. There sho"
  },
  {
    "path": "finetune/textgen/data/meqsum/train.target",
    "chars": 390,
    "preview": "The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is"
  },
  {
    "path": "finetune/textgen/data/meqsum/val.source",
    "chars": 319,
    "preview": "The source text for an example. For instance this could be the full article that is supposed to be summarized. There sho"
  },
  {
    "path": "finetune/textgen/data/meqsum/val.target",
    "chars": 390,
    "preview": "The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is"
  },
  {
    "path": "finetune/textgen/gpt2/finetune_for_summarization.py",
    "chars": 5230,
    "preview": "import torch\nfrom typing import Optional\nfrom dataclasses import dataclass, field\nfrom transformers import (\n    CONFIG_"
  },
  {
    "path": "finetune/textgen/gpt2/generate_demo.py",
    "chars": 772,
    "preview": "import sys\nimport torch\n\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nmodel_path = sys.argv[1]\ndevice ="
  },
  {
    "path": "finetune/textgen/gpt2/run_generation_batch.py",
    "chars": 60436,
    "preview": "\n#!/usr/bin/env python3\n# coding=utf-8\n# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors a"
  },
  {
    "path": "finetune/textgen/gpt2/sum_data_collator.py",
    "chars": 6453,
    "preview": "import torch\n\nfrom dataclasses import dataclass\nfrom torch.nn.utils.rnn import pad_sequence\nfrom transformers.tokenizati"
  },
  {
    "path": "finetune/textgen/gpt2/sum_dataset.py",
    "chars": 10547,
    "preview": "import os\nimport pickle\nimport random\nimport time\nimport copy\nimport json\nfrom typing import Dict, List, Optional\nimport"
  },
  {
    "path": "finetune/utils/custom_modeling_gpt2.py",
    "chars": 13209,
    "preview": "import math\nimport os\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.ut"
  },
  {
    "path": "finetune/utils/custom_modeling_gpt_neo.py",
    "chars": 48834,
    "preview": "# coding=utf-8\n# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the A"
  },
  {
    "path": "finetune/utils/hf_flash_gpt_2.py",
    "chars": 4745,
    "preview": "# coding=utf-8\n# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.\n# Copyright (c) 2018, NVIDIA CORPORAT"
  },
  {
    "path": "tokenize/train_bpe.py",
    "chars": 1224,
    "preview": "import json\nimport os\nimport sys\nfrom tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processor"
  }
]

About this extraction

This page contains the full source code of the stanford-crfm/pubmedgpt GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 36 files (214.5 KB), approximately 51.7k tokens, and a symbol index with 118 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!