[
  {
    "path": ".gitignore",
    "content": "*.pyc\n*.p\n*.pt\n*.pth\n.DS_Store\n\n/.ipynb_checkpoints/*\n/.vscode/*\n/wandb/*\n\n/data/*\n/retrieved/*\n/experiments/*\n/archive/*"
  },
  {
    "path": "README.md",
    "content": "# LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking\n\nThis repository is the PyTorch impelementation for the PGAI@CIKM 2023 paper **LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking [[Paper](https://arxiv.org/abs/2311.02089)]**.\n\n<img src=media/method.png width=1000>\n\nWe propose a two-stage framework using large language models for ranking-based recommendation (LlamaRec). In particular, we use small-scale sequential recommenders to retrieve candidates based on the user interaction history. Then, both history and retrieved items are fed to the LLM in text via a carefully designed prompt template. Instead of generating next-item titles, we adopt a verbalizer-based approach that transforms output logits into probability distributions over the candidate items. Therefore, LlamaRec can efficiently rank items without generating long text and achieve superior performance in both recommendation performance and efficiency.\n\n\n## Requirements\n\nPytorch, transformers, peft, bitsandbytes etc. For our detailed running environment see requirements.txt.\n\n\n## How to run LlamaRec\nThe command below starts the training of the retriever model LRURec\n```bash\npython train_retriever.py\n```\nYou can set additional arguments like weight_decay to change the hyperparameters. Upon the command, you will be prompted to select dataset from ML-100k, Beauty and Games. Once training is finished, evaluation is automatically performed with the best retriever model.\n\nThen, run the following command to train the ranker model based on Llama 2\n```bash\npython train_ranker.py --llm_retrieved_path PATH_TO_RETRIEVER\n```\nPlease specify PATH_TO_RETRIEVER with the retriever path from the previous step. To run this command, you will need access to meta-llama/Llama-2-7b-hf on the HF hub. Similarly, evaluation is performed after training is finished. All weights and results are saved under ./experiments.\n\n\n## Performance\n\nThe table below reports our main performance results, with best results marked in bold and second best results underlined. For training and evaluation details, please refer to our paper.\n\n<img src=media/performance.png width=1000>\n\n\n## Citation\nPlease consider citing the following papers if you use our methods in your research:\n```\n@article{yue2023linear,\n  title={Linear Recurrent Units for Sequential Recommendation},\n  author={Yue, Zhenrui and Wang, Yueqi and He, Zhankui and Zeng, Huimin and McAuley, Julian and Wang, Dong},\n  journal={arXiv preprint arXiv:2310.02367},\n  year={2023}\n}\n\n@article{yue2023llamarec,\n  title={LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking},\n  author={Yue, Zhenrui and Rabhi, Sara and Moreira, Gabriel de Souza Pereira and Wang, Dong and Oldridge, Even},\n  journal={arXiv preprint arXiv:2311.02089},\n  year={2023}\n}\n```\n"
  },
  {
    "path": "config.py",
    "content": "import numpy as np\nimport random\nimport torch\nimport argparse\n\n\nRAW_DATASET_ROOT_FOLDER = 'data'\nEXPERIMENT_ROOT = 'experiments'\nSTATE_DICT_KEY = 'model_state_dict'\nOPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict'\nPROJECT_NAME = 'llmrec'\n\n\ndef set_template(args):\n    if args.dataset_code == None:\n        print('******************** Dataset Selection ********************')\n        dataset_code = {'1': 'ml-100k', 'b': 'beauty', 'g': 'games'}\n        args.dataset_code = dataset_code[input('Input 1 for ml-100k, b for beauty and g for games: ')]\n\n    if args.dataset_code == 'ml-100k':\n        args.bert_max_len = 200\n    else:\n        args.bert_max_len = 50\n\n    if 'llm' in args.model_code: \n        batch = 16 if args.dataset_code == 'ml-100k' else 12\n        args.lora_micro_batch_size = batch\n    else: \n        batch = 16 if args.dataset_code == 'ml-100k' else 64\n\n    args.train_batch_size = batch\n    args.val_batch_size = batch\n    args.test_batch_size = batch\n\n    if torch.cuda.is_available(): args.device = 'cuda'\n    else: args.device = 'cpu'\n    args.optimizer = 'AdamW'\n    args.lr = 0.001\n    args.weight_decay = 0.01\n    args.enable_lr_schedule = False\n    args.decay_step = 10000\n    args.gamma = 1.\n    args.enable_lr_warmup = False\n    args.warmup_steps = 100\n\n    args.metric_ks = [1, 5, 10, 20, 50]\n    args.rerank_metric_ks = [1, 5, 10]\n    args.best_metric = 'Recall@10'\n    args.rerank_best_metric = 'NDCG@10'\n\n    args.bert_num_blocks = 2\n    args.bert_num_heads = 2\n    args.bert_head_size = None\n\n\nparser = argparse.ArgumentParser()\n\n################\n# Dataset\n################\nparser.add_argument('--dataset_code', type=str, default=None)\nparser.add_argument('--min_rating', type=int, default=0)\nparser.add_argument('--min_uc', type=int, default=5)\nparser.add_argument('--min_sc', type=int, default=5)\nparser.add_argument('--seed', type=int, default=42)\n\n################\n# Dataloader\n################\nparser.add_argument('--train_batch_size', type=int, default=64)\nparser.add_argument('--val_batch_size', type=int, default=64)\nparser.add_argument('--test_batch_size', type=int, default=64)\nparser.add_argument('--num_workers', type=int, default=8)\nparser.add_argument('--sliding_window_size', type=float, default=1.0)\nparser.add_argument('--negative_sample_size', type=int, default=10)\n\n################\n# Trainer\n################\n# optimization #\nparser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])\nparser.add_argument('--num_epochs', type=int, default=500)\nparser.add_argument('--optimizer', type=str, default='AdamW', choices=['AdamW', 'Adam'])\nparser.add_argument('--weight_decay', type=float, default=None)\nparser.add_argument('--adam_epsilon', type=float, default=1e-9)\nparser.add_argument('--momentum', type=float, default=None)\nparser.add_argument('--lr', type=float, default=0.001)\nparser.add_argument('--max_grad_norm', type=float, default=5.0)\nparser.add_argument('--enable_lr_schedule', type=bool, default=True)\nparser.add_argument('--decay_step', type=int, default=10000)\nparser.add_argument('--gamma', type=float, default=1)\nparser.add_argument('--enable_lr_warmup', type=bool, default=True)\nparser.add_argument('--warmup_steps', type=int, default=100)\n\n# evaluation #\nparser.add_argument('--val_strategy', type=str, default='iteration', choices=['epoch', 'iteration'])\nparser.add_argument('--val_iterations', type=int, default=500)  # only for iteration val_strategy\nparser.add_argument('--early_stopping', type=bool, default=True)\nparser.add_argument('--early_stopping_patience', type=int, default=20)\nparser.add_argument('--metric_ks', nargs='+', type=int, default=[1, 5, 10, 20, 50])\nparser.add_argument('--rerank_metric_ks', nargs='+', type=int, default=[1, 5, 10])\nparser.add_argument('--best_metric', type=str, default='Recall@10')\nparser.add_argument('--rerank_best_metric', type=str, default='NDCG@10')\nparser.add_argument('--use_wandb', type=bool, default=False)\n\n################\n# Retriever Model\n################\nparser.add_argument('--model_code', type=str, default=None)\nparser.add_argument('--bert_max_len', type=int, default=50)\nparser.add_argument('--bert_hidden_units', type=int, default=64)\nparser.add_argument('--bert_num_blocks', type=int, default=2)\nparser.add_argument('--bert_num_heads', type=int, default=2)\nparser.add_argument('--bert_head_size', type=int, default=32)\nparser.add_argument('--bert_dropout', type=float, default=0.2)\nparser.add_argument('--bert_attn_dropout', type=float, default=0.2)\nparser.add_argument('--bert_mask_prob', type=float, default=0.25)\n\n################\n# LLM Model\n################\nparser.add_argument('--llm_base_model', type=str, default='meta-llama/Llama-2-7b-hf')\nparser.add_argument('--llm_base_tokenizer', type=str, default='meta-llama/Llama-2-7b-hf')\nparser.add_argument('--llm_max_title_len', type=int, default=32)\nparser.add_argument('--llm_max_text_len', type=int, default=1536)\nparser.add_argument('--llm_max_history', type=int, default=20)\nparser.add_argument('--llm_train_on_inputs', type=bool, default=False)\nparser.add_argument('--llm_negative_sample_size', type=int, default=19)  # 19 negative & 1 positive\nparser.add_argument('--llm_system_template', type=str,  # instruction\n    default=\"Given user history in chronological order, recommend an item from the candidate pool with its index letter.\")\nparser.add_argument('--llm_input_template', type=str, \\\n    default='User history: {}; \\n Candidate pool: {}')\nparser.add_argument('--llm_load_in_4bit', type=bool, default=True)\nparser.add_argument('--llm_retrieved_path', type=str, default=None)\nparser.add_argument('--llm_cache_dir', type=str, default=None)\n\n################\n# Lora\n################\nparser.add_argument('--lora_r', type=int, default=8)\nparser.add_argument('--lora_alpha', type=int, default=32)\nparser.add_argument('--lora_dropout', type=float, default=0.05)\nparser.add_argument('--lora_target_modules', type=list, default=['q_proj', 'v_proj'])\nparser.add_argument('--lora_num_epochs', type=int, default=1)\nparser.add_argument('--lora_val_iterations', type=int, default=100)\nparser.add_argument('--lora_early_stopping_patience', type=int, default=20)\nparser.add_argument('--lora_lr', type=float, default=1e-4)\nparser.add_argument('--lora_micro_batch_size', type=int, default=16)\n\n################\n\n\nargs = parser.parse_args()\n"
  },
  {
    "path": "dataloader/__init__.py",
    "content": "from datasets import dataset_factory\n\nfrom .lru import *\nfrom .llm import *\nfrom .utils import *\n\n\ndef dataloader_factory(args):\n    dataset = dataset_factory(args)\n    if args.model_code == 'lru':\n        dataloader = LRUDataloader(args, dataset)\n    elif args.model_code == 'llm':\n        dataloader = LLMDataloader(args, dataset)\n    \n    train, val, test = dataloader.get_pytorch_dataloaders()\n    if 'llm' in args.model_code:\n        tokenizer = dataloader.tokenizer\n        test_retrieval = dataloader.test_retrieval\n        return train, val, test, tokenizer, test_retrieval\n    else:\n        return train, val, test\n\n\ndef test_subset_dataloader_loader(args):\n    dataset = dataset_factory(args)\n    if args.model_code == 'lru':\n        dataloader = LRUDataloader(args, dataset)\n    elif args.model_code == 'llm':\n        dataloader = LLMDataloader(args, dataset)\n\n    return dataloader.get_pytorch_test_subset_dataloader()\n"
  },
  {
    "path": "dataloader/base.py",
    "content": "from abc import *\nimport random\n\n\nclass AbstractDataloader(metaclass=ABCMeta):\n    def __init__(self, args, dataset):\n        self.args = args\n        self.save_folder = dataset._get_preprocessed_folder_path()\n        dataset = dataset.load_dataset()\n        self.train = dataset['train']\n        self.val = dataset['val']\n        self.test = dataset['test']\n        self.meta = dataset['meta']\n        self.umap = dataset['umap']\n        self.smap = dataset['smap']\n        self.user_count = len(self.umap)\n        self.item_count = len(self.smap)\n\n    @classmethod\n    @abstractmethod\n    def code(cls):\n        pass\n\n    @abstractmethod\n    def get_pytorch_dataloaders(self):\n        pass\n"
  },
  {
    "path": "dataloader/llm.py",
    "content": "from .base import AbstractDataloader\nfrom .utils import Prompter\n\nimport torch\nimport random\nimport numpy as np\nimport torch.utils.data as data_utils\n\nimport os\nimport pickle\nimport transformers\nfrom transformers import AutoTokenizer\nfrom transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT\nfrom trainer import absolute_recall_mrr_ndcg_for_ks\n\n\ndef worker_init_fn(worker_id):\n    random.seed(np.random.get_state()[1][0] + worker_id)                                                      \n    np.random.seed(np.random.get_state()[1][0] + worker_id)\n\n\n# the following prompting is based on alpaca\ndef generate_and_tokenize_eval(args, data_point, tokenizer, prompter):\n    in_prompt = prompter.generate_prompt(data_point[\"system\"],\n                                         data_point[\"input\"])\n    tokenized_full_prompt = tokenizer(in_prompt,\n                                      truncation=True,\n                                      max_length=args.llm_max_text_len,\n                                      padding=False,\n                                      return_tensors=None)\n    tokenized_full_prompt[\"labels\"] = ord(data_point[\"output\"]) - ord('A')\n    \n    return tokenized_full_prompt\n\n\ndef generate_and_tokenize_train(args, data_point, tokenizer, prompter):\n    def tokenize(prompt, add_eos_token=True):\n        result = tokenizer(prompt,\n                           truncation=True,\n                           max_length=args.llm_max_text_len,\n                           padding=False,\n                           return_tensors=None)\n        if (result[\"input_ids\"][-1] != tokenizer.eos_token_id and add_eos_token):\n            result[\"input_ids\"].append(tokenizer.eos_token_id)\n            result[\"attention_mask\"].append(1)\n\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n    full_prompt = prompter.generate_prompt(data_point[\"system\"],\n                                           data_point[\"input\"],\n                                           data_point[\"output\"])\n    tokenized_full_prompt = tokenize(full_prompt, add_eos_token=True)\n    if not args.llm_train_on_inputs:\n        tokenized_full_prompt[\"labels\"][:-2] = [-100] * len(tokenized_full_prompt[\"labels\"][:-2])\n    \n    return tokenized_full_prompt\n\n\ndef seq_to_token_ids(args, seq, candidates, label, text_dict, tokenizer, prompter, eval=False):\n    def truncate_title(title):\n        title_ = tokenizer.tokenize(title)[:args.llm_max_title_len]\n        title = tokenizer.convert_tokens_to_string(title_)\n        return title\n\n    seq_t = ' \\n '.join(['(' + str(idx + 1) + ') ' + truncate_title(text_dict[item]) \n                       for idx, item in enumerate(seq)])\n    can_t = ' \\n '.join(['(' + chr(ord('A') + idx) + ') ' + truncate_title(text_dict[item])\n                       for idx, item in enumerate(candidates)])\n    output = chr(ord('A') + candidates.index(label))  # ranking only\n    \n    data_point = {}\n    data_point['system'] = args.llm_system_template if args.llm_system_template is not None else DEFAULT_SYSTEM_PROMPT\n    data_point['input'] = args.llm_input_template.format(seq_t, can_t)\n    data_point['output'] = output\n    \n    if eval:\n        return generate_and_tokenize_eval(args, data_point, tokenizer, prompter)\n    else:\n        return generate_and_tokenize_train(args, data_point, tokenizer, prompter)\n\n\nclass LLMDataloader():\n    def __init__(self, args, dataset):\n        self.args = args\n        self.rng = np.random\n        self.save_folder = dataset._get_preprocessed_folder_path()\n        seq_dataset = dataset.load_dataset()\n        self.train = seq_dataset['train']\n        self.val = seq_dataset['val']\n        self.test = seq_dataset['test']\n        self.umap = seq_dataset['umap']\n        self.smap = seq_dataset['smap']\n        self.text_dict = seq_dataset['meta']\n        self.user_count = len(self.umap)\n        self.item_count = len(self.smap)\n        \n        args.num_items = self.item_count\n        self.max_len = args.llm_max_history\n        \n        self.tokenizer = AutoTokenizer.from_pretrained(\n            args.llm_base_tokenizer, cache_dir=args.llm_cache_dir)\n        self.tokenizer.pad_token = self.tokenizer.unk_token\n        self.tokenizer.padding_side = 'left'\n        self.tokenizer.truncation_side = 'left'\n        self.tokenizer.clean_up_tokenization_spaces = True\n        self.prompter = Prompter()\n        \n        self.llm_retrieved_path = args.llm_retrieved_path\n        print('Loading retrieved file from {}'.format(self.llm_retrieved_path))\n        retrieved_file = pickle.load(open(os.path.join(args.llm_retrieved_path,\n                                                       'retrieved.pkl'), 'rb'))\n        \n        print('******************** Constructing Validation Subset ********************')\n        self.val_probs = retrieved_file['val_probs']\n        self.val_labels = retrieved_file['val_labels']\n        self.val_metrics = retrieved_file['val_metrics']\n        self.val_users = [u for u, (p, l) in enumerate(zip(self.val_probs, self.val_labels), start=1) \\\n                          if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]\n        self.val_candidates = [torch.topk(torch.tensor(self.val_probs[u-1]), \n                                self.args.llm_negative_sample_size+1).indices.tolist() for u in self.val_users]\n\n        print('******************** Constructing Test Subset ********************')\n        self.test_probs = retrieved_file['test_probs']\n        self.test_labels = retrieved_file['test_labels']\n        self.test_metrics = retrieved_file['test_metrics']\n        self.test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \\\n                          if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]\n        self.test_candidates = [torch.topk(torch.tensor(self.test_probs[u-1]), \n                                self.args.llm_negative_sample_size+1).indices.tolist() for u in self.test_users]\n        self.non_test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \\\n                               if l not in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]\n        self.test_retrieval = {\n            'original_size': len(self.test_probs),\n            'retrieval_size': len(self.test_candidates),\n            'original_metrics': self.test_metrics,\n            'retrieval_metrics': absolute_recall_mrr_ndcg_for_ks(\n                torch.tensor(self.test_probs)[torch.tensor(self.test_users)-1],\n                torch.tensor(self.test_labels)[torch.tensor(self.test_users)-1],\n                self.args.metric_ks,\n            ),\n            'non_retrieval_metrics': absolute_recall_mrr_ndcg_for_ks(\n                torch.tensor(self.test_probs)[torch.tensor(self.non_test_users)-1],\n                torch.tensor(self.test_labels)[torch.tensor(self.non_test_users)-1],\n                self.args.metric_ks,\n            ),\n        }\n\n    @classmethod\n    def code(cls):\n        return 'llm'\n\n    def get_pytorch_dataloaders(self):\n        train_loader = self._get_train_loader()\n        val_loader = self._get_val_loader()\n        test_loader = self._get_test_loader()\n        return train_loader, val_loader, test_loader\n\n    def _get_train_loader(self):\n        dataset = self._get_train_dataset()\n        dataloader = data_utils.DataLoader(dataset, batch_size=self.args.lora_micro_batch_size,\n                                           shuffle=True, pin_memory=True, num_workers=self.args.num_workers,\n                                           worker_init_fn=worker_init_fn)\n        return dataloader\n\n    def _get_train_dataset(self):\n        dataset = LLMTrainDataset(self.args, self.train, self.max_len, self.rng,\n                                  self.text_dict, self.tokenizer, self.prompter)\n        return dataset\n\n    def _get_val_loader(self):\n        return self._get_eval_loader(mode='val')\n\n    def _get_test_loader(self):\n        return self._get_eval_loader(mode='test')\n\n    def _get_eval_loader(self, mode):\n        batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size\n        dataset = self._get_eval_dataset(mode)\n        dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True,\n                                           pin_memory=True, num_workers=self.args.num_workers)\n        return dataloader\n\n    def _get_eval_dataset(self, mode):\n        if mode == 'val':\n            dataset = LLMValidDataset(self.args, self.train, self.val, self.max_len, self.rng, \\\n                                      self.text_dict, self.tokenizer, self.prompter, self.val_users, \\\n                                      self.val_candidates)\n        elif mode == 'test':\n            dataset = LLMTestDataset(self.args, self.train, self.val, self.test, self.max_len, \\\n                                     self.rng, self.text_dict, self.tokenizer, self.prompter, self.test_users, \\\n                                     self.test_candidates)\n        return dataset\n\n\nclass LLMTrainDataset(data_utils.Dataset):\n    def __init__(self, args, u2seq, max_len, rng, text_dict, tokenizer, prompter):\n        self.args = args\n        self.max_len = max_len\n        self.num_items = args.num_items\n        self.rng = rng\n        self.text_dict = text_dict\n        self.tokenizer = tokenizer\n        self.prompter = prompter\n\n        self.all_seqs = []\n        for u in sorted(u2seq.keys()):\n            seq = u2seq[u]\n            for i in range(2, len(seq)+1):\n                self.all_seqs += [seq[:i]]\n\n    def __len__(self):\n        return len(self.all_seqs)\n\n    def __getitem__(self, index):\n        tokens = self.all_seqs[index]\n        answer = tokens[-1]\n        original_seq = tokens[:-1]\n        \n        seq = original_seq[-self.max_len:]\n        cur_idx, candidates = 0, [answer]\n        samples = self.rng.randint(1, self.args.num_items+1, size=5*self.args.llm_negative_sample_size)\n        while len(candidates) < self.args.llm_negative_sample_size + 1:\n            item = samples[cur_idx]\n            cur_idx += 1\n            if item in original_seq or item == answer: continue\n            else: candidates.append(item)\n        self.rng.shuffle(candidates)\n\n        return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, \\\n                                self.tokenizer, self.prompter, eval=False)\n\n\nclass LLMValidDataset(data_utils.Dataset):\n    def __init__(self, args, u2seq, u2answer, max_len, rng, text_dict, tokenizer, prompter, val_users, val_candidates):\n        self.args = args\n        self.u2seq = u2seq\n        self.u2answer = u2answer\n        self.users = sorted(self.u2seq.keys())\n        self.max_len = max_len\n        self.rng = rng\n        self.text_dict = text_dict\n        self.tokenizer = tokenizer\n        self.prompter = prompter\n        self.val_users = val_users\n        self.val_candidates = val_candidates\n\n    def __len__(self):\n        return len(self.val_users)\n\n    def __getitem__(self, index):\n        user = self.val_users[index]\n        seq = self.u2seq[user]\n        answer = self.u2answer[user][0]\n        \n        seq = seq[-self.max_len:]\n        candidates = self.val_candidates[index]\n        assert answer in candidates\n        # self.rng.shuffle(candidates)\n        \n        return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True)\n\n\nclass LLMTestDataset(data_utils.Dataset):\n    def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, text_dict, tokenizer, prompter, test_users, test_candidates):\n        self.args = args\n        self.u2seq = u2seq\n        self.u2val = u2val\n        self.u2answer = u2answer\n        self.users = sorted(u2seq.keys())\n        self.max_len = max_len\n        self.rng = rng\n        self.text_dict = text_dict\n        self.tokenizer = tokenizer\n        self.prompter = prompter\n        self.test_users = test_users\n        self.test_candidates = test_candidates\n    \n    def __len__(self):\n        return len(self.test_users)\n    \n    def __getitem__(self, index):\n        user = self.test_users[index]\n        seq = self.u2seq[user] + self.u2val[user]\n        answer = self.u2answer[user][0]\n\n        seq = seq[-self.max_len:]\n        candidates = self.test_candidates[index]\n        assert answer in candidates\n        # self.rng.shuffle(candidates)\n\n        return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True)"
  },
  {
    "path": "dataloader/lru.py",
    "content": "from .base import AbstractDataloader\n\nimport os\nimport torch\nimport random\nimport pickle\nimport numpy as np\nimport torch.utils.data as data_utils\n\n\ndef worker_init_fn(worker_id):\n    random.seed(np.random.get_state()[1][0] + worker_id)                                                      \n    np.random.seed(np.random.get_state()[1][0] + worker_id)\n\n\nclass LRUDataloader():\n    def __init__(self, args, dataset):\n        self.args = args\n        self.rng = np.random\n        self.save_folder = dataset._get_preprocessed_folder_path()\n        dataset = dataset.load_dataset()\n        self.train = dataset['train']\n        self.val = dataset['val']\n        self.test = dataset['test']\n        self.umap = dataset['umap']\n        self.smap = dataset['smap']\n        self.user_count = len(self.umap)\n        self.item_count = len(self.smap)\n\n        args.num_users = self.user_count\n        args.num_items = self.item_count\n        self.max_len = args.bert_max_len\n        self.sliding_size = args.sliding_window_size\n\n    @classmethod\n    def code(cls):\n        return 'lru'\n\n    def get_pytorch_dataloaders(self):\n        train_loader = self._get_train_loader()\n        val_loader = self._get_val_loader()\n        test_loader = self._get_test_loader()\n        return train_loader, val_loader, test_loader\n    \n    def get_pytorch_test_subset_dataloader(self):\n        retrieved_file_path = self.args.llm_retrieved_path\n        print('Loading retrieved file from {}'.format(retrieved_file_path))\n        retrieved_file = pickle.load(open(os.path.join(retrieved_file_path,\n                                                       'retrieved.pkl'), 'rb'))\n        \n        test_probs = retrieved_file['test_probs']\n        test_labels = retrieved_file['test_labels']\n        test_users = [u for u, (p, l) in enumerate(zip(test_probs, test_labels), start=1) \\\n                      if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]\n\n        dataset = dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len, \n                                           self.rng, subset_users=test_users)\n        dataloader = data_utils.DataLoader(dataset, batch_size=self.args.val_batch_size, shuffle=False,\n                                           pin_memory=True, num_workers=self.args.num_workers)\n        return dataloader\n\n    def _get_train_loader(self):\n        dataset = self._get_train_dataset()\n        dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,\n                        shuffle=True, pin_memory=True, num_workers=self.args.num_workers,\n                        worker_init_fn=worker_init_fn)\n        return dataloader\n\n    def _get_train_dataset(self):\n        dataset = LRUTrainDataset(\n            self.args, self.train, self.max_len, self.sliding_size, self.rng)\n        return dataset\n\n    def _get_val_loader(self):\n        return self._get_eval_loader(mode='val')\n\n    def _get_test_loader(self):\n        return self._get_eval_loader(mode='test')\n\n    def _get_eval_loader(self, mode):\n        batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size\n        dataset = self._get_eval_dataset(mode)\n        dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False,\n                        pin_memory=True, num_workers=self.args.num_workers)\n        return dataloader\n\n    def _get_eval_dataset(self, mode):\n        if mode == 'val':\n            dataset = LRUValidDataset(self.args, self.train, self.val, self.max_len, self.rng)\n        elif mode == 'test':\n            dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len, self.rng)\n        return dataset\n\n\nclass LRUTrainDataset(data_utils.Dataset):\n    def __init__(self, args, u2seq, max_len, sliding_size, rng):\n        self.args = args\n        self.max_len = max_len\n        self.sliding_step = int(sliding_size * max_len)\n        self.num_items = args.num_items\n        self.rng = rng\n        \n        assert self.sliding_step > 0\n        self.all_seqs = []\n        for u in sorted(u2seq.keys()):\n            seq = u2seq[u]\n            if len(seq) < self.max_len + self.sliding_step:\n                self.all_seqs.append(seq)\n            else:\n                start_idx = range(len(seq) - max_len, -1, -self.sliding_step)\n                self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx]\n\n    def __len__(self):\n        return len(self.all_seqs)\n\n    def __getitem__(self, index):\n        seq = self.all_seqs[index]\n        labels = seq[-self.max_len:]\n        tokens = seq[:-1][-self.max_len:]\n\n        mask_len = self.max_len - len(tokens)\n        tokens = [0] * mask_len + tokens\n\n        mask_len = self.max_len - len(labels)\n        labels = [0] * mask_len + labels\n\n        return torch.LongTensor(tokens), torch.LongTensor(labels)\n\n\nclass LRUValidDataset(data_utils.Dataset):\n    def __init__(self, args, u2seq, u2answer, max_len, rng):\n        self.args = args\n        self.u2seq = u2seq\n        self.u2answer = u2answer\n        users = sorted(self.u2seq.keys())\n        self.users = [u for u in users if len(u2answer[u]) > 0]\n        self.max_len = max_len\n        self.rng = rng\n    \n    def __len__(self):\n        return len(self.users)\n\n    def __getitem__(self, index):\n        user = self.users[index]\n        seq = self.u2seq[user]\n        answer = self.u2answer[user]\n\n        seq = seq[-self.max_len:]\n        padding_len = self.max_len - len(seq)\n        seq = [0] * padding_len + seq\n\n        return torch.LongTensor(seq), torch.LongTensor(answer)\n\n\nclass LRUTestDataset(data_utils.Dataset):\n    def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, subset_users=None):\n        self.args = args\n        self.u2seq = u2seq\n        self.u2val = u2val\n        self.u2answer = u2answer\n        users = sorted(self.u2seq.keys())\n        self.users = [u for u in users if len(u2val[u]) > 0 and len(u2answer[u]) > 0]\n        self.max_len = max_len\n        self.rng = rng\n        \n        if subset_users is not None:\n            self.users = subset_users\n\n    def __len__(self):\n        return len(self.users)\n\n    def __getitem__(self, index):\n        user = self.users[index]\n        seq = self.u2seq[user] + self.u2val[user]\n        answer = self.u2answer[user]\n\n        seq = seq[-self.max_len:]\n        padding_len = self.max_len - len(seq)\n        seq = [0] * padding_len + seq\n\n        return torch.LongTensor(seq), torch.LongTensor(answer)"
  },
  {
    "path": "dataloader/templates/README.md",
    "content": "# Prompt templates\n\nThis directory contains template styles for the prompts used to finetune LoRA models.\n\n## Format\n\nA template is described via a JSON file with the following keys:\n\n- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.\n- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.\n- `description`: A short description of the template, with possible use cases.\n- `response_split`: The text to use as separator when cutting real response from the model output.\n\nNo `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.\n\n## Example template\n\nThe default template, used unless otherwise specified, is `alpaca.json`\n\n```json\n{\n    \"description\": \"Template used by Alpaca-LoRA.\",\n    \"prompt_input\": \"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\",\n    \"prompt_no_input\": \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{instruction}\\n\\n### Response:\\n\",\n    \"response_split\": \"### Response:\"    \n}\n\n```\n\n## Current templates\n\n### alpaca\n\nDefault template used for generic LoRA fine tunes so far.\n\n### alpaca_legacy\n\nLegacy template used by the original alpaca repo, with no `\\n` after the response field. Kept for reference and experiments.\n\n### alpaca_short\n\nA trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.\n\n### vigogne\n\nThe default alpaca template, translated to french. This template was used to train the \"Vigogne\" LoRA and is to be used to query it, or for extra fine tuning.\n"
  },
  {
    "path": "dataloader/templates/alpaca.json",
    "content": "{\n    \"description\": \"Template used by Alpaca-LoRA.\",\n    \"prompt_input\": \"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\",\n    \"prompt_no_input\": \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{instruction}\\n\\n### Response:\\n\",\n    \"response_split\": \"### Response:\"    \n}\n"
  },
  {
    "path": "dataloader/templates/alpaca_legacy.json",
    "content": "{\n    \"description\": \"Legacy template, used by Original Alpaca repository.\",\n    \"prompt_input\": \"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\",\n    \"prompt_no_input\": \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\n{instruction}\\n\\n### Response:\",\n    \"response_split\": \"### Response:\"    \n}\n"
  },
  {
    "path": "dataloader/templates/alpaca_short.json",
    "content": "{\n    \"description\": \"A shorter template to experiment with.\",\n    \"prompt_input\": \"### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\",\n    \"prompt_no_input\": \"### Instruction:\\n{instruction}\\n\\n### Response:\\n\",\n    \"response_split\": \"### Response:\"    \n}\n"
  },
  {
    "path": "dataloader/templates/vigogne.json",
    "content": "{\n    \"description\": \"French template, used by Vigogne for finetuning.\",\n    \"prompt_input\": \"Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\\n\\n### Instruction:\\n{instruction}\\n\\n### Entrée:\\n{input}\\n\\n### Réponse:\\n\",\n    \"prompt_no_input\": \"Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\\n\\n### Instruction:\\n{instruction}\\n\\n### Réponse:\\n\",\n    \"response_split\": \"### Réponse:\"\n}\n"
  },
  {
    "path": "dataloader/utils.py",
    "content": "import json\nimport os.path as osp\nfrom typing import Union\n\n\nclass Prompter(object):\n    __slots__ = (\"template\", \"_verbose\")\n\n    def __init__(self, template_name: str = \"\", verbose: bool = False):\n        self._verbose = verbose\n        if not template_name:\n            # template_name = \"alpaca\"\n            template_name = \"alpaca_short\"\n        file_name = osp.join(\"dataloader\", \"templates\", f\"{template_name}.json\")\n        if not osp.exists(file_name):\n            raise ValueError(f\"Can't read {file_name}\")\n        with open(file_name) as fp:\n            self.template = json.load(fp)\n        if self._verbose:\n            print(\n                f\"Using prompt template {template_name}: {self.template['description']}\"\n            )\n\n    def generate_prompt(\n        self,\n        instruction: str,\n        input: Union[None, str] = None,\n        label: Union[None, str] = None,\n    ) -> str:\n        if input:\n            res = self.template[\"prompt_input\"].format(\n                instruction=instruction, input=input\n            )\n        else:\n            res = self.template[\"prompt_no_input\"].format(\n                instruction=instruction\n            )\n        if label:\n            res = f\"{res}{label}\"\n        if self._verbose:\n            print(res)\n        return res\n\n    def get_response(self, output: str) -> str:\n        return output.split(self.template[\"response_split\"])[1].strip()"
  },
  {
    "path": "datasets/__init__.py",
    "content": "from .ml_100k import ML100KDataset\nfrom .beauty import BeautyDataset\nfrom .games import GamesDataset\n\nDATASETS = {\n    ML100KDataset.code(): ML100KDataset,\n    BeautyDataset.code(): BeautyDataset,\n    GamesDataset.code(): GamesDataset,\n}\n\n\ndef dataset_factory(args):\n    dataset = DATASETS[args.dataset_code]\n    return dataset(args)\n"
  },
  {
    "path": "datasets/base.py",
    "content": "import pickle\nimport shutil\nimport tempfile\nimport os\nfrom pathlib import Path\nimport gzip\nfrom abc import *\nfrom .utils import *\nfrom config import RAW_DATASET_ROOT_FOLDER\n\nimport numpy as np\nimport pandas as pd\nfrom tqdm import tqdm\ntqdm.pandas()\n\n\nclass AbstractDataset(metaclass=ABCMeta):\n    def __init__(self, args):\n        self.args = args\n        self.min_rating = args.min_rating\n        self.min_uc = args.min_uc\n        self.min_sc = args.min_sc\n\n        assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test'\n\n    @classmethod\n    @abstractmethod\n    def code(cls):\n        pass\n\n    @classmethod\n    def raw_code(cls):\n        return cls.code()\n\n    @classmethod\n    def zip_file_content_is_folder(cls):\n        return True\n\n    @classmethod\n    def all_raw_file_names(cls):\n        return []\n\n    @classmethod\n    @abstractmethod\n    def url(cls):\n        pass\n\n    @abstractmethod\n    def preprocess(self):\n        pass\n\n    @abstractmethod\n    def load_ratings_df(self):\n        pass\n\n    @abstractmethod\n    def maybe_download_raw_dataset(self):\n        pass\n\n    def load_dataset(self):\n        self.preprocess()\n        dataset_path = self._get_preprocessed_dataset_path()\n        dataset = pickle.load(dataset_path.open('rb'))\n        return dataset\n\n    def filter_triplets(self, df):\n        print('Filtering triplets')\n        if self.min_sc > 1 or self.min_uc > 1:\n            item_sizes = df.groupby('sid').size()\n            good_items = item_sizes.index[item_sizes >= self.min_sc]\n            user_sizes = df.groupby('uid').size()\n            good_users = user_sizes.index[user_sizes >= self.min_uc]\n            while len(good_items) < len(item_sizes) or len(good_users) < len(user_sizes):\n                if self.min_sc > 1:\n                    item_sizes = df.groupby('sid').size()\n                    good_items = item_sizes.index[item_sizes >= self.min_sc]\n                    df = df[df['sid'].isin(good_items)]\n\n                if self.min_uc > 1:\n                    user_sizes = df.groupby('uid').size()\n                    good_users = user_sizes.index[user_sizes >= self.min_uc]\n                    df = df[df['uid'].isin(good_users)]\n\n                item_sizes = df.groupby('sid').size()\n                good_items = item_sizes.index[item_sizes >= self.min_sc]\n                user_sizes = df.groupby('uid').size()\n                good_users = user_sizes.index[user_sizes >= self.min_uc]\n        return df\n    \n    def densify_index(self, df):\n        print('Densifying index')\n        umap = {u: i for i, u in enumerate(set(df['uid']), start=1)}\n        smap = {s: i for i, s in enumerate(set(df['sid']), start=1)}\n        df['uid'] = df['uid'].map(umap)\n        df['sid'] = df['sid'].map(smap)\n        return df, umap, smap\n\n    def split_df(self, df, user_count):\n        print('Splitting')\n        user_group = df.groupby('uid')\n        user2items = user_group.progress_apply(\n            lambda d: list(d.sort_values(by=['timestamp', 'sid'])['sid']))\n        train, val, test = {}, {}, {}\n        for i in range(user_count):\n            user = i + 1\n            items = user2items[user]\n            train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]\n        return train, val, test\n\n    def _get_rawdata_root_path(self):\n        return Path(RAW_DATASET_ROOT_FOLDER)\n\n    def _get_rawdata_folder_path(self):\n        root = self._get_rawdata_root_path()\n        return root.joinpath(self.raw_code())\n\n    def _get_preprocessed_root_path(self):\n        root = self._get_rawdata_root_path()\n        return root.joinpath('preprocessed')\n\n    def _get_preprocessed_folder_path(self):\n        preprocessed_root = self._get_preprocessed_root_path()\n        folder_name = '{}_min_rating{}-min_uc{}-min_sc{}' \\\n            .format(self.code(), self.min_rating, self.min_uc, self.min_sc)\n        return preprocessed_root.joinpath(folder_name)\n\n    def _get_preprocessed_dataset_path(self):\n        folder = self._get_preprocessed_folder_path()\n        return folder.joinpath('dataset.pkl')\n"
  },
  {
    "path": "datasets/beauty.py",
    "content": "from .base import AbstractDataset\nfrom .utils import *\n\nfrom datetime import date\nfrom pathlib import Path\nimport pickle\nimport shutil\nimport tempfile\nimport os\n\nimport gzip\nimport json\nimport numpy as np\nimport pandas as pd\nfrom tqdm import tqdm\ntqdm.pandas()\n\n\nclass BeautyDataset(AbstractDataset):\n    @classmethod\n    def code(cls):\n        return 'beauty'\n\n    @classmethod\n    def url(cls):\n        return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Beauty.csv',\n                'http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Beauty.json.gz']\n\n    @classmethod\n    def zip_file_content_is_folder(cls):\n        return True\n\n    @classmethod\n    def all_raw_file_names(cls):\n        return ['beauty.csv', 'beauty_meta.json.gz']\n\n    def maybe_download_raw_dataset(self):\n        folder_path = self._get_rawdata_folder_path()\n        if folder_path.is_dir() and\\\n           all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):\n            print('Raw data already exists. Skip downloading')\n            return\n        \n        print(\"Raw file doesn't exist. Downloading...\")\n        for idx, url in enumerate(self.url()):\n            tmproot = Path(tempfile.mkdtemp())\n            tmpfile = tmproot.joinpath('file')\n            download(url, tmpfile)\n            os.makedirs(folder_path, exist_ok=True)\n            shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx]))\n            print()\n\n    def preprocess(self):\n        dataset_path = self._get_preprocessed_dataset_path()\n        if dataset_path.is_file():\n            print('Already preprocessed. Skip preprocessing')\n            return\n        if not dataset_path.parent.is_dir():\n            dataset_path.parent.mkdir(parents=True)\n        self.maybe_download_raw_dataset()\n        df = self.load_ratings_df()\n        meta_raw = self.load_meta_dict()\n        df = df[df['sid'].isin(meta_raw)]  # filter items without meta info\n        df = self.filter_triplets(df)\n        df, umap, smap = self.densify_index(df)\n        train, val, test = self.split_df(df, len(umap))\n        meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}\n        dataset = {'train': train,\n                   'val': val,\n                   'test': test,\n                   'meta': meta,\n                   'umap': umap,\n                   'smap': smap}\n        with dataset_path.open('wb') as f:\n            pickle.dump(dataset, f)\n\n    def load_ratings_df(self):\n        folder_path = self._get_rawdata_folder_path()\n        file_path = folder_path.joinpath(self.all_raw_file_names()[0])\n        df = pd.read_csv(file_path, header=None)\n        df.columns = ['uid', 'sid', 'rating', 'timestamp']\n        return df\n    \n    def load_meta_dict(self):\n        folder_path = self._get_rawdata_folder_path()\n        file_path = folder_path.joinpath(self.all_raw_file_names()[1])\n\n        meta_dict = {}\n        with gzip.open(file_path, 'rb') as f:\n            for line in f:\n                item = eval(line)\n                if 'title' in item and len(item['title']) > 0:\n                    meta_dict[item['asin'].strip()] = item['title'].strip()\n        \n        return meta_dict\n"
  },
  {
    "path": "datasets/games.py",
    "content": "from .base import AbstractDataset\nfrom .utils import *\n\nfrom datetime import date\nfrom pathlib import Path\nimport pickle\nimport shutil\nimport tempfile\nimport os\n\nimport gzip\nimport json\nimport numpy as np\nimport pandas as pd\nfrom tqdm import tqdm\ntqdm.pandas()\n\n\nclass GamesDataset(AbstractDataset):\n    @classmethod\n    def code(cls):\n        return 'games'\n\n    @classmethod\n    def url(cls):\n        # meta_Video_Games.json.gz from snap.stanford.edu does not contain full meta info\n        return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Video_Games.csv',\n                'https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2/metaFiles2/meta_Video_Games.json.gz']\n\n    @classmethod\n    def zip_file_content_is_folder(cls):\n        return True\n\n    @classmethod\n    def all_raw_file_names(cls):\n        return ['games.csv', 'games_meta.json.gz']\n\n    def maybe_download_raw_dataset(self):\n        folder_path = self._get_rawdata_folder_path()\n        if folder_path.is_dir() and\\\n           all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):\n            print('Raw data already exists. Skip downloading')\n            return\n        \n        print(\"Raw file doesn't exist. Downloading...\")\n        for idx, url in enumerate(self.url()):\n            tmproot = Path(tempfile.mkdtemp())\n            tmpfile = tmproot.joinpath('file')\n            download(url, tmpfile)\n            os.makedirs(folder_path, exist_ok=True)\n            shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx]))\n            print()\n\n    def preprocess(self):\n        dataset_path = self._get_preprocessed_dataset_path()\n        if dataset_path.is_file():\n            print('Already preprocessed. Skip preprocessing')\n            return\n        if not dataset_path.parent.is_dir():\n            dataset_path.parent.mkdir(parents=True)\n        self.maybe_download_raw_dataset()\n        df = self.load_ratings_df()\n        meta_raw = self.load_meta_dict()\n        df = df[df['sid'].isin(meta_raw)]  # filter items without meta info\n        df = self.filter_triplets(df)\n        df, umap, smap = self.densify_index(df)\n        train, val, test = self.split_df(df, len(umap))\n        meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}\n        dataset = {'train': train,\n                   'val': val,\n                   'test': test,\n                   'meta': meta,\n                   'umap': umap,\n                   'smap': smap}\n        with dataset_path.open('wb') as f:\n            pickle.dump(dataset, f)\n\n    def load_ratings_df(self):\n        folder_path = self._get_rawdata_folder_path()\n        file_path = folder_path.joinpath(self.all_raw_file_names()[0])\n        df = pd.read_csv(file_path, header=None)\n        df.columns = ['uid', 'sid', 'rating', 'timestamp']\n        return df\n    \n    def load_meta_dict(self):\n        folder_path = self._get_rawdata_folder_path()\n        file_path = folder_path.joinpath(self.all_raw_file_names()[1])\n\n        meta_dict = {}\n        with gzip.open(file_path, 'rb') as f:\n            for line in f:\n                item = eval(line)\n                if 'title' in item and len(item['title']) > 0:\n                    meta_dict[item['asin'].strip()] = item['title'].strip()\n        \n        return meta_dict\n"
  },
  {
    "path": "datasets/ml_100k.py",
    "content": "from .base import AbstractDataset\nfrom .utils import *\n\nfrom datetime import date\nfrom pathlib import Path\nimport pickle\nimport shutil\nimport tempfile\nimport os\n\nimport re\nimport numpy as np\nimport pandas as pd\nfrom tqdm import tqdm\ntqdm.pandas()\n\n\nclass ML100KDataset(AbstractDataset):\n    @classmethod\n    def code(cls):\n        return 'ml-100k'\n\n    @classmethod\n    def url(cls):  # as of Sep 2023\n        return 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'\n\n    @classmethod\n    def zip_file_content_is_folder(cls):\n        return True\n\n    @classmethod\n    def all_raw_file_names(cls):\n        return ['README',\n                'movies.csv',\n                'ratings.csv',\n                'users.csv']\n\n    def maybe_download_raw_dataset(self):\n        folder_path = self._get_rawdata_folder_path()\n        if folder_path.is_dir() and\\\n           all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):\n            print('Raw data already exists. Skip downloading')\n            return\n\n        print(\"Raw file doesn't exist. Downloading...\")\n        tmproot = Path(tempfile.mkdtemp())\n        tmpzip = tmproot.joinpath('file.zip')\n        tmpfolder = tmproot.joinpath('folder')\n        download(self.url(), tmpzip)\n        unzip(tmpzip, tmpfolder)\n        if self.zip_file_content_is_folder():\n            tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0])\n        shutil.move(tmpfolder, folder_path)\n        shutil.rmtree(tmproot)\n        print()\n\n    def preprocess(self):\n        dataset_path = self._get_preprocessed_dataset_path()\n        if dataset_path.is_file():\n            print('Already preprocessed. Skip preprocessing')\n            return\n        if not dataset_path.parent.is_dir():\n            dataset_path.parent.mkdir(parents=True)\n        self.maybe_download_raw_dataset()\n        df = self.load_ratings_df()\n        meta_raw = self.load_meta_dict()\n        df = df[df['sid'].isin(meta_raw)]  # filter items without meta info\n        df = self.filter_triplets(df)\n        df, umap, smap = self.densify_index(df)\n        train, val, test = self.split_df(df, len(umap))\n        meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}\n        dataset = {'train': train,\n                   'val': val,\n                   'test': test,\n                   'meta': meta,\n                   'umap': umap,\n                   'smap': smap}\n        with dataset_path.open('wb') as f:\n            pickle.dump(dataset, f)\n\n    def load_ratings_df(self):\n        folder_path = self._get_rawdata_folder_path()\n        file_path = folder_path.joinpath('ratings.csv')\n        df = pd.read_csv(file_path)\n        df.columns = ['uid', 'sid', 'rating', 'timestamp']\n        return df\n\n    def load_meta_dict(self):\n        folder_path = self._get_rawdata_folder_path()\n        file_path = folder_path.joinpath('movies.csv')\n        df = pd.read_csv(file_path, encoding=\"ISO-8859-1\")\n        meta_dict = {}\n        for row in df.itertuples():\n            title = row[2][:-7]  # remove year (optional)\n            year = row[2][-7:]\n\n            title = re.sub('\\(.*?\\)', '', title).strip()\n            # the rest articles and parentheses are not considered here\n            if any(', '+x in title.lower()[-5:] for x in ['a', 'an', 'the']):\n                title_pre = title.split(', ')[:-1]\n                title_post = title.split(', ')[-1]\n                title_pre = ', '.join(title_pre)\n                title = title_post + ' ' + title_pre\n\n            meta_dict[row[1]] = title + year\n        return meta_dict\n"
  },
  {
    "path": "datasets/utils.py",
    "content": "import numpy as np\nimport pandas as pd\nfrom tqdm import tqdm\nimport urllib.request\n\n\nfrom pathlib import Path\nimport zipfile\nimport tarfile\nimport sys\n\n\ndef download(url, savepath):\n    urllib.request.urlretrieve(url, str(savepath))\n    print()\n\n\ndef unzip(zippath, savepath):\n    print(\"Extracting data...\")\n    zip = zipfile.ZipFile(zippath)\n    zip.extractall(savepath)\n    zip.close()\n\n\ndef unziptargz(zippath, savepath):\n    print(\"Extracting data...\")\n    f = tarfile.open(zippath)\n    f.extractall(savepath)\n    f.close()\n"
  },
  {
    "path": "model/__init__.py",
    "content": "from .lru import *\nfrom .llm import *"
  },
  {
    "path": "model/llm.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch LLaMA model.\"\"\"\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings\nfrom transformers.models.llama.configuration_llama import LlamaConfig\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"LlamaConfig\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(\n    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass LlamaRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass LlamaRotaryEmbedding(torch.nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :].to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :].to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n        )\n\n\nclass LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :].to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :].to(dtype), persistent=False)\n\n\nclass LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq)\n\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :].to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :].to(dtype), persistent=False)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.\n    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]\n    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.pretraining_tp = config.pretraining_tp\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        if self.pretraining_tp > 1:\n            slice = self.intermediate_size // self.pretraining_tp\n            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)\n            up_proj_slices = self.up_proj.weight.split(slice, dim=0)\n            down_proj_slices = self.down_proj.weight.split(slice, dim=1)\n\n            gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)\n            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)\n\n            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)\n            down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]\n            down_proj = sum(down_proj)\n        else:\n            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n        return down_proj\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass LlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.pretraining_tp = config.pretraining_tp\n        self.max_position_embeddings = config.max_position_embeddings\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n        self._init_rope()\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(\n                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(\n                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.pretraining_tp > 1:\n            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp\n            query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)\n            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]\n            query_states = torch.cat(query_states, dim=-1)\n\n            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]\n            key_states = torch.cat(key_states, dim=-1)\n\n            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]\n            value_states = torch.cat(value_states, dim=-1)\n\n        else:\n            query_states = self.q_proj(hidden_states)\n            key_states = self.k_proj(hidden_states)\n            value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        if self.pretraining_tp > 1:\n            attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)\n            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)\n            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])\n        else:\n            attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = LlamaAttention(config=config)\n        self.mlp = LlamaMLP(config)\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`LlamaConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaPreTrainedModel(PreTrainedModel):\n    config_class = LlamaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"LlamaDecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, LlamaModel):\n            module.gradient_checkpointing = value\n\n\nLLAMA_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaModel(LlamaPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n                inputs_embeds.device\n            )\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device\n            )\n            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n        else:\n            position_ids = position_ids.view(-1, seq_length).long()\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n            )\n        attention_mask = self._prepare_decoder_attention_mask(\n            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n        )\n\n        hidden_states = inputs_embeds\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass LlamaForCausalLM(LlamaPreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = LlamaModel(config)\n        self.pretraining_tp = config.pretraining_tp\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, LlamaForCausalLM\n\n        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        if self.pretraining_tp > 1:\n            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)\n            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]\n            logits = torch.cat(logits, dim=-1)\n        else:\n            logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if self.training and labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n        elif labels is not None:\n            loss = torch.tensor(-1.)  # loss cannot be directly computed in inference\n\n        logits = logits[:, -1]  # we only need last position logits for inference \n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),\n            )\n        return reordered_past"
  },
  {
    "path": "model/lru.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\n\nclass LRURec(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        self.embedding = LRUEmbedding(self.args)\n        self.model = LRUModel(self.args)\n        self.truncated_normal_init()\n\n    def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):\n        with torch.no_grad():\n            l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.\n            u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.\n\n            for n, p in self.named_parameters():\n                if not 'layer_norm' in n and 'params_log' not in n:\n                    if torch.is_complex(p):\n                        p.real.uniform_(2 * l - 1, 2 * u - 1)\n                        p.imag.uniform_(2 * l - 1, 2 * u - 1)\n                        p.real.erfinv_()\n                        p.imag.erfinv_()\n                        p.real.mul_(std * math.sqrt(2.))\n                        p.imag.mul_(std * math.sqrt(2.))\n                        p.real.add_(mean)\n                        p.imag.add_(mean)\n                    else:\n                        p.uniform_(2 * l - 1, 2 * u - 1)\n                        p.erfinv_()\n                        p.mul_(std * math.sqrt(2.))\n                        p.add_(mean)\n\n    def forward(self, x):\n        x, mask = self.embedding(x)\n        scores = self.model(x, self.embedding.token.weight, mask)\n        return scores\n\n\nclass LRUEmbedding(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        vocab_size = args.num_items + 1\n        embed_size = args.bert_hidden_units\n        \n        self.token = nn.Embedding(vocab_size, embed_size)\n        self.layer_norm = nn.LayerNorm(embed_size)\n        self.embed_dropout = nn.Dropout(args.bert_dropout)\n\n    def get_mask(self, x):\n        return (x > 0)\n\n    def forward(self, x):\n        mask = self.get_mask(x)\n        x = self.token(x)\n        return self.layer_norm(self.embed_dropout(x)), mask\n\n\nclass LRUModel(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        self.hidden_size = args.bert_hidden_units\n        layers = args.bert_num_blocks\n\n        self.lru_blocks = nn.ModuleList([LRUBlock(self.args) for _ in range(layers)])\n        self.bias = torch.nn.Parameter(torch.zeros(args.num_items + 1))\n\n    def forward(self, x, embedding_weight, mask):\n        # left padding to the power of 2\n        seq_len = x.size(1)\n        log2_L = int(np.ceil(np.log2(seq_len)))\n        x = F.pad(x, (0, 0, 2 ** log2_L - x.size(1), 0, 0, 0))\n        mask_ = F.pad(mask, (2 ** log2_L - mask.size(1), 0, 0, 0))\n\n        # LRU blocks with pffn\n        for lru_block in self.lru_blocks:\n            x = lru_block.forward(x, mask_)\n        x = x[:, -seq_len:]  # B x L x D (64)\n\n        scores = torch.matmul(x, embedding_weight.permute(1, 0)) + self.bias\n        return scores\n\n\nclass LRUBlock(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        hidden_size = args.bert_hidden_units\n        self.lru_layer = LRULayer(\n            d_model=hidden_size, dropout=args.bert_attn_dropout)\n        self.feed_forward = PositionwiseFeedForward(\n            d_model=hidden_size, d_ff=hidden_size*4, dropout=args.bert_dropout)\n    \n    def forward(self, x, mask):\n        x = self.lru_layer(x, mask)\n        x = self.feed_forward(x)\n        return x\n    \n\nclass LRULayer(nn.Module):\n    def __init__(self,\n                 d_model,\n                 dropout=0.1,\n                 use_bias=True,\n                 r_min=0.8,\n                 r_max=0.99):\n        super().__init__()\n        self.embed_size = d_model\n        self.hidden_size = 2 * d_model\n        self.use_bias = use_bias\n\n        # init nu, theta, gamma\n        u1 = torch.rand(self.hidden_size)\n        u2 = torch.rand(self.hidden_size)\n        nu_log = torch.log(-0.5 * torch.log(u1 * (r_max ** 2 - r_min ** 2) + r_min ** 2))\n        theta_log = torch.log(u2 * torch.tensor(np.pi) * 2)\n        diag_lambda = torch.exp(torch.complex(-torch.exp(nu_log), torch.exp(theta_log)))\n        gamma_log = torch.log(torch.sqrt(1 - torch.abs(diag_lambda) ** 2))\n        self.params_log = nn.Parameter(torch.vstack((nu_log, theta_log, gamma_log)))\n\n        # Init B, C, D\n        self.in_proj = nn.Linear(self.embed_size, self.hidden_size, bias=use_bias).to(torch.cfloat)\n        self.out_proj = nn.Linear(self.hidden_size, self.embed_size, bias=use_bias).to(torch.cfloat)\n        # self.out_vector = nn.Parameter(torch.rand(self.embed_size))\n        self.out_vector = nn.Identity()\n        \n        # Dropout and layer norm\n        self.dropout = nn.Dropout(p=dropout)\n        self.layer_norm = nn.LayerNorm(self.embed_size)\n\n    def lru_parallel(self, i, h, lamb, mask, B, L, D):\n        # Parallel algorithm, see: https://kexue.fm/archives/9554#%E5%B9%B6%E8%A1%8C%E5%8C%96\n        # The original implementation is slightly slower and does not consider 0 padding\n        l = 2 ** i\n        h = h.reshape(B * L // l, l, D)  # (B, L, D) -> (B * L // 2, 2, D)\n        mask_ = mask.reshape(B * L // l, l)  # (B, L) -> (B * L // 2, 2)\n        h1, h2 = h[:, :l // 2], h[:, l // 2:]  # Divide data in half\n\n        if i > 1: lamb = torch.cat((lamb, lamb * lamb[-1]), 0)\n        h2 = h2 + lamb * h1[:, -1:] * mask_[:, l // 2 - 1:l // 2].unsqueeze(-1)\n        h = torch.cat([h1, h2], axis=1)\n        return h, lamb\n\n    def forward(self, x, mask):\n        # compute bu and lambda\n        nu, theta, gamma = torch.exp(self.params_log).split((1, 1, 1))\n        lamb = torch.exp(torch.complex(-nu, theta))\n        h = self.in_proj(x.to(torch.cfloat)) * gamma  # bu\n        \n        # compute h in parallel\n        log2_L = int(np.ceil(np.log2(h.size(1))))\n        B, L, D = h.size(0), h.size(1), h.size(2)\n        for i in range(log2_L):\n            h, lamb = self.lru_parallel(i + 1, h, lamb, mask, B, L, D)\n        x = self.dropout(self.out_proj(h).real) + self.out_vector(x)\n        return self.layer_norm(x)  # residual connection introduced above \n    \n\nclass PositionwiseFeedForward(nn.Module):\n    def __init__(self, d_model, d_ff, dropout=0.1):\n        super().__init__()\n        self.w_1 = nn.Linear(d_model, d_ff)\n        self.w_2 = nn.Linear(d_ff, d_model)\n        self.activation = nn.GELU()\n        self.dropout = nn.Dropout(dropout)\n        self.layer_norm = nn.LayerNorm(d_model)\n\n    def forward(self, x):\n        x_ = self.dropout(self.activation(self.w_1(x)))\n        return self.layer_norm(self.dropout(self.w_2(x_)) + x)"
  },
  {
    "path": "requirements.txt",
    "content": "# This file may be used to create an environment using:\n# $ conda create --name <env> --file <this file>\n# platform: linux-64\n_libgcc_mutex=0.1=main\n_openmp_mutex=5.1=1_gnu\nabsl-py=1.4.0=pypi_0\naccelerate=0.21.0=pypi_0\naiofiles=23.1.0=pypi_0\naiohttp=3.8.4=pypi_0\naiosignal=1.3.1=pypi_0\nalabaster=0.7.13=pypi_0\naltair=4.2.2=pypi_0\naniso8601=9.0.1=pypi_0\nantlr4-python3-runtime=4.9.3=pypi_0\nanyio=3.6.2=pypi_0\nappdirs=1.4.4=pypi_0\nasttokens=2.2.1=pypi_0\nasync-timeout=4.0.2=pypi_0\nattrdict=2.0.1=pypi_0\nattrs=23.1.0=pypi_0\naudioread=3.0.0=pypi_0\nbabel=2.12.1=pypi_0\nbackcall=0.2.0=pypi_0\nbeautifulsoup4=4.12.2=pypi_0\nbitsandbytes=0.41.1=pypi_0\nblack=19.10b0=pypi_0\nblas=1.0=mkl\nboto3=1.26.160=pypi_0\nbotocore=1.29.160=pypi_0\nbraceexpand=0.1.7=pypi_0\nbrotlipy=0.7.0=py310h7f8727e_1002\nbzip2=1.0.8=h7b6447c_0\nca-certificates=2023.08.22=h06a4308_0\ncachetools=5.3.0=pypi_0\ncdifflib=1.2.6=pypi_0\ncertifi=2022.12.7=py310h06a4308_0\ncffi=1.15.1=py310h5eee18b_3\ncharset-normalizer=2.0.4=pyhd3eb1b0_0\nclick=8.0.2=pypi_0\nclip=1.0=pypi_0\ncolorama=0.4.6=pypi_0\ncomm=0.1.3=pypi_0\ncontourpy=1.0.7=pypi_0\ncryptography=39.0.1=py310h9ce1e76_0\ncuda-cudart=11.8.89=0\ncuda-cupti=11.8.87=0\ncuda-libraries=11.8.0=0\ncuda-nvrtc=11.8.89=0\ncuda-nvtx=11.8.86=0\ncuda-runtime=11.8.0=0\ncudatoolkit=11.8.0=h6a678d5_0\ncycler=0.11.0=pypi_0\ncython=0.29.35=pypi_0\ndatasets=2.14.5=pypi_0\ndebugpy=1.6.7=pypi_0\ndecorator=5.1.1=pypi_0\ndill=0.3.6=pypi_0\ndistance=0.1.3=pypi_0\ndocker-pycreds=0.4.0=pypi_0\ndocopt=0.6.2=pypi_0\ndocutils=0.20.1=pypi_0\neditdistance=0.6.2=pypi_0\neinops=0.6.1=pypi_0\nemoji=2.8.0=pypi_0\nentrypoints=0.4=pypi_0\nevaluate=0.3.0=pypi_0\nexceptiongroup=1.1.1=pypi_0\nexecuting=1.2.0=pypi_0\nfairscale=0.4.13=pypi_0\nfaiss-cpu=1.7.4=pypi_0\nfaiss-gpu=1.7.2=pypi_0\nfastapi=0.95.1=pypi_0\nfasttext=0.9.2=pypi_0\nffmpeg=4.3=hf484d3e_0\nffmpy=0.3.0=pypi_0\nfilelock=3.9.0=py310h06a4308_0\nfire=0.5.0=pypi_0\nflask=2.2.5=pypi_0\nflask-restful=0.3.10=pypi_0\nflit-core=3.8.0=py310h06a4308_0\nfonttools=4.39.3=pypi_0\nfreetype=2.12.1=h4a9f257_0\nfrozenlist=1.3.3=pypi_0\nfsspec=2023.4.0=pypi_0\nftfy=6.1.1=pypi_0\nfuture=0.18.3=pypi_0\ng2p-en=2.1.0=pypi_0\ngdown=4.7.1=pypi_0\ngiflib=5.2.1=h5eee18b_3\ngitdb=4.0.10=pypi_0\ngitpython=3.1.31=pypi_0\ngmp=6.2.1=h295c915_3\ngmpy2=2.1.2=py310heeb90bb_0\ngnutls=3.6.15=he1e5248_0\ngoogle-auth=2.17.3=pypi_0\ngoogle-auth-oauthlib=1.0.0=pypi_0\ngradio=3.28.3=pypi_0\ngradio-client=0.1.3=pypi_0\ngrpcio=1.54.0=pypi_0\nh11=0.14.0=pypi_0\nh5py=3.9.0=pypi_0\nhttpcore=0.17.0=pypi_0\nhttpx=0.24.0=pypi_0\nhuggingface-hub=0.16.4=pypi_0\nhydra-core=1.2.0=pypi_0\nidna=3.4=py310h06a4308_0\nijson=3.2.2=pypi_0\nimageio=2.28.0=pypi_0\nimagesize=1.4.1=pypi_0\ninflect=6.0.4=pypi_0\niniconfig=2.0.0=pypi_0\nintel-openmp=2021.4.0=h06a4308_3561\nipadic=1.0.0=pypi_0\nipykernel=6.23.3=pypi_0\nipython=8.12.0=pypi_0\nipywidgets=8.0.6=pypi_0\nisort=5.12.0=pypi_0\nitsdangerous=2.1.2=pypi_0\njedi=0.18.2=pypi_0\njieba=0.42.1=pypi_0\njinja2=3.1.2=py310h06a4308_0\njiwer=2.5.2=pypi_0\njmespath=1.0.1=pypi_0\njoblib=1.2.0=pypi_0\njpeg=9e=h5eee18b_1\njsonschema=4.17.3=pypi_0\njupyter-client=8.3.0=pypi_0\njupyter-core=5.3.1=pypi_0\njupyterlab-widgets=3.0.7=pypi_0\nkaldi-python-io=1.2.2=pypi_0\nkaldiio=2.18.0=pypi_0\nkiwisolver=1.4.4=pypi_0\nkornia=0.6.12=pypi_0\nlame=3.100=h7b6447c_0\nlatexcodec=2.0.1=pypi_0\nlazy-loader=0.2=pypi_0\nlcms2=2.12=h3be6417_0\nld_impl_linux-64=2.38=h1181459_1\nlerc=3.0=h295c915_0\nlevenshtein=0.21.1=pypi_0\nlibcublas=11.11.3.6=0\nlibcufft=10.9.0.58=0\nlibcufile=1.6.0.25=0\nlibcurand=10.3.2.56=0\nlibcusolver=11.4.1.48=0\nlibcusparse=11.7.5.86=0\nlibdeflate=1.17=h5eee18b_0\nlibffi=3.4.2=h6a678d5_6\nlibgcc-ng=11.2.0=h1234567_1\nlibgomp=11.2.0=h1234567_1\nlibiconv=1.16=h7f8727e_2\nlibidn2=2.3.2=h7f8727e_0\nlibnpp=11.8.0.86=0\nlibnvjpeg=11.9.0.86=0\nlibpng=1.6.39=h5eee18b_0\nlibrosa=0.10.0.post2=pypi_0\nlibstdcxx-ng=11.2.0=h1234567_1\nlibtasn1=4.16.0=h27cfd23_0\nlibtiff=4.5.0=h6a678d5_2\nlibunistring=0.9.10=h27cfd23_0\nlibuuid=1.41.5=h5eee18b_0\nlibwebp=1.2.4=h11a3e52_1\nlibwebp-base=1.2.4=h5eee18b_1\nlightning-utilities=0.8.0=pypi_0\nlinkify-it-py=2.0.0=pypi_0\nllvmlite=0.40.1=pypi_0\nloguru=0.7.0=pypi_0\nloralib=0.1.1=pypi_0\nlxml=4.9.2=pypi_0\nlz4-c=1.9.4=h6a678d5_0\nmarkdown=3.4.3=pypi_0\nmarkdown-it-py=2.2.0=pypi_0\nmarkdown2=2.4.9=pypi_0\nmarkupsafe=2.1.1=py310h7f8727e_0\nmarshmallow=3.19.0=pypi_0\nmatplotlib=3.7.1=pypi_0\nmatplotlib-inline=0.1.6=pypi_0\nmdit-py-plugins=0.3.3=pypi_0\nmdurl=0.1.2=pypi_0\nmecab-python3=1.0.5=pypi_0\nmegatron-core=0.2.0=pypi_0\nmkl=2021.4.0=h06a4308_640\nmkl-service=2.4.0=py310h7f8727e_0\nmkl_fft=1.3.1=py310hd6ae3a3_0\nmkl_random=1.2.2=py310h00e6091_0\nmpc=1.1.0=h10f8cd9_1\nmpfr=4.0.2=hb69a4c5_1\nmpmath=1.2.1=pypi_0\nmsgpack=1.0.5=pypi_0\nmultidict=6.0.4=pypi_0\nmultiprocess=0.70.14=pypi_0\nmypy-extensions=1.0.0=pypi_0\nncurses=6.4=h6a678d5_0\nnest-asyncio=1.5.6=pypi_0\nnettle=3.7.3=hbbd107a_1\nnetworkx=2.8.4=py310h06a4308_1\nnltk=3.8=pypi_0\nnumba=0.57.1=pypi_0\nnumpy=1.23.4=pypi_0\nnumpy-base=1.24.3=py310h8e6c178_0\noauthlib=3.2.2=pypi_0\nomegaconf=2.2.3=pypi_0\nonnx=1.14.0=pypi_0\nopenai=0.27.1=pypi_0\nopencc=1.1.6=pypi_0\nopencv-python=4.7.0.72=pypi_0\nopenh264=2.1.1=h4ff587b_0\nopenprompt=1.0.1=pypi_0\nopenssl=1.1.1w=h7f8727e_0\norjson=3.8.10=pypi_0\npackaging=23.1=pypi_0\npandas=2.0.1=pypi_0\npangu=4.0.6.1=pypi_0\nparameterized=0.9.0=pypi_0\nparso=0.8.3=pypi_0\npathspec=0.11.1=pypi_0\npathtools=0.1.2=pypi_0\npeft=0.5.0=pypi_0\npexpect=4.8.0=pypi_0\npickleshare=0.7.5=pypi_0\npillow=9.4.0=py310h6a678d5_0\npip=23.2.1=pypi_0\nplac=1.3.5=pypi_0\nplatformdirs=3.4.0=pypi_0\npluggy=1.2.0=pypi_0\npooch=1.6.0=pypi_0\nportalocker=2.7.0=pypi_0\nprogress=1.6=pypi_0\nprompt-toolkit=3.0.38=pypi_0\nprotobuf=3.20.3=pypi_0\npsutil=5.9.5=pypi_0\nptyprocess=0.7.0=pypi_0\npure-eval=0.2.2=pypi_0\npyannote-core=5.0.0=pypi_0\npyannote-database=5.0.1=pypi_0\npyannote-metrics=3.2.1=pypi_0\npyarrow=11.0.0=pypi_0\npyasn1=0.5.0=pypi_0\npyasn1-modules=0.3.0=pypi_0\npybind11=2.10.4=pypi_0\npybtex=0.24.0=pypi_0\npybtex-docutils=1.0.2=pypi_0\npycparser=2.21=pyhd3eb1b0_0\npydantic=1.10.7=pypi_0\npydeprecate=0.3.1=pypi_0\npydub=0.25.1=pypi_0\npygments=2.15.1=pypi_0\npynini=2.1.5=pypi_0\npyopenssl=23.0.0=py310h06a4308_0\npyparsing=3.0.9=pypi_0\npypinyin=0.49.0=pypi_0\npypinyin-dict=0.6.0=pypi_0\npyrsistent=0.19.3=pypi_0\npysocks=1.7.1=py310h06a4308_0\npytest=7.4.0=pypi_0\npytest-runner=6.0.0=pypi_0\npython=3.10.10=h7a1cb2a_2\npython-dateutil=2.8.2=pypi_0\npython-multipart=0.0.6=pypi_0\npytorch=2.0.0=py3.10_cuda11.8_cudnn8.7.0_0\npytorch-cuda=11.8=h7e8668a_3\npytorch-lightning=1.9.4=pypi_0\npytorch-mutex=1.0=cuda\npytz=2023.3=pypi_0\npyyaml=5.4.1=pypi_0\npyzmq=25.1.0=pypi_0\nrank-bm25=0.2.2=pypi_0\nrapidfuzz=2.13.7=pypi_0\nreadline=8.2=h5eee18b_0\nregex=2023.3.23=pypi_0\nrequests=2.28.1=py310h06a4308_1\nrequests-oauthlib=1.3.1=pypi_0\nresponses=0.18.0=pypi_0\nrich=13.4.2=pypi_0\nrisparser=0.4.4=pypi_0\nrouge=1.0.0=pypi_0\nrouge-score=0.1.2=pypi_0\nrsa=4.9=pypi_0\nruamel-yaml=0.17.32=pypi_0\nruamel-yaml-clib=0.2.7=pypi_0\ns3transfer=0.6.1=pypi_0\nsacrebleu=2.3.1=pypi_0\nsacremoses=0.0.53=pypi_0\nsafetensors=0.3.1=pypi_0\nscikit-learn=1.2.1=pypi_0\nscipy=1.10.1=pypi_0\nseaborn=0.12.2=pypi_0\nsemantic-version=2.10.0=pypi_0\nsentence-transformers=2.2.2=pypi_0\nsentencepiece=0.1.96=pypi_0\nsentry-sdk=1.25.1=pypi_0\nsetproctitle=1.3.2=pypi_0\nsetuptools=65.5.1=pypi_0\nshellingham=1.5.0.post1=pypi_0\nsix=1.16.0=pyhd3eb1b0_1\nsmmap=5.0.0=pypi_0\nsniffio=1.3.0=pypi_0\nsnowballstemmer=2.2.0=pypi_0\nsortedcontainers=2.4.0=pypi_0\nsoundfile=0.12.1=pypi_0\nsoupsieve=2.4.1=pypi_0\nsox=1.4.1=pypi_0\nsoxr=0.3.5=pypi_0\nsphinx=7.0.1=pypi_0\nsphinxcontrib-applehelp=1.0.4=pypi_0\nsphinxcontrib-bibtex=2.5.0=pypi_0\nsphinxcontrib-devhelp=1.0.2=pypi_0\nsphinxcontrib-htmlhelp=2.0.1=pypi_0\nsphinxcontrib-jsmath=1.0.1=pypi_0\nsphinxcontrib-qthelp=1.0.3=pypi_0\nsphinxcontrib-serializinghtml=1.1.5=pypi_0\nsqlite=3.41.1=h5eee18b_0\nstack-data=0.6.2=pypi_0\nstarlette=0.26.1=pypi_0\nsympy=1.11.1=py310h06a4308_0\ntabulate=0.9.0=pypi_0\ntaming-transformers=0.0.1=pypi_0\ntaming-transformers-rom1504=0.0.6=pypi_0\ntensorboard=2.12.2=pypi_0\ntensorboard-data-server=0.7.0=pypi_0\ntensorboard-plugin-wit=1.8.1=pypi_0\ntensorboardx=2.6.2.2=pypi_0\ntermcolor=2.2.0=pypi_0\ntest-tube=0.7.5=pypi_0\ntext-unidecode=1.3=pypi_0\ntextdistance=4.5.0=pypi_0\ntexterrors=0.4.4=pypi_0\nthreadpoolctl=3.1.0=pypi_0\ntk=8.6.12=h1ccaba5_0\ntokenize-rt=5.0.0=pypi_0\ntokenizers=0.13.3=pypi_0\ntoml=0.10.2=pypi_0\ntomli=2.0.1=pypi_0\ntoolz=0.12.0=pypi_0\ntorchaudio=2.0.0=py310_cu118\ntorchmetrics=0.11.4=pypi_0\ntorchtriton=2.0.0=py310\ntorchvision=0.15.0=py310_cu118\ntornado=6.3.2=pypi_0\ntqdm=4.64.1=pypi_0\ntraitlets=5.9.0=pypi_0\ntransformers=4.33.3=pypi_0\ntrl=0.7.1=pypi_0\ntweet-preprocessor=0.6.0=pypi_0\ntyped-ast=1.5.4=pypi_0\ntyper=0.9.0=pypi_0\ntyping_extensions=4.4.0=py310h06a4308_0\ntzdata=2023.3=pypi_0\nuc-micro-py=1.0.1=pypi_0\nunidecode=1.3.7=pypi_0\nurllib3=1.26.15=py310h06a4308_0\nuvicorn=0.21.1=pypi_0\nwandb=0.15.4=pypi_0\nwcwidth=0.2.6=pypi_0\nwebdataset=0.1.62=pypi_0\nwebsockets=11.0.2=pypi_0\nwerkzeug=2.3.0=pypi_0\nwget=3.2=pypi_0\nwheel=0.38.4=py310h06a4308_0\nwidgetsnbextension=4.0.7=pypi_0\nwrapt=1.15.0=pypi_0\nxxhash=3.2.0=pypi_0\nxz=5.2.10=h5eee18b_1\nyacs=0.1.8=pypi_0\nyarl=1.9.2=pypi_0\nyoutokentome=1.0.6=pypi_0\nzlib=1.2.13=h5eee18b_0\nzstd=1.5.4=hc292b87_0\n"
  },
  {
    "path": "train_ranker.py",
    "content": "import os\nimport torch\nos.environ['TOKENIZERS_PARALLELISM'] = 'false'\n\nimport argparse\nfrom datasets import DATASETS\nfrom config import *\nfrom model import *\nfrom dataloader import *\nfrom trainer import *\n\nfrom transformers import BitsAndBytesConfig\nfrom pytorch_lightning import seed_everything\nfrom model import LlamaForCausalLM\nfrom peft import (\n    LoraConfig,\n    get_peft_model,\n    get_peft_model_state_dict,\n    prepare_model_for_int8_training,\n    prepare_model_for_kbit_training,\n)\n\n\ntry:\n    os.environ['WANDB_PROJECT'] = PROJECT_NAME\nexcept:\n    print('WANDB_PROJECT not available, please set it in config.py')\n\n\ndef main(args, export_root=None):\n    seed_everything(args.seed)\n    if export_root == None:\n        export_root = EXPERIMENT_ROOT + '/' + args.llm_base_model.split('/')[-1] + '/' + args.dataset_code\n\n    train_loader, val_loader, test_loader, tokenizer, test_retrieval = dataloader_factory(args)\n    bnb_config = BitsAndBytesConfig(\n        load_in_4bit=True,\n        bnb_4bit_use_double_quant=True,\n        bnb_4bit_quant_type=\"nf4\",\n        bnb_4bit_compute_dtype=torch.bfloat16\n    )\n    model = LlamaForCausalLM.from_pretrained(\n        args.llm_base_model,\n        quantization_config=bnb_config,\n        device_map='auto',\n        cache_dir=args.llm_cache_dir,\n    )\n    model.gradient_checkpointing_enable()\n    model = prepare_model_for_kbit_training(model)\n    config = LoraConfig(\n        r=args.lora_r,\n        lora_alpha=args.lora_alpha,\n        target_modules=args.lora_target_modules,\n        lora_dropout=args.lora_dropout,\n        bias='none',\n        task_type=\"CAUSAL_LM\",\n    )\n    model = get_peft_model(model, config)\n    model.print_trainable_parameters()\n\n    model.config.use_cache = False\n    trainer = LLMTrainer(args, model, train_loader, val_loader, test_loader, tokenizer, export_root, args.use_wandb)\n    \n    trainer.train()\n    trainer.test(test_retrieval)\n\n\nif __name__ == \"__main__\":\n    args.model_code = 'llm'\n    set_template(args)\n    main(args, export_root=None)\n"
  },
  {
    "path": "train_retriever.py",
    "content": "import os\nimport torch\nos.environ['TOKENIZERS_PARALLELISM'] = 'false'\n\nimport wandb\nimport argparse\n\nfrom config import *\nfrom model import *\nfrom dataloader import *\nfrom trainer import *\n\nfrom pytorch_lightning import seed_everything\n\ntry:\n    os.environ['WANDB_PROJECT'] = PROJECT_NAME\nexcept:\n    print('WANDB_PROJECT not available, please set it in config.py')\n\n\ndef main(args, export_root=None):\n    seed_everything(args.seed)\n    train_loader, val_loader, test_loader = dataloader_factory(args)\n    model = LRURec(args)\n    if export_root == None:\n        export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code\n    \n    trainer = LRUTrainer(args, model, train_loader, val_loader, test_loader, export_root, args.use_wandb)\n    trainer.train()\n    trainer.test()\n\n    # the next line generates val / test candidates for reranking\n    trainer.generate_candidates(os.path.join(export_root, 'retrieved.pkl'))\n\n\nif __name__ == \"__main__\":\n    args.model_code = 'lru'\n    set_template(args)\n    main(args, export_root=None)\n\n    # # searching best hyperparameters\n    # for decay in [0, 0.01]:\n    #     for dropout in [0, 0.1, 0.2, 0.3, 0.4, 0.5]:\n    #         args.weight_decay = decay\n    #         args.bert_dropout = dropout\n    #         args.bert_attn_dropout = dropout\n    #         export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code + '/' + str(decay) + '_' + str(dropout)\n    #         main(args, export_root=export_root)"
  },
  {
    "path": "trainer/__init__.py",
    "content": "from .lru import *\nfrom .llm import *\nfrom .utils import *"
  },
  {
    "path": "trainer/base.py",
    "content": "from model import *\nfrom config import *\nfrom .utils import *\nfrom .loggers import *\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom tqdm import tqdm\n\nimport json\nimport numpy as np\nfrom abc import ABCMeta\nfrom pathlib import Path\nfrom collections import OrderedDict\n\n\nclass BaseTrainer(metaclass=ABCMeta):\n    def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb=True):\n        self.args = args\n        self.device = args.device\n        self.model = model.to(self.device)\n\n        self.num_epochs = args.num_epochs\n        self.metric_ks = args.metric_ks\n        self.best_metric = args.best_metric\n        self.train_loader = train_loader\n        self.val_loader = val_loader\n        self.test_loader = test_loader\n        self.optimizer = self._create_optimizer()\n        if args.enable_lr_schedule:\n            if args.enable_lr_warmup:\n                self.lr_scheduler = self.get_linear_schedule_with_warmup(\n                    self.optimizer, args.warmup_steps, len(self.train_loader) * self.num_epochs)\n            else:\n                self.lr_scheduler = optim.lr_scheduler.StepLR(\n                    self.optimizer, step_size=args.decay_step, gamma=args.gamma)\n            \n        self.export_root = export_root\n        if not os.path.exists(self.export_root):\n            Path(self.export_root).mkdir(parents=True)\n        self.use_wandb = use_wandb\n        if use_wandb:\n            import wandb\n            wandb.init(\n                name=self.args.model_code+'_'+self.args.dataset_code,\n                project=PROJECT_NAME,\n                config=args,\n            )\n            writer = wandb\n        else:\n            from torch.utils.tensorboard import SummaryWriter\n            writer = SummaryWriter(\n                log_dir=Path(self.export_root).joinpath('logs'),\n                comment=self.args.model_code+'_'+self.args.dataset_code,\n            )\n        self.val_loggers, self.test_loggers = self._create_loggers()\n        self.logger_service = LoggerService(\n            self.args, writer, self.val_loggers, self.test_loggers, use_wandb)\n        \n        print(args)\n\n    def train(self):\n        accum_iter = 0\n        self.exit_training = self.validate(0, accum_iter)\n        for epoch in range(self.num_epochs):\n            accum_iter = self.train_one_epoch(epoch, accum_iter)\n            if self.args.val_strategy == 'epoch':\n                self.exit_training = self.validate(epoch, accum_iter)  # val after every epoch\n            if self.exit_training:\n                print('Early stopping triggered. Exit training')\n                break\n        self.logger_service.complete()\n\n    def train_one_epoch(self, epoch, accum_iter):\n        average_meter_set = AverageMeterSet()\n        tqdm_dataloader = tqdm(self.train_loader)\n\n        for batch_idx, batch in enumerate(tqdm_dataloader):\n            self.model.train()\n            batch = self.to_device(batch)\n\n            self.optimizer.zero_grad()\n            loss = self.calculate_loss(batch)\n            loss.backward()\n            self.clip_gradients(self.args.max_grad_norm)\n            self.optimizer.step()\n            if self.args.enable_lr_schedule:\n                self.lr_scheduler.step()\n\n            average_meter_set.update('loss', loss.item())\n            tqdm_dataloader.set_description(\n                'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))\n\n            accum_iter += 1\n            if self.args.val_strategy == 'iteration' and accum_iter % self.args.val_iterations == 0:\n                self.exit_training = self.validate(epoch, accum_iter)  # val after certain iterations\n                if self.exit_training: break\n\n        return accum_iter\n\n    def validate(self, epoch, accum_iter):\n        self.model.eval()\n        average_meter_set = AverageMeterSet()\n        with torch.no_grad():\n            tqdm_dataloader = tqdm(self.val_loader)\n            for batch_idx, batch in enumerate(tqdm_dataloader):\n                batch = self.to_device(batch)\n                metrics = self.calculate_metrics(batch, exclude_history=False)  # faster validation\n                self._update_meter_set(average_meter_set, metrics)\n                self._update_dataloader_metrics(\n                    tqdm_dataloader, average_meter_set)\n\n            log_data = {\n                'state_dict': (self._create_state_dict()),\n                'epoch': epoch+1,\n                'accum_iter': accum_iter,\n            }\n            log_data.update(average_meter_set.averages())\n        \n        return self.logger_service.log_val(log_data)  # early stopping\n\n    def test(self, epoch=-1, accum_iter=-1, save_name=None):\n        print('******************** Testing Best Model ********************')\n        best_model_dict = torch.load(os.path.join(\n            self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)\n        self.model.load_state_dict(best_model_dict)\n        self.model.eval()\n\n        average_meter_set = AverageMeterSet()\n        with torch.no_grad():\n            tqdm_dataloader = tqdm(self.test_loader)\n            for batch_idx, batch in enumerate(tqdm_dataloader):\n                batch = self.to_device(batch)\n                metrics = self.calculate_metrics(batch)\n                self._update_meter_set(average_meter_set, metrics)\n                self._update_dataloader_metrics(\n                    tqdm_dataloader, average_meter_set)\n\n            log_data = {\n                'state_dict': (self._create_state_dict()),\n                'epoch': epoch+1,\n                'accum_iter': accum_iter,\n            }\n            average_metrics = average_meter_set.averages()\n            log_data.update(average_metrics)\n            self.logger_service.log_test(log_data)\n\n            print('******************** Testing Metrics ********************')\n            print(average_metrics)\n            file_name = 'test_metrics.json' if save_name is None else save_name\n            with open(os.path.join(self.export_root, file_name), 'w') as f:\n                json.dump(average_metrics, f, indent=4)\n        \n        return average_metrics\n    \n    def to_device(self, batch):\n        return [x.to(self.device) for x in batch]\n\n    @abstractmethod\n    def calculate_loss(self, batch):\n        pass\n    \n    @abstractmethod\n    def calculate_metrics(self, batch):\n        pass\n    \n    def clip_gradients(self, limit=1.0):\n        nn.utils.clip_grad_norm_(self.model.parameters(), limit)\n\n    def _update_meter_set(self, meter_set, metrics):\n        for k, v in metrics.items():\n            meter_set.update(k, v)\n\n    def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):\n        description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]\n                               ] + ['Recall@%d' % k for k in self.metric_ks[:3]]\n        description = 'Eval: ' + \\\n            ', '.join(s + ' {:.4f}' for s in description_metrics)\n        description = description.replace('NDCG', 'N').replace('Recall', 'R')\n        description = description.format(\n            *(meter_set[k].avg for k in description_metrics))\n        tqdm_dataloader.set_description(description)\n\n    def _create_optimizer(self):\n        args = self.args\n        param_optimizer = list(self.model.named_parameters())\n        no_decay = ['bias', 'layer_norm']\n        optimizer_grouped_parameters = [\n            {\n                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n                'weight_decay': args.weight_decay,\n            },\n            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.},\n        ]\n        if args.optimizer.lower() == 'adamw':\n            return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)\n        elif args.optimizer.lower() == 'adam':\n            return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)\n        else:\n            raise NotImplementedError\n\n    def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):\n        def lr_lambda(current_step: int):\n            if current_step < num_warmup_steps:\n                return float(current_step) / float(max(1, num_warmup_steps))\n            return max(\n                0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n            )\n\n        return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n    def _create_loggers(self):\n        root = Path(self.export_root)\n        model_checkpoint = root.joinpath('models')\n\n        val_loggers, test_loggers = [], []\n        for k in self.metric_ks:\n            val_loggers.append(\n                MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation', use_wandb=self.use_wandb))\n            val_loggers.append(\n                MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation', use_wandb=self.use_wandb))\n            val_loggers.append(\n                MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Validation', use_wandb=self.use_wandb))\n\n        val_loggers.append(RecentModelLogger(self.args, model_checkpoint))\n        val_loggers.append(BestModelLogger(self.args, model_checkpoint, metric_key=self.best_metric))\n\n        for k in self.metric_ks:\n            test_loggers.append(\n                MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Test', use_wandb=self.use_wandb))\n            test_loggers.append(\n                MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Test', use_wandb=self.use_wandb))\n            test_loggers.append(\n                MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Test', use_wandb=self.use_wandb))\n\n        return val_loggers, test_loggers\n\n    def _create_state_dict(self):\n        return {\n            STATE_DICT_KEY: self.model.state_dict(),\n            OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),\n        }"
  },
  {
    "path": "trainer/llm.py",
    "content": "from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY\nfrom .verb import ManualVerbalizer\nfrom .utils import *\nfrom .loggers import *\nfrom .base import *\n\nimport re\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\n\nimport json\nimport numpy as np\nfrom abc import *\nfrom pathlib import Path\n\nimport bitsandbytes as bnb\nfrom transformers.trainer import *\nfrom transformers import Trainer, TrainingArguments, EarlyStoppingCallback\n\n\ndef llama_collate_fn_w_truncation(llm_max_length, eval=False):\n    def llama_collate_fn(batch):\n        all_input_ids = []\n        all_attention_mask = []\n        all_labels = []\n        example_max_length = max([len(batch[idx]['input_ids']) for idx in range(len(batch))])\n        max_length = min(llm_max_length, example_max_length)\n        \n        for i in range(len(batch)):\n            input_ids = batch[i]['input_ids']\n            attention_mask = batch[i]['attention_mask']\n            labels = batch[i]['labels']\n            if len(input_ids) > max_length:\n                input_ids = input_ids[-max_length:]\n                attention_mask = attention_mask[-max_length:]\n                if not eval: labels = labels[-max_length:]\n            elif len(input_ids) < max_length:\n                padding_length = max_length - len(input_ids)\n                input_ids = [0] * padding_length + input_ids\n                attention_mask = [0] * padding_length + attention_mask\n                if not eval: labels = [-100] * padding_length + labels\n\n            if eval: assert input_ids[-1] == 13\n            else:\n                assert input_ids[-3] == 13 and input_ids[-1] == 2\n                assert labels[-3] == -100 and labels[-2] != -100\n            \n            all_input_ids.append(torch.tensor(input_ids).long())\n            all_attention_mask.append(torch.tensor(attention_mask).long())\n            all_labels.append(torch.tensor(labels).long())\n        \n        return {\n            'input_ids': torch.vstack(all_input_ids),\n            'attention_mask': torch.vstack(all_attention_mask),\n            'labels': torch.vstack(all_labels)\n        }\n    return llama_collate_fn\n\n\ndef compute_metrics_for_ks(ks, verbalizer):\n    def compute_metrics(eval_pred):\n        logits, labels = eval_pred\n        logits = torch.tensor(logits)\n        labels = torch.tensor(labels).view(-1)\n        scores = verbalizer.process_logits(logits)\n        metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels, ks)\n        return metrics\n    return compute_metrics\n\n\nclass LLMTrainer(Trainer):\n    def __init__(\n            self,\n            args,\n            model,\n            train_loader,\n            val_loader,\n            test_loader,\n            tokenizer,\n            export_root,\n            use_wandb,\n            **kwargs\n        ):\n        self.original_args = args\n        self.export_root = export_root\n        self.use_wandb = use_wandb\n        self.llm_max_text_len = args.llm_max_text_len\n        self.rerank_metric_ks = args.rerank_metric_ks\n        self.verbalizer = ManualVerbalizer(\n            tokenizer=tokenizer,\n            prefix='',\n            post_log_softmax=False,\n            classes=list(range(args.llm_negative_sample_size+1)),\n            label_words={i: chr(ord('A')+i) for i in range(args.llm_negative_sample_size+1)},\n        )\n\n        hf_args = TrainingArguments(\n            per_device_train_batch_size=args.lora_micro_batch_size,\n            gradient_accumulation_steps=args.train_batch_size//args.lora_micro_batch_size,\n            warmup_steps=args.warmup_steps,\n            num_train_epochs=args.lora_num_epochs,\n            learning_rate=args.lora_lr,\n            bf16=True,\n            logging_steps=10,\n            optim=\"paged_adamw_32bit\",\n            evaluation_strategy=\"steps\",\n            save_strategy=\"steps\",\n            eval_steps=args.lora_val_iterations,\n            save_steps=args.lora_val_iterations,\n            output_dir=export_root,\n            save_total_limit=3,\n            load_best_model_at_end=True,\n            ddp_find_unused_parameters=None,\n            group_by_length=False,\n            report_to=\"wandb\" if use_wandb else None,\n            run_name=args.model_code+'_'+args.dataset_code if use_wandb else None,\n            metric_for_best_model=args.rerank_best_metric,\n            greater_is_better=True,\n        )\n        super().__init__(\n            model=model,\n            args=hf_args,\n            callbacks=[EarlyStoppingCallback(args.lora_early_stopping_patience)],\n            **kwargs)  # hf_args is now args\n\n        self.train_loader = train_loader\n        self.val_loader = val_loader\n        self.test_loader = test_loader\n        self.tokenizer = tokenizer\n        \n        self.train_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=False)\n        self.val_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True)\n        self.test_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True)\n        self.compute_metrics = compute_metrics_for_ks(self.rerank_metric_ks, self.verbalizer)\n\n        if len(self.label_names) == 0:\n            self.label_names = ['labels']  # for some reason label name is not set\n    \n    def test(self, test_retrieval):\n        average_metrics = self.predict(test_dataset=None).metrics\n        print('Ranking Performance on Subset:', average_metrics)\n        print('************************************************************')\n        with open(os.path.join(self.export_root, 'subset_metrics.json'), 'w') as f:\n                json.dump(average_metrics, f, indent=4)\n\n        print('Original Performance:', test_retrieval['original_metrics'])\n        print('************************************************************')\n        original_size = test_retrieval['original_size']\n        retrieval_size = test_retrieval['retrieval_size']\n        \n        overall_metrics = {}\n        for key in test_retrieval['non_retrieval_metrics'].keys():\n            if 'test_' + key in average_metrics:\n                overall_metrics['test_' + key] = (average_metrics['test_' + key] * retrieval_size  + \\\n                    test_retrieval['non_retrieval_metrics'][key] * (original_size - retrieval_size)) / original_size\n        print('Overall Performance of Our Framework:', overall_metrics)\n        with open(os.path.join(self.export_root, 'overall_metrics.json'), 'w') as f:\n                json.dump(overall_metrics, f, indent=4)\n        \n        return average_metrics\n\n    def get_train_dataloader(self):\n        return self.train_loader\n    \n    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:\n        return self.val_loader\n    \n    def get_test_dataloader(self, test_dataset: Optional[Dataset] = None) -> DataLoader:\n        return self.test_loader"
  },
  {
    "path": "trainer/loggers.py",
    "content": "import os\nimport torch\nfrom abc import ABCMeta, abstractmethod\n\n\ndef save_state_dict(state_dict, path, filename):\n    torch.save(state_dict, os.path.join(path, filename))\n\n\nclass LoggerService(object):\n    def __init__(self, args, writer, val_loggers, test_loggers, use_wandb):\n        self.args = args\n        self.writer = writer\n        self.val_loggers = val_loggers if val_loggers else []\n        self.test_loggers = test_loggers if test_loggers else []\n        self.use_wandb = use_wandb\n\n    def complete(self):\n        if self.use_wandb:\n            self.writer.finish()\n        else:\n            self.writer.close()\n\n    def log_val(self, log_data):\n        criteria_met = False\n        for logger in self.val_loggers:\n            logger.log(self.writer, **log_data)\n            if self.args.early_stopping and isinstance(logger, BestModelLogger):\n                criteria_met = logger.patience_counter >= self.args.early_stopping_patience\n        return criteria_met\n    \n    def log_test(self, log_data):\n        for logger in self.test_loggers:\n            logger.log(self.writer, **log_data)\n\n\nclass AbstractBaseLogger(metaclass=ABCMeta):\n    @abstractmethod\n    def log(self, *args, **kwargs):\n        raise NotImplementedError\n\n    def complete(self, *args, **kwargs):\n        pass\n\n\nclass MetricGraphPrinter(AbstractBaseLogger):\n    def __init__(self, key, graph_name, group_name, use_wandb):\n        self.key = key\n        self.graph_label = graph_name\n        self.group_name = group_name\n        self.use_wandb = use_wandb\n        \n    def log(self, writer, *args, **kwargs):\n        if self.key in kwargs:\n            if self.use_wandb:\n                writer.log({self.group_name+'/'+self.graph_label: kwargs[self.key], 'batch': kwargs['accum_iter']})\n            else:\n                writer.add_scalar(self.group_name+'/'+ self.graph_label, kwargs[self.key], kwargs['accum_iter'])\n        else:\n            print('Metric {} not found...'.format(self.key))\n\n    def complete(self, writer, *args, **kwargs):\n        self.log(writer, *args, **kwargs)\n\n\nclass RecentModelLogger(AbstractBaseLogger):\n    def __init__(self, args, checkpoint_path, filename='checkpoint-recent.pth'):\n        self.args = args\n        self.checkpoint_path = checkpoint_path\n        if not os.path.exists(self.checkpoint_path):\n            self.checkpoint_path.mkdir(parents=True)\n        self.recent_epoch = None\n        self.filename = filename\n\n    def log(self, *args, **kwargs):\n        epoch = kwargs['epoch']\n\n        if self.recent_epoch != epoch:\n            self.recent_epoch = epoch\n            state_dict = kwargs['state_dict']\n            state_dict['epoch'] = kwargs['epoch']\n            save_state_dict(state_dict, self.checkpoint_path, self.filename)\n\n    def complete(self, *args, **kwargs):\n        save_state_dict(kwargs['state_dict'],\n                        self.checkpoint_path, self.filename + '.final')\n\n\nclass BestModelLogger(AbstractBaseLogger):\n    def __init__(self, args, checkpoint_path, metric_key, filename='best_acc_model.pth'):\n        self.args = args\n        self.checkpoint_path = checkpoint_path\n        if not os.path.exists(self.checkpoint_path):\n            self.checkpoint_path.mkdir(parents=True)\n\n        self.best_metric = 0.\n        self.metric_key = metric_key\n        self.filename = filename\n        self.patience_counter = 0\n\n    def log(self, *args, **kwargs):\n        current_metric = kwargs[self.metric_key]\n        if self.best_metric < current_metric:  # assumes the higher the better\n            print(\"Update Best {} Model at {}\".format(\n                self.metric_key, kwargs['epoch']))\n            self.best_metric = current_metric\n            save_state_dict(kwargs['state_dict'],\n                            self.checkpoint_path, self.filename)\n            if self.args.early_stopping:\n                self.patience_counter = 0\n        elif self.args.early_stopping:\n            self.patience_counter += 1"
  },
  {
    "path": "trainer/lru.py",
    "content": "from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY\nfrom .utils import *\nfrom .loggers import *\nfrom .base import *\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\n\nimport json\nimport pickle\nimport numpy as np\nfrom abc import *\nfrom pathlib import Path\n\n\nclass LRUTrainer(BaseTrainer):\n    def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb):\n        super().__init__(args, model, train_loader, val_loader, test_loader, export_root, use_wandb)\n        self.ce = nn.CrossEntropyLoss(ignore_index=0)\n    \n    def calculate_loss(self, batch):\n        seqs, labels = batch\n        logits = self.model(seqs)\n        logits = logits.view(-1, logits.size(-1))\n        labels = labels.view(-1)\n        loss = self.ce(logits, labels)\n        return loss\n\n    def calculate_metrics(self, batch, exclude_history=True):\n        seqs, labels = batch\n        \n        scores = self.model(seqs)[:, -1, :]\n        B, L = seqs.shape\n        if exclude_history:\n            for i in range(L):\n                scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9\n            scores[:, 0] = -1e9  # padding\n        metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels.view(-1), self.metric_ks)\n        return metrics\n    \n    def generate_candidates(self, retrieved_data_path):\n        self.model.eval()\n        val_probs, val_labels = [], []\n        test_probs, test_labels = [], []\n        with torch.no_grad():\n            print('*************** Generating Candidates for Validation Set ***************')\n            tqdm_dataloader = tqdm(self.val_loader)\n            for batch_idx, batch in enumerate(tqdm_dataloader):\n                batch = self.to_device(batch)\n                seqs, labels = batch\n        \n                scores = self.model(seqs)[:, -1, :]\n                B, L = seqs.shape\n                for i in range(L):\n                    scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9\n                scores[:, 0] = -1e9  # padding\n                val_probs.extend(scores.tolist())\n                val_labels.extend(labels.view(-1).tolist())\n            val_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(val_probs), \n                                                          torch.tensor(val_labels).view(-1), self.metric_ks)\n            print(val_metrics)\n\n            print('****************** Generating Candidates for Test Set ******************')\n            tqdm_dataloader = tqdm(self.test_loader)\n            for batch_idx, batch in enumerate(tqdm_dataloader):\n                batch = self.to_device(batch)\n                seqs, labels = batch\n        \n                scores = self.model(seqs)[:, -1, :]\n                B, L = seqs.shape\n                for i in range(L):\n                    scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9\n                scores[:, 0] = -1e9  # padding\n                test_probs.extend(scores.tolist())\n                test_labels.extend(labels.view(-1).tolist())\n            test_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(test_probs), \n                                                           torch.tensor(test_labels).view(-1), self.metric_ks)\n            print(test_metrics)\n\n        with open(retrieved_data_path, 'wb') as f:\n            pickle.dump({'val_probs': val_probs,\n                         'val_labels': val_labels,\n                         'val_metrics': val_metrics,\n                         'test_probs': test_probs,\n                         'test_labels': test_labels,\n                         'test_metrics': test_metrics}, f)"
  },
  {
    "path": "trainer/utils.py",
    "content": "from config import *\n\nimport json\nimport os\nimport pprint as pp\nimport random\nfrom datetime import date\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch import optim as optim\n\n\ndef ndcg(scores, labels, k):\n    scores = scores.cpu()\n    labels = labels.cpu()\n    rank = (-scores).argsort(dim=1)\n    cut = rank[:, :k]\n    hits = labels.gather(1, cut)\n    position = torch.arange(2, 2+k)\n    weights = 1 / torch.log2(position.float())\n    dcg = (hits.float() * weights).sum(1)\n    idcg = torch.Tensor([weights[:min(int(n), k)].sum()\n                         for n in labels.sum(1)])\n    ndcg = dcg / idcg\n    return ndcg.mean()\n\n\ndef absolute_recall_mrr_ndcg_for_ks(scores, labels, ks):\n    metrics = {}\n    labels = F.one_hot(labels, num_classes=scores.size(1))\n    answer_count = labels.sum(1)\n\n    labels_float = labels.float()\n    rank = (-scores).argsort(dim=1)\n\n    cut = rank\n    for k in sorted(ks, reverse=True):\n        cut = cut[:, :k]\n        hits = labels_float.gather(1, cut)\n        metrics['Recall@%d' % k] = \\\n            (hits.sum(1) / torch.min(torch.Tensor([k]).to(\n                labels.device), labels.sum(1).float())).mean().cpu().item()\n        \n        metrics['MRR@%d' % k] = \\\n            (hits / torch.arange(1, k+1).unsqueeze(0).to(\n                labels.device)).sum(1).mean().cpu().item()\n\n        position = torch.arange(2, 2+k)\n        weights = 1 / torch.log2(position.float())\n        dcg = (hits * weights.to(hits.device)).sum(1)\n        idcg = torch.Tensor([weights[:min(int(n), k)].sum()\n                             for n in answer_count]).to(dcg.device)\n        ndcg = (dcg / idcg).mean()\n        metrics['NDCG@%d' % k] = ndcg.cpu().item()\n\n    return metrics\n\n\nclass AverageMeterSet(object):\n    def __init__(self, meters=None):\n        self.meters = meters if meters else {}\n\n    def __getitem__(self, key):\n        if key not in self.meters:\n            meter = AverageMeter()\n            meter.update(0)\n            return meter\n        return self.meters[key]\n\n    def update(self, name, value, n=1):\n        if name not in self.meters:\n            self.meters[name] = AverageMeter()\n        self.meters[name].update(value, n)\n\n    def reset(self):\n        for meter in self.meters.values():\n            meter.reset()\n\n    def values(self, format_string='{}'):\n        return {format_string.format(name): meter.val for name, meter in self.meters.items()}\n\n    def averages(self, format_string='{}'):\n        return {format_string.format(name): meter.avg for name, meter in self.meters.items()}\n\n    def sums(self, format_string='{}'):\n        return {format_string.format(name): meter.sum for name, meter in self.meters.items()}\n\n    def counts(self, format_string='{}'):\n        return {format_string.format(name): meter.count for name, meter in self.meters.items()}\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val\n        self.count += n\n        self.avg = self.sum / self.count\n\n    def __format__(self, format):\n        return \"{self.val:{format}} ({self.avg:{format}})\".format(self=self, format=format)\n"
  },
  {
    "path": "trainer/verb.py",
    "content": "from abc import abstractmethod\nimport json\n\nfrom transformers.file_utils import ModelOutput\nfrom transformers.data.processors.utils import InputFeatures\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom yacs.config import CfgNode\nfrom transformers.tokenization_utils import PreTrainedTokenizer\n\nimport numpy as np\nfrom collections import namedtuple\n\nimport inspect\nfrom typing import *\n\n_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}\n\n\ndef convert_cfg_to_dict(cfg_node, key_list=[]):\n    \"\"\" Convert a config node to dictionary \"\"\"\n    if not isinstance(cfg_node, CfgNode):\n        if type(cfg_node) not in _VALID_TYPES:\n            print(\"Key {} with value {} is not a valid type; valid types: {}\".format(\n                \".\".join(key_list), type(cfg_node), _VALID_TYPES), )\n        return cfg_node\n    else:\n        cfg_dict = dict(cfg_node)\n        for k, v in cfg_dict.items():\n            cfg_dict[k] = convert_cfg_to_dict(v, key_list + [k])\n        return cfg_dict\n\n\ndef signature(f):\n    r\"\"\"Get the function f 's input arguments. A useful gadget\n    when some function slot might be instantiated into multiple functions.\n    \n    Args:\n        f (:obj:`function`) : the function to get the input arguments.\n    \n    Returns:\n        namedtuple : of args, default, varargs, keywords, respectively.s\n\n    \"\"\"\n    sig = inspect.signature(f)\n    args = [\n        p.name for p in sig.parameters.values()\n        if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD\n    ]\n    varargs = [\n        p.name for p in sig.parameters.values()\n        if p.kind == inspect.Parameter.VAR_POSITIONAL\n    ]\n    varargs = varargs[0] if varargs else None\n    keywords = [\n        p.name for p in sig.parameters.values()\n        if p.kind == inspect.Parameter.VAR_KEYWORD\n    ]\n    keywords = keywords[0] if keywords else None\n    defaults = [\n        p.default for p in sig.parameters.values()\n        if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD\n        and p.default is not p.empty\n    ] or None\n    argspec = namedtuple('Signature', ['args', 'defaults',\n                                        'varargs', 'keywords'])\n    return argspec(args, defaults, varargs, keywords) \n\n\nclass Verbalizer(nn.Module):\n    r'''\n    Base class for all the verbalizers.\n\n    Args:\n        tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy.\n        classes (:obj:`Sequence[str]`): A sequence of classes that need to be projected.\n    '''\n    def __init__(self,\n                 tokenizer: Optional[PreTrainedTokenizer] = None,\n                 classes: Optional[Sequence[str]] = None,\n                 num_classes: Optional[int] = None,\n                ):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.classes = classes\n        if classes is not None and num_classes is not None:\n            assert len(classes) == num_classes, \"len(classes) != num_classes, Check you config.\"\n            self.num_classes = num_classes\n        elif num_classes is not None:\n            self.num_classes = num_classes\n        elif classes is not None:\n            self.num_classes = len(classes)\n        else:\n            self.num_classes = None\n            # raise AttributeError(\"No able to configure num_classes\")\n        self._in_on_label_words_set = False\n\n    @property\n    def label_words(self,):\n        r'''\n        Label words means the words in the vocabulary projected by the labels.\n        E.g. if we want to establish a projection in sentiment classification: positive :math:`\\rightarrow` {`wonderful`, `good`},\n        in this case, `wonderful` and `good` are label words.\n        '''\n        if not hasattr(self, \"_label_words\"):\n            raise RuntimeError(\"label words haven't been set.\")\n        return self._label_words\n\n    @label_words.setter\n    def label_words(self, label_words):\n        if label_words is None:\n            return\n        self._label_words = self._match_label_words_to_label_ids(label_words)\n        if not self._in_on_label_words_set:\n            self.safe_on_label_words_set()\n\n    def _match_label_words_to_label_ids(self, label_words): # TODO newly add function after docs written # TODO rename this function\n        \"\"\"\n        sort label words dict of verbalizer to match the label order of the classes\n        \"\"\"\n        if isinstance(label_words, dict):\n            if self.classes is None:\n                raise ValueError(\"\"\"\n                classes attribute of the Verbalizer should be set since your given label words is a dict.\n                Since we will match the label word with respect to class A, to A's index in classes\n                \"\"\")\n            if set(label_words.keys()) != set(self.classes):\n                raise ValueError(\"name of classes in verbalizer are different from those of dataset\")\n            label_words = [ # sort the dict to match dataset\n                label_words[c]\n                for c in self.classes\n            ] # length: label_size of the whole task\n        elif isinstance(label_words, list) or isinstance(label_words, tuple):\n            pass\n        else:\n            raise ValueError(\"Verbalizer label words must be list, tuple or dict\")\n        return label_words\n\n    def safe_on_label_words_set(self,):\n        self._in_on_label_words_set = True\n        self.on_label_words_set()\n        self._in_on_label_words_set = False\n\n    def on_label_words_set(self,):\n        r\"\"\"A hook to do something when textual label words were set.\n        \"\"\"\n        pass\n\n    @property\n    def vocab(self,) -> Dict:\n        if not hasattr(self, '_vocab'):\n            self._vocab = self.tokenizer.convert_ids_to_tokens(np.arange(self.vocab_size).tolist())\n        return self._vocab\n\n    @property\n    def vocab_size(self,) -> int:\n        return self.tokenizer.vocab_size\n\n    @abstractmethod\n    def generate_parameters(self, **kwargs) -> List:\n        r\"\"\"\n        The verbalizer can be seen as an extra layer on top of the original\n        pre-trained models. In manual verbalizer, it is a fixed one-hot vector of dimension\n        ``vocab_size``, with the position of the label word being 1 and 0 everywhere else.\n        In other situation, the parameters may be a continuous vector over the\n        vocab, with each dimension representing a weight of that token.\n        Moreover, the parameters may be set to trainable to allow label words selection.\n\n        Therefore, this function serves as an abstract methods for generating the parameters\n        of the verbalizer, and must be instantiated in any derived class.\n\n        Note that the parameters need to be registered as a part of pytorch's module to\n        It can be achieved by wrapping a tensor using ``nn.Parameter()``.\n        \"\"\"\n        raise NotImplementedError\n\n    def register_calibrate_logits(self, logits: torch.Tensor):\n        r\"\"\"\n        This function aims to register logits that need to be calibrated, and detach the original logits from the current graph.\n        \"\"\"\n        if logits.requires_grad:\n            logits = logits.detach()\n        self._calibrate_logits = logits\n\n    def process_outputs(self,\n                       outputs: torch.Tensor,\n                       batch: Union[Dict, InputFeatures],\n                       **kwargs):\n        r\"\"\"By default, the verbalizer will process the logits of the PLM's\n        output.\n\n        Args:\n            logits (:obj:`torch.Tensor`): The current logits generated by pre-trained language models.\n            batch (:obj:`Union[Dict, InputFeatures]`): The input features of the data.\n        \"\"\"\n\n        return self.process_logits(outputs, batch=batch, **kwargs)\n\n    def gather_outputs(self, outputs: ModelOutput):\n        r\"\"\" retrieve useful output for the verbalizer from the whole model output\n        By default, it will only retrieve the logits\n\n        Args:\n            outputs (:obj:`ModelOutput`) The output from the pretrained language model.\n\n        Return:\n            :obj:`torch.Tensor` The gathered output, should be of shape (``batch_size``,\n            ``seq_len``, ``any``)\n        \"\"\"\n        return outputs.logits\n\n    @staticmethod\n    def aggregate(label_words_logits: torch.Tensor) -> torch.Tensor:\n        r\"\"\" To aggregate logits on multiple label words into the label's logits\n        Basic aggregator: mean of each label words' logits to a label's logits\n        Can be re-implemented in advanced verbaliezer.\n\n        Args:\n            label_words_logits (:obj:`torch.Tensor`): The logits of the label words only.\n\n        Return:\n            :obj:`torch.Tensor`: The final logits calculated by the label words.\n        \"\"\"\n        if label_words_logits.dim()>2:\n            return label_words_logits.mean(dim=-1)\n        else:\n            return label_words_logits\n\n\n    def normalize(self, logits: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Given logits regarding the entire vocab, calculate the probs over the label words set by softmax.\n\n        Args:\n            logits(:obj:`Tensor`): The logits of the entire vocab.\n\n        Returns:\n            :obj:`Tensor`: The probability distribution over the label words set.\n        \"\"\"\n        batch_size = logits.shape[0]\n        return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape)\n\n    @abstractmethod\n    def project(self,\n                logits: torch.Tensor,\n                **kwargs) -> torch.Tensor:\n        r\"\"\"This method receives input logits of shape ``[batch_size, vocab_size]``, and use the\n        parameters of this verbalizer to project the logits over entire vocab into the\n        logits of labels words.\n\n        Args:\n            logits (:obj:`Tensor`): The logits over entire vocab generated by the pre-trained language model with shape [``batch_size``, ``max_seq_length``, ``vocab_size``]\n\n        Returns:\n            :obj:`Tensor`: The normalized probs (sum to 1) of each label .\n        \"\"\"\n        raise NotImplementedError\n\n    def handle_multi_token(self, label_words_logits, mask):\n        r\"\"\"\n        Support multiple methods to handle the multi tokens produced by the tokenizer.\n        We suggest using 'first' or 'max' if the some parts of the tokenization is not meaningful.\n        Can broadcast to 3-d tensor.\n\n        Args:\n            label_words_logits (:obj:`torch.Tensor`):\n\n        Returns:\n            :obj:`torch.Tensor`\n        \"\"\"\n        if self.multi_token_handler == \"first\":\n            label_words_logits = label_words_logits.select(dim=-1, index=0)\n        elif self.multi_token_handler == \"max\":\n            label_words_logits = label_words_logits - 1000*(1-mask.unsqueeze(0))\n            label_words_logits = label_words_logits.max(dim=-1).values\n        elif self.multi_token_handler == \"mean\":\n            label_words_logits = (label_words_logits*mask.unsqueeze(0)).sum(dim=-1)/(mask.unsqueeze(0).sum(dim=-1)+1e-15)\n        else:\n            raise ValueError(\"multi_token_handler {} not configured\".format(self.multi_token_handler))\n        return label_words_logits\n\n    @classmethod\n    def from_config(cls,\n                    config: CfgNode,\n                    **kwargs):\n        r\"\"\"load a verbalizer from verbalizer's configuration node.\n\n        Args:\n            config (:obj:`CfgNode`): the sub-configuration of verbalizer, i.e. ``config[config.verbalizer]``\n                        if config is a global config node.\n            kwargs: Other kwargs that might be used in initialize the verbalizer.\n                    The actual value should match the arguments of ``__init__`` functions.\n        \"\"\"\n\n        init_args = signature(cls.__init__).args\n        _init_dict = {**convert_cfg_to_dict(config), **kwargs} if config is not None else kwargs\n        init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args}\n        verbalizer = cls(**init_dict)\n        if hasattr(verbalizer, \"from_file\"):\n            if not hasattr(config, \"file_path\"):\n                pass\n            else:\n                if (not hasattr(config, \"label_words\") or config.label_words is None) and config.file_path is not None:\n                    if config.choice is None:\n                        config.choice = 0\n                    verbalizer.from_file(config.file_path, config.choice)\n                elif (hasattr(config, \"label_words\") and config.label_words is not None) and config.file_path is not None:\n                    raise RuntimeError(\"The text can't be both set from `text` and `file_path`.\")\n        return verbalizer\n\n    def from_file(self,\n                  path: str,\n                  choice: Optional[int] = 0 ):\n        r\"\"\"Load the predefined label words from verbalizer file.\n        Currently support three types of file format:\n        1. a .jsonl or .json file, in which is a single verbalizer\n        in dict format.\n        2. a .jsonal or .json file, in which is a list of verbalizers in dict format\n        3.  a .txt or a .csv file, in which is the label words of a class are listed in line,\n        separated by commas. Begin a new verbalizer by an empty line.\n        This format is recommended when you don't know the name of each class.\n\n        The details of verbalizer format can be seen in :ref:`How_to_write_a_verbalizer`.\n\n        Args:\n            path (:obj:`str`): The path of the local template file.\n            choice (:obj:`int`): The choice of verbalizer in a file containing\n                             multiple verbalizers.\n\n        Returns:\n            Template : `self` object\n        \"\"\"\n        if path.endswith(\".txt\") or path.endswith(\".csv\"):\n            with open(path, 'r') as f:\n                lines = f.readlines()\n                label_words_all = []\n                label_words_single_group = []\n                for line in lines:\n                    line = line.strip().strip(\" \")\n                    if line == \"\":\n                        if len(label_words_single_group)>0:\n                            label_words_all.append(label_words_single_group)\n                        label_words_single_group = []\n                    else:\n                        label_words_single_group.append(line)\n                if len(label_words_single_group) > 0: # if no empty line in the last\n                    label_words_all.append(label_words_single_group)\n                if choice >= len(label_words_all):\n                    raise RuntimeError(\"choice {} exceed the number of verbalizers {}\"\n                                .format(choice, len(label_words_all)))\n\n                label_words = label_words_all[choice]\n                label_words = [label_words_per_label.strip().split(\",\") \\\n                            for label_words_per_label in label_words]\n\n        elif path.endswith(\".jsonl\") or path.endswith(\".json\"):\n            with open(path, \"r\") as f:\n                label_words_all = json.load(f)\n                # if it is a file containing multiple verbalizers\n                if isinstance(label_words_all, list):\n                    if choice >= len(label_words_all):\n                        raise RuntimeError(\"choice {} exceed the number of verbalizers {}\"\n                                .format(choice, len(label_words_all)))\n                    label_words = label_words_all[choice]\n                elif isinstance(label_words_all, dict):\n                    label_words = label_words_all\n                    if choice>0:\n                        print(\"Choice of verbalizer is 1, but the file  \\\n                        only contains one verbalizer.\")\n\n        self.label_words = label_words\n        if self.num_classes is not None:\n            num_classes = len(self.label_words)\n            assert num_classes==self.num_classes, 'number of classes in the verbalizer file\\\n                                            does not match the predefined num_classes.'\n        return self\n\n\nclass ManualVerbalizer(Verbalizer):\n    r\"\"\"\n    The basic manually defined verbalizer class, this class is inherited from the :obj:`Verbalizer` class.\n\n    Args:\n        tokenizer (:obj:`PreTrainedTokenizer`): The tokenizer of the current pre-trained model to point out the vocabulary.\n        classes (:obj:`List[Any]`): The classes (or labels) of the current task.\n        label_words (:obj:`Union[List[str], List[List[str]], Dict[List[str]]]`, optional): The label words that are projected by the labels.\n        prefix (:obj:`str`, optional): The prefix string of the verbalizer (used in PLMs like RoBERTa, which is sensitive to prefix space)\n        multi_token_handler (:obj:`str`, optional): The handling strategy for multiple tokens produced by the tokenizer.\n        post_log_softmax (:obj:`bool`, optional): Whether to apply log softmax post processing on label_logits. Default to True.\n    \"\"\"\n    def __init__(self,\n                 tokenizer: PreTrainedTokenizer,\n                 classes: Optional[List] = None,\n                 num_classes: Optional[Sequence[str]] = None,\n                 label_words: Optional[Union[Sequence[str], Mapping[str, str]]] = None,\n                 prefix: Optional[str] = \" \",\n                 multi_token_handler: Optional[str] = \"first\",\n                 post_log_softmax: Optional[bool] = True,\n                ):\n        super().__init__(tokenizer=tokenizer, num_classes=num_classes, classes=classes)\n        self.prefix = prefix\n        self.multi_token_handler = multi_token_handler\n        self.label_words = label_words\n        self.post_log_softmax = post_log_softmax\n\n    def on_label_words_set(self):\n        super().on_label_words_set()\n        self.label_words = self.add_prefix(self.label_words, self.prefix)\n\n         # TODO should Verbalizer base class has label_words property and setter?\n         # it don't have label_words init argument or label words from_file option at all\n\n        self.generate_parameters()\n\n    @staticmethod\n    def add_prefix(label_words, prefix):\n        r\"\"\"Add prefix to label words. For example, if a label words is in the middle of a template,\n        the prefix should be ``' '``.\n\n        Args:\n            label_words (:obj:`Union[Sequence[str], Mapping[str, str]]`, optional): The label words that are projected by the labels.\n            prefix (:obj:`str`, optional): The prefix string of the verbalizer.\n\n        Returns:\n            :obj:`Sequence[str]`: New label words with prefix.\n        \"\"\"\n        new_label_words = []\n        if isinstance(label_words[0], str):\n            label_words = [[w] for w in label_words]  #wrapped it to a list of list of label words.\n\n        for label_words_per_label in label_words:\n            new_label_words_per_label = []\n            for word in label_words_per_label:\n                if word.startswith(\"<!>\"):\n                    new_label_words_per_label.append(word.split(\"<!>\")[1])\n                else:\n                    new_label_words_per_label.append(prefix + word)\n            new_label_words.append(new_label_words_per_label)\n        return new_label_words\n\n    def generate_parameters(self) -> List:\n        r\"\"\"In basic manual template, the parameters are generated from label words directly.\n        In this implementation, the label_words should not be tokenized into more than one token.\n        \"\"\"\n        all_ids = []\n        for words_per_label in self.label_words:\n            ids_per_label = []\n            for word in words_per_label:\n                ids = self.tokenizer.encode(word, add_special_tokens=False)\n                ids_per_label.append(ids)\n            all_ids.append(ids_per_label)\n\n        max_len  = max([max([len(ids) for ids in ids_per_label]) for ids_per_label in all_ids])\n        max_num_label_words = max([len(ids_per_label) for ids_per_label in all_ids])\n        words_ids_mask = torch.zeros(max_num_label_words, max_len)\n        words_ids_mask = [[[1]*len(ids) + [0]*(max_len-len(ids)) for ids in ids_per_label]\n                             + [[0]*max_len]*(max_num_label_words-len(ids_per_label))\n                             for ids_per_label in all_ids]\n        words_ids = [[ids + [0]*(max_len-len(ids)) for ids in ids_per_label]\n                             + [[0]*max_len]*(max_num_label_words-len(ids_per_label))\n                             for ids_per_label in all_ids]\n\n        words_ids_tensor = torch.tensor(words_ids)\n        words_ids_mask = torch.tensor(words_ids_mask)\n        self.label_words_ids = nn.Parameter(words_ids_tensor, requires_grad=False)\n        self.words_ids_mask = nn.Parameter(words_ids_mask, requires_grad=False) # A 3-d mask\n        self.label_words_mask = nn.Parameter(torch.clamp(words_ids_mask.sum(dim=-1), max=1), requires_grad=False)\n\n    def project(self,\n                logits: torch.Tensor,\n                **kwargs,\n                ) -> torch.Tensor:\n        r\"\"\"\n        Project the labels, the return value is the normalized (sum to 1) probs of label words.\n\n        Args:\n            logits (:obj:`torch.Tensor`): The original logits of label words.\n\n        Returns:\n            :obj:`torch.Tensor`: The normalized logits of label words\n        \"\"\"\n\n        label_words_logits = logits[:, self.label_words_ids]\n        label_words_logits = self.handle_multi_token(label_words_logits, self.words_ids_mask)\n        label_words_logits -= 10000*(1-self.label_words_mask)\n        return label_words_logits\n\n    def process_logits(self, logits: torch.Tensor, **kwargs):\n        r\"\"\"A whole framework to process the original logits over the vocabulary, which contains four steps:\n\n        (1) Project the logits into logits of label words\n\n        if self.post_log_softmax is True:\n\n            (2) Normalize over all label words\n\n            (3) Calibrate (optional)\n\n        (4) Aggregate (for multiple label words)\n\n        Args:\n            logits (:obj:`torch.Tensor`): The original logits.\n\n        Returns:\n            (:obj:`torch.Tensor`): The final processed logits over the labels (classes).\n        \"\"\"\n        # project\n        label_words_logits = self.project(logits, **kwargs)  #Output: (batch_size, num_classes) or  (batch_size, num_classes, num_label_words_per_label)\n\n\n        if self.post_log_softmax:\n            # normalize\n            label_words_probs = self.normalize(label_words_logits)\n\n            # calibrate\n            if  hasattr(self, \"_calibrate_logits\") and self._calibrate_logits is not None:\n                label_words_probs = self.calibrate(label_words_probs=label_words_probs)\n\n            # convert to logits\n            label_words_logits = torch.log(label_words_probs+1e-15)\n\n        # aggregate\n        label_logits = self.aggregate(label_words_logits)\n        return label_logits\n\n    def normalize(self, logits: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Given logits regarding the entire vocabulary, return the probs over the label words set.\n\n        Args:\n            logits (:obj:`Tensor`): The logits over the entire vocabulary.\n\n        Returns:\n            :obj:`Tensor`: The logits over the label words set.\n\n        \"\"\"\n        batch_size = logits.shape[0]\n        return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape)\n\n\n    def aggregate(self, label_words_logits: torch.Tensor) -> torch.Tensor:\n        r\"\"\"Use weight to aggregate the logits of label words.\n\n        Args:\n            label_words_logits(:obj:`torch.Tensor`): The logits of the label words.\n\n        Returns:\n            :obj:`torch.Tensor`: The aggregated logits from the label words.\n        \"\"\"\n        label_words_logits = (label_words_logits * self.label_words_mask).sum(-1)/self.label_words_mask.sum(-1)\n        return label_words_logits\n\n    def calibrate(self, label_words_probs: torch.Tensor, **kwargs) -> torch.Tensor:\n        r\"\"\"\n\n        Args:\n            label_words_probs (:obj:`torch.Tensor`): The probability distribution of the label words with the shape of [``batch_size``, ``num_classes``, ``num_label_words_per_class``]\n\n        Returns:\n            :obj:`torch.Tensor`: The calibrated probability of label words.\n        \"\"\"\n        shape = label_words_probs.shape\n        assert self._calibrate_logits.dim() ==  1, \"self._calibrate_logits are not 1-d tensor\"\n        calibrate_label_words_probs = self.normalize(self.project(self._calibrate_logits.unsqueeze(0), **kwargs))\n        assert calibrate_label_words_probs.shape[1:] == label_words_probs.shape[1:] \\\n             and calibrate_label_words_probs.shape[0]==1, \"shape not match\"\n        label_words_probs /= (calibrate_label_words_probs+1e-15)\n        # normalize # TODO Test the performance\n        norm = label_words_probs.reshape(shape[0], -1).sum(dim=-1,keepdim=True) # TODO Test the performance of detaching()\n        label_words_probs = label_words_probs.reshape(shape[0], -1) / norm\n        label_words_probs = label_words_probs.reshape(*shape)\n        return label_words_probs"
  }
]