Repository: tsinghua-fib-lab/ANeurIPS2024_SPV-MIA Branch: main Commit: ee9ebaad837c Files: 17 Total size: 67.8 KB Directory structure: gitextract_dd34ftc1/ ├── .github/ │ └── workflows/ │ └── main.yml ├── .gitignore ├── README.md ├── attack/ │ ├── __init__.py │ ├── attack_model.py │ └── utils.py ├── attack.py ├── configs/ │ └── config.yaml ├── data/ │ ├── __init__.py │ └── prepare.py ├── ft_llms/ │ ├── __init__.py │ ├── llama_patch.py │ ├── llms_finetune.py │ ├── refer_data_generate.py │ ├── run_llama.sh │ └── utils.py └── requirements.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/main.yml ================================================ name: Clean Commit History on: workflow_dispatch jobs: clean_history: runs-on: ubuntu-latest steps: - name: Checkout the repository uses: actions/checkout@v2 with: fetch-depth: 0 - name: Remove commit history run: | # Set up Git user git config --global user.name "Wenjie Fu" git config --global user.email "wjfu99@outlook.com" git branch ls # Create a new orphan branch with the latest commit git checkout --orphan latest_commit # Add all files to the new branch git add -A # Commit the changes git commit -m "Final Commit" # Rename the current branch to main git branch -m latest_commit # Force push to update the repository git push -f origin latest_commit:main ================================================ FILE: .gitignore ================================================ /cache # /ft_llms/cache # /ft_llms/checkpoints /ft_llms/*/ /attack/*/ /wandb /abc /.vscode /data/*/ /test # /detect-gpt # /Finetune_LLMs # ft_llms/{cache, checkpoints} ================================================ FILE: README.md ================================================ # (NeurIPS'24) Practical Membership Inference Attacks against Fine-tuned Large Language Models via Self-prompt Calibration - [Requirements](#requirements) - [Target Model Fine-tuning](#target-model-fine-tuning) - [Self-prompt Reference Model Fine-tuning](#self-prompt-reference-model-fine-tuning) - [Run SPV-MIA](#run-spv-mia) This is the official implementation of the paper "Practical Membership Inference Attacks against Fine-tuned Large Language Models via Self-prompt Calibration". The proposed Membership Inference Attack based on Self-calibrated Probabilistic Variation (SPV-MIA) is implemented as follows. ![The overall architecture of _SPV-MIA_](./Framework.png) ## Requirements - torch>=1.11.0 - accelerate==0.20.3 - transformers==4.34.0.dev0 - trl==0.7.1 - datasets==2.13.1 - numpy>=1.23.4 - scikit-learn>=1.1.3 - pyyaml>=6.0 - tqdm>=4.64.1 Dependency can be installed with the following command: ```bash pip install -r requirements.txt ``` ## Target Model Fine-tuning All large language models (LLMs) are built on the top of [transformers](https://huggingface.co/docs/transformers/index), a go-to library for state-of-the-art transformer models, on which you can fine-tune arbitrary well-known LLMs you want, including LLaMA, GPT-series, Falcon, etc. We recommend training LLMs with multi-GPU and [accelerate](https://huggingface.co/docs/accelerate/index), a library that enables the same PyTorch code to be run across any distributed configuration: ```bash accelerate launch ./ft_llms/llms_finetune.py \ --output_dir ./ft_llms/*pretrained_model_name*/*dataset_name*/target/ \ --block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \ -d *dataset_name* -m *pretrained_model_name* --packing --use_dataset_cache \ -e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \ --train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000 ``` Please replace \*pretrained_model_name\* and \*dataset_name\* with the names of pretrained LLM and training dataset, such as `decapoda-research/llama-7b-hf` and `ag_news`. ### Recommended pretrained models - GPT-2 (1.5B) (https://huggingface.co/gpt2-xl) - GPT-J (https://huggingface.co/EleutherAI/gpt-j-6b) - Falcon (https://huggingface.co/tiiuae/falcon-7b) - LLaMA (https://huggingface.co/decapoda-research/llama-7b-hf) [^1] [^1]: This third-party repo `decapoda-research/llama-7b-hf` seems to be deleted by unknown reasons, using forked repos [luodian/llama-7b-hf](https://huggingface.co/luodian/llama-7b-hf) or [baffo32/decapoda-research-llama-7B-hf](https://huggingface.co/baffo32/decapoda-research-llama-7B-hf) as alternatives. ### Recommended datasets - Ag News (https://huggingface.co/datasets/ag_news) - Wikitext-103 (https://huggingface.co/datasets/wikitext) [^2] - Xsum (https://huggingface.co/datasets/xsum) [^2]: Please add an additional argument `--dataset_config_name wikitext-2-raw-v1` to specify this dataset. ## Self-prompt Reference Model Fine-tuning Before fine-tuning the self-prompt reference model, the reference dataset can be sampled via our proposed self-prompt approach over the fine-tuned LLM. ```bash accelerate launch refer_data_generate.py \ -tm *fine_tuned_model* \ -m *pretrained_model_name* -d *dataset_name* ``` Replace \*fine_tuned_model\* with the directory of the fine-tuned target model (i.e., the output directory of the [Target Model Fine-tuning](#target-model-fine-tuning) phase). Then fine-tune the self-prompt reference model in the same manner as the target model, but with a smaller training epoch: ```bash accelerate launch ./ft_llms/llms_finetune.py --refer \ --output_dir ./ft_llms/*pretrained_model_name*/*dataset_name*/refer/ \ --block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \ -d *dataset_name* -m *pretrained_model_name* --packing --use_dataset_cache \ -e 2 -b 4 -lr 5e-5 --gradient_accumulation_steps 1 \ --train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000 ``` ## Run SPV-MIA After accomplishing the preliminary operations, here is the command for deploying SPV-MIA on the target model. ```bash python attack.py ``` ================================================ FILE: attack/__init__.py ================================================ ================================================ FILE: attack/attack_model.py ================================================ import os import random import torch from accelerate import Accelerator from accelerate.logging import get_logger from attack import utils from attack.utils import Dict import numpy as np from copy import deepcopy from tqdm import tqdm from torch.utils.data import DataLoader import nlpaug.augmenter.word as naw import nlpaug.augmenter.sentence as nas from sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve, f1_score from itertools import cycle import matplotlib.pyplot as plt import re import seaborn as sns from functools import partial logger = get_logger(__name__, "INFO") PATH = os.getcwd() accelerator = Accelerator() class AttackModel: def __init__(self, target_model, tokenizer, datasets, reference_model, shadow_model, cfg, mask_model=None, mask_tokenizer=None): self.target_model = target_model self.tokenizer = tokenizer self.datasets = datasets self.kind = cfg['attack_kind'] self.cfg = cfg if mask_model is not None: self.mask_model = mask_model self.mask_tokenizer = mask_tokenizer self.pattern = re.compile(r"") if shadow_model is not None and cfg['attack_kind'] == "nn": self.shadow_model = shadow_model self.is_model_training = False if reference_model is not None: self.reference_model = reference_model def llm_eval(self, model, data_loader, cfg, idx_rate, perturb_fn=None, refer_model=None): model.eval() losses = [] ref_losses = [] token_lens = [] for iteration, texts in enumerate(data_loader): texts = texts["text"] if cfg["maximum_samples"] is not None: if iteration * accelerator.num_processes >= cfg["maximum_samples"]: break if perturb_fn is not None: texts = perturb_fn(texts) token_ids = self.tokenizer(texts, return_tensors="pt", padding=True).to(accelerator.device) labels = token_ids.input_ids with torch.no_grad(): outputs = model(**token_ids, labels=labels) ref_outputs = refer_model(**token_ids, labels=labels) loss = outputs.loss ref_loss = ref_outputs.loss token_lens.append(accelerator.gather(torch.tensor(token_ids.input_ids.size()[-1]).reshape(-1, 1).to(accelerator.device)).detach().cpu().numpy()) # TODO: may cause bug when running attacks in paralell. losses.append(accelerator.gather(loss.reshape(-1, 1)).detach().cpu().to(torch.float32).numpy()) ref_losses.append(accelerator.gather(ref_loss.reshape(-1, 1)).detach().cpu().to(torch.float32).numpy()) # print(f"{accelerator.device}@{texts}") # print(f"time duration: {time.time() - start_time}s") losses = np.concatenate(losses, axis=0) ref_losses = np.concatenate(ref_losses, axis=0) token_lens = np.concatenate(token_lens, axis=0) # token_lens = np.array(token_lens, dtype=np.int32) return losses, ref_losses, token_lens def eval_perturb(self, model, dataset, cfg): """ Evaluate the loss of the perturbed data :param dataset: N*channel*width*height :return: losses: N*1; var_losses: N*1; per_losses: N*Mask_Num; ori_losses: N*1 """ per_losses = [] ref_per_losses = [] ori_losses = [] ref_ori_losses = [] ori_dataset = deepcopy(dataset) for i in tqdm(range(0, cfg["perturbation_number"])): idx_rate = i / cfg["perturbation_number"] * 0.7 ori_loss, ref_ori_loss, ori_token_len = self.llm_eval(model, ori_dataset, cfg, idx_rate, refer_model=self.reference_model) ori_losses.append(ori_loss) ref_ori_losses.append(ref_ori_loss) perturb_fn = partial(self.sentence_perturbation, idx_rate=idx_rate) sampled_per_losses = [] sampled_ref_per_losses = [] for _ in range(cfg["sample_number"]): per_loss, ref_per_loss, per_token_len = self.llm_eval(model, ori_dataset, cfg, idx_rate, perturb_fn=perturb_fn, refer_model=self.reference_model) sampled_per_losses.append(per_loss) sampled_ref_per_losses.append(ref_per_loss) sampled_per_losses = np.concatenate(sampled_per_losses, axis=-1) sampled_ref_per_losses = np.concatenate(sampled_ref_per_losses, axis=-1) per_losses.append(np.expand_dims(sampled_per_losses, axis=-1)) ref_per_losses.append(np.expand_dims(sampled_ref_per_losses, axis=-1)) ori_losses = np.concatenate(ori_losses, axis=-1) ref_ori_losses = np.concatenate(ref_ori_losses, axis=-1) per_losses = np.concatenate(per_losses, axis=-1) var_losses = per_losses - np.expand_dims(ori_losses, axis=-2) ref_per_losses = np.concatenate(ref_per_losses, axis=-1) if cfg["calibration"] else None ref_var_losses = ref_per_losses - np.expand_dims(ref_ori_losses, axis=-2) if cfg["calibration"] else None output = (Dict( per_losses=per_losses, ori_losses=ori_losses, var_losses=var_losses, ), Dict( ref_per_losses=ref_per_losses, ref_ori_losses=ref_ori_losses, ref_var_losses=ref_var_losses, )) return output def data_prepare(self, kind, cfg): logger.info("Preparing data...") data_path = os.path.join(PATH, cfg["attack_data_path"], f"attack_data_{cfg['model_name']}@{cfg['dataset_name']}") target_model = getattr(self, kind + "_model") mem_data = self.datasets[kind]["train"] nonmem_data = self.datasets[kind]["valid"] mem_path = os.path.join(data_path, kind, "mem_feat.npz") nonmem_path = os.path.join(data_path, kind, "nonmen_feat.npz") ref_mem_path = os.path.join(data_path, kind, "ref_mem_feat.npz") ref_nonmem_path = os.path.join(data_path, kind, "ref_nonmen_feat.npz") pathlist = (mem_path, nonmem_path, ref_mem_path, ref_nonmem_path) if cfg["calibration"] else (mem_path, nonmem_path) if not utils.check_files_exist(*pathlist) or not cfg["load_attack_data"]: logger.info("Generating feature vectors for member data...") mem_feat, ref_mem_feat = self.eval_perturb(target_model, mem_data, cfg) if accelerator.is_main_process: utils.save_dict_to_npz(mem_feat, mem_path) if cfg["calibration"]: utils.save_dict_to_npz(ref_mem_feat, ref_mem_path) logger.info("Generating feature vectors for non-member data...") nonmem_feat, ref_nonmem_feat = self.eval_perturb(target_model, nonmem_data, cfg) if accelerator.is_main_process: utils.save_dict_to_npz(nonmem_feat, nonmem_path) if cfg["calibration"]: utils.save_dict_to_npz(ref_nonmem_feat, ref_nonmem_path) logger.info("Saving feature vectors...") else: logger.info("Loading feature vectors...") mem_feat = utils.load_dict_from_npz(mem_path) ref_mem_feat = utils.load_dict_from_npz(ref_mem_path) if cfg["calibration"] else None nonmem_feat = utils.load_dict_from_npz(nonmem_path) ref_nonmem_feat = utils.load_dict_from_npz(ref_nonmem_path) if cfg["calibration"] else None logger.info("Data preparation complete.") return Dict( mem_feat=mem_feat, nonmem_feat=nonmem_feat, ref_mem_feat=ref_mem_feat, ref_nonmem_feat=ref_nonmem_feat, ) def feat_prepare(self, info_dict, cfg): # mem_info = info_dict.mem_feat # ref_mem_info = info_dict.ref_mem_feat if cfg["calibration"]: get_prob = lambda logprob: np.power(np.e, -logprob) mem_feat = ((get_prob(info_dict.mem_feat.per_losses).mean((-1, -2)) - get_prob(info_dict.mem_feat.ori_losses).mean(-1)) - (get_prob(info_dict.ref_mem_feat.ref_per_losses).mean((-1, -2)) - get_prob(info_dict.ref_mem_feat.ref_ori_losses).mean(-1))) nonmem_feat = ((get_prob(info_dict.nonmem_feat.per_losses).mean((-1, -2)) - get_prob(info_dict.nonmem_feat.ori_losses).mean(-1)) - (get_prob(info_dict.ref_nonmem_feat.ref_per_losses).mean((-1, -2)) - get_prob(info_dict.ref_nonmem_feat.ref_ori_losses).mean(-1))) else: mem_feat = info_dict.mem_feat.var_losses / info_dict.mem_feat.ori_losses nonmem_feat = info_dict.nonmem_feat.var_losses / info_dict.nonmem_feat.ori_losses if cfg["attack_kind"] == "stat": # mem_feat = mem_feat[:, :, 0] # nonmem_feat = nonmem_feat[:, :, 0] # mem_feat[np.isnan(mem_feat)] = 0 # nonmem_feat[np.isnan(nonmem_feat)] = 0 feat = - np.concatenate([mem_feat, nonmem_feat]) ground_truth = np.concatenate([np.zeros(mem_feat.shape[0]), np.ones(nonmem_feat.shape[0])]).astype(int) return feat, ground_truth def conduct_attack(self, cfg): save_path = os.path.join(PATH, cfg["attack_data_path"], f"attack_data_{cfg['model_name']}@{cfg['dataset_name']}", f"roc_{cfg['attack_kind']}.npz") raw_info = self.data_prepare("target", cfg) feat, ground_truth = self.feat_prepare(raw_info, cfg) # self.distinguishability_plot(raw_info['mem_feat']['ori_losses'].mean(-1), # raw_info['nonmem_feat']['ori_losses'].mean(-1)) # self.distinguishability_plot(feat[:1000], feat[-1000:]) self.eval_attack(ground_truth, -feat, path=save_path) def tokenize_and_mask(self, text, span_length, pct, idx_rate, ceil_pct=False): cfg = self.cfg tokens = text.split(' ') mask_string = '<<>>' perturb_start_idx = int(len(tokens) * idx_rate) n_spans = pct * len(tokens) / (span_length + cfg.buffer_size * 2) if ceil_pct: n_spans = np.ceil(n_spans) n_spans = int(n_spans) n_masks = 0 while n_masks < n_spans: start = np.random.randint(0, len(tokens) - span_length) end = start + span_length search_start = max(0, start - cfg.buffer_size) search_end = min(len(tokens), end + cfg.buffer_size) if mask_string not in tokens[search_start:search_end]: tokens[start:end] = [mask_string] n_masks += 1 # replace each occurrence of mask_string with , where NUM increments num_filled = 0 for idx, token in enumerate(tokens): if token == mask_string: tokens[idx] = f'' num_filled += 1 assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}" text = ' '.join(tokens) return text @staticmethod def count_masks(texts): return [len([x for x in text.split() if x.startswith("")[0] tokens = self.mask_tokenizer(texts, return_tensors="pt", padding=True).to(accelerator.device) outputs = self.mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=cfg.mask_top_p, num_return_sequences=1, eos_token_id=stop_id) return self.mask_tokenizer.batch_decode(outputs, skip_special_tokens=False) def extract_fills(self, texts): # remove from beginning of each text texts = [x.replace("", "").replace("", "").strip() for x in texts] # return the text in between each matched mask token extracted_fills = [self.pattern.split(x)[1:-1] for x in texts] # remove whitespace around each fill extracted_fills = [[y.strip() for y in x] for x in extracted_fills] return extracted_fills def apply_extracted_fills(self, masked_texts, extracted_fills): # split masked text into tokens, only splitting on spaces (not newlines) tokens = [x.split(' ') for x in masked_texts] n_expected = self.count_masks(masked_texts) # replace each mask token with the corresponding fill for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)): if len(fills) < n: tokens[idx] = [] else: for fill_idx in range(n): text[text.index(f"")] = fills[fill_idx] # join tokens back into text texts = [" ".join(x) for x in tokens] return texts def sentence_perturbation(self, texts, idx_rate): cfg = self.cfg masked_texts = [self.tokenize_and_mask(x, cfg.span_length, cfg.pct, idx_rate, cfg.ceil_pct) for x in texts] raw_fills = self.replace_masks(masked_texts) extracted_fills = self.extract_fills(raw_fills) perturbed_texts = self.apply_extracted_fills(masked_texts, extracted_fills) # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again attempts = 1 while '' in perturbed_texts: idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ''] print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].') masked_texts = [self.tokenize_and_mask(x, cfg.span_length, cfg.pct, idx_rate, cfg.ceil_pct) for idx, x in enumerate(texts) if idx in idxs] raw_fills = self.replace_masks(masked_texts) extracted_fills = self.extract_fills(raw_fills) new_perturbed_texts = self.apply_extracted_fills(masked_texts, extracted_fills) for idx, x in zip(idxs, new_perturbed_texts): perturbed_texts[idx] = x attempts += 1 return perturbed_texts @staticmethod def eval_attack(y_true, y_scores, plot=True, path=None): if type(y_true) == torch.Tensor: y_true, y_scores = utils.tensor_to_ndarray(y_true, y_scores) fpr, tpr, thresholds = roc_curve(y_true, y_scores) if path is not None: np.savez(path, fpr=fpr, tpr=tpr) auc_score = roc_auc_score(y_true, y_scores) logger.info(f"AUC on the target model: {auc_score}") # Finding the threshold point where FPR + TPR equals 1 threshold_point = tpr[np.argmin(np.abs(tpr - (1 - fpr)))] logger.info(f"ASR on the target model: {threshold_point}") # Finding the threshold point where FPR + TPR equals 1 tpr_1fpr = tpr[np.argmin(np.abs(fpr - 0.01))] logger.info(f"TPR@1%FPR on the target model: {tpr_1fpr}") if plot: # plot the ROC curve plt.plot(fpr, tpr, label=f'ROC curve (AUC = {auc_score}; ASR = {threshold_point})') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.legend() # plot the no-skill line for reference plt.plot([0, 1], [0, 1], linestyle='--') # show the plot plt.show() ================================================ FILE: attack/utils.py ================================================ import logging from typing_extensions import Literal from rich.logging import RichHandler import os import torch import numpy as np def get_logger(name: str, level: Literal["info", "warning", "debug"]) -> logging.Logger: rich_handler = RichHandler(level=logging.INFO, rich_tracebacks=True, markup=True) logger = logging.getLogger(name) logger.setLevel(logging._nameToLevel[level.upper()]) if not logger.handlers: logger.addHandler(rich_handler) logger.propagate = False return logger class Dict(dict): def __getattr__(self, name): if name in self: return self[name] raise AttributeError(f"'Dict' object has no attribute '{name}'") def __setattr__(self, name, value): super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): super().__setitem__(key, value) super().__setattr__(key, value) def check_files_exist(*file_paths): """ Check if the input file(s) exist at the given file path(s). Parameters: *file_paths (str): One or more strings representing the file path(s) to check. Returns: bool: True if all the files exist, False otherwise. """ for file_path in file_paths: if not os.path.isfile(file_path): return False return True def create_folder(folder_path): if not os.path.exists(folder_path): os.makedirs(folder_path) print(f"Folder '{folder_path}' created.") else: print(f"Folder '{folder_path}' already exists.") def save_dict_to_npz(my_dict, file_path): """ Saves a dictionary with ndarray values to an npz file. Parameters: my_dict (dict): A dictionary with ndarray values to be saved. file_path (str): The file path to save the dictionary values to. Returns: None """ folder = os.path.dirname(file_path) if not os.path.exists(folder): os.makedirs(folder) with open(file_path, 'wb') as f: np.savez(f, **my_dict) def load_dict_from_npz(file_path): """ Loads a dictionary with ndarray values from an npz file. Parameters: file_path (str): The file path of the npz file to load. Returns: dict: A dictionary containing the values stored in the npz file. """ with np.load(file_path) as data: my_dict = Dict({key: value for key, value in data.items() if isinstance(value, np.ndarray)}) return my_dict def ndarray_to_tensor(*ndarrays): """ Converts multiple numpy ndarrays to PyTorch tensors. Parameters: *ndarrays (numpy.ndarray): Multiple numpy ndarrays to convert. Returns: tuple of torch.Tensor: A tuple of PyTorch tensors with the same data as the input ndarrays. """ tensors = tuple(torch.from_numpy(ndarray).cuda().float() for ndarray in ndarrays) return tensors def tensor_to_ndarray(*tensors): """ Converts multiple PyTorch tensors to numpy ndarrays. Parameters: *tensors (torch.Tensor): Multiple PyTorch tensors to convert. Returns: tuple of numpy.ndarray: A tuple of numpy ndarrays with the same data as the input tensors. """ ndarrays = tuple(tensor.detach().cpu().numpy() for tensor in tensors) return ndarrays def convert_labels_to_one_hot(labels, num_classes): ''' Converts labels of samples from format (N,) to (N, C), where C is the number of classes Args: labels : numpy array of shape (N,) containing the labels of each sample num_classes : integer indicating the total number of classes in the dataset Returns: numpy array of shape (N, C), where C is the number of classes, containing the one-hot encoded labels ''' one_hot_labels = np.zeros((labels.shape[0], num_classes)) one_hot_labels[np.arange(labels.shape[0]), labels] = 1 return one_hot_labels def get_file_names(folder_path): # List to store the file names file_names = [] # Loop through each file in the folder for file_name in sorted(os.listdir(folder_path)): # Check if the current item is a file if os.path.isfile(os.path.join(folder_path, file_name)): file_names.append(os.path.join(folder_path, file_name)) return file_names def extract(v, t, x_shape): """ Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ out = torch.gather(v, index=t, dim=0).float() return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) ================================================ FILE: attack.py ================================================ import os import numpy as np import torch from tqdm import tqdm from torch.utils.data import DataLoader import logging import random from attack.attack_model import AttackModel from data.prepare import dataset_prepare from attack.utils import Dict import yaml import datasets from datasets import Image, Dataset from accelerate import Accelerator from accelerate.logging import get_logger import trl from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig, TrainingArguments, AutoConfig, LlamaTokenizer from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training # Load config file with open("configs/config.yaml", 'r') as f: cfg = yaml.safe_load(f) cfg = Dict(cfg) # Add Logger accelerator = Accelerator() logger = get_logger(__name__, "INFO") logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Load abs path PATH = os.path.dirname(os.path.abspath(__file__)) # Fix the random seed seed = 0 torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True ## Load generation models. if not cfg["load_attack_data"]: # config = AutoConfig.from_pretrained(cfg["model_name"]) # config.use_cache = False torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 target_model = AutoModelForCausalLM.from_pretrained(cfg["target_model"], quantization_config=BitsAndBytesConfig(load_in_8bit=True), torch_dtype=torch_dtype, local_files_only=True, config=AutoConfig.from_pretrained(cfg["model_name"]), cache_dir=cfg["cache_path"]) reference_model = AutoModelForCausalLM.from_pretrained(cfg["reference_model"], quantization_config=BitsAndBytesConfig(load_in_8bit=True), torch_dtype=torch_dtype, local_files_only=True, config=AutoConfig.from_pretrained(cfg["model_name"]), cache_dir=cfg["cache_path"]) logger.info("Successfully load models") config = AutoConfig.from_pretrained(cfg.model_name) # Load tokenizer. model_type = config.to_dict()["model_type"] if model_type == "llama": tokenizer = LlamaTokenizer.from_pretrained(cfg["model_name"], add_eos_token=cfg["add_eos_token"], add_bos_token=cfg["add_bos_token"], use_fast=True) else: tokenizer = AutoTokenizer.from_pretrained(cfg["model_name"], add_eos_token=cfg["add_eos_token"], add_bos_token=cfg["add_bos_token"], use_fast=True) if cfg["model_name"] == "/mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348": cfg["model_name"] = "decapoda-research/llama-7b-hf" if cfg["pad_token_id"] is not None: logger.info("Using pad token id %d", cfg["pad_token_id"]) tokenizer.pad_token_id = cfg["pad_token_id"] if tokenizer.pad_token_id is None: logger.info("Pad token id is None, setting to eos token id...") tokenizer.pad_token_id = tokenizer.eos_token_id # Load datasets train_dataset, valid_dataset = dataset_prepare(cfg, tokenizer=tokenizer) train_dataset = Dataset.from_dict(train_dataset[cfg.train_sta_idx:cfg.train_end_idx]) valid_dataset = Dataset.from_dict(valid_dataset[cfg.eval_sta_idx:cfg.eval_end_idx]) train_dataset = Dataset.from_dict(train_dataset[random.sample(range(len(train_dataset["text"])), cfg["maximum_samples"])]) valid_dataset = Dataset.from_dict(valid_dataset[random.sample(range(len(valid_dataset["text"])), cfg["maximum_samples"])]) logger.info("Successfully load datasets!") # Prepare dataloade train_dataloader = DataLoader(train_dataset, batch_size=cfg["eval_batch_size"]) eval_dataloader = DataLoader(valid_dataset, batch_size=cfg["eval_batch_size"]) # Load Mask-f shadow_model = None int8_kwargs = {} half_kwargs = {} if cfg["int8"]: int8_kwargs = dict(load_in_8bit=True, device_map='auto', torch_dtype=torch.bfloat16) elif cfg["half"]: half_kwargs = dict(torch_dtype=torch.bfloat16) mask_model = AutoModelForSeq2SeqLM.from_pretrained(cfg["mask_filling_model_name"], **int8_kwargs, **half_kwargs).to(accelerator.device) try: n_positions = mask_model.config.n_positions except AttributeError: n_positions = 512 mask_tokenizer = AutoTokenizer.from_pretrained(cfg["mask_filling_model_name"], model_max_length=n_positions) # Prepare everything with accelerator train_dataloader, eval_dataloader = ( accelerator.prepare( train_dataloader, eval_dataloader, )) else: target_model = None reference_model = None shadow_model = None mask_model = None train_dataloader = None eval_dataloader = None tokenizer = None mask_tokenizer = None datasets = { "target": { "train": train_dataloader, "valid": eval_dataloader } } attack_model = AttackModel(target_model, tokenizer, datasets, reference_model, shadow_model, cfg, mask_model=mask_model, mask_tokenizer=mask_tokenizer) attack_model.conduct_attack(cfg=cfg) ================================================ FILE: configs/config.yaml ================================================ random_seed: 48 model_name: /mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 # EleutherAI/gpt-j-6B gpt2 target_model: /mnt/data0/fuwenjie/MIA-LLMs/ft_llms/llama/ag_news/target/checkpoint-3000 # valid model: reference_model: /mnt/data0/fuwenjie/MIA-LLMs/ft_llms/llama/ag_news/refer/checkpoint-400 # dataset_name: ag_news # xsum, ag_news, wikitext dataset_config_name: null # wikitext-2-raw-v1 null cache_path: ./cache use_dataset_cache: true packing: true calibration: true # whether to enable calibration add_eos_token: false add_bos_token: false pad_token_id: null attack_kind: stat # valid attacks: nn, stat eval_batch_size: 1 # batch size of the evaluation phase maximum_samples: 200 # the maximum samples number for member and non-member records. block_size: 128 validation_split_percentage: 0.1 preprocessing_num_workers: 1 mask_filling_model_name: t5-base buffer_size: 1 mask_top_p: 1.0 span_length: 2 pct: 0.3 # pct_words_masked ceil_pct: false int8: false half: false perturbation_number: 1 # the number of different perturbation strength / position; debugging parameter, should be set to 1 in the regular running. sample_number: 10 # the number of sampling train_sta_idx: 0 train_end_idx: 10000 eval_sta_idx: 0 eval_end_idx: 1000 attack_data_path: attack load_attack_data: false # whether to load prepared attack data if existing. ================================================ FILE: data/__init__.py ================================================ ================================================ FILE: data/prepare.py ================================================ import os import random import datasets import trl from attack.utils import create_folder block_size = None tokenizer_ = None max_buff_size = None text_column = None def packing_texts(examples): more_examples = True packed_texts = [] packed_ids = [] # for key in examples.keys(): assert list(examples.keys()) == ["text"] iterator = iter(examples["text"]) # for sentence in examples["text"]: total_num = 0 drop_num = 0 while more_examples: buffer, buffer_len = [], 0 while True: if buffer_len >= max_buff_size: break try: buffer.append(next(iterator)) buffer_len += len(buffer[-1]) except StopIteration: more_examples = False break tokenized_inputs = tokenizer_(buffer, truncation=False)["input_ids"] inputs = tokenizer_.batch_decode(tokenized_inputs) tokenized_inputs = tokenizer_(inputs, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input) for i in range(0, len(all_token_ids), block_size): input_ids = all_token_ids[i: i + block_size] if len(input_ids) == block_size: packed_ids.append(input_ids) input_text = tokenizer_.decode(input_ids) total_num += 1 if len(tokenizer_.encode(input_text)) == block_size: packed_texts.append(input_text) drop_num += 1 # print(f"Total examples: {total_num}, dropped num: {drop_num}, dropped rate: {1 - drop_num/total_num}") return { "text": packed_texts } def dataset_prepare(args, tokenizer=None, num_of_sequences=1024, chars_per_token=3.6): # raw_datasets = datasets.load_dataset(args.dataset_name, args.dataset_config_name)['train'] # if "validation" in raw_datasets.keys(): # train_dataset = raw_datasets["train"] # valid_dataset = raw_datasets["validation"] # else: train_dataset = datasets.load_dataset( args.dataset_name, args.dataset_config_name, split=f"train[:{int((1-args.validation_split_percentage)*100)}%]" ) valid_dataset = datasets.load_dataset( args.dataset_name, args.dataset_config_name, split=f"train[{int((1-args.validation_split_percentage)*100)}%:]", ) # train_idxs = set(random.sample(range(len(raw_datasets)), int(len(raw_datasets) * (1 - args.validation_split_percentage)))) # valid_idxs = set(range(len(raw_datasets))) - train_idxs # train_dataset = datasets.Dataset.from_dict(raw_datasets[train_idxs]) # valid_dataset = datasets.Dataset.from_dict(raw_datasets[valid_idxs]) global text_column column = train_dataset.column_names if "text" in column: text_column = "text" elif "document" in column: text_column = "document" elif "content" in column: text_column = "content" train_dataset = train_dataset.select_columns(text_column) valid_dataset = valid_dataset.select_columns(text_column) if text_column != "text": train_dataset = train_dataset.rename_column(text_column, "text") valid_dataset = valid_dataset.rename_column(text_column, "text") if args.packing: global block_size, tokenizer_, max_buff_size block_size = args.block_size max_buff_size = block_size * chars_per_token * num_of_sequences tokenizer_ = tokenizer create_folder(f"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}") train_dataset = train_dataset.map( packing_texts, batched=True, # batch_size=None, num_proc=args.preprocessing_num_workers, cache_file_name=f"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}/train_dataset", load_from_cache_file=args.use_dataset_cache, desc=f"Packing texts in chunks of {block_size} tokens" ) valid_dataset = valid_dataset.map( packing_texts, batched=True, # batch_size=None, num_proc=args.preprocessing_num_workers, cache_file_name=f"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}/valid_dataset", load_from_cache_file=args.use_dataset_cache, desc=f"Packing texts in chunks of {block_size} tokens" ) return train_dataset, valid_dataset ================================================ FILE: ft_llms/__init__.py ================================================ ================================================ FILE: ft_llms/llama_patch.py ================================================ from typing import List, Optional, Tuple import torch from torch import nn import torch.nn.functional as F import math import warnings import transformers from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv from peft.tuners.lora import LoraLayer try: from flash_attn.modules.mha import FlashSelfAttention except Exception: raise ModuleNotFoundError( "Please install FlashAttention first, e.g., with pip install flash-attn --no-build-isolation, Learn more at https://github.com/Dao-AILab/flash-attention#installation-and-features" ) def compute_flash_attention(flash_attn, q, k, v, attention_mask=None, head_mask=None): # q, k, v: [bs, seq_len, num_attention_heads, attn_head_size] # attention_mask (float): [bs, seq_len] batch_size, max_len = q.size(0), q.size(1) qkv = torch.stack([q, k, v], dim=2) dtype_in = qkv.dtype if dtype_in == torch.float32: qkv = qkv.to(torch.float16) # need to truncate in case input is fp32 cu_seqlens, max_seqlen = None, None if attention_mask is None: out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) else: # Limitation: non-contiguous attention mask will not be handled correctly # model will be able to pay attention between the first and last non-masked token, i.e. left- and right-side padding is supported. csums = (attention_mask >= 0).cumsum(dim=1) ends = csums.argmax(dim=1) + 1 starts = ends - csums.max(dim=1).values seqlens = ends - starts qkv = torch.cat([qkv[i, starts[i] : ends[i]] for i in range(batch_size)], dim=0) zero = torch.zeros_like(seqlens[:1]) # torch.tensor([0]) with correct dtype and device cu_seqlens = torch.cat([zero, seqlens.cumsum(dim=0)], dim=0).to(torch.int32) max_seqlen = seqlens.max().item() out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) # out: [num_unmasked_tokens, num_attention_heads, attn_head_size] seqs = [out[start:end] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] # stack and pad sequences together padded_seqs = [ F.pad(seqs[i], (0, 0) * (seqs[i].dim() - 1) + (starts[i], max_len - ends[i]), value=0.0) for i in range(batch_size) ] out = torch.stack(padded_seqs) if out.dtype != dtype_in: out = out.to(dtype_in) return out # AND https://github.com/LAION-AI/Open-Assistant/blob/04fa9a24b2a58c8885b8aa6a2eb02b18de6b4961/model/model_training/models/patching_llama.py def llama_forward_with_flash_attn( self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if not hasattr(self, 'att_fn'): self.att_fn = FlashSelfAttention(causal=True) flash_attn = self.att_fn if output_attentions: warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) if ( query_states.shape == key_states.shape ): # and (attention_mask is None or attention_mask[:, 0, -1, 0].min() >= 0): if attention_mask is not None: attention_mask = attention_mask[:, 0, -1] flash_attn.train(self.training) out_dtype = value_states.dtype q, k, v = ( query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), ) attn_output = compute_flash_attention(flash_attn, q, k, v, attention_mask) attn_output = attn_output.transpose(1, 2).to(out_dtype) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # [bsz, seq_len] return attention_mask def replace_attn_with_flash_attn(): cuda_major, cuda_minor = torch.cuda.get_device_capability() if cuda_major < 8: print( "Flash attention is only supported on Ampere or Hopper GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" ) # transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # _prepare_decoder_attention_mask # ) transformers.models.llama.modeling_llama.LlamaAttention.old_forward = transformers.models.llama.modeling_llama.LlamaAttention.forward transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_forward_with_flash_attn def unplace_flash_attn_with_attn(): import importlib import transformers print("Reloading llama model, unpatching flash attention") importlib.reload(transformers.models.llama.modeling_llama) # Adapted from https://github.com/tmm1/axolotl/blob/2eda9e02a9d15a7a3f92b41f257d9844d72fc220/src/axolotl/utils/models.py#L338 def upcast_layer_for_flash_attention(model, torch_dtype): # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. for name, module in model.named_modules(): if isinstance(module, LoraLayer): module.to(torch_dtype) if "norm" in name: module.to(torch_dtype) if "lm_head" in name or "embed_tokens" in name: if hasattr(module, "weight"): module.to(torch_dtype) return model ================================================ FILE: ft_llms/llms_finetune.py ================================================ import argparse import datasets import trl from trl import SFTTrainer from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, AutoConfig from accelerate import Accelerator from datasets import Dataset, load_from_disk import torch import logging import os from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PrefixTuningConfig, PromptEncoderConfig, IA3Config import pandas as pd import sys here = os.path.dirname(__file__) sys.path.append(os.path.join(here, '..')) from data.prepare import dataset_prepare from attack.utils import create_folder from transformers import LlamaTokenizer, get_scheduler import os os.environ['HTTP_PROXY'] = 'http://fuwenjie:19990621f@192.168.75.13:7890' os.environ['HTTPS_PROXY'] = 'http://fuwenjie:19990621f@192.168.75.13:7890' from utils import get_logger, constantlengthdatasetiter, print_trainable_parameters # trl.trainer.ConstantLengthDataset.__dict__["__iter__"] = constantlengthdatasetiter # setattr(trl.trainer.ConstantLengthDataset, "__iter__", constantlengthdatasetiter) logger = get_logger("finetune", "info") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model_name", type=str, default="meta-llama/Llama-2-7b-hf") parser.add_argument("-d", "--dataset_name", type=str, default="wikitext-2-raw-v1") parser.add_argument("-dc", "--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") parser.add_argument("--cache_path", type=str, default="./cache") parser.add_argument("--use_dataset_cache", action="store_true", default=False) parser.add_argument("--refer", action="store_true", default=False) parser.add_argument("--refer_data_source", type=str, default=None) parser.add_argument("--packing", action="store_true", default=False) parser.add_argument("-t", "--token", type=str, default=None) parser.add_argument("--split_model", action="store_true", default=False) parser.add_argument("--block_size", type=int, default=1024) parser.add_argument("--preprocessing_num_workers", type=int, default=1) parser.add_argument("--peft", type=str, default="lora") parser.add_argument("--lora_rank", type=int, default=64) parser.add_argument("--lora_alpha", type=int, default=16) parser.add_argument("--lora_dropout", type=float, default=0.1) parser.add_argument("--p_tokens", type=int, help="The number of virtual tokens for prefix-tuning or p-tuning", default=20) parser.add_argument("--p_hidden", type=int, help="The hidden size of the prompt encoder", default=128) parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4) parser.add_argument("--lr_scheduler_type", type=str, default="linear") parser.add_argument("--warmup_steps", type=int, default=0) parser.add_argument("--weight_decay", type=float, default=0) parser.add_argument("--output_dir", type=str, default="./ft_llms/checkpoints") parser.add_argument("--log_steps", type=int, default=10) parser.add_argument("--eval_steps", type=int, default=10) parser.add_argument("--save_epochs", type=int, default=10) parser.add_argument("-e", "--epochs", type=int, default=1) parser.add_argument("-b", "--batch_size", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_checkpointing", action="store_true", default=False) parser.add_argument("--trust_remote_code", action="store_true", default=False) parser.add_argument("--train_sta_idx", type=int, default=0) parser.add_argument("--train_end_idx", type=int, default=6000) parser.add_argument("--eval_sta_idx", type=int, default=0) parser.add_argument("--eval_end_idx", type=int, default=600) parser.add_argument("-s", "--save_limit", type=int, default=None) parser.add_argument("--use_int4", action="store_true", default=False) parser.add_argument("--use_int8", action="store_true", default=False) parser.add_argument("--disable_peft", action="store_true", default=False) parser.add_argument("--disable_flash_attention", action="store_true", help="Disable flash attention", default=False) parser.add_argument("--pad_token_id", default=None, type=int, help="The end of sequence token.") parser.add_argument("--add_eos_token", action="store_true", help="Add EOS token to tokenizer", default=False) parser.add_argument("--add_bos_token", action="store_true", help="Add BOS token to tokenizer", default=False) parser.add_argument("--validation_split_percentage", default=0.1, help="The percentage of the train set used as validation set in case there's no validation split") args = parser.parse_args() accelerator = Accelerator() if args.token is None: access_token = os.getenv("HF_TOKEN", "") else: access_token = args.token config = AutoConfig.from_pretrained(args.model_name, cache_dir=args.cache_path) config.use_cache = False config_dict = config.to_dict() model_type = config_dict["model_type"] use_flash_attention = False if not args.disable_flash_attention and model_type != "llama": logger.info("Model is not llama, disabling flash attention...") elif args.disable_flash_attention and model_type == "llama": logger.info("Model is llama, could be using flash attention...") elif not args.disable_flash_attention and torch.cuda.get_device_capability()[0] >= 8: from ft_llms.llama_patch import replace_attn_with_flash_attn logger.info("Using flash attention for llama...") replace_attn_with_flash_attn() use_flash_attention = True if "WANDB_PROJECT" not in os.environ: os.environ["WANDB_PROJECT"] = "GPT_finetuning" if args.split_model: logger.info("Splitting the model across all available devices...") kwargs = {"device_map": "auto"} else: kwargs = {"device_map": None} if model_type == "llama": tokenizer = LlamaTokenizer.from_pretrained(args.model_name, token=access_token, trust_remote_code=args.trust_remote_code, cache_dir=args.cache_path, add_eos_token=args.add_eos_token, add_bos_token=args.add_bos_token, use_fast=True) else: tokenizer = AutoTokenizer.from_pretrained(args.model_name, token=access_token, trust_remote_code=args.trust_remote_code, cache_dir=args.cache_path, add_eos_token=args.add_eos_token, add_bos_token=args.add_bos_token, use_fast=True) # THIS IS A HACK TO GET THE PAD TOKEN ID NOT TO BE EOS # good one for LLama is 18610 if args.pad_token_id is not None: logger.info("Using pad token id %d", args.pad_token_id) tokenizer.pad_token_id = args.pad_token_id if tokenizer.pad_token_id is None: logger.info("Pad token id is None, setting to eos token id...") tokenizer.pad_token_id = tokenizer.eos_token_id block_size = args.block_size logger.info("Using a block size of %d", block_size) if args.use_int4: logger.info("Using int4 quantization") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, bnb_4bit_use_double_quant=True, ) optimizer = "adamw_bnb_8bit" elif args.use_int8: logger.info("Using int8 quantization") bnb_config = BitsAndBytesConfig( load_in_8bit=True, ) optimizer = "adamw_bnb_8bit" else: logger.info("Using no quantization") bnb_config = None optimizer = "adamw_torch" if args.peft == "lora": peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout ) elif args.peft == "prefix-tuing": peft_config = PrefixTuningConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, num_virtual_tokens=args.p_tokens, encoder_hidden_size=args.p_hidden) elif args.peft == "p-tuing": peft_config = PromptEncoderConfig( task_type=TaskType.CAUSAL_LM, num_virtual_tokens=args.p_tokens, encoder_hidden_size=args.p_hidden) elif args.peft == "ia3": peft_config = IA3Config( peft_type="IA3", task_type=TaskType.CAUSAL_LM, target_modules=["k_proj", "v_proj", "down_proj"], feedforward_modules=["down_proj"], ) torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 model = AutoModelForCausalLM.from_pretrained(args.model_name, token=access_token, quantization_config=bnb_config, trust_remote_code=args.trust_remote_code, cache_dir=args.cache_path, torch_dtype=torch_dtype, config=config, **kwargs) if use_flash_attention: from ft_llms.llama_patch import llama_forward_with_flash_attn assert model.model.layers[ 0].self_attn.forward.__doc__ == llama_forward_with_flash_attn.__doc__, "Model is not using flash attention" if not args.disable_peft: logger.info("Using PEFT...") if args.use_int4 or args.use_int8: logger.info("Preparing model for kbit training...") model = prepare_model_for_kbit_training(model) if use_flash_attention: from ft_llms.llama_patch import upcast_layer_for_flash_attention logger.info("Upcasting flash attention layers...") model = upcast_layer_for_flash_attention(model, torch_dtype) logger.info("Getting PEFT model...") model = get_peft_model(model, peft_config) else: logger.info("Using Full Finetuning") print_trainable_parameters(model) if args.refer_data_source is not None: args.model_name = args.refer_data_source if args.model_name == "/mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348": args.model_name = "decapoda-research/llama-7b-hf" with accelerator.main_process_first(): train_dataset, valid_dataset = dataset_prepare(args, tokenizer=tokenizer) if args.refer: train_dataset = None refer_data_path = f"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}/refer@{args.model_name}/" train_dataset = load_from_disk(refer_data_path) train_dataset = Dataset.from_dict(train_dataset[args.train_sta_idx:args.train_end_idx]) valid_dataset = Dataset.from_dict(valid_dataset[args.eval_sta_idx:args.eval_end_idx]) # train_dataset = load_from_disk("/mnt/data0/fuwenjie/MIA-LLMs/cache/ag_news/None/refer@gpt2") logger.info(f"Training with {Accelerator().num_processes} GPUs") training_args = TrainingArguments( do_train=True, do_eval=True, output_dir=args.output_dir, dataloader_drop_last=True, evaluation_strategy="steps", save_strategy="steps", logging_strategy="steps", num_train_epochs=args.epochs, eval_steps=args.eval_steps, save_steps=args.save_epochs, logging_steps=args.log_steps, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size * 2, optim=optimizer, learning_rate=args.learning_rate, lr_scheduler_type=args.lr_scheduler_type, warmup_steps=args.warmup_steps, gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_checkpointing=args.gradient_checkpointing, weight_decay=args.weight_decay, adam_epsilon=1e-6, report_to="wandb", load_best_model_at_end=False, save_total_limit=args.save_limit, bf16=True if torch.cuda.is_bf16_supported() else False, fp16=False if torch.cuda.is_bf16_supported() else True, ) # get trainer trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=valid_dataset, dataset_text_field="text", tokenizer=tokenizer, ) # train trainer.train() ================================================ FILE: ft_llms/refer_data_generate.py ================================================ import os import numpy as np import torch from tqdm import tqdm from torch.utils.data import DataLoader import logging import random import sys here = os.path.dirname(__file__) sys.path.append(os.path.join(here, '..')) from data.prepare import dataset_prepare from attack.utils import Dict import argparse import yaml import datasets from datasets import Image, Dataset, load_from_disk, concatenate_datasets from accelerate import Accelerator from accelerate.logging import get_logger import trl from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig, TrainingArguments, AutoConfig, LlamaTokenizer import os os.environ['HTTP_PROXY'] = 'http://fuwenjie:19990621f@localhost:7890' os.environ['HTTPS_PROXY'] = 'http://fuwenjie:19990621f@localhost:7890' # Load config file accelerator = Accelerator() parser = argparse.ArgumentParser() parser.add_argument("-m", "--model_name", type=str, default="meta-llama/Llama-2-7b-hf") parser.add_argument("-tm", "--target_model", type=str, default="meta-llama/Llama-2-7b-hf") parser.add_argument("-d", "--dataset_name", type=str, default="wikitext-2-raw-v1") parser.add_argument("-dc", "--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") parser.add_argument("--cache_path", type=str, default="./cache") parser.add_argument("--use_dataset_cache", action="store_true", default=True) parser.add_argument("--packing", action="store_true", default=True) parser.add_argument("--block_size", type=int, default=128) parser.add_argument("--preprocessing_num_workers", type=int, default=1) parser.add_argument("--validation_split_percentage", default=0.1, help="The percentage of the train set used as validation set in case there's no validation split") cfg = parser.parse_args() print(accelerator.device) config = AutoConfig.from_pretrained(cfg.model_name) config.use_cache = False bnb_config = None torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 model = AutoModelForCausalLM.from_pretrained(cfg.target_model, quantization_config=bnb_config, torch_dtype=torch_dtype, local_files_only=True, config=config, cache_dir=cfg.cache_path) model_type = config.to_dict()["model_type"] if model_type == "llama": tokenizer = LlamaTokenizer.from_pretrained(cfg.model_name) else: tokenizer = AutoTokenizer.from_pretrained(cfg.model_name) if tokenizer.pad_token_id is None: print("Pad token id is None, setting to eos token id...") tokenizer.pad_token_id = tokenizer.eos_token_id # Load datasets train_dataset, valid_dataset = dataset_prepare(cfg, tokenizer=tokenizer) prompt_dataset = Dataset.from_dict(train_dataset[10000:20000]) prompt_dataloader = DataLoader(prompt_dataset, batch_size=1) model, prompt_dataloader = accelerator.prepare(model, prompt_dataloader) generated_dataset = {"text": []} for text in tqdm(prompt_dataloader): prompt = (text["text"]) input_ids = tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(accelerator.device) clipped_ids = input_ids[:, :16] if hasattr(model, "module"): gen_tokens = model.module.generate( clipped_ids, num_beams=1, do_sample=True, max_length=input_ids.size(-1), ) else: gen_tokens = model.generate( clipped_ids, num_beams=1, do_sample=True, max_length=input_ids.size(-1), ) if model_type == "llama": gen_tokens = gen_tokens[:, 1:] print(model(gen_tokens, labels=gen_tokens).loss) gen_text = tokenizer.batch_decode(gen_tokens) generated_dataset["text"].extend(gen_text) generated_dataset = Dataset.from_dict(generated_dataset) if cfg.model_name == "/mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348": cfg.model_name = "decapoda-research/llama-7b-hf" save_dir = f"{cfg.cache_path}/{cfg.dataset_name}/{cfg.dataset_config_name}/refer@{cfg.model_name}/" generated_dataset.save_to_disk(save_dir + f"{accelerator.device}") accelerator.wait_for_everyone() if accelerator.is_main_process: concatenated_dataset = None for sub_dir in os.listdir(save_dir): data_path = os.path.join(save_dir, sub_dir) if os.path.isdir(data_path): if concatenated_dataset is None: concatenated_dataset = load_from_disk(data_path) else: dataset = load_from_disk(data_path) concatenated_dataset = concatenate_datasets([concatenated_dataset, dataset]) concatenated_dataset.save_to_disk(save_dir) ================================================ FILE: ft_llms/run_llama.sh ================================================ # ag_news accelerate launch ./ft_llms/llms_finetune.py \ --output_dir ./ft_llms/llama/ag_news/target/ \ --block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \ -d ag_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \ -e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \ --train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000 # refer candidate accelerate launch ./ft_llms/llms_finetune.py \ --output_dir ./ft_llms/llama/ag_news/candidate/ \ --block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \ -d JulesBelveze/tldr_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \ -e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \ --train_sta_idx=0 --train_end_idx=4767 --eval_sta_idx=0 --eval_end_idx=538 # refer oracle accelerate launch ./ft_llms/llms_finetune.py \ --output_dir ./ft_llms/llama/ag_news/oracle/ \ --block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \ -d ag_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \ -e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \ --train_sta_idx=10000 --train_end_idx=20000 --eval_sta_idx=1000 --eval_end_idx=2000 accelerate launch refer_data_generate.py \ -tm /mnt/data0/fuwenjie/MIA-LLMs/ft_llms/llama/ag_news/target/checkpoint-3000 \ -m decapoda-research/llama-7b-hf -d ag_news # refer prompt accelerate launch ./ft_llms/llms_finetune.py --refer \ --output_dir ./ft_llms/llama/ag_news/refer/ \ --block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \ -d ag_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \ -e 2 -b 4 -lr 5e-5 --gradient_accumulation_steps 1 \ --train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000 ================================================ FILE: ft_llms/utils.py ================================================ import logging from typing_extensions import Literal from rich.logging import RichHandler from torch.utils.data import IterableDataset import warnings import random import torch def get_logger(name: str, level: Literal["info", "warning", "debug"]) -> logging.Logger: rich_handler = RichHandler(level=logging.INFO, rich_tracebacks=True, markup=True) logger = logging.getLogger(name) logger.setLevel(logging._nameToLevel[level.upper()]) if not logger.handlers: logger.addHandler(rich_handler) logger.propagate = False return logger def print_trainable_parameters(model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" ) def constantlengthdatasetiter(self): iterator = iter(self.dataset) more_examples = True while more_examples: buffer, buffer_len = [], 0 while True: if buffer_len >= self.max_buffer_size: break try: buffer.append(self.formatting_func(next(iterator))) buffer_len += len(buffer[-1]) except StopIteration: if self.infinite: iterator = iter(self.dataset) warnings.warn("The dataset reached end and the iterator is reset to the start.") break else: more_examples = False break tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input + [self.concat_token_id]) examples = [] for i in range(0, len(all_token_ids), self.seq_length): input_ids = all_token_ids[i : i + self.seq_length] if len(input_ids) == self.seq_length: examples.append(input_ids) if self.shuffle: random.shuffle(examples) for example in examples: self.current_size += 1 yield { "input_ids": torch.LongTensor(example), "labels": torch.LongTensor(example), } ================================================ FILE: requirements.txt ================================================ accelerate==0.23.0 datasets==2.14.5 deepspeed==0.10.1+46d859a7 deepspeed_mii==0.0.7+0acf569 flash_attn==2.2.1 huggingface_hub==0.16.4 matplotlib==3.5.3 nlpaug==1.1.11 nltk==3.4.5 numpy==1.24.4 opacus==1.4.0 openai==1.3.5 pandas==2.0.3 peft==0.6.0.dev0 python_dateutil==2.8.2 pyvacy==0.0.32 PyYAML==6.0.1 PyYAML==6.0.1 rich==13.7.0 scikit_learn==1.3.0 scipy==1.11.4 seaborn==0.13.0 spacy==3.7.1 tqdm==4.66.1 transformers==4.34.0.dev0 trl==0.7.1 typing_extensions==4.8.0