[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# images\n*.png\n*.jpg\n*.pdf\n\n# experiment data\n*.out\n*.csv\n*.pt\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"stable-diffusion\"]\n\tpath = stable-diffusion\n\turl = https://github.com/CompVis/stable-diffusion.git\n[submodule \"erasing\"]\n\tpath = erasing\n\turl = https://github.com/brandontrabucco/erasing.git\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Brandon Trabucco\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Effective Data Augmentation With Diffusion Models\n\n[![Watch Effective Data Augmentation With Diffusion Models On YouTube](images/play-da-fusion.png)](https://www.youtube.com/watch?v=IKDWOOWzwns)\n\nWatch our talk for a quick introduction!\n\nData augmentation is one of the most prevalent tools in deep learning, underpinning many recent advances. The standard approach to data augmentation combines simple transformations like rotations and flips to generate new images from existing ones. However, current augmentations cannot alter the high-level semantic attributes, such as animal species present in a scene, to enhance the diversity of data. We improve diversity in data augmentation with image-to-image transformations parameterized by pre-trained text-to-image diffusion models. Our method edits images using an off-the-shelf diffusion model, and generalizes to novel visual concepts from a few labelled examples.\n\n[ICLR 2024 Manuscript](https://openreview.net/forum?id=ZWzUA9zeAg)    |    [Site](btrabuc.co/da-fusion)    |    [Leafy Spurge Dataset](leafy-spurge-dataset.github.io)\n\n## Installation\n\nTo install the package, first create a `conda` environment.\n\n```bash\nconda create -n da-fusion python=3.7 pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.6 -c pytorch\nconda activate da-fusion\npip install diffusers[\"torch\"] transformers pycocotools pandas matplotlib seaborn scipy\n```\n\nThen download and install the source code.\n\n```bash\ngit clone git@github.com:brandontrabucco/da-fusion.git\npip install -e da-fusion\n```\n\n## Datasets\n\nWe benchmark DA-Fusion on few-shot image classification problems, including a Leafy Spurge weed recognition task, and classification tasks derived from COCO and PASCAL VOC. For the latter two, we label images with the classes corresponding to the largest object in the image.\n\nCustom datasets can be evaluated by implementing subclasses of `semantic_aug/few_shot_dataset.py`.\n\n## Setting Up PASCAL VOC\n\nData for the PASCAL VOC task is adapted from the [2012 PASCAL VOC Challenge](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar). Once this dataset has been downloaded and extracted, the PASCAL dataset class `semantic_aug/datasets/pascal.py` should be pointed to the downloaded dataset via the `PASCAL_DIR` config variable located [here](https://github.com/brandontrabucco/da-fusion/blob/main/semantic_aug/datasets/pascal.py#L14).\n\nEnsure that `PASCAL_DIR` points to a folder containing `ImageSets`, `JPEGImages`, `SegmentationClass`, and `SegmentationObject` subfolders.\n\n## Setting Up COCO\n\nTo setup COCO, first download the [2017 Training Images](http://images.cocodataset.org/zips/train2017.zip), the [2017 Validation Images](http://images.cocodataset.org/zips/val2017.zip), and the [2017 Train/Val Annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip). These files should be unzipped into the following directory structure.\n\n```\ncoco2017/\n    train2017/\n    val2017/\n    annotations/\n```\n\n`COCO_DIR` located [here](https://github.com/brandontrabucco/da-fusion/blob/main/semantic_aug/datasets/coco.py#L15) should be updated to point to the location of `coco2017` on your system.\n\n## Setting Up The Spurge Dataset\n\nWe are planning to release this dataset in the next few months. Check back for updates!\n\n## Fine-Tuning Tokens\n\nWe perform textual inversion (https://arxiv.org/abs/2208.01618) to adapt Stable Diffusion to the classes present in our few-shot datasets. The implementation in `fine_tune.py` is adapted from the [Diffusers](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) example. \n\nWe wrap this script for distributing experiments on a slurm cluster in a set of `sbatch` scripts located at `scripts/fine_tuning`. These scripts will perform multiple runs of Textual Inversion in parallel, subject to the number of available nodes on your slurm cluster.\n\nIf `sbatch` is not available in your system, you can run these scripts with `bash` and manually set `SLURM_ARRAY_TASK_ID` and `SLURM_ARRAY_TASK_COUNT` for each parallel job (these are normally set automatically by slurm to control the job index, and the number of jobs respectively, and can be set to 0, 1).\n\n## Few-Shot Classification\n\nCode for training image classification models using augmented images from DA-Fusion is located in `train_classifier.py`. This script accepts a number of arguments that control how the classifier is trained:\n\n```bash\npython train_classifier.py --logdir pascal-baselines/textual-inversion-0.5 \\\n--synthetic-dir \"aug/textual-inversion-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 1 --examples-per-class 4\n```\n\nThis example will train a classifier on the PASCAL VOC task, with 4 images per class, using the prompt `\"a photo of a ClassX\"` where the special token `ClassX` is fine-tuned (from scratch) with textual inversion. Slurm scripts that reproduce the paper are located in `scripts/textual_inversion`. Results are logged to `.csv` files based on the script argument `--logdir`. \n\nWe used a [custom plotting script](https://github.com/brandontrabucco/da-fusion/blob/main/plot.py) to generate the figures in the main paper.\n\n## Citation\n\nIf you find our method helpful, consider citing our preprint!\n\n```\n@misc{https://doi.org/10.48550/arxiv.2302.07944,\n  doi = {10.48550/ARXIV.2302.07944},\n  url = {https://arxiv.org/abs/2302.07944},\n  author = {Trabucco, Brandon and Doherty, Kyle and Gurinas, Max and Salakhutdinov, Ruslan},\n  keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},\n  title = {Effective Data Augmentation With Diffusion Models},\n  publisher = {arXiv},\n  year = {2023},\n  copyright = {arXiv.org perpetual, non-exclusive license}\n}\n```\n"
  },
  {
    "path": "aggregate_embeddings.py",
    "content": "import torch\nimport os\nimport glob\nimport argparse\nfrom itertools import product\nfrom tqdm import trange\n\n\nDEFAULT_EMBED_PATH = \"{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\"\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Merge token files\")\n\n    parser.add_argument(\"--num-trials\", type=int, default=8)\n    parser.add_argument(\"--examples-per-class\", nargs='+', type=int, default=[1, 2, 4, 8, 16])\n    \n    parser.add_argument(\"--embed-path\", type=str, default=DEFAULT_EMBED_PATH)\n    parser.add_argument(\"--input-path\", type=str, default=\"./fine-tuned\")\n    \n    parser.add_argument(\"--dataset\", type=str, default=\"pascal\", \n                        choices=[\"spurge\", \"imagenet\", \"coco\", \"pascal\"])\n\n    args = parser.parse_args()\n\n    for seed, examples_per_class in product(\n            range(args.num_trials), args.examples_per_class):\n\n        path = os.path.join(args.input_path, (\n            f\"{args.dataset}-{seed}-{examples_per_class}/*/learned_embeds.bin\"))\n\n        merged_dict = dict()\n        for file in glob.glob(path):\n            merged_dict.update(torch.load(file))\n\n        target_path = args.embed_path.format(\n            dataset=args.dataset, seed=seed, \n            examples_per_class=examples_per_class)\n\n        os.makedirs(os.path.dirname(target_path), exist_ok=True)\n        torch.save(merged_dict, target_path)"
  },
  {
    "path": "fine_tune.py",
    "content": "import argparse\nimport logging\nimport math\nimport os\nimport gc\nimport shutil\nimport random\nfrom pathlib import Path\nfrom typing import Optional\nfrom itertools import product\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch.utils.data import Dataset\n\nfrom semantic_aug.datasets.coco import COCODataset\nfrom semantic_aug.datasets.spurge import SpurgeDataset\nfrom semantic_aug.datasets.imagenet import ImageNetDataset\nfrom semantic_aug.datasets.pascal import PASCALDataset\n\nimport datasets\nimport diffusers\nimport PIL\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import set_seed\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom huggingface_hub import HfFolder, Repository, whoami\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\n\nDATASETS = {\n    \"spurge\": SpurgeDataset, \n    \"coco\": COCODataset, \n    \"pascal\": PASCALDataset,\n    \"imagenet\": ImageNetDataset\n}\n\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.10.0.dev0\")\n\n\nlogger = get_logger(__name__)\n\n\ndef save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):\n    logger.info(\"Saving embeddings\")\n    learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]\n    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}\n    torch.save(learned_embeds_dict, save_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--only_save_embeds\",\n        action=\"store_true\",\n        default=False,\n        help=\"Save only the embeddings for the new concept.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"./\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    parser.add_argument(\"--num-trials\", type=int, default=8)\n    parser.add_argument(\"--examples-per-class\", nargs='+', type=int, default=[1, 2, 4, 8, 16])\n    \n    parser.add_argument(\"--dataset\", type=str, default=\"coco\", \n                        choices=[\"spurge\", \"imagenet\", \"coco\", \"pascal\"])\n\n    parser.add_argument(\"--unet-ckpt\", type=str, default=None)\n\n    parser.add_argument(\"--erase-concepts\", action=\"store_true\", \n                        help=\"erase text inversion concepts first\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (h, w,) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):\n    if token is None:\n        token = HfFolder.get_token()\n    if organization is None:\n        username = whoami(token)[\"name\"]\n        return f\"{username}/{model_id}\"\n    else:\n        return f\"{organization}/{model_id}\"\n\n\ndef main(args):\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        logging_dir=logging_dir,\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.push_to_hub:\n            if args.hub_model_id is None:\n                repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)\n            else:\n                repo_name = args.hub_model_id\n            repo = Repository(args.output_dir, clone_from=repo_name)\n\n            with open(os.path.join(args.output_dir, \".gitignore\"), \"w+\") as gitignore:\n                if \"step_*\" not in gitignore:\n                    gitignore.write(\"step_*\\n\")\n                if \"epoch_*\" not in gitignore:\n                    gitignore.write(\"epoch_*\\n\")\n        elif args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load tokenizer\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    if args.unet_ckpt is not None:\n        unet.load_state_dict(torch.load(args.unet_ckpt))\n        print(f\"Loaded UNET from {args.unet_ckpt}\")\n\n    # Add the placeholder token in tokenizer\n    num_added_tokens = tokenizer.add_tokens(args.placeholder_token)\n    if num_added_tokens == 0:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)\n\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder.resize_token_embeddings(len(tokenizer))\n\n    # Initialise the newly added placeholder token with the embeddings of the initializer token\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]\n\n    # Freeze vae and unet\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    # Freeze all parameters except for the token embeddings in text encoder\n    text_encoder.text_model.encoder.requires_grad_(False)\n    text_encoder.text_model.final_layer_norm.requires_grad_(False)\n    text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        # Keep unet in train mode if we are using gradient checkpointing to save memory.\n        # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.\n        unet.train()\n        text_encoder.gradient_checkpointing_enable()\n        unet.enable_gradient_checkpointing()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        text_encoder, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and unet to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1]\n        accelerator.print(f\"Resuming from checkpoint {path}\")\n        accelerator.load_state(os.path.join(args.output_dir, path))\n        global_step = int(path.split(\"-\")[1])\n\n        resume_global_step = global_step * args.gradient_accumulation_steps\n        first_epoch = resume_global_step // num_update_steps_per_epoch\n        resume_step = resume_global_step % num_update_steps_per_epoch\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    # keep original embeddings as reference\n    orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(text_encoder):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample().detach()\n                latents = latents * 0.18215\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0].to(dtype=weight_dtype)\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # Let's make sure we don't update any embedding weights besides the newly added token\n                index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id\n                with torch.no_grad():\n                    accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[\n                        index_no_updates\n                    ] = orig_embeds_params[index_no_updates]\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                if global_step % args.save_steps == 0:\n                    save_path = os.path.join(args.output_dir, f\"learned_embeds-steps-{global_step}.bin\")\n                    save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    accelerator.wait_for_everyone()\n\n    if accelerator.is_main_process:\n        # Save the newly trained embeddings\n        save_path = os.path.join(args.output_dir, \"learned_embeds.bin\")\n        save_progress(text_encoder, placeholder_token_id, \n                      accelerator, args, save_path)\n\n    accelerator.end_training()\n    accelerator.free_memory()\n\n    del accelerator, vae, unet, text_encoder\n\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n\n    args = parse_args()\n    output_dir = args.output_dir\n\n    rank = int(os.environ.pop(\"RANK\", 0))\n    world_size = int(os.environ.pop(\"WORLD_SIZE\", 1))\n\n    device_id = rank % torch.cuda.device_count()\n    torch.cuda.set_device(rank % torch.cuda.device_count())\n\n    print(f'Initialized process {rank} / {world_size}')\n\n    options = product(range(args.num_trials), args.examples_per_class)\n    options = np.array(list(options))\n    options = np.array_split(options, world_size)[rank]\n\n    for seed, examples_per_class in options.tolist():\n\n        os.makedirs(os.path.join(output_dir, \"extracted\"), exist_ok=True)\n\n        train_dataset = DATASETS[\n            args.dataset](split=\"train\", seed=seed, \n                          examples_per_class=examples_per_class)\n\n        for idx in range(len(train_dataset)):\n\n            image = train_dataset.get_image_by_idx(idx)\n            metadata = train_dataset.get_metadata_by_idx(idx)\n\n            name = metadata[\"name\"].replace(\" \", \"_\")\n            path = f\"{args.dataset}-{seed}-{examples_per_class}\"\n\n            path = os.path.join(output_dir, \"extracted\", path, name, f\"{idx}.png\")\n            os.makedirs(os.path.dirname(path), exist_ok=True)\n\n            image.save(path)\n\n        for class_name in train_dataset.class_names:\n\n            formatted_name = class_name.replace(\" \", \"_\")\n            dirname = f\"{args.dataset}-{seed}-{examples_per_class}/{formatted_name}\"\n\n            args = parse_args()\n            \n            args.seed = seed\n\n            args.placeholder_token = f\"<{formatted_name}>\"\n            args.initializer_token = \"the\"\n\n            args.train_data_dir = os.path.join(\n                output_dir, \"extracted\", dirname)\n            args.output_dir = os.path.join(\n                output_dir, \"fine-tuned\", dirname)\n\n            word_name = class_name.replace(\" \", \"\")\n\n            if args.erase_concepts: args.unet_ckpt = (\n                \"/projects/rsalakhugroup/btrabucc/esd-models/\" + \n                f\"compvis-word_{word_name}-method_full-sg_3-ng_1-iter_1000-lr_1e-05/\" + \n                f\"diffusers-word_{word_name}-method_full-sg_3-ng_1-iter_1000-lr_1e-05.pt\")\n\n            main(args)\n\n            shutil.rmtree(args.train_data_dir)\n"
  },
  {
    "path": "fine_tune_upstream.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nimport argparse\nimport logging\nimport math\nimport shutil\nimport os\nimport gc\nimport random\nimport shutil\nimport warnings\nfrom pathlib import Path\nfrom itertools import product\n\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\nfrom semantic_aug.datasets.coco import COCODataset\nfrom semantic_aug.datasets.spurge import SpurgeDataset\nfrom semantic_aug.datasets.imagenet import ImageNetDataset\nfrom semantic_aug.datasets.pascal import PASCALDataset\nfrom semantic_aug.datasets.caltech101 import CalTech101Dataset\nfrom semantic_aug.datasets.flowers102 import Flowers102Dataset\n\n\nDATASETS = {\n    \"spurge\": SpurgeDataset, \n    \"coco\": COCODataset, \n    \"pascal\": PASCALDataset,\n    \"imagenet\": ImageNetDataset,\n    \"caltech\": CalTech101Dataset,\n    \"flowers\": Flowers102Dataset\n}\n\n\nif is_wandb_available():\n    import wandb\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.20.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- stable-diffusion\n- stable-diffusion-diffusers\n- text-to-image\n- diffusers\n- textual_inversion\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# Textual inversion text2image fine-tuning - {repo_id}\nThese are textual inversion adaption weights for {base_model}. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        text_encoder=accelerator.unwrap_model(text_encoder),\n        tokenizer=tokenizer,\n        unet=unet,\n        vae=vae,\n        safety_checker=None,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    images = []\n    for _ in range(args.num_validation_images):\n        with torch.autocast(\"cuda\"):\n            image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n    return images\n\n\ndef save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):\n    logger.info(\"Saving embeddings\")\n    learned_embeds = (\n        accelerator.unwrap_model(text_encoder)\n        .get_input_embeddings()\n        .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]\n    )\n    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}\n    torch.save(learned_embeds_dict, save_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--save_as_full_pipeline\",\n        action=\"store_true\",\n        help=\"Save the complete stable diffusion pipeline.\",\n    )\n    parser.add_argument(\n        \"--num_vectors\",\n        type=int,\n        default=1,\n        help=\"How many textual inversion vectors shall be used to learn the concept.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=None,\n        help=(\n            \"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    parser.add_argument(\"--num-trials\", type=int, default=8)\n    parser.add_argument(\"--examples-per-class\", nargs='+', type=int, default=[1, 2, 4, 8, 16])\n    \n    parser.add_argument(\"--dataset\", type=str, default=\"coco\", choices=DATASETS.keys())\n\n    parser.add_argument(\"--unet-ckpt\", type=str, default=None)\n\n    parser.add_argument(\"--erase-concepts\", action=\"store_true\", \n                        help=\"erase text inversion concepts first\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef main(args):\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load tokenizer\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    if args.unet_ckpt is not None:\n        unet.load_state_dict(torch.load(args.unet_ckpt))\n        print(f\"Loaded UNET from {args.unet_ckpt}\")\n\n    # Add the placeholder token in tokenizer\n    placeholder_tokens = [args.placeholder_token]\n\n    if args.num_vectors < 1:\n        raise ValueError(f\"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}\")\n\n    # add dummy tokens for multi-vector\n    additional_tokens = []\n    for i in range(1, args.num_vectors):\n        additional_tokens.append(f\"{args.placeholder_token}_{i}\")\n    placeholder_tokens += additional_tokens\n\n    num_added_tokens = tokenizer.add_tokens(placeholder_tokens)\n    if num_added_tokens != args.num_vectors:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)\n\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder.resize_token_embeddings(len(tokenizer))\n\n    # Initialise the newly added placeholder token with the embeddings of the initializer token\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for token_id in placeholder_token_ids:\n            token_embeds[token_id] = token_embeds[initializer_token_id].clone()\n\n    # Freeze vae and unet\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    # Freeze all parameters except for the token embeddings in text encoder\n    text_encoder.text_model.encoder.requires_grad_(False)\n    text_encoder.text_model.final_layer_norm.requires_grad_(False)\n    text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        # Keep unet in train mode if we are using gradient checkpointing to save memory.\n        # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.\n        unet.train()\n        text_encoder.gradient_checkpointing_enable()\n        unet.enable_gradient_checkpointing()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warn(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n    if args.validation_epochs is not None:\n        warnings.warn(\n            f\"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}.\"\n            \" Deprecated validation_epochs in favor of `validation_steps`\"\n            f\"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}\",\n            FutureWarning,\n            stacklevel=2,\n        )\n        args.validation_steps = args.validation_epochs * len(train_dataset)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n    )\n\n    # Prepare everything with our `accelerator`.\n    text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        text_encoder, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and unet to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    # keep original embeddings as reference\n    orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(text_encoder):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample().detach()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0].to(dtype=weight_dtype)\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # Let's make sure we don't update any embedding weights besides the newly added token\n                index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)\n                index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False\n\n                with torch.no_grad():\n                    accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[\n                        index_no_updates\n                    ] = orig_embeds_params[index_no_updates]\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                images = []\n                progress_bar.update(1)\n                global_step += 1\n                if global_step % args.save_steps == 0:\n                    save_path = os.path.join(args.output_dir, f\"learned_embeds-steps-{global_step}.bin\")\n                    save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        images = log_validation(\n                            text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.push_to_hub and not args.save_as_full_pipeline:\n            logger.warn(\"Enabling full model saving because --push_to_hub=True was specified.\")\n            save_full_model = True\n        else:\n            save_full_model = args.save_as_full_pipeline\n        if save_full_model:\n            pipeline = StableDiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                text_encoder=accelerator.unwrap_model(text_encoder),\n                vae=vae,\n                unet=unet,\n                tokenizer=tokenizer,\n            )\n            pipeline.save_pretrained(args.output_dir)\n        # Save the newly trained embeddings\n        save_path = os.path.join(args.output_dir, \"learned_embeds.bin\")\n        save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n    accelerator.free_memory()\n\n    del accelerator, vae, unet, text_encoder\n\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n\n    args = parse_args()\n    output_dir = args.output_dir\n\n    rank = int(os.environ.pop(\"RANK\", 0))\n    world_size = int(os.environ.pop(\"WORLD_SIZE\", 1))\n\n    device_id = rank % torch.cuda.device_count()\n    torch.cuda.set_device(rank % torch.cuda.device_count())\n\n    print(f'Initialized process {rank} / {world_size}')\n\n    class_names = DATASETS[args.dataset].class_names\n\n    options = list(product(\n        range(args.num_trials),\n        args.examples_per_class,\n        class_names))\n\n    print(f\"{len(options)} Total Options\")\n\n    options_idx = np.arange(len(options))\n    options_idx = np.array_split(options_idx, world_size)[rank]\n\n    options = [options[idx] for idx in options_idx]\n\n    for seed, examples_per_class, class_name in options:\n\n        os.makedirs(os.path.join(output_dir, \"extracted\"), exist_ok=True)\n\n        train_dataset = DATASETS[\n            args.dataset](split=\"train\", seed=seed, \n                          examples_per_class=examples_per_class)\n\n        for idx in range(len(train_dataset)):\n\n            image = train_dataset.get_image_by_idx(idx)\n            metadata = train_dataset.get_metadata_by_idx(idx)\n\n            if metadata[\"name\"] == class_name:\n\n                name = metadata[\"name\"].replace(\" \", \"_\")\n                path = f\"{args.dataset}-{seed}-{examples_per_class}\"\n\n                path = os.path.join(output_dir, \"extracted\", path, name, f\"{idx}.png\")\n                os.makedirs(os.path.dirname(path), exist_ok=True)\n\n                image.save(path)\n\n        formatted_name = class_name.replace(\" \", \"_\")\n        dirname = f\"{args.dataset}-{seed}-{examples_per_class}/{formatted_name}\"\n\n        args = parse_args()\n        \n        args.seed = seed\n\n        args.placeholder_token = f\"<{formatted_name}>\"\n\n        args.train_data_dir = os.path.join(\n            output_dir, \"extracted\", dirname)\n        args.output_dir = os.path.join(\n            output_dir, \"fine-tuned\", dirname)\n\n        word_name = class_name.replace(\" \", \"\")\n\n        if args.erase_concepts: args.unet_ckpt = (\n            \"/projects/rsalakhugroup/btrabucc/esd-models/\" + \n            f\"compvis-word_{word_name}-method_full-sg_3-ng_1-iter_1000-lr_1e-05/\" + \n            f\"diffusers-word_{word_name}-method_full-sg_3-ng_1-iter_1000-lr_1e-05.pt\")\n\n        main(args)\n\n        shutil.rmtree(args.train_data_dir)\n"
  },
  {
    "path": "generate_augmentations.py",
    "content": "from semantic_aug.datasets.coco import COCODataset\nfrom semantic_aug.datasets.spurge import SpurgeDataset\nfrom semantic_aug.datasets.imagenet import ImageNetDataset\nfrom semantic_aug.datasets.pascal import PASCALDataset\nfrom semantic_aug.augmentations.compose import ComposeParallel\nfrom semantic_aug.augmentations.compose import ComposeSequential\nfrom semantic_aug.augmentations.real_guidance import RealGuidance\nfrom semantic_aug.augmentations.textual_inversion import TextualInversion\nfrom diffusers import StableDiffusionPipeline\nfrom itertools import product\nfrom torch import autocast\nfrom PIL import Image\n\nfrom tqdm import tqdm\nimport os\nimport torch\nimport argparse\nimport numpy as np\nimport random\n\n\nDATASETS = {\n    \"spurge\": SpurgeDataset, \n    \"coco\": COCODataset, \n    \"pascal\": PASCALDataset,\n    \"imagenet\": ImageNetDataset\n}\n\nCOMPOSE = {\n    \"parallel\": ComposeParallel,\n    \"sequential\": ComposeSequential\n}\n\nAUGMENT = {\n    \"real-guidance\": RealGuidance,\n    \"textual-inversion\": TextualInversion\n}\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Inference script\")\n    \n    parser.add_argument(\"--out\", type=str, default=\"real-guidance/\")\n\n    parser.add_argument(\"--model-path\", type=str, default=\"CompVis/stable-diffusion-v1-4\")\n    parser.add_argument(\"--embed-path\", type=str, default=\"erasure-tokens/pascal-tokens/pascal-0-8.pt\")\n    \n    parser.add_argument(\"--dataset\", type=str, default=\"pascal\")\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--examples-per-class\", type=int, default=1)\n    parser.add_argument(\"--num-synthetic\", type=int, default=10)\n\n    parser.add_argument(\"--prompt\", type=str, default=\"a photo of a {name}\")\n    \n    parser.add_argument(\"--aug\", nargs=\"+\", type=str, default=[\"real-guidance\"], \n                        choices=[\"real-guidance\", \"textual-inversion\"])\n\n    parser.add_argument(\"--guidance-scale\", nargs=\"+\", type=float, default=[7.5])\n    parser.add_argument(\"--strength\", nargs=\"+\", type=float, default=[0.5])\n\n    parser.add_argument(\"--mask\", nargs=\"+\", type=int, default=[0], choices=[0, 1])\n    parser.add_argument(\"--inverted\", nargs=\"+\", type=int, default=[0], choices=[0, 1])\n    \n    parser.add_argument(\"--probs\", nargs=\"+\", type=float, default=None)\n    \n    parser.add_argument(\"--compose\", type=str, default=\"parallel\", \n                        choices=[\"parallel\", \"sequential\"])\n\n    parser.add_argument(\"--class-name\", type=str, default=None)\n    \n    parser.add_argument(\"--erasure-ckpt-path\", type=str, default=None)\n\n    args = parser.parse_args()\n\n    os.makedirs(args.out, exist_ok=True)\n\n    torch.manual_seed(args.seed)\n    np.random.seed(args.seed)\n    random.seed(args.seed)\n\n    aug = COMPOSE[args.compose]([\n        \n        AUGMENT[aug](\n            embed_path=args.embed_path, \n            model_path=args.model_path, \n            prompt=args.prompt, \n            strength=strength, \n            guidance_scale=guidance_scale,\n            mask=mask, \n            inverted=inverted,\n            erasure_ckpt_path=args.erasure_ckpt_path\n        )\n\n        for (aug, guidance_scale, \n             strength, mask, inverted) in zip(\n            args.aug, args.guidance_scale, \n            args.strength, args.mask, args.inverted\n        )\n\n    ], probs=args.probs)\n\n    train_dataset = DATASETS[\n        args.dataset](split=\"train\", seed=args.seed, \n                      examples_per_class=args.examples_per_class)\n\n    options = product(range(len(train_dataset)), range(args.num_synthetic))\n\n    for idx, num in tqdm(list(\n            options), desc=\"Generating Augmentations\"):\n\n        image = train_dataset.get_image_by_idx(idx)\n        label = train_dataset.get_label_by_idx(idx)\n\n        metadata = train_dataset.get_metadata_by_idx(idx)\n\n        if args.class_name is not None: \n            if metadata[\"name\"] != args.class_name: continue\n\n        image, label = aug(\n            image, label, metadata)\n\n        name = metadata['name'].replace(\" \", \"_\")\n\n        pil_image, image = image, os.path.join(\n            args.out, f\"{name}-{idx}-{num}.png\")\n\n        pil_image.save(image)"
  },
  {
    "path": "generate_images.py",
    "content": "from semantic_aug.augmentations.textual_inversion import TextualInversion\nfrom diffusers import StableDiffusionPipeline\nfrom itertools import product\nfrom torch import autocast\nfrom PIL import Image\n\nfrom tqdm import trange\nimport os\nimport torch\nimport argparse\nimport numpy as np\nimport random\n\n\nDEFAULT_ERASURE_CKPT = (\n    \"/projects/rsalakhugroup/btrabucc/esd-models/\" + \n    \"compvis-word_airplane-method_full-sg_3-ng_1-iter_1000-lr_1e-05/\" + \n    \"diffusers-word_airplane-method_full-sg_3-ng_1-iter_1000-lr_1e-05.pt\")\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Stable Diffusion inference script\")\n\n    parser.add_argument(\"--model-path\", type=str, default=\"CompVis/stable-diffusion-v1-4\")\n    parser.add_argument(\"--embed-path\", type=str, default=(\n        \"erasure-tokens/fine-tuned/pascal-0-8/airplane/learned_embeds.bin\"))\n    \n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--num-generate\", type=int, default=10)\n\n    parser.add_argument(\"--prompt\", type=str, default=\"a photo of a <airplane>\")\n    parser.add_argument(\"--out\", type=str, default=\"erasure-tokens/fine-tuned/pascal-0-8/airplane/\")\n\n    parser.add_argument(\"--guidance-scale\", type=float, default=7.5)\n    parser.add_argument(\"--erasure-ckpt-name\", type=str, default=DEFAULT_ERASURE_CKPT)\n\n    args = parser.parse_args()\n\n    os.makedirs(args.out, exist_ok=True)\n\n    torch.manual_seed(args.seed)\n    np.random.seed(args.seed)\n    random.seed(args.seed)\n\n    pipe = StableDiffusionPipeline.from_pretrained(\n        args.model_path, use_auth_token=True,\n        revision=\"fp16\", \n        torch_dtype=torch.float16\n    ).to('cuda')\n\n    aug = TextualInversion(args.embed_path, model_path=args.model_path)\n    pipe.tokenizer = aug.pipe.tokenizer\n    pipe.text_encoder = aug.pipe.text_encoder\n\n    pipe.set_progress_bar_config(disable=True)\n    pipe.safety_checker = None\n\n    if args.erasure_ckpt_name is not None:\n        pipe.unet.load_state_dict(torch.load(\n            args.erasure_ckpt_name, map_location='cuda'))\n\n    for idx in trange(args.num_generate, \n                      desc=\"Generating Images\"):\n\n        with autocast('cuda'):\n\n            image = pipe(\n                args.prompt, \n                guidance_scale=args.guidance_scale\n            ).images[0]\n\n        image.save(os.path.join(args.out, f\"{idx}.png\"))"
  },
  {
    "path": "images/README.md",
    "content": ""
  },
  {
    "path": "index.html",
    "content": "\n<!DOCTYPE html><html lang=\"en\" itemscope itemtype=\"http://schema.org/WebPage\"><head><meta charset=\"utf-8\"><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\">var DOCS_timing={}; DOCS_timing['sl']=new Date().getTime();</script><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\">function _DumpException(e) {throw e;}</script><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\">_docs_flag_initialData={\"atari-emtpr\":false,\"atari-ebidm\":true,\"atari-ebids\":true,\"atari-edtm\":true,\"atari-eibrm\":false,\"atari-ectm\":false,\"atari-ects\":false,\"docs-text-elei\":false,\"docs-text-usc\":true,\"atari-bae\":false,\"docs-text-eessmkc\":false,\"docs-text-emtps\":false,\"docs-text-etsrdpn\":false,\"docs-text-etsrds\":false,\"docs-text-erdfs\":false,\"docs-text-encps\":false,\"docs-text-endes\":false,\"docs-text-escpv\":true,\"docs-text-ecfs\":false,\"docs-text-ecis\":false,\"docs-text-eessips\":false,\"docs-text-edctzs\":false,\"docs-text-eetxpc\":false,\"docs-text-eetxp\":false,\"docs-text-lns\":false,\"docs-text-edhcfs\":true,\"docs-text-ertkmcp\":true,\"docs-text-ettctvs\":false,\"docs-text-issermps\":false,\"docs-text-emscts\":false,\"docs-etshc\":false,\"docs-text-tbcb\":2.0E7,\"docs-text-ftls\":true,\"docs-efsmsdl\":false,\"docs-euoftm\":false,\"docs-text-etb\":false,\"docs-text-esbefr\":false,\"docs-text-etof\":false,\"docs-text-ipi\":false,\"docs-text-ehlb\":false,\"docs-text-epa\":true,\"docs-text-ecls\":true,\"docs-text-dwit\":false,\"docs-text-elawp\":false,\"docs-eec\":false,\"docs-ecot\":\"\",\"docs-text-enbcr\":false,\"docs-text-svofc\":false,\"docs-sup\":\"\",\"docs-eldi\":false,\"docs-dli\":false,\"docs-liap\":\"/logImpressions\",\"ilcm\":{\"eui\":\"AHKXmL1v_PW0AhcXt6BpBW2jQrg8Oghi_CtVSF_bwn67w6hgFSgUc9r_UTzuuxDr7ST4iTjf_sC4\",\"je\":1,\"sstu\":1703046087531681,\"si\":\"CKa51YiVnYMDFaBLqwIdT8QMkA\",\"gsc\":null,\"ei\":[5703839,5704621,5706832,5706836,5707711,5735808,5737802,5738531,5740816,5743126,5746994,5747263,5748031,5752696,5753331,5754231,5755098,5758825,5760350,5762261,5764270,5765553,5766779,5767853,5770437,5773680,5774096,5774349,5774854,5776519,5777196,5783803,5784949,5784969,5791301,5791784,5792686,5796153,5796475,5797293,14101306,14101502,14101510,14101534,49372444,49375323,49376002,49376338,49378890,49451560,49453046,49472072,49512374,49517792,49612442,49613709,49622832,49623182,49624081,49644024,49765383,49769346,49816166,49822930,49823173,49824164,49833471,49839580,49842864,49924715,50082749,50127541,50166960,50168316,50221729,50266231,50273537,50293697,50335898,50360149,50390166,50492351,50520322,50529112,50533185,50580253,50606356,70979411,70983144,71035309,71085250,71102133,71119967,71152134,71178681,71185179,71197835,71230234,71238955,71241074,71260351,71273598,71286030,71289155,71301339,71330602,71346961,71382647,71396894,71401154,71407394,71444154,71471151,71471883,71480305,71528086,71528606,71530092,71531296,71537707,71558038,71624116,71625589,71641922,71659822,71671627],\"crc\":0,\"cvi\":[]},\"docs-ccdil\":false,\"docs-eil\":true,\"info_params\":{\"token\":\"AHL0AtKjtWA7F_ipTj5iNvZa4HtBjmPZSA:1703046087403\"},\"buildLabel\":\"editors.sites-viewer-frontend_20231212.02_p0\",\"docs-show_debug_info\":false,\"atari-jefp\":\"/_/view/jserror\",\"docs-jern\":\"view\",\"atari-rhpp\":\"/_/view\",\"docs-ecuach\":false,\"docs-cclt\":2033,\"docs-ecci\":true,\"docs-esi\":false,\"docs-efypr\":true,\"docs-eyprp\":false,\"docs-eytpgcv\":1}; _docs_flag_cek= null ; if (window['DOCS_timing']) {DOCS_timing['ifdld']=new Date().getTime();}</script><meta name=\"viewport\" content=\"width=device-width, initial-scale=1\"><meta http-equiv=\"X-UA-Compatible\" content=\"IE=edge\"><meta name=\"referrer\" content=\"strict-origin-when-cross-origin\"><link rel=\"icon\" href=\"https://ssl.gstatic.com/atari/images/public/favicon.ico\"><meta property=\"og:title\" content=\"DA-Fusion\"><meta property=\"og:type\" content=\"website\"><meta property=\"og:url\" content=\"https://sites.google.com/btrabucco.com/da-fusion/home\"><meta property=\"og:description\" content=\"Paper: arXiv    |    Code: GitHub\n\"><meta itemprop=\"name\" content=\"DA-Fusion\"><meta itemprop=\"description\" content=\"Paper: arXiv    |    Code: GitHub\n\"><meta itemprop=\"url\" content=\"https://sites.google.com/btrabucco.com/da-fusion/home\"><meta itemprop=\"thumbnailUrl\" content=\"https://lh5.googleusercontent.com/mTrJFRXLoOsce_2zNf1rSofnhmI3zU0oBAU09nT9cl5uj_KakpWpEkR99OXH3MQQT_VnP46NNY56Khcl3pMA5ybQ3yAUXw7CpN2Ndh_XJ5VYtGG6eQXe0EfeYnx-4qym3g=w1280\"><meta itemprop=\"image\" content=\"https://lh5.googleusercontent.com/mTrJFRXLoOsce_2zNf1rSofnhmI3zU0oBAU09nT9cl5uj_KakpWpEkR99OXH3MQQT_VnP46NNY56Khcl3pMA5ybQ3yAUXw7CpN2Ndh_XJ5VYtGG6eQXe0EfeYnx-4qym3g=w1280\"><meta itemprop=\"imageUrl\" content=\"https://lh5.googleusercontent.com/mTrJFRXLoOsce_2zNf1rSofnhmI3zU0oBAU09nT9cl5uj_KakpWpEkR99OXH3MQQT_VnP46NNY56Khcl3pMA5ybQ3yAUXw7CpN2Ndh_XJ5VYtGG6eQXe0EfeYnx-4qym3g=w1280\"><meta property=\"og:image\" content=\"https://lh5.googleusercontent.com/mTrJFRXLoOsce_2zNf1rSofnhmI3zU0oBAU09nT9cl5uj_KakpWpEkR99OXH3MQQT_VnP46NNY56Khcl3pMA5ybQ3yAUXw7CpN2Ndh_XJ5VYtGG6eQXe0EfeYnx-4qym3g=w1280\"><link href=\"https://fonts.googleapis.com/css?family=Lato%3A300%2C300italic%2C400%2C400italic%2C700%2C700italic&display=swap\" rel=\"stylesheet\" nonce=\"k7EAwqdvbMKbMpxLLtxMIg\"><link href=\"https://fonts.googleapis.com/css?family=Google+Sans:400,500|Roboto:300,400,500,700|Source+Code+Pro:400,700&display=swap\" rel=\"stylesheet\" nonce=\"k7EAwqdvbMKbMpxLLtxMIg\"><link href=\"https://fonts.googleapis.com/css?family=Lora%3Ai%2Cbi%2C700%2C400%7CSource%20Code%20Pro%3Ai%2Cbi%2C700%2C400&display=swap\" rel=\"stylesheet\" nonce=\"k7EAwqdvbMKbMpxLLtxMIg\"><style nonce=\"k7EAwqdvbMKbMpxLLtxMIg\">@media only screen and (max-width: 479px){.jgG6ef{font-size: 17.0pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.jgG6ef{font-size: 17.0pt;}}@media only screen and (min-width: 768px) and (max-width: 1279px){.jgG6ef{font-size: 18.0pt;}}@media only screen and (min-width: 1280px){.jgG6ef{font-size: 18.0pt;}}@media only screen and (max-width: 479px){.RijTuc{font-size: 25.0pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.RijTuc{font-size: 30.0pt;}}@media only screen and (min-width: 768px) and (max-width: 1279px){.RijTuc{font-size: 34.0pt;}}@media only screen and (min-width: 1280px){.RijTuc{font-size: 34.0pt;}}@media only screen and (max-width: 479px){.puwcIf{font-size: 20.0pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.puwcIf{font-size: 22.0pt;}}@media only screen and (min-width: 768px) and (max-width: 1279px){.puwcIf{font-size: 24.0pt;}}@media only screen and (min-width: 1280px){.puwcIf{font-size: 24.0pt;}}</style><link rel=\"stylesheet\" href=\"https://www.gstatic.com/_/atari/_/ss/k=atari.vw.RdwxJhNMYZs.L.W.O/am=gAE/d=1/rs=AGEqA5k0HgViAOMqGAcxmPPLYhFps6gwmA\" data-id=\"_cl\" nonce=\"k7EAwqdvbMKbMpxLLtxMIg\"><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\"></script><title>DA-Fusion</title><style jsname=\"ptDGoc\" nonce=\"k7EAwqdvbMKbMpxLLtxMIg\">.M63kCb{background-color: rgba(255,255,255,1);}.OUGEr{color: rgba(33,33,33,1);}.duRjpb .OUGEr{color: rgba(34,110,147,1);}.JYVBee .OUGEr{color: rgba(34,110,147,1);}.OmQG5e .OUGEr{color: rgba(33,33,33,1);}.iwQgFb{background-color: rgba(0,0,0,0.150000006);}.ySLm4c{font-family: Lato, sans-serif;}.CbiMKe{background-color: rgba(30,108,147,1);}.qeLZfd .zfr3Q{color: rgba(33,33,33,1);}.qeLZfd .qnVSj{color: rgba(33,33,33,1);}.qeLZfd .Glwbz{color: rgba(33,33,33,1);}.qeLZfd .duRjpb{color: rgba(34,110,147,1);}.qeLZfd .qLrapd{color: rgba(34,110,147,1);}.qeLZfd .JYVBee{color: rgba(34,110,147,1);}.qeLZfd .aHM7ed{color: rgba(34,110,147,1);}.qeLZfd .OmQG5e{color: rgba(33,33,33,1);}.qeLZfd .NHD4Gf{color: rgba(33,33,33,1);}.qeLZfd .aw5Odc{color: rgba(0,101,128,1);}.qeLZfd .dhtgD:hover{color: rgba(0,0,0,1);}.qeLZfd .dhtgD:visited{color: rgba(0,101,128,1);}.qeLZfd .iwQgFb{background-color: rgba(0,0,0,0.150000006);}.qeLZfd .OUGEr{color: rgba(33,33,33,1);}.qeLZfd .duRjpb .OUGEr{color: rgba(34,110,147,1);}.qeLZfd .JYVBee .OUGEr{color: rgba(34,110,147,1);}.qeLZfd .OmQG5e .OUGEr{color: rgba(33,33,33,1);}.qeLZfd:before{background-color: rgba(242,242,242,1); display: block;}.lQAHbd .zfr3Q{color: rgba(255,255,255,1);}.lQAHbd .qnVSj{color: rgba(255,255,255,1);}.lQAHbd .Glwbz{color: rgba(255,255,255,1);}.lQAHbd .duRjpb{color: rgba(255,255,255,1);}.lQAHbd .qLrapd{color: rgba(255,255,255,1);}.lQAHbd .JYVBee{color: rgba(255,255,255,1);}.lQAHbd .aHM7ed{color: rgba(255,255,255,1);}.lQAHbd .OmQG5e{color: rgba(255,255,255,1);}.lQAHbd .NHD4Gf{color: rgba(255,255,255,1);}.lQAHbd .aw5Odc{color: rgba(255,255,255,1);}.lQAHbd .dhtgD:hover{color: rgba(255,255,255,1);}.lQAHbd .dhtgD:visited{color: rgba(255,255,255,1);}.lQAHbd .iwQgFb{background-color: rgba(255,255,255,0.150000006);}.lQAHbd .OUGEr{color: rgba(255,255,255,1);}.lQAHbd .duRjpb .OUGEr{color: rgba(255,255,255,1);}.lQAHbd .JYVBee .OUGEr{color: rgba(255,255,255,1);}.lQAHbd .OmQG5e .OUGEr{color: rgba(255,255,255,1);}.lQAHbd .CbiMKe{background-color: rgba(255,255,255,1);}.lQAHbd:before{background-color: rgba(30,108,147,1); display: block;}.cJgDec .zfr3Q{color: rgba(255,255,255,1);}.cJgDec .zfr3Q .OUGEr{color: rgba(255,255,255,1);}.cJgDec .qnVSj{color: rgba(255,255,255,1);}.cJgDec .Glwbz{color: rgba(255,255,255,1);}.cJgDec .qLrapd{color: rgba(255,255,255,1);}.cJgDec .aHM7ed{color: rgba(255,255,255,1);}.cJgDec .NHD4Gf{color: rgba(255,255,255,1);}.cJgDec .IFuOkc:before{background-color: rgba(33,33,33,1); opacity: 0; display: block;}.O13XJf{height: 340px; padding-bottom: 60px; padding-top: 60px;}.O13XJf .IFuOkc{background-color: rgba(34,110,147,1); background-image: url(https://ssl.gstatic.com/atari/images/simple-header-blended-small.png);}.O13XJf .IFuOkc:before{background-color: rgba(33,33,33,1); opacity: 0.4; display: block;}.O13XJf .zfr3Q{color: rgba(255,255,255,1);}.O13XJf .qnVSj{color: rgba(255,255,255,1);}.O13XJf .Glwbz{color: rgba(255,255,255,1);}.O13XJf .duRjpb{color: rgba(255,255,255,1);}.O13XJf .qLrapd{color: rgba(255,255,255,1);}.O13XJf .JYVBee{color: rgba(255,255,255,1);}.O13XJf .aHM7ed{color: rgba(255,255,255,1);}.O13XJf .OmQG5e{color: rgba(255,255,255,1);}.O13XJf .NHD4Gf{color: rgba(255,255,255,1);}.tpmmCb .zfr3Q{color: rgba(33,33,33,1);}.tpmmCb .zfr3Q .OUGEr{color: rgba(33,33,33,1);}.tpmmCb .qnVSj{color: rgba(33,33,33,1);}.tpmmCb .Glwbz{color: rgba(33,33,33,1);}.tpmmCb .qLrapd{color: rgba(33,33,33,1);}.tpmmCb .aHM7ed{color: rgba(33,33,33,1);}.tpmmCb .NHD4Gf{color: rgba(33,33,33,1);}.tpmmCb .IFuOkc:before{background-color: rgba(255,255,255,1); display: block;}.tpmmCb .Wew9ke{fill: rgba(33,33,33,1);}.aw5Odc{color: rgba(0,101,128,1);}.dhtgD:hover{color: rgba(0,122,147,1);}.dhtgD:active{color: rgba(0,122,147,1);}.dhtgD:visited{color: rgba(0,101,128,1);}.Zjiec{color: rgba(255,255,255,1); font-family: Lato, sans-serif; font-size: 19pt; font-weight: 300; letter-spacing: 1px; line-height: 1.3; padding-bottom: 62.5px; padding-left: 48px; padding-right: 36px; padding-top: 11.5px;}.XMyrgf{margin-top: 0px; margin-left: 48px; margin-bottom: 24px; margin-right: 24px;}.TlfmSc{color: rgba(255,255,255,1); font-family: Lato, sans-serif; font-size: 15pt; font-weight: 300; line-height: 1.333;}.Mz8gvb{color: rgba(255,255,255,1);}.zDUgLc{background-color: rgba(33,33,33,1);}.QTKDff.chg4Jd:focus{background-color: rgba(255,255,255,0.1199999973);}.YTv4We{color: rgba(178,178,178,1);}.YTv4We:hover:before{background-color: rgba(255,255,255,0.1199999973); display: block;}.YTv4We.chg4Jd:focus:before{border-color: rgba(255,255,255,0.3600000143); display: block;}.eWDljc{background-color: rgba(33,33,33,1);}.eWDljc .hDrhEe{padding-left: 8px;}.ZXW7w{color: rgba(255,255,255,1); opacity: 0.26;}.PsKE7e{color: rgba(255,255,255,1); font-family: Lato, sans-serif; font-size: 12pt; font-weight: 300;}.lhZOrc{color: rgba(73,170,212,1);}.hDrhEe:hover{color: rgba(73,170,212,1);}.M9vuGd{color: rgba(73,170,212,1); font-weight: 400;}.jgXgSe:hover{color: rgba(73,170,212,1);}.j10yRb:hover{color: rgba(0,188,212,1);}.j10yRb.chg4Jd:focus:before{border-color: rgba(255,255,255,0.3600000143); display: block;}.tCHXDc{color: rgba(255,255,255,1);}.iWs3gf.chg4Jd:focus{background-color: rgba(255,255,255,0.1199999973);}.wgxiMe{background-color: rgba(33,33,33,1);}.fOU46b .TlfmSc{color: rgba(255,255,255,1);}.fOU46b .KJll8d{background-color: rgba(255,255,255,1);}.fOU46b .Mz8gvb{color: rgba(255,255,255,1);}.fOU46b .Mz8gvb.chg4Jd:focus:before{border-color: rgba(255,255,255,1); display: block;}.fOU46b .qV4dIc{color: rgba(255,255,255,0.8700000048);}.fOU46b .jgXgSe:hover{color: rgba(255,255,255,1);}.fOU46b .M9vuGd{color: rgba(255,255,255,1);}.fOU46b .tCHXDc{color: rgba(255,255,255,0.8700000048);}.fOU46b .iWs3gf.chg4Jd:focus{background-color: rgba(255,255,255,0.1199999973);}.fOU46b .G8QRnc .Mz8gvb{color: rgba(0,0,0,0.8000000119);}.fOU46b .G8QRnc .Mz8gvb.chg4Jd:focus:before{border-color: rgba(0,0,0,0.8000000119); display: block;}.fOU46b .G8QRnc .ZXW7w{color: rgba(0,0,0,0.8000000119);}.fOU46b .G8QRnc .TlfmSc{color: rgba(0,0,0,0.8000000119);}.fOU46b .G8QRnc .KJll8d{background-color: rgba(0,0,0,0.8000000119);}.fOU46b .G8QRnc .qV4dIc{color: rgba(0,0,0,0.6399999857);}.fOU46b .G8QRnc .jgXgSe:hover{color: rgba(0,0,0,0.8199999928);}.fOU46b .G8QRnc .M9vuGd{color: rgba(0,0,0,0.8199999928);}.fOU46b .G8QRnc .tCHXDc{color: rgba(0,0,0,0.6399999857);}.fOU46b .G8QRnc .iWs3gf.chg4Jd:focus{background-color: rgba(0,0,0,0.1199999973);}.fOU46b .usN8rf .Mz8gvb{color: rgba(0,0,0,0.8000000119);}.fOU46b .usN8rf .Mz8gvb.chg4Jd:focus:before{border-color: rgba(0,0,0,0.8000000119); display: block;}.fOU46b .usN8rf .ZXW7w{color: rgba(0,0,0,0.8000000119);}.fOU46b .usN8rf .TlfmSc{color: rgba(0,0,0,0.8000000119);}.fOU46b .usN8rf .KJll8d{background-color: rgba(0,0,0,0.8000000119);}.fOU46b .usN8rf .qV4dIc{color: rgba(0,0,0,0.6399999857);}.fOU46b .usN8rf .jgXgSe:hover{color: rgba(0,0,0,0.8199999928);}.fOU46b .usN8rf .M9vuGd{color: rgba(0,0,0,0.8199999928);}.fOU46b .usN8rf .tCHXDc{color: rgba(0,0,0,0.6399999857);}.fOU46b .usN8rf .iWs3gf.chg4Jd:focus{background-color: rgba(0,0,0,0.1199999973);}.fOU46b .aCIEDd .qV4dIc{color: rgba(33,33,33,1);}.fOU46b .aCIEDd .TlfmSc{color: rgba(33,33,33,1);}.fOU46b .aCIEDd .KJll8d{background-color: rgba(33,33,33,1);}.fOU46b .aCIEDd .ZXW7w{color: rgba(33,33,33,1);}.fOU46b .aCIEDd .jgXgSe:hover{color: rgba(33,33,33,1); opacity: 0.82;}.fOU46b .aCIEDd .Mz8gvb{color: rgba(33,33,33,1);}.fOU46b .aCIEDd .tCHXDc{color: rgba(33,33,33,1);}.fOU46b .aCIEDd .iWs3gf.chg4Jd:focus{background-color: rgba(33,33,33,0.1199999973);}.fOU46b .a3ETed .qV4dIc{color: rgba(255,255,255,1);}.fOU46b .a3ETed .TlfmSc{color: rgba(255,255,255,1);}.fOU46b .a3ETed .KJll8d{background-color: rgba(255,255,255,1);}.fOU46b .a3ETed .ZXW7w{color: rgba(255,255,255,1);}.fOU46b .a3ETed .jgXgSe:hover{color: rgba(255,255,255,1); opacity: 0.82;}.fOU46b .a3ETed .Mz8gvb{color: rgba(255,255,255,1);}.fOU46b .a3ETed .tCHXDc{color: rgba(255,255,255,1);}.fOU46b .a3ETed .iWs3gf.chg4Jd:focus{background-color: rgba(255,255,255,0.1199999973);}@media only screen and (min-width: 1280px){.XeSM4.b2Iqye.fOU46b .LBrwzc .tCHXDc{color: rgba(255,255,255,0.8700000048);}}.XeSM4.b2Iqye.fOU46b .LBrwzc .iWs3gf.chg4Jd:focus{background-color: rgba(255,255,255,0.1199999973);}@media only screen and (min-width: 1280px){.KuNac.b2Iqye.fOU46b .tCHXDc{color: rgba(0,0,0,0.6399999857);}}.KuNac.b2Iqye.fOU46b .iWs3gf.chg4Jd:focus{background-color: rgba(0,0,0,0.1199999973);}.fOU46b .zDUgLc{opacity: 0;}.LBrwzc .ZXW7w{color: rgba(0,0,0,1);}.LBrwzc .KJll8d{background-color: rgba(0,0,0,1);}.GBy4H .ZXW7w{color: rgba(255,255,255,1);}.GBy4H .KJll8d{background-color: rgba(255,255,255,1);}.eBSUbc{background-color: rgba(33,33,33,1); color: rgba(0,188,212,0.6999999881);}.BFDQOb:hover{color: rgba(73,170,212,1);}.ImnMyf{background-color: rgba(255,255,255,1); color: rgba(33,33,33,1);}.Vs12Bd{background-color: rgba(242,242,242,1); color: rgba(33,33,33,1);}.S5d9Rd{background-color: rgba(30,108,147,1); color: rgba(255,255,255,1);}.zfr3Q{color: rgba(33,33,33,1); font-family: Lato, sans-serif; font-size: 11pt; font-weight: 400; line-height: 1.6667; margin-top: 12px;}.qnVSj{color: rgba(33,33,33,1);}.Glwbz{color: rgba(33,33,33,1);}.duRjpb{color: rgba(34,110,147,1); font-family: Lato, sans-serif; font-size: 34pt; font-weight: 300; letter-spacing: 0.5px; line-height: 1.2; margin-top: 30px;}.Ap4VC{margin-bottom: -30px;}.qLrapd{color: rgba(34,110,147,1);}.JYVBee{color: rgba(34,110,147,1); font-family: Lato, sans-serif; font-size: 19pt; font-weight: 400; line-height: 1.4; margin-top: 20px;}.CobnVe{margin-bottom: -20px;}.aHM7ed{color: rgba(34,110,147,1);}.OmQG5e{color: rgba(33,33,33,1); font-family: Lato, sans-serif; font-size: 15pt; font-style: normal; font-weight: 400; line-height: 1.25; margin-top: 16px;}.GV3q8e{margin-bottom: -16px;}.NHD4Gf{color: rgba(33,33,33,1);}.LB7kq .duRjpb{font-size: 64pt; letter-spacing: 2px; line-height: 1; margin-top: 40px;}.LB7kq .JYVBee{font-size: 25pt; font-weight: 300; line-height: 1.1; margin-top: 25px;}@media only screen and (max-width: 479px){.LB7kq .duRjpb{font-size: 40pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.LB7kq .duRjpb{font-size: 53pt;}}@media only screen and (max-width: 479px){.LB7kq .JYVBee{font-size: 19pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.LB7kq .JYVBee{font-size: 22pt;}}.O13XJf{height: 340px; padding-bottom: 60px; padding-top: 60px;}@media only screen and (min-width: 480px) and (max-width: 767px){.O13XJf{height: 280px; padding-bottom: 40px; padding-top: 40px;}}@media only screen and (max-width: 479px){.O13XJf{height: 250px; padding-bottom: 30px; padding-top: 30px;}}.SBrW1{height: 520px;}@media only screen and (min-width: 480px) and (max-width: 767px){.SBrW1{height: 520px;}}@media only screen and (max-width: 479px){.SBrW1{height: 400px;}}.Wew9ke{fill: rgba(255,255,255,1);}.gk8rDe{height: 180px; padding-bottom: 32px; padding-top: 60px;}.gk8rDe .zfr3Q{color: rgba(0,0,0,1);}.gk8rDe .duRjpb{color: rgba(34,110,147,1); font-size: 45pt; line-height: 1.1;}.gk8rDe .qLrapd{color: rgba(34,110,147,1);}.gk8rDe .JYVBee{color: rgba(34,110,147,1); font-size: 27pt; line-height: 1.35; margin-top: 15px;}.gk8rDe .aHM7ed{color: rgba(34,110,147,1);}.gk8rDe .OmQG5e{color: rgba(33,33,33,1);}.gk8rDe .NHD4Gf{color: rgba(33,33,33,1);}@media only screen and (max-width: 479px){.gk8rDe .duRjpb{font-size: 30pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.gk8rDe .duRjpb{font-size: 38pt;}}@media only screen and (max-width: 479px){.gk8rDe .JYVBee{font-size: 20pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.gk8rDe .JYVBee{font-size: 24pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.gk8rDe{padding-top: 45px;}}@media only screen and (max-width: 479px){.gk8rDe{padding-bottom: 0px; padding-top: 30px;}}.dhtgD{text-decoration: underline;}.JzO0Vc{background-color: rgba(33,33,33,1); font-family: Lato, sans-serif; width: 250px;}@media only screen and (min-width: 1280px){.JzO0Vc{padding-top: 48.5px;}}.TlfmSc{font-family: Lato, sans-serif; font-size: 15pt; font-weight: 300; line-height: 1.333;}.PsKE7e{font-family: Lato, sans-serif; font-size: 12pt;}.IKA38e{line-height: 1.21;}.hDrhEe{padding-bottom: 11.5px; padding-top: 11.5px;}.zDUgLc{opacity: 1;}.QmpIrf{background-color: rgba(30,108,147,1); border-color: rgba(255,255,255,1); color: rgba(255,255,255,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.xkUom{border-color: rgba(30,108,147,1); color: rgba(30,108,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.xkUom:hover{background-color: rgba(30,108,147,0.1000000015);}.KjwKmc{color: rgba(30,108,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal; line-height: normal;}.KjwKmc:hover{background-color: rgba(30,108,147,0.1000000015);}.lQAHbd .QmpIrf{background-color: rgba(255,255,255,1); border-color: rgba(34,110,147,1); color: rgba(34,110,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.lQAHbd .xkUom{border-color: rgba(242,242,242,1); color: rgba(242,242,242,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.lQAHbd .xkUom:hover{background-color: rgba(255,255,255,0.1000000015);}.lQAHbd .KjwKmc{color: rgba(242,242,242,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.lQAHbd .KjwKmc:hover{background-color: rgba(255,255,255,0.1000000015);}.lQAHbd .Mt0nFe{border-color: rgba(255,255,255,0.200000003);}.cJgDec .QmpIrf{background-color: rgba(255,255,255,1); border-color: rgba(34,110,147,1); color: rgba(34,110,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.cJgDec .xkUom{border-color: rgba(242,242,242,1); color: rgba(242,242,242,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.cJgDec .xkUom:hover{background-color: rgba(255,255,255,0.1000000015);}.cJgDec .KjwKmc{color: rgba(242,242,242,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.cJgDec .KjwKmc:hover{background-color: rgba(255,255,255,0.1000000015);}.tpmmCb .QmpIrf{background-color: rgba(255,255,255,1); border-color: rgba(34,110,147,1); color: rgba(34,110,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.tpmmCb .xkUom{border-color: rgba(30,108,147,1); color: rgba(30,108,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.tpmmCb .xkUom:hover{background-color: rgba(30,108,147,0.1000000015);}.tpmmCb .KjwKmc{color: rgba(30,108,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.tpmmCb .KjwKmc:hover{background-color: rgba(30,108,147,0.1000000015);}.gk8rDe .QmpIrf{background-color: rgba(30,108,147,1); border-color: rgba(255,255,255,1); color: rgba(255,255,255,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.gk8rDe .xkUom{border-color: rgba(30,108,147,1); color: rgba(30,108,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.gk8rDe .xkUom:hover{background-color: rgba(30,108,147,0.1000000015);}.gk8rDe .KjwKmc{color: rgba(30,108,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.gk8rDe .KjwKmc:hover{background-color: rgba(30,108,147,0.1000000015);}.O13XJf .QmpIrf{background-color: rgba(255,255,255,1); border-color: rgba(34,110,147,1); color: rgba(34,110,147,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.O13XJf .xkUom{border-color: rgba(242,242,242,1); color: rgba(242,242,242,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.O13XJf .xkUom:hover{background-color: rgba(255,255,255,0.1000000015);}.O13XJf .KjwKmc{color: rgba(242,242,242,1); font-family: Lato, sans-serif; font-size: 11pt; line-height: normal;}.O13XJf .KjwKmc:hover{background-color: rgba(255,255,255,0.1000000015);}.Y4CpGd{font-family: Lato, sans-serif; font-size: 11pt;}.CMArNe{background-color: rgba(242,242,242,1);}.LBrwzc .TlfmSc{color: rgba(0,0,0,0.8000000119);}.LBrwzc .YTv4We{color: rgba(0,0,0,0.6399999857);}.LBrwzc .YTv4We.chg4Jd:focus:before{border-color: rgba(0,0,0,0.6399999857); display: block;}.LBrwzc .Mz8gvb{color: rgba(0,0,0,0.6399999857);}.LBrwzc .tCHXDc{color: rgba(0,0,0,0.6399999857);}.LBrwzc .iWs3gf.chg4Jd:focus{background-color: rgba(0,0,0,0.1199999973);}.LBrwzc .wgxiMe{background-color: rgba(255,255,255,1);}.LBrwzc .qV4dIc{color: rgba(0,0,0,0.6399999857);}.LBrwzc .M9vuGd{color: rgba(0,0,0,0.8000000119); font-weight: bold;}.LBrwzc .Zjiec{color: rgba(0,0,0,0.8000000119);}.LBrwzc .IKA38e{color: rgba(0,0,0,0.6399999857);}.LBrwzc .lhZOrc.IKA38e{color: rgba(0,0,0,0.8000000119); font-weight: bold;}.LBrwzc .j10yRb:hover{color: rgba(0,0,0,0.8000000119);}.LBrwzc .eBSUbc{color: rgba(0,0,0,0.8000000119);}.LBrwzc .hDrhEe:hover{color: rgba(0,0,0,0.8000000119);}.LBrwzc .jgXgSe:hover{color: rgba(0,0,0,0.8000000119);}.LBrwzc .M9vuGd:hover{color: rgba(0,0,0,0.8000000119);}.LBrwzc .zDUgLc{border-bottom-color: rgba(204,204,204,1); border-bottom-width: 1px; border-bottom-style: solid;}.fOU46b .LBrwzc .M9vuGd{color: rgba(0,0,0,0.8000000119);}.fOU46b .LBrwzc .jgXgSe:hover{color: rgba(0,0,0,0.8000000119);}.fOU46b .LBrwzc .zDUgLc{opacity: 1; border-bottom-style: none;}.fOU46b .LBrwzc .tCHXDc{color: rgba(0,0,0,0.6399999857);}.fOU46b .LBrwzc .iWs3gf.chg4Jd:focus{background-color: rgba(0,0,0,0.1199999973);}.fOU46b .GBy4H .M9vuGd{color: rgba(255,255,255,1);}.fOU46b .GBy4H .jgXgSe:hover{color: rgba(255,255,255,1);}.fOU46b .GBy4H .zDUgLc{opacity: 1;}.fOU46b .GBy4H .tCHXDc{color: rgba(255,255,255,0.8700000048);}.fOU46b .GBy4H .iWs3gf.chg4Jd:focus{background-color: rgba(255,255,255,0.1199999973);}.XeSM4.G9Qloe.fOU46b .LBrwzc .tCHXDc{color: rgba(0,0,0,0.6399999857);}.XeSM4.G9Qloe.fOU46b .LBrwzc .iWs3gf.chg4Jd:focus{background-color: rgba(0,0,0,0.1199999973);}.GBy4H .lhZOrc.IKA38e{color: rgba(255,255,255,1);}.GBy4H .eBSUbc{color: rgba(255,255,255,0.8700000048);}.GBy4H .hDrhEe:hover{color: rgba(255,255,255,1);}.GBy4H .j10yRb:hover{color: rgba(255,255,255,1);}.GBy4H .YTv4We{color: rgba(255,255,255,1);}.GBy4H .YTv4We.chg4Jd:focus:before{border-color: rgba(255,255,255,1); display: block;}.GBy4H .tCHXDc{color: rgba(255,255,255,0.8700000048);}.GBy4H .iWs3gf.chg4Jd:focus{background-color: rgba(255,255,255,0.1199999973);}.GBy4H .jgXgSe:hover{color: rgba(255,255,255,1);}.GBy4H .jgXgSe:hover{color: rgba(255,255,255,1);}.GBy4H .M9vuGd{color: rgba(255,255,255,1);}.GBy4H .M9vuGd:hover{color: rgba(255,255,255,1);}.QcmuFb{padding-left: 20px;}.vDPrib{padding-left: 40px;}.TBDXjd{padding-left: 60px;}.bYeK8e{padding-left: 80px;}.CuqSDe{padding-left: 100px;}.Havqpe{padding-left: 120px;}.JvDrRe{padding-left: 140px;}.o5lrIf{padding-left: 160px;}.yOJW7c{padding-left: 180px;}.rB8cye{padding-left: 200px;}.RuayVd{padding-right: 20px;}.YzcKX{padding-right: 40px;}.reTV0b{padding-right: 60px;}.vSYeUc{padding-right: 80px;}.PxtZIe{padding-right: 100px;}.ahQMed{padding-right: 120px;}.rzhcXb{padding-right: 140px;}.PBhj0b{padding-right: 160px;}.TlN46c{padding-right: 180px;}.GEdNnc{padding-right: 200px;}.TMjjoe{font-family: Lato, sans-serif; font-size: 9pt; line-height: 1.2; margin-top: 0px;}@media only screen and (min-width: 1280px){.yxgWrb{margin-left: 250px;}}@media only screen and (max-width: 479px){.Zjiec{font-size: 15pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.Zjiec{font-size: 17pt;}}@media only screen and (max-width: 479px){.TlfmSc{font-size: 13pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.TlfmSc{font-size: 14pt;}}@media only screen and (max-width: 479px){.PsKE7e{font-size: 12pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.PsKE7e{font-size: 12pt;}}@media only screen and (max-width: 479px){.duRjpb{font-size: 24pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.duRjpb{font-size: 29pt;}}@media only screen and (max-width: 479px){.JYVBee{font-size: 15pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.JYVBee{font-size: 17pt;}}@media only screen and (max-width: 479px){.OmQG5e{font-size: 13pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.OmQG5e{font-size: 14pt;}}@media only screen and (max-width: 479px){.TlfmSc{font-size: 13pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.TlfmSc{font-size: 14pt;}}@media only screen and (max-width: 479px){.PsKE7e{font-size: 12pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.PsKE7e{font-size: 12pt;}}@media only screen and (max-width: 479px){.TMjjoe{font-size: 9pt;}}@media only screen and (min-width: 480px) and (max-width: 767px){.TMjjoe{font-size: 9pt;}}section[id=\"h.INITIAL_GRID.hz2sysafyqnv\"] .IFuOkc:before{opacity: 0.0;}</style><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\">_at_config = [null,\"AIzaSyChg3MFqzdi1P5J-YvEyakkSA1yU7HRcDI\",\"897606708560-a63d8ia0t9dhtpdt4i3djab2m42see7o.apps.googleusercontent.com\",null,null,null,null,null,null,null,null,null,null,null,\"SITES_%s\",null,null,null,null,null,null,null,null,null,[\"AHKXmL1v_PW0AhcXt6BpBW2jQrg8Oghi_CtVSF_bwn67w6hgFSgUc9r_UTzuuxDr7ST4iTjf_sC4\",1,\"CKa51YiVnYMDFaBLqwIdT8QMkA\",1703046087531681,[5703839,5704621,5706832,5706836,5707711,5735808,5737802,5738531,5740816,5743126,5746994,5747263,5748031,5752696,5753331,5754231,5755098,5758825,5760350,5762261,5764270,5765553,5766779,5767853,5770437,5773680,5774096,5774349,5774854,5776519,5777196,5783803,5784949,5784969,5791301,5791784,5792686,5796153,5796475,5797293,14101306,14101502,14101510,14101534,49372444,49375323,49376002,49376338,49378890,49451560,49453046,49472072,49512374,49517792,49612442,49613709,49622832,49623182,49624081,49644024,49765383,49769346,49816166,49822930,49823173,49824164,49833471,49839580,49842864,49924715,50082749,50127541,50166960,50168316,50221729,50266231,50273537,50293697,50335898,50360149,50390166,50492351,50520322,50529112,50533185,50580253,50606356,70979411,70983144,71035309,71085250,71102133,71119967,71152134,71178681,71185179,71197835,71230234,71238955,71241074,71260351,71273598,71286030,71289155,71301339,71330602,71346961,71382647,71396894,71401154,71407394,71444154,71471151,71471883,71480305,71528086,71528606,71530092,71531296,71537707,71558038,71624116,71625589,71641922,71659822,71671627]],\"AHL0AtKjtWA7F_ipTj5iNvZa4HtBjmPZSA:1703046087403\",null,null,null,0,null,null,null,null,null,null,null,null,null,\"https://drive.google.com\",null,null,null,null,null,null,null,null,null,0,1,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,\"v2internal\",\"https://docs.google.com\",null,null,null,null,null,null,\"https://sites.google.com/new/?authuser\\u003d0\",null,null,null,null,null,0,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,1,\"\",null,null,null,null,null,null,null,null,null,null,null,null,6,null,null,\"https://accounts.google.com/o/oauth2/auth\",\"https://accounts.google.com/o/oauth2/postmessageRelay\",null,null,null,null,78,\"https://sites.google.com/new/?authuser\\u003d0\\u0026usp\\u003dviewer_footer\\u0026authuser\\u003d0\",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,[],null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,\"https://www.gstatic.com/atari/embeds/83a60601c213b72fb19c1855fb0c5f26/intermediate-frame-minified.html\",0,null,\"v2beta\",null,null,null,null,null,null,4,\"https://accounts.google.com/o/oauth2/iframe\",null,null,null,null,null,null,\"https://1678599899-atari-embeds.googleusercontent.com/embeds/16cb204cf3a9d4d223a0a3fd8b0eec5d/inner-frame-minified.html\",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,0,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,\"https://sites.google.com/btrabucco.com/da-fusion/home\",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,0,null,null,null,null,null,null,0,null,\"02c5ul33\",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,0,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,1,null,null,[1703046087543,\"editors.sites-viewer-frontend_20231212.02_p0\",\"590108479\",\"0\",0,1,\"\"],null,null,null,null,1,null,null,0,null,null,null,null,null,null,null,null,20,500,\"https://domains.google.com\",null,0,null,null,null,null,null,null,null,null,null,null,null,0,null,null,null,0,null,null,null,null,null,1,1,0,1,0,0,0,0,null,null,null,null,null,\"https://www.google.com/calendar/embed\",null,1,null,null,0,null,null,null,null,null,null,null,null,null,null,0,null,null,null,null,null,1,1]; window.globals = {\"enableAnalytics\":true,\"webPropertyId\":\"\",\"showDebug\":false,\"hashedSiteId\":\"4cca8ee9fc4742b8704c74180e6f5423fa06ada53ff489f3d38f5c4e925b2faa\",\"normalizedPath\":\"btrabucco.com/da-fusion/home\",\"pageTitle\":\"Home\"}; function gapiLoaded() {if (globals.gapiLoaded == undefined) {globals.gapiLoaded = true;} else {globals.gapiLoaded();}}window.messages = []; window.addEventListener && window.addEventListener('message', function(e) {if (window.messages && e.data && e.data.magic == 'SHIC') {window.messages.push(e);}});</script><script src=\"https://apis.google.com/js/client.js?onload=gapiLoaded\" nonce=\"a9BH-X5QD4qdoD3whwiuBw\"></script><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\">(function(){}).call(this);\n</script><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\">const imageUrl = 'https:\\/\\/lh6.googleusercontent.com\\/_Uox1EaLrhZo7CLGId161i_WEN1vLr3EM6E3rVaNyvhhtk9-uzPHWYOIgsBT8oSsWV15Nj3KknD7kbe7fA7Rukg\\x3dw16383';\n      function bgImgLoaded() {\n        if (!globals.headerBgImgLoaded) {\n          globals.headerBgImgLoaded = new Date().getTime();\n        } else {\n          globals.headerBgImgLoaded();\n        }\n      }\n      if (imageUrl) {\n        const img = new Image();\n        img.src = imageUrl;\n        img.onload = bgImgLoaded;\n        globals.headerBgImgExists = true;\n      } else {\n        globals.headerBgImgExists = false;\n      }\n      </script></head><body dir=\"ltr\" itemscope itemtype=\"http://schema.org/WebPage\" id=\"yDmH0d\" css=\"yDmH0d\"><div jscontroller=\"pc62j\" jsmodel=\"iTeaXe\" jsaction=\"rcuQ6b:WYd;GvneHb:og1FDd;vbaUQc:uAM5ec;\"><div jscontroller=\"X4BaPc\" jsaction=\"rcuQ6b:WYd;o6xM5b:Pg9eo;HuL2Hd:mHeCvf;VMhF5:FFYy5e;sk3Qmb:HI1Mdd;JIbuQc:rSzFEd(z2EeY),aSaF6e(ilzYPe);\"><div jscontroller=\"o1L5Wb\" data-sitename=\"da-fusion\" data-domain=\"btrabucco.com\" data-universe=\"1\" jsmodel=\"fNFZH\" jsaction=\"Pe9H6d:cZFEp;WMZaJ:VsGN3;hJluRd:UADL7b;zuqEgd:HI9w0;tr6QDd:Y8aXB;MxH79b:xDkBfb;JIbuQc:SPXMTb(uxAMZ);\" jsname=\"G0jgYd\"><div jsname=\"gYwusb\" class=\"p9b27\"></div><div jscontroller=\"RrXLpc\" jsname=\"XeeWQc\" role=\"banner\" jsaction=\"keydown:uiKYid(OH0EC);rcuQ6b:WYd;zuqEgd:ufqpf;JIbuQc:XfTnxb(lfEfFf),AlTiYc(GeGHKb),AlTiYc(m1xNUe),zZlNMe(pZn8Oc);YqO5N:ELcyfe;\"><div jsname=\"bF1uUb\" class=\"BuY5Fd\" jsaction=\"click:xVuwSc;\"></div><div jsname=\"MVsrn\" class=\"TbNlJb \"><div role=\"button\" class=\"U26fgb mUbCce fKz7Od h3nfre M9Bg4d\" jscontroller=\"VXdfxd\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc(preventMouseEvents=true|preventDefault=true); touchcancel:JMtRjd;\" jsshadow jsname=\"GeGHKb\" aria-label=\"Back to site\" aria-disabled=\"false\" tabindex=\"0\" data-tooltip=\"Back to site\" data-tooltip-vertical-offset=\"-12\" data-tooltip-horizontal-offset=\"0\"><div class=\"VTBa7b MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"xjKiLb\"><span class=\"Ce1Y1c\" style=\"top: -12px\"><svg class=\"V4YR2c\" viewBox=\"0 0 24 24\" focusable=\"false\"><path d=\"M0 0h24v24H0z\" fill=\"none\"/><path d=\"M20 11H7.83l5.59-5.59L12 4l-8 8 8 8 1.41-1.41L7.83 13H20v-2z\"/></svg></span></span></div><div class=\"E2UJ5\" jsname=\"M6JdT\"><div class=\"rFrNMe b7AJhc zKHdkd\" jscontroller=\"pxq3x\" jsaction=\"clickonly:KjsqPd; focus:Jt1EX; blur:fpfTEe; input:Lg5SV\" jsshadow jsname=\"OH0EC\" aria-expanded=\"true\"><div class=\"aCsJod oJeWuf\"><div class=\"aXBtI I0VJ4d Wic03c\"><span jsslot class=\"A37UZe qgcB3c iHd5yb\"><div role=\"button\" class=\"U26fgb mUbCce fKz7Od i3PoXe M9Bg4d\" jscontroller=\"VXdfxd\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc(preventMouseEvents=true|preventDefault=true); touchcancel:JMtRjd;\" jsshadow jsname=\"lfEfFf\" aria-label=\"Search\" aria-disabled=\"false\" tabindex=\"0\" data-tooltip=\"Search\" data-tooltip-vertical-offset=\"-12\" data-tooltip-horizontal-offset=\"0\"><div class=\"VTBa7b MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"xjKiLb\"><span class=\"Ce1Y1c\" style=\"top: -12px\"><svg class=\"vu8Pwe\" viewBox=\"0 0 24 24\" focusable=\"false\"><path d=\"M15.5 14h-.79l-.28-.27C15.41 12.59 16 11.11 16 9.5 16 5.91 13.09 3 9.5 3S3 5.91 3 9.5 5.91 16 9.5 16c1.61 0 3.09-.59 4.23-1.57l.27.28v.79l5 4.99L20.49 19l-4.99-5zm-6 0C7.01 14 5 11.99 5 9.5S7.01 5 9.5 5 14 7.01 14 9.5 11.99 14 9.5 14z\"/><path d=\"M0 0h24v24H0z\" fill=\"none\"/></svg></span></span></div><div class=\"EmVfjc SKShhf\" data-loadingmessage=\"Loading…\" jscontroller=\"qAKInc\" jsaction=\"animationend:kWijWc;dyRcpb:dyRcpb\" jsname=\"aZ2wEe\"><div class=\"Cg7hO\" aria-live=\"assertive\" jsname=\"vyyg5\"></div><div jsname=\"Hxlbvc\" class=\"xu46lf\"><div class=\"ir3uv uWlRce co39ub\"><div class=\"xq3j6 ERcjC\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"HBnAAc\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"xq3j6 dj3yTd\"><div class=\"X6jHbb GOJTSe\"></div></div></div><div class=\"ir3uv GFoASc Cn087\"><div class=\"xq3j6 ERcjC\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"HBnAAc\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"xq3j6 dj3yTd\"><div class=\"X6jHbb GOJTSe\"></div></div></div><div class=\"ir3uv WpeOqd hfsr6b\"><div class=\"xq3j6 ERcjC\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"HBnAAc\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"xq3j6 dj3yTd\"><div class=\"X6jHbb GOJTSe\"></div></div></div><div class=\"ir3uv rHV3jf EjXFBf\"><div class=\"xq3j6 ERcjC\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"HBnAAc\"><div class=\"X6jHbb GOJTSe\"></div></div><div class=\"xq3j6 dj3yTd\"><div class=\"X6jHbb GOJTSe\"></div></div></div></div></div><div role=\"button\" class=\"U26fgb mUbCce fKz7Od JyJRXe M9Bg4d\" jscontroller=\"VXdfxd\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc(preventMouseEvents=true|preventDefault=true); touchcancel:JMtRjd;\" jsshadow jsname=\"m1xNUe\" aria-label=\"Back to site\" aria-disabled=\"false\" tabindex=\"0\" data-tooltip=\"Back to site\" data-tooltip-vertical-offset=\"-12\" data-tooltip-horizontal-offset=\"0\"><div class=\"VTBa7b MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"xjKiLb\"><span class=\"Ce1Y1c\" style=\"top: -12px\"><svg class=\"V4YR2c\" viewBox=\"0 0 24 24\" focusable=\"false\"><path d=\"M0 0h24v24H0z\" fill=\"none\"/><path d=\"M20 11H7.83l5.59-5.59L12 4l-8 8 8 8 1.41-1.41L7.83 13H20v-2z\"/></svg></span></span></div></span><div class=\"Xb9hP\"><input type=\"search\" class=\"whsOnd zHQkBf\" jsname=\"YPqjbf\" autocomplete=\"off\" tabindex=\"0\" aria-label=\"Search this site\" value=\"\" aria-disabled=\"false\" autofocus role=\"combobox\" data-initial-value=\"\"/><div jsname=\"LwH6nd\" class=\"ndJi5d snByac\" aria-hidden=\"true\">Search this site</div></div><span jsslot class=\"A37UZe sxyYjd MQL3Ob\"><div role=\"button\" class=\"U26fgb mUbCce fKz7Od Kk06A M9Bg4d\" jscontroller=\"VXdfxd\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc(preventMouseEvents=true|preventDefault=true); touchcancel:JMtRjd;\" jsshadow jsname=\"pZn8Oc\" aria-label=\"Clear search\" aria-disabled=\"false\" tabindex=\"0\" data-tooltip=\"Clear search\" data-tooltip-vertical-offset=\"-12\" data-tooltip-horizontal-offset=\"0\"><div class=\"VTBa7b MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"xjKiLb\"><span class=\"Ce1Y1c\" style=\"top: -12px\"><svg class=\"fAUEUd\" viewBox=\"0 0 24 24\" focusable=\"false\"><path d=\"M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z\"></path><path d=\"M0 0h24v24H0z\" fill=\"none\"></path></svg></span></span></div></span><div class=\"i9lrp mIZh1c\"></div><div jsname=\"XmnwAc\" class=\"OabDMe cXrdqd\"></div></div></div><div class=\"LXRPh\"><div jsname=\"ty6ygf\" class=\"ovnfwe Is7Fhb\"></div></div></div></div></div></div></div><div jsname=\"tiN4bf\"><style nonce=\"k7EAwqdvbMKbMpxLLtxMIg\">.rrJNTc{opacity: 0;}.bKy5e{pointer-events: none; position: absolute; top: 0;}</style><div class=\"bKy5e\"><div class=\"rrJNTc\" tabindex=\"-1\"><div class=\"VfPpkd-dgl2Hf-ppHlrf-sM5MNb\" data-is-touch-wrapper=\"true\"><button class=\"VfPpkd-LgbsSe VfPpkd-LgbsSe-OWXEXe-dgl2Hf LjDxcd XhPA0b LQeN7 WsSUlf jz7fPb\" jscontroller=\"soHxf\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc; touchcancel:JMtRjd; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;mlnRJb:fLiPzd;\" data-idom-class=\"LjDxcd XhPA0b LQeN7 WsSUlf jz7fPb\" jsname=\"z2EeY\" tabindex=\"0\"><div class=\"VfPpkd-Jh9lGc\"></div><div class=\"VfPpkd-J1Ukfc-LhBDec\"></div><div class=\"VfPpkd-RLmnJb\"></div><span jsname=\"V67aGc\" class=\"VfPpkd-vQzf8d\">Skip to main content</span></button></div><div class=\"VfPpkd-dgl2Hf-ppHlrf-sM5MNb\" data-is-touch-wrapper=\"true\"><button class=\"VfPpkd-LgbsSe VfPpkd-LgbsSe-OWXEXe-dgl2Hf LjDxcd XhPA0b LQeN7 WsSUlf br90J\" jscontroller=\"soHxf\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc; touchcancel:JMtRjd; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;mlnRJb:fLiPzd;\" data-idom-class=\"LjDxcd XhPA0b LQeN7 WsSUlf br90J\" jsname=\"ilzYPe\" tabindex=\"0\"><div class=\"VfPpkd-Jh9lGc\"></div><div class=\"VfPpkd-J1Ukfc-LhBDec\"></div><div class=\"VfPpkd-RLmnJb\"></div><span jsname=\"V67aGc\" class=\"VfPpkd-vQzf8d\">Skip to navigation</span></button></div></div></div><div class=\"M63kCb N63NQ\"></div><div class=\"QZ3zWd\"><div class=\"fktJzd AKpWA fOU46b yMcSQd Ly6Unf G9Qloe KuNac XxIgdb\" jsname=\"UzWXSb\" data-uses-custom-theme=\"false\" data-legacy-theme-name=\"QualityBasics\" data-legacy-theme-font-kit=\"Light\" data-legacy-theme-color-kit=\"Blue\" jscontroller=\"Md9ENb\" jsaction=\"gsiSmd:Ffcznf;yj5fUd:cpPetb;HNXL3:q0Vyke;e2SXKd:IPDu5e;BdXpgd:nhk7K;rcuQ6b:WYd;\"><header id=\"atIdViewHeader\"><div class=\"BbxBP G8QRnc K5Zlne\" jsname=\"WA9qLc\" jscontroller=\"RQOkef\" jsaction=\"rcuQ6b:ywL4Jf;VbOlFf:ywL4Jf;FaOgy:ywL4Jf; keydown:Hq2uPe; wheel:Ut4Ahc;\" data-top-navigation=\"true\" data-is-preview=\"false\"><div class=\"VLoccc K5Zlne ELAV1d U8eYrb\" jsname=\"rtFGi\"><div class=\"Pvc6xe\"><div jsname=\"I8J07e\" class=\"TlfmSc YSH9J\"><a class=\"GAuSPc\" jsname=\"jIujaf\" href=\"/btrabucco.com/da-fusion/home?authuser=0\"><span class=\"QTKDff\">DA-Fusion</span></a></div></div><div jsname=\"mADGA\" class=\"zDUgLc\"></div></div></div></header><div role=\"main\" tabindex=\"-1\" class=\"UtePc RCETm\" dir=\"ltr\"><section id=\"h.INITIAL_GRID.hz2sysafyqnv\" class=\"yaqOZd LB7kq cJgDec tpmmCb O13XJf KEFykf\" style=\"\"><div class=\"Nu95r\"><div class=\"IFuOkc\" style=\"background-position: center center; background-image: url(https://lh6.googleusercontent.com/_Uox1EaLrhZo7CLGId161i_WEN1vLr3EM6E3rVaNyvhhtk9-uzPHWYOIgsBT8oSsWV15Nj3KknD7kbe7fA7Rukg=w16383); background-size: cover;\" jsname=\"LQX2Vd\"></div></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc yYI8W HQwdzb\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.INITIAL_GRID.t5bg9a27xpls\" class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe lkHyyc\"><h2 id=\"h.e91iktvgsmnf\" dir=\"ltr\" class=\"zfr3Q JYVBee CDt4Ke \" style=\"background-color: transparent; border-bottom: none; border-left: none; border-right: none; border-top: none; line-height: 1.2; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 0.0pt; padding-right: 0.0pt; padding-top: 0.0pt; text-align: center;\"><span class=\"RijTuc C9DxTc \" style=\"color: #000000; vertical-align: baseline;\">Effective Data Augmentation With Diffusion Models</span></h2><br><h2 id=\"h.24ch5go1zjca\" dir=\"ltr\" class=\"zfr3Q JYVBee CDt4Ke \" style=\"background-color: transparent; border-bottom: none; border-left: none; border-right: none; border-top: none; line-height: 1.2; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 0.0pt; padding-right: 0.0pt; padding-top: 0.0pt; text-align: center;\"><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-family: Lato, Arial; font-variant: normal; font-weight: 400; vertical-align: baseline;\">Brandon Trabucco ¹ ,    </span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-family: Lato, Arial; font-weight: 400; vertical-align: baseline;\">Kyle Doherty</span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-family: Lato, Arial; font-variant: normal; font-weight: 400; vertical-align: baseline;\"> ² ,    </span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-family: Lato, Arial; font-weight: 400; vertical-align: baseline;\">Max Gurinas</span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-family: Lato, Arial; font-variant: normal; font-weight: 400; vertical-align: baseline;\"> ³ ,</span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-family: Lato, Arial; font-weight: 400; vertical-align: baseline;\">    </span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-family: Lato, Arial; font-variant: normal; font-weight: 400; vertical-align: baseline;\">Ruslan Salakhutdinov ¹</span></h2><h2 id=\"h.dkyz64wd01w9\" dir=\"ltr\" class=\"zfr3Q JYVBee CDt4Ke \" style=\"background-color: transparent; border-bottom: none; border-left: none; border-right: none; border-top: none; line-height: 1.2; margin-bottom: 20.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 0.0pt; padding-right: 0.0pt; padding-top: 0.0pt; text-align: center;\"><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-variant: normal; vertical-align: baseline;\">¹ Carnegie Mellon University ,</span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; vertical-align: baseline;\">    </span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-variant: normal; vertical-align: baseline;\">² </span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; vertical-align: baseline;\">MPG Ranch</span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; font-variant: normal; vertical-align: baseline;\"> ,    ³ University of </span><span class=\"jgG6ef C9DxTc \" style=\"color: #000000; vertical-align: baseline;\">Chicago Laboratory Schools</span></h2></div></div></div></div><div class=\"oKdM2c ZZyype\"><div id=\"h.662c5a9b377f41bf_8\" class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd jXK9ad D2fZ2 zu5uec wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd OWlOyc baZpAe\"><div jscontroller=\"VYKRW\" jsaction=\"rcuQ6b:rcuQ6b;\"><div class=\"WIdY2d M1aSXe\"><div jsname=\"WXxXjd\" style=\"padding-top: 45.4939341421%\"></div><iframe jsname=\"L5Fo6c\" class=\"YMEQtf\" sandbox=\"allow-scripts allow-popups allow-forms allow-same-origin allow-popups-to-escape-sandbox allow-downloads allow-modals allow-storage-access-by-user-activation\" frameborder=\"0\" aria-label=\"YouTube Video, Effective Data Augmentation With Diffusion Models [NeurIPS 2023]\" src=\"https://www.youtube.com/embed/IKDWOOWzwns?embed_config=%7B%22gws%22:1%7D\" allowfullscreen></iframe></div></div></div></div></div></div></div></div></div></div></div><div class=\"DnLU4\" jsaction=\"JIbuQc:v5IJLd(ipHvib);\"><div role=\"button\" class=\"U26fgb mUbCce fKz7Od HqAAld Wew9ke M9Bg4d\" jscontroller=\"VXdfxd\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc(preventMouseEvents=true|preventDefault=true); touchcancel:JMtRjd;\" jsshadow jsname=\"ipHvib\" aria-label=\"Scroll down\" aria-disabled=\"false\" tabindex=\"0\"><div class=\"VTBa7b MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"xjKiLb\"><span class=\"Ce1Y1c\" style=\"top: -12px\"><svg class=\"XE8yyf\" viewBox=\"0 0 24 24\" focusable=\"false\"><path d=\"M7.41 7.84L12 12.42l4.59-4.58L18 9.25l-6 6-6-6z\"/><path d=\"M0-.75h24v24H0z\" fill=\"none\"/></svg></span></span></div></div></section><section id=\"h.662c5a9b377f41bf_13\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_16\" class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"text-align: center;\"><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-variant: normal; font-weight: 400; vertical-align: baseline;\">Paper: </span><a class=\"XqQF9c\" href=\"https://arxiv.org/abs/2302.07944\" target=\"_blank\" style=\"color: inherit; text-decoration: none;\"><span class=\"C9DxTc aw5Odc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; text-decoration: underline; vertical-align: baseline;\">arXiv</span></a><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; vertical-align: baseline;\">    |    </span><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-variant: normal; font-weight: 400; vertical-align: baseline;\">Code: </span><a class=\"XqQF9c\" href=\"https://github.com/brandontrabucco/da-fusion\" target=\"_blank\" style=\"color: inherit; text-decoration: none;\"><span class=\"C9DxTc aw5Odc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-variant: normal; font-weight: 400; text-decoration: underline; vertical-align: baseline;\">GitHub</span></a></p></div></div></div></div></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_49\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_45\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"t3iYD\"><img src=\"https://lh5.googleusercontent.com/mTrJFRXLoOsce_2zNf1rSofnhmI3zU0oBAU09nT9cl5uj_KakpWpEkR99OXH3MQQT_VnP46NNY56Khcl3pMA5ybQ3yAUXw7CpN2Ndh_XJ5VYtGG6eQXe0EfeYnx-4qym3g=w1280\" class=\"CENy8b\" role=\"img\"></div></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_33\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_36\" class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"text-align: center;\"><span class=\"puwcIf C9DxTc \" style=\"font-family: Lora, Arial; font-weight: 400; vertical-align: baseline;\">Abstract</span></p></div></div></div></div></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_37\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_40\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"\"><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; vertical-align: baseline;\">Data augmentation is one of the most prevalent tools in deep learning, underpinning many recent advances, including those from classification, generative models, and representation learning. The standard approach to data augmentation combines simple transformations like rotations and flips to generate new images from existing ones. However, these new images lack diversity along key semantic axes present in the data. Current augmentations cannot alter the high-level semantic attributes, such as animal species present in a scene, to enhance the diversity of data. We address the lack of diversity in data augmentation with image-to-image transformations parameterized by pre-trained text-to-image diffusion models. Our method edits images to change their semantics using an off-the-shelf diffusion model, and generalizes to novel visual concepts from a few labelled examples. We evaluate our approach on few-shot image classification tasks, and on a real-world weed recognition task, and observe an improvement in accuracy in tested domains.</span></p></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_70\" class=\"yaqOZd WxWicb\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_73\" class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"iwQgFb\" role=\"presentation\"></div></div></div></div></div></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_53\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc yYI8W HQwdzb\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_50\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"t3iYD\"><img src=\"https://lh3.googleusercontent.com/MQzaRK1DmU7bpbAy8w8VaxuiTyOC9stf_T3oUm8JeCb6A-fVONJw7DwhnYA1wmxVubLNwWI8L13GaxNVh-FakS-3QFUfvzVX4Vcj_jFFgIvvwRQxGQtnOzLIWT-kMUAODA=w1280\" class=\"CENy8b\" role=\"img\"></div></div></div></div></div><div class=\"oKdM2c ZZyype\"><div id=\"h.662c5a9b377f41bf_58\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"\"><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; vertical-align: baseline;\">Our augmentation adapts to the images in your datasets by learning pseudo-prompts &lt;y&gt; for each class.</span></p></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_65\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc yYI8W HQwdzb\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_62\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"t3iYD\"><img src=\"https://lh4.googleusercontent.com/XTw2snXXe905NI7eSTxEQHvq9QA9lwmZuQ-flQdk0QnidV8a91SPJi8-bgthp61ATMIGnOIfAsh3ighbbgOl7vAsdNDOjUW3ibrbhxgca156DY8O_PsO4cnI_c0cW98N0A=w1280\" class=\"CENy8b\" role=\"img\"></div></div></div></div></div><div class=\"oKdM2c ZZyype\"><div id=\"h.662c5a9b377f41bf_69\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"\"><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; vertical-align: baseline;\">We generate augmentations using the structural layout of real images as a guide.</span></p></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_77\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc yYI8W HQwdzb\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_74\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"t3iYD\"><img src=\"https://lh3.googleusercontent.com/WXNoTzEQOxwqLa97-pXeGSb2DdHX-ajfq-dKVasSOib1F8FZQhYjGmmB10WWQ88AlCq0cr5MNDI-E2uRPgLkoaOagOnHnfwsVk32Hx-wYWUIMnqIys0DIKpLejbuTRwPKw=w1280\" class=\"CENy8b\" role=\"img\"></div></div></div></div></div><div class=\"oKdM2c ZZyype\"><div id=\"h.662c5a9b377f41bf_81\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"\"><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; vertical-align: baseline;\">Generations from DA-Fusion </span><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-variant: normal; font-weight: 400; vertical-align: baseline;\">preserve</span><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; vertical-align: baseline;\"> the layout of trees, but produce different structural elements.</span></p></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_82\" class=\"yaqOZd WxWicb\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_85\" class=\"hJDwNd-AhqUyc-uQSCkd Ft7HRd-AhqUyc-uQSCkd jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"iwQgFb\" role=\"presentation\"></div></div></div></div></div></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_92\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc yYI8W HQwdzb\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_89\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"t3iYD\"><img src=\"https://lh5.googleusercontent.com/wHPXgt9yfYbDB8WmAw8gUzU9kdH-UKUpRYXgJuh_woA8r1_DZo5wjJksSPYNRk47RgL5-o3ErD54W85nmNNBcydnCj27HwYyGHuDk2t3aKQM0lDJ8NvgutvaIIwf0FeyfQ=w1280\" class=\"CENy8b\" role=\"img\"></div></div></div></div></div><div class=\"oKdM2c ZZyype\"><div id=\"h.662c5a9b377f41bf_102\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"\"><span class=\"C9DxTc \" style=\"font-family: Lora, Arial; font-size: 13.999999999999998pt; font-weight: 400; vertical-align: baseline;\">We see strong performance across seven few-shot classification tasks.</span></p></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section><section id=\"h.662c5a9b377f41bf_107\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.662c5a9b377f41bf_104\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd baZpAe\"><div class=\"t3iYD\"><img src=\"https://lh3.googleusercontent.com/J9GoHlhEIz90S_RcvjLHn5FsBXOHn24U7VhqKQBCb9V6SJSjMpqMV1MX7BSD-UwEsm3R9t9C2uzvfUp7oRxQ_vVjLnGm22Wkjja5rM4hn2lDbjnJjjzGx5dsS--So6I76g=w1280\" class=\"CENy8b\" role=\"img\"></div></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section><section id=\"h.afc85eb394d9d41_3\" class=\"yaqOZd\"><div class=\"IFuOkc\"></div><div class=\"mYVXT\"><div class=\"LS81yb VICjCf j5pSsc db35Fc\" tabindex=\"-1\"><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div><div class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO purZT-AhqUyc-II5mzb ZcASvf-AhqUyc-II5mzb pSzOP-AhqUyc-qWD73c Ktthjf-AhqUyc-qWD73c JNdkSc SQVYQc\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"><div class=\"oKdM2c ZZyype Kzv0Me\"><div id=\"h.afc85eb394d9d41_0\" class=\"hJDwNd-AhqUyc-EehZO Ft7HRd-AhqUyc-EehZO jXK9ad D2fZ2 zu5uec OjCsFc dmUFtb wHaque g5GTcb\"><div class=\"jXK9ad-SmKAyb\"><div class=\"tyJCtd mGzaTb Depvyb baZpAe\"><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: solid #ffffff1f 0.75pt; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 4.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">@misc{https://doi.org/10.48550/arxiv.2302.07944,</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  doi = {10.48550/ARXIV.2302.07944},</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  url = {https://arxiv.org/abs/2302.07944},</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  author = {Trabucco, Brandon and Doherty, Kyle and Gurinas, Max and Salakhutdinov, Ruslan},</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  title = {Effective Data Augmentation With Diffusion Models},</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  publisher = {arXiv},</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  year = {2023},</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: none; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 0.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">  copyright = {arXiv.org perpetual, non-exclusive license}</span></p><p  dir=\"ltr\" class=\"zfr3Q CDt4Ke \" style=\"background-clip: padding-box; background-color: rgba(0,0,0,0.0590000004); border-bottom: solid #ffffff1f 0.75pt; border-left: solid #ffffff1f 0.75pt; border-right: solid #ffffff1f 0.75pt; border-top: none; margin-bottom: 0.0pt; margin-top: 0.0pt; padding-bottom: 4.0pt; padding-left: 4.0pt; padding-right: 4.0pt; padding-top: 0.0pt;\"><span class=\"C9DxTc \" style=\"color: #24292e; font-family: &#39;Source Code Pro&#39;, Arial; font-size: 10.0pt; font-weight: 400; vertical-align: baseline;\">}</span></p><br></div></div></div></div></div></div></div><div class=\"hJDwNd-AhqUyc-R6PoUb Ft7HRd-AhqUyc-R6PoUb JNdkSc SQVYQc L6cTce-purZT L6cTce-pSzOP\"><div class=\"JNdkSc-SmKAyb LkDMRd\"><div class=\"\" jscontroller=\"sGwD4d\" jsaction=\"zXBUYb:zTPCnb;zQF9Uc:Qxe3nd;\" jsname=\"F57UId\"></div></div></div></div></div></section></div><div class=\"Xpil1b xgQ6eb\"></div><footer jsname=\"yePe5c\"></footer><div jscontroller=\"j1RDQb\" jsaction=\"rcuQ6b:rcuQ6b;MxH79b:JdcaS;FaOgy:XuHpsb;\" class=\"dZA9kd ynRLnc\" data-last-updated-at-time=\"1703046058114\" data-is-preview=\"false\"><div role=\"button\" class=\"U26fgb JRtysb WzwrXb I12f0b K2mXPb zXBiaf ynRLnc\" jscontroller=\"iSvg6e\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc(preventMouseEvents=true|preventDefault=true); touchcancel:JMtRjd;;keydown:I481le;\" jsshadow jsname=\"Bg3gkf\" aria-label=\"Site actions\" aria-disabled=\"false\" tabindex=\"0\" aria-haspopup=\"true\" aria-expanded=\"false\" data-menu-corner=\"bottom-start\" data-anchor-corner=\"top-start\"><div class=\"NWlf3e MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"MhXXcc oJeWuf\"><span class=\"Lw7GHd snByac\"><svg width=\"24\" height=\"24\" viewBox=\"0 0 24 24\" focusable=\"false\" class=\" NMm5M\"><path d=\"M11 17h2v-6h-2v6zm1-15C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8zM11 9h2V7h-2v2z\"/></svg></span></span><div jsname=\"xl07Ob\" style=\"display:none\" aria-hidden=\"true\"><div class=\"JPdR6b hVNH5c\" jscontroller=\"uY3Nvd\" jsaction=\"IpSVtb:TvD9Pc;fEN2Ze:xzS4ub;frq95c:LNeFm;cFpp9e:J9oOtd; click:H8nU8b; mouseup:H8nU8b; keydown:I481le; keypress:Kr2w4b; blur:O22p3e; focus:H8nU8b\" role=\"menu\" tabindex=\"0\" style=\"position:fixed\"><div class=\"XvhY1d\" jsaction=\"mousedown:p8EH2c; touchstart:p8EH2c;\"><div class=\"JAPqpe K0NPx\"><span jsslot class=\"z80M1 FeRvI\" jsaction=\"click:o6ZaF(preventDefault=true); mousedown:lAhnzb; mouseup:Osgxgf; mouseenter:SKyDAe; mouseleave:xq3APb;touchstart:jJiBRc; touchmove:kZeBdd; touchend:VfAz8(preventMouseEvents=true)\" jsname=\"j7LFlb\" aria-label=\"Admin\" role=\"menuitem\" tabindex=\"-1\"><div class=\"aBBjbd MbhUzd\" jsname=\"ksKsZd\"></div><div class=\"uyYuVb oJeWuf\" jscontroller=\"Uw6ODe\" jsaction=\"JIbuQc:sA9Jl;\" jsmodel=\"IlVkp\" data-admin-details-url=\"/v/showsitedetails/btrabucco.com/da-fusion\" data-request-edit-access-url=\"null\"><div class=\"jO7h3c\">Admin</div></div></span><span jsslot class=\"z80M1 FeRvI\" jsaction=\"click:o6ZaF(preventDefault=true); mousedown:lAhnzb; mouseup:Osgxgf; mouseenter:SKyDAe; mouseleave:xq3APb;touchstart:jJiBRc; touchmove:kZeBdd; touchend:VfAz8(preventMouseEvents=true)\" jsname=\"j7LFlb\" data-disabled-tooltip=\"Contact is not available in preview mode\" aria-label=\"Contact\" role=\"menuitem\" tabindex=\"-1\"><div class=\"aBBjbd MbhUzd\" jsname=\"ksKsZd\"></div><div class=\"uyYuVb oJeWuf\" jscontroller=\"j3gDVb\" jsaction=\"JIbuQc:sGCPHc;\" jsmodel=\"Rta7Nb\" data-normalized-path=\"btrabucco.com/da-fusion/home\"><div class=\"jO7h3c\">Contact</div></div></span><span jsslot class=\"z80M1 FeRvI\" jsaction=\"click:o6ZaF(preventDefault=true); mousedown:lAhnzb; mouseup:Osgxgf; mouseenter:SKyDAe; mouseleave:xq3APb;touchstart:jJiBRc; touchmove:kZeBdd; touchend:VfAz8(preventMouseEvents=true)\" jsname=\"j7LFlb\" aria-label=\"Page details\" role=\"menuitem\" tabindex=\"-1\"><div class=\"aBBjbd MbhUzd\" jsname=\"ksKsZd\"></div><div class=\"uyYuVb oJeWuf\" jsaction=\"JIbuQc:hriXLd;\" jsname=\"Rg8K2c\"><div class=\"jO7h3c\">Page details</div></div></span></div></div></div></div></div></div><div jscontroller=\"j1RDQb\" jsaction=\"focusin:gBxDVb(srlkmf); focusout:zvXhGb(srlkmf); click:ro2KTd(psdQ5e);JIbuQc:DSypkd(Bg3gkf);MxH79b:JdcaS;rcuQ6b:rcuQ6b;\" class=\"LqzjUe ynRLnc\" data-last-updated-at-time=\"1703046058114\" data-is-preview=\"false\"><div jsname=\"psdQ5e\" class=\"Q0cSn\"></div><div jsname=\"bN97Pc\" class=\"hBW7Hb\"><div role=\"button\" class=\"U26fgb mUbCce fKz7Od kpPxtd QMuaBc M9Bg4d\" jscontroller=\"VXdfxd\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;touchstart:p6p2H; touchmove:FwuNnf; touchend:yfqBxc(preventMouseEvents=true|preventDefault=true); touchcancel:JMtRjd;\" jsshadow jsname=\"Bg3gkf\" aria-label=\"Site actions\" aria-disabled=\"false\" tabindex=\"-1\" aria-hidden=\"true\"><div class=\"VTBa7b MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"xjKiLb\"><span class=\"Ce1Y1c\" style=\"top: -12px\"><svg width=\"24\" height=\"24\" viewBox=\"0 0 24 24\" focusable=\"false\" class=\" NMm5M\"><path d=\"M11 17h2v-6h-2v6zm1-15C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8zM11 9h2V7h-2v2z\"/></svg></span></span></div><div jsname=\"srlkmf\" class=\"hUphyc\"><div class=\"YkaBSd\"><div class=\"iBkmkf\"><span>Page updated</span> <span jsname=\"CFIm1b\" class=\"dji00c\" jsaction=\"AHmuwe:eGiyHb; mouseover:eGiyHb;\" tabindex=\"0\" role=\"contentinfo\"></span></div></div><div class=\"YkaBSd\" jsmodel=\"IlVkp\" jscontroller=\"Uw6ODe\" jsaction=\"click:sA9Jl\" data-admin-details-url=\"/v/showsitedetails/btrabucco.com/da-fusion\" data-request-edit-access-url=\"null\"><div role=\"button\" class=\"U26fgb kpPxtd J7BuEb\" jsshadow jsname=\"lV5oke\" aria-disabled=\"false\" tabindex=\"0\">Admin</div></div><div class=\"YkaBSd\" jscontroller=\"j3gDVb\" jsmodel=\"Rta7Nb\" jsaction=\"click:sGCPHc;\" data-normalized-path=\"btrabucco.com/da-fusion/home\"><div role=\"button\" class=\"U26fgb kpPxtd J7BuEb\" jsshadow aria-label=\"Contact \" aria-disabled=\"false\" tabindex=\"0\">Contact</div></div></div></div></div><div jsname=\"kdb7zb\"><div jscontroller=\"kklOXe\" jsmodel=\"nbZU0e\" jsaction=\"rcuQ6b:rcuQ6b;FaOgy:nkegzf;BU3dg:U3QbAf;HRy4zb:Z8zbSc;\" class=\"Pt0Du TSZdd\"><div class=\"mF4yBc\" jsname=\"LgbsSe\"><div jscontroller=\"TW9Rvc\" jsaction=\"rcuQ6b:WYd;\"><div role=\"presentation\" class=\"U26fgb XHsn7e MAXCNe M9Bg4d\" jscontroller=\"VXdfxd\" jsaction=\"click:cOuCgd; mousedown:UX7yZ; mouseup:lbsD7e; mouseenter:tfO1Yc; mouseleave:JywGue; focus:AHmuwe; blur:O22p3e; contextmenu:mg9Pef;\" jsshadow jsname=\"LgbsSe\" aria-label=\"Edit this page\" aria-disabled=\"false\" data-tooltip=\"Edit this page\" data-tooltip-vertical-offset=\"-12\" data-tooltip-horizontal-offset=\"0\"><a class=\"FKF6mc TpQm9d\" href=\"/u/0/d/1H2xR56eRCUmVUkDfo9SMNUBwpShtZDZa/p/1j1n5utR-JyBSVQIkKZNN9mWgOndcixik/edit?authuser=0&amp;usp=edit_published_site\" aria-label=\"Edit this page\"><div class=\"HaXdpb wb61gb\"></div><div class=\"HRp7vf MbhUzd\" jsname=\"ksKsZd\"></div><span jsslot class=\"Ip8zfc\"><svg class=\"EI709d\" viewBox=\"0 0 24 24\" fill=\"currentColor\" focusable=\"false\"><path d=\"M3 17.25V21h3.75L17.81 9.94l-3.75-3.75L3 17.25zM20.71 7.04c.39-.39.39-1.02 0-1.41l-2.34-2.34c-.39-.39-1.02-.39-1.41 0l-1.83 1.83 3.75 3.75 1.83-1.83z\"/><path d=\"M0 0h24v24H0z\" fill=\"none\"/></svg></span></a></div></div></div></div></div></div></div></div><script nonce=\"a9BH-X5QD4qdoD3whwiuBw\">DOCS_timing['cov']=new Date().getTime();</script><script src=\"https://www.gstatic.com/_/atari/_/js/k=atari.vw.en.e1fAekHULl4.O/am=gAE/d=1/rs=AGEqA5mst-EBQJicLg6789_EIxD3O9iTkA/m=view\" id=\"base-js\" nonce=\"a9BH-X5QD4qdoD3whwiuBw\"></script></div></div><div jscontroller=\"YV8yqd\" jsaction=\"rcuQ6b:npT2md\"></div></body></html>\n"
  },
  {
    "path": "plot.py",
    "content": "import matplotlib.pyplot as plt\nimport matplotlib\nimport pandas as pd\nimport seaborn as sns\n\nimport os\nimport glob\nimport argparse\nimport math\n\n\ndef pretty(text):\n    \"\"\"Convert a string into a consistent format for\n    presentation in a matplotlib pyplot:\n    this version looks like: One Two Three Four\n    \"\"\"\n\n    text = text.replace(\"_\", \" \")\n    text = text.replace(\"-\", \" \")\n    text = text.replace(\"/\", \" \")\n    text = text.strip()\n    prev_c = None\n    out_str = []\n    for c in text:\n        if prev_c is not None and \\\n                prev_c.islower() and c.isupper():\n            out_str.append(\" \")\n            prev_c = \" \"\n        if prev_c is None or prev_c == \" \":\n            c = c.upper()\n        out_str.append(c)\n        prev_c = c\n    return \"\".join(out_str)\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Few-Shot Baseline\")\n\n    parser.add_argument(\"--logdirs\", nargs=\"+\", type=str, default=[\n        \"./spurge-baselines\", \"./pascal-baselines\", \"./coco-baselines\", \"./imagenet-baselines\"])\n    \n    parser.add_argument(\"--datasets\", nargs=\"+\", type=str, \n                        default=[\"Spurge\", \"Pascal\", \"COCO\", \"ImageNet\"])\n    \n    parser.add_argument(\"--method-dirs\", nargs=\"+\", type=str, \n                        default=[\"baseline\", \"real-guidance\", \"ours\"])\n    \n    parser.add_argument(\"--method-names\", nargs=\"+\", type=str, \n                        default=[\"Baseline\", \"Real Guidance (He et al., 2022)\", \"MBDA (Ours)\"])\n    \n    parser.add_argument(\"--name\", type=str, default=\"visualization\")\n    \n    parser.add_argument(\"--rows\", type=int, default=1)\n\n    args = parser.parse_args()\n\n    combined_dataframe = []\n\n    for logdir, dataset in zip(\n            args.logdirs, args.datasets):\n\n        for bname in os.listdir(logdir):\n\n            bpath = os.path.join(logdir, bname)\n\n            if not os.path.isdir(bpath):\n                continue\n\n            files = list(glob.glob(os.path.join(bpath, \"*.csv\")))\n\n            if len(files) == 0:\n                continue\n\n            data = pd.concat([pd.read_csv(x, index_col=0) \n                              for x in files], ignore_index=True)\n\n            data = data[(data[\"metric\"] == \"Accuracy\") & \n                        (data[ \"split\"] == \"Validation\")]\n\n            def select_by_epoch(df):\n                selected_row = df.loc[df[\"value\"].idxmax()]\n                return data[(data[\"epoch\"] == selected_row[\"epoch\"]) & \n                            (data[ \"examples_per_class\"] == \n                            selected_row[\"examples_per_class\"])]\n\n            best = data.groupby([\"examples_per_class\", \"epoch\"])\n            best = best[\"value\"].mean().to_frame('value').reset_index()\n            best = best.groupby(\"examples_per_class\").apply(\n                select_by_epoch\n            )\n\n            best[\"method\"] = bname\n            best[\"dataset\"] = dataset\n            combined_dataframe.append(best)\n\n    matplotlib.rc('font', family='Times New Roman', serif='cm10')\n    matplotlib.rc('mathtext', fontset='cm')\n    plt.rcParams['text.usetex'] = False\n\n    combined_dataframe = pd.concat(\n        combined_dataframe, ignore_index=True)\n\n    combined_dataframe = pd.concat([combined_dataframe[\n        combined_dataframe['method'] == n] for n in args.method_dirs])\n    \n    color_palette = sns.color_palette(n_colors=len(args.method_dirs))\n\n    legend_rows = int(math.ceil(len(args.method_names) / len(args.datasets)))\n    columns = int(math.ceil(len(args.datasets) / args.rows))\n\n    fig, axs = plt.subplots(\n        args.rows, columns,\n        figsize=(6 * columns, 4 * args.rows + (\n            2.0 if legend_rows == 1 else\n            2.5 if legend_rows == 2 else 3\n        )))\n\n    for i, dataset in enumerate(args.datasets):\n\n        results = combined_dataframe\n        if dataset not in [\"all\", \"All\", \"Overall\"]:\n            results = results[results[\"dataset\"] == dataset]\n\n        axis = sns.lineplot(x=\"examples_per_class\", y=\"value\", hue=\"method\", \n                            data=results, errorbar=('ci', 68),\n                            linewidth=4, palette=color_palette,\n                            ax=(\n            axs[i // columns, i % columns] \n            if args.rows > 1 and len(args.datasets) > 1 \n            else axs[i] if len(args.datasets) > 1 else axs\n        ))\n\n        if i == 0: handles, labels = axis.get_legend_handles_labels()\n        axis.legend([],[], frameon=False)\n\n        axis.set(xlabel=None)\n        axis.set(ylabel=None)\n\n        axis.spines['right'].set_visible(False)\n        axis.spines['top'].set_visible(False)\n\n        axis.xaxis.set_ticks_position('bottom')\n        axis.yaxis.set_ticks_position('left')\n\n        axis.yaxis.set_tick_params(labelsize=16)\n        axis.xaxis.set_tick_params(labelsize=16)\n\n        if i // columns == args.rows - 1:\n            axis.set_xlabel(\"Examples Per Class\", fontsize=24,\n                            fontweight='bold', labelpad=12)\n\n        axis.set_ylabel(\"Accuracy (Val)\", fontsize=24,\n                        fontweight='bold', labelpad=12)\n\n        axis.set_title(dataset, fontsize=24, fontweight='bold', pad=12)\n\n        axis.grid(color='grey', linestyle='dotted', linewidth=2)\n\n    legend = fig.legend(handles, [x for x in args.method_names],\n                        loc=\"lower center\", prop={'size': 24, 'weight': 'bold'}, \n                        ncol=min(len(args.method_names), len(args.datasets)))\n\n    for i, legend_object in enumerate(legend.legendHandles):\n        legend_object.set_linewidth(4.0)\n        legend_object.set_color(color_palette[i])\n\n    plt.tight_layout(pad=1.0)\n    fig.subplots_adjust(hspace=0.3)\n\n    fig.subplots_adjust(bottom=(\n        0.25 if legend_rows == 1 else\n        0.35 if legend_rows == 2 else 0.4\n    ) / args.rows + 0.05)\n\n    plt.savefig(f\"{args.name}.pdf\")\n    plt.savefig(f\"{args.name}.png\")"
  },
  {
    "path": "plot_masking_ablation.py",
    "content": "import matplotlib.pyplot as plt\nimport matplotlib\nimport numpy as np\nimport pandas as pd\nimport seaborn as sns\nfrom collections import defaultdict\nfrom itertools import product\n\nimport os\nimport glob\nimport argparse\nimport math\n\n\ndef pretty(text):\n    \"\"\"Convert a string into a consistent format for\n    presentation in a matplotlib pyplot:\n    this version looks like: One Two Three Four\n    \"\"\"\n\n    text = text.replace(\"_\", \" \")\n    text = text.replace(\"-\", \" \")\n    text = text.replace(\"/\", \" \")\n    text = text.strip()\n    prev_c = None\n    out_str = []\n    for c in text:\n        if prev_c is not None and \\\n                prev_c.islower() and c.isupper():\n            out_str.append(\" \")\n            prev_c = \" \"\n        if prev_c is None or prev_c == \" \":\n            c = c.upper()\n        out_str.append(c)\n        prev_c = c\n    return \"\".join(out_str)\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Few-Shot Baseline\")\n\n    parser.add_argument(\"--logdirs\", nargs=\"+\", type=str, \n                        default=[\"./pascal-baselines\", \"./coco-baselines\"])\n    \n    parser.add_argument(\"--datasets\", nargs=\"+\", type=str, \n                        default=[\"Pascal\", \"COCO\"])\n    \n    parser.add_argument(\"--method-dirs\", nargs=\"+\", type=str, \n                        default=[\"textual-inversion-0.5\", \n                                 \"textual-inversion-mask-0-0.5\", \n                                 \"textual-inversion-mask-0.5-0\"])\n    \n    parser.add_argument(\"--baseline-dirs\", nargs=\"+\", type=str, \n                        default=[\"real-guidance-0.5-cap\", \n                                 \"real-guidance-mask-0-0.5\", \n                                 \"real-guidance-mask-0.5-0\"])\n    \n    parser.add_argument(\"--method-names\", nargs=\"+\", type=str, \n                        default=[\"Original\", \n                                 \"Masked Foreground\", \n                                 \"Masked Background\"])\n    \n    parser.add_argument(\"--name\", type=str, default=\"masking-results\")\n    parser.add_argument(\"--rows\", type=int, default=1)\n    parser.add_argument(\"--num-trials\", type=int, default=8)\n    \n    parser.add_argument(\"--no-legend\", action=\"store_true\")\n\n    args = parser.parse_args()\n\n    combined_dataframe = []\n\n    for logdir, dataset in zip(args.logdirs, args.datasets):\n\n        for bname in os.listdir(logdir):\n\n            bpath = os.path.join(logdir, bname)\n\n            if not os.path.isdir(bpath):\n                continue\n\n            files = list(glob.glob(os.path.join(bpath, \"*.csv\")))\n\n            if len(files) == 0:\n                continue\n\n            data = pd.concat([pd.read_csv(x, index_col=0) \n                              for x in files], ignore_index=True)\n\n            data = data[(data[\"metric\"] == \"Accuracy\") & \n                        (data[ \"split\"] == \"Validation\")]\n\n            def select_by_epoch(df):\n                selected_row = df.loc[df[\"value\"].idxmax()]\n                return data[(data[\"epoch\"] == selected_row[\"epoch\"]) & \n                            (data[ \"examples_per_class\"] == \n                            selected_row[\"examples_per_class\"])]\n\n            best = data.groupby([\"examples_per_class\", \"epoch\"])\n            best = best[\"value\"].mean().to_frame('value').reset_index()\n            best = best.groupby(\"examples_per_class\").apply(\n                select_by_epoch\n            )\n\n            best[\"method\"] = bname\n            best[\"dataset\"] = dataset\n            combined_dataframe.append(best)\n\n    matplotlib.rc('font', family='Times New Roman', serif='cm10')\n    matplotlib.rc('mathtext', fontset='cm')\n    plt.rcParams['text.usetex'] = False\n\n    combined_dataframe = pd.concat(\n        combined_dataframe, ignore_index=True)\n\n    combined_dataframe = pd.concat([\n        combined_dataframe[combined_dataframe['method'] == n] \n        for n in (args.method_dirs + args.baseline_dirs)])\n    \n    color_palette = sns.color_palette(n_colors=len(args.method_dirs))\n\n    legend_rows = int(math.ceil(len(args.method_names) / (2 * len(args.datasets))))\n    columns = int(math.ceil(2 * len(args.datasets) / args.rows))\n\n    fig, axs = plt.subplots(\n        args.rows, columns,\n        figsize=(6 * columns, 4 * args.rows + ((\n            2.0 if legend_rows == 1 else\n            2.5 if legend_rows == 2 else 3\n        ) if not args.no_legend else 1.0)))\n\n    baseline_performance = defaultdict(list)\n\n    for dataset in args.datasets:\n        for seed in range(args.num_trials):\n            for bi, method in enumerate(args.baseline_dirs):\n\n                results = combined_dataframe[\n                    (combined_dataframe[\"dataset\"] == dataset) & \n                    (combined_dataframe[\"method\"] == method) & \n                    (combined_dataframe[\"seed\"] == seed)\n                ]\n\n                for examples in [1, 2, 4, 8, 16]:\n\n                    value = results[results[\"examples_per_class\"] \n                                    == examples][\"value\"].to_numpy()\n\n                    if value.size > 0: baseline_performance[\n                        (dataset, bi, examples)].append(value[0])\n\n                cumulative_value = 0.0\n                invalid = False\n\n                for examples_a, examples_b in zip([1, 2, 4, 8], [2, 4, 8, 16]):\n\n                    value_a = results[results[\"examples_per_class\"] == examples_a][\"value\"].to_numpy()\n                    value_b = results[results[\"examples_per_class\"] == examples_b][\"value\"].to_numpy()\n\n                    if value_a.size > 0 and value_b.size > 0:\n                        cumulative_value += ((value_a + value_b) / 2) * (examples_b - examples_a)\n\n                    else: invalid = True\n\n                if not invalid: baseline_performance[\n                    (dataset, bi)].append(cumulative_value[0])\n\n    performance_df = []\n    performance_auc_df = []\n\n    for dataset in args.datasets:\n        for seed in range(args.num_trials):\n            for bi, method in enumerate(args.method_dirs):\n\n                results = combined_dataframe[\n                    (combined_dataframe[\"dataset\"] == dataset) & \n                    (combined_dataframe[\"method\"] == method) & \n                    (combined_dataframe[\"seed\"] == seed)\n                ]\n\n                for examples in [1, 2, 4, 8, 16]:\n\n                    if (dataset, bi, examples) not in baseline_performance: continue\n\n                    value = results[results[\"examples_per_class\"] \n                                    == examples][\"value\"].to_numpy()\n\n                    baseline_value = np.mean(\n                        baseline_performance[(dataset, bi, examples)])\n\n                    if value.size > 0: performance_df.append(dict(\n                        dataset=dataset,\n                        method=method,\n                        seed=seed,\n                        examples_per_class=examples,\n                        value=value[0] - baseline_value,\n                    ))\n\n                if (dataset, bi) not in baseline_performance: continue\n\n                valid = 0\n                cumulative_value = -np.mean(baseline_performance[(dataset, bi)])\n\n                for examples_a, examples_b in zip([1, 2, 4, 8], [2, 4, 8, 16]):\n\n                    value_a = results[results[\"examples_per_class\"] == examples_a][\"value\"].to_numpy()\n                    value_b = results[results[\"examples_per_class\"] == examples_b][\"value\"].to_numpy()\n\n                    if value_a.size > 0 and value_b.size > 0:\n                        cumulative_value += ((value_a + value_b) / 2) * (examples_b - examples_a)\n                        valid += 1\n\n                if valid == 4:\n\n                    performance_auc_df.append(dict(\n                        dataset=dataset,\n                        method=method,\n                        seed=seed,\n                        value=cumulative_value[0],\n                    ))\n\n    performance_df = pd.DataFrame.from_records(performance_df)\n    performance_auc_df = pd.DataFrame.from_records(performance_auc_df)\n\n    for dataset in args.datasets:\n\n        df = performance_df.loc[performance_df[\"dataset\"] == dataset]\n        if df.size == 0: continue\n\n        acc_max = df[\"value\"].to_numpy().max()\n        acc_min = df[\"value\"].to_numpy().min()\n\n        performance_df.loc[\n            performance_df[\"dataset\"] == dataset, \n            \"normalized_value\"\n        ] = (df[\"value\"] - acc_min) / (acc_max - acc_min)\n\n        df = performance_auc_df.loc[performance_auc_df[\"dataset\"] == dataset]\n        if df.size == 0: continue\n\n        acc_max = df[\"value\"].to_numpy().max()\n        acc_min = df[\"value\"].to_numpy().min()\n\n        performance_auc_df.loc[\n            performance_auc_df[\"dataset\"] == dataset, \n            \"normalized_value\"\n        ] = (df[\"value\"] - acc_min) / (acc_max - acc_min)\n\n    for i, (style, dataset) in enumerate(\n            product([\"line\", \"bar\"], args.datasets)):\n\n        results = performance_df[\n            performance_df[\"dataset\"] == dataset] \\\n            if dataset != \"Overall\" else performance_df\n\n        results_auc = performance_auc_df[\n            performance_auc_df[\"dataset\"] == dataset] \\\n            if dataset != \"Overall\" else performance_auc_df\n\n        if style == \"line\":\n\n            axis = sns.lineplot(\n                y=\"normalized_value\" if dataset == \"Overall\" else \"value\", \n                x=\"examples_per_class\", hue=\"method\", \n                data=results, errorbar=('ci', 68),\n                linewidth=4, palette=color_palette,\n                ax=(axs[i // columns, i % columns] \n                    if args.rows > 1 and len(args.datasets) > 1 \n                    else axs[i] if len(args.datasets) > 1 else axs))\n\n            if i == 0: handles, labels = axis.get_legend_handles_labels()\n            axis.legend([],[], frameon=False)\n\n            axis.set(xlabel=None)\n            axis.set(ylabel=None)\n\n            axis.spines['right'].set_visible(False)\n            axis.spines['top'].set_visible(False)\n\n            axis.xaxis.set_ticks_position('bottom')\n            axis.yaxis.set_ticks_position('left')\n\n            axis.yaxis.set_tick_params(labelsize=16)\n            axis.xaxis.set_tick_params(labelsize=16)\n\n            if i // columns == args.rows - 1:\n                axis.set_xlabel(\"Examples Per Class\", fontsize=24,\n                                fontweight='bold', labelpad=12)\n\n            axis.set_ylabel(\"Normalized Score\" if dataset == \"Overall\" \n                            else \"Gained Accuracy (Val)\", fontsize=24,\n                            fontweight='bold', labelpad=12)\n\n        elif style == \"bar\":\n\n            axis = sns.barplot(\n                y=\"normalized_value\" if dataset == \"Overall\" else \"value\", \n                x=\"method\", data=results_auc, errorbar=('ci', 68),\n                linewidth=4, palette=color_palette,\n                ax=(axs[i // columns, i % columns] \n                    if args.rows > 1 and len(args.datasets) > 1 \n                    else axs[i] if len(args.datasets) > 1 else axs))\n\n            if i == 0: handles, labels = axis.get_legend_handles_labels()\n            axis.legend([],[], frameon=False)\n\n            axis.set(xlabel=None)\n            axis.set(ylabel=None)\n\n            axis.spines['right'].set_visible(False)\n            axis.spines['top'].set_visible(False)\n\n            axis.xaxis.set_ticks_position('bottom')\n            axis.yaxis.set_ticks_position('left')\n\n            axis.yaxis.set_tick_params(labelsize=16)\n            axis.xaxis.set_ticklabels([])\n\n            acc_max = results_auc[\"normalized_value\" if dataset == \"Overall\" else \"value\"].to_numpy().max()\n            acc_min = results_auc[\"normalized_value\" if dataset == \"Overall\" else \"value\"].to_numpy().min()\n            axis.set_ylim(max(0, acc_min), acc_max)\n\n            axis.set_ylabel(\"Normalized Score\" if dataset == \"Overall\" \n                            else \"Gained AUC (Val)\", fontsize=24,\n                            fontweight='bold', labelpad=12)\n\n        axis.set_title(dataset, fontsize=24, fontweight='bold', pad=12)\n\n        axis.grid(color='grey', linestyle='dotted', linewidth=2)\n\n    if not args.no_legend:\n\n        legend = fig.legend(handles, [x for x in args.method_names],\n                            loc=\"lower center\", prop={'size': 24, 'weight': 'bold'}, \n                            ncol=min(len(args.method_names), 2 * len(args.datasets)))\n\n        for i, legend_object in enumerate(legend.legendHandles):\n            legend_object.set_linewidth(4.0)\n            legend_object.set_color(color_palette[i])\n\n    plt.tight_layout(pad=1.0)\n    fig.subplots_adjust(hspace=0.3)\n\n    if not args.no_legend:\n\n        fig.subplots_adjust(bottom=(\n            0.25 if legend_rows == 1 else\n            0.35 if legend_rows == 2 else 0.4\n        ) / args.rows + 0.05)\n\n    plt.savefig(f\"{args.name}.pdf\")\n    plt.savefig(f\"{args.name}.png\")"
  },
  {
    "path": "plot_stacking_ablation.py",
    "content": "import matplotlib.pyplot as plt\nimport matplotlib\nimport pandas as pd\nimport seaborn as sns\n\nimport os\nimport glob\nimport argparse\nimport math\n\n\ndef pretty(text):\n    \"\"\"Convert a string into a consistent format for\n    presentation in a matplotlib pyplot:\n    this version looks like: One Two Three Four\n    \"\"\"\n\n    text = text.replace(\"_\", \" \")\n    text = text.replace(\"-\", \" \")\n    text = text.replace(\"/\", \" \")\n    text = text.strip()\n    prev_c = None\n    out_str = []\n    for c in text:\n        if prev_c is not None and \\\n                prev_c.islower() and c.isupper():\n            out_str.append(\" \")\n            prev_c = \" \"\n        if prev_c is None or prev_c == \" \":\n            c = c.upper()\n        out_str.append(c)\n        prev_c = c\n    return \"\".join(out_str)\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Few-Shot Baseline\")\n\n    parser.add_argument(\"--logdirs\", nargs=\"+\", type=str, default=[\n        \"./spurge-baselines\", \"./pascal-baselines\", \"./coco-baselines\"])\n    \n    parser.add_argument(\"--datasets\", nargs=\"+\", type=str, \n                        default=[\"Spurge\", \"Pascal\", \"COCO\", \"Overall\"])\n    \n    parser.add_argument(\"--method-dirs\", nargs=\"+\", type=str, \n                        default=[\"textual-inversion-0.5\", \"textual-inversion-1.0-0.75-0.5-0.25\"])\n    \n    parser.add_argument(\"--baseline-dir\", type=str, default=\"baseline\")\n    \n    parser.add_argument(\"--method-names\", nargs=\"+\", type=str, \n                        default=[\"Model-Based Data Augmentation (k = 1)\", \n                                 \"Model-Based Data Augmentation (k = 4)\"])\n    \n    parser.add_argument(\"--name\", type=str, default=\"stacking-results-bar\")\n    parser.add_argument(\"--rows\", type=int, default=1)\n    parser.add_argument(\"--num-trials\", type=int, default=8)\n    \n    parser.add_argument(\"--no-legend\", action=\"store_true\")\n\n    args = parser.parse_args()\n\n    combined_dataframe = []\n\n    for logdir, dataset in zip(\n            args.logdirs, args.datasets):\n\n        for bname in os.listdir(logdir):\n\n            bpath = os.path.join(logdir, bname)\n\n            if not os.path.isdir(bpath):\n                continue\n\n            files = list(glob.glob(os.path.join(bpath, \"*.csv\")))\n\n            if len(files) == 0:\n                continue\n\n            data = pd.concat([pd.read_csv(x, index_col=0) \n                              for x in files], ignore_index=True)\n\n            data = data[(data[\"metric\"] == \"Accuracy\") & \n                        (data[ \"split\"] == \"Validation\")]\n\n            def select_by_epoch(df):\n                selected_row = df.loc[df[\"value\"].idxmax()]\n                return data[(data[\"epoch\"] == selected_row[\"epoch\"]) & \n                            (data[ \"examples_per_class\"] == \n                            selected_row[\"examples_per_class\"])]\n\n            best = data.groupby([\"examples_per_class\", \"epoch\"])\n            best = best[\"value\"].mean().to_frame('value').reset_index()\n            best = best.groupby(\"examples_per_class\").apply(\n                select_by_epoch\n            )\n\n            best[\"method\"] = bname\n            best[\"dataset\"] = dataset\n            combined_dataframe.append(best)\n\n    matplotlib.rc('font', family='Times New Roman', serif='cm10')\n    matplotlib.rc('mathtext', fontset='cm')\n    plt.rcParams['text.usetex'] = False\n\n    combined_dataframe = pd.concat(\n        combined_dataframe, ignore_index=True)\n\n    combined_dataframe = pd.concat([combined_dataframe[\n        combined_dataframe['method'] == n] for n in args.method_dirs + [args.baseline_dir]])\n    \n    color_palette = sns.color_palette(n_colors=len(args.method_dirs))\n\n    legend_rows = int(math.ceil(len(args.method_names) / len(args.datasets)))\n    columns = int(math.ceil(len(args.datasets) / args.rows))\n\n    fig, axs = plt.subplots(\n        args.rows, columns,\n        figsize=(6 * columns, 3.5 * args.rows + ((\n            2.0 if legend_rows == 1 else\n            2.5 if legend_rows == 2 else 3\n        ) if not args.no_legend else 1.0)))\n\n    auc_df = []\n    baseline_performance = {dataset: 0 for dataset in args.datasets}\n\n    for dataset in args.datasets:\n        for seed in range(args.num_trials):\n\n            results = combined_dataframe[\n                (combined_dataframe[\"dataset\"] == dataset) & \n                (combined_dataframe[\"method\"] == args.baseline_dir) & \n                (combined_dataframe[\"seed\"] == seed)\n            ]\n\n            cumulative_value = 0.0\n            invalid = False\n\n            for examples_a, examples_b in zip([1, 2, 4, 8], [2, 4, 8, 16]):\n\n                value_a = results[results[\"examples_per_class\"] == examples_a][\"value\"].to_numpy()\n                value_b = results[results[\"examples_per_class\"] == examples_b][\"value\"].to_numpy()\n\n                if value_a.size > 0 and value_b.size > 0:\n                    cumulative_value += ((value_a + value_b) / 2) * (examples_b - examples_a)\n\n                else: invalid = True\n\n            if not invalid:\n\n                baseline_performance[dataset] += \\\n                    cumulative_value[0] / args.num_trials\n\n    for dataset in args.datasets:\n        for method in args.method_dirs:\n            for seed in range(args.num_trials):\n\n                results = combined_dataframe[\n                    (combined_dataframe[\"dataset\"] == dataset) & \n                    (combined_dataframe[\"method\"] == method) & \n                    (combined_dataframe[\"seed\"] == seed)\n                ]\n\n                if dataset not in baseline_performance: continue\n\n                cumulative_value = -baseline_performance[dataset]\n                invalid = False\n\n                for examples_a, examples_b in zip([1, 2, 4, 8], [2, 4, 8, 16]):\n\n                    value_a = results[results[\"examples_per_class\"] == examples_a][\"value\"].to_numpy()\n                    value_b = results[results[\"examples_per_class\"] == examples_b][\"value\"].to_numpy()\n\n                    if value_a.size > 0 and value_b.size > 0:\n                        cumulative_value += ((value_a + value_b) / 2) * (examples_b - examples_a)\n\n                    else: invalid = True\n\n                if not invalid:\n\n                    auc_df.append(\n                        dict(\n                            dataset=dataset,\n                            method=method,\n                            seed=seed,\n                            value=cumulative_value[0],\n                        )\n                    )\n\n    combined_dataframe = pd.DataFrame.from_records(auc_df)\n\n    for dataset in args.datasets:\n\n        df = combined_dataframe.loc[combined_dataframe[\"dataset\"] == dataset]\n        if df.size == 0: continue\n\n        acc_max = df[\"value\"].to_numpy().max()\n        acc_min = df[\"value\"].to_numpy().min()\n\n        combined_dataframe.loc[\n            combined_dataframe[\"dataset\"] == dataset, \n            \"normalized_value\"\n        ] = (df[\"value\"] - acc_min) / (acc_max - acc_min)\n\n    for i, dataset in enumerate(args.datasets):\n\n        results = combined_dataframe[combined_dataframe[\n            \"dataset\"] == dataset] if dataset != \"Overall\" else combined_dataframe\n\n        axis = sns.barplot(\n            y=\"normalized_value\" if dataset == \"Overall\" else \"value\", \n            x=\"method\", data=results, errorbar=('ci', 68),\n            linewidth=4, palette=color_palette,\n            ax=(axs[i // columns, i % columns] \n                if args.rows > 1 and len(args.datasets) > 1 \n                else axs[i] if len(args.datasets) > 1 else axs))\n\n        if i == 0: handles, labels = axis.get_legend_handles_labels()\n        axis.legend([],[], frameon=False)\n\n        axis.set(xlabel=None)\n        axis.set(ylabel=None)\n\n        axis.spines['right'].set_visible(False)\n        axis.spines['top'].set_visible(False)\n\n        axis.xaxis.set_ticks_position('bottom')\n        axis.yaxis.set_ticks_position('left')\n\n        axis.yaxis.set_tick_params(labelsize=16)\n        axis.xaxis.set_ticklabels([])\n\n        acc_max = results[\"normalized_value\" if dataset == \"Overall\" else \"value\"].to_numpy().max()\n        acc_min = results[\"normalized_value\" if dataset == \"Overall\" else \"value\"].to_numpy().min()\n        axis.set_ylim(max(0, acc_min), acc_max)\n\n        axis.set_ylabel(\"Normalized Score\" if dataset == \"Overall\" \n                        else \"Gained AUC (Val)\", fontsize=24,\n                        fontweight='bold', labelpad=12)\n\n        axis.set_title(dataset, fontsize=24, fontweight='bold', pad=12)\n\n        axis.grid(color='grey', linestyle='dotted', linewidth=2)\n\n    if not args.no_legend:\n\n        legend = fig.legend([x for x in args.method_names],\n                            loc=\"lower center\", prop={'size': 24, 'weight': 'bold'}, \n                            ncol=min(len(args.method_names), len(args.datasets)))\n\n        for i, legend_object in enumerate(legend.legendHandles):\n            legend_object.set_linewidth(4.0)\n            legend_object.set_color(color_palette[i])\n\n    plt.tight_layout(pad=1.0)\n    fig.subplots_adjust(hspace=0.3)\n\n    if not args.no_legend:\n\n        fig.subplots_adjust(bottom=(\n            0.20 if legend_rows == 1 else\n            0.30 if legend_rows == 2 else 0.35\n        ) / args.rows + 0.05)\n\n    plt.savefig(f\"{args.name}.pdf\")\n    plt.savefig(f\"{args.name}.png\")"
  },
  {
    "path": "plot_stratify.py",
    "content": "import matplotlib.pyplot as plt\nimport matplotlib\nimport pandas as pd\nimport seaborn as sns\n\nimport os\nimport glob\nimport argparse\n\n\ndef pretty(text):\n    \"\"\"Convert a string into a consistent format for\n    presentation in a matplotlib pyplot:\n    this version looks like: One Two Three Four\n    \"\"\"\n\n    text = text.replace(\"_\", \" \")\n    text = text.replace(\"-\", \" \")\n    text = text.replace(\"/\", \" \")\n    text = text.strip()\n    prev_c = None\n    out_str = []\n    for c in text:\n        if prev_c is not None and \\\n                prev_c.islower() and c.isupper():\n            out_str.append(\" \")\n            prev_c = \" \"\n        if prev_c is None or prev_c == \" \":\n            c = c.upper()\n        out_str.append(c)\n        prev_c = c\n    return \"\".join(out_str)\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Few-Shot Baseline\")\n\n    parser.add_argument(\"--logdirs\", nargs=\"+\", type=str, default=[\n        \"./pascal-baselines\"])\n    \n    parser.add_argument(\"--datasets\", nargs=\"+\", type=str, \n                        default=[\"Pascal\"])\n    \n    parser.add_argument(\"--method-names\", nargs=\"+\", type=str, \n                        default=[\"baseline\", \"real-guidance-0.5\", \"textual-inversion-0.5\"])\n    \n    parser.add_argument(\"--name\", type=str, default=\"stratification\")\n\n    args = parser.parse_args()\n\n    combined_dataframe = []\n\n    for dataset, logdir in zip(\n            args.datasets, args.logdirs):\n\n        for bname in os.listdir(logdir):\n\n            bpath = os.path.join(logdir, bname)\n\n            if not os.path.isdir(bpath):\n                continue\n\n            files = list(glob.glob(os.path.join(bpath, \"*.csv\")))\n\n            if len(files) == 0:\n                continue\n\n            data = pd.concat([pd.read_csv(x, index_col=0) \n                              for x in files], ignore_index=True)\n\n            data = data[(data[\"metric\"].str.contains(\"Accuracy \")) & \n                        (data[ \"split\"] == \"Validation\")]\n\n            def select_by_epoch(df):\n                selected_row = df.loc[df[\"value\"].idxmax()]\n                return data[(data[\"epoch\"] == selected_row[\"epoch\"]) & \n                            (data[ \"examples_per_class\"] == \n                            selected_row[\"examples_per_class\"])]\n\n            best = data.groupby([\"examples_per_class\", \"epoch\"])\n            best = best[\"value\"].mean().to_frame('value').reset_index()\n            best = best.groupby(\"examples_per_class\").apply(\n                select_by_epoch\n            )\n\n            best[\"method\"] = bname\n            best[\"dataset\"] = dataset\n\n            combined_dataframe.append(best)\n\n    matplotlib.rc('font', family='Times New Roman', serif='cm10')\n    matplotlib.rc('mathtext', fontset='cm')\n    plt.rcParams['text.usetex'] = False\n\n    combined_dataframe = pd.concat(\n        combined_dataframe, ignore_index=True)\n\n    combined_dataframe = pd.concat([combined_dataframe[\n        combined_dataframe['method'] == n] for n in args.method_names])\n\n    combined_dataframe[\"class_name\"] = \\\n        combined_dataframe[\"metric\"].str.replace(\"Accuracy \", \"\")\n    \n    color_palette = sns.color_palette(n_colors=len(args.method_names))\n\n    for i, dataset in enumerate(args.datasets):\n\n        results = combined_dataframe[combined_dataframe[\"dataset\"] == dataset]\n\n        for j, examples_per_class in enumerate(\n                results[\"examples_per_class\"].unique()):\n\n            results2 = results[results[\"examples_per_class\"] == examples_per_class]\n            \n            fig, axs = plt.subplots(1, 1, figsize=(20, 6))\n\n            axis = sns.barplot(x=\"class_name\", y=\"value\", hue=\"method\", \n                               data=results2, errorbar=('ci', 68),\n                               linewidth=4, palette=color_palette,\n                               ax=axs)\n\n            if i == 0: handles, labels = axis.get_legend_handles_labels()\n            axis.legend([],[], frameon=False)\n\n            axis.set(xlabel=None)\n            axis.set(ylabel=None)\n\n            axis.spines['right'].set_visible(False)\n            axis.spines['top'].set_visible(False)\n\n            axis.xaxis.set_ticks_position('bottom')\n            axis.yaxis.set_ticks_position('left')\n\n            axis.yaxis.set_tick_params(labelsize=16)\n            axis.xaxis.set_tick_params(labelsize=16, labelrotation=45)\n\n            axis.set_ylabel(\"Accuracy (Val)\", fontsize=24,\n                            fontweight='bold', labelpad=12)\n\n            axis.set_title(f\"Dataset = {dataset} (Examples Per Class = {examples_per_class})\",\n                           fontsize=24, fontweight='bold', pad=12)\n\n            axis.grid(color='grey', linestyle='dotted', linewidth=2)\n\n            legend = fig.legend(handles, [pretty(x) for x in args.method_names],\n                                loc=\"lower center\", ncol=len(args.method_names),\n                                prop={'size': 24, 'weight': 'bold'})\n\n            for i, legend_object in enumerate(legend.legendHandles):\n                legend_object.set_linewidth(4.0)\n                legend_object.set_color(color_palette[i])\n\n            plt.tight_layout(pad=1.0)\n            fig.subplots_adjust(bottom=0.35)\n\n            plt.savefig(f\"{args.name}-{dataset}-{examples_per_class}-barplot.pdf\")\n            plt.savefig(f\"{args.name}-{dataset}-{examples_per_class}-barplot.png\")"
  },
  {
    "path": "scripts/baseline/launch_baseline_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./coco-baselines/baseline \\\n--dataset coco --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/baseline/launch_baseline_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./imagenet-baselines/baseline \\\n--dataset imagenet --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/baseline/launch_baseline_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./pascal-baselines/baseline \\\n--dataset pascal --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/baseline/launch_baseline_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./spurge-baselines/baseline \\\n--dataset spurge --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/baseline_randaugment/launch_baseline_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./randaugment-coco-baselines/baseline \\\n--dataset coco --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16 --use-randaugment"
  },
  {
    "path": "scripts/baseline_randaugment/launch_baseline_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./randaugment-imagenet-baselines/baseline \\\n--dataset imagenet --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16 --use-randaugment"
  },
  {
    "path": "scripts/baseline_randaugment/launch_baseline_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./randaugment-pascal-baselines/baseline \\\n--dataset pascal --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16 --use-randaugment"
  },
  {
    "path": "scripts/baseline_randaugment/launch_baseline_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./randaugment-spurge-baselines/baseline \\\n--dataset spurge --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16 --use-randaugment"
  },
  {
    "path": "scripts/cutmix_ablation/launch_baseline_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./cutmix-pascal-baselines/baseline \\\n--dataset pascal --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16 \\\n--use-cutmix"
  },
  {
    "path": "scripts/cutmix_ablation/launch_real_guidance=0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir cutmix-pascal-baselines/real-guidance-0.5-cap \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap/cutmix-{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-cutmix"
  },
  {
    "path": "scripts/cutmix_ablation/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir cutmix-pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25/cutmix-{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-cutmix"
  },
  {
    "path": "scripts/deit_backbone/launch_baseline_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py \\\n--logdir ./deit-pascal-baselines/baseline \\\n--dataset pascal --num-synthetic 0 \\\n--synthetic-probability 0.0 --num-trials 8 \\\n--examples-per-class 1 2 4 8 16 \\\n--classifier-backbone deit --image-size 224"
  },
  {
    "path": "scripts/deit_backbone/launch_real_guidance=0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir deit-pascal-baselines/real-guidance-0.5-cap \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap/deit-{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--classifier-backbone deit --image-size 224"
  },
  {
    "path": "scripts/deit_backbone/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir deit-pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25/deit-{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--classifier-backbone deit --image-size 224"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part0.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'accordion' 'airplanes' 'anchor' 'ant' 'background google'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part1.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'barrel' 'bass' 'beaver' 'binocular' 'bonsai'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part10.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'ibis' 'inline skate' 'joshua tree' 'kangaroo' 'ketch'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part11.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'lamp' 'laptop' 'leopards' 'llama' 'lobster'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part12.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'lotus' 'mandolin' 'mayfly' 'menorah' 'metronome'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part13.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'minaret' 'motorbikes' 'nautilus' 'octopus' 'okapi'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part14.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'pagoda' 'panda' 'pigeon' 'pizza' 'platypus'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part15.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'pyramid' 'revolver' 'rhino' 'rooster' 'saxophone'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part16.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'schooner' 'scissors' 'scorpion' 'sea horse' 'snoopy'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part17.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'soccer ball' 'stapler' 'starfish' 'stegosaurus' 'stop sign'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part18.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'strawberry' 'sunflower' 'tick' 'trilobite' 'umbrella'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part19.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'watch' 'water lilly' 'wheelchair' 'wild cat' 'windsor chair'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part2.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'brain' 'brontosaurus' 'buddha' 'butterfly' 'camera'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part20.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'wrench' 'yin yang'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part3.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'cannon' 'car side' 'ceiling fan' 'cellphone' 'chair'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part4.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'chandelier' 'cougar body' 'cougar face' 'crab' 'crayfish'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part5.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'crocodile' 'crocodile head' 'cup' 'dalmatian' 'dollar bill'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part6.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'dolphin' 'dragonfly' 'electric guitar' 'elephant' 'emu'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part7.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'euphonium' 'ewer' 'faces' 'faces easy' 'ferry'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part8.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'flamingo' 'flamingo head' 'garfield' 'gerenuk' 'gramophone'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_caltech101_part9.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'grand piano' 'hawksbill' 'headphone' 'hedgehog' 'helicopter'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part0.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'person' 'bicycle' 'car' 'motorcycle' 'airplane'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part1.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'bus' 'train' 'truck' 'boat' 'traffic light'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part10.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'broccoli' 'carrot' 'hot dog' 'pizza' 'donut'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part11.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'cake' 'chair' 'couch' 'potted plant' 'bed'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part12.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'dining table' 'toilet' 'tv' 'laptop' 'mouse'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part13.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'remote' 'keyboard' 'cell phone' 'microwave' 'oven'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part14.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'toaster' 'sink' 'refrigerator' 'book' 'clock'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part15.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'vase' 'scissors' 'teddy bear' 'hair drier' 'toothbrush'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part2.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'fire hydrant' 'stop sign' 'parking meter' 'bench' 'bird'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part3.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'cat' 'dog' 'horse' 'sheep' 'cow'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part4.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'elephant' 'bear' 'zebra' 'giraffe' 'backpack'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part5.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'umbrella' 'handbag' 'tie' 'suitcase' 'frisbee'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part6.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'skis' 'snowboard' 'sports ball' 'kite' 'baseball bat'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part7.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'baseball glove' 'skateboard' 'surfboard' 'tennis racket' 'bottle'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part8.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'wine glass' 'cup' 'fork' 'knife' 'spoon'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_coco_part9.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'bowl' 'banana' 'apple' 'sandwich' 'orange'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part0.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'pink primrose' 'hard-leaved pocket orchid' 'canterbury bells' 'sweet pea' 'english marigold'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part1.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'tiger lily' 'moon orchid' 'bird of paradise' 'monkshood' 'globe thistle'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part10.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'petunia' 'wild pansy' 'primula' 'sunflower' 'pelargonium'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part11.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'bishop of llandaff' 'gaura' 'geranium' 'orange dahlia' 'pink-yellow dahlia'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part12.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'cautleya spicata' 'japanese anemone' 'black-eyed susan' 'silverbush' 'californian poppy'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part13.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'osteospermum' 'spring crocus' 'bearded iris' 'windflower' 'tree poppy'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part14.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'gazania' 'azalea' 'water lily' 'rose' 'thorn apple'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part15.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'morning glory' 'passion flower' 'lotus' 'toad lily' 'anthurium'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part16.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'frangipani' 'clematis' 'hibiscus' 'columbine' 'desert-rose'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part17.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'tree mallow' 'magnolia' 'cyclamen ' 'watercress' 'canna lily'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part18.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'hippeastrum ' 'bee balm' 'ball moss' 'foxglove' 'bougainvillea'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part19.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'camellia' 'mallow' 'mexican petunia' 'bromelia' 'blanket flower'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part2.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'snapdragon' 'colt's foot' 'king protea' 'spear thistle' 'yellow iris'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part20.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'trumpet creeper' 'blackberry lily'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part3.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'globe-flower' 'purple coneflower' 'peruvian lily' 'balloon flower' 'giant white arum lily'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part4.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'fire lily' 'pincushion flower' 'fritillary' 'red ginger' 'grape hyacinth'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part5.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'corn poppy' 'prince of wales feathers' 'stemless gentian' 'artichoke' 'sweet william'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part6.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'carnation' 'garden phlox' 'love in the mist' 'mexican aster' 'alpine sea holly'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part7.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'ruby-lipped cattleya' 'cape flower' 'great masterwort' 'siam tulip' 'lenten rose'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part8.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'barbeton daisy' 'daffodil' 'sword lily' 'poinsettia' 'bolero deep blue'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_flowers102_part9.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'wallflower' 'marigold' 'buttercup' 'oxeye daisy' 'common dandelion'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part0.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'brassiere' 'curly-coated retriever' 'kuvasz' 'beagle' 'perfume'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part1.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'speedboat' 'ibex' 'volleyball' 'crash helmet' 'Cardigan'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part10.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'bolo tie' 'coral reef' 'French bulldog' 'meerkat' 'croquet ball'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part11.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'restaurant' 'dingo' 'thatch' 'traffic light' 'porcupine'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part12.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'great white shark' 'altar' 'cello' 'valley' 'black swan'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part13.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'hook' 'chambered nautilus' 'oil filter' 'patas' 'gong'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part14.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'pay-phone' 'moped' 'white wolf' 'dishwasher' 'garden spider'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part15.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'English setter' 'mushroom' 'dhole' 'Siamese cat' 'isopod'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part16.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'lifeboat' 'bikini' 'maraca' 'green lizard' 'swing'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part17.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'Lakeland terrier' 'hay' 'seashore' 'house finch' 'banjo'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part18.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'ringneck snake' 'harp' 'cinema' 'Sussex spaniel' 'frying pan'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part19.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'Labrador retriever' 'fire engine' 'crate' 'timber wolf' 'sloth bear'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part2.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'oboe' 'mousetrap' 'convertible' 'coho' 'sea cucumber'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part3.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'sulphur butterfly' 'ptarmigan' 'silky terrier' 'grocery store' 'titi'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part4.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'bell pepper' 'alp' 'spatula' 'police van' 'military uniform'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part5.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'confectionery' 'European fire salamander' 'hair slide' 'terrapin' 'microwave'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part6.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'guillotine' 'pelican' 'Chesapeake Bay retriever' 'hen' 'butcher shop'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part7.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'bluetick' 'chime' 'rugby ball' 'giant panda' 'toy terrier'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part8.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'lotion' 'hot pot' 'water bottle' 'peacock' 'wine bottle'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_imagenet_part9.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'paddlewheel' 'bannister' 'banana' 'wombat' 'trilobite'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_pascal_part0.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'airplane' 'bicycle' 'bird' 'boat' 'bottle'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_pascal_part1.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'bus' 'car' 'cat' 'chair' 'cow'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_pascal_part2.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'dining table' 'dog' 'horse' 'motorcycle' 'person'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_pascal_part3.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'potted plant' 'sheep' 'sofa' 'train' 'television'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/erase_spurge_part0.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in 'no spurge' 'leafy spurge'; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n"
  },
  {
    "path": "scripts/erase_classes/generate_scripts.py",
    "content": "from semantic_aug.datasets.coco import COCODataset\nfrom semantic_aug.datasets.spurge import SpurgeDataset\nfrom semantic_aug.datasets.imagenet import ImageNetDataset\nfrom semantic_aug.datasets.pascal import PASCALDataset\nfrom semantic_aug.datasets.caltech101 import CalTech101Dataset\nfrom semantic_aug.datasets.flowers102 import Flowers102Dataset\nimport numpy as np\nimport os\n\n\nSCRIPT_TEMPLATE = \"\"\"#!/bin/bash\n#SBATCH --job-name=erase\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate ldm\ncd ~/spurge/semantic-aug/stable-diffusion\n\nfor CLASS_NAME in {}; do\n\npython train-scripts/train-esd.py --prompt \"$CLASS_NAME\" --train_method 'full' --devices '0,0'\n\ndone\n\"\"\"\n\n\nPART_SIZE = 5\n\n\nscript_dir = os.path.dirname(os.path.realpath(__file__))\n\n\nif __name__ == \"__main__\":\n\n    for class_names, dataset_name in [\n            (COCODataset().class_names, \"coco\"), \n            (PASCALDataset().class_names, \"pascal\"), \n            (SpurgeDataset().class_names, \"spurge\"), \n            (CalTech101Dataset().class_names, \"caltech101\"), \n            (Flowers102Dataset().class_names, \"flowers102\"),\n            (ImageNetDataset().class_names, \"imagenet\")]:\n\n        num_parts = int(np.ceil(len(class_names) / PART_SIZE))\n\n        for i in range(num_parts):\n\n            part_names = class_names[i*PART_SIZE:(i + 1)*PART_SIZE]\n\n            with open(os.path.join(\n                script_dir, f\"erase_{dataset_name}_part{i}.sh\"), \"w\") as f:\n\n                f.write(SCRIPT_TEMPLATE.format(\n                    \" \".join([f\"'{x}'\" for x in part_names])))"
  },
  {
    "path": "scripts/fine_tuning/fine_tune_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38,matrix-1-18,matrix-1-20\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=coco --output_dir=./ \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16 "
  },
  {
    "path": "scripts/fine_tuning/fine_tune_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38,matrix-1-18,matrix-1-20\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=imagenet --output_dir=./ \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16 "
  },
  {
    "path": "scripts/fine_tuning/fine_tune_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=pascal --output_dir=./ \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16 "
  },
  {
    "path": "scripts/fine_tuning/fine_tune_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=spurge --output_dir=./ \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16 "
  },
  {
    "path": "scripts/fine_tuning_erasure/fine_tune_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38,matrix-1-18,matrix-1-20\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=coco --output_dir=./erasure-tokens \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16  --erase-concepts"
  },
  {
    "path": "scripts/fine_tuning_erasure/fine_tune_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38,matrix-1-18,matrix-1-20\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=imagenet --output_dir=./erasure-tokens \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16  --erase-concepts"
  },
  {
    "path": "scripts/fine_tuning_erasure/fine_tune_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=pascal --output_dir=./erasure-tokens \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16  --erase-concepts"
  },
  {
    "path": "scripts/fine_tuning_erasure/fine_tune_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=f-tune\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8,matrix-0-38\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython fine_tune.py --dataset=spurge --output_dir=./erasure-tokens \\\n--pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n--resolution=512 --train_batch_size=4 --lr_warmup_steps=0 \\\n--gradient_accumulation_steps=1 --max_train_steps=1000 \\\n--learning_rate=5.0e-04 --scale_lr --lr_scheduler=\"constant\" \\\n--mixed_precision=fp16 --revision=fp16 --gradient_checkpointing \\\n--only_save_embeds --num-trials 8 --examples-per-class 1 2 4 8 16 --erase-concepts"
  },
  {
    "path": "scripts/masking/launch_real_guidance=0-0.5_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/real-guidance-mask-0-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-mask-0-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo\" \\\n--aug real-guidance \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 1 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/masking/launch_real_guidance=0-0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/real-guidance-mask-0-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-mask-0-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 1 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/masking/launch_real_guidance=0.5-0_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/real-guidance-mask-0.5-0 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-mask-0.5-0/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo\" \\\n--aug real-guidance \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 0 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/masking/launch_real_guidance=0.5-0_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/real-guidance-mask-0.5-0 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-mask-0.5-0/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 0 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/masking/launch_textual_inversion=0-0.5_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/textual-inversion-mask-0-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-mask-0-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug textual-inversion \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 1 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/masking/launch_textual_inversion=0-0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/textual-inversion-mask-0-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-mask-0-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 1 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/masking/launch_textual_inversion=0.5-0_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/textual-inversion-mask-0.5-0 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-mask-0.5-0/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug textual-inversion \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 0 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/masking/launch_textual_inversion=0.5-0_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/textual-inversion-mask-0.5-0 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-mask-0.5-0/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion \\\n--guidance-scale 7.5 \\\n--strength 0.5 \\\n--mask 1 \\\n--inverted 0 \\\n--probs 1 \\\n--compose sequential --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/num_synthetic/launch_real_guidance=0.5_pascal_class_agnostic-20.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir num-synthetic-pascal-baselines/real-guidance-0.5-num-synthetic-20 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-num-synthetic-20/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 20 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/num_synthetic/launch_real_guidance=0.5_pascal_class_agnostic-5.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir num-synthetic-pascal-baselines/real-guidance-0.5-num-synthetic-5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-num-synthetic-5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 5 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/num_synthetic/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal-20.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir num-synthetic-pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25-num-synthetic-20 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25-num-synthetic-20/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 20 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/num_synthetic/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal-5.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir num-synthetic-pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25-num-synthetic-5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25-num-synthetic-5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 5 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_coco_class_agnostic.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/real-guidance-0.5-cap \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir imagenet-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_imagenet_class_agnostic.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir imagenet-baselines/real-guidance-0.5-cap \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_pascal_class_agnostic.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/real-guidance-0.5-cap \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir spurge-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a woodland seen from a drone\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 50 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance/launch_real_guidance=0.5_spurge_class_agnostic.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir spurge-baselines/real-guidance-0.5-cap \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 50 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/real_guidance_erasure/launch_real_guidance=0.5_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-coco-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/real_guidance_erasure/launch_real_guidance=0.5_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-imagenet-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/real_guidance_erasure/launch_real_guidance=0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-pascal-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/real_guidance_erasure/launch_real_guidance=0.5_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-spurge-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a woodland seen from a drone\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 50 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/real_guidance_randaugment/launch_real_guidance=0.5_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-coco-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/real_guidance_randaugment/launch_real_guidance=0.5_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-imagenet-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/real_guidance_randaugment/launch_real_guidance=0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-pascal-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/real_guidance_randaugment/launch_real_guidance=0.5_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-spurge-baselines/real-guidance-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\nreal-guidance-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a woodland seen from a drone\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/"
  },
  {
    "path": "scripts/stacking/launch_textual_inversion=1.0-0.75-0.5-0.25_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/stacking/launch_textual_inversion=1.0-0.75-0.5-0.25_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir imagenet-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/stacking/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/stacking/launch_textual_inversion=1.0-0.75-0.5-0.25_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir spurge-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 50 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/stacking_erasure/launch_textual_inversion=1.0-0.75-0.5-0.25_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-coco-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/stacking_erasure/launch_textual_inversion=1.0-0.75-0.5-0.25_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-imagenet-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/stacking_erasure/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/stacking_erasure/launch_textual_inversion=1.0-0.75-0.5-0.25_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir erasure-spurge-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/erasure/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 50 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/stacking_randaugment/launch_textual_inversion=1.0-0.75-0.5-0.25_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-coco-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/stacking_randaugment/launch_textual_inversion=1.0-0.75-0.5-0.25_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-imagenet-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/stacking_randaugment/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/stacking_randaugment/launch_textual_inversion=1.0-0.75-0.5-0.25_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir randaugment-spurge-baselines/textual-inversion-1.0-0.75-0.5-0.25 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/randaugment/\\\ntextual-inversion-1.0-0.75-0.5-0.25/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 50 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16 \\\n--use-randaugment --erasure-ckpt-path /projects/rsalakhugroup/btrabucc/esd-models/ \\\n--embed-path \"erasure-tokens/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\""
  },
  {
    "path": "scripts/synthetic_prob/launch_real_guidance=0.5_pascal_class_agnostic-0.3.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/real-guidance-0.5-cap+0.3 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap+0.3/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.3 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/synthetic_prob/launch_real_guidance=0.5_pascal_class_agnostic-0.7.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/real-guidance-0.5-cap+0.7 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\nreal-guidance-0.5-cap+0.7/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo\" \\\n--aug real-guidance --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.7 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/synthetic_prob/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal-0.3.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25+0.3 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25+0.3/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.3 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/synthetic_prob/launch_textual_inversion=1.0-0.75-0.5-0.25_pascal-0.7.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/textual-inversion-1.0-0.75-0.5-0.25+0.7 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-1.0-0.75-0.5-0.25+0.7/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion textual-inversion textual-inversion textual-inversion \\\n--guidance-scale 7.5 7.5 7.5 7.5 \\\n--strength 1.0 0.75 0.5 0.25 \\\n--mask 0 0 0 0 \\\n--inverted 0 0 0 0 \\\n--probs 0.25 0.25 0.25 0.25 \\\n--compose parallel --num-synthetic 10 --synthetic-probability 0.7 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/textual_inversion/launch_textual_inversion=0.5_coco.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=coco\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir coco-baselines/textual-inversion-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset coco --prompt \"a photo of a {name}\" \\\n--aug textual-inversion --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/textual_inversion/launch_textual_inversion=0.5_imagenet.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=imagenet\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir imagenet-baselines/textual-inversion-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset imagenet --prompt \"a photo of a {name}\" \\\n--aug textual-inversion --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/textual_inversion/launch_textual_inversion=0.5_pascal.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=pascal\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir pascal-baselines/textual-inversion-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset pascal --prompt \"a photo of a {name}\" \\\n--aug textual-inversion --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 10 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "scripts/textual_inversion/launch_textual_inversion=0.5_spurge.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spurge\n#SBATCH --exclude=matrix-1-12,matrix-0-24,matrix-1-4,matrix-2-13,matrix-1-8\n#SBATCH --time=72:00:00\n#SBATCH --nodes=1\n#SBATCH --partition=russ_reserved\n#SBATCH --gres=gpu:1\n#SBATCH --cpus-per-gpu=8\n#SBATCH --mem=32g\n#SBATCH --array=0-39\n \nsource ~/anaconda3/etc/profile.d/conda.sh\nconda activate semantic-aug\ncd ~/spurge/semantic-aug\n\nRANK=$SLURM_ARRAY_TASK_ID WORLD_SIZE=$SLURM_ARRAY_TASK_COUNT \\\npython train_classifier.py --logdir spurge-baselines/textual-inversion-0.5 \\\n--synthetic-dir \"/projects/rsalakhugroup/btrabucc/aug/\\\ntextual-inversion-0.5/{dataset}-{seed}-{examples_per_class}\" \\\n--dataset spurge --prompt \"a photo of a {name}\" \\\n--aug textual-inversion --guidance-scale 7.5 \\\n--strength 0.5 --mask 0 --inverted 0 \\\n--num-synthetic 50 --synthetic-probability 0.5 \\\n--num-trials 8 --examples-per-class 1 2 4 8 16"
  },
  {
    "path": "semantic_aug/__init__.py",
    "content": ""
  },
  {
    "path": "semantic_aug/augmentations/__init__.py",
    "content": ""
  },
  {
    "path": "semantic_aug/augmentations/compose.py",
    "content": "from semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom diffusers import StableDiffusionImg2ImgPipeline\nfrom diffusers import StableDiffusionInpaintPipeline\nfrom diffusers.utils import logging\nfrom PIL import Image\n\nfrom typing import List, Union, Any, Tuple\nfrom torch import autocast\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ComposeSequential(GenerativeAugmentation):\n\n    def __init__(self, augs: List[GenerativeAugmentation], \n                 probs: List[float] = None):\n\n        super(ComposeSequential, self).__init__()\n\n        self.augs = augs\n        self.probs = probs if probs is not None \\\n            else [1.0 for _ in augs]\n\n    def forward(self, image: Image.Image, label: int, \n                metadata: dict) -> Tuple[Image.Image, int]:\n\n        for aug, p in zip(self.augs, self.probs):\n\n            if np.random.uniform() < p:\n                image, label = aug(image, label, metadata)\n\n        return image, label\n\n\nclass ComposeParallel(GenerativeAugmentation):\n\n    def __init__(self, augs: List[GenerativeAugmentation], \n                 probs: List[float] = None):\n\n        super(ComposeParallel, self).__init__()\n\n        self.augs = augs\n        self.probs = probs if probs is not None \\\n            else [1.0 / len(augs) for _ in augs]\n\n    def forward(self, image: Image.Image, label: int, \n                metadata: dict) -> Tuple[Image.Image, int]:\n\n        idx = np.random.choice(len(self.probs), p=self.probs)\n\n        image, label = self.augs[idx](image, label, metadata)\n\n        return image, label"
  },
  {
    "path": "semantic_aug/augmentations/real_guidance.py",
    "content": "from semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom diffusers import StableDiffusionImg2ImgPipeline\nfrom diffusers import StableDiffusionInpaintPipeline\nfrom diffusers.utils import logging\nfrom PIL import Image, ImageOps\n\nfrom typing import Any, Tuple, Callable\nfrom torch import autocast\nfrom scipy.ndimage import maximum_filter\n\nimport os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass RealGuidance(GenerativeAugmentation):\n\n    pipe = None  # global sharing is a hack to avoid OOM\n\n    def __init__(self, model_path: str = \"CompVis/stable-diffusion-v1-4\",\n                 prompt: str = \"a photo of a {name}\",\n                 strength: float = 0.5, \n                 guidance_scale: float = 7.5,\n                 mask: bool = False,\n                 inverted: bool = False,\n                 mask_grow_radius: int = 16,\n                 erasure_ckpt_path: str = None,\n                 disable_safety_checker: bool = True,\n                 **kwargs):\n\n        super(RealGuidance, self).__init__()\n\n        if RealGuidance.pipe is None:\n\n            PipelineClass = (StableDiffusionInpaintPipeline \n                             if mask else \n                             StableDiffusionImg2ImgPipeline)\n\n            self.pipe = PipelineClass.from_pretrained(\n                model_path, use_auth_token=True,\n                revision=\"fp16\", \n                torch_dtype=torch.float16\n            ).to('cuda')\n\n            logging.disable_progress_bar()\n            self.pipe.set_progress_bar_config(disable=True)\n\n            if disable_safety_checker: \n                self.pipe.safety_checker = None\n\n        self.prompt = prompt\n        self.strength = strength\n        self.guidance_scale = guidance_scale\n\n        self.mask = mask\n        self.inverted = inverted\n        self.mask_grow_radius = mask_grow_radius\n\n        self.erasure_ckpt_path = erasure_ckpt_path\n        self.erasure_word_name = None\n\n    def forward(self, image: Image.Image, label: int, \n                metadata: dict) -> Tuple[Image.Image, int]:\n\n        canvas = image.resize((512, 512), Image.BILINEAR)\n        prompt = self.prompt.format(name=metadata.get(\"name\", \"\"))\n\n        if self.mask: assert \"mask\" in metadata, \\\n            \"mask=True but no mask present in metadata\"\n        \n        word_name = metadata.get(\"name\", \"\").replace(\" \", \"\")\n\n        if self.erasure_ckpt_path is not None and (\n                self.erasure_word_name is None \n                or self.erasure_word_name != word_name):\n\n            self.erasure_word_name = word_name\n            ckpt_name = \"method_full-sg_3-ng_1-iter_1000-lr_1e-05\"\n\n            ckpt_path = os.path.join(\n                self.erasure_ckpt_path, \n                f\"compvis-word_{word_name}-{ckpt_name}\",\n                f\"diffusers-word_{word_name}-{ckpt_name}.pt\")\n    \n            self.pipe.unet.load_state_dict(torch.load(\n                ckpt_path, map_location='cuda'))\n\n        kwargs = dict(\n            image=canvas,\n            prompt=[prompt], \n            strength=self.strength, \n            guidance_scale=self.guidance_scale\n        )\n\n        if self.mask:  # use focal object mask\n\n            mask_image = Image.fromarray((\n                np.where(metadata[\"mask\"], 255, 0)\n            ).astype(np.uint8)).resize((512, 512), Image.NEAREST)\n\n            mask_image = Image.fromarray(\n                maximum_filter(np.array(mask_image), \n                               size=self.mask_grow_radius))\n\n            if self.inverted:\n\n                mask_image = ImageOps.invert(\n                    mask_image.convert('L')).convert('1')\n\n            kwargs[\"mask_image\"] = mask_image\n\n        has_nsfw_concept = True\n        while has_nsfw_concept:\n            with autocast(\"cuda\"):\n                outputs = self.pipe(**kwargs)\n\n            has_nsfw_concept = (\n                self.pipe.safety_checker is not None \n                and outputs.nsfw_content_detected[0]\n            )\n\n        canvas = outputs.images[0].resize(\n            image.size, Image.BILINEAR)\n\n        return canvas, label"
  },
  {
    "path": "semantic_aug/augmentations/textual_inversion.py",
    "content": "from semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom diffusers import StableDiffusionImg2ImgPipeline\nfrom diffusers import StableDiffusionInpaintPipeline\nfrom transformers import (\n    CLIPFeatureExtractor, \n    CLIPTextModel, \n    CLIPTokenizer\n)\nfrom diffusers.utils import logging\nfrom PIL import Image, ImageOps\n\nfrom typing import Any, Tuple, Callable\nfrom torch import autocast\nfrom scipy.ndimage import maximum_filter\n\nimport os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nERROR_MESSAGE = \"Tokenizer already contains the token {token}. \\\nPlease pass a different `token` that is not already in the tokenizer.\"\n\n\ndef load_embeddings(embed_path: str,\n                    model_path: str = \"CompVis/stable-diffusion-v1-4\"):\n\n    tokenizer = CLIPTokenizer.from_pretrained(\n        model_path, use_auth_token=True,\n        subfolder=\"tokenizer\")\n\n    text_encoder = CLIPTextModel.from_pretrained(\n        model_path, use_auth_token=True, \n        subfolder=\"text_encoder\")\n\n    for token, token_embedding in torch.load(\n            embed_path, map_location=\"cpu\").items():\n\n        # add the token in tokenizer\n        num_added_tokens = tokenizer.add_tokens(token)\n        assert num_added_tokens > 0, ERROR_MESSAGE.format(token=token)\n    \n        # resize the token embeddings\n        text_encoder.resize_token_embeddings(len(tokenizer))\n        added_token_id = tokenizer.convert_tokens_to_ids(token)\n\n        # get the old word embeddings\n        embeddings = text_encoder.get_input_embeddings()\n\n        # get the id for the token and assign new embeds\n        embeddings.weight.data[added_token_id] = \\\n            token_embedding.to(embeddings.weight.dtype)\n\n    return tokenizer, text_encoder.to('cuda')\n\n\ndef format_name(name):\n    return f\"<{name.replace(' ', '_')}>\"\n\n\nclass TextualInversion(GenerativeAugmentation):\n\n    pipe = None  # global sharing is a hack to avoid OOM\n\n    def __init__(self, embed_path: str, \n                 model_path: str = \"CompVis/stable-diffusion-v1-4\",\n                 prompt: str = \"a photo of a {name}\",\n                 format_name: Callable = format_name,\n                 strength: float = 0.5, \n                 guidance_scale: float = 7.5,\n                 mask: bool = False,\n                 inverted: bool = False,\n                 mask_grow_radius: int = 16,\n                 erasure_ckpt_path: str = None,\n                 disable_safety_checker: bool = True,\n                 **kwargs):\n\n        super(TextualInversion, self).__init__()\n\n        if TextualInversion.pipe is None:\n\n            PipelineClass = (StableDiffusionInpaintPipeline \n                             if mask else \n                             StableDiffusionImg2ImgPipeline)\n\n            tokenizer, text_encoder = load_embeddings(\n                embed_path, model_path=model_path)\n\n            TextualInversion.pipe = PipelineClass.from_pretrained(\n                model_path, use_auth_token=True,\n                revision=\"fp16\", \n                torch_dtype=torch.float16\n            ).to('cuda')\n\n            self.pipe.tokenizer = tokenizer\n            self.pipe.text_encoder = text_encoder\n\n            logging.disable_progress_bar()\n            self.pipe.set_progress_bar_config(disable=True)\n\n            if disable_safety_checker:\n                self.pipe.safety_checker = None\n\n        self.prompt = prompt\n        self.strength = strength\n        self.guidance_scale = guidance_scale\n        self.format_name = format_name\n\n        self.mask = mask\n        self.inverted = inverted\n        self.mask_grow_radius = mask_grow_radius\n\n        self.erasure_ckpt_path = erasure_ckpt_path\n        self.erasure_word_name = None\n\n    def forward(self, image: Image.Image, label: int, \n                metadata: dict) -> Tuple[Image.Image, int]:\n\n        canvas = image.resize((512, 512), Image.BILINEAR)\n        name = self.format_name(metadata.get(\"name\", \"\"))\n        prompt = self.prompt.format(name=name)\n\n        if self.mask: assert \"mask\" in metadata, \\\n            \"mask=True but no mask present in metadata\"\n        \n        word_name = metadata.get(\"name\", \"\").replace(\" \", \"\")\n\n        if self.erasure_ckpt_path is not None and (\n                self.erasure_word_name is None \n                or self.erasure_word_name != word_name):\n\n            self.erasure_word_name = word_name\n            ckpt_name = \"method_full-sg_3-ng_1-iter_1000-lr_1e-05\"\n\n            ckpt_path = os.path.join(\n                self.erasure_ckpt_path, \n                f\"compvis-word_{word_name}-{ckpt_name}\",\n                f\"diffusers-word_{word_name}-{ckpt_name}.pt\")\n    \n            self.pipe.unet.load_state_dict(torch.load(\n                ckpt_path, map_location='cuda'))\n\n        kwargs = dict(\n            image=canvas,\n            prompt=[prompt], \n            strength=self.strength, \n            guidance_scale=self.guidance_scale\n        )\n\n        if self.mask:  # use focal object mask\n\n            mask_image = Image.fromarray((\n                np.where(metadata[\"mask\"], 255, 0)\n            ).astype(np.uint8)).resize((512, 512), Image.NEAREST)\n\n            mask_image = Image.fromarray(\n                maximum_filter(np.array(mask_image), \n                               size=self.mask_grow_radius))\n\n            if self.inverted:\n\n                mask_image = ImageOps.invert(\n                    mask_image.convert('L')).convert('1')\n\n            kwargs[\"mask_image\"] = mask_image\n\n        has_nsfw_concept = True\n        while has_nsfw_concept:\n            with autocast(\"cuda\"):\n                outputs = self.pipe(**kwargs)\n\n            has_nsfw_concept = (\n                self.pipe.safety_checker is not None \n                and outputs.nsfw_content_detected[0]\n            )\n\n        canvas = outputs.images[0].resize(\n            image.size, Image.BILINEAR)\n\n        return canvas, label"
  },
  {
    "path": "semantic_aug/augmentations/textual_inversion_upstream.py",
    "content": "from semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom diffusers import StableDiffusionImg2ImgPipeline\nfrom diffusers import StableDiffusionInpaintPipeline\nfrom transformers import (\n    CLIPFeatureExtractor, \n    CLIPTextModel, \n    CLIPTokenizer\n)\nfrom diffusers.utils import logging\nfrom PIL import Image, ImageOps\n\nfrom typing import Any, Tuple, Callable\nfrom torch import autocast\nfrom scipy.ndimage import maximum_filter\n\nimport os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom glob import glob\n\nERROR_MESSAGE = \"Tokenizer already contains the token {token}. \\\nPlease pass a different `token` that is not already in the tokenizer.\"\n\ndef format_name(name, num_tokens: int = 1):\n\n    special_token = f\"<{name.replace(' ', '_')}>\"\n\n    return \" \".join([\n        special_token\n        if token_idx == 0 else\n        f\"{special_token}_{token_idx}\"\n        for token_idx in range(num_tokens)\n    ])\n\nclass TextualInversion(GenerativeAugmentation):\n\n    pipe = None  # global sharing is a hack to avoid OOM\n\n    def __init__(self, embed_path: str, \n                 model_path: str = \"CompVis/stable-diffusion-v1-4\",\n                 prompt: str = \"a photo of a {name}\",\n                 format_name: Callable = format_name,\n                 strength: float = 0.5, \n                 guidance_scale: float = 7.5,\n                 mask: bool = False,\n                 inverted: bool = False,\n                 mask_grow_radius: int = 16,\n                 erasure_ckpt_path: str = None,\n                 disable_safety_checker: bool = True,\n                 tokens_per_class: int = 1,\n                 **kwargs):\n\n        super(TextualInversion, self).__init__()\n\n        if TextualInversion.pipe is None:\n\n            PipelineClass = (StableDiffusionInpaintPipeline \n                             if mask else \n                             StableDiffusionImg2ImgPipeline)\n\n            TextualInversion.pipe = PipelineClass.from_pretrained(\n                model_path, use_auth_token=True,\n                revision=\"fp16\", \n                torch_dtype=torch.float16\n            ).to('cuda')\n            \n            logging.disable_progress_bar()\n            self.pipe.set_progress_bar_config(disable=True)\n\n            if disable_safety_checker:\n                self.pipe.safety_checker = None\n        \n            embeds_list = glob(embed_path + '/**/learned_embeds.bin')\n            \n            for e in embeds_list:\n                self.pipe.load_textual_inversion(e)\n        \n        self.prompt = prompt\n        self.strength = strength\n        self.guidance_scale = guidance_scale\n        self.format_name = format_name\n        self.tokens_per_class = tokens_per_class\n\n        self.mask = mask\n        self.inverted = inverted\n        self.mask_grow_radius = mask_grow_radius\n\n        self.erasure_ckpt_path = erasure_ckpt_path\n        self.erasure_word_name = None\n\n    def forward(self, image: Image.Image, label: int, \n                metadata: dict) -> Tuple[Image.Image, int]:\n\n        canvas = image.resize((512, 512), Image.BILINEAR)\n        name = self.format_name(\n            metadata.get(\"name\", \"\"),\n            num_tokens=self.tokens_per_class)\n        prompt = self.prompt.format(name=name)\n\n        if self.mask: assert \"mask\" in metadata, \\\n            \"mask=True but no mask present in metadata\"\n        \n        word_name = metadata.get(\"name\", \"\").replace(\" \", \"\")\n\n        if self.erasure_ckpt_path is not None and (\n                self.erasure_word_name is None \n                or self.erasure_word_name != word_name):\n\n            self.erasure_word_name = word_name\n            ckpt_name = \"method_full-sg_3-ng_1-iter_1000-lr_1e-05\"\n\n            ckpt_path = os.path.join(\n                self.erasure_ckpt_path, \n                f\"compvis-word_{word_name}-{ckpt_name}\",\n                f\"diffusers-word_{word_name}-{ckpt_name}.pt\")\n    \n            self.pipe.unet.load_state_dict(torch.load(\n                ckpt_path, map_location='cuda'))\n\n        kwargs = dict(\n            image=canvas,\n            prompt=[prompt], \n            strength=self.strength, \n            guidance_scale=self.guidance_scale\n        )\n\n        if self.mask:  # use focal object mask\n\n            mask_image = Image.fromarray((\n                np.where(metadata[\"mask\"], 255, 0)\n            ).astype(np.uint8)).resize((512, 512), Image.NEAREST)\n\n            mask_image = Image.fromarray(\n                maximum_filter(np.array(mask_image), \n                               size=self.mask_grow_radius))\n\n            if self.inverted:\n\n                mask_image = ImageOps.invert(\n                    mask_image.convert('L')).convert('1')\n\n            kwargs[\"mask_image\"] = mask_image\n\n        has_nsfw_concept = True\n        while has_nsfw_concept:\n            with autocast(\"cuda\"):\n                outputs = self.pipe(**kwargs)\n\n            has_nsfw_concept = (\n                self.pipe.safety_checker is not None \n                and outputs.nsfw_content_detected[0]\n            )\n\n        canvas = outputs.images[0].resize(\n            image.size, Image.BILINEAR)\n\n        return canvas, label"
  },
  {
    "path": "semantic_aug/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "semantic_aug/datasets/caltech101.py",
    "content": "from semantic_aug.few_shot_dataset import FewShotDataset\nfrom semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom typing import Any, Tuple, Dict\n\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torchvision\nimport torch\nimport glob\nimport os\n\nfrom PIL import Image\nfrom collections import defaultdict\n\n\nDEFAULT_IMAGE_DIR = \"/projects/rsalakhugroup/datasets/caltech101/caltech101/101_ObjectCategories\"\n\n\nclass CalTech101Dataset(FewShotDataset):\n\n    class_names = ['accordion', 'airplanes', 'anchor', 'ant', \n        'background google', 'barrel', 'bass', 'beaver', 'binocular', \n        'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', \n        'cannon', 'car side', 'ceiling fan', 'cellphone', 'chair', \n        'chandelier', 'cougar body', 'cougar face', 'crab', 'crayfish', \n        'crocodile', 'crocodile head', 'cup', 'dalmatian', 'dollar bill', \n        'dolphin', 'dragonfly', 'electric guitar', 'elephant', 'emu', \n        'euphonium', 'ewer', 'faces', 'faces easy', 'ferry', 'flamingo', \n        'flamingo head', 'garfield', 'gerenuk', 'gramophone', 'grand piano', \n        'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', \n        'inline skate', 'joshua tree', 'kangaroo', 'ketch', 'lamp', 'laptop', \n        'leopards', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', \n        'menorah', 'metronome', 'minaret', 'motorbikes', 'nautilus', \n        'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', \n        'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', \n        'scissors', 'scorpion', 'sea horse', 'snoopy', 'soccer ball', \n        'stapler', 'starfish', 'stegosaurus', 'stop sign', 'strawberry', \n        'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water lilly', \n        'wheelchair', 'wild cat', 'windsor chair', 'wrench', 'yin yang']\n\n    num_classes: int = len(class_names)\n\n    def __init__(self, *args, split: str = \"train\", seed: int = 0, \n                 image_dir: str = DEFAULT_IMAGE_DIR, \n                 examples_per_class: int = None, \n                 generative_aug: GenerativeAugmentation = None, \n                 synthetic_probability: float = 0.5,\n                 use_randaugment: bool = False,\n                 image_size: Tuple[int] = (256, 256), **kwargs):\n\n        super(CalTech101Dataset, self).__init__(\n            *args, examples_per_class=examples_per_class,\n            synthetic_probability=synthetic_probability, \n            generative_aug=generative_aug, **kwargs)\n\n        class_to_images = defaultdict(list)\n\n        for image_path in glob.glob(os.path.join(image_dir, \"*/*.jpg\")):\n            class_name = image_path.split(\"/\")[-2].lower().replace(\"_\", \" \")\n            class_to_images[class_name].append(image_path)\n\n        rng = np.random.default_rng(seed)\n\n        class_to_ids = {key: rng.permutation(\n            len(class_to_images[key])) for key in self.class_names}\n        \n        class_to_ids = {key: np.array_split(class_to_ids[key], 2)[0 if split == \"train\" else 1] for key in self.class_names}\n\n        if examples_per_class is not None:\n            class_to_ids = {key: ids[:examples_per_class] \n                            for key, ids in class_to_ids.items()}\n\n        self.class_to_images = {\n            key: [class_to_images[key][i] for i in ids] \n            for key, ids in class_to_ids.items()}\n\n        self.all_images = sum([\n            self.class_to_images[key] \n            for key in self.class_names], [])\n\n        self.all_labels = [i for i, key in enumerate(\n            self.class_names) for _ in self.class_to_images[key]]\n\n        if use_randaugment: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandAugment(),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        else: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomRotation(degrees=15.0),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        val_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        self.transform = {\"train\": train_transform, \"val\": val_transform}[split]\n\n    def __len__(self):\n        \n        return len(self.all_images)\n\n    def get_image_by_idx(self, idx: int) -> Image.Image:\n\n        return Image.open(self.all_images[idx]).convert('RGB')\n\n    def get_label_by_idx(self, idx: int) -> int:\n\n        return self.all_labels[idx]\n    \n    def get_metadata_by_idx(self, idx: int) -> dict:\n\n        return dict(name=self.class_names[self.all_labels[idx]])"
  },
  {
    "path": "semantic_aug/datasets/coco.py",
    "content": "from semantic_aug.few_shot_dataset import FewShotDataset\nfrom semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom typing import Any, Tuple, Dict\n\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torch\nimport os\n\nfrom pycocotools.coco import COCO\nfrom PIL import Image\nfrom collections import defaultdict\n\n\nCOCO_DIR = \"/projects/rsalakhugroup/datasets/coco/coco_2017\"\n\nTRAIN_IMAGE_DIR = os.path.join(COCO_DIR, \"train2017\")\nVAL_IMAGE_DIR = os.path.join(COCO_DIR, \"val2017\")\n\nDEFAULT_TRAIN_INSTANCES = os.path.join(\n    COCO_DIR, \"annotations/instances_train2017.json\")\nDEFAULT_VAL_INSTANCES = os.path.join(\n    COCO_DIR, \"annotations/instances_val2017.json\")\n\n\nclass COCODataset(FewShotDataset):\n\n    class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', \n        'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', \n        'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', \n        'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', \n        'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', \n        'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', \n        'baseball glove', 'skateboard', 'surfboard', 'tennis racket', \n        'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', \n        'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', \n        'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', \n        'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', \n        'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', \n        'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', \n        'scissors', 'teddy bear', 'hair drier', 'toothbrush']\n\n    num_classes: int = len(class_names)\n\n    def __init__(self, *args, split: str = \"train\", seed: int = 0, \n                 train_image_dir: str = TRAIN_IMAGE_DIR, \n                 val_image_dir: str = VAL_IMAGE_DIR, \n                 train_instances_file: str = DEFAULT_TRAIN_INSTANCES, \n                 val_instances_file: str = DEFAULT_VAL_INSTANCES, \n                 examples_per_class: int = None, \n                 generative_aug: GenerativeAugmentation = None, \n                 synthetic_probability: float = 0.5,\n                 use_randaugment: bool = False,\n                 image_size: Tuple[int] = (256, 256), **kwargs):\n\n        super(COCODataset, self).__init__(\n            *args, examples_per_class=examples_per_class,\n            synthetic_probability=synthetic_probability, \n            generative_aug=generative_aug, **kwargs)\n\n        image_dir = {\"train\": train_image_dir, \"val\": val_image_dir}[split]\n        instances_file = {\"train\": train_instances_file, \"val\": val_instances_file}[split]\n\n        class_to_images = defaultdict(list)\n        class_to_annotations = defaultdict(list)\n\n        self.cocoapi = COCO(instances_file)\n        for image_id, x in self.cocoapi.imgs.items():\n\n            annotations = self.cocoapi.imgToAnns[image_id]\n            if len(annotations) == 0: continue\n\n            maximal_ann = max(annotations, key=lambda x: x[\"area\"])\n            class_name = self.cocoapi.cats[maximal_ann[\"category_id\"]][\"name\"]\n\n            class_to_images[class_name].append(\n                os.path.join(image_dir, x[\"file_name\"]))\n            class_to_annotations[class_name].append(maximal_ann)\n\n        rng = np.random.default_rng(seed)\n        class_to_ids = {key: rng.permutation(\n            len(class_to_images[key])) for key in self.class_names}\n\n        if examples_per_class is not None:\n            class_to_ids = {key: ids[:examples_per_class] \n                            for key, ids in class_to_ids.items()}\n\n        self.class_to_images = {\n            key: [class_to_images[key][i] for i in ids] \n            for key, ids in class_to_ids.items()}\n\n        self.class_to_annotations = {\n            key: [class_to_annotations[key][i] for i in ids] \n            for key, ids in class_to_ids.items()}\n\n        self.all_images = sum([\n            self.class_to_images[key] \n            for key in self.class_names], [])\n\n        self.all_annotations = sum([\n            self.class_to_annotations[key] \n            for key in self.class_names], [])\n\n        self.all_labels = [i for i, key in enumerate(\n            self.class_names) for _ in self.class_to_images[key]]\n\n        if use_randaugment: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandAugment(),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        else: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomRotation(degrees=15.0),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        val_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        self.transform = {\"train\": train_transform, \"val\": val_transform}[split]\n\n    def __len__(self):\n        \n        return len(self.all_images)\n\n    def get_image_by_idx(self, idx: int) -> torch.Tensor:\n\n        return Image.open(self.all_images[idx]).convert('RGB')\n\n    def get_label_by_idx(self, idx: int) -> torch.Tensor:\n\n        return self.all_labels[idx]\n    \n    def get_metadata_by_idx(self, idx: int) -> Dict:\n\n        annotation = self.all_annotations[idx]\n\n        return dict(name=self.class_names[self.all_labels[idx]], \n                    mask=self.cocoapi.annToMask(annotation),\n                    **annotation)"
  },
  {
    "path": "semantic_aug/datasets/flowers102.py",
    "content": "from semantic_aug.few_shot_dataset import FewShotDataset\nfrom semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom typing import Any, Tuple, Dict\n\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torchvision\nimport torch\nimport glob\nimport os\n\nfrom scipy.io import loadmat\nfrom PIL import Image\nfrom collections import defaultdict\n\n\nDEFAULT_IMAGE_DIR = \"/projects/rsalakhugroup/datasets/flowers102\"\n\n\nclass Flowers102Dataset(FewShotDataset):\n\n    class_names = [\n        'pink primrose',\n        'hard-leaved pocket orchid',\n        'canterbury bells',\n        'sweet pea',\n        'english marigold',\n        'tiger lily',\n        'moon orchid',\n        'bird of paradise',\n        'monkshood',\n        'globe thistle',\n        'snapdragon',\n        \"colt's foot\",\n        'king protea',\n        'spear thistle',\n        'yellow iris',\n        'globe-flower',\n        'purple coneflower',\n        'peruvian lily',\n        'balloon flower',\n        'giant white arum lily',\n        'fire lily',\n        'pincushion flower',\n        'fritillary',\n        'red ginger',\n        'grape hyacinth',\n        'corn poppy',\n        'prince of wales feathers',\n        'stemless gentian',\n        'artichoke',\n        'sweet william',\n        'carnation',\n        'garden phlox',\n        'love in the mist',\n        'mexican aster',\n        'alpine sea holly',\n        'ruby-lipped cattleya',\n        'cape flower',\n        'great masterwort',\n        'siam tulip',\n        'lenten rose',\n        'barbeton daisy',\n        'daffodil',\n        'sword lily',\n        'poinsettia',\n        'bolero deep blue',\n        'wallflower',\n        'marigold',\n        'buttercup',\n        'oxeye daisy',\n        'common dandelion',\n        'petunia',\n        'wild pansy',\n        'primula',\n        'sunflower',\n        'pelargonium',\n        'bishop of llandaff',\n        'gaura',\n        'geranium',\n        'orange dahlia',\n        'pink-yellow dahlia',\n        'cautleya spicata',\n        'japanese anemone',\n        'black-eyed susan',\n        'silverbush',\n        'californian poppy',\n        'osteospermum',\n        'spring crocus',\n        'bearded iris',\n        'windflower',\n        'tree poppy',\n        'gazania',\n        'azalea',\n        'water lily',\n        'rose',\n        'thorn apple',\n        'morning glory',\n        'passion flower',\n        'lotus',\n        'toad lily',\n        'anthurium',\n        'frangipani',\n        'clematis',\n        'hibiscus',\n        'columbine',\n        'desert-rose',\n        'tree mallow',\n        'magnolia',\n        'cyclamen ',\n        'watercress',\n        'canna lily',\n        'hippeastrum ',\n        'bee balm',\n        'ball moss',\n        'foxglove',\n        'bougainvillea',\n        'camellia',\n        'mallow',\n        'mexican petunia',\n        'bromelia',\n        'blanket flower',\n        'trumpet creeper',\n        'blackberry lily']\n\n    num_classes: int = len(class_names)\n\n    def __init__(self, *args, split: str = \"train\", seed: int = 0, \n                 image_dir: str = DEFAULT_IMAGE_DIR, \n                 examples_per_class: int = None, \n                 generative_aug: GenerativeAugmentation = None, \n                 synthetic_probability: float = 0.5,\n                 use_randaugment: bool = False,\n                 image_size: Tuple[int] = (256, 256), **kwargs):\n\n        super(Flowers102Dataset, self).__init__(\n            *args, examples_per_class=examples_per_class,\n            synthetic_probability=synthetic_probability, \n            generative_aug=generative_aug, **kwargs)\n\n        imagelabels = loadmat(os.path.join(image_dir, \"imagelabels.mat\"))[\"labels\"][0]\n        image_files = sorted(list(glob.glob(os.path.join(image_dir, \"jpg/*.jpg\"))))\n\n        class_to_images = defaultdict(list)\n\n        for image_idx, image_path in enumerate(image_files):\n            class_name = self.class_names[imagelabels[image_idx] - 1]\n            class_to_images[class_name].append(image_path)\n\n        rng = np.random.default_rng(seed)\n        class_to_ids = {key: rng.permutation(\n            len(class_to_images[key])) for key in self.class_names}\n        \n        class_to_ids = {key: np.array_split(class_to_ids[key], 2)[0 if split == \"train\" else 1] for key in self.class_names}\n\n        if examples_per_class is not None:\n            class_to_ids = {key: ids[:examples_per_class] \n                            for key, ids in class_to_ids.items()}\n\n        self.class_to_images = {\n            key: [class_to_images[key][i] for i in ids] \n            for key, ids in class_to_ids.items()}\n\n        self.all_images = sum([\n            self.class_to_images[key] \n            for key in self.class_names], [])\n\n        self.all_labels = [i for i, key in enumerate(\n            self.class_names) for _ in self.class_to_images[key]]\n\n        if use_randaugment: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandAugment(),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        else: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomRotation(degrees=15.0),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        val_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        self.transform = {\"train\": train_transform, \"val\": val_transform}[split]\n\n    def __len__(self):\n        \n        return len(self.all_images)\n\n    def get_image_by_idx(self, idx: int) -> Image.Image:\n\n        return Image.open(self.all_images[idx]).convert('RGB')\n\n    def get_label_by_idx(self, idx: int) -> int:\n\n        return self.all_labels[idx]\n    \n    def get_metadata_by_idx(self, idx: int) -> dict:\n\n        return dict(name=self.class_names[self.all_labels[idx]])"
  },
  {
    "path": "semantic_aug/datasets/imagenet.py",
    "content": "from semantic_aug.few_shot_dataset import FewShotDataset\nfrom semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom typing import Any, Tuple, Dict\nimport torchvision.transforms as transforms\nimport torch\nimport os\nfrom PIL import Image\nfrom collections import defaultdict\nimport numpy as np\n\n\nILSVRC_DIR = \"/projects/rsalakhugroup/datasets/imagenet\"\n\nLABEL_SYNSET = os.path.join(\n    ILSVRC_DIR, \"LOC_synset_mapping.txt\")\n\nTRAIN_IMAGE_SET = os.path.join(\n    ILSVRC_DIR, \"ILSVRC/ImageSets/CLS-LOC/train_cls.txt\")\nTRAIN_IMAGE_DIR = os.path.join(\n    ILSVRC_DIR, \"ILSVRC/Data/CLS-LOC/train\")\n\nVAL_IMAGE_SET = \"/projects/rsalakhugroup/spurge/val_cls.txt\"\nVAL_IMAGE_DIR = os.path.join(\n    ILSVRC_DIR, \"ILSVRC/Data/CLS-LOC/val\")\n\n\nclass ImageNetDataset(FewShotDataset):\n\n    class_names = ['steel arch bridge', 'ram', 'great white shark', 'sombrero', \n        'hamster', 'racket', 'chain mail', 'ski mask', 'potpie', 'cocktail shaker', \n        'Indian cobra', 'green snake', 'orange', 'Great Pyrenees', 'minibus', 'wall clock', \n        \"yellow lady's slipper\", 'vacuum', 'guillotine', 'redshank', 'pajama', \n        'tile roof', 'hen of the woods', 'oboe', 'overskirt', 'slug', 'running shoe', \n        'harp', 'strawberry', 'sturgeon', 'leatherback turtle', 'malamute', 'ladybug', \n        'mink', 'bulletproof vest', 'walking stick', 'can opener', 'pelican', \n        'projectile', 'gorilla', 'green mamba', 'drilling platform', \n        'black and gold garden spider', 'suit', 'volcano', 'hoopskirt', \n        'meat loaf', 'scuba diver', 'armadillo', 'crane', 'throne', 'barrel', \n        'golfcart', 'Border collie', 'fire engine', 'Indian elephant', \n        \"carpenter's kit\", 'black-and-tan coonhound', 'ballplayer', 'earthstar', \n        'Italian greyhound', 'confectionery', 'warthog', 'dishwasher', 'American egret', \n        'bald eagle', 'beagle', 'pinwheel', 'wombat', 'disk brake', 'pole', 'sandbar', 'drake',\n        'cheeseburger', 'sea anemone', 'computer keyboard', 'suspension bridge', 'ibex', \n        'toilet seat', 'vulture', 'coffee mug', 'Bouvier des Flandres', \n        'honeycomb', 'African chameleon', 'barn spider', 'ladle', 'Airedale', \n        'maze', 'scoreboard', 'fly', 'Bedlington terrier', \n        'yawl', 'revolver', 'racer', 'croquet ball', 'obelisk', 'mosque', \n        'dowitcher', 'shovel', 'sleeping bag']\n\n    num_classes: int = len(class_names)\n\n    def __init__(self, *args, split: str = \"train\", seed: int = 0,\n                 train_image_dir: str = TRAIN_IMAGE_DIR, \n                 val_image_dir: str = VAL_IMAGE_DIR, \n                 train_image_set: str = TRAIN_IMAGE_SET, \n                 val_image_set: str = VAL_IMAGE_SET, \n                 label_synset: str = LABEL_SYNSET,\n                 examples_per_class: int = None, \n                 generative_aug: GenerativeAugmentation = None, \n                 synthetic_probability: float = 0.5,\n                 use_randaugment: bool = False,\n                 image_size: Tuple[int] = (256, 256), **kwargs):\n\n        super(ImageNetDataset, self).__init__(\n            *args, examples_per_class=examples_per_class,\n            synthetic_probability=synthetic_probability, \n            generative_aug=generative_aug, **kwargs)\n\n        image_dir = {\"train\": train_image_dir, \"val\": val_image_dir}[split]\n        image_set = {\"train\": train_image_set, \"val\": val_image_set}[split]\n\n        with open(label_synset, \"r\") as f:\n            label_synset_lines = f.readlines()\n\n        self.dir_to_class_names = dict()\n\n        for synset in label_synset_lines:\n\n            dir_name, synset = synset.split(\" \", maxsplit=1)\n            class_name = synset.split(\",\")[0].strip()\n\n            self.dir_to_class_names[dir_name] = class_name\n\n        class_to_images = defaultdict(list)\n\n        with open(image_set, \"r\") as f:\n            image_set_lines = f.readlines()\n\n        for training_example in image_set_lines:\n\n            path, idx = training_example.split(\" \")\n            class_name = self.dir_to_class_names[path.split(\"/\")[0]]\n\n            class_to_images[class_name].append(\n                os.path.join(image_dir, path + \".JPEG\"))\n\n        rng = np.random.default_rng(seed)\n        class_to_ids = {key: rng.permutation(\n            len(class_to_images[key])) for key in self.class_names}\n\n        if examples_per_class is not None:\n            class_to_ids = {key: ids[:examples_per_class] \n                            for key, ids in class_to_ids.items()}\n\n        self.class_to_images = {\n            key: [class_to_images[key][i] for i in ids] \n            for key, ids in class_to_ids.items()}\n\n        self.all_images = sum([self.class_to_images[key] \n                               for key in self.class_names], [])\n\n        self.all_labels = [i for i, key in enumerate(\n            self.class_names) for _ in self.class_to_images[key]]\n\n        if use_randaugment: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandAugment(),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        else: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomRotation(degrees=15.0),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        val_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        self.transform = {\"train\": train_transform, \"val\": val_transform}[split]\n\n    def __len__(self):\n        \n        return len(self.all_images)\n\n    def get_image_by_idx(self, idx: int) -> torch.Tensor:\n\n        return Image.open(self.all_images[idx]).convert('RGB')\n\n    def get_label_by_idx(self, idx: int) -> torch.Tensor:\n\n        return self.all_labels[idx]\n    \n    def get_metadata_by_idx(self, idx: int) -> Dict:\n\n        return dict(name=self.class_names[self.all_labels[idx]])"
  },
  {
    "path": "semantic_aug/datasets/pascal.py",
    "content": "from semantic_aug.few_shot_dataset import FewShotDataset\nfrom semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom typing import Any, Tuple, Dict\n\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torch\nimport os\n\nfrom PIL import Image\nfrom collections import defaultdict\n\n\nPASCAL_DIR = \"/projects/rsalakhugroup/datasets/pascal\"\n\nTRAIN_IMAGE_SET = os.path.join(\n    PASCAL_DIR, \"ImageSets/Segmentation/train.txt\")\nVAL_IMAGE_SET = os.path.join(\n    PASCAL_DIR, \"ImageSets/Segmentation/val.txt\")\n\nDEFAULT_IMAGE_DIR = os.path.join(PASCAL_DIR, \"JPEGImages\")\nDEFAULT_LABEL_DIR = os.path.join(PASCAL_DIR, \"SegmentationClass\")\nDEFAULT_INSTANCE_DIR = os.path.join(PASCAL_DIR, \"SegmentationObject\")\n\n\nclass PASCALDataset(FewShotDataset):\n\n    class_names = ['airplane', 'bicycle', 'bird', 'boat', 'bottle', \n        'bus', 'car', 'cat', 'chair', 'cow', 'dining table', 'dog', \n        'horse', 'motorcycle', 'person', 'potted plant', 'sheep', \n        'sofa', 'train', 'television']\n\n    num_classes: int = len(class_names)\n\n    def __init__(self, *args, split: str = \"train\", seed: int = 0, \n                 train_image_set: str = TRAIN_IMAGE_SET, \n                 val_image_set: str = VAL_IMAGE_SET, \n                 image_dir: str = DEFAULT_IMAGE_DIR, \n                 label_dir: str = DEFAULT_LABEL_DIR, \n                 instance_dir: str = DEFAULT_INSTANCE_DIR, \n                 examples_per_class: int = None, \n                 generative_aug: GenerativeAugmentation = None, \n                 synthetic_probability: float = 0.5,\n                 use_randaugment: bool = False,\n                 image_size: Tuple[int] = (256, 256), **kwargs):\n\n        super(PASCALDataset, self).__init__(\n            *args, examples_per_class=examples_per_class,\n            synthetic_probability=synthetic_probability, \n            generative_aug=generative_aug, **kwargs)\n\n        image_set = {\"train\": train_image_set, \"val\": val_image_set}[split]\n\n        with open(image_set, \"r\") as f:\n            image_set_lines = [x.strip() for x in f.readlines()]\n\n        class_to_images = defaultdict(list)\n        class_to_annotations = defaultdict(list)\n\n        for image_id in image_set_lines:\n\n            labels = os.path.join(label_dir, image_id + \".png\")\n            instances = os.path.join(instance_dir, image_id + \".png\")\n\n            labels = np.asarray(Image.open(labels))\n            instances = np.asarray(Image.open(instances))\n\n            instance_ids, pixel_loc, counts = np.unique(\n                instances, return_index=True, return_counts=True)\n\n            counts[0] = counts[-1] = 0  # remove background\n\n            argmax_index = counts.argmax()\n\n            mask = np.equal(instances, instance_ids[argmax_index])\n            class_name = self.class_names[\n                labels.flat[pixel_loc[argmax_index]] - 1]\n\n            class_to_images[class_name].append(\n                os.path.join(image_dir, image_id + \".jpg\"))\n            class_to_annotations[class_name].append(dict(mask=mask))\n\n        rng = np.random.default_rng(seed)\n        class_to_ids = {key: rng.permutation(\n            len(class_to_images[key])) for key in self.class_names}\n\n        if examples_per_class is not None:\n            class_to_ids = {key: ids[:examples_per_class] \n                            for key, ids in class_to_ids.items()}\n\n        self.class_to_images = {\n            key: [class_to_images[key][i] for i in ids] \n            for key, ids in class_to_ids.items()}\n\n        self.class_to_annotations = {\n            key: [class_to_annotations[key][i] for i in ids] \n            for key, ids in class_to_ids.items()}\n\n        self.all_images = sum([\n            self.class_to_images[key] \n            for key in self.class_names], [])\n\n        self.all_annotations = sum([\n            self.class_to_annotations[key] \n            for key in self.class_names], [])\n\n        self.all_labels = [i for i, key in enumerate(\n            self.class_names) for _ in self.class_to_images[key]]\n\n        if use_randaugment: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandAugment(),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        else: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomRotation(degrees=15.0),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        val_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        self.transform = {\"train\": train_transform, \"val\": val_transform}[split]\n\n    def __len__(self):\n        \n        return len(self.all_images)\n\n    def get_image_by_idx(self, idx: int) -> Image.Image:\n\n        return Image.open(self.all_images[idx]).convert('RGB')\n\n    def get_label_by_idx(self, idx: int) -> int:\n\n        return self.all_labels[idx]\n    \n    def get_metadata_by_idx(self, idx: int) -> dict:\n\n        return dict(name=self.class_names[self.all_labels[idx]], \n                    **self.all_annotations[idx])"
  },
  {
    "path": "semantic_aug/datasets/spurge.py",
    "content": "from semantic_aug.few_shot_dataset import FewShotDataset\nfrom semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom typing import Any, Tuple\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\nimport os\nimport glob\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torch\n\n\nDEFAULT_DATA_DIR = os.path.join(\n    os.path.abspath(os.path.dirname(\n    os.path.dirname(os.path.dirname(\n        os.path.abspath(__file__))))), 'data/spurge')\n\n\nclass SpurgeDataset(FewShotDataset):\n\n    num_classes: int = 2\n    class_names = [\"no spurge\", \"leafy spurge\"]\n\n    def __init__(self, *args, data_dir: str = DEFAULT_DATA_DIR, \n                 split: str = \"train\", seed: int = 0, \n                 examples_per_class: int = None, \n                 generative_aug: GenerativeAugmentation = None, \n                 synthetic_probability: float = 0.5,\n                 use_randaugment: bool = False,\n                 image_size: Tuple[int] = (256, 256), **kwargs):\n\n        super(SpurgeDataset, self).__init__(\n            *args, examples_per_class=examples_per_class,\n            synthetic_probability=synthetic_probability,\n            generative_aug=generative_aug, **kwargs)\n\n        absent = list(glob.glob(os.path.join(data_dir, \"absent/*.png\")))\n        apparent = list(glob.glob(os.path.join(data_dir, \"apparent/*.png\")))\n\n        rng = np.random.default_rng(seed)\n\n        absent_ids = rng.permutation(len(absent))\n        apparent_ids = rng.permutation(len(apparent))\n\n        absent_ids_train, absent_ids_val = np.array_split(absent_ids, 2)\n        apparent_ids_train, apparent_ids_val = np.array_split(apparent_ids, 2)\n\n        absent_ids = {\"train\": absent_ids_train, \"val\": absent_ids_val}[split]\n        apparent_ids = {\"train\": apparent_ids_train, \"val\": apparent_ids_val}[split]\n\n        if examples_per_class is not None:\n            absent_ids = absent_ids[:examples_per_class]\n            apparent_ids = apparent_ids[:examples_per_class]\n\n        self.absent = [absent[i] for i in absent_ids]\n        self.apparent = [apparent[i] for i in apparent_ids]\n\n        self.all_images = self.absent + self.apparent\n        self.all_labels = [0] * len(self.absent) + [1] * len(self.apparent)\n\n        if use_randaugment: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandAugment(),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Lambda(lambda x: x.expand(3, *image_size)),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        else: train_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomVerticalFlip(p=0.5),\n            transforms.RandomRotation(degrees=45),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        val_transform = transforms.Compose([\n            transforms.Resize(image_size),\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5])\n        ])\n\n        self.transform = {\"train\": train_transform, \"val\": val_transform}[split]\n\n    def __len__(self):\n\n        return len(self.all_images)\n\n    def get_image_by_idx(self, idx: int) -> torch.Tensor:\n        \n        return Image.open(self.all_images[idx])\n\n    def get_label_by_idx(self, idx: int) -> torch.Tensor:\n        \n        return self.all_labels[idx]\n    \n    def get_metadata_by_idx(self, idx: int) -> Any:\n\n        return dict(name=self.class_names[self.all_labels[idx]])"
  },
  {
    "path": "semantic_aug/few_shot_dataset.py",
    "content": "from semantic_aug.generative_augmentation import GenerativeAugmentation\nfrom typing import Any, Tuple\nfrom torch.utils.data import Dataset\nfrom collections import defaultdict\nfrom itertools import product\nfrom tqdm import tqdm\nfrom PIL import Image\n\nimport torchvision.transforms as transforms\nimport torch\nimport numpy as np\nimport abc\nimport random\nimport os\n\n\nclass FewShotDataset(Dataset):\n\n    num_classes: int = None\n    class_names: int = None\n\n    def __init__(self, examples_per_class: int = None, \n                 generative_aug: GenerativeAugmentation = None, \n                 synthetic_probability: float = 0.5,\n                 synthetic_dir: str = None):\n\n        self.examples_per_class = examples_per_class\n        self.generative_aug = generative_aug\n\n        self.synthetic_probability = synthetic_probability\n        self.synthetic_dir = synthetic_dir\n        self.synthetic_examples = defaultdict(list)\n\n        self.transform = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.ConvertImageDtype(torch.float),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], \n                                  std=[0.5, 0.5, 0.5]),\n        ])\n        \n        if synthetic_dir is not None:\n            os.makedirs(synthetic_dir, exist_ok=True)\n    \n    @abc.abstractmethod\n    def get_image_by_idx(self, idx: int) -> Image.Image:\n\n        return NotImplemented\n    \n    @abc.abstractmethod\n    def get_label_by_idx(self, idx: int) -> int:\n\n        return NotImplemented\n    \n    @abc.abstractmethod\n    def get_metadata_by_idx(self, idx: int) -> dict:\n\n        return NotImplemented\n\n    def generate_augmentations(self, num_repeats: int):\n\n        self.synthetic_examples.clear()\n        options = product(range(len(self)), range(num_repeats))\n\n        for idx, num in tqdm(list(\n                options), desc=\"Generating Augmentations\"):\n\n            image = self.get_image_by_idx(idx)\n            label = self.get_label_by_idx(idx)\n\n            image, label = self.generative_aug(\n                image, label, self.get_metadata_by_idx(idx))\n\n            if self.synthetic_dir is not None:\n\n                pil_image, image = image, os.path.join(\n                    self.synthetic_dir, f\"aug-{idx}-{num}.png\")\n\n                pil_image.save(image)\n\n            self.synthetic_examples[idx].append((image, label))\n\n    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:\n\n        if len(self.synthetic_examples[idx]) > 0 and \\\n                np.random.uniform() < self.synthetic_probability:\n\n            image, label = random.choice(self.synthetic_examples[idx])\n            if isinstance(image, str): image = Image.open(image)\n\n        else:\n\n            image = self.get_image_by_idx(idx)\n            label = self.get_label_by_idx(idx)\n\n        return self.transform(image), label"
  },
  {
    "path": "semantic_aug/generative_augmentation.py",
    "content": "from torch.utils.data import Dataset\nfrom typing import Any, Tuple\nfrom PIL import Image\n\nimport torch.nn as nn\nimport torch\nimport abc\n\n\nclass GenerativeAugmentation(nn.Module, abc.ABC):\n\n    @abc.abstractmethod\n    def forward(self, image: Image.Image, label: int, \n                metadata: dict) -> Tuple[Image.Image, int]:\n\n        return NotImplemented"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import find_packages\nfrom setuptools import setup\n\n\nURL = 'https://github.com/brandontrabucco/semantic-aug'\nDESCRIPTION = \"Semantic Controls For Data Augmentation\"\nCLASSIFIERS = ['Intended Audience :: Developers',\n               'Intended Audience :: Science/Research',\n               'Topic :: Scientific/Engineering',\n               'Topic :: Scientific/Engineering :: Artificial Intelligence',\n               'Topic :: Scientific/Engineering :: Mathematics',\n               'Topic :: Software Development',\n               'Topic :: Software Development :: Libraries',\n               'Topic :: Software Development :: Libraries :: Python Modules',\n               'License :: OSI Approved :: MIT License',\n               'Programming Language :: Python :: 3',\n               'Programming Language :: Python :: 3.7',\n               'Programming Language :: Python :: 3.8',\n               'Programming Language :: Python :: 3.9']\n\n\nwith open('README.md', 'r') as readme:\n    LONG_DESCRIPTION = readme.read()  # use readme as long description\n\n\nsetup(name='semantic-aug', version='1.0', license='MIT',\n      author='Brandon Trabucco', author_email='brandon@btrabucco.com',\n      packages=find_packages(include=['semantic_aug', 'semantic_aug.*']),\n      classifiers=CLASSIFIERS, description=DESCRIPTION,\n      long_description=LONG_DESCRIPTION,\n      long_description_content_type='text/markdown',\n      url=URL, keywords=['Computer Vision', 'Data Augmentation'],\n      install_requires=['torch', 'torchvision', 'pandas'])"
  },
  {
    "path": "train_classifier.py",
    "content": "from semantic_aug.datasets.coco import COCODataset\nfrom semantic_aug.datasets.spurge import SpurgeDataset\nfrom semantic_aug.datasets.imagenet import ImageNetDataset\nfrom semantic_aug.datasets.pascal import PASCALDataset\nfrom semantic_aug.datasets.caltech101 import CalTech101Dataset\nfrom semantic_aug.datasets.flowers102 import Flowers102Dataset\nfrom semantic_aug.augmentations.compose import ComposeParallel\nfrom semantic_aug.augmentations.compose import ComposeSequential\nfrom semantic_aug.augmentations.real_guidance import RealGuidance\nfrom semantic_aug.augmentations.textual_inversion import TextualInversion\nfrom semantic_aug.augmentations.textual_inversion_upstream \\\n    import TextualInversion as MultiTokenTextualInversion\nfrom torch.utils.data import DataLoader\nfrom torchvision.models import resnet50, ResNet50_Weights\nfrom transformers import AutoImageProcessor, DeiTModel\nfrom itertools import product\nfrom tqdm import trange\nfrom typing import List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distributed as distributed\n\nimport argparse\nimport pandas as pd\nimport numpy as np\nimport random\nimport os\n\ntry: \n    from cutmix.cutmix import CutMix\n    IS_CUTMIX_INSTALLED = True\nexcept:\n    IS_CUTMIX_INSTALLED = False\n\n\nDEFAULT_MODEL_PATH = \"CompVis/stable-diffusion-v1-4\"\nDEFAULT_PROMPT = \"a photo of a {name}\"\n\nDEFAULT_SYNTHETIC_DIR = \"/projects/rsalakhugroup/\\\nbtrabucc/aug/{dataset}-{aug}-{seed}-{examples_per_class}\"\n\nDEFAULT_EMBED_PATH = \"{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt\"\n\nDATASETS = {\n    \"spurge\": SpurgeDataset, \n    \"coco\": COCODataset, \n    \"pascal\": PASCALDataset,\n    \"imagenet\": ImageNetDataset,\n    \"caltech\": CalTech101Dataset,\n    \"flowers\": Flowers102Dataset\n}\n\nCOMPOSERS = {\n    \"parallel\": ComposeParallel,\n    \"sequential\": ComposeSequential\n}\n\nAUGMENTATIONS = {\n    \"real-guidance\": RealGuidance,\n    \"textual-inversion\": TextualInversion,\n    \"multi-token-inversion\": MultiTokenTextualInversion\n}\n\n\ndef run_experiment(examples_per_class: int = 0, \n                   seed: int = 0, \n                   dataset: str = \"spurge\", \n                   num_synthetic: int = 100, \n                   iterations_per_epoch: int = 200, \n                   num_epochs: int = 50, \n                   batch_size: int = 32, \n                   aug: List[str] = None,\n                   strength: List[float] = None, \n                   guidance_scale: List[float] = None,\n                   mask: List[bool] = None,\n                   inverted: List[bool] = None, \n                   probs: List[float] = None,\n                   compose: str = \"parallel\",\n                   synthetic_probability: float = 0.5, \n                   synthetic_dir: str = DEFAULT_SYNTHETIC_DIR, \n                   embed_path: str = DEFAULT_EMBED_PATH,\n                   model_path: str = DEFAULT_MODEL_PATH,\n                   prompt: str = DEFAULT_PROMPT,\n                   tokens_per_class: int = 4,\n                   use_randaugment: bool = False,\n                   use_cutmix: bool = False,\n                   erasure_ckpt_path: str = None,\n                   image_size: int = 256,\n                   classifier_backbone: str = \"resnet50\"):\n\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n\n    if aug is not None:\n\n        aug = COMPOSERS[compose]([\n            \n            AUGMENTATIONS[aug](\n                embed_path=embed_path, \n                model_path=model_path, \n                prompt=prompt, \n                strength=strength, \n                guidance_scale=guidance_scale,\n                mask=mask, \n                inverted=inverted,\n                erasure_ckpt_path=erasure_ckpt_path,\n                tokens_per_class=tokens_per_class\n            )\n\n            for (aug, guidance_scale, \n                 strength, mask, inverted) in zip(\n                aug, guidance_scale, \n                strength, mask, inverted\n            )\n\n        ], probs=probs)\n\n    train_dataset = DATASETS[dataset](\n        split=\"train\", examples_per_class=examples_per_class, \n        synthetic_probability=synthetic_probability, \n        synthetic_dir=synthetic_dir,\n        use_randaugment=use_randaugment,\n        generative_aug=aug, seed=seed,\n        image_size=(image_size, image_size))\n\n    if num_synthetic > 0 and aug is not None:\n        train_dataset.generate_augmentations(num_synthetic)\n\n    cutmix_dataset = None\n    if use_cutmix and IS_CUTMIX_INSTALLED:\n        cutmix_dataset = CutMix(\n            train_dataset, beta=1.0, prob=0.5, num_mix=2, \n            num_class=train_dataset.num_classes)\n\n    train_sampler = torch.utils.data.RandomSampler(\n        cutmix_dataset if cutmix_dataset is not None else \n        train_dataset, replacement=True, \n        num_samples=batch_size * iterations_per_epoch)\n\n    train_dataloader = DataLoader(\n        cutmix_dataset if cutmix_dataset is not None else \n        train_dataset, batch_size=batch_size, \n        sampler=train_sampler, num_workers=4)\n\n    val_dataset = DATASETS[dataset](\n        split=\"val\", seed=seed,\n        image_size=(image_size, image_size))\n\n    val_sampler = torch.utils.data.RandomSampler(\n        val_dataset, replacement=True, \n        num_samples=batch_size * iterations_per_epoch)\n\n    val_dataloader = DataLoader(\n        val_dataset, batch_size=batch_size, \n        sampler=val_sampler, num_workers=4)\n\n    model = ClassificationModel(\n        train_dataset.num_classes, \n        backbone=classifier_backbone\n    ).cuda()\n\n    optim = torch.optim.Adam(model.parameters(), lr=0.0001)\n\n    records = []\n\n    for epoch in trange(num_epochs, desc=\"Training Classifier\"):\n\n        model.train()\n\n        epoch_loss = torch.zeros(\n            train_dataset.num_classes, \n            dtype=torch.float32, device='cuda')\n        epoch_accuracy = torch.zeros(\n            train_dataset.num_classes, \n            dtype=torch.float32, device='cuda')\n        epoch_size = torch.zeros(\n            train_dataset.num_classes, \n            dtype=torch.float32, device='cuda')\n\n        for image, label in train_dataloader:\n            image, label = image.cuda(), label.cuda()\n\n            logits = model(image)\n            prediction = logits.argmax(dim=1)\n\n            loss = F.cross_entropy(logits, label, reduction=\"none\")\n            if len(label.shape) > 1: label = label.argmax(dim=1)\n\n            accuracy = (prediction == label).float()\n\n            optim.zero_grad()\n            loss.mean().backward()\n            optim.step()\n\n            with torch.no_grad():\n            \n                epoch_size.scatter_add_(0, label, torch.ones_like(loss))\n                epoch_loss.scatter_add_(0, label, loss)\n                epoch_accuracy.scatter_add_(0, label, accuracy)\n\n        training_loss = epoch_loss / epoch_size.clamp(min=1)\n        training_accuracy = epoch_accuracy / epoch_size.clamp(min=1)\n\n        training_loss = training_loss.cpu().numpy()\n        training_accuracy = training_accuracy.cpu().numpy()\n\n        model.eval()\n\n        epoch_loss = torch.zeros(\n            train_dataset.num_classes, \n            dtype=torch.float32, device='cuda')\n        epoch_accuracy = torch.zeros(\n            train_dataset.num_classes, \n            dtype=torch.float32, device='cuda')\n        epoch_size = torch.zeros(\n            train_dataset.num_classes, \n            dtype=torch.float32, device='cuda')\n\n        for image, label in val_dataloader:\n            image, label = image.cuda(), label.cuda()\n\n            logits = model(image)\n            prediction = logits.argmax(dim=1)\n\n            loss = F.cross_entropy(logits, label, reduction=\"none\")\n            accuracy = (prediction == label).float()\n\n            with torch.no_grad():\n            \n                epoch_size.scatter_add_(0, label, torch.ones_like(loss))\n                epoch_loss.scatter_add_(0, label, loss)\n                epoch_accuracy.scatter_add_(0, label, accuracy)\n\n        validation_loss = epoch_loss / epoch_size.clamp(min=1)\n        validation_accuracy = epoch_accuracy / epoch_size.clamp(min=1)\n\n        validation_loss = validation_loss.cpu().numpy()\n        validation_accuracy = validation_accuracy.cpu().numpy()\n\n        records.append(dict(\n            seed=seed, \n            examples_per_class=examples_per_class,\n            epoch=epoch, \n            value=training_loss.mean(), \n            metric=\"Loss\", \n            split=\"Training\"\n        ))\n\n        records.append(dict(\n            seed=seed, \n            examples_per_class=examples_per_class,\n            epoch=epoch, \n            value=validation_loss.mean(), \n            metric=\"Loss\", \n            split=\"Validation\"\n        ))\n\n        records.append(dict(\n            seed=seed, \n            examples_per_class=examples_per_class,\n            epoch=epoch, \n            value=training_accuracy.mean(), \n            metric=\"Accuracy\", \n            split=\"Training\"\n        ))\n\n        records.append(dict(\n            seed=seed, \n            examples_per_class=examples_per_class,\n            epoch=epoch, \n            value=validation_accuracy.mean(), \n            metric=\"Accuracy\", \n            split=\"Validation\"\n        ))\n\n        for i, name in enumerate(train_dataset.class_names):\n\n            records.append(dict(\n                seed=seed, \n                examples_per_class=examples_per_class,\n                epoch=epoch, \n                value=training_loss[i], \n                metric=f\"Loss {name.title()}\", \n                split=\"Training\"\n            ))\n\n            records.append(dict(\n                seed=seed, \n                examples_per_class=examples_per_class,\n                epoch=epoch, \n                value=validation_loss[i], \n                metric=f\"Loss {name.title()}\", \n                split=\"Validation\"\n            ))\n\n            records.append(dict(\n                seed=seed, \n                examples_per_class=examples_per_class,\n                epoch=epoch, \n                value=training_accuracy[i], \n                metric=f\"Accuracy {name.title()}\", \n                split=\"Training\"\n            ))\n\n            records.append(dict(\n                seed=seed, \n                examples_per_class=examples_per_class,\n                epoch=epoch, \n                value=validation_accuracy[i], \n                metric=f\"Accuracy {name.title()}\", \n                split=\"Validation\"\n            ))\n            \n    return records\n\n\nclass ClassificationModel(nn.Module):\n    \n    def __init__(self, num_classes: int, backbone: str = \"resnet50\"):\n        \n        super(ClassificationModel, self).__init__()\n\n        self.backbone = backbone\n        self.image_processor  = None\n\n        if backbone == \"resnet50\":\n        \n            self.base_model = resnet50(weights=ResNet50_Weights.DEFAULT)\n            self.out = nn.Linear(2048, num_classes)\n\n        elif backbone == \"deit\":\n\n            self.base_model = DeiTModel.from_pretrained(\n                \"facebook/deit-base-distilled-patch16-224\")\n            self.out = nn.Linear(768, num_classes)\n        \n    def forward(self, image):\n        \n        x = image\n\n        if self.backbone == \"resnet50\":\n            \n            with torch.no_grad():\n\n                x = self.base_model.conv1(x)\n                x = self.base_model.bn1(x)\n                x = self.base_model.relu(x)\n                x = self.base_model.maxpool(x)\n\n                x = self.base_model.layer1(x)\n                x = self.base_model.layer2(x)\n                x = self.base_model.layer3(x)\n                x = self.base_model.layer4(x)\n\n                x = self.base_model.avgpool(x)\n                x = torch.flatten(x, 1)\n\n        elif self.backbone == \"deit\":\n            \n            with torch.no_grad():\n\n                x = self.base_model(x)[0][:, 0, :]\n            \n        return self.out(x)\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\"Few-Shot Baseline\")\n\n    parser.add_argument(\"--logdir\", type=str, default=\"few_shot_combined\")\n    parser.add_argument(\"--model-path\", type=str, default=\"CompVis/stable-diffusion-v1-4\")\n\n    parser.add_argument(\"--prompt\", type=str, default=\"a photo of a {name}\")\n\n    parser.add_argument(\"--synthetic-probability\", type=float, default=0.5)\n    parser.add_argument(\"--synthetic-dir\", type=str, default=DEFAULT_SYNTHETIC_DIR)\n    \n    parser.add_argument(\"--image-size\", type=int, default=256)\n    parser.add_argument(\"--classifier-backbone\", type=str, \n                        default=\"resnet50\", choices=[\"resnet50\", \"deit\"])\n\n    parser.add_argument(\"--iterations-per-epoch\", type=int, default=200)\n    parser.add_argument(\"--num-epochs\", type=int, default=50)\n    parser.add_argument(\"--batch-size\", type=int, default=32)\n\n    parser.add_argument(\"--num-synthetic\", type=int, default=15)\n    parser.add_argument(\"--num-trials\", type=int, default=8)\n    parser.add_argument(\"--examples-per-class\", nargs='+', type=int, default=[1, 2, 4, 8, 16])\n    \n    parser.add_argument(\"--embed-path\", type=str, default=DEFAULT_EMBED_PATH)\n    \n    parser.add_argument(\"--dataset\", type=str, default=\"pascal\", \n                        choices=[\"spurge\", \"imagenet\", \"coco\", \"pascal\", \"flowers\", \"caltech\"])\n    \n    parser.add_argument(\"--aug\", nargs=\"+\", type=str, default=None, \n                        choices=[\"real-guidance\", \"textual-inversion\",\n                                 \"multi-token-inversion\"])\n\n    parser.add_argument(\"--strength\", nargs=\"+\", type=float, default=None)\n    parser.add_argument(\"--guidance-scale\", nargs=\"+\", type=float, default=None)\n\n    parser.add_argument(\"--mask\", nargs=\"+\", type=int, default=None, choices=[0, 1])\n    parser.add_argument(\"--inverted\", nargs=\"+\", type=int, default=None, choices=[0, 1])\n    \n    parser.add_argument(\"--probs\", nargs=\"+\", type=float, default=None)\n    \n    parser.add_argument(\"--compose\", type=str, default=\"parallel\", \n                        choices=[\"parallel\", \"sequential\"])\n    \n    parser.add_argument(\"--erasure-ckpt-path\", type=str, default=None)\n\n    parser.add_argument(\"--use-randaugment\", action=\"store_true\")\n    parser.add_argument(\"--use-cutmix\", action=\"store_true\")\n\n    parser.add_argument(\"--tokens-per-class\", type=int, default=4)\n    \n    args = parser.parse_args()\n\n    try:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n    except KeyError:\n        rank, world_size = 0, 1\n\n    device_id = rank % torch.cuda.device_count()\n    torch.cuda.set_device(rank % torch.cuda.device_count())\n\n    print(f'Initialized process {rank} / {world_size}')\n    os.makedirs(args.logdir, exist_ok=True)\n\n    all_trials = []\n\n    options = product(range(args.num_trials), args.examples_per_class)\n    options = np.array(list(options))\n    options = np.array_split(options, world_size)[rank]\n\n    for seed, examples_per_class in options.tolist():\n\n        hyperparameters = dict(\n            examples_per_class=examples_per_class,\n            seed=seed, \n            dataset=args.dataset,\n            num_epochs=args.num_epochs,\n            iterations_per_epoch=args.iterations_per_epoch, \n            batch_size=args.batch_size,\n            model_path=args.model_path,\n            synthetic_probability=args.synthetic_probability, \n            num_synthetic=args.num_synthetic, \n            prompt=args.prompt, \n            tokens_per_class=args.tokens_per_class,\n            aug=args.aug,\n            strength=args.strength, \n            guidance_scale=args.guidance_scale,\n            mask=args.mask, \n            inverted=args.inverted,\n            probs=args.probs,\n            compose=args.compose,\n            use_randaugment=args.use_randaugment,\n            use_cutmix=args.use_cutmix,\n            erasure_ckpt_path=args.erasure_ckpt_path,\n            image_size=args.image_size,\n            classifier_backbone=args.classifier_backbone)\n\n        synthetic_dir = args.synthetic_dir.format(**hyperparameters)\n        embed_path = args.embed_path.format(**hyperparameters)\n\n        all_trials.extend(run_experiment(\n            synthetic_dir=synthetic_dir, \n            embed_path=embed_path, **hyperparameters))\n\n        path = f\"results_{seed}_{examples_per_class}.csv\"\n        path = os.path.join(args.logdir, path)\n\n        pd.DataFrame.from_records(all_trials).to_csv(path)\n        print(f\"[rank {rank}] n={examples_per_class} saved to: {path}\")"
  }
]