Repository: Yueeeeeeee/LlamaRec Branch: main Commit: 48b288b23197 Files: 32 Total size: 159.5 KB Directory structure: gitextract_v751wz7n/ ├── .gitignore ├── README.md ├── config.py ├── dataloader/ │ ├── __init__.py │ ├── base.py │ ├── llm.py │ ├── lru.py │ ├── templates/ │ │ ├── README.md │ │ ├── alpaca.json │ │ ├── alpaca_legacy.json │ │ ├── alpaca_short.json │ │ └── vigogne.json │ └── utils.py ├── datasets/ │ ├── __init__.py │ ├── base.py │ ├── beauty.py │ ├── games.py │ ├── ml_100k.py │ └── utils.py ├── model/ │ ├── __init__.py │ ├── llm.py │ └── lru.py ├── requirements.txt ├── train_ranker.py ├── train_retriever.py └── trainer/ ├── __init__.py ├── base.py ├── llm.py ├── loggers.py ├── lru.py ├── utils.py └── verb.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pyc *.p *.pt *.pth .DS_Store /.ipynb_checkpoints/* /.vscode/* /wandb/* /data/* /retrieved/* /experiments/* /archive/* ================================================ FILE: README.md ================================================ # LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking This repository is the PyTorch impelementation for the PGAI@CIKM 2023 paper **LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking [[Paper](https://arxiv.org/abs/2311.02089)]**. We propose a two-stage framework using large language models for ranking-based recommendation (LlamaRec). In particular, we use small-scale sequential recommenders to retrieve candidates based on the user interaction history. Then, both history and retrieved items are fed to the LLM in text via a carefully designed prompt template. Instead of generating next-item titles, we adopt a verbalizer-based approach that transforms output logits into probability distributions over the candidate items. Therefore, LlamaRec can efficiently rank items without generating long text and achieve superior performance in both recommendation performance and efficiency. ## Requirements Pytorch, transformers, peft, bitsandbytes etc. For our detailed running environment see requirements.txt. ## How to run LlamaRec The command below starts the training of the retriever model LRURec ```bash python train_retriever.py ``` You can set additional arguments like weight_decay to change the hyperparameters. Upon the command, you will be prompted to select dataset from ML-100k, Beauty and Games. Once training is finished, evaluation is automatically performed with the best retriever model. Then, run the following command to train the ranker model based on Llama 2 ```bash python train_ranker.py --llm_retrieved_path PATH_TO_RETRIEVER ``` Please specify PATH_TO_RETRIEVER with the retriever path from the previous step. To run this command, you will need access to meta-llama/Llama-2-7b-hf on the HF hub. Similarly, evaluation is performed after training is finished. All weights and results are saved under ./experiments. ## Performance The table below reports our main performance results, with best results marked in bold and second best results underlined. For training and evaluation details, please refer to our paper. ## Citation Please consider citing the following papers if you use our methods in your research: ``` @article{yue2023linear, title={Linear Recurrent Units for Sequential Recommendation}, author={Yue, Zhenrui and Wang, Yueqi and He, Zhankui and Zeng, Huimin and McAuley, Julian and Wang, Dong}, journal={arXiv preprint arXiv:2310.02367}, year={2023} } @article{yue2023llamarec, title={LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking}, author={Yue, Zhenrui and Rabhi, Sara and Moreira, Gabriel de Souza Pereira and Wang, Dong and Oldridge, Even}, journal={arXiv preprint arXiv:2311.02089}, year={2023} } ``` ================================================ FILE: config.py ================================================ import numpy as np import random import torch import argparse RAW_DATASET_ROOT_FOLDER = 'data' EXPERIMENT_ROOT = 'experiments' STATE_DICT_KEY = 'model_state_dict' OPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict' PROJECT_NAME = 'llmrec' def set_template(args): if args.dataset_code == None: print('******************** Dataset Selection ********************') dataset_code = {'1': 'ml-100k', 'b': 'beauty', 'g': 'games'} args.dataset_code = dataset_code[input('Input 1 for ml-100k, b for beauty and g for games: ')] if args.dataset_code == 'ml-100k': args.bert_max_len = 200 else: args.bert_max_len = 50 if 'llm' in args.model_code: batch = 16 if args.dataset_code == 'ml-100k' else 12 args.lora_micro_batch_size = batch else: batch = 16 if args.dataset_code == 'ml-100k' else 64 args.train_batch_size = batch args.val_batch_size = batch args.test_batch_size = batch if torch.cuda.is_available(): args.device = 'cuda' else: args.device = 'cpu' args.optimizer = 'AdamW' args.lr = 0.001 args.weight_decay = 0.01 args.enable_lr_schedule = False args.decay_step = 10000 args.gamma = 1. args.enable_lr_warmup = False args.warmup_steps = 100 args.metric_ks = [1, 5, 10, 20, 50] args.rerank_metric_ks = [1, 5, 10] args.best_metric = 'Recall@10' args.rerank_best_metric = 'NDCG@10' args.bert_num_blocks = 2 args.bert_num_heads = 2 args.bert_head_size = None parser = argparse.ArgumentParser() ################ # Dataset ################ parser.add_argument('--dataset_code', type=str, default=None) parser.add_argument('--min_rating', type=int, default=0) parser.add_argument('--min_uc', type=int, default=5) parser.add_argument('--min_sc', type=int, default=5) parser.add_argument('--seed', type=int, default=42) ################ # Dataloader ################ parser.add_argument('--train_batch_size', type=int, default=64) parser.add_argument('--val_batch_size', type=int, default=64) parser.add_argument('--test_batch_size', type=int, default=64) parser.add_argument('--num_workers', type=int, default=8) parser.add_argument('--sliding_window_size', type=float, default=1.0) parser.add_argument('--negative_sample_size', type=int, default=10) ################ # Trainer ################ # optimization # parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) parser.add_argument('--num_epochs', type=int, default=500) parser.add_argument('--optimizer', type=str, default='AdamW', choices=['AdamW', 'Adam']) parser.add_argument('--weight_decay', type=float, default=None) parser.add_argument('--adam_epsilon', type=float, default=1e-9) parser.add_argument('--momentum', type=float, default=None) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--max_grad_norm', type=float, default=5.0) parser.add_argument('--enable_lr_schedule', type=bool, default=True) parser.add_argument('--decay_step', type=int, default=10000) parser.add_argument('--gamma', type=float, default=1) parser.add_argument('--enable_lr_warmup', type=bool, default=True) parser.add_argument('--warmup_steps', type=int, default=100) # evaluation # parser.add_argument('--val_strategy', type=str, default='iteration', choices=['epoch', 'iteration']) parser.add_argument('--val_iterations', type=int, default=500) # only for iteration val_strategy parser.add_argument('--early_stopping', type=bool, default=True) parser.add_argument('--early_stopping_patience', type=int, default=20) parser.add_argument('--metric_ks', nargs='+', type=int, default=[1, 5, 10, 20, 50]) parser.add_argument('--rerank_metric_ks', nargs='+', type=int, default=[1, 5, 10]) parser.add_argument('--best_metric', type=str, default='Recall@10') parser.add_argument('--rerank_best_metric', type=str, default='NDCG@10') parser.add_argument('--use_wandb', type=bool, default=False) ################ # Retriever Model ################ parser.add_argument('--model_code', type=str, default=None) parser.add_argument('--bert_max_len', type=int, default=50) parser.add_argument('--bert_hidden_units', type=int, default=64) parser.add_argument('--bert_num_blocks', type=int, default=2) parser.add_argument('--bert_num_heads', type=int, default=2) parser.add_argument('--bert_head_size', type=int, default=32) parser.add_argument('--bert_dropout', type=float, default=0.2) parser.add_argument('--bert_attn_dropout', type=float, default=0.2) parser.add_argument('--bert_mask_prob', type=float, default=0.25) ################ # LLM Model ################ parser.add_argument('--llm_base_model', type=str, default='meta-llama/Llama-2-7b-hf') parser.add_argument('--llm_base_tokenizer', type=str, default='meta-llama/Llama-2-7b-hf') parser.add_argument('--llm_max_title_len', type=int, default=32) parser.add_argument('--llm_max_text_len', type=int, default=1536) parser.add_argument('--llm_max_history', type=int, default=20) parser.add_argument('--llm_train_on_inputs', type=bool, default=False) parser.add_argument('--llm_negative_sample_size', type=int, default=19) # 19 negative & 1 positive parser.add_argument('--llm_system_template', type=str, # instruction default="Given user history in chronological order, recommend an item from the candidate pool with its index letter.") parser.add_argument('--llm_input_template', type=str, \ default='User history: {}; \n Candidate pool: {}') parser.add_argument('--llm_load_in_4bit', type=bool, default=True) parser.add_argument('--llm_retrieved_path', type=str, default=None) parser.add_argument('--llm_cache_dir', type=str, default=None) ################ # Lora ################ parser.add_argument('--lora_r', type=int, default=8) parser.add_argument('--lora_alpha', type=int, default=32) parser.add_argument('--lora_dropout', type=float, default=0.05) parser.add_argument('--lora_target_modules', type=list, default=['q_proj', 'v_proj']) parser.add_argument('--lora_num_epochs', type=int, default=1) parser.add_argument('--lora_val_iterations', type=int, default=100) parser.add_argument('--lora_early_stopping_patience', type=int, default=20) parser.add_argument('--lora_lr', type=float, default=1e-4) parser.add_argument('--lora_micro_batch_size', type=int, default=16) ################ args = parser.parse_args() ================================================ FILE: dataloader/__init__.py ================================================ from datasets import dataset_factory from .lru import * from .llm import * from .utils import * def dataloader_factory(args): dataset = dataset_factory(args) if args.model_code == 'lru': dataloader = LRUDataloader(args, dataset) elif args.model_code == 'llm': dataloader = LLMDataloader(args, dataset) train, val, test = dataloader.get_pytorch_dataloaders() if 'llm' in args.model_code: tokenizer = dataloader.tokenizer test_retrieval = dataloader.test_retrieval return train, val, test, tokenizer, test_retrieval else: return train, val, test def test_subset_dataloader_loader(args): dataset = dataset_factory(args) if args.model_code == 'lru': dataloader = LRUDataloader(args, dataset) elif args.model_code == 'llm': dataloader = LLMDataloader(args, dataset) return dataloader.get_pytorch_test_subset_dataloader() ================================================ FILE: dataloader/base.py ================================================ from abc import * import random class AbstractDataloader(metaclass=ABCMeta): def __init__(self, args, dataset): self.args = args self.save_folder = dataset._get_preprocessed_folder_path() dataset = dataset.load_dataset() self.train = dataset['train'] self.val = dataset['val'] self.test = dataset['test'] self.meta = dataset['meta'] self.umap = dataset['umap'] self.smap = dataset['smap'] self.user_count = len(self.umap) self.item_count = len(self.smap) @classmethod @abstractmethod def code(cls): pass @abstractmethod def get_pytorch_dataloaders(self): pass ================================================ FILE: dataloader/llm.py ================================================ from .base import AbstractDataloader from .utils import Prompter import torch import random import numpy as np import torch.utils.data as data_utils import os import pickle import transformers from transformers import AutoTokenizer from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT from trainer import absolute_recall_mrr_ndcg_for_ks def worker_init_fn(worker_id): random.seed(np.random.get_state()[1][0] + worker_id) np.random.seed(np.random.get_state()[1][0] + worker_id) # the following prompting is based on alpaca def generate_and_tokenize_eval(args, data_point, tokenizer, prompter): in_prompt = prompter.generate_prompt(data_point["system"], data_point["input"]) tokenized_full_prompt = tokenizer(in_prompt, truncation=True, max_length=args.llm_max_text_len, padding=False, return_tensors=None) tokenized_full_prompt["labels"] = ord(data_point["output"]) - ord('A') return tokenized_full_prompt def generate_and_tokenize_train(args, data_point, tokenizer, prompter): def tokenize(prompt, add_eos_token=True): result = tokenizer(prompt, truncation=True, max_length=args.llm_max_text_len, padding=False, return_tensors=None) if (result["input_ids"][-1] != tokenizer.eos_token_id and add_eos_token): result["input_ids"].append(tokenizer.eos_token_id) result["attention_mask"].append(1) result["labels"] = result["input_ids"].copy() return result full_prompt = prompter.generate_prompt(data_point["system"], data_point["input"], data_point["output"]) tokenized_full_prompt = tokenize(full_prompt, add_eos_token=True) if not args.llm_train_on_inputs: tokenized_full_prompt["labels"][:-2] = [-100] * len(tokenized_full_prompt["labels"][:-2]) return tokenized_full_prompt def seq_to_token_ids(args, seq, candidates, label, text_dict, tokenizer, prompter, eval=False): def truncate_title(title): title_ = tokenizer.tokenize(title)[:args.llm_max_title_len] title = tokenizer.convert_tokens_to_string(title_) return title seq_t = ' \n '.join(['(' + str(idx + 1) + ') ' + truncate_title(text_dict[item]) for idx, item in enumerate(seq)]) can_t = ' \n '.join(['(' + chr(ord('A') + idx) + ') ' + truncate_title(text_dict[item]) for idx, item in enumerate(candidates)]) output = chr(ord('A') + candidates.index(label)) # ranking only data_point = {} data_point['system'] = args.llm_system_template if args.llm_system_template is not None else DEFAULT_SYSTEM_PROMPT data_point['input'] = args.llm_input_template.format(seq_t, can_t) data_point['output'] = output if eval: return generate_and_tokenize_eval(args, data_point, tokenizer, prompter) else: return generate_and_tokenize_train(args, data_point, tokenizer, prompter) class LLMDataloader(): def __init__(self, args, dataset): self.args = args self.rng = np.random self.save_folder = dataset._get_preprocessed_folder_path() seq_dataset = dataset.load_dataset() self.train = seq_dataset['train'] self.val = seq_dataset['val'] self.test = seq_dataset['test'] self.umap = seq_dataset['umap'] self.smap = seq_dataset['smap'] self.text_dict = seq_dataset['meta'] self.user_count = len(self.umap) self.item_count = len(self.smap) args.num_items = self.item_count self.max_len = args.llm_max_history self.tokenizer = AutoTokenizer.from_pretrained( args.llm_base_tokenizer, cache_dir=args.llm_cache_dir) self.tokenizer.pad_token = self.tokenizer.unk_token self.tokenizer.padding_side = 'left' self.tokenizer.truncation_side = 'left' self.tokenizer.clean_up_tokenization_spaces = True self.prompter = Prompter() self.llm_retrieved_path = args.llm_retrieved_path print('Loading retrieved file from {}'.format(self.llm_retrieved_path)) retrieved_file = pickle.load(open(os.path.join(args.llm_retrieved_path, 'retrieved.pkl'), 'rb')) print('******************** Constructing Validation Subset ********************') self.val_probs = retrieved_file['val_probs'] self.val_labels = retrieved_file['val_labels'] self.val_metrics = retrieved_file['val_metrics'] self.val_users = [u for u, (p, l) in enumerate(zip(self.val_probs, self.val_labels), start=1) \ if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices] self.val_candidates = [torch.topk(torch.tensor(self.val_probs[u-1]), self.args.llm_negative_sample_size+1).indices.tolist() for u in self.val_users] print('******************** Constructing Test Subset ********************') self.test_probs = retrieved_file['test_probs'] self.test_labels = retrieved_file['test_labels'] self.test_metrics = retrieved_file['test_metrics'] self.test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \ if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices] self.test_candidates = [torch.topk(torch.tensor(self.test_probs[u-1]), self.args.llm_negative_sample_size+1).indices.tolist() for u in self.test_users] self.non_test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \ if l not in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices] self.test_retrieval = { 'original_size': len(self.test_probs), 'retrieval_size': len(self.test_candidates), 'original_metrics': self.test_metrics, 'retrieval_metrics': absolute_recall_mrr_ndcg_for_ks( torch.tensor(self.test_probs)[torch.tensor(self.test_users)-1], torch.tensor(self.test_labels)[torch.tensor(self.test_users)-1], self.args.metric_ks, ), 'non_retrieval_metrics': absolute_recall_mrr_ndcg_for_ks( torch.tensor(self.test_probs)[torch.tensor(self.non_test_users)-1], torch.tensor(self.test_labels)[torch.tensor(self.non_test_users)-1], self.args.metric_ks, ), } @classmethod def code(cls): return 'llm' def get_pytorch_dataloaders(self): train_loader = self._get_train_loader() val_loader = self._get_val_loader() test_loader = self._get_test_loader() return train_loader, val_loader, test_loader def _get_train_loader(self): dataset = self._get_train_dataset() dataloader = data_utils.DataLoader(dataset, batch_size=self.args.lora_micro_batch_size, shuffle=True, pin_memory=True, num_workers=self.args.num_workers, worker_init_fn=worker_init_fn) return dataloader def _get_train_dataset(self): dataset = LLMTrainDataset(self.args, self.train, self.max_len, self.rng, self.text_dict, self.tokenizer, self.prompter) return dataset def _get_val_loader(self): return self._get_eval_loader(mode='val') def _get_test_loader(self): return self._get_eval_loader(mode='test') def _get_eval_loader(self, mode): batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size dataset = self._get_eval_dataset(mode) dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=self.args.num_workers) return dataloader def _get_eval_dataset(self, mode): if mode == 'val': dataset = LLMValidDataset(self.args, self.train, self.val, self.max_len, self.rng, \ self.text_dict, self.tokenizer, self.prompter, self.val_users, \ self.val_candidates) elif mode == 'test': dataset = LLMTestDataset(self.args, self.train, self.val, self.test, self.max_len, \ self.rng, self.text_dict, self.tokenizer, self.prompter, self.test_users, \ self.test_candidates) return dataset class LLMTrainDataset(data_utils.Dataset): def __init__(self, args, u2seq, max_len, rng, text_dict, tokenizer, prompter): self.args = args self.max_len = max_len self.num_items = args.num_items self.rng = rng self.text_dict = text_dict self.tokenizer = tokenizer self.prompter = prompter self.all_seqs = [] for u in sorted(u2seq.keys()): seq = u2seq[u] for i in range(2, len(seq)+1): self.all_seqs += [seq[:i]] def __len__(self): return len(self.all_seqs) def __getitem__(self, index): tokens = self.all_seqs[index] answer = tokens[-1] original_seq = tokens[:-1] seq = original_seq[-self.max_len:] cur_idx, candidates = 0, [answer] samples = self.rng.randint(1, self.args.num_items+1, size=5*self.args.llm_negative_sample_size) while len(candidates) < self.args.llm_negative_sample_size + 1: item = samples[cur_idx] cur_idx += 1 if item in original_seq or item == answer: continue else: candidates.append(item) self.rng.shuffle(candidates) return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, \ self.tokenizer, self.prompter, eval=False) class LLMValidDataset(data_utils.Dataset): def __init__(self, args, u2seq, u2answer, max_len, rng, text_dict, tokenizer, prompter, val_users, val_candidates): self.args = args self.u2seq = u2seq self.u2answer = u2answer self.users = sorted(self.u2seq.keys()) self.max_len = max_len self.rng = rng self.text_dict = text_dict self.tokenizer = tokenizer self.prompter = prompter self.val_users = val_users self.val_candidates = val_candidates def __len__(self): return len(self.val_users) def __getitem__(self, index): user = self.val_users[index] seq = self.u2seq[user] answer = self.u2answer[user][0] seq = seq[-self.max_len:] candidates = self.val_candidates[index] assert answer in candidates # self.rng.shuffle(candidates) return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True) class LLMTestDataset(data_utils.Dataset): def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, text_dict, tokenizer, prompter, test_users, test_candidates): self.args = args self.u2seq = u2seq self.u2val = u2val self.u2answer = u2answer self.users = sorted(u2seq.keys()) self.max_len = max_len self.rng = rng self.text_dict = text_dict self.tokenizer = tokenizer self.prompter = prompter self.test_users = test_users self.test_candidates = test_candidates def __len__(self): return len(self.test_users) def __getitem__(self, index): user = self.test_users[index] seq = self.u2seq[user] + self.u2val[user] answer = self.u2answer[user][0] seq = seq[-self.max_len:] candidates = self.test_candidates[index] assert answer in candidates # self.rng.shuffle(candidates) return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True) ================================================ FILE: dataloader/lru.py ================================================ from .base import AbstractDataloader import os import torch import random import pickle import numpy as np import torch.utils.data as data_utils def worker_init_fn(worker_id): random.seed(np.random.get_state()[1][0] + worker_id) np.random.seed(np.random.get_state()[1][0] + worker_id) class LRUDataloader(): def __init__(self, args, dataset): self.args = args self.rng = np.random self.save_folder = dataset._get_preprocessed_folder_path() dataset = dataset.load_dataset() self.train = dataset['train'] self.val = dataset['val'] self.test = dataset['test'] self.umap = dataset['umap'] self.smap = dataset['smap'] self.user_count = len(self.umap) self.item_count = len(self.smap) args.num_users = self.user_count args.num_items = self.item_count self.max_len = args.bert_max_len self.sliding_size = args.sliding_window_size @classmethod def code(cls): return 'lru' def get_pytorch_dataloaders(self): train_loader = self._get_train_loader() val_loader = self._get_val_loader() test_loader = self._get_test_loader() return train_loader, val_loader, test_loader def get_pytorch_test_subset_dataloader(self): retrieved_file_path = self.args.llm_retrieved_path print('Loading retrieved file from {}'.format(retrieved_file_path)) retrieved_file = pickle.load(open(os.path.join(retrieved_file_path, 'retrieved.pkl'), 'rb')) test_probs = retrieved_file['test_probs'] test_labels = retrieved_file['test_labels'] test_users = [u for u, (p, l) in enumerate(zip(test_probs, test_labels), start=1) \ if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices] dataset = dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len, self.rng, subset_users=test_users) dataloader = data_utils.DataLoader(dataset, batch_size=self.args.val_batch_size, shuffle=False, pin_memory=True, num_workers=self.args.num_workers) return dataloader def _get_train_loader(self): dataset = self._get_train_dataset() dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size, shuffle=True, pin_memory=True, num_workers=self.args.num_workers, worker_init_fn=worker_init_fn) return dataloader def _get_train_dataset(self): dataset = LRUTrainDataset( self.args, self.train, self.max_len, self.sliding_size, self.rng) return dataset def _get_val_loader(self): return self._get_eval_loader(mode='val') def _get_test_loader(self): return self._get_eval_loader(mode='test') def _get_eval_loader(self, mode): batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size dataset = self._get_eval_dataset(mode) dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=self.args.num_workers) return dataloader def _get_eval_dataset(self, mode): if mode == 'val': dataset = LRUValidDataset(self.args, self.train, self.val, self.max_len, self.rng) elif mode == 'test': dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len, self.rng) return dataset class LRUTrainDataset(data_utils.Dataset): def __init__(self, args, u2seq, max_len, sliding_size, rng): self.args = args self.max_len = max_len self.sliding_step = int(sliding_size * max_len) self.num_items = args.num_items self.rng = rng assert self.sliding_step > 0 self.all_seqs = [] for u in sorted(u2seq.keys()): seq = u2seq[u] if len(seq) < self.max_len + self.sliding_step: self.all_seqs.append(seq) else: start_idx = range(len(seq) - max_len, -1, -self.sliding_step) self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx] def __len__(self): return len(self.all_seqs) def __getitem__(self, index): seq = self.all_seqs[index] labels = seq[-self.max_len:] tokens = seq[:-1][-self.max_len:] mask_len = self.max_len - len(tokens) tokens = [0] * mask_len + tokens mask_len = self.max_len - len(labels) labels = [0] * mask_len + labels return torch.LongTensor(tokens), torch.LongTensor(labels) class LRUValidDataset(data_utils.Dataset): def __init__(self, args, u2seq, u2answer, max_len, rng): self.args = args self.u2seq = u2seq self.u2answer = u2answer users = sorted(self.u2seq.keys()) self.users = [u for u in users if len(u2answer[u]) > 0] self.max_len = max_len self.rng = rng def __len__(self): return len(self.users) def __getitem__(self, index): user = self.users[index] seq = self.u2seq[user] answer = self.u2answer[user] seq = seq[-self.max_len:] padding_len = self.max_len - len(seq) seq = [0] * padding_len + seq return torch.LongTensor(seq), torch.LongTensor(answer) class LRUTestDataset(data_utils.Dataset): def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, subset_users=None): self.args = args self.u2seq = u2seq self.u2val = u2val self.u2answer = u2answer users = sorted(self.u2seq.keys()) self.users = [u for u in users if len(u2val[u]) > 0 and len(u2answer[u]) > 0] self.max_len = max_len self.rng = rng if subset_users is not None: self.users = subset_users def __len__(self): return len(self.users) def __getitem__(self, index): user = self.users[index] seq = self.u2seq[user] + self.u2val[user] answer = self.u2answer[user] seq = seq[-self.max_len:] padding_len = self.max_len - len(seq) seq = [0] * padding_len + seq return torch.LongTensor(seq), torch.LongTensor(answer) ================================================ FILE: dataloader/templates/README.md ================================================ # Prompt templates This directory contains template styles for the prompts used to finetune LoRA models. ## Format A template is described via a JSON file with the following keys: - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders. - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders. - `description`: A short description of the template, with possible use cases. - `response_split`: The text to use as separator when cutting real response from the model output. No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest. ## Example template The default template, used unless otherwise specified, is `alpaca.json` ```json { "description": "Template used by Alpaca-LoRA.", "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", "response_split": "### Response:" } ``` ## Current templates ### alpaca Default template used for generic LoRA fine tunes so far. ### alpaca_legacy Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments. ### alpaca_short A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome. ### vigogne The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning. ================================================ FILE: dataloader/templates/alpaca.json ================================================ { "description": "Template used by Alpaca-LoRA.", "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", "response_split": "### Response:" } ================================================ FILE: dataloader/templates/alpaca_legacy.json ================================================ { "description": "Legacy template, used by Original Alpaca repository.", "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", "response_split": "### Response:" } ================================================ FILE: dataloader/templates/alpaca_short.json ================================================ { "description": "A shorter template to experiment with.", "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", "response_split": "### Response:" } ================================================ FILE: dataloader/templates/vigogne.json ================================================ { "description": "French template, used by Vigogne for finetuning.", "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n", "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n", "response_split": "### Réponse:" } ================================================ FILE: dataloader/utils.py ================================================ import json import os.path as osp from typing import Union class Prompter(object): __slots__ = ("template", "_verbose") def __init__(self, template_name: str = "", verbose: bool = False): self._verbose = verbose if not template_name: # template_name = "alpaca" template_name = "alpaca_short" file_name = osp.join("dataloader", "templates", f"{template_name}.json") if not osp.exists(file_name): raise ValueError(f"Can't read {file_name}") with open(file_name) as fp: self.template = json.load(fp) if self._verbose: print( f"Using prompt template {template_name}: {self.template['description']}" ) def generate_prompt( self, instruction: str, input: Union[None, str] = None, label: Union[None, str] = None, ) -> str: if input: res = self.template["prompt_input"].format( instruction=instruction, input=input ) else: res = self.template["prompt_no_input"].format( instruction=instruction ) if label: res = f"{res}{label}" if self._verbose: print(res) return res def get_response(self, output: str) -> str: return output.split(self.template["response_split"])[1].strip() ================================================ FILE: datasets/__init__.py ================================================ from .ml_100k import ML100KDataset from .beauty import BeautyDataset from .games import GamesDataset DATASETS = { ML100KDataset.code(): ML100KDataset, BeautyDataset.code(): BeautyDataset, GamesDataset.code(): GamesDataset, } def dataset_factory(args): dataset = DATASETS[args.dataset_code] return dataset(args) ================================================ FILE: datasets/base.py ================================================ import pickle import shutil import tempfile import os from pathlib import Path import gzip from abc import * from .utils import * from config import RAW_DATASET_ROOT_FOLDER import numpy as np import pandas as pd from tqdm import tqdm tqdm.pandas() class AbstractDataset(metaclass=ABCMeta): def __init__(self, args): self.args = args self.min_rating = args.min_rating self.min_uc = args.min_uc self.min_sc = args.min_sc assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test' @classmethod @abstractmethod def code(cls): pass @classmethod def raw_code(cls): return cls.code() @classmethod def zip_file_content_is_folder(cls): return True @classmethod def all_raw_file_names(cls): return [] @classmethod @abstractmethod def url(cls): pass @abstractmethod def preprocess(self): pass @abstractmethod def load_ratings_df(self): pass @abstractmethod def maybe_download_raw_dataset(self): pass def load_dataset(self): self.preprocess() dataset_path = self._get_preprocessed_dataset_path() dataset = pickle.load(dataset_path.open('rb')) return dataset def filter_triplets(self, df): print('Filtering triplets') if self.min_sc > 1 or self.min_uc > 1: item_sizes = df.groupby('sid').size() good_items = item_sizes.index[item_sizes >= self.min_sc] user_sizes = df.groupby('uid').size() good_users = user_sizes.index[user_sizes >= self.min_uc] while len(good_items) < len(item_sizes) or len(good_users) < len(user_sizes): if self.min_sc > 1: item_sizes = df.groupby('sid').size() good_items = item_sizes.index[item_sizes >= self.min_sc] df = df[df['sid'].isin(good_items)] if self.min_uc > 1: user_sizes = df.groupby('uid').size() good_users = user_sizes.index[user_sizes >= self.min_uc] df = df[df['uid'].isin(good_users)] item_sizes = df.groupby('sid').size() good_items = item_sizes.index[item_sizes >= self.min_sc] user_sizes = df.groupby('uid').size() good_users = user_sizes.index[user_sizes >= self.min_uc] return df def densify_index(self, df): print('Densifying index') umap = {u: i for i, u in enumerate(set(df['uid']), start=1)} smap = {s: i for i, s in enumerate(set(df['sid']), start=1)} df['uid'] = df['uid'].map(umap) df['sid'] = df['sid'].map(smap) return df, umap, smap def split_df(self, df, user_count): print('Splitting') user_group = df.groupby('uid') user2items = user_group.progress_apply( lambda d: list(d.sort_values(by=['timestamp', 'sid'])['sid'])) train, val, test = {}, {}, {} for i in range(user_count): user = i + 1 items = user2items[user] train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:] return train, val, test def _get_rawdata_root_path(self): return Path(RAW_DATASET_ROOT_FOLDER) def _get_rawdata_folder_path(self): root = self._get_rawdata_root_path() return root.joinpath(self.raw_code()) def _get_preprocessed_root_path(self): root = self._get_rawdata_root_path() return root.joinpath('preprocessed') def _get_preprocessed_folder_path(self): preprocessed_root = self._get_preprocessed_root_path() folder_name = '{}_min_rating{}-min_uc{}-min_sc{}' \ .format(self.code(), self.min_rating, self.min_uc, self.min_sc) return preprocessed_root.joinpath(folder_name) def _get_preprocessed_dataset_path(self): folder = self._get_preprocessed_folder_path() return folder.joinpath('dataset.pkl') ================================================ FILE: datasets/beauty.py ================================================ from .base import AbstractDataset from .utils import * from datetime import date from pathlib import Path import pickle import shutil import tempfile import os import gzip import json import numpy as np import pandas as pd from tqdm import tqdm tqdm.pandas() class BeautyDataset(AbstractDataset): @classmethod def code(cls): return 'beauty' @classmethod def url(cls): return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Beauty.csv', 'http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Beauty.json.gz'] @classmethod def zip_file_content_is_folder(cls): return True @classmethod def all_raw_file_names(cls): return ['beauty.csv', 'beauty_meta.json.gz'] def maybe_download_raw_dataset(self): folder_path = self._get_rawdata_folder_path() if folder_path.is_dir() and\ all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()): print('Raw data already exists. Skip downloading') return print("Raw file doesn't exist. Downloading...") for idx, url in enumerate(self.url()): tmproot = Path(tempfile.mkdtemp()) tmpfile = tmproot.joinpath('file') download(url, tmpfile) os.makedirs(folder_path, exist_ok=True) shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx])) print() def preprocess(self): dataset_path = self._get_preprocessed_dataset_path() if dataset_path.is_file(): print('Already preprocessed. Skip preprocessing') return if not dataset_path.parent.is_dir(): dataset_path.parent.mkdir(parents=True) self.maybe_download_raw_dataset() df = self.load_ratings_df() meta_raw = self.load_meta_dict() df = df[df['sid'].isin(meta_raw)] # filter items without meta info df = self.filter_triplets(df) df, umap, smap = self.densify_index(df) train, val, test = self.split_df(df, len(umap)) meta = {smap[k]: v for k, v in meta_raw.items() if k in smap} dataset = {'train': train, 'val': val, 'test': test, 'meta': meta, 'umap': umap, 'smap': smap} with dataset_path.open('wb') as f: pickle.dump(dataset, f) def load_ratings_df(self): folder_path = self._get_rawdata_folder_path() file_path = folder_path.joinpath(self.all_raw_file_names()[0]) df = pd.read_csv(file_path, header=None) df.columns = ['uid', 'sid', 'rating', 'timestamp'] return df def load_meta_dict(self): folder_path = self._get_rawdata_folder_path() file_path = folder_path.joinpath(self.all_raw_file_names()[1]) meta_dict = {} with gzip.open(file_path, 'rb') as f: for line in f: item = eval(line) if 'title' in item and len(item['title']) > 0: meta_dict[item['asin'].strip()] = item['title'].strip() return meta_dict ================================================ FILE: datasets/games.py ================================================ from .base import AbstractDataset from .utils import * from datetime import date from pathlib import Path import pickle import shutil import tempfile import os import gzip import json import numpy as np import pandas as pd from tqdm import tqdm tqdm.pandas() class GamesDataset(AbstractDataset): @classmethod def code(cls): return 'games' @classmethod def url(cls): # meta_Video_Games.json.gz from snap.stanford.edu does not contain full meta info return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Video_Games.csv', 'https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2/metaFiles2/meta_Video_Games.json.gz'] @classmethod def zip_file_content_is_folder(cls): return True @classmethod def all_raw_file_names(cls): return ['games.csv', 'games_meta.json.gz'] def maybe_download_raw_dataset(self): folder_path = self._get_rawdata_folder_path() if folder_path.is_dir() and\ all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()): print('Raw data already exists. Skip downloading') return print("Raw file doesn't exist. Downloading...") for idx, url in enumerate(self.url()): tmproot = Path(tempfile.mkdtemp()) tmpfile = tmproot.joinpath('file') download(url, tmpfile) os.makedirs(folder_path, exist_ok=True) shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx])) print() def preprocess(self): dataset_path = self._get_preprocessed_dataset_path() if dataset_path.is_file(): print('Already preprocessed. Skip preprocessing') return if not dataset_path.parent.is_dir(): dataset_path.parent.mkdir(parents=True) self.maybe_download_raw_dataset() df = self.load_ratings_df() meta_raw = self.load_meta_dict() df = df[df['sid'].isin(meta_raw)] # filter items without meta info df = self.filter_triplets(df) df, umap, smap = self.densify_index(df) train, val, test = self.split_df(df, len(umap)) meta = {smap[k]: v for k, v in meta_raw.items() if k in smap} dataset = {'train': train, 'val': val, 'test': test, 'meta': meta, 'umap': umap, 'smap': smap} with dataset_path.open('wb') as f: pickle.dump(dataset, f) def load_ratings_df(self): folder_path = self._get_rawdata_folder_path() file_path = folder_path.joinpath(self.all_raw_file_names()[0]) df = pd.read_csv(file_path, header=None) df.columns = ['uid', 'sid', 'rating', 'timestamp'] return df def load_meta_dict(self): folder_path = self._get_rawdata_folder_path() file_path = folder_path.joinpath(self.all_raw_file_names()[1]) meta_dict = {} with gzip.open(file_path, 'rb') as f: for line in f: item = eval(line) if 'title' in item and len(item['title']) > 0: meta_dict[item['asin'].strip()] = item['title'].strip() return meta_dict ================================================ FILE: datasets/ml_100k.py ================================================ from .base import AbstractDataset from .utils import * from datetime import date from pathlib import Path import pickle import shutil import tempfile import os import re import numpy as np import pandas as pd from tqdm import tqdm tqdm.pandas() class ML100KDataset(AbstractDataset): @classmethod def code(cls): return 'ml-100k' @classmethod def url(cls): # as of Sep 2023 return 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip' @classmethod def zip_file_content_is_folder(cls): return True @classmethod def all_raw_file_names(cls): return ['README', 'movies.csv', 'ratings.csv', 'users.csv'] def maybe_download_raw_dataset(self): folder_path = self._get_rawdata_folder_path() if folder_path.is_dir() and\ all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()): print('Raw data already exists. Skip downloading') return print("Raw file doesn't exist. Downloading...") tmproot = Path(tempfile.mkdtemp()) tmpzip = tmproot.joinpath('file.zip') tmpfolder = tmproot.joinpath('folder') download(self.url(), tmpzip) unzip(tmpzip, tmpfolder) if self.zip_file_content_is_folder(): tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0]) shutil.move(tmpfolder, folder_path) shutil.rmtree(tmproot) print() def preprocess(self): dataset_path = self._get_preprocessed_dataset_path() if dataset_path.is_file(): print('Already preprocessed. Skip preprocessing') return if not dataset_path.parent.is_dir(): dataset_path.parent.mkdir(parents=True) self.maybe_download_raw_dataset() df = self.load_ratings_df() meta_raw = self.load_meta_dict() df = df[df['sid'].isin(meta_raw)] # filter items without meta info df = self.filter_triplets(df) df, umap, smap = self.densify_index(df) train, val, test = self.split_df(df, len(umap)) meta = {smap[k]: v for k, v in meta_raw.items() if k in smap} dataset = {'train': train, 'val': val, 'test': test, 'meta': meta, 'umap': umap, 'smap': smap} with dataset_path.open('wb') as f: pickle.dump(dataset, f) def load_ratings_df(self): folder_path = self._get_rawdata_folder_path() file_path = folder_path.joinpath('ratings.csv') df = pd.read_csv(file_path) df.columns = ['uid', 'sid', 'rating', 'timestamp'] return df def load_meta_dict(self): folder_path = self._get_rawdata_folder_path() file_path = folder_path.joinpath('movies.csv') df = pd.read_csv(file_path, encoding="ISO-8859-1") meta_dict = {} for row in df.itertuples(): title = row[2][:-7] # remove year (optional) year = row[2][-7:] title = re.sub('\(.*?\)', '', title).strip() # the rest articles and parentheses are not considered here if any(', '+x in title.lower()[-5:] for x in ['a', 'an', 'the']): title_pre = title.split(', ')[:-1] title_post = title.split(', ')[-1] title_pre = ', '.join(title_pre) title = title_post + ' ' + title_pre meta_dict[row[1]] = title + year return meta_dict ================================================ FILE: datasets/utils.py ================================================ import numpy as np import pandas as pd from tqdm import tqdm import urllib.request from pathlib import Path import zipfile import tarfile import sys def download(url, savepath): urllib.request.urlretrieve(url, str(savepath)) print() def unzip(zippath, savepath): print("Extracting data...") zip = zipfile.ZipFile(zippath) zip.extractall(savepath) zip.close() def unziptargz(zippath, savepath): print("Extracting data...") f = tarfile.open(zippath) f.extractall(savepath) f.close() ================================================ FILE: model/__init__.py ================================================ from .lru import * from .llm import * ================================================ FILE: model/llm.py ================================================ # coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch LLaMA model.""" import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.models.llama.configuration_llama import LlamaConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) # Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = t / self.scaling_factor freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() self.pretraining_tp = config.pretraining_tp self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): if self.pretraining_tp > 1: slice = self.intermediate_size // self.pretraining_tp gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) up_proj_slices = self.up_proj.weight.split(slice, dim=0) down_proj_slices = self.down_proj.weight.split(slice, dim=1) gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.pretraining_tp = config.pretraining_tp self.max_position_embeddings = config.max_position_embeddings if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = LlamaAttention(config=config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs LLAMA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, LlamaModel): module.gradient_checkpointing = value LLAMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaModel(LlamaPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config: LlamaConfig """ def __init__(self, config: LlamaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class LlamaForCausalLM(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() loss = None if self.training and labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) elif labels is not None: loss = torch.tensor(-1.) # loss cannot be directly computed in inference logits = logits[:, -1] # we only need last position logits for inference if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past ================================================ FILE: model/lru.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np class LRURec(nn.Module): def __init__(self, args): super().__init__() self.args = args self.embedding = LRUEmbedding(self.args) self.model = LRUModel(self.args) self.truncated_normal_init() def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04): with torch.no_grad(): l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2. u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2. for n, p in self.named_parameters(): if not 'layer_norm' in n and 'params_log' not in n: if torch.is_complex(p): p.real.uniform_(2 * l - 1, 2 * u - 1) p.imag.uniform_(2 * l - 1, 2 * u - 1) p.real.erfinv_() p.imag.erfinv_() p.real.mul_(std * math.sqrt(2.)) p.imag.mul_(std * math.sqrt(2.)) p.real.add_(mean) p.imag.add_(mean) else: p.uniform_(2 * l - 1, 2 * u - 1) p.erfinv_() p.mul_(std * math.sqrt(2.)) p.add_(mean) def forward(self, x): x, mask = self.embedding(x) scores = self.model(x, self.embedding.token.weight, mask) return scores class LRUEmbedding(nn.Module): def __init__(self, args): super().__init__() vocab_size = args.num_items + 1 embed_size = args.bert_hidden_units self.token = nn.Embedding(vocab_size, embed_size) self.layer_norm = nn.LayerNorm(embed_size) self.embed_dropout = nn.Dropout(args.bert_dropout) def get_mask(self, x): return (x > 0) def forward(self, x): mask = self.get_mask(x) x = self.token(x) return self.layer_norm(self.embed_dropout(x)), mask class LRUModel(nn.Module): def __init__(self, args): super().__init__() self.args = args self.hidden_size = args.bert_hidden_units layers = args.bert_num_blocks self.lru_blocks = nn.ModuleList([LRUBlock(self.args) for _ in range(layers)]) self.bias = torch.nn.Parameter(torch.zeros(args.num_items + 1)) def forward(self, x, embedding_weight, mask): # left padding to the power of 2 seq_len = x.size(1) log2_L = int(np.ceil(np.log2(seq_len))) x = F.pad(x, (0, 0, 2 ** log2_L - x.size(1), 0, 0, 0)) mask_ = F.pad(mask, (2 ** log2_L - mask.size(1), 0, 0, 0)) # LRU blocks with pffn for lru_block in self.lru_blocks: x = lru_block.forward(x, mask_) x = x[:, -seq_len:] # B x L x D (64) scores = torch.matmul(x, embedding_weight.permute(1, 0)) + self.bias return scores class LRUBlock(nn.Module): def __init__(self, args): super().__init__() self.args = args hidden_size = args.bert_hidden_units self.lru_layer = LRULayer( d_model=hidden_size, dropout=args.bert_attn_dropout) self.feed_forward = PositionwiseFeedForward( d_model=hidden_size, d_ff=hidden_size*4, dropout=args.bert_dropout) def forward(self, x, mask): x = self.lru_layer(x, mask) x = self.feed_forward(x) return x class LRULayer(nn.Module): def __init__(self, d_model, dropout=0.1, use_bias=True, r_min=0.8, r_max=0.99): super().__init__() self.embed_size = d_model self.hidden_size = 2 * d_model self.use_bias = use_bias # init nu, theta, gamma u1 = torch.rand(self.hidden_size) u2 = torch.rand(self.hidden_size) nu_log = torch.log(-0.5 * torch.log(u1 * (r_max ** 2 - r_min ** 2) + r_min ** 2)) theta_log = torch.log(u2 * torch.tensor(np.pi) * 2) diag_lambda = torch.exp(torch.complex(-torch.exp(nu_log), torch.exp(theta_log))) gamma_log = torch.log(torch.sqrt(1 - torch.abs(diag_lambda) ** 2)) self.params_log = nn.Parameter(torch.vstack((nu_log, theta_log, gamma_log))) # Init B, C, D self.in_proj = nn.Linear(self.embed_size, self.hidden_size, bias=use_bias).to(torch.cfloat) self.out_proj = nn.Linear(self.hidden_size, self.embed_size, bias=use_bias).to(torch.cfloat) # self.out_vector = nn.Parameter(torch.rand(self.embed_size)) self.out_vector = nn.Identity() # Dropout and layer norm self.dropout = nn.Dropout(p=dropout) self.layer_norm = nn.LayerNorm(self.embed_size) def lru_parallel(self, i, h, lamb, mask, B, L, D): # Parallel algorithm, see: https://kexue.fm/archives/9554#%E5%B9%B6%E8%A1%8C%E5%8C%96 # The original implementation is slightly slower and does not consider 0 padding l = 2 ** i h = h.reshape(B * L // l, l, D) # (B, L, D) -> (B * L // 2, 2, D) mask_ = mask.reshape(B * L // l, l) # (B, L) -> (B * L // 2, 2) h1, h2 = h[:, :l // 2], h[:, l // 2:] # Divide data in half if i > 1: lamb = torch.cat((lamb, lamb * lamb[-1]), 0) h2 = h2 + lamb * h1[:, -1:] * mask_[:, l // 2 - 1:l // 2].unsqueeze(-1) h = torch.cat([h1, h2], axis=1) return h, lamb def forward(self, x, mask): # compute bu and lambda nu, theta, gamma = torch.exp(self.params_log).split((1, 1, 1)) lamb = torch.exp(torch.complex(-nu, theta)) h = self.in_proj(x.to(torch.cfloat)) * gamma # bu # compute h in parallel log2_L = int(np.ceil(np.log2(h.size(1)))) B, L, D = h.size(0), h.size(1), h.size(2) for i in range(log2_L): h, lamb = self.lru_parallel(i + 1, h, lamb, mask, B, L, D) x = self.dropout(self.out_proj(h).real) + self.out_vector(x) return self.layer_norm(x) # residual connection introduced above class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.activation = nn.GELU() self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x): x_ = self.dropout(self.activation(self.w_1(x))) return self.layer_norm(self.dropout(self.w_2(x_)) + x) ================================================ FILE: requirements.txt ================================================ # This file may be used to create an environment using: # $ conda create --name --file # platform: linux-64 _libgcc_mutex=0.1=main _openmp_mutex=5.1=1_gnu absl-py=1.4.0=pypi_0 accelerate=0.21.0=pypi_0 aiofiles=23.1.0=pypi_0 aiohttp=3.8.4=pypi_0 aiosignal=1.3.1=pypi_0 alabaster=0.7.13=pypi_0 altair=4.2.2=pypi_0 aniso8601=9.0.1=pypi_0 antlr4-python3-runtime=4.9.3=pypi_0 anyio=3.6.2=pypi_0 appdirs=1.4.4=pypi_0 asttokens=2.2.1=pypi_0 async-timeout=4.0.2=pypi_0 attrdict=2.0.1=pypi_0 attrs=23.1.0=pypi_0 audioread=3.0.0=pypi_0 babel=2.12.1=pypi_0 backcall=0.2.0=pypi_0 beautifulsoup4=4.12.2=pypi_0 bitsandbytes=0.41.1=pypi_0 black=19.10b0=pypi_0 blas=1.0=mkl boto3=1.26.160=pypi_0 botocore=1.29.160=pypi_0 braceexpand=0.1.7=pypi_0 brotlipy=0.7.0=py310h7f8727e_1002 bzip2=1.0.8=h7b6447c_0 ca-certificates=2023.08.22=h06a4308_0 cachetools=5.3.0=pypi_0 cdifflib=1.2.6=pypi_0 certifi=2022.12.7=py310h06a4308_0 cffi=1.15.1=py310h5eee18b_3 charset-normalizer=2.0.4=pyhd3eb1b0_0 click=8.0.2=pypi_0 clip=1.0=pypi_0 colorama=0.4.6=pypi_0 comm=0.1.3=pypi_0 contourpy=1.0.7=pypi_0 cryptography=39.0.1=py310h9ce1e76_0 cuda-cudart=11.8.89=0 cuda-cupti=11.8.87=0 cuda-libraries=11.8.0=0 cuda-nvrtc=11.8.89=0 cuda-nvtx=11.8.86=0 cuda-runtime=11.8.0=0 cudatoolkit=11.8.0=h6a678d5_0 cycler=0.11.0=pypi_0 cython=0.29.35=pypi_0 datasets=2.14.5=pypi_0 debugpy=1.6.7=pypi_0 decorator=5.1.1=pypi_0 dill=0.3.6=pypi_0 distance=0.1.3=pypi_0 docker-pycreds=0.4.0=pypi_0 docopt=0.6.2=pypi_0 docutils=0.20.1=pypi_0 editdistance=0.6.2=pypi_0 einops=0.6.1=pypi_0 emoji=2.8.0=pypi_0 entrypoints=0.4=pypi_0 evaluate=0.3.0=pypi_0 exceptiongroup=1.1.1=pypi_0 executing=1.2.0=pypi_0 fairscale=0.4.13=pypi_0 faiss-cpu=1.7.4=pypi_0 faiss-gpu=1.7.2=pypi_0 fastapi=0.95.1=pypi_0 fasttext=0.9.2=pypi_0 ffmpeg=4.3=hf484d3e_0 ffmpy=0.3.0=pypi_0 filelock=3.9.0=py310h06a4308_0 fire=0.5.0=pypi_0 flask=2.2.5=pypi_0 flask-restful=0.3.10=pypi_0 flit-core=3.8.0=py310h06a4308_0 fonttools=4.39.3=pypi_0 freetype=2.12.1=h4a9f257_0 frozenlist=1.3.3=pypi_0 fsspec=2023.4.0=pypi_0 ftfy=6.1.1=pypi_0 future=0.18.3=pypi_0 g2p-en=2.1.0=pypi_0 gdown=4.7.1=pypi_0 giflib=5.2.1=h5eee18b_3 gitdb=4.0.10=pypi_0 gitpython=3.1.31=pypi_0 gmp=6.2.1=h295c915_3 gmpy2=2.1.2=py310heeb90bb_0 gnutls=3.6.15=he1e5248_0 google-auth=2.17.3=pypi_0 google-auth-oauthlib=1.0.0=pypi_0 gradio=3.28.3=pypi_0 gradio-client=0.1.3=pypi_0 grpcio=1.54.0=pypi_0 h11=0.14.0=pypi_0 h5py=3.9.0=pypi_0 httpcore=0.17.0=pypi_0 httpx=0.24.0=pypi_0 huggingface-hub=0.16.4=pypi_0 hydra-core=1.2.0=pypi_0 idna=3.4=py310h06a4308_0 ijson=3.2.2=pypi_0 imageio=2.28.0=pypi_0 imagesize=1.4.1=pypi_0 inflect=6.0.4=pypi_0 iniconfig=2.0.0=pypi_0 intel-openmp=2021.4.0=h06a4308_3561 ipadic=1.0.0=pypi_0 ipykernel=6.23.3=pypi_0 ipython=8.12.0=pypi_0 ipywidgets=8.0.6=pypi_0 isort=5.12.0=pypi_0 itsdangerous=2.1.2=pypi_0 jedi=0.18.2=pypi_0 jieba=0.42.1=pypi_0 jinja2=3.1.2=py310h06a4308_0 jiwer=2.5.2=pypi_0 jmespath=1.0.1=pypi_0 joblib=1.2.0=pypi_0 jpeg=9e=h5eee18b_1 jsonschema=4.17.3=pypi_0 jupyter-client=8.3.0=pypi_0 jupyter-core=5.3.1=pypi_0 jupyterlab-widgets=3.0.7=pypi_0 kaldi-python-io=1.2.2=pypi_0 kaldiio=2.18.0=pypi_0 kiwisolver=1.4.4=pypi_0 kornia=0.6.12=pypi_0 lame=3.100=h7b6447c_0 latexcodec=2.0.1=pypi_0 lazy-loader=0.2=pypi_0 lcms2=2.12=h3be6417_0 ld_impl_linux-64=2.38=h1181459_1 lerc=3.0=h295c915_0 levenshtein=0.21.1=pypi_0 libcublas=11.11.3.6=0 libcufft=10.9.0.58=0 libcufile=1.6.0.25=0 libcurand=10.3.2.56=0 libcusolver=11.4.1.48=0 libcusparse=11.7.5.86=0 libdeflate=1.17=h5eee18b_0 libffi=3.4.2=h6a678d5_6 libgcc-ng=11.2.0=h1234567_1 libgomp=11.2.0=h1234567_1 libiconv=1.16=h7f8727e_2 libidn2=2.3.2=h7f8727e_0 libnpp=11.8.0.86=0 libnvjpeg=11.9.0.86=0 libpng=1.6.39=h5eee18b_0 librosa=0.10.0.post2=pypi_0 libstdcxx-ng=11.2.0=h1234567_1 libtasn1=4.16.0=h27cfd23_0 libtiff=4.5.0=h6a678d5_2 libunistring=0.9.10=h27cfd23_0 libuuid=1.41.5=h5eee18b_0 libwebp=1.2.4=h11a3e52_1 libwebp-base=1.2.4=h5eee18b_1 lightning-utilities=0.8.0=pypi_0 linkify-it-py=2.0.0=pypi_0 llvmlite=0.40.1=pypi_0 loguru=0.7.0=pypi_0 loralib=0.1.1=pypi_0 lxml=4.9.2=pypi_0 lz4-c=1.9.4=h6a678d5_0 markdown=3.4.3=pypi_0 markdown-it-py=2.2.0=pypi_0 markdown2=2.4.9=pypi_0 markupsafe=2.1.1=py310h7f8727e_0 marshmallow=3.19.0=pypi_0 matplotlib=3.7.1=pypi_0 matplotlib-inline=0.1.6=pypi_0 mdit-py-plugins=0.3.3=pypi_0 mdurl=0.1.2=pypi_0 mecab-python3=1.0.5=pypi_0 megatron-core=0.2.0=pypi_0 mkl=2021.4.0=h06a4308_640 mkl-service=2.4.0=py310h7f8727e_0 mkl_fft=1.3.1=py310hd6ae3a3_0 mkl_random=1.2.2=py310h00e6091_0 mpc=1.1.0=h10f8cd9_1 mpfr=4.0.2=hb69a4c5_1 mpmath=1.2.1=pypi_0 msgpack=1.0.5=pypi_0 multidict=6.0.4=pypi_0 multiprocess=0.70.14=pypi_0 mypy-extensions=1.0.0=pypi_0 ncurses=6.4=h6a678d5_0 nest-asyncio=1.5.6=pypi_0 nettle=3.7.3=hbbd107a_1 networkx=2.8.4=py310h06a4308_1 nltk=3.8=pypi_0 numba=0.57.1=pypi_0 numpy=1.23.4=pypi_0 numpy-base=1.24.3=py310h8e6c178_0 oauthlib=3.2.2=pypi_0 omegaconf=2.2.3=pypi_0 onnx=1.14.0=pypi_0 openai=0.27.1=pypi_0 opencc=1.1.6=pypi_0 opencv-python=4.7.0.72=pypi_0 openh264=2.1.1=h4ff587b_0 openprompt=1.0.1=pypi_0 openssl=1.1.1w=h7f8727e_0 orjson=3.8.10=pypi_0 packaging=23.1=pypi_0 pandas=2.0.1=pypi_0 pangu=4.0.6.1=pypi_0 parameterized=0.9.0=pypi_0 parso=0.8.3=pypi_0 pathspec=0.11.1=pypi_0 pathtools=0.1.2=pypi_0 peft=0.5.0=pypi_0 pexpect=4.8.0=pypi_0 pickleshare=0.7.5=pypi_0 pillow=9.4.0=py310h6a678d5_0 pip=23.2.1=pypi_0 plac=1.3.5=pypi_0 platformdirs=3.4.0=pypi_0 pluggy=1.2.0=pypi_0 pooch=1.6.0=pypi_0 portalocker=2.7.0=pypi_0 progress=1.6=pypi_0 prompt-toolkit=3.0.38=pypi_0 protobuf=3.20.3=pypi_0 psutil=5.9.5=pypi_0 ptyprocess=0.7.0=pypi_0 pure-eval=0.2.2=pypi_0 pyannote-core=5.0.0=pypi_0 pyannote-database=5.0.1=pypi_0 pyannote-metrics=3.2.1=pypi_0 pyarrow=11.0.0=pypi_0 pyasn1=0.5.0=pypi_0 pyasn1-modules=0.3.0=pypi_0 pybind11=2.10.4=pypi_0 pybtex=0.24.0=pypi_0 pybtex-docutils=1.0.2=pypi_0 pycparser=2.21=pyhd3eb1b0_0 pydantic=1.10.7=pypi_0 pydeprecate=0.3.1=pypi_0 pydub=0.25.1=pypi_0 pygments=2.15.1=pypi_0 pynini=2.1.5=pypi_0 pyopenssl=23.0.0=py310h06a4308_0 pyparsing=3.0.9=pypi_0 pypinyin=0.49.0=pypi_0 pypinyin-dict=0.6.0=pypi_0 pyrsistent=0.19.3=pypi_0 pysocks=1.7.1=py310h06a4308_0 pytest=7.4.0=pypi_0 pytest-runner=6.0.0=pypi_0 python=3.10.10=h7a1cb2a_2 python-dateutil=2.8.2=pypi_0 python-multipart=0.0.6=pypi_0 pytorch=2.0.0=py3.10_cuda11.8_cudnn8.7.0_0 pytorch-cuda=11.8=h7e8668a_3 pytorch-lightning=1.9.4=pypi_0 pytorch-mutex=1.0=cuda pytz=2023.3=pypi_0 pyyaml=5.4.1=pypi_0 pyzmq=25.1.0=pypi_0 rank-bm25=0.2.2=pypi_0 rapidfuzz=2.13.7=pypi_0 readline=8.2=h5eee18b_0 regex=2023.3.23=pypi_0 requests=2.28.1=py310h06a4308_1 requests-oauthlib=1.3.1=pypi_0 responses=0.18.0=pypi_0 rich=13.4.2=pypi_0 risparser=0.4.4=pypi_0 rouge=1.0.0=pypi_0 rouge-score=0.1.2=pypi_0 rsa=4.9=pypi_0 ruamel-yaml=0.17.32=pypi_0 ruamel-yaml-clib=0.2.7=pypi_0 s3transfer=0.6.1=pypi_0 sacrebleu=2.3.1=pypi_0 sacremoses=0.0.53=pypi_0 safetensors=0.3.1=pypi_0 scikit-learn=1.2.1=pypi_0 scipy=1.10.1=pypi_0 seaborn=0.12.2=pypi_0 semantic-version=2.10.0=pypi_0 sentence-transformers=2.2.2=pypi_0 sentencepiece=0.1.96=pypi_0 sentry-sdk=1.25.1=pypi_0 setproctitle=1.3.2=pypi_0 setuptools=65.5.1=pypi_0 shellingham=1.5.0.post1=pypi_0 six=1.16.0=pyhd3eb1b0_1 smmap=5.0.0=pypi_0 sniffio=1.3.0=pypi_0 snowballstemmer=2.2.0=pypi_0 sortedcontainers=2.4.0=pypi_0 soundfile=0.12.1=pypi_0 soupsieve=2.4.1=pypi_0 sox=1.4.1=pypi_0 soxr=0.3.5=pypi_0 sphinx=7.0.1=pypi_0 sphinxcontrib-applehelp=1.0.4=pypi_0 sphinxcontrib-bibtex=2.5.0=pypi_0 sphinxcontrib-devhelp=1.0.2=pypi_0 sphinxcontrib-htmlhelp=2.0.1=pypi_0 sphinxcontrib-jsmath=1.0.1=pypi_0 sphinxcontrib-qthelp=1.0.3=pypi_0 sphinxcontrib-serializinghtml=1.1.5=pypi_0 sqlite=3.41.1=h5eee18b_0 stack-data=0.6.2=pypi_0 starlette=0.26.1=pypi_0 sympy=1.11.1=py310h06a4308_0 tabulate=0.9.0=pypi_0 taming-transformers=0.0.1=pypi_0 taming-transformers-rom1504=0.0.6=pypi_0 tensorboard=2.12.2=pypi_0 tensorboard-data-server=0.7.0=pypi_0 tensorboard-plugin-wit=1.8.1=pypi_0 tensorboardx=2.6.2.2=pypi_0 termcolor=2.2.0=pypi_0 test-tube=0.7.5=pypi_0 text-unidecode=1.3=pypi_0 textdistance=4.5.0=pypi_0 texterrors=0.4.4=pypi_0 threadpoolctl=3.1.0=pypi_0 tk=8.6.12=h1ccaba5_0 tokenize-rt=5.0.0=pypi_0 tokenizers=0.13.3=pypi_0 toml=0.10.2=pypi_0 tomli=2.0.1=pypi_0 toolz=0.12.0=pypi_0 torchaudio=2.0.0=py310_cu118 torchmetrics=0.11.4=pypi_0 torchtriton=2.0.0=py310 torchvision=0.15.0=py310_cu118 tornado=6.3.2=pypi_0 tqdm=4.64.1=pypi_0 traitlets=5.9.0=pypi_0 transformers=4.33.3=pypi_0 trl=0.7.1=pypi_0 tweet-preprocessor=0.6.0=pypi_0 typed-ast=1.5.4=pypi_0 typer=0.9.0=pypi_0 typing_extensions=4.4.0=py310h06a4308_0 tzdata=2023.3=pypi_0 uc-micro-py=1.0.1=pypi_0 unidecode=1.3.7=pypi_0 urllib3=1.26.15=py310h06a4308_0 uvicorn=0.21.1=pypi_0 wandb=0.15.4=pypi_0 wcwidth=0.2.6=pypi_0 webdataset=0.1.62=pypi_0 websockets=11.0.2=pypi_0 werkzeug=2.3.0=pypi_0 wget=3.2=pypi_0 wheel=0.38.4=py310h06a4308_0 widgetsnbextension=4.0.7=pypi_0 wrapt=1.15.0=pypi_0 xxhash=3.2.0=pypi_0 xz=5.2.10=h5eee18b_1 yacs=0.1.8=pypi_0 yarl=1.9.2=pypi_0 youtokentome=1.0.6=pypi_0 zlib=1.2.13=h5eee18b_0 zstd=1.5.4=hc292b87_0 ================================================ FILE: train_ranker.py ================================================ import os import torch os.environ['TOKENIZERS_PARALLELISM'] = 'false' import argparse from datasets import DATASETS from config import * from model import * from dataloader import * from trainer import * from transformers import BitsAndBytesConfig from pytorch_lightning import seed_everything from model import LlamaForCausalLM from peft import ( LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training, prepare_model_for_kbit_training, ) try: os.environ['WANDB_PROJECT'] = PROJECT_NAME except: print('WANDB_PROJECT not available, please set it in config.py') def main(args, export_root=None): seed_everything(args.seed) if export_root == None: export_root = EXPERIMENT_ROOT + '/' + args.llm_base_model.split('/')[-1] + '/' + args.dataset_code train_loader, val_loader, test_loader, tokenizer, test_retrieval = dataloader_factory(args) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model = LlamaForCausalLM.from_pretrained( args.llm_base_model, quantization_config=bnb_config, device_map='auto', cache_dir=args.llm_cache_dir, ) model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model) config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=args.lora_target_modules, lora_dropout=args.lora_dropout, bias='none', task_type="CAUSAL_LM", ) model = get_peft_model(model, config) model.print_trainable_parameters() model.config.use_cache = False trainer = LLMTrainer(args, model, train_loader, val_loader, test_loader, tokenizer, export_root, args.use_wandb) trainer.train() trainer.test(test_retrieval) if __name__ == "__main__": args.model_code = 'llm' set_template(args) main(args, export_root=None) ================================================ FILE: train_retriever.py ================================================ import os import torch os.environ['TOKENIZERS_PARALLELISM'] = 'false' import wandb import argparse from config import * from model import * from dataloader import * from trainer import * from pytorch_lightning import seed_everything try: os.environ['WANDB_PROJECT'] = PROJECT_NAME except: print('WANDB_PROJECT not available, please set it in config.py') def main(args, export_root=None): seed_everything(args.seed) train_loader, val_loader, test_loader = dataloader_factory(args) model = LRURec(args) if export_root == None: export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code trainer = LRUTrainer(args, model, train_loader, val_loader, test_loader, export_root, args.use_wandb) trainer.train() trainer.test() # the next line generates val / test candidates for reranking trainer.generate_candidates(os.path.join(export_root, 'retrieved.pkl')) if __name__ == "__main__": args.model_code = 'lru' set_template(args) main(args, export_root=None) # # searching best hyperparameters # for decay in [0, 0.01]: # for dropout in [0, 0.1, 0.2, 0.3, 0.4, 0.5]: # args.weight_decay = decay # args.bert_dropout = dropout # args.bert_attn_dropout = dropout # export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code + '/' + str(decay) + '_' + str(dropout) # main(args, export_root=export_root) ================================================ FILE: trainer/__init__.py ================================================ from .lru import * from .llm import * from .utils import * ================================================ FILE: trainer/base.py ================================================ from model import * from config import * from .utils import * from .loggers import * import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm import json import numpy as np from abc import ABCMeta from pathlib import Path from collections import OrderedDict class BaseTrainer(metaclass=ABCMeta): def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb=True): self.args = args self.device = args.device self.model = model.to(self.device) self.num_epochs = args.num_epochs self.metric_ks = args.metric_ks self.best_metric = args.best_metric self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.optimizer = self._create_optimizer() if args.enable_lr_schedule: if args.enable_lr_warmup: self.lr_scheduler = self.get_linear_schedule_with_warmup( self.optimizer, args.warmup_steps, len(self.train_loader) * self.num_epochs) else: self.lr_scheduler = optim.lr_scheduler.StepLR( self.optimizer, step_size=args.decay_step, gamma=args.gamma) self.export_root = export_root if not os.path.exists(self.export_root): Path(self.export_root).mkdir(parents=True) self.use_wandb = use_wandb if use_wandb: import wandb wandb.init( name=self.args.model_code+'_'+self.args.dataset_code, project=PROJECT_NAME, config=args, ) writer = wandb else: from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter( log_dir=Path(self.export_root).joinpath('logs'), comment=self.args.model_code+'_'+self.args.dataset_code, ) self.val_loggers, self.test_loggers = self._create_loggers() self.logger_service = LoggerService( self.args, writer, self.val_loggers, self.test_loggers, use_wandb) print(args) def train(self): accum_iter = 0 self.exit_training = self.validate(0, accum_iter) for epoch in range(self.num_epochs): accum_iter = self.train_one_epoch(epoch, accum_iter) if self.args.val_strategy == 'epoch': self.exit_training = self.validate(epoch, accum_iter) # val after every epoch if self.exit_training: print('Early stopping triggered. Exit training') break self.logger_service.complete() def train_one_epoch(self, epoch, accum_iter): average_meter_set = AverageMeterSet() tqdm_dataloader = tqdm(self.train_loader) for batch_idx, batch in enumerate(tqdm_dataloader): self.model.train() batch = self.to_device(batch) self.optimizer.zero_grad() loss = self.calculate_loss(batch) loss.backward() self.clip_gradients(self.args.max_grad_norm) self.optimizer.step() if self.args.enable_lr_schedule: self.lr_scheduler.step() average_meter_set.update('loss', loss.item()) tqdm_dataloader.set_description( 'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg)) accum_iter += 1 if self.args.val_strategy == 'iteration' and accum_iter % self.args.val_iterations == 0: self.exit_training = self.validate(epoch, accum_iter) # val after certain iterations if self.exit_training: break return accum_iter def validate(self, epoch, accum_iter): self.model.eval() average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(self.val_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = self.to_device(batch) metrics = self.calculate_metrics(batch, exclude_history=False) # faster validation self._update_meter_set(average_meter_set, metrics) self._update_dataloader_metrics( tqdm_dataloader, average_meter_set) log_data = { 'state_dict': (self._create_state_dict()), 'epoch': epoch+1, 'accum_iter': accum_iter, } log_data.update(average_meter_set.averages()) return self.logger_service.log_val(log_data) # early stopping def test(self, epoch=-1, accum_iter=-1, save_name=None): print('******************** Testing Best Model ********************') best_model_dict = torch.load(os.path.join( self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY) self.model.load_state_dict(best_model_dict) self.model.eval() average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(self.test_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = self.to_device(batch) metrics = self.calculate_metrics(batch) self._update_meter_set(average_meter_set, metrics) self._update_dataloader_metrics( tqdm_dataloader, average_meter_set) log_data = { 'state_dict': (self._create_state_dict()), 'epoch': epoch+1, 'accum_iter': accum_iter, } average_metrics = average_meter_set.averages() log_data.update(average_metrics) self.logger_service.log_test(log_data) print('******************** Testing Metrics ********************') print(average_metrics) file_name = 'test_metrics.json' if save_name is None else save_name with open(os.path.join(self.export_root, file_name), 'w') as f: json.dump(average_metrics, f, indent=4) return average_metrics def to_device(self, batch): return [x.to(self.device) for x in batch] @abstractmethod def calculate_loss(self, batch): pass @abstractmethod def calculate_metrics(self, batch): pass def clip_gradients(self, limit=1.0): nn.utils.clip_grad_norm_(self.model.parameters(), limit) def _update_meter_set(self, meter_set, metrics): for k, v in metrics.items(): meter_set.update(k, v) def _update_dataloader_metrics(self, tqdm_dataloader, meter_set): description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3] ] + ['Recall@%d' % k for k in self.metric_ks[:3]] description = 'Eval: ' + \ ', '.join(s + ' {:.4f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace('Recall', 'R') description = description.format( *(meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) def _create_optimizer(self): args = self.args param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'layer_norm'] optimizer_grouped_parameters = [ { 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, }, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.}, ] if args.optimizer.lower() == 'adamw': return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) elif args.optimizer.lower() == 'adam': return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay) else: raise NotImplementedError def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def _create_loggers(self): root = Path(self.export_root) model_checkpoint = root.joinpath('models') val_loggers, test_loggers = [], [] for k in self.metric_ks: val_loggers.append( MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation', use_wandb=self.use_wandb)) val_loggers.append( MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation', use_wandb=self.use_wandb)) val_loggers.append( MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Validation', use_wandb=self.use_wandb)) val_loggers.append(RecentModelLogger(self.args, model_checkpoint)) val_loggers.append(BestModelLogger(self.args, model_checkpoint, metric_key=self.best_metric)) for k in self.metric_ks: test_loggers.append( MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Test', use_wandb=self.use_wandb)) test_loggers.append( MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Test', use_wandb=self.use_wandb)) test_loggers.append( MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Test', use_wandb=self.use_wandb)) return val_loggers, test_loggers def _create_state_dict(self): return { STATE_DICT_KEY: self.model.state_dict(), OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(), } ================================================ FILE: trainer/llm.py ================================================ from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY from .verb import ManualVerbalizer from .utils import * from .loggers import * from .base import * import re import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import json import numpy as np from abc import * from pathlib import Path import bitsandbytes as bnb from transformers.trainer import * from transformers import Trainer, TrainingArguments, EarlyStoppingCallback def llama_collate_fn_w_truncation(llm_max_length, eval=False): def llama_collate_fn(batch): all_input_ids = [] all_attention_mask = [] all_labels = [] example_max_length = max([len(batch[idx]['input_ids']) for idx in range(len(batch))]) max_length = min(llm_max_length, example_max_length) for i in range(len(batch)): input_ids = batch[i]['input_ids'] attention_mask = batch[i]['attention_mask'] labels = batch[i]['labels'] if len(input_ids) > max_length: input_ids = input_ids[-max_length:] attention_mask = attention_mask[-max_length:] if not eval: labels = labels[-max_length:] elif len(input_ids) < max_length: padding_length = max_length - len(input_ids) input_ids = [0] * padding_length + input_ids attention_mask = [0] * padding_length + attention_mask if not eval: labels = [-100] * padding_length + labels if eval: assert input_ids[-1] == 13 else: assert input_ids[-3] == 13 and input_ids[-1] == 2 assert labels[-3] == -100 and labels[-2] != -100 all_input_ids.append(torch.tensor(input_ids).long()) all_attention_mask.append(torch.tensor(attention_mask).long()) all_labels.append(torch.tensor(labels).long()) return { 'input_ids': torch.vstack(all_input_ids), 'attention_mask': torch.vstack(all_attention_mask), 'labels': torch.vstack(all_labels) } return llama_collate_fn def compute_metrics_for_ks(ks, verbalizer): def compute_metrics(eval_pred): logits, labels = eval_pred logits = torch.tensor(logits) labels = torch.tensor(labels).view(-1) scores = verbalizer.process_logits(logits) metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels, ks) return metrics return compute_metrics class LLMTrainer(Trainer): def __init__( self, args, model, train_loader, val_loader, test_loader, tokenizer, export_root, use_wandb, **kwargs ): self.original_args = args self.export_root = export_root self.use_wandb = use_wandb self.llm_max_text_len = args.llm_max_text_len self.rerank_metric_ks = args.rerank_metric_ks self.verbalizer = ManualVerbalizer( tokenizer=tokenizer, prefix='', post_log_softmax=False, classes=list(range(args.llm_negative_sample_size+1)), label_words={i: chr(ord('A')+i) for i in range(args.llm_negative_sample_size+1)}, ) hf_args = TrainingArguments( per_device_train_batch_size=args.lora_micro_batch_size, gradient_accumulation_steps=args.train_batch_size//args.lora_micro_batch_size, warmup_steps=args.warmup_steps, num_train_epochs=args.lora_num_epochs, learning_rate=args.lora_lr, bf16=True, logging_steps=10, optim="paged_adamw_32bit", evaluation_strategy="steps", save_strategy="steps", eval_steps=args.lora_val_iterations, save_steps=args.lora_val_iterations, output_dir=export_root, save_total_limit=3, load_best_model_at_end=True, ddp_find_unused_parameters=None, group_by_length=False, report_to="wandb" if use_wandb else None, run_name=args.model_code+'_'+args.dataset_code if use_wandb else None, metric_for_best_model=args.rerank_best_metric, greater_is_better=True, ) super().__init__( model=model, args=hf_args, callbacks=[EarlyStoppingCallback(args.lora_early_stopping_patience)], **kwargs) # hf_args is now args self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.tokenizer = tokenizer self.train_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=False) self.val_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True) self.test_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True) self.compute_metrics = compute_metrics_for_ks(self.rerank_metric_ks, self.verbalizer) if len(self.label_names) == 0: self.label_names = ['labels'] # for some reason label name is not set def test(self, test_retrieval): average_metrics = self.predict(test_dataset=None).metrics print('Ranking Performance on Subset:', average_metrics) print('************************************************************') with open(os.path.join(self.export_root, 'subset_metrics.json'), 'w') as f: json.dump(average_metrics, f, indent=4) print('Original Performance:', test_retrieval['original_metrics']) print('************************************************************') original_size = test_retrieval['original_size'] retrieval_size = test_retrieval['retrieval_size'] overall_metrics = {} for key in test_retrieval['non_retrieval_metrics'].keys(): if 'test_' + key in average_metrics: overall_metrics['test_' + key] = (average_metrics['test_' + key] * retrieval_size + \ test_retrieval['non_retrieval_metrics'][key] * (original_size - retrieval_size)) / original_size print('Overall Performance of Our Framework:', overall_metrics) with open(os.path.join(self.export_root, 'overall_metrics.json'), 'w') as f: json.dump(overall_metrics, f, indent=4) return average_metrics def get_train_dataloader(self): return self.train_loader def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: return self.val_loader def get_test_dataloader(self, test_dataset: Optional[Dataset] = None) -> DataLoader: return self.test_loader ================================================ FILE: trainer/loggers.py ================================================ import os import torch from abc import ABCMeta, abstractmethod def save_state_dict(state_dict, path, filename): torch.save(state_dict, os.path.join(path, filename)) class LoggerService(object): def __init__(self, args, writer, val_loggers, test_loggers, use_wandb): self.args = args self.writer = writer self.val_loggers = val_loggers if val_loggers else [] self.test_loggers = test_loggers if test_loggers else [] self.use_wandb = use_wandb def complete(self): if self.use_wandb: self.writer.finish() else: self.writer.close() def log_val(self, log_data): criteria_met = False for logger in self.val_loggers: logger.log(self.writer, **log_data) if self.args.early_stopping and isinstance(logger, BestModelLogger): criteria_met = logger.patience_counter >= self.args.early_stopping_patience return criteria_met def log_test(self, log_data): for logger in self.test_loggers: logger.log(self.writer, **log_data) class AbstractBaseLogger(metaclass=ABCMeta): @abstractmethod def log(self, *args, **kwargs): raise NotImplementedError def complete(self, *args, **kwargs): pass class MetricGraphPrinter(AbstractBaseLogger): def __init__(self, key, graph_name, group_name, use_wandb): self.key = key self.graph_label = graph_name self.group_name = group_name self.use_wandb = use_wandb def log(self, writer, *args, **kwargs): if self.key in kwargs: if self.use_wandb: writer.log({self.group_name+'/'+self.graph_label: kwargs[self.key], 'batch': kwargs['accum_iter']}) else: writer.add_scalar(self.group_name+'/'+ self.graph_label, kwargs[self.key], kwargs['accum_iter']) else: print('Metric {} not found...'.format(self.key)) def complete(self, writer, *args, **kwargs): self.log(writer, *args, **kwargs) class RecentModelLogger(AbstractBaseLogger): def __init__(self, args, checkpoint_path, filename='checkpoint-recent.pth'): self.args = args self.checkpoint_path = checkpoint_path if not os.path.exists(self.checkpoint_path): self.checkpoint_path.mkdir(parents=True) self.recent_epoch = None self.filename = filename def log(self, *args, **kwargs): epoch = kwargs['epoch'] if self.recent_epoch != epoch: self.recent_epoch = epoch state_dict = kwargs['state_dict'] state_dict['epoch'] = kwargs['epoch'] save_state_dict(state_dict, self.checkpoint_path, self.filename) def complete(self, *args, **kwargs): save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename + '.final') class BestModelLogger(AbstractBaseLogger): def __init__(self, args, checkpoint_path, metric_key, filename='best_acc_model.pth'): self.args = args self.checkpoint_path = checkpoint_path if not os.path.exists(self.checkpoint_path): self.checkpoint_path.mkdir(parents=True) self.best_metric = 0. self.metric_key = metric_key self.filename = filename self.patience_counter = 0 def log(self, *args, **kwargs): current_metric = kwargs[self.metric_key] if self.best_metric < current_metric: # assumes the higher the better print("Update Best {} Model at {}".format( self.metric_key, kwargs['epoch'])) self.best_metric = current_metric save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename) if self.args.early_stopping: self.patience_counter = 0 elif self.args.early_stopping: self.patience_counter += 1 ================================================ FILE: trainer/lru.py ================================================ from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY from .utils import * from .loggers import * from .base import * import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import json import pickle import numpy as np from abc import * from pathlib import Path class LRUTrainer(BaseTrainer): def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb): super().__init__(args, model, train_loader, val_loader, test_loader, export_root, use_wandb) self.ce = nn.CrossEntropyLoss(ignore_index=0) def calculate_loss(self, batch): seqs, labels = batch logits = self.model(seqs) logits = logits.view(-1, logits.size(-1)) labels = labels.view(-1) loss = self.ce(logits, labels) return loss def calculate_metrics(self, batch, exclude_history=True): seqs, labels = batch scores = self.model(seqs)[:, -1, :] B, L = seqs.shape if exclude_history: for i in range(L): scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9 scores[:, 0] = -1e9 # padding metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels.view(-1), self.metric_ks) return metrics def generate_candidates(self, retrieved_data_path): self.model.eval() val_probs, val_labels = [], [] test_probs, test_labels = [], [] with torch.no_grad(): print('*************** Generating Candidates for Validation Set ***************') tqdm_dataloader = tqdm(self.val_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = self.to_device(batch) seqs, labels = batch scores = self.model(seqs)[:, -1, :] B, L = seqs.shape for i in range(L): scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9 scores[:, 0] = -1e9 # padding val_probs.extend(scores.tolist()) val_labels.extend(labels.view(-1).tolist()) val_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(val_probs), torch.tensor(val_labels).view(-1), self.metric_ks) print(val_metrics) print('****************** Generating Candidates for Test Set ******************') tqdm_dataloader = tqdm(self.test_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = self.to_device(batch) seqs, labels = batch scores = self.model(seqs)[:, -1, :] B, L = seqs.shape for i in range(L): scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9 scores[:, 0] = -1e9 # padding test_probs.extend(scores.tolist()) test_labels.extend(labels.view(-1).tolist()) test_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(test_probs), torch.tensor(test_labels).view(-1), self.metric_ks) print(test_metrics) with open(retrieved_data_path, 'wb') as f: pickle.dump({'val_probs': val_probs, 'val_labels': val_labels, 'val_metrics': val_metrics, 'test_probs': test_probs, 'test_labels': test_labels, 'test_metrics': test_metrics}, f) ================================================ FILE: trainer/utils.py ================================================ from config import * import json import os import pprint as pp import random from datetime import date from pathlib import Path import numpy as np import torch import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch import optim as optim def ndcg(scores, labels, k): scores = scores.cpu() labels = labels.cpu() rank = (-scores).argsort(dim=1) cut = rank[:, :k] hits = labels.gather(1, cut) position = torch.arange(2, 2+k) weights = 1 / torch.log2(position.float()) dcg = (hits.float() * weights).sum(1) idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in labels.sum(1)]) ndcg = dcg / idcg return ndcg.mean() def absolute_recall_mrr_ndcg_for_ks(scores, labels, ks): metrics = {} labels = F.one_hot(labels, num_classes=scores.size(1)) answer_count = labels.sum(1) labels_float = labels.float() rank = (-scores).argsort(dim=1) cut = rank for k in sorted(ks, reverse=True): cut = cut[:, :k] hits = labels_float.gather(1, cut) metrics['Recall@%d' % k] = \ (hits.sum(1) / torch.min(torch.Tensor([k]).to( labels.device), labels.sum(1).float())).mean().cpu().item() metrics['MRR@%d' % k] = \ (hits / torch.arange(1, k+1).unsqueeze(0).to( labels.device)).sum(1).mean().cpu().item() position = torch.arange(2, 2+k) weights = 1 / torch.log2(position.float()) dcg = (hits * weights.to(hits.device)).sum(1) idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in answer_count]).to(dcg.device) ndcg = (dcg / idcg).mean() metrics['NDCG@%d' % k] = ndcg.cpu().item() return metrics class AverageMeterSet(object): def __init__(self, meters=None): self.meters = meters if meters else {} def __getitem__(self, key): if key not in self.meters: meter = AverageMeter() meter.update(0) return meter return self.meters[key] def update(self, name, value, n=1): if name not in self.meters: self.meters[name] = AverageMeter() self.meters[name].update(value, n) def reset(self): for meter in self.meters.values(): meter.reset() def values(self, format_string='{}'): return {format_string.format(name): meter.val for name, meter in self.meters.items()} def averages(self, format_string='{}'): return {format_string.format(name): meter.avg for name, meter in self.meters.items()} def sums(self, format_string='{}'): return {format_string.format(name): meter.sum for name, meter in self.meters.items()} def counts(self, format_string='{}'): return {format_string.format(name): meter.count for name, meter in self.meters.items()} class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val self.count += n self.avg = self.sum / self.count def __format__(self, format): return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format) ================================================ FILE: trainer/verb.py ================================================ from abc import abstractmethod import json from transformers.file_utils import ModelOutput from transformers.data.processors.utils import InputFeatures import torch import torch.nn as nn import torch.nn.functional as F from yacs.config import CfgNode from transformers.tokenization_utils import PreTrainedTokenizer import numpy as np from collections import namedtuple import inspect from typing import * _VALID_TYPES = {tuple, list, str, int, float, bool, type(None)} def convert_cfg_to_dict(cfg_node, key_list=[]): """ Convert a config node to dictionary """ if not isinstance(cfg_node, CfgNode): if type(cfg_node) not in _VALID_TYPES: print("Key {} with value {} is not a valid type; valid types: {}".format( ".".join(key_list), type(cfg_node), _VALID_TYPES), ) return cfg_node else: cfg_dict = dict(cfg_node) for k, v in cfg_dict.items(): cfg_dict[k] = convert_cfg_to_dict(v, key_list + [k]) return cfg_dict def signature(f): r"""Get the function f 's input arguments. A useful gadget when some function slot might be instantiated into multiple functions. Args: f (:obj:`function`) : the function to get the input arguments. Returns: namedtuple : of args, default, varargs, keywords, respectively.s """ sig = inspect.signature(f) args = [ p.name for p in sig.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ] varargs = [ p.name for p in sig.parameters.values() if p.kind == inspect.Parameter.VAR_POSITIONAL ] varargs = varargs[0] if varargs else None keywords = [ p.name for p in sig.parameters.values() if p.kind == inspect.Parameter.VAR_KEYWORD ] keywords = keywords[0] if keywords else None defaults = [ p.default for p in sig.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and p.default is not p.empty ] or None argspec = namedtuple('Signature', ['args', 'defaults', 'varargs', 'keywords']) return argspec(args, defaults, varargs, keywords) class Verbalizer(nn.Module): r''' Base class for all the verbalizers. Args: tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy. classes (:obj:`Sequence[str]`): A sequence of classes that need to be projected. ''' def __init__(self, tokenizer: Optional[PreTrainedTokenizer] = None, classes: Optional[Sequence[str]] = None, num_classes: Optional[int] = None, ): super().__init__() self.tokenizer = tokenizer self.classes = classes if classes is not None and num_classes is not None: assert len(classes) == num_classes, "len(classes) != num_classes, Check you config." self.num_classes = num_classes elif num_classes is not None: self.num_classes = num_classes elif classes is not None: self.num_classes = len(classes) else: self.num_classes = None # raise AttributeError("No able to configure num_classes") self._in_on_label_words_set = False @property def label_words(self,): r''' Label words means the words in the vocabulary projected by the labels. E.g. if we want to establish a projection in sentiment classification: positive :math:`\rightarrow` {`wonderful`, `good`}, in this case, `wonderful` and `good` are label words. ''' if not hasattr(self, "_label_words"): raise RuntimeError("label words haven't been set.") return self._label_words @label_words.setter def label_words(self, label_words): if label_words is None: return self._label_words = self._match_label_words_to_label_ids(label_words) if not self._in_on_label_words_set: self.safe_on_label_words_set() def _match_label_words_to_label_ids(self, label_words): # TODO newly add function after docs written # TODO rename this function """ sort label words dict of verbalizer to match the label order of the classes """ if isinstance(label_words, dict): if self.classes is None: raise ValueError(""" classes attribute of the Verbalizer should be set since your given label words is a dict. Since we will match the label word with respect to class A, to A's index in classes """) if set(label_words.keys()) != set(self.classes): raise ValueError("name of classes in verbalizer are different from those of dataset") label_words = [ # sort the dict to match dataset label_words[c] for c in self.classes ] # length: label_size of the whole task elif isinstance(label_words, list) or isinstance(label_words, tuple): pass else: raise ValueError("Verbalizer label words must be list, tuple or dict") return label_words def safe_on_label_words_set(self,): self._in_on_label_words_set = True self.on_label_words_set() self._in_on_label_words_set = False def on_label_words_set(self,): r"""A hook to do something when textual label words were set. """ pass @property def vocab(self,) -> Dict: if not hasattr(self, '_vocab'): self._vocab = self.tokenizer.convert_ids_to_tokens(np.arange(self.vocab_size).tolist()) return self._vocab @property def vocab_size(self,) -> int: return self.tokenizer.vocab_size @abstractmethod def generate_parameters(self, **kwargs) -> List: r""" The verbalizer can be seen as an extra layer on top of the original pre-trained models. In manual verbalizer, it is a fixed one-hot vector of dimension ``vocab_size``, with the position of the label word being 1 and 0 everywhere else. In other situation, the parameters may be a continuous vector over the vocab, with each dimension representing a weight of that token. Moreover, the parameters may be set to trainable to allow label words selection. Therefore, this function serves as an abstract methods for generating the parameters of the verbalizer, and must be instantiated in any derived class. Note that the parameters need to be registered as a part of pytorch's module to It can be achieved by wrapping a tensor using ``nn.Parameter()``. """ raise NotImplementedError def register_calibrate_logits(self, logits: torch.Tensor): r""" This function aims to register logits that need to be calibrated, and detach the original logits from the current graph. """ if logits.requires_grad: logits = logits.detach() self._calibrate_logits = logits def process_outputs(self, outputs: torch.Tensor, batch: Union[Dict, InputFeatures], **kwargs): r"""By default, the verbalizer will process the logits of the PLM's output. Args: logits (:obj:`torch.Tensor`): The current logits generated by pre-trained language models. batch (:obj:`Union[Dict, InputFeatures]`): The input features of the data. """ return self.process_logits(outputs, batch=batch, **kwargs) def gather_outputs(self, outputs: ModelOutput): r""" retrieve useful output for the verbalizer from the whole model output By default, it will only retrieve the logits Args: outputs (:obj:`ModelOutput`) The output from the pretrained language model. Return: :obj:`torch.Tensor` The gathered output, should be of shape (``batch_size``, ``seq_len``, ``any``) """ return outputs.logits @staticmethod def aggregate(label_words_logits: torch.Tensor) -> torch.Tensor: r""" To aggregate logits on multiple label words into the label's logits Basic aggregator: mean of each label words' logits to a label's logits Can be re-implemented in advanced verbaliezer. Args: label_words_logits (:obj:`torch.Tensor`): The logits of the label words only. Return: :obj:`torch.Tensor`: The final logits calculated by the label words. """ if label_words_logits.dim()>2: return label_words_logits.mean(dim=-1) else: return label_words_logits def normalize(self, logits: torch.Tensor) -> torch.Tensor: r""" Given logits regarding the entire vocab, calculate the probs over the label words set by softmax. Args: logits(:obj:`Tensor`): The logits of the entire vocab. Returns: :obj:`Tensor`: The probability distribution over the label words set. """ batch_size = logits.shape[0] return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape) @abstractmethod def project(self, logits: torch.Tensor, **kwargs) -> torch.Tensor: r"""This method receives input logits of shape ``[batch_size, vocab_size]``, and use the parameters of this verbalizer to project the logits over entire vocab into the logits of labels words. Args: logits (:obj:`Tensor`): The logits over entire vocab generated by the pre-trained language model with shape [``batch_size``, ``max_seq_length``, ``vocab_size``] Returns: :obj:`Tensor`: The normalized probs (sum to 1) of each label . """ raise NotImplementedError def handle_multi_token(self, label_words_logits, mask): r""" Support multiple methods to handle the multi tokens produced by the tokenizer. We suggest using 'first' or 'max' if the some parts of the tokenization is not meaningful. Can broadcast to 3-d tensor. Args: label_words_logits (:obj:`torch.Tensor`): Returns: :obj:`torch.Tensor` """ if self.multi_token_handler == "first": label_words_logits = label_words_logits.select(dim=-1, index=0) elif self.multi_token_handler == "max": label_words_logits = label_words_logits - 1000*(1-mask.unsqueeze(0)) label_words_logits = label_words_logits.max(dim=-1).values elif self.multi_token_handler == "mean": label_words_logits = (label_words_logits*mask.unsqueeze(0)).sum(dim=-1)/(mask.unsqueeze(0).sum(dim=-1)+1e-15) else: raise ValueError("multi_token_handler {} not configured".format(self.multi_token_handler)) return label_words_logits @classmethod def from_config(cls, config: CfgNode, **kwargs): r"""load a verbalizer from verbalizer's configuration node. Args: config (:obj:`CfgNode`): the sub-configuration of verbalizer, i.e. ``config[config.verbalizer]`` if config is a global config node. kwargs: Other kwargs that might be used in initialize the verbalizer. The actual value should match the arguments of ``__init__`` functions. """ init_args = signature(cls.__init__).args _init_dict = {**convert_cfg_to_dict(config), **kwargs} if config is not None else kwargs init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args} verbalizer = cls(**init_dict) if hasattr(verbalizer, "from_file"): if not hasattr(config, "file_path"): pass else: if (not hasattr(config, "label_words") or config.label_words is None) and config.file_path is not None: if config.choice is None: config.choice = 0 verbalizer.from_file(config.file_path, config.choice) elif (hasattr(config, "label_words") and config.label_words is not None) and config.file_path is not None: raise RuntimeError("The text can't be both set from `text` and `file_path`.") return verbalizer def from_file(self, path: str, choice: Optional[int] = 0 ): r"""Load the predefined label words from verbalizer file. Currently support three types of file format: 1. a .jsonl or .json file, in which is a single verbalizer in dict format. 2. a .jsonal or .json file, in which is a list of verbalizers in dict format 3. a .txt or a .csv file, in which is the label words of a class are listed in line, separated by commas. Begin a new verbalizer by an empty line. This format is recommended when you don't know the name of each class. The details of verbalizer format can be seen in :ref:`How_to_write_a_verbalizer`. Args: path (:obj:`str`): The path of the local template file. choice (:obj:`int`): The choice of verbalizer in a file containing multiple verbalizers. Returns: Template : `self` object """ if path.endswith(".txt") or path.endswith(".csv"): with open(path, 'r') as f: lines = f.readlines() label_words_all = [] label_words_single_group = [] for line in lines: line = line.strip().strip(" ") if line == "": if len(label_words_single_group)>0: label_words_all.append(label_words_single_group) label_words_single_group = [] else: label_words_single_group.append(line) if len(label_words_single_group) > 0: # if no empty line in the last label_words_all.append(label_words_single_group) if choice >= len(label_words_all): raise RuntimeError("choice {} exceed the number of verbalizers {}" .format(choice, len(label_words_all))) label_words = label_words_all[choice] label_words = [label_words_per_label.strip().split(",") \ for label_words_per_label in label_words] elif path.endswith(".jsonl") or path.endswith(".json"): with open(path, "r") as f: label_words_all = json.load(f) # if it is a file containing multiple verbalizers if isinstance(label_words_all, list): if choice >= len(label_words_all): raise RuntimeError("choice {} exceed the number of verbalizers {}" .format(choice, len(label_words_all))) label_words = label_words_all[choice] elif isinstance(label_words_all, dict): label_words = label_words_all if choice>0: print("Choice of verbalizer is 1, but the file \ only contains one verbalizer.") self.label_words = label_words if self.num_classes is not None: num_classes = len(self.label_words) assert num_classes==self.num_classes, 'number of classes in the verbalizer file\ does not match the predefined num_classes.' return self class ManualVerbalizer(Verbalizer): r""" The basic manually defined verbalizer class, this class is inherited from the :obj:`Verbalizer` class. Args: tokenizer (:obj:`PreTrainedTokenizer`): The tokenizer of the current pre-trained model to point out the vocabulary. classes (:obj:`List[Any]`): The classes (or labels) of the current task. label_words (:obj:`Union[List[str], List[List[str]], Dict[List[str]]]`, optional): The label words that are projected by the labels. prefix (:obj:`str`, optional): The prefix string of the verbalizer (used in PLMs like RoBERTa, which is sensitive to prefix space) multi_token_handler (:obj:`str`, optional): The handling strategy for multiple tokens produced by the tokenizer. post_log_softmax (:obj:`bool`, optional): Whether to apply log softmax post processing on label_logits. Default to True. """ def __init__(self, tokenizer: PreTrainedTokenizer, classes: Optional[List] = None, num_classes: Optional[Sequence[str]] = None, label_words: Optional[Union[Sequence[str], Mapping[str, str]]] = None, prefix: Optional[str] = " ", multi_token_handler: Optional[str] = "first", post_log_softmax: Optional[bool] = True, ): super().__init__(tokenizer=tokenizer, num_classes=num_classes, classes=classes) self.prefix = prefix self.multi_token_handler = multi_token_handler self.label_words = label_words self.post_log_softmax = post_log_softmax def on_label_words_set(self): super().on_label_words_set() self.label_words = self.add_prefix(self.label_words, self.prefix) # TODO should Verbalizer base class has label_words property and setter? # it don't have label_words init argument or label words from_file option at all self.generate_parameters() @staticmethod def add_prefix(label_words, prefix): r"""Add prefix to label words. For example, if a label words is in the middle of a template, the prefix should be ``' '``. Args: label_words (:obj:`Union[Sequence[str], Mapping[str, str]]`, optional): The label words that are projected by the labels. prefix (:obj:`str`, optional): The prefix string of the verbalizer. Returns: :obj:`Sequence[str]`: New label words with prefix. """ new_label_words = [] if isinstance(label_words[0], str): label_words = [[w] for w in label_words] #wrapped it to a list of list of label words. for label_words_per_label in label_words: new_label_words_per_label = [] for word in label_words_per_label: if word.startswith(""): new_label_words_per_label.append(word.split("")[1]) else: new_label_words_per_label.append(prefix + word) new_label_words.append(new_label_words_per_label) return new_label_words def generate_parameters(self) -> List: r"""In basic manual template, the parameters are generated from label words directly. In this implementation, the label_words should not be tokenized into more than one token. """ all_ids = [] for words_per_label in self.label_words: ids_per_label = [] for word in words_per_label: ids = self.tokenizer.encode(word, add_special_tokens=False) ids_per_label.append(ids) all_ids.append(ids_per_label) max_len = max([max([len(ids) for ids in ids_per_label]) for ids_per_label in all_ids]) max_num_label_words = max([len(ids_per_label) for ids_per_label in all_ids]) words_ids_mask = torch.zeros(max_num_label_words, max_len) words_ids_mask = [[[1]*len(ids) + [0]*(max_len-len(ids)) for ids in ids_per_label] + [[0]*max_len]*(max_num_label_words-len(ids_per_label)) for ids_per_label in all_ids] words_ids = [[ids + [0]*(max_len-len(ids)) for ids in ids_per_label] + [[0]*max_len]*(max_num_label_words-len(ids_per_label)) for ids_per_label in all_ids] words_ids_tensor = torch.tensor(words_ids) words_ids_mask = torch.tensor(words_ids_mask) self.label_words_ids = nn.Parameter(words_ids_tensor, requires_grad=False) self.words_ids_mask = nn.Parameter(words_ids_mask, requires_grad=False) # A 3-d mask self.label_words_mask = nn.Parameter(torch.clamp(words_ids_mask.sum(dim=-1), max=1), requires_grad=False) def project(self, logits: torch.Tensor, **kwargs, ) -> torch.Tensor: r""" Project the labels, the return value is the normalized (sum to 1) probs of label words. Args: logits (:obj:`torch.Tensor`): The original logits of label words. Returns: :obj:`torch.Tensor`: The normalized logits of label words """ label_words_logits = logits[:, self.label_words_ids] label_words_logits = self.handle_multi_token(label_words_logits, self.words_ids_mask) label_words_logits -= 10000*(1-self.label_words_mask) return label_words_logits def process_logits(self, logits: torch.Tensor, **kwargs): r"""A whole framework to process the original logits over the vocabulary, which contains four steps: (1) Project the logits into logits of label words if self.post_log_softmax is True: (2) Normalize over all label words (3) Calibrate (optional) (4) Aggregate (for multiple label words) Args: logits (:obj:`torch.Tensor`): The original logits. Returns: (:obj:`torch.Tensor`): The final processed logits over the labels (classes). """ # project label_words_logits = self.project(logits, **kwargs) #Output: (batch_size, num_classes) or (batch_size, num_classes, num_label_words_per_label) if self.post_log_softmax: # normalize label_words_probs = self.normalize(label_words_logits) # calibrate if hasattr(self, "_calibrate_logits") and self._calibrate_logits is not None: label_words_probs = self.calibrate(label_words_probs=label_words_probs) # convert to logits label_words_logits = torch.log(label_words_probs+1e-15) # aggregate label_logits = self.aggregate(label_words_logits) return label_logits def normalize(self, logits: torch.Tensor) -> torch.Tensor: """ Given logits regarding the entire vocabulary, return the probs over the label words set. Args: logits (:obj:`Tensor`): The logits over the entire vocabulary. Returns: :obj:`Tensor`: The logits over the label words set. """ batch_size = logits.shape[0] return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape) def aggregate(self, label_words_logits: torch.Tensor) -> torch.Tensor: r"""Use weight to aggregate the logits of label words. Args: label_words_logits(:obj:`torch.Tensor`): The logits of the label words. Returns: :obj:`torch.Tensor`: The aggregated logits from the label words. """ label_words_logits = (label_words_logits * self.label_words_mask).sum(-1)/self.label_words_mask.sum(-1) return label_words_logits def calibrate(self, label_words_probs: torch.Tensor, **kwargs) -> torch.Tensor: r""" Args: label_words_probs (:obj:`torch.Tensor`): The probability distribution of the label words with the shape of [``batch_size``, ``num_classes``, ``num_label_words_per_class``] Returns: :obj:`torch.Tensor`: The calibrated probability of label words. """ shape = label_words_probs.shape assert self._calibrate_logits.dim() == 1, "self._calibrate_logits are not 1-d tensor" calibrate_label_words_probs = self.normalize(self.project(self._calibrate_logits.unsqueeze(0), **kwargs)) assert calibrate_label_words_probs.shape[1:] == label_words_probs.shape[1:] \ and calibrate_label_words_probs.shape[0]==1, "shape not match" label_words_probs /= (calibrate_label_words_probs+1e-15) # normalize # TODO Test the performance norm = label_words_probs.reshape(shape[0], -1).sum(dim=-1,keepdim=True) # TODO Test the performance of detaching() label_words_probs = label_words_probs.reshape(shape[0], -1) / norm label_words_probs = label_words_probs.reshape(*shape) return label_words_probs