Repository: Yueeeeeeee/LlamaRec
Branch: main
Commit: 48b288b23197
Files: 32
Total size: 159.5 KB
Directory structure:
gitextract_v751wz7n/
├── .gitignore
├── README.md
├── config.py
├── dataloader/
│ ├── __init__.py
│ ├── base.py
│ ├── llm.py
│ ├── lru.py
│ ├── templates/
│ │ ├── README.md
│ │ ├── alpaca.json
│ │ ├── alpaca_legacy.json
│ │ ├── alpaca_short.json
│ │ └── vigogne.json
│ └── utils.py
├── datasets/
│ ├── __init__.py
│ ├── base.py
│ ├── beauty.py
│ ├── games.py
│ ├── ml_100k.py
│ └── utils.py
├── model/
│ ├── __init__.py
│ ├── llm.py
│ └── lru.py
├── requirements.txt
├── train_ranker.py
├── train_retriever.py
└── trainer/
├── __init__.py
├── base.py
├── llm.py
├── loggers.py
├── lru.py
├── utils.py
└── verb.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
*.p
*.pt
*.pth
.DS_Store
/.ipynb_checkpoints/*
/.vscode/*
/wandb/*
/data/*
/retrieved/*
/experiments/*
/archive/*
================================================
FILE: README.md
================================================
# LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking
This repository is the PyTorch impelementation for the PGAI@CIKM 2023 paper **LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking [[Paper](https://arxiv.org/abs/2311.02089)]**.
We propose a two-stage framework using large language models for ranking-based recommendation (LlamaRec). In particular, we use small-scale sequential recommenders to retrieve candidates based on the user interaction history. Then, both history and retrieved items are fed to the LLM in text via a carefully designed prompt template. Instead of generating next-item titles, we adopt a verbalizer-based approach that transforms output logits into probability distributions over the candidate items. Therefore, LlamaRec can efficiently rank items without generating long text and achieve superior performance in both recommendation performance and efficiency.
## Requirements
Pytorch, transformers, peft, bitsandbytes etc. For our detailed running environment see requirements.txt.
## How to run LlamaRec
The command below starts the training of the retriever model LRURec
```bash
python train_retriever.py
```
You can set additional arguments like weight_decay to change the hyperparameters. Upon the command, you will be prompted to select dataset from ML-100k, Beauty and Games. Once training is finished, evaluation is automatically performed with the best retriever model.
Then, run the following command to train the ranker model based on Llama 2
```bash
python train_ranker.py --llm_retrieved_path PATH_TO_RETRIEVER
```
Please specify PATH_TO_RETRIEVER with the retriever path from the previous step. To run this command, you will need access to meta-llama/Llama-2-7b-hf on the HF hub. Similarly, evaluation is performed after training is finished. All weights and results are saved under ./experiments.
## Performance
The table below reports our main performance results, with best results marked in bold and second best results underlined. For training and evaluation details, please refer to our paper.
## Citation
Please consider citing the following papers if you use our methods in your research:
```
@article{yue2023linear,
title={Linear Recurrent Units for Sequential Recommendation},
author={Yue, Zhenrui and Wang, Yueqi and He, Zhankui and Zeng, Huimin and McAuley, Julian and Wang, Dong},
journal={arXiv preprint arXiv:2310.02367},
year={2023}
}
@article{yue2023llamarec,
title={LlamaRec: Two-Stage Recommendation using Large Language Models for Ranking},
author={Yue, Zhenrui and Rabhi, Sara and Moreira, Gabriel de Souza Pereira and Wang, Dong and Oldridge, Even},
journal={arXiv preprint arXiv:2311.02089},
year={2023}
}
```
================================================
FILE: config.py
================================================
import numpy as np
import random
import torch
import argparse
RAW_DATASET_ROOT_FOLDER = 'data'
EXPERIMENT_ROOT = 'experiments'
STATE_DICT_KEY = 'model_state_dict'
OPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict'
PROJECT_NAME = 'llmrec'
def set_template(args):
if args.dataset_code == None:
print('******************** Dataset Selection ********************')
dataset_code = {'1': 'ml-100k', 'b': 'beauty', 'g': 'games'}
args.dataset_code = dataset_code[input('Input 1 for ml-100k, b for beauty and g for games: ')]
if args.dataset_code == 'ml-100k':
args.bert_max_len = 200
else:
args.bert_max_len = 50
if 'llm' in args.model_code:
batch = 16 if args.dataset_code == 'ml-100k' else 12
args.lora_micro_batch_size = batch
else:
batch = 16 if args.dataset_code == 'ml-100k' else 64
args.train_batch_size = batch
args.val_batch_size = batch
args.test_batch_size = batch
if torch.cuda.is_available(): args.device = 'cuda'
else: args.device = 'cpu'
args.optimizer = 'AdamW'
args.lr = 0.001
args.weight_decay = 0.01
args.enable_lr_schedule = False
args.decay_step = 10000
args.gamma = 1.
args.enable_lr_warmup = False
args.warmup_steps = 100
args.metric_ks = [1, 5, 10, 20, 50]
args.rerank_metric_ks = [1, 5, 10]
args.best_metric = 'Recall@10'
args.rerank_best_metric = 'NDCG@10'
args.bert_num_blocks = 2
args.bert_num_heads = 2
args.bert_head_size = None
parser = argparse.ArgumentParser()
################
# Dataset
################
parser.add_argument('--dataset_code', type=str, default=None)
parser.add_argument('--min_rating', type=int, default=0)
parser.add_argument('--min_uc', type=int, default=5)
parser.add_argument('--min_sc', type=int, default=5)
parser.add_argument('--seed', type=int, default=42)
################
# Dataloader
################
parser.add_argument('--train_batch_size', type=int, default=64)
parser.add_argument('--val_batch_size', type=int, default=64)
parser.add_argument('--test_batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--sliding_window_size', type=float, default=1.0)
parser.add_argument('--negative_sample_size', type=int, default=10)
################
# Trainer
################
# optimization #
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
parser.add_argument('--num_epochs', type=int, default=500)
parser.add_argument('--optimizer', type=str, default='AdamW', choices=['AdamW', 'Adam'])
parser.add_argument('--weight_decay', type=float, default=None)
parser.add_argument('--adam_epsilon', type=float, default=1e-9)
parser.add_argument('--momentum', type=float, default=None)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--max_grad_norm', type=float, default=5.0)
parser.add_argument('--enable_lr_schedule', type=bool, default=True)
parser.add_argument('--decay_step', type=int, default=10000)
parser.add_argument('--gamma', type=float, default=1)
parser.add_argument('--enable_lr_warmup', type=bool, default=True)
parser.add_argument('--warmup_steps', type=int, default=100)
# evaluation #
parser.add_argument('--val_strategy', type=str, default='iteration', choices=['epoch', 'iteration'])
parser.add_argument('--val_iterations', type=int, default=500) # only for iteration val_strategy
parser.add_argument('--early_stopping', type=bool, default=True)
parser.add_argument('--early_stopping_patience', type=int, default=20)
parser.add_argument('--metric_ks', nargs='+', type=int, default=[1, 5, 10, 20, 50])
parser.add_argument('--rerank_metric_ks', nargs='+', type=int, default=[1, 5, 10])
parser.add_argument('--best_metric', type=str, default='Recall@10')
parser.add_argument('--rerank_best_metric', type=str, default='NDCG@10')
parser.add_argument('--use_wandb', type=bool, default=False)
################
# Retriever Model
################
parser.add_argument('--model_code', type=str, default=None)
parser.add_argument('--bert_max_len', type=int, default=50)
parser.add_argument('--bert_hidden_units', type=int, default=64)
parser.add_argument('--bert_num_blocks', type=int, default=2)
parser.add_argument('--bert_num_heads', type=int, default=2)
parser.add_argument('--bert_head_size', type=int, default=32)
parser.add_argument('--bert_dropout', type=float, default=0.2)
parser.add_argument('--bert_attn_dropout', type=float, default=0.2)
parser.add_argument('--bert_mask_prob', type=float, default=0.25)
################
# LLM Model
################
parser.add_argument('--llm_base_model', type=str, default='meta-llama/Llama-2-7b-hf')
parser.add_argument('--llm_base_tokenizer', type=str, default='meta-llama/Llama-2-7b-hf')
parser.add_argument('--llm_max_title_len', type=int, default=32)
parser.add_argument('--llm_max_text_len', type=int, default=1536)
parser.add_argument('--llm_max_history', type=int, default=20)
parser.add_argument('--llm_train_on_inputs', type=bool, default=False)
parser.add_argument('--llm_negative_sample_size', type=int, default=19) # 19 negative & 1 positive
parser.add_argument('--llm_system_template', type=str, # instruction
default="Given user history in chronological order, recommend an item from the candidate pool with its index letter.")
parser.add_argument('--llm_input_template', type=str, \
default='User history: {}; \n Candidate pool: {}')
parser.add_argument('--llm_load_in_4bit', type=bool, default=True)
parser.add_argument('--llm_retrieved_path', type=str, default=None)
parser.add_argument('--llm_cache_dir', type=str, default=None)
################
# Lora
################
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--lora_alpha', type=int, default=32)
parser.add_argument('--lora_dropout', type=float, default=0.05)
parser.add_argument('--lora_target_modules', type=list, default=['q_proj', 'v_proj'])
parser.add_argument('--lora_num_epochs', type=int, default=1)
parser.add_argument('--lora_val_iterations', type=int, default=100)
parser.add_argument('--lora_early_stopping_patience', type=int, default=20)
parser.add_argument('--lora_lr', type=float, default=1e-4)
parser.add_argument('--lora_micro_batch_size', type=int, default=16)
################
args = parser.parse_args()
================================================
FILE: dataloader/__init__.py
================================================
from datasets import dataset_factory
from .lru import *
from .llm import *
from .utils import *
def dataloader_factory(args):
dataset = dataset_factory(args)
if args.model_code == 'lru':
dataloader = LRUDataloader(args, dataset)
elif args.model_code == 'llm':
dataloader = LLMDataloader(args, dataset)
train, val, test = dataloader.get_pytorch_dataloaders()
if 'llm' in args.model_code:
tokenizer = dataloader.tokenizer
test_retrieval = dataloader.test_retrieval
return train, val, test, tokenizer, test_retrieval
else:
return train, val, test
def test_subset_dataloader_loader(args):
dataset = dataset_factory(args)
if args.model_code == 'lru':
dataloader = LRUDataloader(args, dataset)
elif args.model_code == 'llm':
dataloader = LLMDataloader(args, dataset)
return dataloader.get_pytorch_test_subset_dataloader()
================================================
FILE: dataloader/base.py
================================================
from abc import *
import random
class AbstractDataloader(metaclass=ABCMeta):
def __init__(self, args, dataset):
self.args = args
self.save_folder = dataset._get_preprocessed_folder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.meta = dataset['meta']
self.umap = dataset['umap']
self.smap = dataset['smap']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
@classmethod
@abstractmethod
def code(cls):
pass
@abstractmethod
def get_pytorch_dataloaders(self):
pass
================================================
FILE: dataloader/llm.py
================================================
from .base import AbstractDataloader
from .utils import Prompter
import torch
import random
import numpy as np
import torch.utils.data as data_utils
import os
import pickle
import transformers
from transformers import AutoTokenizer
from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
from trainer import absolute_recall_mrr_ndcg_for_ks
def worker_init_fn(worker_id):
random.seed(np.random.get_state()[1][0] + worker_id)
np.random.seed(np.random.get_state()[1][0] + worker_id)
# the following prompting is based on alpaca
def generate_and_tokenize_eval(args, data_point, tokenizer, prompter):
in_prompt = prompter.generate_prompt(data_point["system"],
data_point["input"])
tokenized_full_prompt = tokenizer(in_prompt,
truncation=True,
max_length=args.llm_max_text_len,
padding=False,
return_tensors=None)
tokenized_full_prompt["labels"] = ord(data_point["output"]) - ord('A')
return tokenized_full_prompt
def generate_and_tokenize_train(args, data_point, tokenizer, prompter):
def tokenize(prompt, add_eos_token=True):
result = tokenizer(prompt,
truncation=True,
max_length=args.llm_max_text_len,
padding=False,
return_tensors=None)
if (result["input_ids"][-1] != tokenizer.eos_token_id and add_eos_token):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
full_prompt = prompter.generate_prompt(data_point["system"],
data_point["input"],
data_point["output"])
tokenized_full_prompt = tokenize(full_prompt, add_eos_token=True)
if not args.llm_train_on_inputs:
tokenized_full_prompt["labels"][:-2] = [-100] * len(tokenized_full_prompt["labels"][:-2])
return tokenized_full_prompt
def seq_to_token_ids(args, seq, candidates, label, text_dict, tokenizer, prompter, eval=False):
def truncate_title(title):
title_ = tokenizer.tokenize(title)[:args.llm_max_title_len]
title = tokenizer.convert_tokens_to_string(title_)
return title
seq_t = ' \n '.join(['(' + str(idx + 1) + ') ' + truncate_title(text_dict[item])
for idx, item in enumerate(seq)])
can_t = ' \n '.join(['(' + chr(ord('A') + idx) + ') ' + truncate_title(text_dict[item])
for idx, item in enumerate(candidates)])
output = chr(ord('A') + candidates.index(label)) # ranking only
data_point = {}
data_point['system'] = args.llm_system_template if args.llm_system_template is not None else DEFAULT_SYSTEM_PROMPT
data_point['input'] = args.llm_input_template.format(seq_t, can_t)
data_point['output'] = output
if eval:
return generate_and_tokenize_eval(args, data_point, tokenizer, prompter)
else:
return generate_and_tokenize_train(args, data_point, tokenizer, prompter)
class LLMDataloader():
def __init__(self, args, dataset):
self.args = args
self.rng = np.random
self.save_folder = dataset._get_preprocessed_folder_path()
seq_dataset = dataset.load_dataset()
self.train = seq_dataset['train']
self.val = seq_dataset['val']
self.test = seq_dataset['test']
self.umap = seq_dataset['umap']
self.smap = seq_dataset['smap']
self.text_dict = seq_dataset['meta']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
args.num_items = self.item_count
self.max_len = args.llm_max_history
self.tokenizer = AutoTokenizer.from_pretrained(
args.llm_base_tokenizer, cache_dir=args.llm_cache_dir)
self.tokenizer.pad_token = self.tokenizer.unk_token
self.tokenizer.padding_side = 'left'
self.tokenizer.truncation_side = 'left'
self.tokenizer.clean_up_tokenization_spaces = True
self.prompter = Prompter()
self.llm_retrieved_path = args.llm_retrieved_path
print('Loading retrieved file from {}'.format(self.llm_retrieved_path))
retrieved_file = pickle.load(open(os.path.join(args.llm_retrieved_path,
'retrieved.pkl'), 'rb'))
print('******************** Constructing Validation Subset ********************')
self.val_probs = retrieved_file['val_probs']
self.val_labels = retrieved_file['val_labels']
self.val_metrics = retrieved_file['val_metrics']
self.val_users = [u for u, (p, l) in enumerate(zip(self.val_probs, self.val_labels), start=1) \
if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]
self.val_candidates = [torch.topk(torch.tensor(self.val_probs[u-1]),
self.args.llm_negative_sample_size+1).indices.tolist() for u in self.val_users]
print('******************** Constructing Test Subset ********************')
self.test_probs = retrieved_file['test_probs']
self.test_labels = retrieved_file['test_labels']
self.test_metrics = retrieved_file['test_metrics']
self.test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \
if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]
self.test_candidates = [torch.topk(torch.tensor(self.test_probs[u-1]),
self.args.llm_negative_sample_size+1).indices.tolist() for u in self.test_users]
self.non_test_users = [u for u, (p, l) in enumerate(zip(self.test_probs, self.test_labels), start=1) \
if l not in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]
self.test_retrieval = {
'original_size': len(self.test_probs),
'retrieval_size': len(self.test_candidates),
'original_metrics': self.test_metrics,
'retrieval_metrics': absolute_recall_mrr_ndcg_for_ks(
torch.tensor(self.test_probs)[torch.tensor(self.test_users)-1],
torch.tensor(self.test_labels)[torch.tensor(self.test_users)-1],
self.args.metric_ks,
),
'non_retrieval_metrics': absolute_recall_mrr_ndcg_for_ks(
torch.tensor(self.test_probs)[torch.tensor(self.non_test_users)-1],
torch.tensor(self.test_labels)[torch.tensor(self.non_test_users)-1],
self.args.metric_ks,
),
}
@classmethod
def code(cls):
return 'llm'
def get_pytorch_dataloaders(self):
train_loader = self._get_train_loader()
val_loader = self._get_val_loader()
test_loader = self._get_test_loader()
return train_loader, val_loader, test_loader
def _get_train_loader(self):
dataset = self._get_train_dataset()
dataloader = data_utils.DataLoader(dataset, batch_size=self.args.lora_micro_batch_size,
shuffle=True, pin_memory=True, num_workers=self.args.num_workers,
worker_init_fn=worker_init_fn)
return dataloader
def _get_train_dataset(self):
dataset = LLMTrainDataset(self.args, self.train, self.max_len, self.rng,
self.text_dict, self.tokenizer, self.prompter)
return dataset
def _get_val_loader(self):
return self._get_eval_loader(mode='val')
def _get_test_loader(self):
return self._get_eval_loader(mode='test')
def _get_eval_loader(self, mode):
batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
dataset = self._get_eval_dataset(mode)
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True,
pin_memory=True, num_workers=self.args.num_workers)
return dataloader
def _get_eval_dataset(self, mode):
if mode == 'val':
dataset = LLMValidDataset(self.args, self.train, self.val, self.max_len, self.rng, \
self.text_dict, self.tokenizer, self.prompter, self.val_users, \
self.val_candidates)
elif mode == 'test':
dataset = LLMTestDataset(self.args, self.train, self.val, self.test, self.max_len, \
self.rng, self.text_dict, self.tokenizer, self.prompter, self.test_users, \
self.test_candidates)
return dataset
class LLMTrainDataset(data_utils.Dataset):
def __init__(self, args, u2seq, max_len, rng, text_dict, tokenizer, prompter):
self.args = args
self.max_len = max_len
self.num_items = args.num_items
self.rng = rng
self.text_dict = text_dict
self.tokenizer = tokenizer
self.prompter = prompter
self.all_seqs = []
for u in sorted(u2seq.keys()):
seq = u2seq[u]
for i in range(2, len(seq)+1):
self.all_seqs += [seq[:i]]
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
tokens = self.all_seqs[index]
answer = tokens[-1]
original_seq = tokens[:-1]
seq = original_seq[-self.max_len:]
cur_idx, candidates = 0, [answer]
samples = self.rng.randint(1, self.args.num_items+1, size=5*self.args.llm_negative_sample_size)
while len(candidates) < self.args.llm_negative_sample_size + 1:
item = samples[cur_idx]
cur_idx += 1
if item in original_seq or item == answer: continue
else: candidates.append(item)
self.rng.shuffle(candidates)
return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, \
self.tokenizer, self.prompter, eval=False)
class LLMValidDataset(data_utils.Dataset):
def __init__(self, args, u2seq, u2answer, max_len, rng, text_dict, tokenizer, prompter, val_users, val_candidates):
self.args = args
self.u2seq = u2seq
self.u2answer = u2answer
self.users = sorted(self.u2seq.keys())
self.max_len = max_len
self.rng = rng
self.text_dict = text_dict
self.tokenizer = tokenizer
self.prompter = prompter
self.val_users = val_users
self.val_candidates = val_candidates
def __len__(self):
return len(self.val_users)
def __getitem__(self, index):
user = self.val_users[index]
seq = self.u2seq[user]
answer = self.u2answer[user][0]
seq = seq[-self.max_len:]
candidates = self.val_candidates[index]
assert answer in candidates
# self.rng.shuffle(candidates)
return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True)
class LLMTestDataset(data_utils.Dataset):
def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, text_dict, tokenizer, prompter, test_users, test_candidates):
self.args = args
self.u2seq = u2seq
self.u2val = u2val
self.u2answer = u2answer
self.users = sorted(u2seq.keys())
self.max_len = max_len
self.rng = rng
self.text_dict = text_dict
self.tokenizer = tokenizer
self.prompter = prompter
self.test_users = test_users
self.test_candidates = test_candidates
def __len__(self):
return len(self.test_users)
def __getitem__(self, index):
user = self.test_users[index]
seq = self.u2seq[user] + self.u2val[user]
answer = self.u2answer[user][0]
seq = seq[-self.max_len:]
candidates = self.test_candidates[index]
assert answer in candidates
# self.rng.shuffle(candidates)
return seq_to_token_ids(self.args, seq, candidates, answer, self.text_dict, self.tokenizer, self.prompter, eval=True)
================================================
FILE: dataloader/lru.py
================================================
from .base import AbstractDataloader
import os
import torch
import random
import pickle
import numpy as np
import torch.utils.data as data_utils
def worker_init_fn(worker_id):
random.seed(np.random.get_state()[1][0] + worker_id)
np.random.seed(np.random.get_state()[1][0] + worker_id)
class LRUDataloader():
def __init__(self, args, dataset):
self.args = args
self.rng = np.random
self.save_folder = dataset._get_preprocessed_folder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.umap = dataset['umap']
self.smap = dataset['smap']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
args.num_users = self.user_count
args.num_items = self.item_count
self.max_len = args.bert_max_len
self.sliding_size = args.sliding_window_size
@classmethod
def code(cls):
return 'lru'
def get_pytorch_dataloaders(self):
train_loader = self._get_train_loader()
val_loader = self._get_val_loader()
test_loader = self._get_test_loader()
return train_loader, val_loader, test_loader
def get_pytorch_test_subset_dataloader(self):
retrieved_file_path = self.args.llm_retrieved_path
print('Loading retrieved file from {}'.format(retrieved_file_path))
retrieved_file = pickle.load(open(os.path.join(retrieved_file_path,
'retrieved.pkl'), 'rb'))
test_probs = retrieved_file['test_probs']
test_labels = retrieved_file['test_labels']
test_users = [u for u, (p, l) in enumerate(zip(test_probs, test_labels), start=1) \
if l in torch.topk(torch.tensor(p), self.args.llm_negative_sample_size+1).indices]
dataset = dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len,
self.rng, subset_users=test_users)
dataloader = data_utils.DataLoader(dataset, batch_size=self.args.val_batch_size, shuffle=False,
pin_memory=True, num_workers=self.args.num_workers)
return dataloader
def _get_train_loader(self):
dataset = self._get_train_dataset()
dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True, num_workers=self.args.num_workers,
worker_init_fn=worker_init_fn)
return dataloader
def _get_train_dataset(self):
dataset = LRUTrainDataset(
self.args, self.train, self.max_len, self.sliding_size, self.rng)
return dataset
def _get_val_loader(self):
return self._get_eval_loader(mode='val')
def _get_test_loader(self):
return self._get_eval_loader(mode='test')
def _get_eval_loader(self, mode):
batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
dataset = self._get_eval_dataset(mode)
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False,
pin_memory=True, num_workers=self.args.num_workers)
return dataloader
def _get_eval_dataset(self, mode):
if mode == 'val':
dataset = LRUValidDataset(self.args, self.train, self.val, self.max_len, self.rng)
elif mode == 'test':
dataset = LRUTestDataset(self.args, self.train, self.val, self.test, self.max_len, self.rng)
return dataset
class LRUTrainDataset(data_utils.Dataset):
def __init__(self, args, u2seq, max_len, sliding_size, rng):
self.args = args
self.max_len = max_len
self.sliding_step = int(sliding_size * max_len)
self.num_items = args.num_items
self.rng = rng
assert self.sliding_step > 0
self.all_seqs = []
for u in sorted(u2seq.keys()):
seq = u2seq[u]
if len(seq) < self.max_len + self.sliding_step:
self.all_seqs.append(seq)
else:
start_idx = range(len(seq) - max_len, -1, -self.sliding_step)
self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx]
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
seq = self.all_seqs[index]
labels = seq[-self.max_len:]
tokens = seq[:-1][-self.max_len:]
mask_len = self.max_len - len(tokens)
tokens = [0] * mask_len + tokens
mask_len = self.max_len - len(labels)
labels = [0] * mask_len + labels
return torch.LongTensor(tokens), torch.LongTensor(labels)
class LRUValidDataset(data_utils.Dataset):
def __init__(self, args, u2seq, u2answer, max_len, rng):
self.args = args
self.u2seq = u2seq
self.u2answer = u2answer
users = sorted(self.u2seq.keys())
self.users = [u for u in users if len(u2answer[u]) > 0]
self.max_len = max_len
self.rng = rng
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.u2seq[user]
answer = self.u2answer[user]
seq = seq[-self.max_len:]
padding_len = self.max_len - len(seq)
seq = [0] * padding_len + seq
return torch.LongTensor(seq), torch.LongTensor(answer)
class LRUTestDataset(data_utils.Dataset):
def __init__(self, args, u2seq, u2val, u2answer, max_len, rng, subset_users=None):
self.args = args
self.u2seq = u2seq
self.u2val = u2val
self.u2answer = u2answer
users = sorted(self.u2seq.keys())
self.users = [u for u in users if len(u2val[u]) > 0 and len(u2answer[u]) > 0]
self.max_len = max_len
self.rng = rng
if subset_users is not None:
self.users = subset_users
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.u2seq[user] + self.u2val[user]
answer = self.u2answer[user]
seq = seq[-self.max_len:]
padding_len = self.max_len - len(seq)
seq = [0] * padding_len + seq
return torch.LongTensor(seq), torch.LongTensor(answer)
================================================
FILE: dataloader/templates/README.md
================================================
# Prompt templates
This directory contains template styles for the prompts used to finetune LoRA models.
## Format
A template is described via a JSON file with the following keys:
- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
- `description`: A short description of the template, with possible use cases.
- `response_split`: The text to use as separator when cutting real response from the model output.
No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
## Example template
The default template, used unless otherwise specified, is `alpaca.json`
```json
{
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}
```
## Current templates
### alpaca
Default template used for generic LoRA fine tunes so far.
### alpaca_legacy
Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
### alpaca_short
A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
### vigogne
The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.
================================================
FILE: dataloader/templates/alpaca.json
================================================
{
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}
================================================
FILE: dataloader/templates/alpaca_legacy.json
================================================
{
"description": "Legacy template, used by Original Alpaca repository.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
"response_split": "### Response:"
}
================================================
FILE: dataloader/templates/alpaca_short.json
================================================
{
"description": "A shorter template to experiment with.",
"prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}
================================================
FILE: dataloader/templates/vigogne.json
================================================
{
"description": "French template, used by Vigogne for finetuning.",
"prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
"prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
"response_split": "### Réponse:"
}
================================================
FILE: dataloader/utils.py
================================================
import json
import os.path as osp
from typing import Union
class Prompter(object):
__slots__ = ("template", "_verbose")
def __init__(self, template_name: str = "", verbose: bool = False):
self._verbose = verbose
if not template_name:
# template_name = "alpaca"
template_name = "alpaca_short"
file_name = osp.join("dataloader", "templates", f"{template_name}.json")
if not osp.exists(file_name):
raise ValueError(f"Can't read {file_name}")
with open(file_name) as fp:
self.template = json.load(fp)
if self._verbose:
print(
f"Using prompt template {template_name}: {self.template['description']}"
)
def generate_prompt(
self,
instruction: str,
input: Union[None, str] = None,
label: Union[None, str] = None,
) -> str:
if input:
res = self.template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = self.template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
if self._verbose:
print(res)
return res
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip()
================================================
FILE: datasets/__init__.py
================================================
from .ml_100k import ML100KDataset
from .beauty import BeautyDataset
from .games import GamesDataset
DATASETS = {
ML100KDataset.code(): ML100KDataset,
BeautyDataset.code(): BeautyDataset,
GamesDataset.code(): GamesDataset,
}
def dataset_factory(args):
dataset = DATASETS[args.dataset_code]
return dataset(args)
================================================
FILE: datasets/base.py
================================================
import pickle
import shutil
import tempfile
import os
from pathlib import Path
import gzip
from abc import *
from .utils import *
from config import RAW_DATASET_ROOT_FOLDER
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
class AbstractDataset(metaclass=ABCMeta):
def __init__(self, args):
self.args = args
self.min_rating = args.min_rating
self.min_uc = args.min_uc
self.min_sc = args.min_sc
assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test'
@classmethod
@abstractmethod
def code(cls):
pass
@classmethod
def raw_code(cls):
return cls.code()
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return []
@classmethod
@abstractmethod
def url(cls):
pass
@abstractmethod
def preprocess(self):
pass
@abstractmethod
def load_ratings_df(self):
pass
@abstractmethod
def maybe_download_raw_dataset(self):
pass
def load_dataset(self):
self.preprocess()
dataset_path = self._get_preprocessed_dataset_path()
dataset = pickle.load(dataset_path.open('rb'))
return dataset
def filter_triplets(self, df):
print('Filtering triplets')
if self.min_sc > 1 or self.min_uc > 1:
item_sizes = df.groupby('sid').size()
good_items = item_sizes.index[item_sizes >= self.min_sc]
user_sizes = df.groupby('uid').size()
good_users = user_sizes.index[user_sizes >= self.min_uc]
while len(good_items) < len(item_sizes) or len(good_users) < len(user_sizes):
if self.min_sc > 1:
item_sizes = df.groupby('sid').size()
good_items = item_sizes.index[item_sizes >= self.min_sc]
df = df[df['sid'].isin(good_items)]
if self.min_uc > 1:
user_sizes = df.groupby('uid').size()
good_users = user_sizes.index[user_sizes >= self.min_uc]
df = df[df['uid'].isin(good_users)]
item_sizes = df.groupby('sid').size()
good_items = item_sizes.index[item_sizes >= self.min_sc]
user_sizes = df.groupby('uid').size()
good_users = user_sizes.index[user_sizes >= self.min_uc]
return df
def densify_index(self, df):
print('Densifying index')
umap = {u: i for i, u in enumerate(set(df['uid']), start=1)}
smap = {s: i for i, s in enumerate(set(df['sid']), start=1)}
df['uid'] = df['uid'].map(umap)
df['sid'] = df['sid'].map(smap)
return df, umap, smap
def split_df(self, df, user_count):
print('Splitting')
user_group = df.groupby('uid')
user2items = user_group.progress_apply(
lambda d: list(d.sort_values(by=['timestamp', 'sid'])['sid']))
train, val, test = {}, {}, {}
for i in range(user_count):
user = i + 1
items = user2items[user]
train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]
return train, val, test
def _get_rawdata_root_path(self):
return Path(RAW_DATASET_ROOT_FOLDER)
def _get_rawdata_folder_path(self):
root = self._get_rawdata_root_path()
return root.joinpath(self.raw_code())
def _get_preprocessed_root_path(self):
root = self._get_rawdata_root_path()
return root.joinpath('preprocessed')
def _get_preprocessed_folder_path(self):
preprocessed_root = self._get_preprocessed_root_path()
folder_name = '{}_min_rating{}-min_uc{}-min_sc{}' \
.format(self.code(), self.min_rating, self.min_uc, self.min_sc)
return preprocessed_root.joinpath(folder_name)
def _get_preprocessed_dataset_path(self):
folder = self._get_preprocessed_folder_path()
return folder.joinpath('dataset.pkl')
================================================
FILE: datasets/beauty.py
================================================
from .base import AbstractDataset
from .utils import *
from datetime import date
from pathlib import Path
import pickle
import shutil
import tempfile
import os
import gzip
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
class BeautyDataset(AbstractDataset):
@classmethod
def code(cls):
return 'beauty'
@classmethod
def url(cls):
return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Beauty.csv',
'http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Beauty.json.gz']
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return ['beauty.csv', 'beauty_meta.json.gz']
def maybe_download_raw_dataset(self):
folder_path = self._get_rawdata_folder_path()
if folder_path.is_dir() and\
all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
print('Raw data already exists. Skip downloading')
return
print("Raw file doesn't exist. Downloading...")
for idx, url in enumerate(self.url()):
tmproot = Path(tempfile.mkdtemp())
tmpfile = tmproot.joinpath('file')
download(url, tmpfile)
os.makedirs(folder_path, exist_ok=True)
shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx]))
print()
def preprocess(self):
dataset_path = self._get_preprocessed_dataset_path()
if dataset_path.is_file():
print('Already preprocessed. Skip preprocessing')
return
if not dataset_path.parent.is_dir():
dataset_path.parent.mkdir(parents=True)
self.maybe_download_raw_dataset()
df = self.load_ratings_df()
meta_raw = self.load_meta_dict()
df = df[df['sid'].isin(meta_raw)] # filter items without meta info
df = self.filter_triplets(df)
df, umap, smap = self.densify_index(df)
train, val, test = self.split_df(df, len(umap))
meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}
dataset = {'train': train,
'val': val,
'test': test,
'meta': meta,
'umap': umap,
'smap': smap}
with dataset_path.open('wb') as f:
pickle.dump(dataset, f)
def load_ratings_df(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath(self.all_raw_file_names()[0])
df = pd.read_csv(file_path, header=None)
df.columns = ['uid', 'sid', 'rating', 'timestamp']
return df
def load_meta_dict(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath(self.all_raw_file_names()[1])
meta_dict = {}
with gzip.open(file_path, 'rb') as f:
for line in f:
item = eval(line)
if 'title' in item and len(item['title']) > 0:
meta_dict[item['asin'].strip()] = item['title'].strip()
return meta_dict
================================================
FILE: datasets/games.py
================================================
from .base import AbstractDataset
from .utils import *
from datetime import date
from pathlib import Path
import pickle
import shutil
import tempfile
import os
import gzip
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
class GamesDataset(AbstractDataset):
@classmethod
def code(cls):
return 'games'
@classmethod
def url(cls):
# meta_Video_Games.json.gz from snap.stanford.edu does not contain full meta info
return ['http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Video_Games.csv',
'https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2/metaFiles2/meta_Video_Games.json.gz']
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return ['games.csv', 'games_meta.json.gz']
def maybe_download_raw_dataset(self):
folder_path = self._get_rawdata_folder_path()
if folder_path.is_dir() and\
all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
print('Raw data already exists. Skip downloading')
return
print("Raw file doesn't exist. Downloading...")
for idx, url in enumerate(self.url()):
tmproot = Path(tempfile.mkdtemp())
tmpfile = tmproot.joinpath('file')
download(url, tmpfile)
os.makedirs(folder_path, exist_ok=True)
shutil.move(tmpfile, folder_path.joinpath(self.all_raw_file_names()[idx]))
print()
def preprocess(self):
dataset_path = self._get_preprocessed_dataset_path()
if dataset_path.is_file():
print('Already preprocessed. Skip preprocessing')
return
if not dataset_path.parent.is_dir():
dataset_path.parent.mkdir(parents=True)
self.maybe_download_raw_dataset()
df = self.load_ratings_df()
meta_raw = self.load_meta_dict()
df = df[df['sid'].isin(meta_raw)] # filter items without meta info
df = self.filter_triplets(df)
df, umap, smap = self.densify_index(df)
train, val, test = self.split_df(df, len(umap))
meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}
dataset = {'train': train,
'val': val,
'test': test,
'meta': meta,
'umap': umap,
'smap': smap}
with dataset_path.open('wb') as f:
pickle.dump(dataset, f)
def load_ratings_df(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath(self.all_raw_file_names()[0])
df = pd.read_csv(file_path, header=None)
df.columns = ['uid', 'sid', 'rating', 'timestamp']
return df
def load_meta_dict(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath(self.all_raw_file_names()[1])
meta_dict = {}
with gzip.open(file_path, 'rb') as f:
for line in f:
item = eval(line)
if 'title' in item and len(item['title']) > 0:
meta_dict[item['asin'].strip()] = item['title'].strip()
return meta_dict
================================================
FILE: datasets/ml_100k.py
================================================
from .base import AbstractDataset
from .utils import *
from datetime import date
from pathlib import Path
import pickle
import shutil
import tempfile
import os
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
class ML100KDataset(AbstractDataset):
@classmethod
def code(cls):
return 'ml-100k'
@classmethod
def url(cls): # as of Sep 2023
return 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return ['README',
'movies.csv',
'ratings.csv',
'users.csv']
def maybe_download_raw_dataset(self):
folder_path = self._get_rawdata_folder_path()
if folder_path.is_dir() and\
all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
print('Raw data already exists. Skip downloading')
return
print("Raw file doesn't exist. Downloading...")
tmproot = Path(tempfile.mkdtemp())
tmpzip = tmproot.joinpath('file.zip')
tmpfolder = tmproot.joinpath('folder')
download(self.url(), tmpzip)
unzip(tmpzip, tmpfolder)
if self.zip_file_content_is_folder():
tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0])
shutil.move(tmpfolder, folder_path)
shutil.rmtree(tmproot)
print()
def preprocess(self):
dataset_path = self._get_preprocessed_dataset_path()
if dataset_path.is_file():
print('Already preprocessed. Skip preprocessing')
return
if not dataset_path.parent.is_dir():
dataset_path.parent.mkdir(parents=True)
self.maybe_download_raw_dataset()
df = self.load_ratings_df()
meta_raw = self.load_meta_dict()
df = df[df['sid'].isin(meta_raw)] # filter items without meta info
df = self.filter_triplets(df)
df, umap, smap = self.densify_index(df)
train, val, test = self.split_df(df, len(umap))
meta = {smap[k]: v for k, v in meta_raw.items() if k in smap}
dataset = {'train': train,
'val': val,
'test': test,
'meta': meta,
'umap': umap,
'smap': smap}
with dataset_path.open('wb') as f:
pickle.dump(dataset, f)
def load_ratings_df(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath('ratings.csv')
df = pd.read_csv(file_path)
df.columns = ['uid', 'sid', 'rating', 'timestamp']
return df
def load_meta_dict(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath('movies.csv')
df = pd.read_csv(file_path, encoding="ISO-8859-1")
meta_dict = {}
for row in df.itertuples():
title = row[2][:-7] # remove year (optional)
year = row[2][-7:]
title = re.sub('\(.*?\)', '', title).strip()
# the rest articles and parentheses are not considered here
if any(', '+x in title.lower()[-5:] for x in ['a', 'an', 'the']):
title_pre = title.split(', ')[:-1]
title_post = title.split(', ')[-1]
title_pre = ', '.join(title_pre)
title = title_post + ' ' + title_pre
meta_dict[row[1]] = title + year
return meta_dict
================================================
FILE: datasets/utils.py
================================================
import numpy as np
import pandas as pd
from tqdm import tqdm
import urllib.request
from pathlib import Path
import zipfile
import tarfile
import sys
def download(url, savepath):
urllib.request.urlretrieve(url, str(savepath))
print()
def unzip(zippath, savepath):
print("Extracting data...")
zip = zipfile.ZipFile(zippath)
zip.extractall(savepath)
zip.close()
def unziptargz(zippath, savepath):
print("Extracting data...")
f = tarfile.open(zippath)
f.extractall(savepath)
f.close()
================================================
FILE: model/__init__.py
================================================
from .lru import *
from .llm import *
================================================
FILE: model/llm.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from transformers.models.llama.configuration_llama import LlamaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.pretraining_tp > 1:
slice = self.intermediate_size // self.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.pretraining_tp = config.pretraining_tp
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if self.training and labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
elif labels is not None:
loss = torch.tensor(-1.) # loss cannot be directly computed in inference
logits = logits[:, -1] # we only need last position logits for inference
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
================================================
FILE: model/lru.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class LRURec(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.embedding = LRUEmbedding(self.args)
self.model = LRUModel(self.args)
self.truncated_normal_init()
def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
with torch.no_grad():
l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.
for n, p in self.named_parameters():
if not 'layer_norm' in n and 'params_log' not in n:
if torch.is_complex(p):
p.real.uniform_(2 * l - 1, 2 * u - 1)
p.imag.uniform_(2 * l - 1, 2 * u - 1)
p.real.erfinv_()
p.imag.erfinv_()
p.real.mul_(std * math.sqrt(2.))
p.imag.mul_(std * math.sqrt(2.))
p.real.add_(mean)
p.imag.add_(mean)
else:
p.uniform_(2 * l - 1, 2 * u - 1)
p.erfinv_()
p.mul_(std * math.sqrt(2.))
p.add_(mean)
def forward(self, x):
x, mask = self.embedding(x)
scores = self.model(x, self.embedding.token.weight, mask)
return scores
class LRUEmbedding(nn.Module):
def __init__(self, args):
super().__init__()
vocab_size = args.num_items + 1
embed_size = args.bert_hidden_units
self.token = nn.Embedding(vocab_size, embed_size)
self.layer_norm = nn.LayerNorm(embed_size)
self.embed_dropout = nn.Dropout(args.bert_dropout)
def get_mask(self, x):
return (x > 0)
def forward(self, x):
mask = self.get_mask(x)
x = self.token(x)
return self.layer_norm(self.embed_dropout(x)), mask
class LRUModel(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.hidden_size = args.bert_hidden_units
layers = args.bert_num_blocks
self.lru_blocks = nn.ModuleList([LRUBlock(self.args) for _ in range(layers)])
self.bias = torch.nn.Parameter(torch.zeros(args.num_items + 1))
def forward(self, x, embedding_weight, mask):
# left padding to the power of 2
seq_len = x.size(1)
log2_L = int(np.ceil(np.log2(seq_len)))
x = F.pad(x, (0, 0, 2 ** log2_L - x.size(1), 0, 0, 0))
mask_ = F.pad(mask, (2 ** log2_L - mask.size(1), 0, 0, 0))
# LRU blocks with pffn
for lru_block in self.lru_blocks:
x = lru_block.forward(x, mask_)
x = x[:, -seq_len:] # B x L x D (64)
scores = torch.matmul(x, embedding_weight.permute(1, 0)) + self.bias
return scores
class LRUBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
hidden_size = args.bert_hidden_units
self.lru_layer = LRULayer(
d_model=hidden_size, dropout=args.bert_attn_dropout)
self.feed_forward = PositionwiseFeedForward(
d_model=hidden_size, d_ff=hidden_size*4, dropout=args.bert_dropout)
def forward(self, x, mask):
x = self.lru_layer(x, mask)
x = self.feed_forward(x)
return x
class LRULayer(nn.Module):
def __init__(self,
d_model,
dropout=0.1,
use_bias=True,
r_min=0.8,
r_max=0.99):
super().__init__()
self.embed_size = d_model
self.hidden_size = 2 * d_model
self.use_bias = use_bias
# init nu, theta, gamma
u1 = torch.rand(self.hidden_size)
u2 = torch.rand(self.hidden_size)
nu_log = torch.log(-0.5 * torch.log(u1 * (r_max ** 2 - r_min ** 2) + r_min ** 2))
theta_log = torch.log(u2 * torch.tensor(np.pi) * 2)
diag_lambda = torch.exp(torch.complex(-torch.exp(nu_log), torch.exp(theta_log)))
gamma_log = torch.log(torch.sqrt(1 - torch.abs(diag_lambda) ** 2))
self.params_log = nn.Parameter(torch.vstack((nu_log, theta_log, gamma_log)))
# Init B, C, D
self.in_proj = nn.Linear(self.embed_size, self.hidden_size, bias=use_bias).to(torch.cfloat)
self.out_proj = nn.Linear(self.hidden_size, self.embed_size, bias=use_bias).to(torch.cfloat)
# self.out_vector = nn.Parameter(torch.rand(self.embed_size))
self.out_vector = nn.Identity()
# Dropout and layer norm
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(self.embed_size)
def lru_parallel(self, i, h, lamb, mask, B, L, D):
# Parallel algorithm, see: https://kexue.fm/archives/9554#%E5%B9%B6%E8%A1%8C%E5%8C%96
# The original implementation is slightly slower and does not consider 0 padding
l = 2 ** i
h = h.reshape(B * L // l, l, D) # (B, L, D) -> (B * L // 2, 2, D)
mask_ = mask.reshape(B * L // l, l) # (B, L) -> (B * L // 2, 2)
h1, h2 = h[:, :l // 2], h[:, l // 2:] # Divide data in half
if i > 1: lamb = torch.cat((lamb, lamb * lamb[-1]), 0)
h2 = h2 + lamb * h1[:, -1:] * mask_[:, l // 2 - 1:l // 2].unsqueeze(-1)
h = torch.cat([h1, h2], axis=1)
return h, lamb
def forward(self, x, mask):
# compute bu and lambda
nu, theta, gamma = torch.exp(self.params_log).split((1, 1, 1))
lamb = torch.exp(torch.complex(-nu, theta))
h = self.in_proj(x.to(torch.cfloat)) * gamma # bu
# compute h in parallel
log2_L = int(np.ceil(np.log2(h.size(1))))
B, L, D = h.size(0), h.size(1), h.size(2)
for i in range(log2_L):
h, lamb = self.lru_parallel(i + 1, h, lamb, mask, B, L, D)
x = self.dropout(self.out_proj(h).real) + self.out_vector(x)
return self.layer_norm(x) # residual connection introduced above
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x):
x_ = self.dropout(self.activation(self.w_1(x)))
return self.layer_norm(self.dropout(self.w_2(x_)) + x)
================================================
FILE: requirements.txt
================================================
# This file may be used to create an environment using:
# $ conda create --name --file
# platform: linux-64
_libgcc_mutex=0.1=main
_openmp_mutex=5.1=1_gnu
absl-py=1.4.0=pypi_0
accelerate=0.21.0=pypi_0
aiofiles=23.1.0=pypi_0
aiohttp=3.8.4=pypi_0
aiosignal=1.3.1=pypi_0
alabaster=0.7.13=pypi_0
altair=4.2.2=pypi_0
aniso8601=9.0.1=pypi_0
antlr4-python3-runtime=4.9.3=pypi_0
anyio=3.6.2=pypi_0
appdirs=1.4.4=pypi_0
asttokens=2.2.1=pypi_0
async-timeout=4.0.2=pypi_0
attrdict=2.0.1=pypi_0
attrs=23.1.0=pypi_0
audioread=3.0.0=pypi_0
babel=2.12.1=pypi_0
backcall=0.2.0=pypi_0
beautifulsoup4=4.12.2=pypi_0
bitsandbytes=0.41.1=pypi_0
black=19.10b0=pypi_0
blas=1.0=mkl
boto3=1.26.160=pypi_0
botocore=1.29.160=pypi_0
braceexpand=0.1.7=pypi_0
brotlipy=0.7.0=py310h7f8727e_1002
bzip2=1.0.8=h7b6447c_0
ca-certificates=2023.08.22=h06a4308_0
cachetools=5.3.0=pypi_0
cdifflib=1.2.6=pypi_0
certifi=2022.12.7=py310h06a4308_0
cffi=1.15.1=py310h5eee18b_3
charset-normalizer=2.0.4=pyhd3eb1b0_0
click=8.0.2=pypi_0
clip=1.0=pypi_0
colorama=0.4.6=pypi_0
comm=0.1.3=pypi_0
contourpy=1.0.7=pypi_0
cryptography=39.0.1=py310h9ce1e76_0
cuda-cudart=11.8.89=0
cuda-cupti=11.8.87=0
cuda-libraries=11.8.0=0
cuda-nvrtc=11.8.89=0
cuda-nvtx=11.8.86=0
cuda-runtime=11.8.0=0
cudatoolkit=11.8.0=h6a678d5_0
cycler=0.11.0=pypi_0
cython=0.29.35=pypi_0
datasets=2.14.5=pypi_0
debugpy=1.6.7=pypi_0
decorator=5.1.1=pypi_0
dill=0.3.6=pypi_0
distance=0.1.3=pypi_0
docker-pycreds=0.4.0=pypi_0
docopt=0.6.2=pypi_0
docutils=0.20.1=pypi_0
editdistance=0.6.2=pypi_0
einops=0.6.1=pypi_0
emoji=2.8.0=pypi_0
entrypoints=0.4=pypi_0
evaluate=0.3.0=pypi_0
exceptiongroup=1.1.1=pypi_0
executing=1.2.0=pypi_0
fairscale=0.4.13=pypi_0
faiss-cpu=1.7.4=pypi_0
faiss-gpu=1.7.2=pypi_0
fastapi=0.95.1=pypi_0
fasttext=0.9.2=pypi_0
ffmpeg=4.3=hf484d3e_0
ffmpy=0.3.0=pypi_0
filelock=3.9.0=py310h06a4308_0
fire=0.5.0=pypi_0
flask=2.2.5=pypi_0
flask-restful=0.3.10=pypi_0
flit-core=3.8.0=py310h06a4308_0
fonttools=4.39.3=pypi_0
freetype=2.12.1=h4a9f257_0
frozenlist=1.3.3=pypi_0
fsspec=2023.4.0=pypi_0
ftfy=6.1.1=pypi_0
future=0.18.3=pypi_0
g2p-en=2.1.0=pypi_0
gdown=4.7.1=pypi_0
giflib=5.2.1=h5eee18b_3
gitdb=4.0.10=pypi_0
gitpython=3.1.31=pypi_0
gmp=6.2.1=h295c915_3
gmpy2=2.1.2=py310heeb90bb_0
gnutls=3.6.15=he1e5248_0
google-auth=2.17.3=pypi_0
google-auth-oauthlib=1.0.0=pypi_0
gradio=3.28.3=pypi_0
gradio-client=0.1.3=pypi_0
grpcio=1.54.0=pypi_0
h11=0.14.0=pypi_0
h5py=3.9.0=pypi_0
httpcore=0.17.0=pypi_0
httpx=0.24.0=pypi_0
huggingface-hub=0.16.4=pypi_0
hydra-core=1.2.0=pypi_0
idna=3.4=py310h06a4308_0
ijson=3.2.2=pypi_0
imageio=2.28.0=pypi_0
imagesize=1.4.1=pypi_0
inflect=6.0.4=pypi_0
iniconfig=2.0.0=pypi_0
intel-openmp=2021.4.0=h06a4308_3561
ipadic=1.0.0=pypi_0
ipykernel=6.23.3=pypi_0
ipython=8.12.0=pypi_0
ipywidgets=8.0.6=pypi_0
isort=5.12.0=pypi_0
itsdangerous=2.1.2=pypi_0
jedi=0.18.2=pypi_0
jieba=0.42.1=pypi_0
jinja2=3.1.2=py310h06a4308_0
jiwer=2.5.2=pypi_0
jmespath=1.0.1=pypi_0
joblib=1.2.0=pypi_0
jpeg=9e=h5eee18b_1
jsonschema=4.17.3=pypi_0
jupyter-client=8.3.0=pypi_0
jupyter-core=5.3.1=pypi_0
jupyterlab-widgets=3.0.7=pypi_0
kaldi-python-io=1.2.2=pypi_0
kaldiio=2.18.0=pypi_0
kiwisolver=1.4.4=pypi_0
kornia=0.6.12=pypi_0
lame=3.100=h7b6447c_0
latexcodec=2.0.1=pypi_0
lazy-loader=0.2=pypi_0
lcms2=2.12=h3be6417_0
ld_impl_linux-64=2.38=h1181459_1
lerc=3.0=h295c915_0
levenshtein=0.21.1=pypi_0
libcublas=11.11.3.6=0
libcufft=10.9.0.58=0
libcufile=1.6.0.25=0
libcurand=10.3.2.56=0
libcusolver=11.4.1.48=0
libcusparse=11.7.5.86=0
libdeflate=1.17=h5eee18b_0
libffi=3.4.2=h6a678d5_6
libgcc-ng=11.2.0=h1234567_1
libgomp=11.2.0=h1234567_1
libiconv=1.16=h7f8727e_2
libidn2=2.3.2=h7f8727e_0
libnpp=11.8.0.86=0
libnvjpeg=11.9.0.86=0
libpng=1.6.39=h5eee18b_0
librosa=0.10.0.post2=pypi_0
libstdcxx-ng=11.2.0=h1234567_1
libtasn1=4.16.0=h27cfd23_0
libtiff=4.5.0=h6a678d5_2
libunistring=0.9.10=h27cfd23_0
libuuid=1.41.5=h5eee18b_0
libwebp=1.2.4=h11a3e52_1
libwebp-base=1.2.4=h5eee18b_1
lightning-utilities=0.8.0=pypi_0
linkify-it-py=2.0.0=pypi_0
llvmlite=0.40.1=pypi_0
loguru=0.7.0=pypi_0
loralib=0.1.1=pypi_0
lxml=4.9.2=pypi_0
lz4-c=1.9.4=h6a678d5_0
markdown=3.4.3=pypi_0
markdown-it-py=2.2.0=pypi_0
markdown2=2.4.9=pypi_0
markupsafe=2.1.1=py310h7f8727e_0
marshmallow=3.19.0=pypi_0
matplotlib=3.7.1=pypi_0
matplotlib-inline=0.1.6=pypi_0
mdit-py-plugins=0.3.3=pypi_0
mdurl=0.1.2=pypi_0
mecab-python3=1.0.5=pypi_0
megatron-core=0.2.0=pypi_0
mkl=2021.4.0=h06a4308_640
mkl-service=2.4.0=py310h7f8727e_0
mkl_fft=1.3.1=py310hd6ae3a3_0
mkl_random=1.2.2=py310h00e6091_0
mpc=1.1.0=h10f8cd9_1
mpfr=4.0.2=hb69a4c5_1
mpmath=1.2.1=pypi_0
msgpack=1.0.5=pypi_0
multidict=6.0.4=pypi_0
multiprocess=0.70.14=pypi_0
mypy-extensions=1.0.0=pypi_0
ncurses=6.4=h6a678d5_0
nest-asyncio=1.5.6=pypi_0
nettle=3.7.3=hbbd107a_1
networkx=2.8.4=py310h06a4308_1
nltk=3.8=pypi_0
numba=0.57.1=pypi_0
numpy=1.23.4=pypi_0
numpy-base=1.24.3=py310h8e6c178_0
oauthlib=3.2.2=pypi_0
omegaconf=2.2.3=pypi_0
onnx=1.14.0=pypi_0
openai=0.27.1=pypi_0
opencc=1.1.6=pypi_0
opencv-python=4.7.0.72=pypi_0
openh264=2.1.1=h4ff587b_0
openprompt=1.0.1=pypi_0
openssl=1.1.1w=h7f8727e_0
orjson=3.8.10=pypi_0
packaging=23.1=pypi_0
pandas=2.0.1=pypi_0
pangu=4.0.6.1=pypi_0
parameterized=0.9.0=pypi_0
parso=0.8.3=pypi_0
pathspec=0.11.1=pypi_0
pathtools=0.1.2=pypi_0
peft=0.5.0=pypi_0
pexpect=4.8.0=pypi_0
pickleshare=0.7.5=pypi_0
pillow=9.4.0=py310h6a678d5_0
pip=23.2.1=pypi_0
plac=1.3.5=pypi_0
platformdirs=3.4.0=pypi_0
pluggy=1.2.0=pypi_0
pooch=1.6.0=pypi_0
portalocker=2.7.0=pypi_0
progress=1.6=pypi_0
prompt-toolkit=3.0.38=pypi_0
protobuf=3.20.3=pypi_0
psutil=5.9.5=pypi_0
ptyprocess=0.7.0=pypi_0
pure-eval=0.2.2=pypi_0
pyannote-core=5.0.0=pypi_0
pyannote-database=5.0.1=pypi_0
pyannote-metrics=3.2.1=pypi_0
pyarrow=11.0.0=pypi_0
pyasn1=0.5.0=pypi_0
pyasn1-modules=0.3.0=pypi_0
pybind11=2.10.4=pypi_0
pybtex=0.24.0=pypi_0
pybtex-docutils=1.0.2=pypi_0
pycparser=2.21=pyhd3eb1b0_0
pydantic=1.10.7=pypi_0
pydeprecate=0.3.1=pypi_0
pydub=0.25.1=pypi_0
pygments=2.15.1=pypi_0
pynini=2.1.5=pypi_0
pyopenssl=23.0.0=py310h06a4308_0
pyparsing=3.0.9=pypi_0
pypinyin=0.49.0=pypi_0
pypinyin-dict=0.6.0=pypi_0
pyrsistent=0.19.3=pypi_0
pysocks=1.7.1=py310h06a4308_0
pytest=7.4.0=pypi_0
pytest-runner=6.0.0=pypi_0
python=3.10.10=h7a1cb2a_2
python-dateutil=2.8.2=pypi_0
python-multipart=0.0.6=pypi_0
pytorch=2.0.0=py3.10_cuda11.8_cudnn8.7.0_0
pytorch-cuda=11.8=h7e8668a_3
pytorch-lightning=1.9.4=pypi_0
pytorch-mutex=1.0=cuda
pytz=2023.3=pypi_0
pyyaml=5.4.1=pypi_0
pyzmq=25.1.0=pypi_0
rank-bm25=0.2.2=pypi_0
rapidfuzz=2.13.7=pypi_0
readline=8.2=h5eee18b_0
regex=2023.3.23=pypi_0
requests=2.28.1=py310h06a4308_1
requests-oauthlib=1.3.1=pypi_0
responses=0.18.0=pypi_0
rich=13.4.2=pypi_0
risparser=0.4.4=pypi_0
rouge=1.0.0=pypi_0
rouge-score=0.1.2=pypi_0
rsa=4.9=pypi_0
ruamel-yaml=0.17.32=pypi_0
ruamel-yaml-clib=0.2.7=pypi_0
s3transfer=0.6.1=pypi_0
sacrebleu=2.3.1=pypi_0
sacremoses=0.0.53=pypi_0
safetensors=0.3.1=pypi_0
scikit-learn=1.2.1=pypi_0
scipy=1.10.1=pypi_0
seaborn=0.12.2=pypi_0
semantic-version=2.10.0=pypi_0
sentence-transformers=2.2.2=pypi_0
sentencepiece=0.1.96=pypi_0
sentry-sdk=1.25.1=pypi_0
setproctitle=1.3.2=pypi_0
setuptools=65.5.1=pypi_0
shellingham=1.5.0.post1=pypi_0
six=1.16.0=pyhd3eb1b0_1
smmap=5.0.0=pypi_0
sniffio=1.3.0=pypi_0
snowballstemmer=2.2.0=pypi_0
sortedcontainers=2.4.0=pypi_0
soundfile=0.12.1=pypi_0
soupsieve=2.4.1=pypi_0
sox=1.4.1=pypi_0
soxr=0.3.5=pypi_0
sphinx=7.0.1=pypi_0
sphinxcontrib-applehelp=1.0.4=pypi_0
sphinxcontrib-bibtex=2.5.0=pypi_0
sphinxcontrib-devhelp=1.0.2=pypi_0
sphinxcontrib-htmlhelp=2.0.1=pypi_0
sphinxcontrib-jsmath=1.0.1=pypi_0
sphinxcontrib-qthelp=1.0.3=pypi_0
sphinxcontrib-serializinghtml=1.1.5=pypi_0
sqlite=3.41.1=h5eee18b_0
stack-data=0.6.2=pypi_0
starlette=0.26.1=pypi_0
sympy=1.11.1=py310h06a4308_0
tabulate=0.9.0=pypi_0
taming-transformers=0.0.1=pypi_0
taming-transformers-rom1504=0.0.6=pypi_0
tensorboard=2.12.2=pypi_0
tensorboard-data-server=0.7.0=pypi_0
tensorboard-plugin-wit=1.8.1=pypi_0
tensorboardx=2.6.2.2=pypi_0
termcolor=2.2.0=pypi_0
test-tube=0.7.5=pypi_0
text-unidecode=1.3=pypi_0
textdistance=4.5.0=pypi_0
texterrors=0.4.4=pypi_0
threadpoolctl=3.1.0=pypi_0
tk=8.6.12=h1ccaba5_0
tokenize-rt=5.0.0=pypi_0
tokenizers=0.13.3=pypi_0
toml=0.10.2=pypi_0
tomli=2.0.1=pypi_0
toolz=0.12.0=pypi_0
torchaudio=2.0.0=py310_cu118
torchmetrics=0.11.4=pypi_0
torchtriton=2.0.0=py310
torchvision=0.15.0=py310_cu118
tornado=6.3.2=pypi_0
tqdm=4.64.1=pypi_0
traitlets=5.9.0=pypi_0
transformers=4.33.3=pypi_0
trl=0.7.1=pypi_0
tweet-preprocessor=0.6.0=pypi_0
typed-ast=1.5.4=pypi_0
typer=0.9.0=pypi_0
typing_extensions=4.4.0=py310h06a4308_0
tzdata=2023.3=pypi_0
uc-micro-py=1.0.1=pypi_0
unidecode=1.3.7=pypi_0
urllib3=1.26.15=py310h06a4308_0
uvicorn=0.21.1=pypi_0
wandb=0.15.4=pypi_0
wcwidth=0.2.6=pypi_0
webdataset=0.1.62=pypi_0
websockets=11.0.2=pypi_0
werkzeug=2.3.0=pypi_0
wget=3.2=pypi_0
wheel=0.38.4=py310h06a4308_0
widgetsnbextension=4.0.7=pypi_0
wrapt=1.15.0=pypi_0
xxhash=3.2.0=pypi_0
xz=5.2.10=h5eee18b_1
yacs=0.1.8=pypi_0
yarl=1.9.2=pypi_0
youtokentome=1.0.6=pypi_0
zlib=1.2.13=h5eee18b_0
zstd=1.5.4=hc292b87_0
================================================
FILE: train_ranker.py
================================================
import os
import torch
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import argparse
from datasets import DATASETS
from config import *
from model import *
from dataloader import *
from trainer import *
from transformers import BitsAndBytesConfig
from pytorch_lightning import seed_everything
from model import LlamaForCausalLM
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
)
try:
os.environ['WANDB_PROJECT'] = PROJECT_NAME
except:
print('WANDB_PROJECT not available, please set it in config.py')
def main(args, export_root=None):
seed_everything(args.seed)
if export_root == None:
export_root = EXPERIMENT_ROOT + '/' + args.llm_base_model.split('/')[-1] + '/' + args.dataset_code
train_loader, val_loader, test_loader, tokenizer, test_retrieval = dataloader_factory(args)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = LlamaForCausalLM.from_pretrained(
args.llm_base_model,
quantization_config=bnb_config,
device_map='auto',
cache_dir=args.llm_cache_dir,
)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=args.lora_target_modules,
lora_dropout=args.lora_dropout,
bias='none',
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
model.config.use_cache = False
trainer = LLMTrainer(args, model, train_loader, val_loader, test_loader, tokenizer, export_root, args.use_wandb)
trainer.train()
trainer.test(test_retrieval)
if __name__ == "__main__":
args.model_code = 'llm'
set_template(args)
main(args, export_root=None)
================================================
FILE: train_retriever.py
================================================
import os
import torch
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import wandb
import argparse
from config import *
from model import *
from dataloader import *
from trainer import *
from pytorch_lightning import seed_everything
try:
os.environ['WANDB_PROJECT'] = PROJECT_NAME
except:
print('WANDB_PROJECT not available, please set it in config.py')
def main(args, export_root=None):
seed_everything(args.seed)
train_loader, val_loader, test_loader = dataloader_factory(args)
model = LRURec(args)
if export_root == None:
export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code
trainer = LRUTrainer(args, model, train_loader, val_loader, test_loader, export_root, args.use_wandb)
trainer.train()
trainer.test()
# the next line generates val / test candidates for reranking
trainer.generate_candidates(os.path.join(export_root, 'retrieved.pkl'))
if __name__ == "__main__":
args.model_code = 'lru'
set_template(args)
main(args, export_root=None)
# # searching best hyperparameters
# for decay in [0, 0.01]:
# for dropout in [0, 0.1, 0.2, 0.3, 0.4, 0.5]:
# args.weight_decay = decay
# args.bert_dropout = dropout
# args.bert_attn_dropout = dropout
# export_root = EXPERIMENT_ROOT + '/' + args.model_code + '/' + args.dataset_code + '/' + str(decay) + '_' + str(dropout)
# main(args, export_root=export_root)
================================================
FILE: trainer/__init__.py
================================================
from .lru import *
from .llm import *
from .utils import *
================================================
FILE: trainer/base.py
================================================
from model import *
from config import *
from .utils import *
from .loggers import *
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
import json
import numpy as np
from abc import ABCMeta
from pathlib import Path
from collections import OrderedDict
class BaseTrainer(metaclass=ABCMeta):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb=True):
self.args = args
self.device = args.device
self.model = model.to(self.device)
self.num_epochs = args.num_epochs
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.optimizer = self._create_optimizer()
if args.enable_lr_schedule:
if args.enable_lr_warmup:
self.lr_scheduler = self.get_linear_schedule_with_warmup(
self.optimizer, args.warmup_steps, len(self.train_loader) * self.num_epochs)
else:
self.lr_scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=args.decay_step, gamma=args.gamma)
self.export_root = export_root
if not os.path.exists(self.export_root):
Path(self.export_root).mkdir(parents=True)
self.use_wandb = use_wandb
if use_wandb:
import wandb
wandb.init(
name=self.args.model_code+'_'+self.args.dataset_code,
project=PROJECT_NAME,
config=args,
)
writer = wandb
else:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(
log_dir=Path(self.export_root).joinpath('logs'),
comment=self.args.model_code+'_'+self.args.dataset_code,
)
self.val_loggers, self.test_loggers = self._create_loggers()
self.logger_service = LoggerService(
self.args, writer, self.val_loggers, self.test_loggers, use_wandb)
print(args)
def train(self):
accum_iter = 0
self.exit_training = self.validate(0, accum_iter)
for epoch in range(self.num_epochs):
accum_iter = self.train_one_epoch(epoch, accum_iter)
if self.args.val_strategy == 'epoch':
self.exit_training = self.validate(epoch, accum_iter) # val after every epoch
if self.exit_training:
print('Early stopping triggered. Exit training')
break
self.logger_service.complete()
def train_one_epoch(self, epoch, accum_iter):
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(self.train_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
self.model.train()
batch = self.to_device(batch)
self.optimizer.zero_grad()
loss = self.calculate_loss(batch)
loss.backward()
self.clip_gradients(self.args.max_grad_norm)
self.optimizer.step()
if self.args.enable_lr_schedule:
self.lr_scheduler.step()
average_meter_set.update('loss', loss.item())
tqdm_dataloader.set_description(
'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))
accum_iter += 1
if self.args.val_strategy == 'iteration' and accum_iter % self.args.val_iterations == 0:
self.exit_training = self.validate(epoch, accum_iter) # val after certain iterations
if self.exit_training: break
return accum_iter
def validate(self, epoch, accum_iter):
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = self.to_device(batch)
metrics = self.calculate_metrics(batch, exclude_history=False) # faster validation
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
return self.logger_service.log_val(log_data) # early stopping
def test(self, epoch=-1, accum_iter=-1, save_name=None):
print('******************** Testing Best Model ********************')
best_model_dict = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.model.load_state_dict(best_model_dict)
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = self.to_device(batch)
metrics = self.calculate_metrics(batch)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
average_metrics = average_meter_set.averages()
log_data.update(average_metrics)
self.logger_service.log_test(log_data)
print('******************** Testing Metrics ********************')
print(average_metrics)
file_name = 'test_metrics.json' if save_name is None else save_name
with open(os.path.join(self.export_root, file_name), 'w') as f:
json.dump(average_metrics, f, indent=4)
return average_metrics
def to_device(self, batch):
return [x.to(self.device) for x in batch]
@abstractmethod
def calculate_loss(self, batch):
pass
@abstractmethod
def calculate_metrics(self, batch):
pass
def clip_gradients(self, limit=1.0):
nn.utils.clip_grad_norm_(self.model.parameters(), limit)
def _update_meter_set(self, meter_set, metrics):
for k, v in metrics.items():
meter_set.update(k, v)
def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]
] + ['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Eval: ' + \
', '.join(s + ' {:.4f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(
*(meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
def _create_optimizer(self):
args = self.args
param_optimizer = list(self.model.named_parameters())
no_decay = ['bias', 'layer_norm']
optimizer_grouped_parameters = [
{
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay,
},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.},
]
if args.optimizer.lower() == 'adamw':
return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
elif args.optimizer.lower() == 'adam':
return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise NotImplementedError
def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _create_loggers(self):
root = Path(self.export_root)
model_checkpoint = root.joinpath('models')
val_loggers, test_loggers = [], []
for k in self.metric_ks:
val_loggers.append(
MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation', use_wandb=self.use_wandb))
val_loggers.append(
MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation', use_wandb=self.use_wandb))
val_loggers.append(
MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Validation', use_wandb=self.use_wandb))
val_loggers.append(RecentModelLogger(self.args, model_checkpoint))
val_loggers.append(BestModelLogger(self.args, model_checkpoint, metric_key=self.best_metric))
for k in self.metric_ks:
test_loggers.append(
MetricGraphPrinter(key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Test', use_wandb=self.use_wandb))
test_loggers.append(
MetricGraphPrinter(key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Test', use_wandb=self.use_wandb))
test_loggers.append(
MetricGraphPrinter(key='MRR@%d' % k, graph_name='MRR@%d' % k, group_name='Test', use_wandb=self.use_wandb))
return val_loggers, test_loggers
def _create_state_dict(self):
return {
STATE_DICT_KEY: self.model.state_dict(),
OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
}
================================================
FILE: trainer/llm.py
================================================
from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY
from .verb import ManualVerbalizer
from .utils import *
from .loggers import *
from .base import *
import re
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import json
import numpy as np
from abc import *
from pathlib import Path
import bitsandbytes as bnb
from transformers.trainer import *
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
def llama_collate_fn_w_truncation(llm_max_length, eval=False):
def llama_collate_fn(batch):
all_input_ids = []
all_attention_mask = []
all_labels = []
example_max_length = max([len(batch[idx]['input_ids']) for idx in range(len(batch))])
max_length = min(llm_max_length, example_max_length)
for i in range(len(batch)):
input_ids = batch[i]['input_ids']
attention_mask = batch[i]['attention_mask']
labels = batch[i]['labels']
if len(input_ids) > max_length:
input_ids = input_ids[-max_length:]
attention_mask = attention_mask[-max_length:]
if not eval: labels = labels[-max_length:]
elif len(input_ids) < max_length:
padding_length = max_length - len(input_ids)
input_ids = [0] * padding_length + input_ids
attention_mask = [0] * padding_length + attention_mask
if not eval: labels = [-100] * padding_length + labels
if eval: assert input_ids[-1] == 13
else:
assert input_ids[-3] == 13 and input_ids[-1] == 2
assert labels[-3] == -100 and labels[-2] != -100
all_input_ids.append(torch.tensor(input_ids).long())
all_attention_mask.append(torch.tensor(attention_mask).long())
all_labels.append(torch.tensor(labels).long())
return {
'input_ids': torch.vstack(all_input_ids),
'attention_mask': torch.vstack(all_attention_mask),
'labels': torch.vstack(all_labels)
}
return llama_collate_fn
def compute_metrics_for_ks(ks, verbalizer):
def compute_metrics(eval_pred):
logits, labels = eval_pred
logits = torch.tensor(logits)
labels = torch.tensor(labels).view(-1)
scores = verbalizer.process_logits(logits)
metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels, ks)
return metrics
return compute_metrics
class LLMTrainer(Trainer):
def __init__(
self,
args,
model,
train_loader,
val_loader,
test_loader,
tokenizer,
export_root,
use_wandb,
**kwargs
):
self.original_args = args
self.export_root = export_root
self.use_wandb = use_wandb
self.llm_max_text_len = args.llm_max_text_len
self.rerank_metric_ks = args.rerank_metric_ks
self.verbalizer = ManualVerbalizer(
tokenizer=tokenizer,
prefix='',
post_log_softmax=False,
classes=list(range(args.llm_negative_sample_size+1)),
label_words={i: chr(ord('A')+i) for i in range(args.llm_negative_sample_size+1)},
)
hf_args = TrainingArguments(
per_device_train_batch_size=args.lora_micro_batch_size,
gradient_accumulation_steps=args.train_batch_size//args.lora_micro_batch_size,
warmup_steps=args.warmup_steps,
num_train_epochs=args.lora_num_epochs,
learning_rate=args.lora_lr,
bf16=True,
logging_steps=10,
optim="paged_adamw_32bit",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=args.lora_val_iterations,
save_steps=args.lora_val_iterations,
output_dir=export_root,
save_total_limit=3,
load_best_model_at_end=True,
ddp_find_unused_parameters=None,
group_by_length=False,
report_to="wandb" if use_wandb else None,
run_name=args.model_code+'_'+args.dataset_code if use_wandb else None,
metric_for_best_model=args.rerank_best_metric,
greater_is_better=True,
)
super().__init__(
model=model,
args=hf_args,
callbacks=[EarlyStoppingCallback(args.lora_early_stopping_patience)],
**kwargs) # hf_args is now args
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.tokenizer = tokenizer
self.train_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=False)
self.val_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True)
self.test_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True)
self.compute_metrics = compute_metrics_for_ks(self.rerank_metric_ks, self.verbalizer)
if len(self.label_names) == 0:
self.label_names = ['labels'] # for some reason label name is not set
def test(self, test_retrieval):
average_metrics = self.predict(test_dataset=None).metrics
print('Ranking Performance on Subset:', average_metrics)
print('************************************************************')
with open(os.path.join(self.export_root, 'subset_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
print('Original Performance:', test_retrieval['original_metrics'])
print('************************************************************')
original_size = test_retrieval['original_size']
retrieval_size = test_retrieval['retrieval_size']
overall_metrics = {}
for key in test_retrieval['non_retrieval_metrics'].keys():
if 'test_' + key in average_metrics:
overall_metrics['test_' + key] = (average_metrics['test_' + key] * retrieval_size + \
test_retrieval['non_retrieval_metrics'][key] * (original_size - retrieval_size)) / original_size
print('Overall Performance of Our Framework:', overall_metrics)
with open(os.path.join(self.export_root, 'overall_metrics.json'), 'w') as f:
json.dump(overall_metrics, f, indent=4)
return average_metrics
def get_train_dataloader(self):
return self.train_loader
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
return self.val_loader
def get_test_dataloader(self, test_dataset: Optional[Dataset] = None) -> DataLoader:
return self.test_loader
================================================
FILE: trainer/loggers.py
================================================
import os
import torch
from abc import ABCMeta, abstractmethod
def save_state_dict(state_dict, path, filename):
torch.save(state_dict, os.path.join(path, filename))
class LoggerService(object):
def __init__(self, args, writer, val_loggers, test_loggers, use_wandb):
self.args = args
self.writer = writer
self.val_loggers = val_loggers if val_loggers else []
self.test_loggers = test_loggers if test_loggers else []
self.use_wandb = use_wandb
def complete(self):
if self.use_wandb:
self.writer.finish()
else:
self.writer.close()
def log_val(self, log_data):
criteria_met = False
for logger in self.val_loggers:
logger.log(self.writer, **log_data)
if self.args.early_stopping and isinstance(logger, BestModelLogger):
criteria_met = logger.patience_counter >= self.args.early_stopping_patience
return criteria_met
def log_test(self, log_data):
for logger in self.test_loggers:
logger.log(self.writer, **log_data)
class AbstractBaseLogger(metaclass=ABCMeta):
@abstractmethod
def log(self, *args, **kwargs):
raise NotImplementedError
def complete(self, *args, **kwargs):
pass
class MetricGraphPrinter(AbstractBaseLogger):
def __init__(self, key, graph_name, group_name, use_wandb):
self.key = key
self.graph_label = graph_name
self.group_name = group_name
self.use_wandb = use_wandb
def log(self, writer, *args, **kwargs):
if self.key in kwargs:
if self.use_wandb:
writer.log({self.group_name+'/'+self.graph_label: kwargs[self.key], 'batch': kwargs['accum_iter']})
else:
writer.add_scalar(self.group_name+'/'+ self.graph_label, kwargs[self.key], kwargs['accum_iter'])
else:
print('Metric {} not found...'.format(self.key))
def complete(self, writer, *args, **kwargs):
self.log(writer, *args, **kwargs)
class RecentModelLogger(AbstractBaseLogger):
def __init__(self, args, checkpoint_path, filename='checkpoint-recent.pth'):
self.args = args
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
self.checkpoint_path.mkdir(parents=True)
self.recent_epoch = None
self.filename = filename
def log(self, *args, **kwargs):
epoch = kwargs['epoch']
if self.recent_epoch != epoch:
self.recent_epoch = epoch
state_dict = kwargs['state_dict']
state_dict['epoch'] = kwargs['epoch']
save_state_dict(state_dict, self.checkpoint_path, self.filename)
def complete(self, *args, **kwargs):
save_state_dict(kwargs['state_dict'],
self.checkpoint_path, self.filename + '.final')
class BestModelLogger(AbstractBaseLogger):
def __init__(self, args, checkpoint_path, metric_key, filename='best_acc_model.pth'):
self.args = args
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
self.checkpoint_path.mkdir(parents=True)
self.best_metric = 0.
self.metric_key = metric_key
self.filename = filename
self.patience_counter = 0
def log(self, *args, **kwargs):
current_metric = kwargs[self.metric_key]
if self.best_metric < current_metric: # assumes the higher the better
print("Update Best {} Model at {}".format(
self.metric_key, kwargs['epoch']))
self.best_metric = current_metric
save_state_dict(kwargs['state_dict'],
self.checkpoint_path, self.filename)
if self.args.early_stopping:
self.patience_counter = 0
elif self.args.early_stopping:
self.patience_counter += 1
================================================
FILE: trainer/lru.py
================================================
from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY
from .utils import *
from .loggers import *
from .base import *
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import json
import pickle
import numpy as np
from abc import *
from pathlib import Path
class LRUTrainer(BaseTrainer):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root, use_wandb):
super().__init__(args, model, train_loader, val_loader, test_loader, export_root, use_wandb)
self.ce = nn.CrossEntropyLoss(ignore_index=0)
def calculate_loss(self, batch):
seqs, labels = batch
logits = self.model(seqs)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.ce(logits, labels)
return loss
def calculate_metrics(self, batch, exclude_history=True):
seqs, labels = batch
scores = self.model(seqs)[:, -1, :]
B, L = seqs.shape
if exclude_history:
for i in range(L):
scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9
scores[:, 0] = -1e9 # padding
metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels.view(-1), self.metric_ks)
return metrics
def generate_candidates(self, retrieved_data_path):
self.model.eval()
val_probs, val_labels = [], []
test_probs, test_labels = [], []
with torch.no_grad():
print('*************** Generating Candidates for Validation Set ***************')
tqdm_dataloader = tqdm(self.val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = self.to_device(batch)
seqs, labels = batch
scores = self.model(seqs)[:, -1, :]
B, L = seqs.shape
for i in range(L):
scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9
scores[:, 0] = -1e9 # padding
val_probs.extend(scores.tolist())
val_labels.extend(labels.view(-1).tolist())
val_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(val_probs),
torch.tensor(val_labels).view(-1), self.metric_ks)
print(val_metrics)
print('****************** Generating Candidates for Test Set ******************')
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = self.to_device(batch)
seqs, labels = batch
scores = self.model(seqs)[:, -1, :]
B, L = seqs.shape
for i in range(L):
scores[torch.arange(scores.size(0)), seqs[:, i]] = -1e9
scores[:, 0] = -1e9 # padding
test_probs.extend(scores.tolist())
test_labels.extend(labels.view(-1).tolist())
test_metrics = absolute_recall_mrr_ndcg_for_ks(torch.tensor(test_probs),
torch.tensor(test_labels).view(-1), self.metric_ks)
print(test_metrics)
with open(retrieved_data_path, 'wb') as f:
pickle.dump({'val_probs': val_probs,
'val_labels': val_labels,
'val_metrics': val_metrics,
'test_probs': test_probs,
'test_labels': test_labels,
'test_metrics': test_metrics}, f)
================================================
FILE: trainer/utils.py
================================================
from config import *
import json
import os
import pprint as pp
import random
from datetime import date
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch import optim as optim
def ndcg(scores, labels, k):
scores = scores.cpu()
labels = labels.cpu()
rank = (-scores).argsort(dim=1)
cut = rank[:, :k]
hits = labels.gather(1, cut)
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits.float() * weights).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in labels.sum(1)])
ndcg = dcg / idcg
return ndcg.mean()
def absolute_recall_mrr_ndcg_for_ks(scores, labels, ks):
metrics = {}
labels = F.one_hot(labels, num_classes=scores.size(1))
answer_count = labels.sum(1)
labels_float = labels.float()
rank = (-scores).argsort(dim=1)
cut = rank
for k in sorted(ks, reverse=True):
cut = cut[:, :k]
hits = labels_float.gather(1, cut)
metrics['Recall@%d' % k] = \
(hits.sum(1) / torch.min(torch.Tensor([k]).to(
labels.device), labels.sum(1).float())).mean().cpu().item()
metrics['MRR@%d' % k] = \
(hits / torch.arange(1, k+1).unsqueeze(0).to(
labels.device)).sum(1).mean().cpu().item()
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits * weights.to(hits.device)).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in answer_count]).to(dcg.device)
ndcg = (dcg / idcg).mean()
metrics['NDCG@%d' % k] = ndcg.cpu().item()
return metrics
class AverageMeterSet(object):
def __init__(self, meters=None):
self.meters = meters if meters else {}
def __getitem__(self, key):
if key not in self.meters:
meter = AverageMeter()
meter.update(0)
return meter
return self.meters[key]
def update(self, name, value, n=1):
if name not in self.meters:
self.meters[name] = AverageMeter()
self.meters[name].update(value, n)
def reset(self):
for meter in self.meters.values():
meter.reset()
def values(self, format_string='{}'):
return {format_string.format(name): meter.val for name, meter in self.meters.items()}
def averages(self, format_string='{}'):
return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
def sums(self, format_string='{}'):
return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
def counts(self, format_string='{}'):
return {format_string.format(name): meter.count for name, meter in self.meters.items()}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
def __format__(self, format):
return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)
================================================
FILE: trainer/verb.py
================================================
from abc import abstractmethod
import json
from transformers.file_utils import ModelOutput
from transformers.data.processors.utils import InputFeatures
import torch
import torch.nn as nn
import torch.nn.functional as F
from yacs.config import CfgNode
from transformers.tokenization_utils import PreTrainedTokenizer
import numpy as np
from collections import namedtuple
import inspect
from typing import *
_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}
def convert_cfg_to_dict(cfg_node, key_list=[]):
""" Convert a config node to dictionary """
if not isinstance(cfg_node, CfgNode):
if type(cfg_node) not in _VALID_TYPES:
print("Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_node), _VALID_TYPES), )
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = convert_cfg_to_dict(v, key_list + [k])
return cfg_dict
def signature(f):
r"""Get the function f 's input arguments. A useful gadget
when some function slot might be instantiated into multiple functions.
Args:
f (:obj:`function`) : the function to get the input arguments.
Returns:
namedtuple : of args, default, varargs, keywords, respectively.s
"""
sig = inspect.signature(f)
args = [
p.name for p in sig.parameters.values()
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
varargs = [
p.name for p in sig.parameters.values()
if p.kind == inspect.Parameter.VAR_POSITIONAL
]
varargs = varargs[0] if varargs else None
keywords = [
p.name for p in sig.parameters.values()
if p.kind == inspect.Parameter.VAR_KEYWORD
]
keywords = keywords[0] if keywords else None
defaults = [
p.default for p in sig.parameters.values()
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and p.default is not p.empty
] or None
argspec = namedtuple('Signature', ['args', 'defaults',
'varargs', 'keywords'])
return argspec(args, defaults, varargs, keywords)
class Verbalizer(nn.Module):
r'''
Base class for all the verbalizers.
Args:
tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy.
classes (:obj:`Sequence[str]`): A sequence of classes that need to be projected.
'''
def __init__(self,
tokenizer: Optional[PreTrainedTokenizer] = None,
classes: Optional[Sequence[str]] = None,
num_classes: Optional[int] = None,
):
super().__init__()
self.tokenizer = tokenizer
self.classes = classes
if classes is not None and num_classes is not None:
assert len(classes) == num_classes, "len(classes) != num_classes, Check you config."
self.num_classes = num_classes
elif num_classes is not None:
self.num_classes = num_classes
elif classes is not None:
self.num_classes = len(classes)
else:
self.num_classes = None
# raise AttributeError("No able to configure num_classes")
self._in_on_label_words_set = False
@property
def label_words(self,):
r'''
Label words means the words in the vocabulary projected by the labels.
E.g. if we want to establish a projection in sentiment classification: positive :math:`\rightarrow` {`wonderful`, `good`},
in this case, `wonderful` and `good` are label words.
'''
if not hasattr(self, "_label_words"):
raise RuntimeError("label words haven't been set.")
return self._label_words
@label_words.setter
def label_words(self, label_words):
if label_words is None:
return
self._label_words = self._match_label_words_to_label_ids(label_words)
if not self._in_on_label_words_set:
self.safe_on_label_words_set()
def _match_label_words_to_label_ids(self, label_words): # TODO newly add function after docs written # TODO rename this function
"""
sort label words dict of verbalizer to match the label order of the classes
"""
if isinstance(label_words, dict):
if self.classes is None:
raise ValueError("""
classes attribute of the Verbalizer should be set since your given label words is a dict.
Since we will match the label word with respect to class A, to A's index in classes
""")
if set(label_words.keys()) != set(self.classes):
raise ValueError("name of classes in verbalizer are different from those of dataset")
label_words = [ # sort the dict to match dataset
label_words[c]
for c in self.classes
] # length: label_size of the whole task
elif isinstance(label_words, list) or isinstance(label_words, tuple):
pass
else:
raise ValueError("Verbalizer label words must be list, tuple or dict")
return label_words
def safe_on_label_words_set(self,):
self._in_on_label_words_set = True
self.on_label_words_set()
self._in_on_label_words_set = False
def on_label_words_set(self,):
r"""A hook to do something when textual label words were set.
"""
pass
@property
def vocab(self,) -> Dict:
if not hasattr(self, '_vocab'):
self._vocab = self.tokenizer.convert_ids_to_tokens(np.arange(self.vocab_size).tolist())
return self._vocab
@property
def vocab_size(self,) -> int:
return self.tokenizer.vocab_size
@abstractmethod
def generate_parameters(self, **kwargs) -> List:
r"""
The verbalizer can be seen as an extra layer on top of the original
pre-trained models. In manual verbalizer, it is a fixed one-hot vector of dimension
``vocab_size``, with the position of the label word being 1 and 0 everywhere else.
In other situation, the parameters may be a continuous vector over the
vocab, with each dimension representing a weight of that token.
Moreover, the parameters may be set to trainable to allow label words selection.
Therefore, this function serves as an abstract methods for generating the parameters
of the verbalizer, and must be instantiated in any derived class.
Note that the parameters need to be registered as a part of pytorch's module to
It can be achieved by wrapping a tensor using ``nn.Parameter()``.
"""
raise NotImplementedError
def register_calibrate_logits(self, logits: torch.Tensor):
r"""
This function aims to register logits that need to be calibrated, and detach the original logits from the current graph.
"""
if logits.requires_grad:
logits = logits.detach()
self._calibrate_logits = logits
def process_outputs(self,
outputs: torch.Tensor,
batch: Union[Dict, InputFeatures],
**kwargs):
r"""By default, the verbalizer will process the logits of the PLM's
output.
Args:
logits (:obj:`torch.Tensor`): The current logits generated by pre-trained language models.
batch (:obj:`Union[Dict, InputFeatures]`): The input features of the data.
"""
return self.process_logits(outputs, batch=batch, **kwargs)
def gather_outputs(self, outputs: ModelOutput):
r""" retrieve useful output for the verbalizer from the whole model output
By default, it will only retrieve the logits
Args:
outputs (:obj:`ModelOutput`) The output from the pretrained language model.
Return:
:obj:`torch.Tensor` The gathered output, should be of shape (``batch_size``,
``seq_len``, ``any``)
"""
return outputs.logits
@staticmethod
def aggregate(label_words_logits: torch.Tensor) -> torch.Tensor:
r""" To aggregate logits on multiple label words into the label's logits
Basic aggregator: mean of each label words' logits to a label's logits
Can be re-implemented in advanced verbaliezer.
Args:
label_words_logits (:obj:`torch.Tensor`): The logits of the label words only.
Return:
:obj:`torch.Tensor`: The final logits calculated by the label words.
"""
if label_words_logits.dim()>2:
return label_words_logits.mean(dim=-1)
else:
return label_words_logits
def normalize(self, logits: torch.Tensor) -> torch.Tensor:
r"""
Given logits regarding the entire vocab, calculate the probs over the label words set by softmax.
Args:
logits(:obj:`Tensor`): The logits of the entire vocab.
Returns:
:obj:`Tensor`: The probability distribution over the label words set.
"""
batch_size = logits.shape[0]
return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape)
@abstractmethod
def project(self,
logits: torch.Tensor,
**kwargs) -> torch.Tensor:
r"""This method receives input logits of shape ``[batch_size, vocab_size]``, and use the
parameters of this verbalizer to project the logits over entire vocab into the
logits of labels words.
Args:
logits (:obj:`Tensor`): The logits over entire vocab generated by the pre-trained language model with shape [``batch_size``, ``max_seq_length``, ``vocab_size``]
Returns:
:obj:`Tensor`: The normalized probs (sum to 1) of each label .
"""
raise NotImplementedError
def handle_multi_token(self, label_words_logits, mask):
r"""
Support multiple methods to handle the multi tokens produced by the tokenizer.
We suggest using 'first' or 'max' if the some parts of the tokenization is not meaningful.
Can broadcast to 3-d tensor.
Args:
label_words_logits (:obj:`torch.Tensor`):
Returns:
:obj:`torch.Tensor`
"""
if self.multi_token_handler == "first":
label_words_logits = label_words_logits.select(dim=-1, index=0)
elif self.multi_token_handler == "max":
label_words_logits = label_words_logits - 1000*(1-mask.unsqueeze(0))
label_words_logits = label_words_logits.max(dim=-1).values
elif self.multi_token_handler == "mean":
label_words_logits = (label_words_logits*mask.unsqueeze(0)).sum(dim=-1)/(mask.unsqueeze(0).sum(dim=-1)+1e-15)
else:
raise ValueError("multi_token_handler {} not configured".format(self.multi_token_handler))
return label_words_logits
@classmethod
def from_config(cls,
config: CfgNode,
**kwargs):
r"""load a verbalizer from verbalizer's configuration node.
Args:
config (:obj:`CfgNode`): the sub-configuration of verbalizer, i.e. ``config[config.verbalizer]``
if config is a global config node.
kwargs: Other kwargs that might be used in initialize the verbalizer.
The actual value should match the arguments of ``__init__`` functions.
"""
init_args = signature(cls.__init__).args
_init_dict = {**convert_cfg_to_dict(config), **kwargs} if config is not None else kwargs
init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args}
verbalizer = cls(**init_dict)
if hasattr(verbalizer, "from_file"):
if not hasattr(config, "file_path"):
pass
else:
if (not hasattr(config, "label_words") or config.label_words is None) and config.file_path is not None:
if config.choice is None:
config.choice = 0
verbalizer.from_file(config.file_path, config.choice)
elif (hasattr(config, "label_words") and config.label_words is not None) and config.file_path is not None:
raise RuntimeError("The text can't be both set from `text` and `file_path`.")
return verbalizer
def from_file(self,
path: str,
choice: Optional[int] = 0 ):
r"""Load the predefined label words from verbalizer file.
Currently support three types of file format:
1. a .jsonl or .json file, in which is a single verbalizer
in dict format.
2. a .jsonal or .json file, in which is a list of verbalizers in dict format
3. a .txt or a .csv file, in which is the label words of a class are listed in line,
separated by commas. Begin a new verbalizer by an empty line.
This format is recommended when you don't know the name of each class.
The details of verbalizer format can be seen in :ref:`How_to_write_a_verbalizer`.
Args:
path (:obj:`str`): The path of the local template file.
choice (:obj:`int`): The choice of verbalizer in a file containing
multiple verbalizers.
Returns:
Template : `self` object
"""
if path.endswith(".txt") or path.endswith(".csv"):
with open(path, 'r') as f:
lines = f.readlines()
label_words_all = []
label_words_single_group = []
for line in lines:
line = line.strip().strip(" ")
if line == "":
if len(label_words_single_group)>0:
label_words_all.append(label_words_single_group)
label_words_single_group = []
else:
label_words_single_group.append(line)
if len(label_words_single_group) > 0: # if no empty line in the last
label_words_all.append(label_words_single_group)
if choice >= len(label_words_all):
raise RuntimeError("choice {} exceed the number of verbalizers {}"
.format(choice, len(label_words_all)))
label_words = label_words_all[choice]
label_words = [label_words_per_label.strip().split(",") \
for label_words_per_label in label_words]
elif path.endswith(".jsonl") or path.endswith(".json"):
with open(path, "r") as f:
label_words_all = json.load(f)
# if it is a file containing multiple verbalizers
if isinstance(label_words_all, list):
if choice >= len(label_words_all):
raise RuntimeError("choice {} exceed the number of verbalizers {}"
.format(choice, len(label_words_all)))
label_words = label_words_all[choice]
elif isinstance(label_words_all, dict):
label_words = label_words_all
if choice>0:
print("Choice of verbalizer is 1, but the file \
only contains one verbalizer.")
self.label_words = label_words
if self.num_classes is not None:
num_classes = len(self.label_words)
assert num_classes==self.num_classes, 'number of classes in the verbalizer file\
does not match the predefined num_classes.'
return self
class ManualVerbalizer(Verbalizer):
r"""
The basic manually defined verbalizer class, this class is inherited from the :obj:`Verbalizer` class.
Args:
tokenizer (:obj:`PreTrainedTokenizer`): The tokenizer of the current pre-trained model to point out the vocabulary.
classes (:obj:`List[Any]`): The classes (or labels) of the current task.
label_words (:obj:`Union[List[str], List[List[str]], Dict[List[str]]]`, optional): The label words that are projected by the labels.
prefix (:obj:`str`, optional): The prefix string of the verbalizer (used in PLMs like RoBERTa, which is sensitive to prefix space)
multi_token_handler (:obj:`str`, optional): The handling strategy for multiple tokens produced by the tokenizer.
post_log_softmax (:obj:`bool`, optional): Whether to apply log softmax post processing on label_logits. Default to True.
"""
def __init__(self,
tokenizer: PreTrainedTokenizer,
classes: Optional[List] = None,
num_classes: Optional[Sequence[str]] = None,
label_words: Optional[Union[Sequence[str], Mapping[str, str]]] = None,
prefix: Optional[str] = " ",
multi_token_handler: Optional[str] = "first",
post_log_softmax: Optional[bool] = True,
):
super().__init__(tokenizer=tokenizer, num_classes=num_classes, classes=classes)
self.prefix = prefix
self.multi_token_handler = multi_token_handler
self.label_words = label_words
self.post_log_softmax = post_log_softmax
def on_label_words_set(self):
super().on_label_words_set()
self.label_words = self.add_prefix(self.label_words, self.prefix)
# TODO should Verbalizer base class has label_words property and setter?
# it don't have label_words init argument or label words from_file option at all
self.generate_parameters()
@staticmethod
def add_prefix(label_words, prefix):
r"""Add prefix to label words. For example, if a label words is in the middle of a template,
the prefix should be ``' '``.
Args:
label_words (:obj:`Union[Sequence[str], Mapping[str, str]]`, optional): The label words that are projected by the labels.
prefix (:obj:`str`, optional): The prefix string of the verbalizer.
Returns:
:obj:`Sequence[str]`: New label words with prefix.
"""
new_label_words = []
if isinstance(label_words[0], str):
label_words = [[w] for w in label_words] #wrapped it to a list of list of label words.
for label_words_per_label in label_words:
new_label_words_per_label = []
for word in label_words_per_label:
if word.startswith(""):
new_label_words_per_label.append(word.split("")[1])
else:
new_label_words_per_label.append(prefix + word)
new_label_words.append(new_label_words_per_label)
return new_label_words
def generate_parameters(self) -> List:
r"""In basic manual template, the parameters are generated from label words directly.
In this implementation, the label_words should not be tokenized into more than one token.
"""
all_ids = []
for words_per_label in self.label_words:
ids_per_label = []
for word in words_per_label:
ids = self.tokenizer.encode(word, add_special_tokens=False)
ids_per_label.append(ids)
all_ids.append(ids_per_label)
max_len = max([max([len(ids) for ids in ids_per_label]) for ids_per_label in all_ids])
max_num_label_words = max([len(ids_per_label) for ids_per_label in all_ids])
words_ids_mask = torch.zeros(max_num_label_words, max_len)
words_ids_mask = [[[1]*len(ids) + [0]*(max_len-len(ids)) for ids in ids_per_label]
+ [[0]*max_len]*(max_num_label_words-len(ids_per_label))
for ids_per_label in all_ids]
words_ids = [[ids + [0]*(max_len-len(ids)) for ids in ids_per_label]
+ [[0]*max_len]*(max_num_label_words-len(ids_per_label))
for ids_per_label in all_ids]
words_ids_tensor = torch.tensor(words_ids)
words_ids_mask = torch.tensor(words_ids_mask)
self.label_words_ids = nn.Parameter(words_ids_tensor, requires_grad=False)
self.words_ids_mask = nn.Parameter(words_ids_mask, requires_grad=False) # A 3-d mask
self.label_words_mask = nn.Parameter(torch.clamp(words_ids_mask.sum(dim=-1), max=1), requires_grad=False)
def project(self,
logits: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Project the labels, the return value is the normalized (sum to 1) probs of label words.
Args:
logits (:obj:`torch.Tensor`): The original logits of label words.
Returns:
:obj:`torch.Tensor`: The normalized logits of label words
"""
label_words_logits = logits[:, self.label_words_ids]
label_words_logits = self.handle_multi_token(label_words_logits, self.words_ids_mask)
label_words_logits -= 10000*(1-self.label_words_mask)
return label_words_logits
def process_logits(self, logits: torch.Tensor, **kwargs):
r"""A whole framework to process the original logits over the vocabulary, which contains four steps:
(1) Project the logits into logits of label words
if self.post_log_softmax is True:
(2) Normalize over all label words
(3) Calibrate (optional)
(4) Aggregate (for multiple label words)
Args:
logits (:obj:`torch.Tensor`): The original logits.
Returns:
(:obj:`torch.Tensor`): The final processed logits over the labels (classes).
"""
# project
label_words_logits = self.project(logits, **kwargs) #Output: (batch_size, num_classes) or (batch_size, num_classes, num_label_words_per_label)
if self.post_log_softmax:
# normalize
label_words_probs = self.normalize(label_words_logits)
# calibrate
if hasattr(self, "_calibrate_logits") and self._calibrate_logits is not None:
label_words_probs = self.calibrate(label_words_probs=label_words_probs)
# convert to logits
label_words_logits = torch.log(label_words_probs+1e-15)
# aggregate
label_logits = self.aggregate(label_words_logits)
return label_logits
def normalize(self, logits: torch.Tensor) -> torch.Tensor:
"""
Given logits regarding the entire vocabulary, return the probs over the label words set.
Args:
logits (:obj:`Tensor`): The logits over the entire vocabulary.
Returns:
:obj:`Tensor`: The logits over the label words set.
"""
batch_size = logits.shape[0]
return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape)
def aggregate(self, label_words_logits: torch.Tensor) -> torch.Tensor:
r"""Use weight to aggregate the logits of label words.
Args:
label_words_logits(:obj:`torch.Tensor`): The logits of the label words.
Returns:
:obj:`torch.Tensor`: The aggregated logits from the label words.
"""
label_words_logits = (label_words_logits * self.label_words_mask).sum(-1)/self.label_words_mask.sum(-1)
return label_words_logits
def calibrate(self, label_words_probs: torch.Tensor, **kwargs) -> torch.Tensor:
r"""
Args:
label_words_probs (:obj:`torch.Tensor`): The probability distribution of the label words with the shape of [``batch_size``, ``num_classes``, ``num_label_words_per_class``]
Returns:
:obj:`torch.Tensor`: The calibrated probability of label words.
"""
shape = label_words_probs.shape
assert self._calibrate_logits.dim() == 1, "self._calibrate_logits are not 1-d tensor"
calibrate_label_words_probs = self.normalize(self.project(self._calibrate_logits.unsqueeze(0), **kwargs))
assert calibrate_label_words_probs.shape[1:] == label_words_probs.shape[1:] \
and calibrate_label_words_probs.shape[0]==1, "shape not match"
label_words_probs /= (calibrate_label_words_probs+1e-15)
# normalize # TODO Test the performance
norm = label_words_probs.reshape(shape[0], -1).sum(dim=-1,keepdim=True) # TODO Test the performance of detaching()
label_words_probs = label_words_probs.reshape(shape[0], -1) / norm
label_words_probs = label_words_probs.reshape(*shape)
return label_words_probs