[
  {
    "path": ".github/workflows/main.yml",
    "content": "name: Clean Commit History\n\non: workflow_dispatch\n\n\njobs:\n  clean_history:\n    runs-on: ubuntu-latest\n\n    steps:\n    - name: Checkout the repository\n      uses: actions/checkout@v2\n      with:\n        fetch-depth: 0\n\n    - name: Remove commit history\n      run: |\n        # Set up Git user\n        git config --global user.name \"Wenjie Fu\"\n        git config --global user.email \"wjfu99@outlook.com\"\n        git branch\n        ls\n        # Create a new orphan branch with the latest commit\n        git checkout --orphan latest_commit\n        # Add all files to the new branch\n        git add -A\n        # Commit the changes\n        git commit -m \"Final Commit\"\n\n        # Rename the current branch to main\n        git branch -m latest_commit\n\n        # Force push to update the repository\n        git push -f origin latest_commit:main"
  },
  {
    "path": ".gitignore",
    "content": "/cache\n# /ft_llms/cache\n# /ft_llms/checkpoints\n/ft_llms/*/\n/attack/*/\n/wandb\n/abc\n/.vscode\n/data/*/\n/test\n# /detect-gpt\n# /Finetune_LLMs\n\n# ft_llms/{cache, checkpoints}\n\n"
  },
  {
    "path": "README.md",
    "content": "# (NeurIPS'24) Practical Membership Inference Attacks against Fine-tuned Large Language Models via Self-prompt Calibration\n\n- [Requirements](#requirements)\n- [Target Model Fine-tuning](#target-model-fine-tuning)\n- [Self-prompt Reference Model Fine-tuning](#self-prompt-reference-model-fine-tuning)\n- [Run SPV-MIA](#run-spv-mia)\n\nThis is the official implementation of the paper \"Practical Membership Inference Attacks against Fine-tuned \nLarge Language Models via Self-prompt Calibration\".\nThe proposed Membership Inference Attack based on Self-calibrated Probabilistic Variation (SPV-MIA) is implemented as follows.\n\n![The overall architecture of _SPV-MIA_](./Framework.png)\n\n## Requirements\n\n- torch>=1.11.0\n- accelerate==0.20.3\n- transformers==4.34.0.dev0\n- trl==0.7.1\n- datasets==2.13.1\n- numpy>=1.23.4\n- scikit-learn>=1.1.3\n- pyyaml>=6.0\n- tqdm>=4.64.1\n\nDependency can be installed with the following command:\n\n```bash\npip install -r requirements.txt\n```\n\n\n## Target Model Fine-tuning\n  All large language models (LLMs) are built on the top of [transformers](https://huggingface.co/docs/transformers/index), \n  a go-to library for state-of-the-art transformer models, on which you can fine-tune arbitrary well-known LLMs you want,\n  including LLaMA, GPT-series, Falcon, etc.\n  We recommend training LLMs with multi-GPU and [accelerate](https://huggingface.co/docs/accelerate/index), \n  a library that enables the same PyTorch code to be run across any distributed configuration:\n  ```bash\n  accelerate launch ./ft_llms/llms_finetune.py \\\n  --output_dir ./ft_llms/*pretrained_model_name*/*dataset_name*/target/ \\\n  --block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \\\n  -d *dataset_name* -m *pretrained_model_name* --packing --use_dataset_cache \\\n  -e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \\\n  --train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000\n  ```\n\nPlease replace \\*pretrained_model_name\\* and \\*dataset_name\\* with the names of pretrained LLM and training dataset, such as `decapoda-research/llama-7b-hf` and `ag_news`.\n\n### Recommended pretrained models\n- GPT-2 (1.5B) (https://huggingface.co/gpt2-xl)\n- GPT-J (https://huggingface.co/EleutherAI/gpt-j-6b)\n- Falcon (https://huggingface.co/tiiuae/falcon-7b)\n- LLaMA (https://huggingface.co/decapoda-research/llama-7b-hf) [^1]\n\n[^1]: This third-party repo `decapoda-research/llama-7b-hf` seems to be deleted by unknown reasons, using forked repos [luodian/llama-7b-hf](https://huggingface.co/luodian/llama-7b-hf) \nor [baffo32/decapoda-research-llama-7B-hf](https://huggingface.co/baffo32/decapoda-research-llama-7B-hf) as alternatives.\n### Recommended datasets\n- Ag News (https://huggingface.co/datasets/ag_news)\n- Wikitext-103 (https://huggingface.co/datasets/wikitext) [^2]\n- Xsum (https://huggingface.co/datasets/xsum)\n\n[^2]: Please add an additional argument `--dataset_config_name wikitext-2-raw-v1` to specify this dataset.\n## Self-prompt Reference Model Fine-tuning\n  Before fine-tuning the self-prompt reference model, the reference dataset can be sampled via our proposed \n  self-prompt approach over the fine-tuned LLM. \n  ```bash\n  accelerate launch refer_data_generate.py \\\n  -tm *fine_tuned_model* \\\n  -m *pretrained_model_name* -d *dataset_name*\n  ```\n  Replace \\*fine_tuned_model\\* with the directory of the fine-tuned target model (i.e., the output directory of \n  the [Target Model Fine-tuning](#target-model-fine-tuning) phase). \n\n Then fine-tune the self-prompt reference model in the same manner as the target model, but with a smaller training epoch:\n```bash\naccelerate launch ./ft_llms/llms_finetune.py --refer \\\n--output_dir ./ft_llms/*pretrained_model_name*/*dataset_name*/refer/ \\\n--block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \\\n-d *dataset_name* -m *pretrained_model_name* --packing --use_dataset_cache \\\n-e 2 -b 4 -lr 5e-5 --gradient_accumulation_steps 1 \\\n--train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000\n```\n\n\n## Run SPV-MIA\nAfter accomplishing the preliminary operations, here is the command for deploying SPV-MIA on the target model.\n```bash\npython attack.py\n```\n"
  },
  {
    "path": "attack/__init__.py",
    "content": ""
  },
  {
    "path": "attack/attack_model.py",
    "content": "import os\nimport random\n\nimport torch\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom attack import utils\nfrom attack.utils import Dict\nimport numpy as np\nfrom copy import deepcopy\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nimport nlpaug.augmenter.word as naw\nimport nlpaug.augmenter.sentence as nas\nfrom sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve, f1_score\nfrom itertools import cycle\nimport matplotlib.pyplot as plt\nimport re\nimport seaborn as sns\nfrom functools import partial\n\nlogger = get_logger(__name__, \"INFO\")\n\nPATH = os.getcwd()\n\naccelerator = Accelerator()\nclass AttackModel:\n    def __init__(self, target_model, tokenizer, datasets, reference_model, shadow_model, cfg, mask_model=None, mask_tokenizer=None):\n        self.target_model = target_model\n        self.tokenizer = tokenizer\n        self.datasets = datasets\n        self.kind = cfg['attack_kind']\n        self.cfg = cfg\n        if mask_model is not None:\n            self.mask_model = mask_model\n            self.mask_tokenizer = mask_tokenizer\n            self.pattern = re.compile(r\"<extra_id_\\d+>\")\n        if shadow_model is not None and cfg['attack_kind'] == \"nn\":\n            self.shadow_model = shadow_model\n            self.is_model_training = False\n        if reference_model is not None:\n            self.reference_model = reference_model\n\n    def llm_eval(self, model, data_loader, cfg, idx_rate, perturb_fn=None, refer_model=None):\n        model.eval()\n        losses = []\n        ref_losses = []\n        token_lens = []\n        for iteration, texts in enumerate(data_loader):\n            texts = texts[\"text\"]\n            if cfg[\"maximum_samples\"] is not None:\n                if iteration * accelerator.num_processes >= cfg[\"maximum_samples\"]:\n                    break\n            if perturb_fn is not None:\n                texts = perturb_fn(texts)\n            token_ids = self.tokenizer(texts, return_tensors=\"pt\", padding=True).to(accelerator.device)\n            labels = token_ids.input_ids\n            with torch.no_grad():\n                outputs = model(**token_ids, labels=labels)\n                ref_outputs = refer_model(**token_ids, labels=labels)\n            loss = outputs.loss\n            ref_loss = ref_outputs.loss\n            token_lens.append(accelerator.gather(torch.tensor(token_ids.input_ids.size()[-1]).reshape(-1, 1).to(accelerator.device)).detach().cpu().numpy()) # TODO: may cause bug when running attacks in paralell.\n            losses.append(accelerator.gather(loss.reshape(-1, 1)).detach().cpu().to(torch.float32).numpy())\n            ref_losses.append(accelerator.gather(ref_loss.reshape(-1, 1)).detach().cpu().to(torch.float32).numpy())\n            # print(f\"{accelerator.device}@{texts}\")\n            # print(f\"time duration: {time.time() - start_time}s\")\n        losses = np.concatenate(losses, axis=0)\n        ref_losses = np.concatenate(ref_losses, axis=0)\n        token_lens = np.concatenate(token_lens, axis=0)\n        # token_lens = np.array(token_lens, dtype=np.int32)\n        return losses, ref_losses, token_lens\n\n    def eval_perturb(self, model, dataset, cfg):\n        \"\"\"\n        Evaluate the loss of the perturbed data\n\n        :param dataset: N*channel*width*height\n        :return: losses: N*1; var_losses: N*1; per_losses: N*Mask_Num; ori_losses: N*1\n        \"\"\"\n        per_losses = []\n        ref_per_losses = []\n        ori_losses = []\n        ref_ori_losses = []\n        ori_dataset = deepcopy(dataset)\n        for i in tqdm(range(0, cfg[\"perturbation_number\"])):\n            idx_rate = i / cfg[\"perturbation_number\"] * 0.7\n            ori_loss, ref_ori_loss, ori_token_len = self.llm_eval(model, ori_dataset, cfg, idx_rate, refer_model=self.reference_model)\n            ori_losses.append(ori_loss)\n            ref_ori_losses.append(ref_ori_loss)\n            perturb_fn = partial(self.sentence_perturbation, idx_rate=idx_rate)\n            sampled_per_losses = []\n            sampled_ref_per_losses = []\n            for _ in range(cfg[\"sample_number\"]):\n                per_loss, ref_per_loss, per_token_len = self.llm_eval(model, ori_dataset, cfg, idx_rate, perturb_fn=perturb_fn, refer_model=self.reference_model)\n                sampled_per_losses.append(per_loss)\n                sampled_ref_per_losses.append(ref_per_loss)\n            sampled_per_losses = np.concatenate(sampled_per_losses, axis=-1)\n            sampled_ref_per_losses = np.concatenate(sampled_ref_per_losses, axis=-1)\n            per_losses.append(np.expand_dims(sampled_per_losses, axis=-1))\n            ref_per_losses.append(np.expand_dims(sampled_ref_per_losses, axis=-1))\n        ori_losses = np.concatenate(ori_losses, axis=-1)\n        ref_ori_losses = np.concatenate(ref_ori_losses, axis=-1)\n        per_losses = np.concatenate(per_losses, axis=-1)\n        var_losses = per_losses - np.expand_dims(ori_losses, axis=-2)\n        ref_per_losses = np.concatenate(ref_per_losses, axis=-1) if cfg[\"calibration\"] else None\n        ref_var_losses = ref_per_losses - np.expand_dims(ref_ori_losses, axis=-2) if cfg[\"calibration\"] else None\n\n        output = (Dict(\n            per_losses=per_losses,\n            ori_losses=ori_losses,\n            var_losses=var_losses,\n        ),\n        Dict(\n            ref_per_losses=ref_per_losses,\n            ref_ori_losses=ref_ori_losses,\n            ref_var_losses=ref_var_losses,\n        ))\n        return output\n\n    def data_prepare(self, kind, cfg):\n        logger.info(\"Preparing data...\")\n        data_path = os.path.join(PATH, cfg[\"attack_data_path\"], f\"attack_data_{cfg['model_name']}@{cfg['dataset_name']}\")\n        target_model = getattr(self, kind + \"_model\")\n        mem_data = self.datasets[kind][\"train\"]\n        nonmem_data = self.datasets[kind][\"valid\"]\n\n        mem_path = os.path.join(data_path, kind, \"mem_feat.npz\")\n        nonmem_path = os.path.join(data_path, kind, \"nonmen_feat.npz\")\n        ref_mem_path = os.path.join(data_path, kind, \"ref_mem_feat.npz\")\n        ref_nonmem_path = os.path.join(data_path, kind, \"ref_nonmen_feat.npz\")\n\n        pathlist = (mem_path, nonmem_path, ref_mem_path, ref_nonmem_path) if cfg[\"calibration\"] else (mem_path, nonmem_path)\n\n        if not utils.check_files_exist(*pathlist) or not cfg[\"load_attack_data\"]:\n\n            logger.info(\"Generating feature vectors for member data...\")\n            mem_feat, ref_mem_feat = self.eval_perturb(target_model, mem_data, cfg)\n            if accelerator.is_main_process:\n                utils.save_dict_to_npz(mem_feat, mem_path)\n                if cfg[\"calibration\"]:\n                    utils.save_dict_to_npz(ref_mem_feat, ref_mem_path)\n\n            logger.info(\"Generating feature vectors for non-member data...\")\n            nonmem_feat, ref_nonmem_feat = self.eval_perturb(target_model, nonmem_data, cfg)\n            if accelerator.is_main_process:\n                utils.save_dict_to_npz(nonmem_feat, nonmem_path)\n                if cfg[\"calibration\"]:\n                    utils.save_dict_to_npz(ref_nonmem_feat, ref_nonmem_path)\n\n            logger.info(\"Saving feature vectors...\")\n\n        else:\n            logger.info(\"Loading feature vectors...\")\n            mem_feat = utils.load_dict_from_npz(mem_path)\n            ref_mem_feat = utils.load_dict_from_npz(ref_mem_path) if cfg[\"calibration\"] else None\n            nonmem_feat = utils.load_dict_from_npz(nonmem_path)\n            ref_nonmem_feat = utils.load_dict_from_npz(ref_nonmem_path) if cfg[\"calibration\"] else None\n\n        logger.info(\"Data preparation complete.\")\n\n        return Dict(\n            mem_feat=mem_feat,\n            nonmem_feat=nonmem_feat,\n            ref_mem_feat=ref_mem_feat,\n            ref_nonmem_feat=ref_nonmem_feat,\n                    )\n\n    def feat_prepare(self, info_dict, cfg):\n        # mem_info = info_dict.mem_feat\n        # ref_mem_info = info_dict.ref_mem_feat\n        if cfg[\"calibration\"]:\n            get_prob = lambda logprob: np.power(np.e, -logprob)\n            mem_feat = ((get_prob(info_dict.mem_feat.per_losses).mean((-1, -2)) - get_prob(info_dict.mem_feat.ori_losses).mean(-1)) -\n                        (get_prob(info_dict.ref_mem_feat.ref_per_losses).mean((-1, -2)) - get_prob(info_dict.ref_mem_feat.ref_ori_losses).mean(-1)))\n            nonmem_feat = ((get_prob(info_dict.nonmem_feat.per_losses).mean((-1, -2)) - get_prob(info_dict.nonmem_feat.ori_losses).mean(-1)) -\n                           (get_prob(info_dict.ref_nonmem_feat.ref_per_losses).mean((-1, -2)) - get_prob(info_dict.ref_nonmem_feat.ref_ori_losses).mean(-1)))\n        else:\n            mem_feat = info_dict.mem_feat.var_losses / info_dict.mem_feat.ori_losses\n            nonmem_feat = info_dict.nonmem_feat.var_losses / info_dict.nonmem_feat.ori_losses\n\n\n        if cfg[\"attack_kind\"] == \"stat\":\n            # mem_feat = mem_feat[:, :, 0]\n            # nonmem_feat = nonmem_feat[:, :, 0]\n            # mem_feat[np.isnan(mem_feat)] = 0\n            # nonmem_feat[np.isnan(nonmem_feat)] = 0\n            feat = - np.concatenate([mem_feat, nonmem_feat])\n            ground_truth = np.concatenate([np.zeros(mem_feat.shape[0]), np.ones(nonmem_feat.shape[0])]).astype(int)\n\n        return feat, ground_truth\n\n    def conduct_attack(self, cfg):\n        save_path = os.path.join(PATH, cfg[\"attack_data_path\"], f\"attack_data_{cfg['model_name']}@{cfg['dataset_name']}\",\n                                 f\"roc_{cfg['attack_kind']}.npz\")\n\n        raw_info = self.data_prepare(\"target\", cfg)\n        feat, ground_truth = self.feat_prepare(raw_info, cfg)\n        # self.distinguishability_plot(raw_info['mem_feat']['ori_losses'].mean(-1),\n        #                              raw_info['nonmem_feat']['ori_losses'].mean(-1))\n        # self.distinguishability_plot(feat[:1000], feat[-1000:])\n        self.eval_attack(ground_truth, -feat, path=save_path)\n\n    def tokenize_and_mask(self, text, span_length, pct, idx_rate, ceil_pct=False):\n        cfg = self.cfg\n        tokens = text.split(' ')\n        mask_string = '<<<mask>>>'\n        perturb_start_idx = int(len(tokens) * idx_rate)\n\n        n_spans = pct * len(tokens) / (span_length + cfg.buffer_size * 2)\n        if ceil_pct:\n            n_spans = np.ceil(n_spans)\n        n_spans = int(n_spans)\n\n        n_masks = 0\n        while n_masks < n_spans:\n            start = np.random.randint(0, len(tokens) - span_length)\n            end = start + span_length\n            search_start = max(0, start - cfg.buffer_size)\n            search_end = min(len(tokens), end + cfg.buffer_size)\n            if mask_string not in tokens[search_start:search_end]:\n                tokens[start:end] = [mask_string]\n                n_masks += 1\n\n        # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments\n        num_filled = 0\n        for idx, token in enumerate(tokens):\n            if token == mask_string:\n                tokens[idx] = f'<extra_id_{num_filled}>'\n                num_filled += 1\n        assert num_filled == n_masks, f\"num_filled {num_filled} != n_masks {n_masks}\"\n        text = ' '.join(tokens)\n        return text\n\n    @staticmethod\n    def count_masks(texts):\n        return [len([x for x in text.split() if x.startswith(\"<extra_id_\")]) for text in texts]\n\n    def replace_masks(self, texts):\n        cfg = self.cfg\n        n_expected = self.count_masks(texts)\n        stop_id = self.mask_tokenizer.encode(f\"<extra_id_{max(n_expected)}>\")[0]\n        tokens = self.mask_tokenizer(texts, return_tensors=\"pt\", padding=True).to(accelerator.device)\n        outputs = self.mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=cfg.mask_top_p,\n                                      num_return_sequences=1, eos_token_id=stop_id)\n        return self.mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)\n\n    def extract_fills(self, texts):\n        # remove <pad> from beginning of each text\n        texts = [x.replace(\"<pad>\", \"\").replace(\"</s>\", \"\").strip() for x in texts]\n\n        # return the text in between each matched mask token\n        extracted_fills = [self.pattern.split(x)[1:-1] for x in texts]\n\n        # remove whitespace around each fill\n        extracted_fills = [[y.strip() for y in x] for x in extracted_fills]\n\n        return extracted_fills\n\n    def apply_extracted_fills(self, masked_texts, extracted_fills):\n        # split masked text into tokens, only splitting on spaces (not newlines)\n        tokens = [x.split(' ') for x in masked_texts]\n\n        n_expected = self.count_masks(masked_texts)\n\n        # replace each mask token with the corresponding fill\n        for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):\n            if len(fills) < n:\n                tokens[idx] = []\n            else:\n                for fill_idx in range(n):\n                    text[text.index(f\"<extra_id_{fill_idx}>\")] = fills[fill_idx]\n\n        # join tokens back into text\n        texts = [\" \".join(x) for x in tokens]\n        return texts\n\n    def sentence_perturbation(self, texts, idx_rate):\n        cfg = self.cfg\n        masked_texts = [self.tokenize_and_mask(x, cfg.span_length, cfg.pct, idx_rate, cfg.ceil_pct) for x in texts]\n        raw_fills = self.replace_masks(masked_texts)\n        extracted_fills = self.extract_fills(raw_fills)\n        perturbed_texts = self.apply_extracted_fills(masked_texts, extracted_fills)\n\n        # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again\n        attempts = 1\n        while '' in perturbed_texts:\n            idxs = [idx for idx, x in enumerate(perturbed_texts) if x == '']\n            print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].')\n            masked_texts = [self.tokenize_and_mask(x, cfg.span_length, cfg.pct, idx_rate, cfg.ceil_pct) for idx, x in enumerate(texts) if idx in idxs]\n            raw_fills = self.replace_masks(masked_texts)\n            extracted_fills = self.extract_fills(raw_fills)\n            new_perturbed_texts = self.apply_extracted_fills(masked_texts, extracted_fills)\n            for idx, x in zip(idxs, new_perturbed_texts):\n                perturbed_texts[idx] = x\n            attempts += 1\n        return perturbed_texts\n\n    @staticmethod\n    def eval_attack(y_true, y_scores, plot=True, path=None):\n        if type(y_true) == torch.Tensor:\n            y_true, y_scores = utils.tensor_to_ndarray(y_true, y_scores)\n        fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n        if path is not None:\n            np.savez(path, fpr=fpr, tpr=tpr)\n        auc_score = roc_auc_score(y_true, y_scores)\n        logger.info(f\"AUC on the target model: {auc_score}\")\n\n        # Finding the threshold point where FPR + TPR equals 1\n        threshold_point = tpr[np.argmin(np.abs(tpr - (1 - fpr)))]\n        logger.info(f\"ASR on the target model: {threshold_point}\")\n\n        # Finding the threshold point where FPR + TPR equals 1\n        tpr_1fpr = tpr[np.argmin(np.abs(fpr - 0.01))]\n        logger.info(f\"TPR@1%FPR on the target model: {tpr_1fpr}\")\n\n\n        if plot:\n            # plot the ROC curve\n            plt.plot(fpr, tpr, label=f'ROC curve (AUC = {auc_score}; ASR = {threshold_point})')\n            plt.xlabel('False Positive Rate')\n            plt.ylabel('True Positive Rate')\n            plt.legend()\n            # plot the no-skill line for reference\n            plt.plot([0, 1], [0, 1], linestyle='--')\n            # show the plot\n            plt.show()\n"
  },
  {
    "path": "attack/utils.py",
    "content": "import logging\nfrom typing_extensions import Literal\nfrom rich.logging import RichHandler\nimport os\nimport torch\nimport numpy as np\n\n\ndef get_logger(name: str, level: Literal[\"info\", \"warning\", \"debug\"]) -> logging.Logger:\n    rich_handler = RichHandler(level=logging.INFO, rich_tracebacks=True, markup=True)\n\n    logger = logging.getLogger(name)\n    logger.setLevel(logging._nameToLevel[level.upper()])\n\n    if not logger.handlers:\n        logger.addHandler(rich_handler)\n\n    logger.propagate = False\n\n    return logger\n\nclass Dict(dict):\n    def __getattr__(self, name):\n        if name in self:\n            return  self[name]\n        raise AttributeError(f\"'Dict' object has no attribute '{name}'\")\n    def __setattr__(self, name, value):\n        super().__setitem__(name, value)\n        super().__setattr__(name, value)\n\n    def __setitem__(self, key, value):\n        super().__setitem__(key, value)\n        super().__setattr__(key, value)\n\ndef check_files_exist(*file_paths):\n    \"\"\"\n    Check if the input file(s) exist at the given file path(s).\n\n    Parameters:\n        *file_paths (str): One or more strings representing the file path(s) to check.\n\n    Returns:\n        bool: True if all the files exist, False otherwise.\n    \"\"\"\n    for file_path in file_paths:\n        if not os.path.isfile(file_path):\n            return False\n    return True\n\n\ndef create_folder(folder_path):\n    if not os.path.exists(folder_path):\n        os.makedirs(folder_path)\n        print(f\"Folder '{folder_path}' created.\")\n    else:\n        print(f\"Folder '{folder_path}' already exists.\")\n\n\ndef save_dict_to_npz(my_dict, file_path):\n    \"\"\"\n    Saves a dictionary with ndarray values to an npz file.\n\n    Parameters:\n        my_dict (dict): A dictionary with ndarray values to be saved.\n        file_path (str): The file path to save the dictionary values to.\n\n    Returns:\n        None\n    \"\"\"\n    folder = os.path.dirname(file_path)\n    if not os.path.exists(folder):\n        os.makedirs(folder)\n    with open(file_path, 'wb') as f:\n        np.savez(f, **my_dict)\n\n\ndef load_dict_from_npz(file_path):\n    \"\"\"\n    Loads a dictionary with ndarray values from an npz file.\n\n    Parameters:\n        file_path (str): The file path of the npz file to load.\n\n    Returns:\n        dict: A dictionary containing the values stored in the npz file.\n    \"\"\"\n    with np.load(file_path) as data:\n        my_dict = Dict({key: value for key, value in data.items() if isinstance(value, np.ndarray)})\n    return my_dict\n\n\ndef ndarray_to_tensor(*ndarrays):\n    \"\"\"\n    Converts multiple numpy ndarrays to PyTorch tensors.\n\n    Parameters:\n        *ndarrays (numpy.ndarray): Multiple numpy ndarrays to convert.\n\n    Returns:\n        tuple of torch.Tensor: A tuple of PyTorch tensors with the same data as the input ndarrays.\n    \"\"\"\n    tensors = tuple(torch.from_numpy(ndarray).cuda().float() for ndarray in ndarrays)\n    return tensors\n\n\ndef tensor_to_ndarray(*tensors):\n    \"\"\"\n    Converts multiple PyTorch tensors to numpy ndarrays.\n\n    Parameters:\n        *tensors (torch.Tensor): Multiple PyTorch tensors to convert.\n\n    Returns:\n        tuple of numpy.ndarray: A tuple of numpy ndarrays with the same data as the input tensors.\n    \"\"\"\n    ndarrays = tuple(tensor.detach().cpu().numpy() for tensor in tensors)\n    return ndarrays\n\n\ndef convert_labels_to_one_hot(labels, num_classes):\n    '''\n    Converts labels of samples from format (N,) to (N, C), where C is the number of classes\n\n    Args:\n    labels : numpy array of shape (N,) containing the labels of each sample\n    num_classes : integer indicating the total number of classes in the dataset\n\n    Returns:\n    numpy array of shape (N, C), where C is the number of classes, containing the one-hot encoded labels\n    '''\n    one_hot_labels = np.zeros((labels.shape[0], num_classes))\n    one_hot_labels[np.arange(labels.shape[0]), labels] = 1\n    return one_hot_labels\n\n\ndef get_file_names(folder_path):\n    # List to store the file names\n    file_names = []\n\n    # Loop through each file in the folder\n    for file_name in sorted(os.listdir(folder_path)):\n        # Check if the current item is a file\n        if os.path.isfile(os.path.join(folder_path, file_name)):\n            file_names.append(os.path.join(folder_path, file_name))\n\n    return file_names\n\n\ndef extract(v, t, x_shape):\n    \"\"\"\n    Extract some coefficients at specified timesteps, then reshape to\n    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.\n    \"\"\"\n    out = torch.gather(v, index=t, dim=0).float()\n    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))"
  },
  {
    "path": "attack.py",
    "content": "import os\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nimport logging\nimport random\n\nfrom attack.attack_model import AttackModel\nfrom data.prepare import dataset_prepare\nfrom attack.utils import Dict\n\nimport yaml\nimport datasets\nfrom datasets import Image, Dataset\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nimport trl\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig, TrainingArguments, AutoConfig, LlamaTokenizer\nfrom peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training\n\n# Load config file\nwith open(\"configs/config.yaml\", 'r') as f:\n    cfg = yaml.safe_load(f)\n    cfg = Dict(cfg)\n\n# Add Logger\naccelerator = Accelerator()\nlogger = get_logger(__name__, \"INFO\")\nlogging.basicConfig(\n    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n    datefmt=\"%m/%d/%Y %H:%M:%S\",\n    level=logging.INFO,\n    )\n\n# Load abs path\nPATH = os.path.dirname(os.path.abspath(__file__))\n\n# Fix the random seed\nseed = 0\ntorch.manual_seed(seed)\nnp.random.seed(seed)\ntorch.cuda.manual_seed_all(seed)\nrandom.seed(seed)\ntorch.backends.cudnn.benchmark = False\ntorch.backends.cudnn.deterministic = True\n\n## Load generation models.\nif not cfg[\"load_attack_data\"]:\n    # config = AutoConfig.from_pretrained(cfg[\"model_name\"])\n    # config.use_cache = False\n    torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n    target_model = AutoModelForCausalLM.from_pretrained(cfg[\"target_model\"], quantization_config=BitsAndBytesConfig(load_in_8bit=True),\n                                                        torch_dtype=torch_dtype,\n                                                        local_files_only=True,\n                                                        config=AutoConfig.from_pretrained(cfg[\"model_name\"]),\n                                                        cache_dir=cfg[\"cache_path\"])\n    reference_model = AutoModelForCausalLM.from_pretrained(cfg[\"reference_model\"], quantization_config=BitsAndBytesConfig(load_in_8bit=True),\n                                                           torch_dtype=torch_dtype,\n                                                           local_files_only=True,\n                                                           config=AutoConfig.from_pretrained(cfg[\"model_name\"]),\n                                                           cache_dir=cfg[\"cache_path\"])\n\n\n    logger.info(\"Successfully load models\")\n    config = AutoConfig.from_pretrained(cfg.model_name)\n    # Load tokenizer.\n    model_type = config.to_dict()[\"model_type\"]\n    if model_type == \"llama\":\n        tokenizer = LlamaTokenizer.from_pretrained(cfg[\"model_name\"], add_eos_token=cfg[\"add_eos_token\"],\n                                                  add_bos_token=cfg[\"add_bos_token\"], use_fast=True)\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(cfg[\"model_name\"], add_eos_token=cfg[\"add_eos_token\"],\n                                                  add_bos_token=cfg[\"add_bos_token\"], use_fast=True)\n\n    if cfg[\"model_name\"] == \"/mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348\":\n        cfg[\"model_name\"] = \"decapoda-research/llama-7b-hf\"\n\n    if cfg[\"pad_token_id\"] is not None:\n        logger.info(\"Using pad token id %d\", cfg[\"pad_token_id\"])\n        tokenizer.pad_token_id = cfg[\"pad_token_id\"]\n\n    if tokenizer.pad_token_id is None:\n        logger.info(\"Pad token id is None, setting to eos token id...\")\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n\n    # Load datasets\n    train_dataset, valid_dataset = dataset_prepare(cfg, tokenizer=tokenizer)\n    train_dataset = Dataset.from_dict(train_dataset[cfg.train_sta_idx:cfg.train_end_idx])\n    valid_dataset = Dataset.from_dict(valid_dataset[cfg.eval_sta_idx:cfg.eval_end_idx])\n    train_dataset = Dataset.from_dict(train_dataset[random.sample(range(len(train_dataset[\"text\"])), cfg[\"maximum_samples\"])])\n    valid_dataset = Dataset.from_dict(valid_dataset[random.sample(range(len(valid_dataset[\"text\"])), cfg[\"maximum_samples\"])])\n    logger.info(\"Successfully load datasets!\")\n\n    # Prepare dataloade\n    train_dataloader = DataLoader(train_dataset, batch_size=cfg[\"eval_batch_size\"])\n    eval_dataloader = DataLoader(valid_dataset, batch_size=cfg[\"eval_batch_size\"])\n\n    # Load Mask-f\n    shadow_model = None\n    int8_kwargs = {}\n    half_kwargs = {}\n    if cfg[\"int8\"]:\n        int8_kwargs = dict(load_in_8bit=True, device_map='auto', torch_dtype=torch.bfloat16)\n    elif cfg[\"half\"]:\n        half_kwargs = dict(torch_dtype=torch.bfloat16)\n    mask_model = AutoModelForSeq2SeqLM.from_pretrained(cfg[\"mask_filling_model_name\"], **int8_kwargs, **half_kwargs).to(accelerator.device)\n    try:\n        n_positions = mask_model.config.n_positions\n    except AttributeError:\n        n_positions = 512\n    mask_tokenizer = AutoTokenizer.from_pretrained(cfg[\"mask_filling_model_name\"], model_max_length=n_positions)\n\n    # Prepare everything with accelerator\n    train_dataloader, eval_dataloader = (\n        accelerator.prepare(\n            train_dataloader,\n            eval_dataloader,\n    ))\nelse:\n    target_model = None\n    reference_model = None\n    shadow_model = None\n    mask_model = None\n    train_dataloader = None\n    eval_dataloader = None\n    tokenizer = None\n    mask_tokenizer = None\n\n\ndatasets = {\n    \"target\": {\n        \"train\": train_dataloader,\n        \"valid\": eval_dataloader\n    }\n}\n\n\nattack_model = AttackModel(target_model, tokenizer, datasets, reference_model, shadow_model, cfg, mask_model=mask_model, mask_tokenizer=mask_tokenizer)\nattack_model.conduct_attack(cfg=cfg)\n"
  },
  {
    "path": "configs/config.yaml",
    "content": "random_seed: 48\nmodel_name: /mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 # EleutherAI/gpt-j-6B gpt2\ntarget_model: /mnt/data0/fuwenjie/MIA-LLMs/ft_llms/llama/ag_news/target/checkpoint-3000 # valid model:\nreference_model: /mnt/data0/fuwenjie/MIA-LLMs/ft_llms/llama/ag_news/refer/checkpoint-400 #\ndataset_name: ag_news # xsum, ag_news, wikitext\ndataset_config_name: null # wikitext-2-raw-v1 null\ncache_path: ./cache\nuse_dataset_cache: true\npacking: true\ncalibration: true # whether to enable calibration\nadd_eos_token: false\nadd_bos_token: false\npad_token_id: null\nattack_kind: stat # valid attacks: nn, stat\neval_batch_size: 1 # batch size of the evaluation phase\nmaximum_samples: 200 # the maximum samples number for member and non-member records.\nblock_size: 128\nvalidation_split_percentage: 0.1\npreprocessing_num_workers: 1\nmask_filling_model_name: t5-base\nbuffer_size: 1\nmask_top_p: 1.0\nspan_length: 2\npct: 0.3 # pct_words_masked\nceil_pct: false\nint8: false\nhalf: false\nperturbation_number: 1 # the number of different perturbation strength / position; debugging parameter, should be set to 1 in the regular running.\nsample_number: 10 # the number of sampling\ntrain_sta_idx: 0\ntrain_end_idx: 10000\neval_sta_idx: 0\neval_end_idx: 1000\nattack_data_path: attack\nload_attack_data: false # whether to load prepared attack data if existing."
  },
  {
    "path": "data/__init__.py",
    "content": ""
  },
  {
    "path": "data/prepare.py",
    "content": "import os\nimport random\nimport datasets\nimport trl\nfrom attack.utils import create_folder\n\nblock_size = None\ntokenizer_ = None\nmax_buff_size = None\ntext_column = None\n\ndef packing_texts(examples):\n    more_examples = True\n    packed_texts = []\n    packed_ids = []\n    # for key in examples.keys():\n    assert list(examples.keys()) == [\"text\"]\n    iterator = iter(examples[\"text\"])\n    # for sentence in examples[\"text\"]:\n    total_num = 0\n    drop_num = 0\n    while more_examples:\n        buffer, buffer_len = [], 0\n        while True:\n            if buffer_len >= max_buff_size:\n                break\n            try:\n                buffer.append(next(iterator))\n                buffer_len += len(buffer[-1])\n            except StopIteration:\n                more_examples = False\n                break\n        tokenized_inputs = tokenizer_(buffer, truncation=False)[\"input_ids\"]\n        inputs = tokenizer_.batch_decode(tokenized_inputs)\n        tokenized_inputs = tokenizer_(inputs, truncation=False)[\"input_ids\"]\n        all_token_ids = []\n        for tokenized_input in tokenized_inputs:\n            all_token_ids.extend(tokenized_input)\n        for i in range(0, len(all_token_ids), block_size):\n            input_ids = all_token_ids[i: i + block_size]\n            if len(input_ids) == block_size:\n                packed_ids.append(input_ids)\n                input_text = tokenizer_.decode(input_ids)\n                total_num += 1\n                if len(tokenizer_.encode(input_text)) == block_size:\n                    packed_texts.append(input_text)\n                    drop_num += 1\n    # print(f\"Total examples: {total_num}, dropped num: {drop_num}, dropped rate: {1 - drop_num/total_num}\")\n    return {\n        \"text\": packed_texts\n    }\ndef dataset_prepare(args, tokenizer=None, num_of_sequences=1024, chars_per_token=3.6):\n    # raw_datasets = datasets.load_dataset(args.dataset_name, args.dataset_config_name)['train']\n    # if \"validation\" in raw_datasets.keys():\n    #     train_dataset = raw_datasets[\"train\"]\n    #     valid_dataset = raw_datasets[\"validation\"]\n    # else:\n    train_dataset = datasets.load_dataset(\n        args.dataset_name,\n        args.dataset_config_name,\n        split=f\"train[:{int((1-args.validation_split_percentage)*100)}%]\"\n    )\n    valid_dataset = datasets.load_dataset(\n        args.dataset_name,\n        args.dataset_config_name,\n        split=f\"train[{int((1-args.validation_split_percentage)*100)}%:]\",\n    )\n\n    # train_idxs = set(random.sample(range(len(raw_datasets)), int(len(raw_datasets) * (1 - args.validation_split_percentage))))\n    # valid_idxs = set(range(len(raw_datasets))) - train_idxs\n    # train_dataset = datasets.Dataset.from_dict(raw_datasets[train_idxs])\n    # valid_dataset = datasets.Dataset.from_dict(raw_datasets[valid_idxs])\n\n\n    global text_column\n    column = train_dataset.column_names\n    if \"text\" in column:\n        text_column = \"text\"\n    elif \"document\" in column:\n        text_column = \"document\"\n    elif \"content\" in column:\n        text_column = \"content\"\n\n    train_dataset = train_dataset.select_columns(text_column)\n    valid_dataset = valid_dataset.select_columns(text_column)\n    if text_column != \"text\":\n        train_dataset = train_dataset.rename_column(text_column, \"text\")\n        valid_dataset = valid_dataset.rename_column(text_column, \"text\")\n\n    if args.packing:\n        global block_size, tokenizer_, max_buff_size\n        block_size = args.block_size\n        max_buff_size = block_size * chars_per_token * num_of_sequences\n        tokenizer_ = tokenizer\n        create_folder(f\"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}\")\n        train_dataset = train_dataset.map(\n            packing_texts,\n            batched=True,\n            # batch_size=None,\n            num_proc=args.preprocessing_num_workers,\n            cache_file_name=f\"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}/train_dataset\",\n            load_from_cache_file=args.use_dataset_cache,\n            desc=f\"Packing texts in chunks of {block_size} tokens\"\n        )\n        valid_dataset = valid_dataset.map(\n            packing_texts,\n            batched=True,\n            # batch_size=None,\n            num_proc=args.preprocessing_num_workers,\n            cache_file_name=f\"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}/valid_dataset\",\n            load_from_cache_file=args.use_dataset_cache,\n            desc=f\"Packing texts in chunks of {block_size} tokens\"\n        )\n        return train_dataset, valid_dataset"
  },
  {
    "path": "ft_llms/__init__.py",
    "content": ""
  },
  {
    "path": "ft_llms/llama_patch.py",
    "content": "from typing import List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport math\nimport warnings\nimport transformers\nfrom transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv\nfrom peft.tuners.lora import LoraLayer\n\ntry:\n    from flash_attn.modules.mha import FlashSelfAttention\nexcept Exception:\n    raise ModuleNotFoundError(\n        \"Please install FlashAttention first, e.g., with pip install flash-attn --no-build-isolation, Learn more at https://github.com/Dao-AILab/flash-attention#installation-and-features\"\n    )\n\n\n\ndef compute_flash_attention(flash_attn, q, k, v, attention_mask=None, head_mask=None):\n    # q, k, v: [bs, seq_len, num_attention_heads, attn_head_size]\n    # attention_mask (float): [bs, seq_len]\n    batch_size, max_len = q.size(0), q.size(1)\n\n    qkv = torch.stack([q, k, v], dim=2)\n    dtype_in = qkv.dtype\n    if dtype_in == torch.float32:\n        qkv = qkv.to(torch.float16)  # need to truncate in case input is fp32\n    cu_seqlens, max_seqlen = None, None\n\n    if attention_mask is None:\n        out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)\n    else:\n        # Limitation: non-contiguous attention mask will not be handled correctly\n        # model will be able to pay attention between the first and last non-masked token, i.e. left- and right-side padding is supported.\n        csums = (attention_mask >= 0).cumsum(dim=1)\n        ends = csums.argmax(dim=1) + 1\n        starts = ends - csums.max(dim=1).values\n        seqlens = ends - starts\n\n        qkv = torch.cat([qkv[i, starts[i] : ends[i]] for i in range(batch_size)], dim=0)\n        zero = torch.zeros_like(seqlens[:1])  # torch.tensor([0]) with correct dtype and device\n        cu_seqlens = torch.cat([zero, seqlens.cumsum(dim=0)], dim=0).to(torch.int32)\n        max_seqlen = seqlens.max().item()\n\n        out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)\n        # out: [num_unmasked_tokens, num_attention_heads, attn_head_size]\n\n        seqs = [out[start:end] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])]\n        # stack and pad sequences together\n        padded_seqs = [\n            F.pad(seqs[i], (0, 0) * (seqs[i].dim() - 1) + (starts[i], max_len - ends[i]), value=0.0)\n            for i in range(batch_size)\n        ]\n        out = torch.stack(padded_seqs)\n\n    if out.dtype != dtype_in:\n        out = out.to(dtype_in)\n    return out\n\n\n# AND https://github.com/LAION-AI/Open-Assistant/blob/04fa9a24b2a58c8885b8aa6a2eb02b18de6b4961/model/model_training/models/patching_llama.py\ndef llama_forward_with_flash_attn(\n    self: LlamaAttention,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    bsz, q_len, _ = hidden_states.size()\n    \n    if not hasattr(self, 'att_fn'):\n        self.att_fn = FlashSelfAttention(causal=True)\n    \n    flash_attn = self.att_fn\n\n    if output_attentions:\n        warnings.warn(\"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.\")\n    if self.config.pretraining_tp > 1:\n        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp\n        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)\n        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]\n        query_states = torch.cat(query_states, dim=-1)\n\n        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]\n        key_states = torch.cat(key_states, dim=-1)\n\n        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]\n        value_states = torch.cat(value_states, dim=-1)\n\n    else:\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    kv_seq_len = key_states.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n    if past_key_value is not None:\n        # reuse k, v, self_attention\n        key_states = torch.cat([past_key_value[0], key_states], dim=2)\n        value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n    past_key_value = (key_states, value_states) if use_cache else None\n\n    # repeat k/v heads if n_kv_heads < n_heads\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n    if (\n        query_states.shape == key_states.shape\n    ):  # and (attention_mask is None or attention_mask[:, 0, -1, 0].min() >= 0):\n        if attention_mask is not None:\n            attention_mask = attention_mask[:, 0, -1]\n\n        flash_attn.train(self.training)\n        out_dtype = value_states.dtype\n        q, k, v = (\n            query_states.transpose(1, 2),\n            key_states.transpose(1, 2),\n            value_states.transpose(1, 2),\n        )\n        attn_output = compute_flash_attention(flash_attn, q, k, v, attention_mask)\n        attn_output = attn_output.transpose(1, 2).to(out_dtype)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n    else:\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n    attn_output = attn_output.transpose(1, 2).contiguous()\n    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n    if self.config.pretraining_tp > 1:\n        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)\n        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)\n        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])\n    else:\n        attn_output = self.o_proj(attn_output)\n\n    return attn_output, None, past_key_value\n\n\n# Disable the transformation of the attention mask in LlamaModel as the flash attention\n# requires the attention mask to be the same as the key_padding_mask\ndef _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n    # [bsz, seq_len]\n    return attention_mask\n\n\ndef replace_attn_with_flash_attn():\n    cuda_major, cuda_minor = torch.cuda.get_device_capability()\n    if cuda_major < 8:\n        print(\n            \"Flash attention is only supported on Ampere or Hopper GPU during training due to head dim > 64 backward.\"\n            \"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593\"\n        )\n    \n    # transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (\n    #    _prepare_decoder_attention_mask\n    # )\n    transformers.models.llama.modeling_llama.LlamaAttention.old_forward = transformers.models.llama.modeling_llama.LlamaAttention.forward\n    transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_forward_with_flash_attn\n\n\ndef unplace_flash_attn_with_attn():\n    import importlib\n    import transformers\n\n    print(\"Reloading llama model, unpatching flash attention\")\n    importlib.reload(transformers.models.llama.modeling_llama)\n\n\n# Adapted from https://github.com/tmm1/axolotl/blob/2eda9e02a9d15a7a3f92b41f257d9844d72fc220/src/axolotl/utils/models.py#L338\ndef upcast_layer_for_flash_attention(model, torch_dtype):\n    # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to\n    # convert them back to fp16/bf16 for flash-attn compatibility.\n    for name, module in model.named_modules():\n        if isinstance(module, LoraLayer):\n            module.to(torch_dtype)\n        if \"norm\" in name:\n            module.to(torch_dtype)\n        if \"lm_head\" in name or \"embed_tokens\" in name:\n            if hasattr(module, \"weight\"):\n                module.to(torch_dtype)\n\n    return model"
  },
  {
    "path": "ft_llms/llms_finetune.py",
    "content": "import argparse\n\nimport datasets\nimport trl\nfrom trl import SFTTrainer\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, AutoConfig\nfrom accelerate import Accelerator\nfrom datasets import Dataset, load_from_disk\nimport torch\nimport logging\nimport os\nfrom peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PrefixTuningConfig, PromptEncoderConfig, IA3Config\nimport pandas as pd\nimport sys\nhere = os.path.dirname(__file__)\nsys.path.append(os.path.join(here, '..'))\nfrom data.prepare import dataset_prepare\nfrom attack.utils import create_folder\nfrom transformers import LlamaTokenizer, get_scheduler\nimport os\n\nos.environ['HTTP_PROXY'] = 'http://fuwenjie:19990621f@192.168.75.13:7890'\nos.environ['HTTPS_PROXY'] = 'http://fuwenjie:19990621f@192.168.75.13:7890'\n\nfrom utils import get_logger, constantlengthdatasetiter, print_trainable_parameters\n# trl.trainer.ConstantLengthDataset.__dict__[\"__iter__\"] = constantlengthdatasetiter\n# setattr(trl.trainer.ConstantLengthDataset, \"__iter__\", constantlengthdatasetiter)\n\nlogger = get_logger(\"finetune\", \"info\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-m\", \"--model_name\", type=str, default=\"meta-llama/Llama-2-7b-hf\")\n    parser.add_argument(\"-d\", \"--dataset_name\", type=str, default=\"wikitext-2-raw-v1\")\n    parser.add_argument(\"-dc\", \"--dataset_config_name\", type=str, default=None, help=\"The configuration name of the dataset to use (via the datasets library).\")\n    parser.add_argument(\"--cache_path\", type=str, default=\"./cache\")\n    parser.add_argument(\"--use_dataset_cache\", action=\"store_true\", default=False)\n    parser.add_argument(\"--refer\", action=\"store_true\", default=False)\n    parser.add_argument(\"--refer_data_source\", type=str, default=None)\n    parser.add_argument(\"--packing\", action=\"store_true\", default=False)\n    parser.add_argument(\"-t\", \"--token\", type=str, default=None)\n    parser.add_argument(\"--split_model\", action=\"store_true\", default=False)\n    parser.add_argument(\"--block_size\", type=int, default=1024)\n    parser.add_argument(\"--preprocessing_num_workers\", type=int, default=1)\n    parser.add_argument(\"--peft\", type=str, default=\"lora\")\n    parser.add_argument(\"--lora_rank\", type=int, default=64)\n    parser.add_argument(\"--lora_alpha\", type=int, default=16)\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.1)\n    parser.add_argument(\"--p_tokens\", type=int, help=\"The number of virtual tokens for prefix-tuning or p-tuning\", default=20)\n    parser.add_argument(\"--p_hidden\", type=int, help=\"The hidden size of the prompt encoder\", default=128)\n\n    parser.add_argument(\"-lr\", \"--learning_rate\", type=float, default=1e-4)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"linear\")\n    parser.add_argument(\"--warmup_steps\", type=int, default=0)\n    parser.add_argument(\"--weight_decay\", type=float, default=0)\n    parser.add_argument(\"--output_dir\", type=str, default=\"./ft_llms/checkpoints\")\n    parser.add_argument(\"--log_steps\", type=int, default=10)\n    parser.add_argument(\"--eval_steps\", type=int, default=10)\n    parser.add_argument(\"--save_epochs\", type=int, default=10)\n    parser.add_argument(\"-e\", \"--epochs\", type=int, default=1)\n    parser.add_argument(\"-b\", \"--batch_size\", type=int, default=1)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=1)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\", default=False)\n    parser.add_argument(\"--trust_remote_code\", action=\"store_true\", default=False)\n\n    parser.add_argument(\"--train_sta_idx\", type=int, default=0)\n    parser.add_argument(\"--train_end_idx\", type=int, default=6000)\n    parser.add_argument(\"--eval_sta_idx\", type=int, default=0)\n    parser.add_argument(\"--eval_end_idx\", type=int, default=600)\n\n    parser.add_argument(\"-s\", \"--save_limit\", type=int, default=None)\n\n    parser.add_argument(\"--use_int4\", action=\"store_true\", default=False)\n    parser.add_argument(\"--use_int8\", action=\"store_true\", default=False)\n    parser.add_argument(\"--disable_peft\", action=\"store_true\", default=False)\n    parser.add_argument(\"--disable_flash_attention\", action=\"store_true\", help=\"Disable flash attention\", default=False)\n\n    parser.add_argument(\"--pad_token_id\", default=None, type=int, help=\"The end of sequence token.\")\n    parser.add_argument(\"--add_eos_token\", action=\"store_true\", help=\"Add EOS token to tokenizer\", default=False)\n    parser.add_argument(\"--add_bos_token\", action=\"store_true\", help=\"Add BOS token to tokenizer\", default=False)\n    parser.add_argument(\"--validation_split_percentage\", default=0.1, help=\"The percentage of the train set used as validation set in case there's no validation split\")\n    args = parser.parse_args()\n\n    accelerator = Accelerator()\n\n    if args.token is None:\n        access_token = os.getenv(\"HF_TOKEN\", \"\")\n    else:\n        access_token = args.token\n\n    config = AutoConfig.from_pretrained(args.model_name, cache_dir=args.cache_path)\n\n    config.use_cache = False\n    config_dict = config.to_dict()\n    model_type = config_dict[\"model_type\"]\n\n    use_flash_attention = False\n\n    if not args.disable_flash_attention and model_type != \"llama\":\n        logger.info(\"Model is not llama, disabling flash attention...\")\n    elif args.disable_flash_attention and model_type == \"llama\":\n        logger.info(\"Model is llama, could be using flash attention...\")\n    elif not args.disable_flash_attention and torch.cuda.get_device_capability()[0] >= 8:\n        from ft_llms.llama_patch import replace_attn_with_flash_attn\n        logger.info(\"Using flash attention for llama...\")\n        replace_attn_with_flash_attn()\n        use_flash_attention = True\n\n\n    if \"WANDB_PROJECT\" not in os.environ:\n        os.environ[\"WANDB_PROJECT\"] = \"GPT_finetuning\"\n\n    if args.split_model:\n        logger.info(\"Splitting the model across all available devices...\")\n        kwargs = {\"device_map\": \"auto\"}\n    else:\n        kwargs = {\"device_map\": None}\n\n    if model_type == \"llama\":\n        tokenizer = LlamaTokenizer.from_pretrained(args.model_name, token=access_token,\n                                                  trust_remote_code=args.trust_remote_code, cache_dir=args.cache_path,\n                                                  add_eos_token=args.add_eos_token, add_bos_token=args.add_bos_token,\n                                                  use_fast=True)\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(args.model_name, token=access_token,\n                                                  trust_remote_code=args.trust_remote_code, cache_dir=args.cache_path,\n                                                  add_eos_token=args.add_eos_token, add_bos_token=args.add_bos_token,\n                                                  use_fast=True)\n    # THIS IS A HACK TO GET THE PAD TOKEN ID NOT TO BE EOS\n    # good one for LLama is 18610\n    if args.pad_token_id is not None:\n        logger.info(\"Using pad token id %d\", args.pad_token_id)\n        tokenizer.pad_token_id = args.pad_token_id\n\n    if tokenizer.pad_token_id is None:\n        logger.info(\"Pad token id is None, setting to eos token id...\")\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n\n    block_size = args.block_size\n    logger.info(\"Using a block size of %d\", block_size)\n\n    if args.use_int4:\n        logger.info(\"Using int4 quantization\")\n        bnb_config = BitsAndBytesConfig(\n            load_in_4bit=True,\n            bnb_4bit_quant_type=\"nf4\",\n            bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,\n            bnb_4bit_use_double_quant=True,\n        )\n        optimizer = \"adamw_bnb_8bit\"\n    elif args.use_int8:\n        logger.info(\"Using int8 quantization\")\n        bnb_config = BitsAndBytesConfig(\n            load_in_8bit=True,\n        )\n        optimizer = \"adamw_bnb_8bit\"\n    else:\n        logger.info(\"Using no quantization\")\n        bnb_config = None\n        optimizer = \"adamw_torch\"\n\n    if args.peft == \"lora\":\n        peft_config = LoraConfig(\n            task_type=TaskType.CAUSAL_LM,\n            inference_mode=False,\n            r=args.lora_rank,\n            lora_alpha=args.lora_alpha,\n            lora_dropout=args.lora_dropout\n        )\n    elif args.peft == \"prefix-tuing\":\n        peft_config = PrefixTuningConfig(\n            task_type=TaskType.CAUSAL_LM,\n            inference_mode=False,\n            num_virtual_tokens=args.p_tokens,\n            encoder_hidden_size=args.p_hidden)\n    elif args.peft == \"p-tuing\":\n        peft_config = PromptEncoderConfig(\n            task_type=TaskType.CAUSAL_LM,\n            num_virtual_tokens=args.p_tokens,\n            encoder_hidden_size=args.p_hidden)\n    elif args.peft == \"ia3\":\n        peft_config = IA3Config(\n            peft_type=\"IA3\",\n            task_type=TaskType.CAUSAL_LM,\n            target_modules=[\"k_proj\", \"v_proj\", \"down_proj\"],\n            feedforward_modules=[\"down_proj\"],\n        )\n\n    torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n    model = AutoModelForCausalLM.from_pretrained(args.model_name, token=access_token, quantization_config=bnb_config,\n                                                 trust_remote_code=args.trust_remote_code, cache_dir=args.cache_path,\n                                                 torch_dtype=torch_dtype, config=config, **kwargs)\n\n    if use_flash_attention:\n        from ft_llms.llama_patch import llama_forward_with_flash_attn\n\n        assert model.model.layers[\n                   0].self_attn.forward.__doc__ == llama_forward_with_flash_attn.__doc__, \"Model is not using flash attention\"\n\n    if not args.disable_peft:\n        logger.info(\"Using PEFT...\")\n        if args.use_int4 or args.use_int8:\n            logger.info(\"Preparing model for kbit training...\")\n            model = prepare_model_for_kbit_training(model)\n            if use_flash_attention:\n                from ft_llms.llama_patch import upcast_layer_for_flash_attention\n                logger.info(\"Upcasting flash attention layers...\")\n                model = upcast_layer_for_flash_attention(model, torch_dtype)\n        logger.info(\"Getting PEFT model...\")\n        model = get_peft_model(model, peft_config)\n    else:\n        logger.info(\"Using Full Finetuning\")\n\n    print_trainable_parameters(model)\n\n    if args.refer_data_source is not None:\n        args.model_name = args.refer_data_source\n    if args.model_name == \"/mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348\":\n        args.model_name = \"decapoda-research/llama-7b-hf\"\n    with accelerator.main_process_first():\n        train_dataset, valid_dataset = dataset_prepare(args, tokenizer=tokenizer)\n        if args.refer:\n            train_dataset = None\n            refer_data_path = f\"{args.cache_path}/{args.dataset_name}/{args.dataset_config_name}/refer@{args.model_name}/\"\n            train_dataset = load_from_disk(refer_data_path)\n        train_dataset = Dataset.from_dict(train_dataset[args.train_sta_idx:args.train_end_idx])\n        valid_dataset = Dataset.from_dict(valid_dataset[args.eval_sta_idx:args.eval_end_idx])\n        # train_dataset = load_from_disk(\"/mnt/data0/fuwenjie/MIA-LLMs/cache/ag_news/None/refer@gpt2\")\n\n    logger.info(f\"Training with {Accelerator().num_processes} GPUs\")\n    training_args = TrainingArguments(\n        do_train=True,\n        do_eval=True,\n        output_dir=args.output_dir,\n        dataloader_drop_last=True,\n        evaluation_strategy=\"steps\",\n        save_strategy=\"steps\",\n        logging_strategy=\"steps\",\n        num_train_epochs=args.epochs,\n        eval_steps=args.eval_steps,\n        save_steps=args.save_epochs,\n        logging_steps=args.log_steps,\n        per_device_train_batch_size=args.batch_size,\n        per_device_eval_batch_size=args.batch_size * 2,\n        optim=optimizer,\n        learning_rate=args.learning_rate,\n        lr_scheduler_type=args.lr_scheduler_type,\n        warmup_steps=args.warmup_steps,\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        gradient_checkpointing=args.gradient_checkpointing,\n        weight_decay=args.weight_decay,\n        adam_epsilon=1e-6,\n        report_to=\"wandb\",\n        load_best_model_at_end=False,\n        save_total_limit=args.save_limit,\n        bf16=True if torch.cuda.is_bf16_supported() else False,\n        fp16=False if torch.cuda.is_bf16_supported() else True,\n    )\n\n    # get trainer\n    trainer = SFTTrainer(\n        model=model,\n        args=training_args,\n        train_dataset=train_dataset,\n        eval_dataset=valid_dataset,\n        dataset_text_field=\"text\",\n        tokenizer=tokenizer,\n    )\n\n    # train\n    trainer.train()"
  },
  {
    "path": "ft_llms/refer_data_generate.py",
    "content": "import os\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nimport logging\nimport random\nimport sys\nhere = os.path.dirname(__file__)\nsys.path.append(os.path.join(here, '..'))\nfrom data.prepare import dataset_prepare\nfrom attack.utils import Dict\nimport argparse\nimport yaml\nimport datasets\nfrom datasets import Image, Dataset, load_from_disk, concatenate_datasets\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nimport trl\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig, TrainingArguments, AutoConfig, LlamaTokenizer\n\nimport os\nos.environ['HTTP_PROXY'] = 'http://fuwenjie:19990621f@localhost:7890'\nos.environ['HTTPS_PROXY'] = 'http://fuwenjie:19990621f@localhost:7890'\n\n# Load config file\naccelerator = Accelerator()\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"-m\", \"--model_name\", type=str, default=\"meta-llama/Llama-2-7b-hf\")\nparser.add_argument(\"-tm\", \"--target_model\", type=str, default=\"meta-llama/Llama-2-7b-hf\")\nparser.add_argument(\"-d\", \"--dataset_name\", type=str, default=\"wikitext-2-raw-v1\")\nparser.add_argument(\"-dc\", \"--dataset_config_name\", type=str, default=None,\n                    help=\"The configuration name of the dataset to use (via the datasets library).\")\nparser.add_argument(\"--cache_path\", type=str, default=\"./cache\")\nparser.add_argument(\"--use_dataset_cache\", action=\"store_true\", default=True)\nparser.add_argument(\"--packing\", action=\"store_true\", default=True)\nparser.add_argument(\"--block_size\", type=int, default=128)\nparser.add_argument(\"--preprocessing_num_workers\", type=int, default=1)\nparser.add_argument(\"--validation_split_percentage\", default=0.1,\n                    help=\"The percentage of the train set used as validation set in case there's no validation split\")\ncfg = parser.parse_args()\n\nprint(accelerator.device)\n\nconfig = AutoConfig.from_pretrained(cfg.model_name)\nconfig.use_cache = False\nbnb_config = None\ntorch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\nmodel = AutoModelForCausalLM.from_pretrained(cfg.target_model, quantization_config=bnb_config,\n                                                    torch_dtype=torch_dtype,\n                                                    local_files_only=True,\n                                                    config=config,\n                                                    cache_dir=cfg.cache_path)\nmodel_type = config.to_dict()[\"model_type\"]\nif model_type == \"llama\":\n    tokenizer = LlamaTokenizer.from_pretrained(cfg.model_name)\nelse:\n    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)\nif tokenizer.pad_token_id is None:\n    print(\"Pad token id is None, setting to eos token id...\")\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n# Load datasets\ntrain_dataset, valid_dataset = dataset_prepare(cfg, tokenizer=tokenizer)\nprompt_dataset = Dataset.from_dict(train_dataset[10000:20000])\nprompt_dataloader = DataLoader(prompt_dataset, batch_size=1)\n\nmodel, prompt_dataloader = accelerator.prepare(model, prompt_dataloader)\n\ngenerated_dataset = {\"text\": []}\n\nfor text in tqdm(prompt_dataloader):\n    prompt = (text[\"text\"])\n    input_ids = tokenizer(prompt, return_tensors=\"pt\", padding=True).input_ids.to(accelerator.device)\n    clipped_ids = input_ids[:, :16]\n    if hasattr(model, \"module\"):\n        gen_tokens = model.module.generate(\n            clipped_ids,\n            num_beams=1,\n            do_sample=True,\n            max_length=input_ids.size(-1),\n        )\n    else:\n        gen_tokens = model.generate(\n            clipped_ids,\n            num_beams=1,\n            do_sample=True,\n            max_length=input_ids.size(-1),\n        )\n    if model_type == \"llama\":\n        gen_tokens = gen_tokens[:, 1:]\n    print(model(gen_tokens, labels=gen_tokens).loss)\n    gen_text = tokenizer.batch_decode(gen_tokens)\n    generated_dataset[\"text\"].extend(gen_text)\n\ngenerated_dataset = Dataset.from_dict(generated_dataset)\nif cfg.model_name == \"/mnt/data0/fuwenjie/MIA-LLMs/cache/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348\":\n    cfg.model_name = \"decapoda-research/llama-7b-hf\"\nsave_dir = f\"{cfg.cache_path}/{cfg.dataset_name}/{cfg.dataset_config_name}/refer@{cfg.model_name}/\"\ngenerated_dataset.save_to_disk(save_dir + f\"{accelerator.device}\")\n\naccelerator.wait_for_everyone()\n\nif accelerator.is_main_process:\n    concatenated_dataset = None\n    for sub_dir in os.listdir(save_dir):\n        data_path = os.path.join(save_dir, sub_dir)\n        if os.path.isdir(data_path):\n            if concatenated_dataset is None:\n                concatenated_dataset = load_from_disk(data_path)\n            else:\n                dataset = load_from_disk(data_path)\n                concatenated_dataset = concatenate_datasets([concatenated_dataset, dataset])\n    concatenated_dataset.save_to_disk(save_dir)"
  },
  {
    "path": "ft_llms/run_llama.sh",
    "content": "\n# ag_news\naccelerate launch ./ft_llms/llms_finetune.py \\\n--output_dir ./ft_llms/llama/ag_news/target/ \\\n--block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \\\n-d ag_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \\\n-e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \\\n--train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000\n\n# refer candidate\naccelerate launch ./ft_llms/llms_finetune.py \\\n--output_dir ./ft_llms/llama/ag_news/candidate/ \\\n--block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \\\n-d JulesBelveze/tldr_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \\\n-e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \\\n--train_sta_idx=0 --train_end_idx=4767 --eval_sta_idx=0 --eval_end_idx=538\n\n# refer oracle\naccelerate launch ./ft_llms/llms_finetune.py \\\n--output_dir ./ft_llms/llama/ag_news/oracle/ \\\n--block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \\\n-d ag_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \\\n-e 10 -b 4 -lr 1e-4 --gradient_accumulation_steps 1 \\\n--train_sta_idx=10000 --train_end_idx=20000 --eval_sta_idx=1000 --eval_end_idx=2000\n\naccelerate launch refer_data_generate.py \\\n-tm /mnt/data0/fuwenjie/MIA-LLMs/ft_llms/llama/ag_news/target/checkpoint-3000 \\\n-m decapoda-research/llama-7b-hf -d ag_news\n\n# refer prompt\naccelerate launch ./ft_llms/llms_finetune.py --refer \\\n--output_dir ./ft_llms/llama/ag_news/refer/ \\\n--block_size 128 --eval_steps 100 --save_epochs 100 --log_steps 100 \\\n-d ag_news -m decapoda-research/llama-7b-hf --packing --use_dataset_cache \\\n-e 2 -b 4 -lr 5e-5 --gradient_accumulation_steps 1 \\\n--train_sta_idx=0 --train_end_idx=10000 --eval_sta_idx=0 --eval_end_idx=1000"
  },
  {
    "path": "ft_llms/utils.py",
    "content": "import logging\nfrom typing_extensions import Literal\nfrom rich.logging import RichHandler\nfrom torch.utils.data import IterableDataset\nimport warnings\nimport random\nimport torch\n\n\ndef get_logger(name: str, level: Literal[\"info\", \"warning\", \"debug\"]) -> logging.Logger:\n    rich_handler = RichHandler(level=logging.INFO, rich_tracebacks=True, markup=True)\n\n    logger = logging.getLogger(name)\n    logger.setLevel(logging._nameToLevel[level.upper()])\n\n    if not logger.handlers:\n        logger.addHandler(rich_handler)\n\n    logger.propagate = False\n\n    return logger\n\ndef print_trainable_parameters(model):\n    \"\"\"\n    Prints the number of trainable parameters in the model.\n    \"\"\"\n    trainable_params = 0\n    all_param = 0\n    for _, param in model.named_parameters():\n        all_param += param.numel()\n        if param.requires_grad:\n            trainable_params += param.numel()\n    print(\n        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n    )\n\ndef constantlengthdatasetiter(self):\n    iterator = iter(self.dataset)\n    more_examples = True\n    while more_examples:\n        buffer, buffer_len = [], 0\n        while True:\n            if buffer_len >= self.max_buffer_size:\n                break\n            try:\n                buffer.append(self.formatting_func(next(iterator)))\n                buffer_len += len(buffer[-1])\n            except StopIteration:\n                if self.infinite:\n                    iterator = iter(self.dataset)\n                    warnings.warn(\"The dataset reached end and the iterator is reset to the start.\")\n                    break\n                else:\n                    more_examples = False\n                    break\n        tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n        all_token_ids = []\n        for tokenized_input in tokenized_inputs:\n            all_token_ids.extend(tokenized_input + [self.concat_token_id])\n        examples = []\n        for i in range(0, len(all_token_ids), self.seq_length):\n            input_ids = all_token_ids[i : i + self.seq_length]\n            if len(input_ids) == self.seq_length:\n                examples.append(input_ids)\n        if self.shuffle:\n            random.shuffle(examples)\n        for example in examples:\n            self.current_size += 1\n            yield {\n                \"input_ids\": torch.LongTensor(example),\n                \"labels\": torch.LongTensor(example),\n            }"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate==0.23.0\ndatasets==2.14.5\ndeepspeed==0.10.1+46d859a7\ndeepspeed_mii==0.0.7+0acf569\nflash_attn==2.2.1\nhuggingface_hub==0.16.4\nmatplotlib==3.5.3\nnlpaug==1.1.11\nnltk==3.4.5\nnumpy==1.24.4\nopacus==1.4.0\nopenai==1.3.5\npandas==2.0.3\npeft==0.6.0.dev0\npython_dateutil==2.8.2\npyvacy==0.0.32\nPyYAML==6.0.1\nPyYAML==6.0.1\nrich==13.7.0\nscikit_learn==1.3.0\nscipy==1.11.4\nseaborn==0.13.0\nspacy==3.7.1\ntqdm==4.66.1\ntransformers==4.34.0.dev0\ntrl==0.7.1\ntyping_extensions==4.8.0\n"
  }
]