Full Code of Yueeeeeeee/LlamaRec for AI

main 48b288b23197 cached
32 files
159.5 KB
41.4k tokens
279 symbols
1 requests
Download .txt
Repository: Yueeeeeeee/LlamaRec
Branch: main
Commit: 48b288b23197
Files: 32
Total size: 159.5 KB

Directory structure:
gitextract_v751wz7n/

├── .gitignore
├── README.md
├── config.py
├── dataloader/
│   ├── __init__.py
│   ├── base.py
│   ├── llm.py
│   ├── lru.py
│   ├── templates/
│   │   ├── README.md
│   │   ├── alpaca.json
│   │   ├── alpaca_legacy.json
│   │   ├── alpaca_short.json
│   │   └── vigogne.json
│   └── utils.py
├── datasets/
│   ├── __init__.py
│   ├── base.py
│   ├── beauty.py
│   ├── games.py
│   ├── ml_100k.py
│   └── utils.py
├── model/
│   ├── __init__.py
│   ├── llm.py
│   └── lru.py
├── requirements.txt
├── train_ranker.py
├── train_retriever.py
└── trainer/
    ├── __init__.py
    ├── base.py
    ├── llm.py
    ├── loggers.py
    ├── lru.py
    ├── utils.py
    └── verb.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pyc
*.p
*.pt
*.pth
.DS_Store

/.ipynb_checkpoints/*
/.vscode/*
/wandb/*

/data/*
/retrieved/*
/experiments/*
/archive/*

================================================
FILE: README.md
================================================
# LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking

This repository is the PyTorch impelementation for the PGAI@CIKM 2023 paper **LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking [[Paper](https://arxiv.org/abs/2311.02089)]**.

<img src=media/method.png width=1000>

We propose a two-stage framework using large language models for ranking-based recommendation (LlamaRec). In particular, we use small-scale sequential recommenders to retrieve candidates based on the user interaction history. Then, both history and retrieved items are fed to the LLM in text via a carefully designed prompt template. Instead of generating next-item titles, we adopt a verbalizer-based approach that transforms output logits into probability distributions over the candidate items. Therefore, LlamaRec can efficiently rank items without generating long text and achieve superior performance in both recommendation performance and efficiency.


## Requirements

Pytorch, transformers, peft, bitsandbytes etc. For our detailed running environment see requirements.txt.


## How to run LlamaRec
The command below starts the training of the retriever model LRURec
```bash
python train_retriever.py
```
You can set additional arguments like weight_decay to change the hyperparameters. Upon the command, you will be prompted to select dataset from ML-100k, Beauty and Games. Once training is finished, evaluation is automatically performed with the best retriever model.

Then, run the following command to train the ranker model based on Llama 2
```bash
python train_ranker.py --llm_retrieved_path PATH_TO_RETRIEVER
```
Please specify PATH_TO_RETRIEVER with the retriever path from the previous step. To run this command, you will need access to meta-llama/Llama-2-7b-hf on the HF hub. Similarly, evaluation is performed after training is finished. All weights and results are saved under ./experiments.


## Performance

The table below reports our main performance results, with best results marked in bold and second best results underlined. For training and evaluation details, please refer to our paper.

<img src=media/performance.png width=1000>


## Citation
Please consider citing the following papers if you use our methods in your research:
```
@article{yue2023linear,
  title={Linear Recurrent Units for Sequential Recommendation},
  author={Yue, Zhenrui and Wang, Yueqi and He, Zhankui and Zeng, Huimin and McAuley, Julian and Wang, Dong},
  journal={arXiv preprint arXiv:2310.02367},
  year={2023}
}

@article{yue2023llamarec,
  title={LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking},
  author={Yue, Zhenrui and Rabhi, Sara and Moreira, Gabriel de Souza Pereira and Wang, Dong and Oldridge, Even},
  journal={arXiv preprint arXiv:2311.02089},
  year={2023}
}
```


================================================
FILE: config.py
================================================
import numpy as np
import random
import torch
import argparse


RAW_DATASET_ROOT_FOLDER = 'data'
EXPERIMENT_ROOT = 'experiments'
STATE_DICT_KEY = 'model_state_dict'
OPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict'
PROJECT_NAME = 'llmrec'


def set_template(args):
    if args.dataset_code == None:
        print('******************** Dataset Selection ********************')
        dataset_code = {'1': 'ml-100k', 'b': 'beauty', 'g': 'games'}
        args.dataset_code = dataset_code[input('Input 1 for ml-100k, b for beauty and g for games: ')]

    if args.dataset_code == 'ml-100k':
        args.bert_max_len = 200
    else:
        args.bert_max_len = 50

    if 'llm' in args.model_code: 
        batch = 16 if args.dataset_code == 'ml-100k' else 12
        args.lora_micro_batch_size = batch
    else: 
        batch = 16 if args.dataset_code == 'ml-100k' else 64

    args.train_batch_size = batch
    args.val_batch_size = batch
    args.test_batch_size = batch

    if torch.cuda.is_available(): args.device = 'cuda'
    else: args.device = 'cpu'
    args.optimizer = 'AdamW'
    args.lr = 0.001
    args.weight_decay = 0.01
    args.enable_lr_schedule = False
    args.decay_step = 10000
    args.gamma = 1.
    args.enable_lr_warmup = False
    args.warmup_steps = 100

    args.metric_ks = [1, 5, 10, 20, 50]
    args.rerank_metric_ks = [1, 5, 10]
    args.best_metric = 'Recall@10'
    args.rerank_best_metric = 'NDCG@10'

    args.bert_num_blocks = 2
    args.bert_num_heads = 2
    args.bert_head_size = None


parser = argparse.ArgumentParser()

################
# Dataset
################
parser.add_argument('--dataset_code', type=str, default=None)
parser.add_argument('--min_rating', type=int, default=0)
parser.add_argument('--min_uc', type=int, default=5)
parser.add_argument('--min_sc', type=int, default=5)
parser.add_argument('--seed', type=int, default=42)

################
# Dataloader
################
parser.add_argument('--train_batch_size', type=int, default=64)
parser.add_argument('--val_batch_size', type=int, default=64)
parser.add_argument('--test_batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--sliding_window_size', type=float, default=1.0)
parser.add_argument('--negative_sample_size', type=int, default=10)

################
# Trainer
################
# optimization #
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
parser.add_argument('--num_epochs', type=int, default=500)
parser.add_argument('--optimizer', type=str, default='AdamW', choices=['AdamW', 'Adam'])
parser.add_argument('--weight_decay', type=float, default=None)
parser.add_argument('--adam_epsilon', type=float, default=1e-9)
parser.add_argument('--momentum', type=float, default=None)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--max_grad_norm', type=float, default=5.0)
parser.add_argument('--enable_lr_schedule', type=bool, default=True)
parser.add_argument('--decay_step', type=int, default=10000)
parser.add_argument('--gamma', type=float, default=1)
parser.add_argument('--enable_lr_warmup', type=bool, default=True)
parser.add_argument('--warmup_steps', type=int, default=100)

# evaluation #
parser.add_argument('--val_strategy', type=str, default='iteration', choices=['epoch', 'iteration'])
parser.add_argument('--val_iterations', type=int, default=500)  # only for iteration val_strategy
parser.add_argument('--early_stopping', type=bool, default=True)
parser.add_argument('--early_stopping_patience', type=int, default=20)
parser.add_argument('--metric_ks', nargs='+', type=int, default=[1, 5, 10, 20, 50])
parser.add_argument('--rerank_metric_ks', nargs='+', type=int, default=[1, 5, 10])
parser.add_argument('--best_metric', type=str, default='Recall@10')
parser.add_argument('--rerank_best_metric', type=str, default='NDCG@10')
parser.add_argument('--use_wandb', type=bool, default=False)

################
# Retriever Model
################
parser.add_argument('--model_code', type=str, default=None)
parser.add_argument('--bert_max_len', type=int, default=50)
parser.add_argument('--bert_hidden_units', type=int, default=64)
parser.add_argument('--bert_num_blocks', type=int, default=2)
parser.add_argument('--bert_num_heads', type=int, default=2)
parser.add_argument('--bert_head_size', type=int, default=32)
parser.add_argument('--bert_dropout', type=float, default=0.2)
parser.add_argument('--bert_attn_dropout', type=float, default=0.2)
parser.add_argument('--bert_mask_prob', type=float, default=0.25)

################
# LLM Model
################
parser.add_argument('--llm_base_model', type=str, default='meta-llama/Llama-2-7b-hf')
parser.add_argument('--llm_base_tokenizer', type=str, default='meta-llama/Llama-2-7b-hf')
parser.add_argument('--llm_max_title_len', type=int, default=32)
parser.add_argument('--llm_max_text_len', type=int, default=1536)
parser.add_argument('--llm_max_history', type=int, default=20)
parser.add_argument('--llm_train_on_inputs', type=bool, default=False)
parser.add_argument('--llm_negative_sample_size', type=int, default=19)  # 19 negative & 1 positive
parser.add_argument('--llm_system_template', type=str,  # instruction
    default="Given user history in chronological order, recommend an item from the candidate pool with its index letter.")
parser.add_argument('--llm_input_template', type=str, \
    default='User history: {}; \n Candidate pool: {}')
parser.add_argument('--llm_load_in_4bit', type=bool, default=True)
parser.add_argument('--llm_retrieved_path', type=str, default=None)
parser.add_argument('--llm_cache_dir', type=str, default=None)

################
# Lora
################
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--lora_alpha', type=int, default=32)
parser.add_argument('--lora_dropout', type=float, default=0.05)
parser.add_argument('--lora_target_modules', type=list, default=['q_proj', 'v_proj'])
parser.add_argument('--lora_num_epochs', type=int, default=1)
parser.add_argument('--lora_val_iterations', type=int, default=100)
parser.add_argument('--lora_early_stopping_patience', type=int, default=20)
parser.add_argument('--lora_lr', type=float, default=1e-4)
parser.add_argument('--lora_micro_batch_size', type=int, default=16)

################


args = parser.parse_args()


================================================
FILE: dataloader/__init__.py
================================================
from datasets import dataset_factory

from .lru import *
from .llm import *
from .utils import *


def dataloader_factory(args):
    dataset = dataset_factory(args)
    if args.model_code == 'lru':
        dataloader = LRUDataloader(args, dataset)
    elif args.model_code == 'llm':
        dataloader = LLMDataloader(args, dataset)
    
    train, val, test = dataloader.get_pytorch_dataloaders()
    if 'llm' in args.model_code:
        tokenizer = dataloader.tokenizer
        test_retrieval = dataloader.test_retrieval
        return train, val, test, tokenizer, test_retrieval
    else:
        return train, val, test


def test_subset_dataloader_loader(args):
    dataset = dataset_factory(args)
    if args.model_code == 'lru':
        dataloader = LRUDataloader(args, dataset)
    elif args.model_code == 'llm':
        dataloader = LLMDataloader(args, dataset)

    return dataloader.get_pytorch_test_subset_dataloader()


================================================
FILE: dataloader/base.py
================================================
from abc import *
import random


class AbstractDataloader(metaclass=ABCMeta):
    def __init__(self, args, dataset):
        self.args = args
        self.save_folder = dataset._get_preprocessed_folder_path()
        dataset = dataset.load_dataset()
        self.train = dataset['train']
        self.val = dataset['val']
        self.test = dataset['test']
        self.meta = dataset['meta']
        self.umap = dataset['umap']
        self.smap = dataset['smap']
        self.user_count = len(self.umap)
        self.item_count = len(self.smap)

    @classmethod
    @abstractmethod
    def code(cls):
        pass

    @abstractmethod
    def get_pytorch_dataloaders(self):
        pass


================================================
FILE: dataloader/llm.py
================================================
from .base import AbstractDataloader
from .utils import Prompter

import torch
import random
import numpy as np
import torch.utils.data as data_utils

import os
import pickle
import transformers
from transformers import AutoTokenizer
from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
from trainer import absolute_recall_mrr_ndcg_for_ks


def worker_init_fn(worker_id):
    random.seed(np.random.get_state()[1][0] + worker_id)                                                      
    np.random.seed(np.random.get_state()[1][0] + worker_id)


# the following prompting is based on alpaca
def generate_and_tokenize_eval(args, data_point, tokenizer, prompter):
    in_prompt = prompter.generate_prompt(data_point["system"],
                                         data_point["input"])
    tokenized_full_prompt = tokenizer(in_prompt,
                                      truncation=True,
                                      max_length=args.llm_max_text_len,
                                      padding=False,
                                      return_tensors=None)
    tokenized_full_prompt["labels"] = ord(data_point["output"]) - ord('A')
    
    return tokenized_full_prompt


def generate_and_tokenize_train(args, data_point, tokenizer, prompter):
    def tokenize(prompt, add_eos_token=True):
        result = tokenizer(prompt,
                           truncation=True,
                           max_length=args.llm_max_text_len,
                           padding=False,
                           return_tensors=None)
        if (result["input_ids"][-1] != tokenizer.eos_token_id and add_eos_token):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()
        return result

    full_prompt = prompter.generate_prompt(data_point["system"],
                                           data_point["input"],
                                           data_point["output"])
    tokenized_full_prompt = tokenize(full_prompt, add_eos_token=True)
    if not args.llm_train_on_inputs:
        tokenized_full_prompt["labels"][:-2] = [-100] * len(tokenized_full_prompt["labels"][:-2])
    
    return tokenized_full_prompt


def seq_to_token_ids(args, seq, candidates, label, text_dict, tokenizer, prompter, eval=False):
    def truncate_title(title):
        title_ = tokenizer.tokenize(title)[:args.llm_max_title_len]
        title = tokenizer.convert_tokens_to_string(title_)
        return title

    seq_t = ' \n '.join(['(' + str(idx + 1) + ') ' + truncate_title(text_dict[item]) 
                       for idx, item in enumerate(seq)])
    can_t = ' \n '.join(['(' + chr(ord('A') + idx) + ') ' + truncate_title(text_dict[item])
                       for idx, item in enumerate(candidates)])
    output = chr(ord('A') + candidates.index(label))  # ranking only
    
    data_point = {}
    data_point['system'] = args.llm_system_template if args.llm_system_template is not None else DEFAULT_SYSTEM_PROMPT
    data_point['input'] = args.llm_input_template.format(seq_t, can_t)
    data_point['output'] = output
    
    if eval:
        return generate_and_tokenize_eval(args, data_point, tokenizer, prompter)
    else:
        return generate_and_tokenize_train(args, data_point, tokenizer, prompter)


class LLMDataloader():
    def __init__(self, args, dataset):
        self.args = args
        self.rng = np.random
        self.save_folder = dataset._get_preprocessed_folder_path()
        seq_dataset = dataset.load_dataset()
        self.train = seq_dataset['train']
        self.val = seq_dataset['val']
        self.test = seq_dataset['test']
        self.umap = seq_dataset['umap']
        self.smap = seq_dataset['smap']
        self.text_dict = seq_dataset['meta']
        self.user_count = len(self.umap)
        self.item_count = len(self.smap)
        
        args.num_items = self.item_count
        self.max_len = args.llm_max_history
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.llm_base_tokenizer, cache_dir=args.llm_cache_dir)
        self.tokenizer.pad_token = self.tokenizer.unk_token
        self.tokenizer.padding_side = 'left'
        self.tokenizer.truncation_side = 'left'
        self.tokenizer.clean_up_tokenization_spaces = True
        self.prompter = Prompter()
        
        self.llm_retrieved_path = args.llm_retrieved_path
        print('Loading retrieved file from {}'.format(self.llm_retrieved_path))
        retrieved_file = pickle.load(open(os.path.join(args.llm_retrieved_path,
                                                       'retrieved.pkl'), 'rb'))
        
        print('******************** Constructing Validation Subset ********************')
        self.val_probs = retrieved_file['val_probs']
        self.val_labels = retrieved_file['val_labels']
        self.val_metrics = retrieved_file['val_metrics']
        self.val_users = [u for u, (p, l) in enumerate(zip(self.val_probs, self.val_labels), start=1) \
                          if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]
        self.val_candidates = [torch.topk(torch.tensor(self.val_probs[u-1]), 
                                self.args.llm_negative_sample_size+1).indices.tolist() for u in self.val_users]

        print('******************** Constructing Test Subset ********************')
        self.test_probs = retrieved_file['test_probs']
        self.test_labels = retrieved_file['test_labels']
        self.test_metrics = retrieved_file['test_metrics']
        self.test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \
                          if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]
        self.test_candidates = [torch.topk(torch.tensor(self.test_probs[u-1]), 
                                self.args.llm_negative_sample_size+1).indices.tolist() for u in self.test_users]
        self.non_test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \
                               if l not in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]
        self.test_retrieval = {
            'original_size': len(self.test_probs),
            'retrieval_size': len(self.test_candidates),
            'original_metrics': self.test_metrics,
            'retrieval_metrics': absolute_recall_mrr_ndcg_for_ks(
                torch.tensor(self.test_probs)[torch.tensor(self.test_users)-1],
                torch.tensor(self.test_labels)[torch.tensor(self.test_users)-1],
                self.args.metric_ks,
            ),
            'non_retrieval_metrics': absolute_recall_mrr_ndcg_for_ks(
                torch.tensor(self.test_probs)[torch.tensor(self.non_test_users)-1],
                torch.tensor(self.test_labels)[torch.tensor(self.non_test_users)-1],
                self.args.metric_ks,
            ),
        }

    @classmethod
    def code(cls):
        return 'llm'

    def get_pytorch_dataloaders(self):
        train_loader = self._get_train_loader()
        val_loader = self._get_val_loader()
        test_loader = self._get_test_loader()
        return train_loader, val_loader, test_loader

    def _get_train_loader(self):
        dataset = self._get_train_dataset()
        dataloader = data_utils.DataLoader(dataset, batch_size=self.args.lora_micro_batch_size,
                                           shuffle=True, pin_memory=True, num_workers=self.args.num_workers,
                                           worker_init_fn=worker_init_fn)
        return dataloader

    def _get_train_dataset(self):
        dataset = LLMTrainDataset(self.args, self.train, self.max_len, self.rng,
                                  self.text_dict, self.tokenizer, self.prompter)
        return dataset

    def _get_val_loader(self):
        return self._get_eval_loader(mode='val')

    def _get_test_loader(self):
        return self._get_eval_loader(mode='test')

    def _get_eval_loader(self, mode):
        batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
        dataset = self._get_eval_dataset(mode)
        dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True,
                                           pin_memory=True, num_workers=self.args.num_workers)
        return dataloader

    def _get_eval_dataset(self, mode):
        if mode == 'val':
            dataset = LLMValidDataset(self.args, self.train, self.val, self.max_len, self.rng, \
                                      self.text_dict, self.tokenizer, self.prompter, self.val_users, \
                                      self.val_candidates)
        elif mode == 'test':
            dataset = LLMTestDataset(self.args, self.train, self.val, self.test, self.max_len, \
                                     self.rng, self.text_dict, self.tokenizer, self.prompter, self.test_users, \
                                     self.test_candidates)
        return dataset


class LLMTrainDataset(data_utils.Dataset):
    def __init__(self, args, u2seq, max_len, rng, text_dict, tokenizer, prompter):
        self.args = args
        self.max_len = max_len
        self.num_items = args.num_items
        self.rng = rng
        self.text_dict = text_dict
        self.tokenizer = tokenizer
        self.prompter = prompter

        self.all_seqs = []
        for u in sorted(u2seq.keys()):
            seq = u2seq[u]
            for i in range(2, len(seq)+1):
                self.all_seqs += [seq[:i]]

    def __len__(self):
        return len(self.all_seqs)

    def __getitem__(self, index):
        tokens = self.all_seqs[index]
        answer = tokens[-1]
        original_seq = tokens[:-1]
        
        seq = original_seq[-self.max_len:]
        cur_idx, candidates = 0, [answer]
        samples = self.rng.randint(1, self.args.num_items+1, size=5*self.args.llm_negative_sample_size)
        while len(candidates) < self.args.llm_negative_sample_size + 1:
            item = samples[cur_idx]
            cur_idx += 1
            if item in original_seq or item == answer: continue
            else: candidates.append(item)
        self.rng.shuffle(candidates)

        return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, \
                                self.tokenizer, self.prompter, eval=False)


class LLMValidDataset(data_utils.Dataset):
    def __init__(self, args, u2seq, u2answer, max_len, rng, text_dict, tokenizer, prompter, val_users, val_candidates):
        self.args = args
        self.u2seq = u2seq
        self.u2answer = u2answer
        self.users = sorted(self.u2seq.keys())
        self.max_len = max_len
        self.rng = rng
        self.text_dict = text_dict
        self.tokenizer = tokenizer
        self.prompter = prompter
        self.val_users = val_users
        self.val_candidates = val_candidates

    def __len__(self):
        return len(self.val_users)

    def __getitem__(self, index):
        user = self.val_users[index]
        seq = self.u2seq[user]
        answer = self.u2answer[user][0]
        
        seq = seq[-self.max_len:]
        candidates = self.val_candidates[index]
        assert answer in candidates
        # self.rng.shuffle(candidates)
        
        return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True)


class LLMTestDataset(data_utils.Dataset):
    def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, text_dict, tokenizer, prompter, test_users, test_candidates):
        self.args = args
        self.u2seq = u2seq
        self.u2val = u2val
        self.u2answer = u2answer
        self.users = sorted(u2seq.keys())
        self.max_len = max_len
        self.rng = rng
        self.text_dict = text_dict
        self.tokenizer = tokenizer
        self.prompter = prompter
        self.test_users = test_users
        self.test_candidates = test_candidates
    
    def __len__(self):
        return len(self.test_users)
    
    def __getitem__(self, index):
        user = self.test_users[index]
        seq = self.u2seq[user] + self.u2val[user]
        answer = self.u2answer[user][0]

        seq = seq[-self.max_len:]
        candidates = self.test_candidates[index]
        assert answer in candidates
        # self.rng.shuffle(candidates)

        return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True)

================================================
FILE: dataloader/lru.py
================================================
from .base import AbstractDataloader

import os
import torch
import random
import pickle
import numpy as np
import torch.utils.data as data_utils


def worker_init_fn(worker_id):
    random.seed(np.random.get_state()[1][0] + worker_id)                                                      
    np.random.seed(np.random.get_state()[1][0] + worker_id)


class LRUDataloader():
    def __init__(self, args, dataset):
        self.args = args
        self.rng = np.random
        self.save_folder = dataset._get_preprocessed_folder_path()
        dataset = dataset.load_dataset()
        self.train = dataset['train']
        self.val = dataset['val']
        self.test = dataset['test']
        self.umap = dataset['umap']
        self.smap = dataset['smap']
        self.user_count = len(self.umap)
        self.item_count = len(self.smap)

        args.num_users = self.user_count
        args.num_items = self.item_count
        self.max_len = args.bert_max_len
        self.sliding_size = args.sliding_window_size

    @classmethod
    def code(cls):
        return 'lru'

    def get_pytorch_dataloaders(self):
        train_loader = self._get_train_loader()
        val_loader = self._get_val_loader()
        test_loader = self._get_test_loader()
        return train_loader, val_loader, test_loader
    
    def get_pytorch_test_subset_dataloader(self):
        retrieved_file_path = self.args.llm_retrieved_path
        print('Loading retrieved file from {}'.format(retrieved_file_path))
        retrieved_file = pickle.load(open(os.path.join(retrieved_file_path,
                                                       'retrieved.pkl'), 'rb'))
        
        test_probs = retrieved_file['test_probs']
        test_labels = retrieved_file['test_labels']
        test_users = [u for u, (p, l) in enumerate(zip(test_probs, test_labels), start=1) \
                      if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]

        dataset = dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len, 
                                           self.rng, subset_users=test_users)
        dataloader = data_utils.DataLoader(dataset, batch_size=self.args.val_batch_size, shuffle=False,
                                           pin_memory=True, num_workers=self.args.num_workers)
        return dataloader

    def _get_train_loader(self):
        dataset = self._get_train_dataset()
        dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,
                        shuffle=True, pin_memory=True, num_workers=self.args.num_workers,
                        worker_init_fn=worker_init_fn)
        return dataloader

    def _get_train_dataset(self):
        dataset = LRUTrainDataset(
            self.args, self.train, self.max_len, self.sliding_size, self.rng)
        return dataset

    def _get_val_loader(self):
        return self._get_eval_loader(mode='val')

    def _get_test_loader(self):
        return self._get_eval_loader(mode='test')

    def _get_eval_loader(self, mode):
        batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
        dataset = self._get_eval_dataset(mode)
        dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        pin_memory=True, num_workers=self.args.num_workers)
        return dataloader

    def _get_eval_dataset(self, mode):
        if mode == 'val':
            dataset = LRUValidDataset(self.args, self.train, self.val, self.max_len, self.rng)
        elif mode == 'test':
            dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len, self.rng)
        return dataset


class LRUTrainDataset(data_utils.Dataset):
    def __init__(self, args, u2seq, max_len, sliding_size, rng):
        self.args = args
        self.max_len = max_len
        self.sliding_step = int(sliding_size * max_len)
        self.num_items = args.num_items
        self.rng = rng
        
        assert self.sliding_step > 0
        self.all_seqs = []
        for u in sorted(u2seq.keys()):
            seq = u2seq[u]
            if len(seq) < self.max_len + self.sliding_step:
                self.all_seqs.append(seq)
            else:
                start_idx = range(len(seq) - max_len, -1, -self.sliding_step)
                self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx]

    def __len__(self):
        return len(self.all_seqs)

    def __getitem__(self, index):
        seq = self.all_seqs[index]
        labels = seq[-self.max_len:]
        tokens = seq[:-1][-self.max_len:]

        mask_len = self.max_len - len(tokens)
        tokens = [0] * mask_len + tokens

        mask_len = self.max_len - len(labels)
        labels = [0] * mask_len + labels

        return torch.LongTensor(tokens), torch.LongTensor(labels)


class LRUValidDataset(data_utils.Dataset):
    def __init__(self, args, u2seq, u2answer, max_len, rng):
        self.args = args
        self.u2seq = u2seq
        self.u2answer = u2answer
        users = sorted(self.u2seq.keys())
        self.users = [u for u in users if len(u2answer[u]) > 0]
        self.max_len = max_len
        self.rng = rng
    
    def __len__(self):
        return len(self.users)

    def __getitem__(self, index):
        user = self.users[index]
        seq = self.u2seq[user]
        answer = self.u2answer[user]

        seq = seq[-self.max_len:]
        padding_len = self.max_len - len(seq)
        seq = [0] * padding_len + seq

        return torch.LongTensor(seq), torch.LongTensor(answer)


class LRUTestDataset(data_utils.Dataset):
    def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, subset_users=None):
        self.args = args
        self.u2seq = u2seq
        self.u2val = u2val
        self.u2answer = u2answer
        users = sorted(self.u2seq.keys())
        self.users = [u for u in users if len(u2val[u]) > 0 and len(u2answer[u]) > 0]
        self.max_len = max_len
        self.rng = rng
        
        if subset_users is not None:
            self.users = subset_users

    def __len__(self):
        return len(self.users)

    def __getitem__(self, index):
        user = self.users[index]
        seq = self.u2seq[user] + self.u2val[user]
        answer = self.u2answer[user]

        seq = seq[-self.max_len:]
        padding_len = self.max_len - len(seq)
        seq = [0] * padding_len + seq

        return torch.LongTensor(seq), torch.LongTensor(answer)

================================================
FILE: dataloader/templates/README.md
================================================
# Prompt templates

This directory contains template styles for the prompts used to finetune LoRA models.

## Format

A template is described via a JSON file with the following keys:

- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
- `description`: A short description of the template, with possible use cases.
- `response_split`: The text to use as separator when cutting real response from the model output.

No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.

## Example template

The default template, used unless otherwise specified, is `alpaca.json`

```json
{
    "description": "Template used by Alpaca-LoRA.",
    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
    "response_split": "### Response:"    
}

```

## Current templates

### alpaca

Default template used for generic LoRA fine tunes so far.

### alpaca_legacy

Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.

### alpaca_short

A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.

### vigogne

The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.


================================================
FILE: dataloader/templates/alpaca.json
================================================
{
    "description": "Template used by Alpaca-LoRA.",
    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
    "response_split": "### Response:"    
}


================================================
FILE: dataloader/templates/alpaca_legacy.json
================================================
{
    "description": "Legacy template, used by Original Alpaca repository.",
    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
    "response_split": "### Response:"    
}


================================================
FILE: dataloader/templates/alpaca_short.json
================================================
{
    "description": "A shorter template to experiment with.",
    "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
    "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
    "response_split": "### Response:"    
}


================================================
FILE: dataloader/templates/vigogne.json
================================================
{
    "description": "French template, used by Vigogne for finetuning.",
    "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
    "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
    "response_split": "### Réponse:"
}


================================================
FILE: dataloader/utils.py
================================================
import json
import os.path as osp
from typing import Union


class Prompter(object):
    __slots__ = ("template", "_verbose")

    def __init__(self, template_name: str = "", verbose: bool = False):
        self._verbose = verbose
        if not template_name:
            # template_name = "alpaca"
            template_name = "alpaca_short"
        file_name = osp.join("dataloader", "templates", f"{template_name}.json")
        if not osp.exists(file_name):
            raise ValueError(f"Can't read {file_name}")
        with open(file_name) as fp:
            self.template = json.load(fp)
        if self._verbose:
            print(
                f"Using prompt template {template_name}: {self.template['description']}"
            )

    def generate_prompt(
        self,
        instruction: str,
        input: Union[None, str] = None,
        label: Union[None, str] = None,
    ) -> str:
        if input:
            res = self.template["prompt_input"].format(
                instruction=instruction, input=input
            )
        else:
            res = self.template["prompt_no_input"].format(
                instruction=instruction
            )
        if label:
            res = f"{res}{label}"
        if self._verbose:
            print(res)
        return res

    def get_response(self, output: str) -> str:
        return output.split(self.template["response_split"])[1].strip()

================================================
FILE: datasets/__init__.py
================================================
from .ml_100k import ML100KDataset
from .beauty import BeautyDataset
from .games import GamesDataset

DATASETS = {
    ML100KDataset.code(): ML100KDataset,
    BeautyDataset.code(): BeautyDataset,
    GamesDataset.code(): GamesDataset,
}


def dataset_factory(args):
    dataset = DATASETS[args.dataset_code]
    return dataset(args)


================================================
FILE: datasets/base.py
================================================
import pickle
import shutil
import tempfile
import os
from pathlib import Path
import gzip
from abc import *
from .utils import *
from config import RAW_DATASET_ROOT_FOLDER

import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()


class AbstractDataset(metaclass=ABCMeta):
    def __init__(self, args):
        self.args = args
        self.min_rating = args.min_rating
        self.min_uc = args.min_uc
        self.min_sc = args.min_sc

        assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test'

    @classmethod
    @abstractmethod
    def code(cls):
        pass

    @classmethod
    def raw_code(cls):
        return cls.code()

    @classmethod
    def zip_file_content_is_folder(cls):
        return True

    @classmethod
    def all_raw_file_names(cls):
        return []

    @classmethod
    @abstractmethod
    def url(cls):
        pass

    @abstractmethod
    def preprocess(self):
        pass

    @abstractmethod
    def load_ratings_df(self):
        pass

    @abstractmethod
    def maybe_download_raw_dataset(self):
        pass

    def load_dataset(self):
        self.preprocess()
        dataset_path = self._get_preprocessed_dataset_path()
        dataset = pickle.load(dataset_path.open('rb'))
        return dataset

    def filter_triplets(self, df):
        print('Filtering triplets')
        if self.min_sc > 1 or self.min_uc > 1:
            item_sizes = df.groupby('sid').size()
            good_items = item_sizes.index[item_sizes >= self.min_sc]
            user_sizes = df.groupby('uid').size()
            good_users = user_sizes.index[user_sizes >= self.min_uc]
            while len(good_items) < len(item_sizes) or len(good_users) < len(user_sizes):
                if self.min_sc > 1:
                    item_sizes = df.groupby('sid').size()
                    good_items = item_sizes.index[item_sizes >= self.min_sc]
                    df = df[df['sid'].isin(good_items)]

                if self.min_uc > 1:
                    user_sizes = df.groupby('uid').size()
                    good_users = user_sizes.index[user_sizes >= self.min_uc]
                    df = df[df['uid'].isin(good_users)]

                item_sizes = df.groupby('sid').size()
                good_items = item_sizes.index[item_sizes >= self.min_sc]
                user_sizes = df.groupby('uid').size()
                good_users = user_sizes.index[user_sizes >= self.min_uc]
        return df
    
    def densify_index(self, df):
        print('Densifying index')
        umap = {u: i for i, u in enumerate(set(df['uid']), start=1)}
        smap = {s: i for i, s in enumerate(set(df['sid']), start=1)}
        df['uid'] = df['uid'].map(umap)
        df['sid'] = df['sid'].map(smap)
        return df, umap, smap

    def split_df(self, df, user_count):
        print('Splitting')
        user_group = df.groupby('uid')
        user2items = user_group.progress_apply(
            lambda d: list(d.sort_values(by=['timestamp', 'sid'])['sid']))
        train, val, test = {}, {}, {}
        for i in range(user_count):
            user = i + 1
            items = user2items[user]
            train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]
        return train, val, test

    def _get_rawdata_root_path(self):
        return Path(RAW_DATASET_ROOT_FOLDER)

    def _get_rawdata_folder_path(self):
        root = self._get_rawdata_root_path()
        return root.joinpath(self.raw_code())

    def _get_preprocessed_root_path(self):
        root = self._get_rawdata_root_path()
        return root.joinpath('preprocessed')

    def _get_preprocessed_folder_path(self):
        preprocessed_root = self._get_preprocessed_root_path()
        folder_name = '{}_min_rating{}-min_uc{}-min_sc{}' \
            .format(self.code(), self.min_rating, self.min_uc, self.min_sc)
        return preprocessed_root.joinpath(folder_name)

    def _get_preprocessed_dataset_path(self):
        folder = self._get_preprocessed_folder_path()
        return folder.joinpath('dataset.pkl')


================================================
FILE: datasets/beauty.py
================================================
from .base import AbstractDataset
from .utils import *

from datetime import date
from pathlib import Path
import pickle
import shutil
import tempfile
import os

import gzip
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()


class BeautyDataset(AbstractDataset):
    @classmethod
    def code(cls):
        return 'beauty'

    @classmethod
    def url(cls):
        return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Beauty.csv',
                'http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Beauty.json.gz']

    @classmethod
    def zip_file_content_is_folder(cls):
        return True

    @classmethod
    def all_raw_file_names(cls):
        return ['beauty.csv', 'beauty_meta.json.gz']

    def maybe_download_raw_dataset(self):
        folder_path = self._get_rawdata_folder_path()
        if folder_path.is_dir() and\
           all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
            print('Raw data already exists. Skip downloading')
            return
        
        print("Raw file doesn't exist. Downloading...")
        for idx, url in enumerate(self.url()):
            tmproot = Path(tempfile.mkdtemp())
            tmpfile = tmproot.joinpath('file')
            download(url, tmpfile)
            os.makedirs(folder_path, exist_ok=True)
            shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx]))
            print()

    def preprocess(self):
        dataset_path = self._get_preprocessed_dataset_path()
        if dataset_path.is_file():
            print('Already preprocessed. Skip preprocessing')
            return
        if not dataset_path.parent.is_dir():
            dataset_path.parent.mkdir(parents=True)
        self.maybe_download_raw_dataset()
        df = self.load_ratings_df()
        meta_raw = self.load_meta_dict()
        df = df[df['sid'].isin(meta_raw)]  # filter items without meta info
        df = self.filter_triplets(df)
        df, umap, smap = self.densify_index(df)
        train, val, test = self.split_df(df, len(umap))
        meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}
        dataset = {'train': train,
                   'val': val,
                   'test': test,
                   'meta': meta,
                   'umap': umap,
                   'smap': smap}
        with dataset_path.open('wb') as f:
            pickle.dump(dataset, f)

    def load_ratings_df(self):
        folder_path = self._get_rawdata_folder_path()
        file_path = folder_path.joinpath(self.all_raw_file_names()[0])
        df = pd.read_csv(file_path, header=None)
        df.columns = ['uid', 'sid', 'rating', 'timestamp']
        return df
    
    def load_meta_dict(self):
        folder_path = self._get_rawdata_folder_path()
        file_path = folder_path.joinpath(self.all_raw_file_names()[1])

        meta_dict = {}
        with gzip.open(file_path, 'rb') as f:
            for line in f:
                item = eval(line)
                if 'title' in item and len(item['title']) > 0:
                    meta_dict[item['asin'].strip()] = item['title'].strip()
        
        return meta_dict


================================================
FILE: datasets/games.py
================================================
from .base import AbstractDataset
from .utils import *

from datetime import date
from pathlib import Path
import pickle
import shutil
import tempfile
import os

import gzip
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()


class GamesDataset(AbstractDataset):
    @classmethod
    def code(cls):
        return 'games'

    @classmethod
    def url(cls):
        # meta_Video_Games.json.gz from snap.stanford.edu does not contain full meta info
        return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Video_Games.csv',
                'https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2/metaFiles2/meta_Video_Games.json.gz']

    @classmethod
    def zip_file_content_is_folder(cls):
        return True

    @classmethod
    def all_raw_file_names(cls):
        return ['games.csv', 'games_meta.json.gz']

    def maybe_download_raw_dataset(self):
        folder_path = self._get_rawdata_folder_path()
        if folder_path.is_dir() and\
           all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
            print('Raw data already exists. Skip downloading')
            return
        
        print("Raw file doesn't exist. Downloading...")
        for idx, url in enumerate(self.url()):
            tmproot = Path(tempfile.mkdtemp())
            tmpfile = tmproot.joinpath('file')
            download(url, tmpfile)
            os.makedirs(folder_path, exist_ok=True)
            shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx]))
            print()

    def preprocess(self):
        dataset_path = self._get_preprocessed_dataset_path()
        if dataset_path.is_file():
            print('Already preprocessed. Skip preprocessing')
            return
        if not dataset_path.parent.is_dir():
            dataset_path.parent.mkdir(parents=True)
        self.maybe_download_raw_dataset()
        df = self.load_ratings_df()
        meta_raw = self.load_meta_dict()
        df = df[df['sid'].isin(meta_raw)]  # filter items without meta info
        df = self.filter_triplets(df)
        df, umap, smap = self.densify_index(df)
        train, val, test = self.split_df(df, len(umap))
        meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}
        dataset = {'train': train,
                   'val': val,
                   'test': test,
                   'meta': meta,
                   'umap': umap,
                   'smap': smap}
        with dataset_path.open('wb') as f:
            pickle.dump(dataset, f)

    def load_ratings_df(self):
        folder_path = self._get_rawdata_folder_path()
        file_path = folder_path.joinpath(self.all_raw_file_names()[0])
        df = pd.read_csv(file_path, header=None)
        df.columns = ['uid', 'sid', 'rating', 'timestamp']
        return df
    
    def load_meta_dict(self):
        folder_path = self._get_rawdata_folder_path()
        file_path = folder_path.joinpath(self.all_raw_file_names()[1])

        meta_dict = {}
        with gzip.open(file_path, 'rb') as f:
            for line in f:
                item = eval(line)
                if 'title' in item and len(item['title']) > 0:
                    meta_dict[item['asin'].strip()] = item['title'].strip()
        
        return meta_dict


================================================
FILE: datasets/ml_100k.py
================================================
from .base import AbstractDataset
from .utils import *

from datetime import date
from pathlib import Path
import pickle
import shutil
import tempfile
import os

import re
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()


class ML100KDataset(AbstractDataset):
    @classmethod
    def code(cls):
        return 'ml-100k'

    @classmethod
    def url(cls):  # as of Sep 2023
        return 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'

    @classmethod
    def zip_file_content_is_folder(cls):
        return True

    @classmethod
    def all_raw_file_names(cls):
        return ['README',
                'movies.csv',
                'ratings.csv',
                'users.csv']

    def maybe_download_raw_dataset(self):
        folder_path = self._get_rawdata_folder_path()
        if folder_path.is_dir() and\
           all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
            print('Raw data already exists. Skip downloading')
            return

        print("Raw file doesn't exist. Downloading...")
        tmproot = Path(tempfile.mkdtemp())
        tmpzip = tmproot.joinpath('file.zip')
        tmpfolder = tmproot.joinpath('folder')
        download(self.url(), tmpzip)
        unzip(tmpzip, tmpfolder)
        if self.zip_file_content_is_folder():
            tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0])
        shutil.move(tmpfolder, folder_path)
        shutil.rmtree(tmproot)
        print()

    def preprocess(self):
        dataset_path = self._get_preprocessed_dataset_path()
        if dataset_path.is_file():
            print('Already preprocessed. Skip preprocessing')
            return
        if not dataset_path.parent.is_dir():
            dataset_path.parent.mkdir(parents=True)
        self.maybe_download_raw_dataset()
        df = self.load_ratings_df()
        meta_raw = self.load_meta_dict()
        df = df[df['sid'].isin(meta_raw)]  # filter items without meta info
        df = self.filter_triplets(df)
        df, umap, smap = self.densify_index(df)
        train, val, test = self.split_df(df, len(umap))
        meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}
        dataset = {'train': train,
                   'val': val,
                   'test': test,
                   'meta': meta,
                   'umap': umap,
                   'smap': smap}
        with dataset_path.open('wb') as f:
            pickle.dump(dataset, f)

    def load_ratings_df(self):
        folder_path = self._get_rawdata_folder_path()
        file_path = folder_path.joinpath('ratings.csv')
        df = pd.read_csv(file_path)
        df.columns = ['uid', 'sid', 'rating', 'timestamp']
        return df

    def load_meta_dict(self):
        folder_path = self._get_rawdata_folder_path()
        file_path = folder_path.joinpath('movies.csv')
        df = pd.read_csv(file_path, encoding="ISO-8859-1")
        meta_dict = {}
        for row in df.itertuples():
            title = row[2][:-7]  # remove year (optional)
            year = row[2][-7:]

            title = re.sub('\(.*?\)', '', title).strip()
            # the rest articles and parentheses are not considered here
            if any(', '+x in title.lower()[-5:] for x in ['a', 'an', 'the']):
                title_pre = title.split(', ')[:-1]
                title_post = title.split(', ')[-1]
                title_pre = ', '.join(title_pre)
                title = title_post + ' ' + title_pre

            meta_dict[row[1]] = title + year
        return meta_dict


================================================
FILE: datasets/utils.py
================================================
import numpy as np
import pandas as pd
from tqdm import tqdm
import urllib.request


from pathlib import Path
import zipfile
import tarfile
import sys


def download(url, savepath):
    urllib.request.urlretrieve(url, str(savepath))
    print()


def unzip(zippath, savepath):
    print("Extracting data...")
    zip = zipfile.ZipFile(zippath)
    zip.extractall(savepath)
    zip.close()


def unziptargz(zippath, savepath):
    print("Extracting data...")
    f = tarfile.open(zippath)
    f.extractall(savepath)
    f.close()


================================================
FILE: model/__init__.py
================================================
from .lru import *
from .llm import *

================================================
FILE: model/llm.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from transformers.models.llama.configuration_llama import LlamaConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LlamaConfig"


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pretraining_tp = config.pretraining_tp
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        if self.pretraining_tp > 1:
            slice = self.intermediate_size // self.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
            down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
            down_proj = sum(down_proj)
        else:
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.pretraining_tp = config.pretraining_tp
        self.max_position_embeddings = config.max_position_embeddings

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
            query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(config=config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


LLAMA_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`LlamaConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlamaDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, LlamaModel):
            module.gradient_checkpointing = value


LLAMA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class LlamaForCausalLM(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if self.training and labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
        elif labels is not None:
            loss = torch.tensor(-1.)  # loss cannot be directly computed in inference

        logits = logits[:, -1]  # we only need last position logits for inference 

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -1].unsqueeze(-1)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

================================================
FILE: model/lru.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np


class LRURec(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.embedding = LRUEmbedding(self.args)
        self.model = LRUModel(self.args)
        self.truncated_normal_init()

    def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
        with torch.no_grad():
            l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
            u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.

            for n, p in self.named_parameters():
                if not 'layer_norm' in n and 'params_log' not in n:
                    if torch.is_complex(p):
                        p.real.uniform_(2 * l - 1, 2 * u - 1)
                        p.imag.uniform_(2 * l - 1, 2 * u - 1)
                        p.real.erfinv_()
                        p.imag.erfinv_()
                        p.real.mul_(std * math.sqrt(2.))
                        p.imag.mul_(std * math.sqrt(2.))
                        p.real.add_(mean)
                        p.imag.add_(mean)
                    else:
                        p.uniform_(2 * l - 1, 2 * u - 1)
                        p.erfinv_()
                        p.mul_(std * math.sqrt(2.))
                        p.add_(mean)

    def forward(self, x):
        x, mask = self.embedding(x)
        scores = self.model(x, self.embedding.token.weight, mask)
        return scores


class LRUEmbedding(nn.Module):
    def __init__(self, args):
        super().__init__()
        vocab_size = args.num_items + 1
        embed_size = args.bert_hidden_units
        
        self.token = nn.Embedding(vocab_size, embed_size)
        self.layer_norm = nn.LayerNorm(embed_size)
        self.embed_dropout = nn.Dropout(args.bert_dropout)

    def get_mask(self, x):
        return (x > 0)

    def forward(self, x):
        mask = self.get_mask(x)
        x = self.token(x)
        return self.layer_norm(self.embed_dropout(x)), mask


class LRUModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.hidden_size = args.bert_hidden_units
        layers = args.bert_num_blocks

        self.lru_blocks = nn.ModuleList([LRUBlock(self.args) for _ in range(layers)])
        self.bias = torch.nn.Parameter(torch.zeros(args.num_items + 1))

    def forward(self, x, embedding_weight, mask):
        # left padding to the power of 2
        seq_len = x.size(1)
        log2_L = int(np.ceil(np.log2(seq_len)))
        x = F.pad(x, (0, 0, 2 ** log2_L - x.size(1), 0, 0, 0))
        mask_ = F.pad(mask, (2 ** log2_L - mask.size(1), 0, 0, 0))

        # LRU blocks with pffn
        for lru_block in self.lru_blocks:
            x = lru_block.forward(x, mask_)
        x = x[:, -seq_len:]  # B x L x D (64)

        scores = torch.matmul(x, embedding_weight.permute(1, 0)) + self.bias
        return scores


class LRUBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        hidden_size = args.bert_hidden_units
        self.lru_layer = LRULayer(
            d_model=hidden_size, dropout=args.bert_attn_dropout)
        self.feed_forward = PositionwiseFeedForward(
            d_model=hidden_size, d_ff=hidden_size*4, dropout=args.bert_dropout)
    
    def forward(self, x, mask):
        x = self.lru_layer(x, mask)
        x = self.feed_forward(x)
        return x
    

class LRULayer(nn.Module):
    def __init__(self,
                 d_model,
                 dropout=0.1,
                 use_bias=True,
                 r_min=0.8,
                 r_max=0.99):
        super().__init__()
        self.embed_size = d_model
        self.hidden_size = 2 * d_model
        self.use_bias = use_bias

        # init nu, theta, gamma
        u1 = torch.rand(self.hidden_size)
        u2 = torch.rand(self.hidden_size)
        nu_log = torch.log(-0.5 * torch.log(u1 * (r_max ** 2 - r_min ** 2) + r_min ** 2))
        theta_log = torch.log(u2 * torch.tensor(np.pi) * 2)
        diag_lambda = torch.exp(torch.complex(-torch.exp(nu_log), torch.exp(theta_log)))
        gamma_log = torch.log(torch.sqrt(1 - torch.abs(diag_lambda) ** 2))
        self.params_log = nn.Parameter(torch.vstack((nu_log, theta_log, gamma_log)))

        # Init B, C, D
        self.in_proj = nn.Linear(self.embed_size, self.hidden_size, bias=use_bias).to(torch.cfloat)
        self.out_proj = nn.Linear(self.hidden_size, self.embed_size, bias=use_bias).to(torch.cfloat)
        # self.out_vector = nn.Parameter(torch.rand(self.embed_size))
        self.out_vector = nn.Identity()
        
        # Dropout and layer norm
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(self.embed_size)

    def lru_parallel(self, i, h, lamb, mask, B, L, D):
        # Parallel algorithm, see: https://kexue.fm/archives/9554#%E5%B9%B6%E8%A1%8C%E5%8C%96
        # The original implementation is slightly slower and does not consider 0 padding
        l = 2 ** i
        h = h.reshape(B * L // l, l, D)  # (B, L, D) -> (B * L // 2, 2, D)
        mask_ = mask.reshape(B * L // l, l)  # (B, L) -> (B * L // 2, 2)
        h1, h2 = h[:, :l // 2], h[:, l // 2:]  # Divide data in half

        if i > 1: lamb = torch.cat((lamb, lamb * lamb[-1]), 0)
        h2 = h2 + lamb * h1[:, -1:] * mask_[:, l // 2 - 1:l // 2].unsqueeze(-1)
        h = torch.cat([h1, h2], axis=1)
        return h, lamb

    def forward(self, x, mask):
        # compute bu and lambda
        nu, theta, gamma = torch.exp(self.params_log).split((1, 1, 1))
        lamb = torch.exp(torch.complex(-nu, theta))
        h = self.in_proj(x.to(torch.cfloat)) * gamma  # bu
        
        # compute h in parallel
        log2_L = int(np.ceil(np.log2(h.size(1))))
        B, L, D = h.size(0), h.size(1), h.size(2)
        for i in range(log2_L):
            h, lamb = self.lru_parallel(i + 1, h, lamb, mask, B, L, D)
        x = self.dropout(self.out_proj(h).real) + self.out_vector(x)
        return self.layer_norm(x)  # residual connection introduced above 
    

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x_ = self.dropout(self.activation(self.w_1(x)))
        return self.layer_norm(self.dropout(self.w_2(x_)) + x)

================================================
FILE: requirements.txt
================================================
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
_openmp_mutex=5.1=1_gnu
absl-py=1.4.0=pypi_0
accelerate=0.21.0=pypi_0
aiofiles=23.1.0=pypi_0
aiohttp=3.8.4=pypi_0
aiosignal=1.3.1=pypi_0
alabaster=0.7.13=pypi_0
altair=4.2.2=pypi_0
aniso8601=9.0.1=pypi_0
antlr4-python3-runtime=4.9.3=pypi_0
anyio=3.6.2=pypi_0
appdirs=1.4.4=pypi_0
asttokens=2.2.1=pypi_0
async-timeout=4.0.2=pypi_0
attrdict=2.0.1=pypi_0
attrs=23.1.0=pypi_0
audioread=3.0.0=pypi_0
babel=2.12.1=pypi_0
backcall=0.2.0=pypi_0
beautifulsoup4=4.12.2=pypi_0
bitsandbytes=0.41.1=pypi_0
black=19.10b0=pypi_0
blas=1.0=mkl
boto3=1.26.160=pypi_0
botocore=1.29.160=pypi_0
braceexpand=0.1.7=pypi_0
brotlipy=0.7.0=py310h7f8727e_1002
bzip2=1.0.8=h7b6447c_0
ca-certificates=2023.08.22=h06a4308_0
cachetools=5.3.0=pypi_0
cdifflib=1.2.6=pypi_0
certifi=2022.12.7=py310h06a4308_0
cffi=1.15.1=py310h5eee18b_3
charset-normalizer=2.0.4=pyhd3eb1b0_0
click=8.0.2=pypi_0
clip=1.0=pypi_0
colorama=0.4.6=pypi_0
comm=0.1.3=pypi_0
contourpy=1.0.7=pypi_0
cryptography=39.0.1=py310h9ce1e76_0
cuda-cudart=11.8.89=0
cuda-cupti=11.8.87=0
cuda-libraries=11.8.0=0
cuda-nvrtc=11.8.89=0
cuda-nvtx=11.8.86=0
cuda-runtime=11.8.0=0
cudatoolkit=11.8.0=h6a678d5_0
cycler=0.11.0=pypi_0
cython=0.29.35=pypi_0
datasets=2.14.5=pypi_0
debugpy=1.6.7=pypi_0
decorator=5.1.1=pypi_0
dill=0.3.6=pypi_0
distance=0.1.3=pypi_0
docker-pycreds=0.4.0=pypi_0
docopt=0.6.2=pypi_0
docutils=0.20.1=pypi_0
editdistance=0.6.2=pypi_0
einops=0.6.1=pypi_0
emoji=2.8.0=pypi_0
entrypoints=0.4=pypi_0
evaluate=0.3.0=pypi_0
exceptiongroup=1.1.1=pypi_0
executing=1.2.0=pypi_0
fairscale=0.4.13=pypi_0
faiss-cpu=1.7.4=pypi_0
faiss-gpu=1.7.2=pypi_0
fastapi=0.95.1=pypi_0
fasttext=0.9.2=pypi_0
ffmpeg=4.3=hf484d3e_0
ffmpy=0.3.0=pypi_0
filelock=3.9.0=py310h06a4308_0
fire=0.5.0=pypi_0
flask=2.2.5=pypi_0
flask-restful=0.3.10=pypi_0
flit-core=3.8.0=py310h06a4308_0
fonttools=4.39.3=pypi_0
freetype=2.12.1=h4a9f257_0
frozenlist=1.3.3=pypi_0
fsspec=2023.4.0=pypi_0
ftfy=6.1.1=pypi_0
future=0.18.3=pypi_0
g2p-en=2.1.0=pypi_0
gdown=4.7.1=pypi_0
giflib=5.2.1=h5eee18b_3
gitdb=4.0.10=pypi_0
gitpython=3.1.31=pypi_0
gmp=6.2.1=h295c915_3
gmpy2=2.1.2=py310heeb90bb_0
gnutls=3.6.15=he1e5248_0
google-auth=2.17.3=pypi_0
google-auth-oauthlib=1.0.0=pypi_0
gradio=3.28.3=pypi_0
gradio-client=0.1.3=pypi_0
grpcio=1.54.0=pypi_0
h11=0.14.0=pypi_0
h5py=3.9.0=pypi_0
httpcore=0.17.0=pypi_0
httpx=0.24.0=pypi_0
huggingface-hub=0.16.4=pypi_0
hydra-core=1.2.0=pypi_0
idna=3.4=py310h06a4308_0
ijson=3.2.2=pypi_0
imageio=2.28.0=pypi_0
imagesize=1.4.1=pypi_0
inflect=6.0.4=pypi_0
iniconfig=2.0.0=pypi_0
intel-openmp=2021.4.0=h06a4308_3561
ipadic=1.0.0=pypi_0
ipykernel=6.23.3=pypi_0
ipython=8.12.0=pypi_0
ipywidgets=8.0.6=pypi_0
isort=5.12.0=pypi_0
itsdangerous=2.1.2=pypi_0
jedi=0.18.2=pypi_0
jieba=0.42.1=pypi_0
jinja2=3.1.2=py310h06a4308_0
jiwer=2.5.2=pypi_0
jmespath=1.0.1=pypi_0
joblib=1.2.0=pypi_0
jpeg=9e=h5eee18b_1
jsonschema=4.17.3=pypi_0
jupyter-client=8.3.0=pypi_0
jupyter-core=5.3.1=pypi_0
jupyterlab-widgets=3.0.7=pypi_0
kaldi-python-io=1.2.2=pypi_0
kaldiio=2.18.0=pypi_0
kiwisolver=1.4.4=pypi_0
kornia=0.6.12=pypi_0
lame=3.100=h7b6447c_0
latexcodec=2.0.1=pypi_0
lazy-loader=0.2=pypi_0
lcms2=2.12=h3be6417_0
ld_impl_linux-64=2.38=h1181459_1
lerc=3.0=h295c915_0
levenshtein=0.21.1=pypi_0
libcublas=11.11.3.6=0
libcufft=10.9.0.58=0
libcufile=1.6.0.25=0
libcurand=10.3.2.56=0
libcusolver=11.4.1.48=0
libcusparse=11.7.5.86=0
libdeflate=1.17=h5eee18b_0
libffi=3.4.2=h6a678d5_6
libgcc-ng=11.2.0=h1234567_1
libgomp=11.2.0=h1234567_1
libiconv=1.16=h7f8727e_2
libidn2=2.3.2=h7f8727e_0
libnpp=11.8.0.86=0
libnvjpeg=11.9.0.86=0
libpng=1.6.39=h5eee18b_0
librosa=0.10.0.post2=pypi_0
libstdcxx-ng=11.2.0=h1234567_1
libtasn1=4.16.0=h27cfd23_0
libtiff=4.5.0=h6a678d5_2
libunistring=0.9.10=h27cfd23_0
libuuid=1.41.5=h5eee18b_0
libwebp=1.2.4=h11a3e52_1
libwebp-base=1.2.4=h5eee18b_1
lightning-utilities=0.8.0=pypi_0
linkify-it-py=2.0.0=pypi_0
llvmlite=0.40.1=pypi_0
loguru=0.7.0=pypi_0
loralib=0.1.1=pypi_0
lxml=4.9.2=pypi_0
lz4-c=1.9.4=h6a678d5_0
markdown=3.4.3=pypi_0
markdown-it-py=2.2.0=pypi_0
markdown2=2.4.9=pypi_0
markupsafe=2.1.1=py310h7f8727e_0
marshmallow=3.19.0=pypi_0
matplotlib=3.7.1=pypi_0
matplotlib-inline=0.1.6=pypi_0
mdit-py-plugins=0.3.3=pypi_0
mdurl=0.1.2=pypi_0
mecab-python3=1.0.5=pypi_0
megatron-core=0.2.0=pypi_0
mkl=2021.4.0=h06a4308_640
mkl-service=2.4.0=py310h7f8727e_0
mkl_fft=1.3.1=py310hd6ae3a3_0
mkl_random=1.2.2=py310h00e6091_0
mpc=1.1.0=h10f8cd9_1
mpfr=4.0.2=hb69a4c5_1
mpmath=1.2.1=pypi_0
msgpack=1.0.5=pypi_0
multidict=6.0.4=pypi_0
multiprocess=0.70.14=pypi_0
mypy-extensions=1.0.0=pypi_0
ncurses=6.4=h6a678d5_0
nest-asyncio=1.5.6=pypi_0
nettle=3.7.3=hbbd107a_1
networkx=2.8.4=py310h06a4308_1
nltk=3.8=pypi_0
numba=0.57.1=pypi_0
numpy=1.23.4=pypi_0
numpy-base=1.24.3=py310h8e6c178_0
oauthlib=3.2.2=pypi_0
omegaconf=2.2.3=pypi_0
onnx=1.14.0=pypi_0
openai=0.27.1=pypi_0
opencc=1.1.6=pypi_0
opencv-python=4.7.0.72=pypi_0
openh264=2.1.1=h4ff587b_0
openprompt=1.0.1=pypi_0
openssl=1.1.1w=h7f8727e_0
orjson=3.8.10=pypi_0
packaging=23.1=pypi_0
pandas=2.0.1=pypi_0
pangu=4.0.6.1=pypi_0
parameterized=0.9.0=pypi_0
parso=0.8.3=pypi_0
pathspec=0.11.1=pypi_0
pathtools=0.1.2=pypi_0
peft=0.5.0=pypi_0
pexpect=4.8.0=pypi_0
pickleshare=0.7.5=pypi_0
pillow=9.4.0=py310h6a678d5_0
pip=23.2.1=pypi_0
plac=1.3.5=pypi_0
platformdirs=3.4.0=pypi_0
pluggy=1.2.0=pypi_0
pooch=1.6.0=pypi_0
portalocker=2.7.0=pypi_0
progress=1.6=pypi_0
prompt-toolkit=3.0.38=pypi_0
protobuf=3.20.3=pypi_0
psutil=5.9.5=pypi_0
ptyprocess=0.7.0=pypi_0
pure-eval=0.2.2=pypi_0
pyannote-core=5.0.0=pypi_0
pyannote-database=5.0.1=pypi_0
pyannote-metrics=3.2.1=pypi_0
pyarrow=11.0.0=pypi_0
pyasn1=0.5.0=pypi_0
pyasn1-modules=0.3.0=pypi_0
pybind11=2.10.4=pypi_0
pybtex=0.24.0=pypi_0
pybtex-docutils=1.0.2=pypi_0
pycparser=2.21=pyhd3eb1b0_0
pydantic=1.10.7=pypi_0
pydeprecate=0.3.1=pypi_0
pydub=0.25.1=pypi_0
pygments=2.15.1=pypi_0
pynini=2.1.5=pypi_0
pyopenssl=23.0.0=py310h06a4308_0
pyparsing=3.0.9=pypi_0
pypinyin=0.49.0=pypi_0
pypinyin-dict=0.6.0=pypi_0
pyrsistent=0.19.3=pypi_0
pysocks=1.7.1=py310h06a4308_0
pytest=7.4.0=pypi_0
pytest-runner=6.0.0=pypi_0
python=3.10.10=h7a1cb2a_2
python-dateutil=2.8.2=pypi_0
python-multipart=0.0.6=pypi_0
pytorch=2.0.0=py3.10_cuda11.8_cudnn8.7.0_0
pytorch-cuda=11.8=h7e8668a_3
pytorch-lightning=1.9.4=pypi_0
pytorch-mutex=1.0=cuda
pytz=2023.3=pypi_0
pyyaml=5.4.1=pypi_0
pyzmq=25.1.0=pypi_0
rank-bm25=0.2.2=pypi_0
rapidfuzz=2.13.7=pypi_0
readline=8.2=h5eee18b_0
regex=2023.3.23=pypi_0
requests=2.28.1=py310h06a4308_1
requests-oauthlib=1.3.1=pypi_0
responses=0.18.0=pypi_0
rich=13.4.2=pypi_0
risparser=0.4.4=pypi_0
rouge=1.0.0=pypi_0
rouge-score=0.1.2=pypi_0
rsa=4.9=pypi_0
ruamel-yaml=0.17.32=pypi_0
ruamel-yaml-clib=0.2.7=pypi_0
s3transfer=0.6.1=pypi_0
sacrebleu=2.3.1=pypi_0
sacremoses=0.0.53=pypi_0
safetensors=0.3.1=pypi_0
scikit-learn=1.2.1=pypi_0
scipy=1.10.1=pypi_0
seaborn=0.12.2=pypi_0
semantic-version=2.10.0=pypi_0
sentence-transformers=2.2.2=pypi_0
sentencepiece=0.1.96=pypi_0
sentry-sdk=1.25.1=pypi_0
setproctitle=1.3.2=pypi_0
setuptools=65.5.1=pypi_0
shellingham=1.5.0.post1=pypi_0
six=1.16.0=pyhd3eb1b0_1
smmap=5.0.0=pypi_0
sniffio=1.3.0=pypi_0
snowballstemmer=2.2.0=pypi_0
sortedcontainers=2.4.0=pypi_0
soundfile=0.12.1=pypi_0
soupsieve=2.4.1=pypi_0
sox=1.4.1=pypi_0
soxr=0.3.5=pypi_0
sphinx=7.0.1=pypi_0
sphinxcontrib-applehelp=1.0.4=pypi_0
sphinxcontrib-bibtex=2.5.0=pypi_0
sphinxcontrib-devhelp=1.0.2=pypi_0
sphinxcontrib-htmlhelp=2.0.1=pypi_0
sphinxcontrib-jsmath=1.0.1=pypi_0
sphinxcontrib-qthelp=1.0.3=pypi_0
sphinxcontrib-serializinghtml=1.1.5=pypi_0
sqlite=3.41.1=h5eee18b_0
stack-data=0.6.2=pypi_0
starlette=0.26.1=pypi_0
sympy=1.11.1=py310h06a4308_0
tabulate=0.9.0=pypi_0
taming-transformers=0.0.1=pypi_0
taming-transformers-rom1504=0.0.6=pypi_0
tensorboard=2.12.2=pypi_0
tensorboard-data-server=0.7.0=pypi_0
tensorboard-plugin-wit=1.8.1=pypi_0
tensorboardx=2.6.2.2=pypi_0
termcolor=2.2.0=pypi_0
test-tube=0.7.5=pypi_0
text-unidecode=1.3=pypi_0
textdistance=4.5.0=pypi_0
texterrors=0.4.4=pypi_0
threadpoolctl=3.1.0=pypi_0
tk=8.6.12=h1ccaba5_0
tokenize-rt=5.0.0=pypi_0
tokenizers=0.13.3=pypi_0
toml=0.10.2=pypi_0
tomli=2.0.1=pypi_0
toolz=0.12.0=pypi_0
torchaudio=2.0.0=py310_cu118
torchmetrics=0.11.4=pypi_0
torchtriton=2.0.0=py310
torchvision=0.15.0=py310_cu118
tornado=6.3.2=pypi_0
tqdm=4.64.1=pypi_0
traitlets=5.9.0=pypi_0
transformers=4.33.3=pypi_0
trl=0.7.1=pypi_0
tweet-preprocessor=0.6.0=pypi_0
typed-ast=1.5.4=pypi_0
typer=0.9.0=pypi_0
typing_extensions=4.4.0=py310h06a4308_0
tzdata=2023.3=pypi_0
uc-micro-py=1.0.1=pypi_0
unidecode=1.3.7=pypi_0
urllib3=1.26.15=py310h06a4308_0
uvicorn=0.21.1=pypi_0
wandb=0.15.4=pypi_0
wcwidth=0.2.6=pypi_0
webdataset=0.1.62=pypi_0
websockets=11.0.2=pypi_0
werkzeug=2.3.0=pypi_0
wget=3.2=pypi_0
wheel=0.38.4=py310h06a4308_0
widgetsnbextension=4.0.7=pypi_0
wrapt=1.15.0=pypi_0
xxhash=3.2.0=pypi_0
xz=5.2.10=h5eee18b_1
yacs=0.1.8=pypi_0
yarl=1.9.2=pypi_0
youtokentome=1.0.6=pypi_0
zlib=1.2.13=h5eee18b_0
zstd=1.5.4=hc292b87_0


================================================
FILE: train_ranker.py
================================================
import os
import torch
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import argparse
from datasets import DATASETS
from config import *
from model import *
from dataloader import *
from trainer import *

from transformers import BitsAndBytesConfig
from pytorch_lightning import seed_everything
from model import LlamaForCausalLM
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    prepare_model_for_kbit_training,
)


try:
    os.environ['WANDB_PROJECT'] = PROJECT_NAME
except:
    print('WANDB_PROJECT not available, please set it in config.py')


def main(args, export_root=None):
    seed_everything(args.seed)
    if export_root == None:
        export_root = EXPERIMENT_ROOT + '/' + args.llm_base_model.split('/')[-1] + '/' + args.dataset_code

    train_loader, val_loader, test_loader, tokenizer, test_retrieval = dataloader_factory(args)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model = LlamaForCausalLM.from_pretrained(
        args.llm_base_model,
        quantization_config=bnb_config,
        device_map='auto',
        cache_dir=args.llm_cache_dir,
    )
    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)
    config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.lora_target_modules,
        lora_dropout=args.lora_dropout,
        bias='none',
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config)
    model.print_trainable_parameters()

    model.config.use_cache = False
    trainer = LLMTrainer(args, model, train_loader, val_loader, test_loader, tokenizer, export_root, args.use_wandb)
    
    trainer.train()
    trainer.test(test_retrieval)


if __name__ == "__main__":
    args.model_code = 'llm'
    set_template(args)
    main(args, export_root=None)


================================================
FILE: train_retriever.py
================================================
import os
import torch
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import wandb
import argparse

from config import *
from model import *
from dataloader import *
from trainer import *

from pytorch_lightning import seed_everything

try:
    os.environ['WANDB_PROJECT'] = PROJECT_NAME
except:
    print('WANDB_PROJECT not available, please set it in config.py')


def main(args, export_root=None):
    seed_everything(args.seed)
    train_loader, val_loader, test_loader = dataloader_factory(args)
    model = LRURec(args)
    if export_root == None:
        export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code
    
    trainer = LRUTrainer(args, model, train_loader, val_loader, test_loader, export_root, args.use_wandb)
    trainer.train()
    trainer.test()

    # the next line generates val / test candidates for reranking
    trainer.generate_candidates(os.path.join(export_root, 'retrieved.pkl'))


if __name__ == "__main__":
    args.model_code = 'lru'
    set_template(args)
    main(args, export_root=None)

    # # searching best hyperparameters
    # for decay in [0, 0.01]:
    #     for dropout in [0, 0.1, 0.2, 0.3, 0.4, 0.5]:
    #         args.weight_decay = decay
    #         args.bert_dropout = dropout
    #         args.bert_attn_dropout = dropout
    #         export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code + '/' + str(decay) + '_' + str(dropout)
    #         main(args, export_root=export_root)

================================================
FILE: trainer/__init__.py
================================================
from .lru import *
from .llm import *
from .utils import *

================================================
FILE: trainer/base.py
================================================
from model import *
from config import *
from .utils import *
from .loggers import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm

import json
import numpy as np
from abc import ABCMeta
from pathlib import Path
from collections import OrderedDict


class BaseTrainer(metaclass=ABCMeta):
    def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb=True):
        self.args = args
        self.device = args.device
        self.model = model.to(self.device)

        self.num_epochs = args.num_epochs
        self.metric_ks = args.metric_ks
        self.best_metric = args.best_metric
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.optimizer = self._create_optimizer()
        if args.enable_lr_schedule:
            if args.enable_lr_warmup:
                self.lr_scheduler = self.get_linear_schedule_with_warmup(
                    self.optimizer, args.warmup_steps, len(self.train_loader) * self.num_epochs)
            else:
                self.lr_scheduler = optim.lr_scheduler.StepLR(
                    self.optimizer, step_size=args.decay_step, gamma=args.gamma)
            
        self.export_root = export_root
        if not os.path.exists(self.export_root):
            Path(self.export_root).mkdir(parents=True)
        self.use_wandb = use_wandb
        if use_wandb:
            import wandb
            wandb.init(
                name=self.args.model_code+'_'+self.args.dataset_code,
                project=PROJECT_NAME,
                config=args,
            )
            writer = wandb
        else:
            from torch.utils.tensorboard import SummaryWriter
            writer = SummaryWriter(
                log_dir=Path(self.export_root).joinpath('logs'),
                comment=self.args.model_code+'_'+self.args.dataset_code,
            )
        self.val_loggers, self.test_loggers = self._create_loggers()
        self.logger_service = LoggerService(
            self.args, writer, self.val_loggers, self.test_loggers, use_wandb)
        
        print(args)

    def train(self):
        accum_iter = 0
        self.exit_training = self.validate(0, accum_iter)
        for epoch in range(self.num_epochs):
            accum_iter = self.train_one_epoch(epoch, accum_iter)
            if self.args.val_strategy == 'epoch':
                self.exit_training = self.validate(epoch, accum_iter)  # val after every epoch
            if self.exit_training:
                print('Early stopping triggered. Exit training')
                break
        self.logger_service.complete()

    def train_one_epoch(self, epoch, accum_iter):
        average_meter_set = AverageMeterSet()
        tqdm_dataloader = tqdm(self.train_loader)

        for batch_idx, batch in enumerate(tqdm_dataloader):
            self.model.train()
            batch = self.to_device(batch)

            self.optimizer.zero_grad()
            loss = self.calculate_loss(batch)
            loss.backward()
            self.clip_gradients(self.args.max_grad_norm)
            self.optimizer.step()
            if self.args.enable_lr_schedule:
                self.lr_scheduler.step()

            average_meter_set.update('loss', loss.item())
            tqdm_dataloader.set_description(
                'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))

            accum_iter += 1
            if self.args.val_strategy == 'iteration' and accum_iter % self.args.val_iterations == 0:
                self.exit_training = self.validate(epoch, accum_iter)  # val after certain iterations
                if self.exit_training: break

        return accum_iter

    def validate(self, epoch, accum_iter):
        self.model.eval()
        average_meter_set = AverageMeterSet()
        with torch.no_grad():
            tqdm_dataloader = tqdm(self.val_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = self.to_device(batch)
                metrics = self.calculate_metrics(batch, exclude_history=False)  # faster validation
                self._update_meter_set(average_meter_set, metrics)
                self._update_dataloader_metrics(
                    tqdm_dataloader, average_meter_set)

            log_data = {
                'state_dict': (self._create_state_dict()),
                'epoch': epoch+1,
                'accum_iter': accum_iter,
            }
            log_data.update(average_meter_set.averages())
        
        return self.logger_service.log_val(log_data)  # early stopping

    def test(self, epoch=-1, accum_iter=-1, save_name=None):
        print('******************** Testing Best Model ********************')
        best_model_dict = torch.load(os.path.join(
            self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
        self.model.load_state_dict(best_model_dict)
        self.model.eval()

        average_meter_set = AverageMeterSet()
        with torch.no_grad():
            tqdm_dataloader = tqdm(self.test_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = self.to_device(batch)
                metrics = self.calculate_metrics(batch)
                self._update_meter_set(average_meter_set, metrics)
                self._update_dataloader_metrics(
                    tqdm_dataloader, average_meter_set)

            log_data = {
                'state_dict': (self._create_state_dict()),
                'epoch': epoch+1,
                'accum_iter': accum_iter,
            }
            average_metrics = average_meter_set.averages()
            log_data.update(average_metrics)
            self.logger_service.log_test(log_data)

            print('******************** Testing Metrics ********************')
            print(average_metrics)
            file_name = 'test_metrics.json' if save_name is None else save_name
            with open(os.path.join(self.export_root, file_name), 'w') as f:
                json.dump(average_metrics, f, indent=4)
        
        return average_metrics
    
    def to_device(self, batch):
        return [x.to(self.device) for x in batch]

    @abstractmethod
    def calculate_loss(self, batch):
        pass
    
    @abstractmethod
    def calculate_metrics(self, batch):
        pass
    
    def clip_gradients(self, limit=1.0):
        nn.utils.clip_grad_norm_(self.model.parameters(), limit)

    def _update_meter_set(self, meter_set, metrics):
        for k, v in metrics.items():
            meter_set.update(k, v)

    def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
        description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]
                               ] + ['Recall@%d' % k for k in self.metric_ks[:3]]
        description = 'Eval: ' + \
            ', '.join(s + ' {:.4f}' for s in description_metrics)
        description = description.replace('NDCG', 'N').replace('Recall', 'R')
        description = description.format(
            *(meter_set[k].avg for k in description_metrics))
        tqdm_dataloader.set_description(description)

    def _create_optimizer(self):
        args = self.args
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'layer_norm']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                'weight_decay': args.weight_decay,
            },
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.},
        ]
        if args.optimizer.lower() == 'adamw':
            return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
        elif args.optimizer.lower() == 'adam':
            return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

    def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
        def lr_lambda(current_step: int):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            return max(
                0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
            )

        return LambdaLR(optimizer, lr_lambda, last_epoch)

    def _create_loggers(self):
        root = Path(self.export_root)
        model_checkpoint = root.joinpath('models')

        val_loggers, test_loggers = [], []
        for k in self.metric_ks:
            val_loggers.append(
                MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation', use_wandb=self.use_wandb))
            val_loggers.append(
                MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation', use_wandb=self.use_wandb))
            val_loggers.append(
                MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Validation', use_wandb=self.use_wandb))

        val_loggers.append(RecentModelLogger(self.args, model_checkpoint))
        val_loggers.append(BestModelLogger(self.args, model_checkpoint, metric_key=self.best_metric))

        for k in self.metric_ks:
            test_loggers.append(
                MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Test', use_wandb=self.use_wandb))
            test_loggers.append(
                MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Test', use_wandb=self.use_wandb))
            test_loggers.append(
                MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Test', use_wandb=self.use_wandb))

        return val_loggers, test_loggers

    def _create_state_dict(self):
        return {
            STATE_DICT_KEY: self.model.state_dict(),
            OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
        }

================================================
FILE: trainer/llm.py
================================================
from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY
from .verb import ManualVerbalizer
from .utils import *
from .loggers import *
from .base import *

import re
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import json
import numpy as np
from abc import *
from pathlib import Path

import bitsandbytes as bnb
from transformers.trainer import *
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback


def llama_collate_fn_w_truncation(llm_max_length, eval=False):
    def llama_collate_fn(batch):
        all_input_ids = []
        all_attention_mask = []
        all_labels = []
        example_max_length = max([len(batch[idx]['input_ids']) for idx in range(len(batch))])
        max_length = min(llm_max_length, example_max_length)
        
        for i in range(len(batch)):
            input_ids = batch[i]['input_ids']
            attention_mask = batch[i]['attention_mask']
            labels = batch[i]['labels']
            if len(input_ids) > max_length:
                input_ids = input_ids[-max_length:]
                attention_mask = attention_mask[-max_length:]
                if not eval: labels = labels[-max_length:]
            elif len(input_ids) < max_length:
                padding_length = max_length - len(input_ids)
                input_ids = [0] * padding_length + input_ids
                attention_mask = [0] * padding_length + attention_mask
                if not eval: labels = [-100] * padding_length + labels

            if eval: assert input_ids[-1] == 13
            else:
                assert input_ids[-3] == 13 and input_ids[-1] == 2
                assert labels[-3] == -100 and labels[-2] != -100
            
            all_input_ids.append(torch.tensor(input_ids).long())
            all_attention_mask.append(torch.tensor(attention_mask).long())
            all_labels.append(torch.tensor(labels).long())
        
        return {
            'input_ids': torch.vstack(all_input_ids),
            'attention_mask': torch.vstack(all_attention_mask),
            'labels': torch.vstack(all_labels)
        }
    return llama_collate_fn


def compute_metrics_for_ks(ks, verbalizer):
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        logits = torch.tensor(logits)
        labels = torch.tensor(labels).view(-1)
        scores = verbalizer.process_logits(logits)
        metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels, ks)
        return metrics
    return compute_metrics


class LLMTrainer(Trainer):
    def __init__(
            self,
            args,
            model,
            train_loader,
            val_loader,
            test_loader,
            tokenizer,
            export_root,
            use_wandb,
            **kwargs
        ):
        self.original_args = args
        self.export_root = export_root
        self.use_wandb = use_wandb
        self.llm_max_text_len = args.llm_max_text_len
        self.rerank_metric_ks = args.rerank_metric_ks
        self.verbalizer = ManualVerbalizer(
            tokenizer=tokenizer,
            prefix='',
            post_log_softmax=False,
            classes=list(range(args.llm_negative_sample_size+1)),
            label_words={i: chr(ord('A')+i) for i in range(args.llm_negative_sample_size+1)},
        )

        hf_args = TrainingArguments(
            per_device_train_batch_size=args.lora_micro_batch_size,
            gradient_accumulation_steps=args.train_batch_size//args.lora_micro_batch_size,
            warmup_steps=args.warmup_steps,
            num_train_epochs=args.lora_num_epochs,
            learning_rate=args.lora_lr,
            bf16=True,
            logging_steps=10,
            optim="paged_adamw_32bit",
            evaluation_strategy="steps",
            save_strategy="steps",
            eval_steps=args.lora_val_iterations,
            save_steps=args.lora_val_iterations,
            output_dir=export_root,
            save_total_limit=3,
            load_best_model_at_end=True,
            ddp_find_unused_parameters=None,
            group_by_length=False,
            report_to="wandb" if use_wandb else None,
            run_name=args.model_code+'_'+args.dataset_code if use_wandb else None,
            metric_for_best_model=args.rerank_best_metric,
            greater_is_better=True,
        )
        super().__init__(
            model=model,
            args=hf_args,
            callbacks=[EarlyStoppingCallback(args.lora_early_stopping_patience)],
            **kwargs)  # hf_args is now args

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.tokenizer = tokenizer
        
        self.train_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=False)
        self.val_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True)
        self.test_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True)
        self.compute_metrics = compute_metrics_for_ks(self.rerank_metric_ks, self.verbalizer)

        if len(self.label_names) == 0:
            self.label_names = ['labels']  # for some reason label name is not set
    
    def test(self, test_retrieval):
        average_metrics = self.predict(test_dataset=None).metrics
        print('Ranking Performance on Subset:', average_metrics)
        print('************************************************************')
        with open(os.path.join(self.export_root, 'subset_metrics.json'), 'w') as f:
                json.dump(average_metrics, f, indent=4)

        print('Original Performance:', test_retrieval['original_metrics'])
        print('************************************************************')
        original_size = test_retrieval['original_size']
        retrieval_size = test_retrieval['retrieval_size']
        
        overall_metrics = {}
        for key in test_retrieval['non_retrieval_metrics'].keys():
            if 'test_' + key in average_metrics:
                overall_metrics['test_' + key] = (average_metrics['test_' + key] * retrieval_size  + \
                    test_retrieval['non_retrieval_metrics'][key] * (original_size - retrieval_size)) / original_size
        print('Overall Performance of Our Framework:', overall_metrics)
        with open(os.path.join(self.export_root, 'overall_metrics.json'), 'w') as f:
                json.dump(overall_metrics, f, indent=4)
        
        return average_metrics

    def get_train_dataloader(self):
        return self.train_loader
    
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        return self.val_loader
    
    def get_test_dataloader(self, test_dataset: Optional[Dataset] = None) -> DataLoader:
        return self.test_loader

================================================
FILE: trainer/loggers.py
================================================
import os
import torch
from abc import ABCMeta, abstractmethod


def save_state_dict(state_dict, path, filename):
    torch.save(state_dict, os.path.join(path, filename))


class LoggerService(object):
    def __init__(self, args, writer, val_loggers, test_loggers, use_wandb):
        self.args = args
        self.writer = writer
        self.val_loggers = val_loggers if val_loggers else []
        self.test_loggers = test_loggers if test_loggers else []
        self.use_wandb = use_wandb

    def complete(self):
        if self.use_wandb:
            self.writer.finish()
        else:
            self.writer.close()

    def log_val(self, log_data):
        criteria_met = False
        for logger in self.val_loggers:
            logger.log(self.writer, **log_data)
            if self.args.early_stopping and isinstance(logger, BestModelLogger):
                criteria_met = logger.patience_counter >= self.args.early_stopping_patience
        return criteria_met
    
    def log_test(self, log_data):
        for logger in self.test_loggers:
            logger.log(self.writer, **log_data)


class AbstractBaseLogger(metaclass=ABCMeta):
    @abstractmethod
    def log(self, *args, **kwargs):
        raise NotImplementedError

    def complete(self, *args, **kwargs):
        pass


class MetricGraphPrinter(AbstractBaseLogger):
    def __init__(self, key, graph_name, group_name, use_wandb):
        self.key = key
        self.graph_label = graph_name
        self.group_name = group_name
        self.use_wandb = use_wandb
        
    def log(self, writer, *args, **kwargs):
        if self.key in kwargs:
            if self.use_wandb:
                writer.log({self.group_name+'/'+self.graph_label: kwargs[self.key], 'batch': kwargs['accum_iter']})
            else:
                writer.add_scalar(self.group_name+'/'+ self.graph_label, kwargs[self.key], kwargs['accum_iter'])
        else:
            print('Metric {} not found...'.format(self.key))

    def complete(self, writer, *args, **kwargs):
        self.log(writer, *args, **kwargs)


class RecentModelLogger(AbstractBaseLogger):
    def __init__(self, args, checkpoint_path, filename='checkpoint-recent.pth'):
        self.args = args
        self.checkpoint_path = checkpoint_path
        if not os.path.exists(self.checkpoint_path):
            self.checkpoint_path.mkdir(parents=True)
        self.recent_epoch = None
        self.filename = filename

    def log(self, *args, **kwargs):
        epoch = kwargs['epoch']

        if self.recent_epoch != epoch:
            self.recent_epoch = epoch
            state_dict = kwargs['state_dict']
            state_dict['epoch'] = kwargs['epoch']
            save_state_dict(state_dict, self.checkpoint_path, self.filename)

    def complete(self, *args, **kwargs):
        save_state_dict(kwargs['state_dict'],
                        self.checkpoint_path, self.filename + '.final')


class BestModelLogger(AbstractBaseLogger):
    def __init__(self, args, checkpoint_path, metric_key, filename='best_acc_model.pth'):
        self.args = args
        self.checkpoint_path = checkpoint_path
        if not os.path.exists(self.checkpoint_path):
            self.checkpoint_path.mkdir(parents=True)

        self.best_metric = 0.
        self.metric_key = metric_key
        self.filename = filename
        self.patience_counter = 0

    def log(self, *args, **kwargs):
        current_metric = kwargs[self.metric_key]
        if self.best_metric < current_metric:  # assumes the higher the better
            print("Update Best {} Model at {}".format(
                self.metric_key, kwargs['epoch']))
            self.best_metric = current_metric
            save_state_dict(kwargs['state_dict'],
                            self.checkpoint_path, self.filename)
            if self.args.early_stopping:
                self.patience_counter = 0
        elif self.args.early_stopping:
            self.patience_counter += 1

================================================
FILE: trainer/lru.py
================================================
from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY
from .utils import *
from .loggers import *
from .base import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import json
import pickle
import numpy as np
from abc import *
from pathlib import Path


class LRUTrainer(BaseTrainer):
    def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb):
        super().__init__(args, model, train_loader, val_loader, test_loader, export_root, use_wandb)
        self.ce = nn.CrossEntropyLoss(ignore_index=0)
    
    def calculate_loss(self, batch):
        seqs, labels = batch
        logits = self.model(seqs)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)
        loss = self.ce(logits, labels)
        return loss

    def calculate_metrics(self, batch, exclude_history=True):
        seqs, labels = batch
        
        scores = self.model(seqs)[:, -1, :]
        B, L = seqs.shape
        if exclude_history:
            for i in range(L):
                scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9
            scores[:, 0] = -1e9  # padding
        metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels.view(-1), self.metric_ks)
        return metrics
    
    def generate_candidates(self, retrieved_data_path):
        self.model.eval()
        val_probs, val_labels = [], []
        test_probs, test_labels = [], []
        with torch.no_grad():
            print('*************** Generating Candidates for Validation Set ***************')
            tqdm_dataloader = tqdm(self.val_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = self.to_device(batch)
                seqs, labels = batch
        
                scores = self.model(seqs)[:, -1, :]
                B, L = seqs.shape
                for i in range(L):
                    scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9
                scores[:, 0] = -1e9  # padding
                val_probs.extend(scores.tolist())
                val_labels.extend(labels.view(-1).tolist())
            val_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(val_probs), 
                                                          torch.tensor(val_labels).view(-1), self.metric_ks)
            print(val_metrics)

            print('****************** Generating Candidates for Test Set ******************')
            tqdm_dataloader = tqdm(self.test_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch = self.to_device(batch)
                seqs, labels = batch
        
                scores = self.model(seqs)[:, -1, :]
                B, L = seqs.shape
                for i in range(L):
                    scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9
                scores[:, 0] = -1e9  # padding
                test_probs.extend(scores.tolist())
                test_labels.extend(labels.view(-1).tolist())
            test_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(test_probs), 
                                                           torch.tensor(test_labels).view(-1), self.metric_ks)
            print(test_metrics)

        with open(retrieved_data_path, 'wb') as f:
            pickle.dump({'val_probs': val_probs,
                         'val_labels': val_labels,
                         'val_metrics': val_metrics,
                         'test_probs': test_probs,
                         'test_labels': test_labels,
                         'test_metrics': test_metrics}, f)

================================================
FILE: trainer/utils.py
================================================
from config import *

import json
import os
import pprint as pp
import random
from datetime import date
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch import optim as optim


def ndcg(scores, labels, k):
    scores = scores.cpu()
    labels = labels.cpu()
    rank = (-scores).argsort(dim=1)
    cut = rank[:, :k]
    hits = labels.gather(1, cut)
    position = torch.arange(2, 2+k)
    weights = 1 / torch.log2(position.float())
    dcg = (hits.float() * weights).sum(1)
    idcg = torch.Tensor([weights[:min(int(n), k)].sum()
                         for n in labels.sum(1)])
    ndcg = dcg / idcg
    return ndcg.mean()


def absolute_recall_mrr_ndcg_for_ks(scores, labels, ks):
    metrics = {}
    labels = F.one_hot(labels, num_classes=scores.size(1))
    answer_count = labels.sum(1)

    labels_float = labels.float()
    rank = (-scores).argsort(dim=1)

    cut = rank
    for k in sorted(ks, reverse=True):
        cut = cut[:, :k]
        hits = labels_float.gather(1, cut)
        metrics['Recall@%d' % k] = \
            (hits.sum(1) / torch.min(torch.Tensor([k]).to(
                labels.device), labels.sum(1).float())).mean().cpu().item()
        
        metrics['MRR@%d' % k] = \
            (hits / torch.arange(1, k+1).unsqueeze(0).to(
                labels.device)).sum(1).mean().cpu().item()

        position = torch.arange(2, 2+k)
        weights = 1 / torch.log2(position.float())
        dcg = (hits * weights.to(hits.device)).sum(1)
        idcg = torch.Tensor([weights[:min(int(n), k)].sum()
                             for n in answer_count]).to(dcg.device)
        ndcg = (dcg / idcg).mean()
        metrics['NDCG@%d' % k] = ndcg.cpu().item()

    return metrics


class AverageMeterSet(object):
    def __init__(self, meters=None):
        self.meters = meters if meters else {}

    def __getitem__(self, key):
        if key not in self.meters:
            meter = AverageMeter()
            meter.update(0)
            return meter
        return self.meters[key]

    def update(self, name, value, n=1):
        if name not in self.meters:
            self.meters[name] = AverageMeter()
        self.meters[name].update(value, n)

    def reset(self):
        for meter in self.meters.values():
            meter.reset()

    def values(self, format_string='{}'):
        return {format_string.format(name): meter.val for name, meter in self.meters.items()}

    def averages(self, format_string='{}'):
        return {format_string.format(name): meter.avg for name, meter in self.meters.items()}

    def sums(self, format_string='{}'):
        return {format_string.format(name): meter.sum for name, meter in self.meters.items()}

    def counts(self, format_string='{}'):
        return {format_string.format(name): meter.count for name, meter in self.meters.items()}


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val
        self.count += n
        self.avg = self.sum / self.count

    def __format__(self, format):
        return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)


================================================
FILE: trainer/verb.py
================================================
from abc import abstractmethod
import json

from transformers.file_utils import ModelOutput
from transformers.data.processors.utils import InputFeatures

import torch
import torch.nn as nn
import torch.nn.functional as F
from yacs.config import CfgNode
from transformers.tokenization_utils import PreTrainedTokenizer

import numpy as np
from collections import namedtuple

import inspect
from typing import *

_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}


def convert_cfg_to_dict(cfg_node, key_list=[]):
    """ Convert a config node to dictionary """
    if not isinstance(cfg_node, CfgNode):
        if type(cfg_node) not in _VALID_TYPES:
            print("Key {} with value {} is not a valid type; valid types: {}".format(
                ".".join(key_list), type(cfg_node), _VALID_TYPES), )
        return cfg_node
    else:
        cfg_dict = dict(cfg_node)
        for k, v in cfg_dict.items():
            cfg_dict[k] = convert_cfg_to_dict(v, key_list + [k])
        return cfg_dict


def signature(f):
    r"""Get the function f 's input arguments. A useful gadget
    when some function slot might be instantiated into multiple functions.
    
    Args:
        f (:obj:`function`) : the function to get the input arguments.
    
    Returns:
        namedtuple : of args, default, varargs, keywords, respectively.s

    """
    sig = inspect.signature(f)
    args = [
        p.name for p in sig.parameters.values()
        if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
    ]
    varargs = [
        p.name for p in sig.parameters.values()
        if p.kind == inspect.Parameter.VAR_POSITIONAL
    ]
    varargs = varargs[0] if varargs else None
    keywords = [
        p.name for p in sig.parameters.values()
        if p.kind == inspect.Parameter.VAR_KEYWORD
    ]
    keywords = keywords[0] if keywords else None
    defaults = [
        p.default for p in sig.parameters.values()
        if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
        and p.default is not p.empty
    ] or None
    argspec = namedtuple('Signature', ['args', 'defaults',
                                        'varargs', 'keywords'])
    return argspec(args, defaults, varargs, keywords) 


class Verbalizer(nn.Module):
    r'''
    Base class for all the verbalizers.

    Args:
        tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy.
        classes (:obj:`Sequence[str]`): A sequence of classes that need to be projected.
    '''
    def __init__(self,
                 tokenizer: Optional[PreTrainedTokenizer] = None,
                 classes: Optional[Sequence[str]] = None,
                 num_classes: Optional[int] = None,
                ):
        super().__init__()
        self.tokenizer = tokenizer
        self.classes = classes
        if classes is not None and num_classes is not None:
            assert len(classes) == num_classes, "len(classes) != num_classes, Check you config."
            self.num_classes = num_classes
        elif num_classes is not None:
            self.num_classes = num_classes
        elif classes is not None:
            self.num_classes = len(classes)
        else:
            self.num_classes = None
            # raise AttributeError("No able to configure num_classes")
        self._in_on_label_words_set = False

    @property
    def label_words(self,):
        r'''
        Label words means the words in the vocabulary projected by the labels.
        E.g. if we want to establish a projection in sentiment classification: positive :math:`\rightarrow` {`wonderful`, `good`},
        in this case, `wonderful` and `good` are label words.
        '''
        if not hasattr(self, "_label_words"):
            raise RuntimeError("label words haven't been set.")
        return self._label_words

    @label_words.setter
    def label_words(self, label_words):
        if label_words is None:
            return
        self._label_words = self._match_label_words_to_label_ids(label_words)
        if not self._in_on_label_words_set:
            self.safe_on_label_words_set()

    def _match_label_words_to_label_ids(self, label_words): # TODO newly add function after docs written # TODO rename this function
        """
        sort label words dict of verbalizer to match the label order of the classes
        """
        if isinstance(label_words, dict):
            if self.classes is None:
                raise ValueError("""
                classes attribute of the Verbalizer should be set since your given label words is a dict.
                Since we will match the label word with respect to class A, to A's index in classes
                """)
            if set(label_words.keys()) != set(self.classes):
                raise ValueError("name of classes in verbalizer are different from those of dataset")
            label_words = [ # sort the dict to match dataset
                label_words[c]
                for c in self.classes
            ] # length: label_size of the whole task
        elif isinstance(label_words, list) or isinstance(label_words, tuple):
            pass
        else:
            raise ValueError("Verbalizer label words must be list, tuple or dict")
        return label_words

    def safe_on_label_words_set(self,):
        self._in_on_label_words_set = True
        self.on_label_words_set()
        self._in_on_label_words_set = False

    def on_label_words_set(self,):
        r"""A hook to do something when textual label words were set.
        """
        pass

    @property
    def vocab(self,) -> Dict:
        if not hasattr(self, '_vocab'):
            self._vocab = self.tokenizer.convert_ids_to_tokens(np.arange(self.vocab_size).tolist())
        return self._vocab

    @property
    def vocab_size(self,) -> int:
        return self.tokenizer.vocab_size

    @abstractmethod
    def generate_parameters(self, **kwargs) -> List:
        r"""
        The verbalizer can be seen as an extra layer on top of the original
        pre-trained models. In manual verbalizer, it is a fixed one-hot vector of dimension
        ``vocab_size``, with the position of the label word being 1 and 0 everywhere else.
        In other situation, the parameters may be a continuous vector over the
        vocab, with each dimension representing a weight of that token.
        Moreover, the parameters may be set to trainable to allow label words selection.

        Therefore, this function serves as an abstract methods for generating the parameters
        of the verbalizer, and must be instantiated in any derived class.

        Note that the parameters need to be registered as a part of pytorch's module to
        It can be achieved by wrapping a tensor using ``nn.Parameter()``.
        """
        raise NotImplementedError

    def register_calibrate_logits(self, logits: torch.Tensor):
        r"""
        This function aims to register logits that need to be calibrated, and detach the original logits from the current graph.
        """
        if logits.requires_grad:
            logits = logits.detach()
        self._calibrate_logits = logits

    def process_outputs(self,
                       outputs: torch.Tensor,
                       batch: Union[Dict, InputFeatures],
                       **kwargs):
        r"""By default, the verbalizer will process the logits of the PLM's
        output.

        Args:
            logits (:obj:`torch.Tensor`): The current logits generated by pre-trained language models.
            batch (:obj:`Union[Dict, InputFeatures]`): The input features of the data.
        """

        return self.process_logits(outputs, batch=batch, **kwargs)

    def gather_outputs(self, outputs: ModelOutput):
        r""" retrieve useful output for the verbalizer from the whole model output
        By default, it will only retrieve the logits

        Args:
            outputs (:obj:`ModelOutput`) The output from the pretrained language model.

        Return:
            :obj:`torch.Tensor` The gathered output, should be of shape (``batch_size``,
            ``seq_len``, ``any``)
        """
        return outputs.logits

    @staticmethod
    def aggregate(label_words_logits: torch.Tensor) -> torch.Tensor:
        r""" To aggregate logits on multiple label words into the label's logits
        Basic aggregator: mean of each label words' logits to a label's logits
        Can be re-implemented in advanced verbaliezer.

        Args:
            label_words_logits (:obj:`torch.Tensor`): The logits of the label words only.

        Return:
            :obj:`torch.Tensor`: The final logits calculated by the label words.
        """
        if label_words_logits.dim()>2:
            return label_words_logits.mean(dim=-1)
        else:
            return label_words_logits


    def normalize(self, logits: torch.Tensor) -> torch.Tensor:
        r"""
        Given logits regarding the entire vocab, calculate the probs over the label words set by softmax.

        Args:
            logits(:obj:`Tensor`): The logits of the entire vocab.

        Returns:
            :obj:`Tensor`: The probability distribution over the label words set.
        """
        batch_size = logits.shape[0]
        return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape)

    @abstractmethod
    def project(self,
                logits: torch.Tensor,
                **kwargs) -> torch.Tensor:
        r"""This method receives input logits of shape ``[batch_size, vocab_size]``, and use the
        parameters of this verbalizer to project the logits over entire vocab into the
        logits of labels words.

        Args:
            logits (:obj:`Tensor`): The logits over entire vocab generated by the pre-trained language model with shape [``batch_size``, ``max_seq_length``, ``vocab_size``]

        Returns:
            :obj:`Tensor`: The normalized probs (sum to 1) of each label .
        """
        raise NotImplementedError

    def handle_multi_token(self, label_words_logits, mask):
        r"""
        Support multiple methods to handle the multi tokens produced by the tokenizer.
        We suggest using 'first' or 'max' if the some parts of the tokenization is not meaningful.
        Can broadcast to 3-d tensor.

        Args:
            label_words_logits (:obj:`torch.Tensor`):

        Returns:
            :obj:`torch.Tensor`
        """
        if self.multi_token_handler == "first":
            label_words_logits = label_words_logits.select(dim=-1, index=0)
        elif self.multi_token_handler == "max":
            label_words_logits = label_words_logits - 1000*(1-mask.unsqueeze(0))
            label_words_logits = label_words_logits.max(dim=-1).values
        elif self.multi_token_handler == "mean":
            label_words_logits = (label_words_logits*mask.unsqueeze(0)).sum(dim=-1)/(mask.unsqueeze(0).sum(dim=-1)+1e-15)
        else:
            raise ValueError("multi_token_handler {} not configured".format(self.multi_token_handler))
        return label_words_logits

    @classmethod
    def from_config(cls,
                    config: CfgNode,
                    **kwargs):
        r"""load a verbalizer from verbalizer's configuration node.

        Args:
            config (:obj:`CfgNode`): the sub-configuration of verbalizer, i.e. ``config[config.verbalizer]``
                        if config is a global config node.
            kwargs: Other kwargs that might be used in initialize the verbalizer.
                    The actual value should match the arguments of ``__init__`` functions.
        """

        init_args = signature(cls.__init__).args
        _init_dict = {**convert_cfg_to_dict(config), **kwargs} if config is not None else kwargs
        init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args}
        verbalizer = cls(**init_dict)
        if hasattr(verbalizer, "from_file"):
            if not hasattr(config, "file_path"):
                pass
            else:
                if (not hasattr(config, "label_words") or config.label_words is None) and config.file_path is not None:
                    if config.choice is None:
                        config.choice = 0
                    verbalizer.from_file(config.file_path, config.choice)
                elif (hasattr(config, "label_words") and config.label_words is not None) and config.file_path is not None:
                    raise RuntimeError("The text can't be both set from `text` and `file_path`.")
        return verbalizer

    def from_file(self,
                  path: str,
                  choice: Optional[int] = 0 ):
        r"""Load the predefined label words from verbalizer file.
        Currently support three types of file format:
        1. a .jsonl or .json file, in which is a single verbalizer
        in dict format.
        2. a .jsonal or .json file, in which is a list of verbalizers in dict format
        3.  a .txt or a .csv file, in which is the label words of a class are listed in line,
        separated by commas. Begin a new verbalizer by an empty line.
        This format is recommended when you don't know the name of each class.

        The details of verbalizer format can be seen in :ref:`How_to_write_a_verbalizer`.

        Args:
            path (:obj:`str`): The path of the local template file.
            choice (:obj:`int`): The choice of verbalizer in a file containing
                             multiple verbalizers.

        Returns:
            Template : `self` object
        """
        if path.endswith(".txt") or path.endswith(".csv"):
            with open(path, 'r') as f:
                lines = f.readlines()
                label_words_all = []
                label_words_single_group = []
                for line in lines:
                    line = line.strip().strip(" ")
                    if line == "":
                        if len(label_words_single_group)>0:
                            label_words_all.append(label_words_single_group)
                        label_words_single_group = []
                    else:
                        label_words_single_group.append(line)
                if len(label_words_single_group) > 0: # if no empty line in the last
                    label_words_all.append(label_words_single_group)
                if choice >= len(label_words_all):
                    raise RuntimeError("choice {} exceed the number of verbalizers {}"
                                .format(choice, len(label_words_all)))

                label_words = label_words_all[choice]
                label_words = [label_words_per_label.strip().split(",") \
                            for label_words_per_label in label_words]

        elif path.endswith(".jsonl") or path.endswith(".json"):
            with open(path, "r") as f:
                label_words_all = json.load(f)
                # if it is a file containing multiple verbalizers
                if isinstance(label_words_all, list):
                    if choice >= len(label_words_all):
                        raise RuntimeError("choice {} exceed the number of verbalizers {}"
                                .format(choice, len(label_words_all)))
                    label_words = label_words_all[choice]
                elif isinstance(label_words_all, dict):
                    label_words = label_words_all
                    if choice>0:
                        print("Choice of verbalizer is 1, but the file  \
                        only contains one verbalizer.")

        self.label_words = label_words
        if self.num_classes is not None:
            num_classes = len(self.label_words)
            assert num_classes==self.num_classes, 'number of classes in the verbalizer file\
                                            does not match the predefined num_classes.'
        return self


class ManualVerbalizer(Verbalizer):
    r"""
    The basic manually defined verbalizer class, this class is inherited from the :obj:`Verbalizer` class.

    Args:
        tokenizer (:obj:`PreTrainedTokenizer`): The tokenizer of the current pre-trained model to point out the vocabulary.
        classes (:obj:`List[Any]`): The classes (or labels) of the current task.
        label_words (:obj:`Union[List[str], List[List[str]], Dict[List[str]]]`, optional): The label words that are projected by the labels.
        prefix (:obj:`str`, optional): The prefix string of the verbalizer (used in PLMs like RoBERTa, which is sensitive to prefix space)
        multi_token_handler (:obj:`str`, optional): The handling strategy for multiple tokens produced by the tokenizer.
        post_log_softmax (:obj:`bool`, optional): Whether to apply log softmax post processing on label_logits. Default to True.
    """
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 classes: Optional[List] = None,
                 num_classes: Optional[Sequence[str]] = None,
                 label_words: Optional[Union[Sequence[str], Mapping[str, str]]] = None,
                 prefix: Optional[str] = " ",
                 multi_token_handler: Optional[str] = "first",
                 post_log_softmax: Optional[bool] = True,
                ):
        super().__init__(tokenizer=tokenizer, num_classes=num_classes, classes=classes)
        self.prefix = prefix
        self.multi_token_handler = multi_token_handler
        self.label_words = label_words
        self.post_log_softmax = post_log_softmax

    def on_label_words_set(self):
        super().on_label_words_set()
        self.label_words = self.add_prefix(self.label_words, self.prefix)

         # TODO should Verbalizer base class has label_words property and setter?
         # it don't have label_words init argument or label words from_file option at all

        self.generate_parameters()

    @staticmethod
    def add_prefix(label_words, prefix):
        r"""Add prefix to label words. For example, if a label words is in the middle of a template,
        the prefix should be ``' '``.

        Args:
            label_words (:obj:`Union[Sequence[str], Mapping[str, str]]`, optional): The label words that are projected by the labels.
            prefix (:obj:`str`, optional): The prefix string of the verbalizer.

        Returns:
            :obj:`Sequence[str]`: New label words with prefix.
        """
        new_label_words = []
        if isinstance(label_words[0], str):
            label_words = [[w] for w in label_words]  #wrapped it to a list of list of label words.

        for label_words_per_label in label_words:
            new_label_words_per_label = []
            for word in label_words_per_label:
                if word.startswith("<!>"):
                    new_label_words_per_label.append(word.split("<!>")[1])
                else:
                    new_label_words_per_label.append(prefix + word)
            new_label_words.append(new_label_words_per_label)
        return new_label_words

    def generate_parameters(self) -> List:
        r"""In basic manual template, the parameters are generated from label words directly.
        In this implementation, the label_words should not be tokenized into more than one token.
        """
        all_ids = []
        for words_per_label in self.label_words:
            ids_per_label = []
            for word in words_per_label:
                ids = self.tokenizer.encode(word, add_special_tokens=False)
                ids_per_label.append(ids)
            all_ids.append(ids_per_label)

        max_len  = max([max([len(ids) for ids in ids_per_label]) for ids_per_label in all_ids])
        max_num_label_words = max([len(ids_per_label) for ids_per_label in all_ids])
        words_ids_mask = torch.zeros(max_num_label_words, max_len)
        words_ids_mask = [[[1]*len(ids) + [0]*(max_len-len(ids)) for ids in ids_per_label]
                             + [[0]*max_len]*(max_num_label_words-len(ids_per_label))
                             for ids_per_label in all_ids]
        words_ids = [[ids + [0]*(max_len-len(ids)) for ids in ids_per_label]
                             + [[0]*max_len]*(max_num_label_words-len(ids_per_label))
                             for ids_per_label in all_ids]

        words_ids_tensor = torch.tensor(words_ids)
        words_ids_mask = torch.tensor(words_ids_mask)
        self.label_words_ids = nn.Parameter(words_ids_tensor, requires_grad=False)
        self.words_ids_mask = nn.Parameter(words_ids_mask, requires_grad=False) # A 3-d mask
        self.label_words_mask = nn.Parameter(torch.clamp(words_ids_mask.sum(dim=-1), max=1), requires_grad=False)

    def project(self,
                logits: torch.Tensor,
                **kwargs,
                ) -> torch.Tensor:
        r"""
        Project the labels, the return value is the normalized (sum to 1) probs of label words.

        Args:
            logits (:obj:`torch.Tensor`): The original logits of label words.

        Returns:
            :obj:`torch.Tensor`: The normalized logits of label words
        """

        label_words_logits = logits[:, self.label_words_ids]
        label_words_logits = self.handle_multi_token(label_words_logits, self.words_ids_mask)
        label_words_logits -= 10000*(1-self.label_words_mask)
        return label_words_logits

    def process_logits(self, logits: torch.Tensor, **kwargs):
        r"""A whole framework to process the original logits over the vocabulary, which contains four steps:

        (1) Project the logits into logits of label words

        if self.post_log_softmax is True:

            (2) Normalize over all label words

            (3) Calibrate (optional)

        (4) Aggregate (for multiple label words)

        Args:
            logits (:obj:`torch.Tensor`): The original logits.

        Returns:
            (:obj:`torch.Tensor`): The final processed logits over the labels (classes).
        """
        # project
        label_words_logits = self.project(logits, **kwargs)  #Output: (batch_size, num_classes) or  (batch_size, num_classes, num_label_words_per_label)


        if self.post_log_softmax:
            # normalize
            label_words_probs = self.normalize(label_words_logits)

            # calibrate
            if  hasattr(self, "_calibrate_logits") and self._calibrate_logits is not None:
                label_words_probs = self.calibrate(label_words_probs=label_words_probs)

            # convert to logits
            label_words_logits = torch.log(label_words_probs+1e-15)

        # aggregate
        label_logits = self.aggregate(label_words_logits)
        return label_logits

    def normalize(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Given logits regarding the entire vocabulary, return the probs over the label words set.

        Args:
            logits (:obj:`Tensor`): The logits over the entire vocabulary.

        Returns:
            :obj:`Tensor`: The logits over the label words set.

        """
        batch_size = logits.shape[0]
        return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape)


    def aggregate(self, label_words_logits: torch.Tensor) -> torch.Tensor:
        r"""Use weight to aggregate the logits of label words.

        Args:
            label_words_logits(:obj:`torch.Tensor`): The logits of the label words.

        Returns:
            :obj:`torch.Tensor`: The aggregated logits from the label words.
        """
        label_words_logits = (label_words_logits * self.label_words_mask).sum(-1)/self.label_words_mask.sum(-1)
        return label_words_logits

    def calibrate(self, label_words_probs: torch.Tensor, **kwargs) -> torch.Tensor:
        r"""

        Args:
            label_words_probs (:obj:`torch.Tensor`): The probability distribution of the label words with the shape of [``batch_size``, ``num_classes``, ``num_label_words_per_class``]

        Returns:
            :obj:`torch.Tensor`: The calibrated probability of label words.
        """
        shape = label_words_probs.shape
        assert self._calibrate_logits.dim() ==  1, "self._calibrate_logits are not 1-d tensor"
        calibrate_label_words_probs = self.normalize(self.project(self._calibrate_logits.unsqueeze(0), **kwargs))
        assert calibrate_label_words_probs.shape[1:] == label_words_probs.shape[1:] \
             and calibrate_label_words_probs.shape[0]==1, "shape not match"
        label_words_probs /= (calibrate_label_words_probs+1e-15)
        # normalize # TODO Test the performance
        norm = label_words_probs.reshape(shape[0], -1).sum(dim=-1,keepdim=True) # TODO Test the performance of detaching()
        label_words_probs = label_words_probs.reshape(shape[0], -1) / norm
        label_words_probs = label_words_probs.reshape(*shape)
        return label_words_probs
Download .txt
gitextract_v751wz7n/

├── .gitignore
├── README.md
├── config.py
├── dataloader/
│   ├── __init__.py
│   ├── base.py
│   ├── llm.py
│   ├── lru.py
│   ├── templates/
│   │   ├── README.md
│   │   ├── alpaca.json
│   │   ├── alpaca_legacy.json
│   │   ├── alpaca_short.json
│   │   └── vigogne.json
│   └── utils.py
├── datasets/
│   ├── __init__.py
│   ├── base.py
│   ├── beauty.py
│   ├── games.py
│   ├── ml_100k.py
│   └── utils.py
├── model/
│   ├── __init__.py
│   ├── llm.py
│   └── lru.py
├── requirements.txt
├── train_ranker.py
├── train_retriever.py
└── trainer/
    ├── __init__.py
    ├── base.py
    ├── llm.py
    ├── loggers.py
    ├── lru.py
    ├── utils.py
    └── verb.py
Download .txt
SYMBOL INDEX (279 symbols across 22 files)

FILE: config.py
  function set_template (line 14) | def set_template(args):

FILE: dataloader/__init__.py
  function dataloader_factory (line 8) | def dataloader_factory(args):
  function test_subset_dataloader_loader (line 24) | def test_subset_dataloader_loader(args):

FILE: dataloader/base.py
  class AbstractDataloader (line 5) | class AbstractDataloader(metaclass=ABCMeta):
    method __init__ (line 6) | def __init__(self, args, dataset):
    method code (line 21) | def code(cls):
    method get_pytorch_dataloaders (line 25) | def get_pytorch_dataloaders(self):

FILE: dataloader/llm.py
  function worker_init_fn (line 17) | def worker_init_fn(worker_id):
  function generate_and_tokenize_eval (line 23) | def generate_and_tokenize_eval(args, data_point, tokenizer, prompter):
  function generate_and_tokenize_train (line 36) | def generate_and_tokenize_train(args, data_point, tokenizer, prompter):
  function seq_to_token_ids (line 60) | def seq_to_token_ids(args, seq, candidates, label, text_dict, tokenizer,...
  class LLMDataloader (line 83) | class LLMDataloader():
    method __init__ (line 84) | def __init__(self, args, dataset):
    method code (line 150) | def code(cls):
    method get_pytorch_dataloaders (line 153) | def get_pytorch_dataloaders(self):
    method _get_train_loader (line 159) | def _get_train_loader(self):
    method _get_train_dataset (line 166) | def _get_train_dataset(self):
    method _get_val_loader (line 171) | def _get_val_loader(self):
    method _get_test_loader (line 174) | def _get_test_loader(self):
    method _get_eval_loader (line 177) | def _get_eval_loader(self, mode):
    method _get_eval_dataset (line 184) | def _get_eval_dataset(self, mode):
  class LLMTrainDataset (line 196) | class LLMTrainDataset(data_utils.Dataset):
    method __init__ (line 197) | def __init__(self, args, u2seq, max_len, rng, text_dict, tokenizer, pr...
    method __len__ (line 212) | def __len__(self):
    method __getitem__ (line 215) | def __getitem__(self, index):
  class LLMValidDataset (line 234) | class LLMValidDataset(data_utils.Dataset):
    method __init__ (line 235) | def __init__(self, args, u2seq, u2answer, max_len, rng, text_dict, tok...
    method __len__ (line 248) | def __len__(self):
    method __getitem__ (line 251) | def __getitem__(self, index):
  class LLMTestDataset (line 264) | class LLMTestDataset(data_utils.Dataset):
    method __init__ (line 265) | def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, text_di...
    method __len__ (line 279) | def __len__(self):
    method __getitem__ (line 282) | def __getitem__(self, index):

FILE: dataloader/lru.py
  function worker_init_fn (line 11) | def worker_init_fn(worker_id):
  class LRUDataloader (line 16) | class LRUDataloader():
    method __init__ (line 17) | def __init__(self, args, dataset):
    method code (line 36) | def code(cls):
    method get_pytorch_dataloaders (line 39) | def get_pytorch_dataloaders(self):
    method get_pytorch_test_subset_dataloader (line 45) | def get_pytorch_test_subset_dataloader(self):
    method _get_train_loader (line 62) | def _get_train_loader(self):
    method _get_train_dataset (line 69) | def _get_train_dataset(self):
    method _get_val_loader (line 74) | def _get_val_loader(self):
    method _get_test_loader (line 77) | def _get_test_loader(self):
    method _get_eval_loader (line 80) | def _get_eval_loader(self, mode):
    method _get_eval_dataset (line 87) | def _get_eval_dataset(self, mode):
  class LRUTrainDataset (line 95) | class LRUTrainDataset(data_utils.Dataset):
    method __init__ (line 96) | def __init__(self, args, u2seq, max_len, sliding_size, rng):
    method __len__ (line 113) | def __len__(self):
    method __getitem__ (line 116) | def __getitem__(self, index):
  class LRUValidDataset (line 130) | class LRUValidDataset(data_utils.Dataset):
    method __init__ (line 131) | def __init__(self, args, u2seq, u2answer, max_len, rng):
    method __len__ (line 140) | def __len__(self):
    method __getitem__ (line 143) | def __getitem__(self, index):
  class LRUTestDataset (line 155) | class LRUTestDataset(data_utils.Dataset):
    method __init__ (line 156) | def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, subset_...
    method __len__ (line 169) | def __len__(self):
    method __getitem__ (line 172) | def __getitem__(self, index):

FILE: dataloader/utils.py
  class Prompter (line 6) | class Prompter(object):
    method __init__ (line 9) | def __init__(self, template_name: str = "", verbose: bool = False):
    method generate_prompt (line 24) | def generate_prompt(
    method get_response (line 44) | def get_response(self, output: str) -> str:

FILE: datasets/__init__.py
  function dataset_factory (line 12) | def dataset_factory(args):

FILE: datasets/base.py
  class AbstractDataset (line 17) | class AbstractDataset(metaclass=ABCMeta):
    method __init__ (line 18) | def __init__(self, args):
    method code (line 28) | def code(cls):
    method raw_code (line 32) | def raw_code(cls):
    method zip_file_content_is_folder (line 36) | def zip_file_content_is_folder(cls):
    method all_raw_file_names (line 40) | def all_raw_file_names(cls):
    method url (line 45) | def url(cls):
    method preprocess (line 49) | def preprocess(self):
    method load_ratings_df (line 53) | def load_ratings_df(self):
    method maybe_download_raw_dataset (line 57) | def maybe_download_raw_dataset(self):
    method load_dataset (line 60) | def load_dataset(self):
    method filter_triplets (line 66) | def filter_triplets(self, df):
    method densify_index (line 90) | def densify_index(self, df):
    method split_df (line 98) | def split_df(self, df, user_count):
    method _get_rawdata_root_path (line 110) | def _get_rawdata_root_path(self):
    method _get_rawdata_folder_path (line 113) | def _get_rawdata_folder_path(self):
    method _get_preprocessed_root_path (line 117) | def _get_preprocessed_root_path(self):
    method _get_preprocessed_folder_path (line 121) | def _get_preprocessed_folder_path(self):
    method _get_preprocessed_dataset_path (line 127) | def _get_preprocessed_dataset_path(self):

FILE: datasets/beauty.py
  class BeautyDataset (line 19) | class BeautyDataset(AbstractDataset):
    method code (line 21) | def code(cls):
    method url (line 25) | def url(cls):
    method zip_file_content_is_folder (line 30) | def zip_file_content_is_folder(cls):
    method all_raw_file_names (line 34) | def all_raw_file_names(cls):
    method maybe_download_raw_dataset (line 37) | def maybe_download_raw_dataset(self):
    method preprocess (line 53) | def preprocess(self):
    method load_ratings_df (line 77) | def load_ratings_df(self):
    method load_meta_dict (line 84) | def load_meta_dict(self):

FILE: datasets/games.py
  class GamesDataset (line 19) | class GamesDataset(AbstractDataset):
    method code (line 21) | def code(cls):
    method url (line 25) | def url(cls):
    method zip_file_content_is_folder (line 31) | def zip_file_content_is_folder(cls):
    method all_raw_file_names (line 35) | def all_raw_file_names(cls):
    method maybe_download_raw_dataset (line 38) | def maybe_download_raw_dataset(self):
    method preprocess (line 54) | def preprocess(self):
    method load_ratings_df (line 78) | def load_ratings_df(self):
    method load_meta_dict (line 85) | def load_meta_dict(self):

FILE: datasets/ml_100k.py
  class ML100KDataset (line 18) | class ML100KDataset(AbstractDataset):
    method code (line 20) | def code(cls):
    method url (line 24) | def url(cls):  # as of Sep 2023
    method zip_file_content_is_folder (line 28) | def zip_file_content_is_folder(cls):
    method all_raw_file_names (line 32) | def all_raw_file_names(cls):
    method maybe_download_raw_dataset (line 38) | def maybe_download_raw_dataset(self):
    method preprocess (line 57) | def preprocess(self):
    method load_ratings_df (line 81) | def load_ratings_df(self):
    method load_meta_dict (line 88) | def load_meta_dict(self):

FILE: datasets/utils.py
  function download (line 13) | def download(url, savepath):
  function unzip (line 18) | def unzip(zippath, savepath):
  function unziptargz (line 25) | def unziptargz(zippath, savepath):

FILE: model/llm.py
  function _make_causal_mask (line 43) | def _make_causal_mask(
  function _expand_mask (line 61) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  class LlamaRMSNorm (line 75) | class LlamaRMSNorm(nn.Module):
    method __init__ (line 76) | def __init__(self, hidden_size, eps=1e-6):
    method forward (line 84) | def forward(self, hidden_states):
  class LlamaRotaryEmbedding (line 92) | class LlamaRotaryEmbedding(torch.nn.Module):
    method __init__ (line 93) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 107) | def _set_cos_sin_cache(self, seq_len, device, dtype):
    method forward (line 117) | def forward(self, x, seq_len=None):
  class LlamaLinearScalingRotaryEmbedding (line 128) | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    method __init__ (line 131) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 135) | def _set_cos_sin_cache(self, seq_len, device, dtype):
  class LlamaDynamicNTKScalingRotaryEmbedding (line 147) | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    method __init__ (line 150) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 154) | def _set_cos_sin_cache(self, seq_len, device, dtype):
  function rotate_half (line 173) | def rotate_half(x):
  function apply_rotary_pos_emb (line 180) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  class LlamaMLP (line 191) | class LlamaMLP(nn.Module):
    method __init__ (line 192) | def __init__(self, config):
    method forward (line 202) | def forward(self, x):
  function repeat_kv (line 221) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  class LlamaAttention (line 233) | class LlamaAttention(nn.Module):
    method __init__ (line 236) | def __init__(self, config: LlamaConfig):
    method _init_rope (line 258) | def _init_rope(self):
    method _shape (line 275) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 278) | def forward(
  class LlamaDecoderLayer (line 371) | class LlamaDecoderLayer(nn.Module):
    method __init__ (line 372) | def __init__(self, config: LlamaConfig):
    method forward (line 380) | def forward(
  class LlamaPreTrainedModel (line 456) | class LlamaPreTrainedModel(PreTrainedModel):
    method _init_weights (line 463) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 474) | def _set_gradient_checkpointing(self, module, value=False):
  class LlamaModel (line 547) | class LlamaModel(LlamaPreTrainedModel):
    method __init__ (line 555) | def __init__(self, config: LlamaConfig):
    method get_input_embeddings (line 568) | def get_input_embeddings(self):
    method set_input_embeddings (line 571) | def set_input_embeddings(self, value):
    method _prepare_decoder_attention_mask (line 575) | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,...
    method forward (line 599) | def forward(
  class LlamaForCausalLM (line 727) | class LlamaForCausalLM(LlamaPreTrainedModel):
    method __init__ (line 730) | def __init__(self, config):
    method get_input_embeddings (line 740) | def get_input_embeddings(self):
    method set_input_embeddings (line 743) | def set_input_embeddings(self, value):
    method get_output_embeddings (line 746) | def get_output_embeddings(self):
    method set_output_embeddings (line 749) | def set_output_embeddings(self, new_embeddings):
    method set_decoder (line 752) | def set_decoder(self, decoder):
    method get_decoder (line 755) | def get_decoder(self):
    method forward (line 760) | def forward(
    method prepare_inputs_for_generation (line 856) | def prepare_inputs_for_generation(
    method _reorder_cache (line 887) | def _reorder_cache(past_key_values, beam_idx):

FILE: model/lru.py
  class LRURec (line 8) | class LRURec(nn.Module):
    method __init__ (line 9) | def __init__(self, args):
    method truncated_normal_init (line 16) | def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0...
    method forward (line 38) | def forward(self, x):
  class LRUEmbedding (line 44) | class LRUEmbedding(nn.Module):
    method __init__ (line 45) | def __init__(self, args):
    method get_mask (line 54) | def get_mask(self, x):
    method forward (line 57) | def forward(self, x):
  class LRUModel (line 63) | class LRUModel(nn.Module):
    method __init__ (line 64) | def __init__(self, args):
    method forward (line 73) | def forward(self, x, embedding_weight, mask):
  class LRUBlock (line 89) | class LRUBlock(nn.Module):
    method __init__ (line 90) | def __init__(self, args):
    method forward (line 99) | def forward(self, x, mask):
  class LRULayer (line 105) | class LRULayer(nn.Module):
    method __init__ (line 106) | def __init__(self,
    method lru_parallel (line 136) | def lru_parallel(self, i, h, lamb, mask, B, L, D):
    method forward (line 149) | def forward(self, x, mask):
  class PositionwiseFeedForward (line 164) | class PositionwiseFeedForward(nn.Module):
    method __init__ (line 165) | def __init__(self, d_model, d_ff, dropout=0.1):
    method forward (line 173) | def forward(self, x):

FILE: train_ranker.py
  function main (line 30) | def main(args, export_root=None):

FILE: train_retriever.py
  function main (line 21) | def main(args, export_root=None):

FILE: trainer/base.py
  class BaseTrainer (line 20) | class BaseTrainer(metaclass=ABCMeta):
    method __init__ (line 21) | def __init__(self, args, model, train_loader, val_loader, test_loader,...
    method train (line 65) | def train(self):
    method train_one_epoch (line 77) | def train_one_epoch(self, epoch, accum_iter):
    method validate (line 104) | def validate(self, epoch, accum_iter):
    method test (line 125) | def test(self, epoch=-1, accum_iter=-1, save_name=None):
    method to_device (line 159) | def to_device(self, batch):
    method calculate_loss (line 163) | def calculate_loss(self, batch):
    method calculate_metrics (line 167) | def calculate_metrics(self, batch):
    method clip_gradients (line 170) | def clip_gradients(self, limit=1.0):
    method _update_meter_set (line 173) | def _update_meter_set(self, meter_set, metrics):
    method _update_dataloader_metrics (line 177) | def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
    method _create_optimizer (line 187) | def _create_optimizer(self):
    method get_linear_schedule_with_warmup (line 205) | def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps,...
    method _create_loggers (line 215) | def _create_loggers(self):
    method _create_state_dict (line 241) | def _create_state_dict(self):

FILE: trainer/llm.py
  function llama_collate_fn_w_truncation (line 23) | def llama_collate_fn_w_truncation(llm_max_length, eval=False):
  function compute_metrics_for_ks (line 62) | def compute_metrics_for_ks(ks, verbalizer):
  class LLMTrainer (line 73) | class LLMTrainer(Trainer):
    method __init__ (line 74) | def __init__(
    method test (line 141) | def test(self, test_retrieval):
    method get_train_dataloader (line 164) | def get_train_dataloader(self):
    method get_eval_dataloader (line 167) | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) ...
    method get_test_dataloader (line 170) | def get_test_dataloader(self, test_dataset: Optional[Dataset] = None) ...

FILE: trainer/loggers.py
  function save_state_dict (line 6) | def save_state_dict(state_dict, path, filename):
  class LoggerService (line 10) | class LoggerService(object):
    method __init__ (line 11) | def __init__(self, args, writer, val_loggers, test_loggers, use_wandb):
    method complete (line 18) | def complete(self):
    method log_val (line 24) | def log_val(self, log_data):
    method log_test (line 32) | def log_test(self, log_data):
  class AbstractBaseLogger (line 37) | class AbstractBaseLogger(metaclass=ABCMeta):
    method log (line 39) | def log(self, *args, **kwargs):
    method complete (line 42) | def complete(self, *args, **kwargs):
  class MetricGraphPrinter (line 46) | class MetricGraphPrinter(AbstractBaseLogger):
    method __init__ (line 47) | def __init__(self, key, graph_name, group_name, use_wandb):
    method log (line 53) | def log(self, writer, *args, **kwargs):
    method complete (line 62) | def complete(self, writer, *args, **kwargs):
  class RecentModelLogger (line 66) | class RecentModelLogger(AbstractBaseLogger):
    method __init__ (line 67) | def __init__(self, args, checkpoint_path, filename='checkpoint-recent....
    method log (line 75) | def log(self, *args, **kwargs):
    method complete (line 84) | def complete(self, *args, **kwargs):
  class BestModelLogger (line 89) | class BestModelLogger(AbstractBaseLogger):
    method __init__ (line 90) | def __init__(self, args, checkpoint_path, metric_key, filename='best_a...
    method log (line 101) | def log(self, *args, **kwargs):

FILE: trainer/lru.py
  class LRUTrainer (line 18) | class LRUTrainer(BaseTrainer):
    method __init__ (line 19) | def __init__(self, args, model, train_loader, val_loader, test_loader,...
    method calculate_loss (line 23) | def calculate_loss(self, batch):
    method calculate_metrics (line 31) | def calculate_metrics(self, batch, exclude_history=True):
    method generate_candidates (line 43) | def generate_candidates(self, retrieved_data_path):

FILE: trainer/utils.py
  function ndcg (line 17) | def ndcg(scores, labels, k):
  function absolute_recall_mrr_ndcg_for_ks (line 32) | def absolute_recall_mrr_ndcg_for_ks(scores, labels, ks):
  class AverageMeterSet (line 63) | class AverageMeterSet(object):
    method __init__ (line 64) | def __init__(self, meters=None):
    method __getitem__ (line 67) | def __getitem__(self, key):
    method update (line 74) | def update(self, name, value, n=1):
    method reset (line 79) | def reset(self):
    method values (line 83) | def values(self, format_string='{}'):
    method averages (line 86) | def averages(self, format_string='{}'):
    method sums (line 89) | def sums(self, format_string='{}'):
    method counts (line 92) | def counts(self, format_string='{}'):
  class AverageMeter (line 96) | class AverageMeter(object):
    method __init__ (line 99) | def __init__(self):
    method reset (line 105) | def reset(self):
    method update (line 111) | def update(self, val, n=1):
    method __format__ (line 117) | def __format__(self, format):

FILE: trainer/verb.py
  function convert_cfg_to_dict (line 22) | def convert_cfg_to_dict(cfg_node, key_list=[]):
  function signature (line 36) | def signature(f):
  class Verbalizer (line 72) | class Verbalizer(nn.Module):
    method __init__ (line 80) | def __init__(self,
    method label_words (line 101) | def label_words(self,):
    method label_words (line 112) | def label_words(self, label_words):
    method _match_label_words_to_label_ids (line 119) | def _match_label_words_to_label_ids(self, label_words): # TODO newly a...
    method safe_on_label_words_set (line 141) | def safe_on_label_words_set(self,):
    method on_label_words_set (line 146) | def on_label_words_set(self,):
    method vocab (line 152) | def vocab(self,) -> Dict:
    method vocab_size (line 158) | def vocab_size(self,) -> int:
    method generate_parameters (line 162) | def generate_parameters(self, **kwargs) -> List:
    method register_calibrate_logits (line 179) | def register_calibrate_logits(self, logits: torch.Tensor):
    method process_outputs (line 187) | def process_outputs(self,
    method gather_outputs (line 201) | def gather_outputs(self, outputs: ModelOutput):
    method aggregate (line 215) | def aggregate(label_words_logits: torch.Tensor) -> torch.Tensor:
    method normalize (line 232) | def normalize(self, logits: torch.Tensor) -> torch.Tensor:
    method project (line 246) | def project(self,
    method handle_multi_token (line 261) | def handle_multi_token(self, label_words_logits, mask):
    method from_config (line 285) | def from_config(cls,
    method from_file (line 313) | def from_file(self,
  class ManualVerbalizer (line 381) | class ManualVerbalizer(Verbalizer):
    method __init__ (line 393) | def __init__(self,
    method on_label_words_set (line 408) | def on_label_words_set(self):
    method add_prefix (line 418) | def add_prefix(label_words, prefix):
    method generate_parameters (line 443) | def generate_parameters(self) -> List:
    method project (line 471) | def project(self,
    method process_logits (line 490) | def process_logits(self, logits: torch.Tensor, **kwargs):
    method normalize (line 528) | def normalize(self, logits: torch.Tensor) -> torch.Tensor:
    method aggregate (line 543) | def aggregate(self, label_words_logits: torch.Tensor) -> torch.Tensor:
    method calibrate (line 555) | def calibrate(self, label_words_probs: torch.Tensor, **kwargs) -> torc...
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (170K chars).
[
  {
    "path": ".gitignore",
    "chars": 121,
    "preview": "*.pyc\n*.p\n*.pt\n*.pth\n.DS_Store\n\n/.ipynb_checkpoints/*\n/.vscode/*\n/wandb/*\n\n/data/*\n/retrieved/*\n/experiments/*\n/archive/"
  },
  {
    "path": "README.md",
    "chars": 2834,
    "preview": "# LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking\n\nThis repository is the PyTorch impelementa"
  },
  {
    "path": "config.py",
    "chars": 6344,
    "preview": "import numpy as np\nimport random\nimport torch\nimport argparse\n\n\nRAW_DATASET_ROOT_FOLDER = 'data'\nEXPERIMENT_ROOT = 'expe"
  },
  {
    "path": "dataloader/__init__.py",
    "chars": 931,
    "preview": "from datasets import dataset_factory\n\nfrom .lru import *\nfrom .llm import *\nfrom .utils import *\n\n\ndef dataloader_factor"
  },
  {
    "path": "dataloader/base.py",
    "chars": 692,
    "preview": "from abc import *\nimport random\n\n\nclass AbstractDataloader(metaclass=ABCMeta):\n    def __init__(self, args, dataset):\n  "
  },
  {
    "path": "dataloader/llm.py",
    "chars": 12598,
    "preview": "from .base import AbstractDataloader\nfrom .utils import Prompter\n\nimport torch\nimport random\nimport numpy as np\nimport t"
  },
  {
    "path": "dataloader/lru.py",
    "chars": 6519,
    "preview": "from .base import AbstractDataloader\n\nimport os\nimport torch\nimport random\nimport pickle\nimport numpy as np\nimport torch"
  },
  {
    "path": "dataloader/templates/README.md",
    "chars": 2007,
    "preview": "# Prompt templates\n\nThis directory contains template styles for the prompts used to finetune LoRA models.\n\n## Format\n\nA "
  },
  {
    "path": "dataloader/templates/alpaca.json",
    "chars": 542,
    "preview": "{\n    \"description\": \"Template used by Alpaca-LoRA.\",\n    \"prompt_input\": \"Below is an instruction that describes a task"
  },
  {
    "path": "dataloader/templates/alpaca_legacy.json",
    "chars": 561,
    "preview": "{\n    \"description\": \"Legacy template, used by Original Alpaca repository.\",\n    \"prompt_input\": \"Below is an instructio"
  },
  {
    "path": "dataloader/templates/alpaca_short.json",
    "chars": 281,
    "preview": "{\n    \"description\": \"A shorter template to experiment with.\",\n    \"prompt_input\": \"### Instruction:\\n{instruction}\\n\\n#"
  },
  {
    "path": "dataloader/templates/vigogne.json",
    "chars": 587,
    "preview": "{\n    \"description\": \"French template, used by Vigogne for finetuning.\",\n    \"prompt_input\": \"Ci-dessous se trouve une i"
  },
  {
    "path": "dataloader/utils.py",
    "chars": 1412,
    "preview": "import json\nimport os.path as osp\nfrom typing import Union\n\n\nclass Prompter(object):\n    __slots__ = (\"template\", \"_verb"
  },
  {
    "path": "datasets/__init__.py",
    "chars": 334,
    "preview": "from .ml_100k import ML100KDataset\nfrom .beauty import BeautyDataset\nfrom .games import GamesDataset\n\nDATASETS = {\n    M"
  },
  {
    "path": "datasets/base.py",
    "chars": 4076,
    "preview": "import pickle\nimport shutil\nimport tempfile\nimport os\nfrom pathlib import Path\nimport gzip\nfrom abc import *\nfrom .utils"
  },
  {
    "path": "datasets/beauty.py",
    "chars": 3233,
    "preview": "from .base import AbstractDataset\nfrom .utils import *\n\nfrom datetime import date\nfrom pathlib import Path\nimport pickle"
  },
  {
    "path": "datasets/games.py",
    "chars": 3335,
    "preview": "from .base import AbstractDataset\nfrom .utils import *\n\nfrom datetime import date\nfrom pathlib import Path\nimport pickle"
  },
  {
    "path": "datasets/ml_100k.py",
    "chars": 3587,
    "preview": "from .base import AbstractDataset\nfrom .utils import *\n\nfrom datetime import date\nfrom pathlib import Path\nimport pickle"
  },
  {
    "path": "datasets/utils.py",
    "chars": 529,
    "preview": "import numpy as np\nimport pandas as pd\nfrom tqdm import tqdm\nimport urllib.request\n\n\nfrom pathlib import Path\nimport zip"
  },
  {
    "path": "model/__init__.py",
    "chars": 37,
    "preview": "from .lru import *\nfrom .llm import *"
  },
  {
    "path": "model/llm.py",
    "chars": 40446,
    "preview": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on"
  },
  {
    "path": "model/lru.py",
    "chars": 6636,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\n\nclass LRURec(nn.Modu"
  },
  {
    "path": "requirements.txt",
    "chars": 9063,
    "preview": "# This file may be used to create an environment using:\n# $ conda create --name <env> --file <this file>\n# platform: lin"
  },
  {
    "path": "train_ranker.py",
    "chars": 2029,
    "preview": "import os\nimport torch\nos.environ['TOKENIZERS_PARALLELISM'] = 'false'\n\nimport argparse\nfrom datasets import DATASETS\nfro"
  },
  {
    "path": "train_retriever.py",
    "chars": 1486,
    "preview": "import os\nimport torch\nos.environ['TOKENIZERS_PARALLELISM'] = 'false'\n\nimport wandb\nimport argparse\n\nfrom config import "
  },
  {
    "path": "trainer/__init__.py",
    "chars": 58,
    "preview": "from .lru import *\nfrom .llm import *\nfrom .utils import *"
  },
  {
    "path": "trainer/base.py",
    "chars": 10220,
    "preview": "from model import *\nfrom config import *\nfrom .utils import *\nfrom .loggers import *\n\nimport torch\nimport torch.nn as nn"
  },
  {
    "path": "trainer/llm.py",
    "chars": 6866,
    "preview": "from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY\nfrom .verb import ManualVerbalizer\nfrom .utils import *\nfrom"
  },
  {
    "path": "trainer/loggers.py",
    "chars": 3963,
    "preview": "import os\nimport torch\nfrom abc import ABCMeta, abstractmethod\n\n\ndef save_state_dict(state_dict, path, filename):\n    to"
  },
  {
    "path": "trainer/lru.py",
    "chars": 3636,
    "preview": "from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY\nfrom .utils import *\nfrom .loggers import *\nfrom .base impor"
  },
  {
    "path": "trainer/utils.py",
    "chars": 3482,
    "preview": "from config import *\n\nimport json\nimport os\nimport pprint as pp\nimport random\nfrom datetime import date\nfrom pathlib imp"
  },
  {
    "path": "trainer/verb.py",
    "chars": 24849,
    "preview": "from abc import abstractmethod\nimport json\n\nfrom transformers.file_utils import ModelOutput\nfrom transformers.data.proce"
  }
]

About this extraction

This page contains the full source code of the Yueeeeeeee/LlamaRec GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (159.5 KB), approximately 41.4k tokens, and a symbol index with 279 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!