[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 dongzelian\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# SSF for Efficient Model Tuning\n\nThis repo is the official implementation of our NeurIPS2022 paper \"Scaling & Shifting Your Features: A New Baseline for Efficient Model Tuning\" ([arXiv](https://arxiv.org/abs/2210.08823)). \n\n\n\n\n## Usage\n\n### Install\n\n- Clone this repo:\n\n```bash\ngit clone https://github.com/dongzelian/SSF.git\ncd SSF\n```\n\n- Create a conda virtual environment and activate it:\n\n```bash\nconda create -n ssf python=3.7 -y\nconda activate ssf\n```\n\n- Install `CUDA==10.1` with `cudnn7` following\n  the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)\n- Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`:\n\n```bash\nconda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch\n```\n\n- Install `timm==0.6.5`:\n\n```bash\npip install timm==0.6.5\n```\n\n\n- Install other requirements:\n\n```bash\npip install -r requirements.txt\n```\n\n\n### Data preparation\n\n- FGVC & vtab-1k\n\nYou can follow [VPT](https://github.com/KMnP/vpt) to download them. \n\nSince the original [vtab dataset](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data) is processed with tensorflow scripts and the processing of some datasets is tricky, we also upload the extracted vtab-1k dataset in [onedrive](https://shanghaitecheducn-my.sharepoint.com/:f:/g/personal/liandz_shanghaitech_edu_cn/EnV6eYPVCPZKhbqi-WSJIO8BOcyQwDwRk6dAThqonQ1Ycw?e=J884Fp) for your convenience. You can download from here and then use them with our [vtab.py](https://github.com/dongzelian/SSF/blob/main/data/vtab.py) directly. (Note that the license is in [vtab dataset](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data)).\n\n\n\n- CIFAR-100\n```bash\nwget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz\n```\n\n- For ImageNet-1K, download it from http://image-net.org/, and move validation images to labeled sub-folders. The file structure should look like:\n  ```bash\n  $ tree data\n  imagenet\n  ├── train\n  │   ├── class1\n  │   │   ├── img1.jpeg\n  │   │   ├── img2.jpeg\n  │   │   └── ...\n  │   ├── class2\n  │   │   ├── img3.jpeg\n  │   │   └── ...\n  │   └── ...\n  └── val\n      ├── class1\n      │   ├── img4.jpeg\n      │   ├── img5.jpeg\n      │   └── ...\n      ├── class2\n      │   ├── img6.jpeg\n      │   └── ...\n      └── ...\n \n  ```\n\n- Robustness & OOD datasets\n\nPrepare [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-R](https://github.com/hendrycks/imagenet-r) and [ImageNet-C](https://zenodo.org/record/2235448#.Y04cBOxByFw) for evaluation.\n\n\n\n### Pre-trained model preparation\n\n- For pre-trained ViT-B/16, Swin-B, and ConvNext-B models on ImageNet-21K, the model weights will be automatically downloaded when you fine-tune a pre-trained model via `SSF`. You can also manually download them from [ViT](https://github.com/google-research/vision_transformer),[Swin Transformer](https://github.com/microsoft/Swin-Transformer), and [ConvNext](https://github.com/facebookresearch/ConvNeXt).\n\n\n\n- For pre-trained AS-MLP-B model on ImageNet-1K, you can manually download them from [AS-MLP](https://github.com/svip-lab/AS-MLP).\n\n\n\n### Fine-tuning a pre-trained model via SSF\n\nTo fine-tune a pre-trained ViT model via `SSF` on CIFAR-100 or ImageNet-1K, run:\n\n```bash\nbash train_scripts/vit/cifar_100/train_ssf.sh\n```\nor \n```bash\nbash train_scripts/vit/imagenet_1k/train_ssf.sh\n```\n\nYou can also find the similar scripts for Swin, ConvNext, and AS-MLP models. You can easily reproduce our results. Enjoy!\n\n\n\n### Robustness & OOD\n\nTo evaluate the performance of fine-tuned model via SSF on Robustness & OOD, run:\n\n```bash\nbash train_scripts/vit/imagenet_a(r, c)/eval_ssf.sh\n```\n\n\n### Citation\nIf this project is helpful for you, you can cite our paper:\n```\n@InProceedings{Lian_2022_SSF,\n  title={Scaling \\& Shifting Your Features: A New Baseline for Efficient Model Tuning},\n  author={Lian, Dongze and Zhou, Daquan and Feng, Jiashi and Wang, Xinchao},\n  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},\n  year={2022}\n}\n```\n\n\n### Acknowledgement\nThe code is built upon [timm](https://github.com/rwightman/pytorch-image-models). The processing of the vtab-1k dataset refers to [vpt](https://github.com/KMnP/vpt), [vtab github repo](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data), and [NOAH](https://github.com/ZhangYuanhan-AI/NOAH).\n"
  },
  {
    "path": "data/__init__.py",
    "content": "from .loader import create_loader\nfrom .dataset_factory import create_dataset"
  },
  {
    "path": "data/cub2011.py",
    "content": "import os\n\nimport pandas as pd\nfrom torchvision.datasets import VisionDataset\nfrom torchvision.datasets.folder import default_loader\nfrom torchvision.datasets.utils import download_file_from_google_drive\n\n\nclass Cub2011(VisionDataset):\n    \"\"\"`CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.\n        Args:\n            root (string): Root directory of the dataset.\n            train (bool, optional): If True, creates dataset from training set, otherwise\n               creates from test set.\n            transform (callable, optional): A function/transform that  takes in an PIL image\n               and returns a transformed version. E.g, ``transforms.RandomCrop``\n            target_transform (callable, optional): A function/transform that takes in the\n               target and transforms it.\n            download (bool, optional): If true, downloads the dataset from the internet and\n               puts it in root directory. If dataset is already downloaded, it is not\n               downloaded again.\n    \"\"\"\n    base_folder = 'CUB_200_2011/images'\n    # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'\n    file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'\n    filename = 'CUB_200_2011.tgz'\n    tgz_md5 = '97eceeb196236b17998738112f37df78'\n\n    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):\n        super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform)\n\n        self.loader = default_loader\n        self.train = train\n        if download:\n            self._download()\n\n        if not self._check_integrity():\n            raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')\n\n    def _load_metadata(self):\n        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',\n                             names=['img_id', 'filepath'])\n        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),\n                                         sep=' ', names=['img_id', 'target'])\n        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),\n                                       sep=' ', names=['img_id', 'is_training_img'])\n\n        data = images.merge(image_class_labels, on='img_id')\n        self.data = data.merge(train_test_split, on='img_id')\n\n        class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),\n                                  sep=' ', names=['class_name'], usecols=[1])\n        self.class_names = class_names['class_name'].to_list()\n        if self.train:\n            self.data = self.data[self.data.is_training_img == 1]\n        else:\n            self.data = self.data[self.data.is_training_img == 0]\n\n    def _check_integrity(self):\n        try:\n            self._load_metadata()\n        except Exception:\n            return False\n\n        for index, row in self.data.iterrows():\n            filepath = os.path.join(self.root, self.base_folder, row.filepath)\n            if not os.path.isfile(filepath):\n                print(filepath)\n                return False\n        return True\n\n    def _download(self):\n        import tarfile\n\n        if self._check_integrity():\n            print('Files already downloaded and verified')\n            return\n\n        download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)\n\n        with tarfile.open(os.path.join(self.root, self.filename), \"r:gz\") as tar:\n            tar.extractall(path=self.root)\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, idx):\n        sample = self.data.iloc[idx]\n        path = os.path.join(self.root, self.base_folder, sample.filepath)\n        target = sample.target - 1  # Targets start at 1 by default, so shift to 0\n        img = self.loader(path)\n\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n        return img, target\n\n\nif __name__ == '__main__':\n    train_dataset = Cub2011('./cub2011', train=True, download=False)\n    test_dataset = Cub2011('./cub2011', train=False, download=False)"
  },
  {
    "path": "data/dataset_factory.py",
    "content": "\"\"\" Dataset Factory\n\nHacked together by / Copyright 2021, Ross Wightman\n\"\"\"\nimport os\n#import hub\n\nfrom torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder\ntry:\n    from torchvision.datasets import Places365\n    has_places365 = True\nexcept ImportError:\n    has_places365 = False\ntry:\n    from torchvision.datasets import INaturalist\n    has_inaturalist = True\nexcept ImportError:\n    has_inaturalist = False\n\nfrom timm.data.dataset import IterableImageDataset, ImageDataset\n\n\n\n# my datasets\nfrom .stanford_dogs import dogs\nfrom .nabirds import NABirds\nfrom .cub2011 import Cub2011\nfrom .vtab import VTAB\n\n\n\n_TORCH_BASIC_DS = dict(\n    cifar10=CIFAR10,\n    cifar100=CIFAR100,\n    mnist=MNIST,\n    qmist=QMNIST,\n    kmnist=KMNIST,\n    fashion_mnist=FashionMNIST,\n)\n_TRAIN_SYNONYM = {'train', 'training'}\n_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}\n\n_VTAB_DATASET = ['caltech101', 'clevr_count', 'dmlab', 'dsprites_ori', 'eurosat', 'flowers102', 'patch_camelyon', 'smallnorb_azi', 'svhn', 'cifar100', 'clevr_dist', 'dsprites_loc', 'dtd', 'kitti', 'pets', 'resisc45', 'smallnorb_ele', 'sun397', 'diabetic_retinopathy']\n\n\n\n\ndef _search_split(root, split):\n    # look for sub-folder with name of split in root and use that if it exists\n    split_name = split.split('[')[0]\n    try_root = os.path.join(root, split_name)\n    if os.path.exists(try_root):\n        return try_root\n\n    def _try(syn):\n        for s in syn:\n            try_root = os.path.join(root, s)\n            if os.path.exists(try_root):\n                return try_root\n        return root\n    if split_name in _TRAIN_SYNONYM:\n        root = _try(_TRAIN_SYNONYM)\n    elif split_name in _EVAL_SYNONYM:\n        root = _try(_EVAL_SYNONYM)\n    return root\n\n\ndef create_dataset(\n        name,\n        root,\n        split='validation',\n        search_split=True,\n        class_map=None,\n        load_bytes=False,\n        is_training=False,\n        download=False,\n        batch_size=None,\n        repeats=0,\n        **kwargs\n):\n    \"\"\" Dataset factory method\n\n    In parenthesis after each arg are the type of dataset supported for each arg, one of:\n      * folder - default, timm folder (or tar) based ImageDataset\n      * torch - torchvision based datasets\n      * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset\n      * all - any of the above\n\n    Args:\n        name: dataset name, empty is okay for folder based datasets\n        root: root folder of dataset (all)\n        split: dataset split (all)\n        search_split: search for split specific child fold from root so one can specify\n            `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)\n        class_map: specify class -> index mapping via text file or dict (folder)\n        load_bytes: load data, return images as undecoded bytes (folder)\n        download: download dataset if not present and supported (TFDS, torch)\n        is_training: create dataset in train mode, this is different from the split.\n            For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)\n        batch_size: batch size hint for (TFDS)\n        repeats: dataset repeats per iteration i.e. epoch (TFDS)\n        **kwargs: other args to pass to dataset\n\n    Returns:\n        Dataset object\n    \"\"\"\n    name = name.lower()\n    if name.startswith('torch/'):\n        name = name.split('/', 2)[-1]\n        torch_kwargs = dict(root=root, download=download, **kwargs)\n        if name in _TORCH_BASIC_DS:\n            ds_class = _TORCH_BASIC_DS[name]\n            use_train = split in _TRAIN_SYNONYM\n            ds = ds_class(train=use_train, **torch_kwargs)\n        elif name == 'inaturalist' or name == 'inat':\n            assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'\n            target_type = 'full'\n            split_split = split.split('/')\n            if len(split_split) > 1:\n                target_type = split_split[0].split('_')\n                if len(target_type) == 1:\n                    target_type = target_type[0]\n                split = split_split[-1]\n            if split in _TRAIN_SYNONYM:\n                split = '2021_train'\n            elif split in _EVAL_SYNONYM:\n                split = '2021_valid'\n            ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)\n        elif name == 'places365':\n            assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'\n            if split in _TRAIN_SYNONYM:\n                split = 'train-standard'\n            elif split in _EVAL_SYNONYM:\n                split = 'val'\n            ds = Places365(split=split, **torch_kwargs)\n        elif name == 'imagenet':\n            if split in _EVAL_SYNONYM:\n                split = 'val'\n            ds = ImageNet(split=split, **torch_kwargs)\n        elif name == 'image_folder' or name == 'folder':\n            # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason\n            if search_split and os.path.isdir(root):\n                # look for split specific sub-folder in root\n                root = _search_split(root, split)\n            ds = ImageFolder(root, **kwargs)\n        else:\n            assert False, f\"Unknown torchvision dataset {name}\"\n    elif name.startswith('tfds/'):\n        ds = IterableImageDataset(\n            root, parser=name, split=split, is_training=is_training,\n            download=download, batch_size=batch_size, repeats=repeats, **kwargs)\n    else:\n        # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future\n\n        # define my datasets\n        if name == 'stanford_dogs':\n            ds = dogs(root=root, train=is_training, **kwargs)\n        elif name == 'nabirds':\n            ds = NABirds(root=root, train=is_training, **kwargs)\n        elif name == 'cub2011':\n            ds = Cub2011(root=root, train=is_training, **kwargs)\n        elif name in _VTAB_DATASET:\n            ds = VTAB(root=root, train=is_training, **kwargs)\n        else:\n            if os.path.isdir(os.path.join(root, split)):\n                root = os.path.join(root, split)\n            else:\n                if search_split and os.path.isdir(root):\n                    root = _search_split(root, split)\n            ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)\n    return ds\n\n"
  },
  {
    "path": "data/loader.py",
    "content": "\"\"\" Loader Factory, Fast Collate, CUDA Prefetcher\n\nPrefetcher and Fast Collate inspired by NVIDIA APEX example at\nhttps://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf\n\nHacked together by / Copyright 2019, Ross Wightman\n\"\"\"\nimport random\nfrom functools import partial\nfrom itertools import repeat\nfrom typing import Callable\n\nimport torch.utils.data\nimport numpy as np\n\nfrom .transforms_factory import create_transform\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.data.distributed_sampler import OrderedDistributedSampler, RepeatAugSampler\nfrom timm.data.random_erasing import RandomErasing\nfrom timm.data.mixup import FastCollateMixup\n\n\ndef fast_collate(batch):\n    \"\"\" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)\"\"\"\n    assert isinstance(batch[0], tuple)\n    batch_size = len(batch)\n    if isinstance(batch[0][0], tuple):\n        # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position\n        # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position\n        inner_tuple_size = len(batch[0][0])\n        flattened_batch_size = batch_size * inner_tuple_size\n        targets = torch.zeros(flattened_batch_size, dtype=torch.int64)\n        tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)\n        for i in range(batch_size):\n            assert len(batch[i][0]) == inner_tuple_size  # all input tensor tuples must be same length\n            for j in range(inner_tuple_size):\n                targets[i + j * batch_size] = batch[i][1]\n                tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])\n        return tensor, targets\n    elif isinstance(batch[0][0], np.ndarray):\n        targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)\n        assert len(targets) == batch_size\n        tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)\n        for i in range(batch_size):\n            tensor[i] += torch.from_numpy(batch[i][0])\n        return tensor, targets\n    elif isinstance(batch[0][0], torch.Tensor):\n        targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)\n        assert len(targets) == batch_size\n        tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)\n        for i in range(batch_size):\n            tensor[i].copy_(batch[i][0])\n        return tensor, targets\n    else:\n        assert False\n\n\ndef expand_to_chs(x, n):\n    if not isinstance(x, (tuple, list)):\n        x = tuple(repeat(x, n))\n    elif len(x) == 1:\n        x = x * n\n    else:\n        assert len(x) == n, 'normalization stats must match image channels'\n    return x\n\n\nclass PrefetchLoader:\n\n    def __init__(\n            self,\n            loader,\n            mean=IMAGENET_DEFAULT_MEAN,\n            std=IMAGENET_DEFAULT_STD,\n            channels=3,\n            fp16=False,\n            re_prob=0.,\n            re_mode='const',\n            re_count=1,\n            re_num_splits=0):\n\n        mean = expand_to_chs(mean, channels)\n        std = expand_to_chs(std, channels)\n        normalization_shape = (1, channels, 1, 1)\n\n        self.loader = loader\n        self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)\n        self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)\n        self.fp16 = fp16\n        if fp16:\n            self.mean = self.mean.half()\n            self.std = self.std.half()\n        if re_prob > 0.:\n            self.random_erasing = RandomErasing(\n                probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)\n        else:\n            self.random_erasing = None\n\n    def __iter__(self):\n        stream = torch.cuda.Stream()\n        first = True\n\n        for next_input, next_target in self.loader:\n            with torch.cuda.stream(stream):\n                next_input = next_input.cuda(non_blocking=True)\n                next_target = next_target.cuda(non_blocking=True)\n                if self.fp16:\n                    next_input = next_input.half().sub_(self.mean).div_(self.std)\n                else:\n                    next_input = next_input.float().sub_(self.mean).div_(self.std)\n                if self.random_erasing is not None:\n                    next_input = self.random_erasing(next_input)\n\n            if not first:\n                yield input, target\n            else:\n                first = False\n\n            torch.cuda.current_stream().wait_stream(stream)\n            input = next_input\n            target = next_target\n\n        yield input, target\n\n    def __len__(self):\n        return len(self.loader)\n\n    @property\n    def sampler(self):\n        return self.loader.sampler\n\n    @property\n    def dataset(self):\n        return self.loader.dataset\n\n    @property\n    def mixup_enabled(self):\n        if isinstance(self.loader.collate_fn, FastCollateMixup):\n            return self.loader.collate_fn.mixup_enabled\n        else:\n            return False\n\n    @mixup_enabled.setter\n    def mixup_enabled(self, x):\n        if isinstance(self.loader.collate_fn, FastCollateMixup):\n            self.loader.collate_fn.mixup_enabled = x\n\n\ndef _worker_init(worker_id, worker_seeding='all'):\n    worker_info = torch.utils.data.get_worker_info()\n    assert worker_info.id == worker_id\n    if isinstance(worker_seeding, Callable):\n        seed = worker_seeding(worker_info)\n        random.seed(seed)\n        torch.manual_seed(seed)\n        np.random.seed(seed % (2 ** 32 - 1))\n    else:\n        assert worker_seeding in ('all', 'part')\n        # random / torch seed already called in dataloader iter class w/ worker_info.seed\n        # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)\n        if worker_seeding == 'all':\n            np.random.seed(worker_info.seed % (2 ** 32 - 1))\n\n\ndef create_loader(\n        dataset,\n        input_size,\n        batch_size,\n        is_training=False,\n        use_prefetcher=True,\n        no_aug=False,\n        simple_aug=False,\n        direct_resize=False,\n        re_prob=0.,\n        re_mode='const',\n        re_count=1,\n        re_split=False,\n        scale=None,\n        ratio=None,\n        hflip=0.5,\n        vflip=0.,\n        color_jitter=0.4,\n        auto_augment=None,\n        num_aug_repeats=0,\n        num_aug_splits=0,\n        interpolation='bilinear',\n        mean=IMAGENET_DEFAULT_MEAN,\n        std=IMAGENET_DEFAULT_STD,\n        num_workers=1,\n        distributed=False,\n        crop_pct=None,\n        collate_fn=None,\n        pin_memory=False,\n        fp16=False,\n        tf_preprocessing=False,\n        use_multi_epochs_loader=False,\n        persistent_workers=True,\n        worker_seeding='all',\n):\n    re_num_splits = 0\n    if re_split:\n        # apply RE to second half of batch if no aug split otherwise line up with aug split\n        re_num_splits = num_aug_splits or 2\n    dataset.transform = create_transform(\n        input_size,\n        is_training=is_training,\n        use_prefetcher=use_prefetcher,\n        no_aug=no_aug,\n        simple_aug=simple_aug,\n        direct_resize=direct_resize,\n        scale=scale,\n        ratio=ratio,\n        hflip=hflip,\n        vflip=vflip,\n        color_jitter=color_jitter,\n        auto_augment=auto_augment,\n        interpolation=interpolation,\n        mean=mean,\n        std=std,\n        crop_pct=crop_pct,\n        tf_preprocessing=tf_preprocessing,\n        re_prob=re_prob,\n        re_mode=re_mode,\n        re_count=re_count,\n        re_num_splits=re_num_splits,\n        separate=num_aug_splits > 0,\n    )\n\n    sampler = None\n    if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):\n        if is_training:\n            if num_aug_repeats:\n                sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)\n            else:\n                sampler = torch.utils.data.distributed.DistributedSampler(dataset)\n        else:\n            # This will add extra duplicate entries to result in equal num\n            # of samples per-process, will slightly alter validation results\n            sampler = OrderedDistributedSampler(dataset)\n    else:\n        assert num_aug_repeats == 0, \"RepeatAugment not currently supported in non-distributed or IterableDataset use\"\n\n    if collate_fn is None:\n        collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate\n\n    loader_class = torch.utils.data.DataLoader\n    if use_multi_epochs_loader:\n        loader_class = MultiEpochsDataLoader\n\n    loader_args = dict(\n        batch_size=batch_size,\n        shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,\n        num_workers=num_workers,\n        sampler=sampler,\n        collate_fn=collate_fn,\n        pin_memory=pin_memory,\n        drop_last=is_training,\n        worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),\n        persistent_workers=persistent_workers\n    )\n    try:\n        loader = loader_class(dataset, **loader_args)\n    except TypeError as e:\n        loader_args.pop('persistent_workers')  # only in Pytorch 1.7+\n        loader = loader_class(dataset, **loader_args)\n    if use_prefetcher:\n        prefetch_re_prob = re_prob if is_training and not no_aug else 0.\n        loader = PrefetchLoader(\n            loader,\n            mean=mean,\n            std=std,\n            channels=input_size[0],\n            fp16=fp16,\n            re_prob=prefetch_re_prob,\n            re_mode=re_mode,\n            re_count=re_count,\n            re_num_splits=re_num_splits\n        )\n\n    return loader\n\n\nclass MultiEpochsDataLoader(torch.utils.data.DataLoader):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._DataLoader__initialized = False\n        self.batch_sampler = _RepeatSampler(self.batch_sampler)\n        self._DataLoader__initialized = True\n        self.iterator = super().__iter__()\n\n    def __len__(self):\n        return len(self.batch_sampler.sampler)\n\n    def __iter__(self):\n        for i in range(len(self)):\n            yield next(self.iterator)\n\n\nclass _RepeatSampler(object):\n    \"\"\" Sampler that repeats forever.\n\n    Args:\n        sampler (Sampler)\n    \"\"\"\n\n    def __init__(self, sampler):\n        self.sampler = sampler\n\n    def __iter__(self):\n        while True:\n            yield from iter(self.sampler)\n"
  },
  {
    "path": "data/nabirds.py",
    "content": "import os\nimport pandas as pd\nimport warnings\nimport numpy as np\nimport torch\nfrom PIL import Image\n\n\n\nfrom torchvision.datasets import VisionDataset\nfrom torchvision.datasets.folder import default_loader\nfrom torchvision.datasets.utils import check_integrity, extract_archive\n\n\nfrom torch.utils.data import DataLoader, Dataset\n\n\n\nclass NABirds(Dataset):\n    \"\"\"`NABirds <https://dl.allaboutbirds.org/nabirds>`_ Dataset.\n        Args:\n            root (string): Root directory of the dataset.\n            train (bool, optional): If True, creates dataset from training set, otherwise\n               creates from test set.\n            transform (callable, optional): A function/transform that  takes in an PIL image\n               and returns a transformed version. E.g, ``transforms.RandomCrop``\n            target_transform (callable, optional): A function/transform that takes in the\n               target and transforms it.\n            download (bool, optional): If true, downloads the dataset from the internet and\n               puts it in root directory. If dataset is already downloaded, it is not\n               downloaded again.\n    \"\"\"\n    base_folder = 'nabirds/images'\n\n    def __init__(self, root, train=True, transform=None):\n        dataset_path = os.path.join(root, 'nabirds')\n        self.root = root\n        self.loader = default_loader\n        self.train = train\n        self.transform = transform\n\n        image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),\n                                  sep=' ', names=['img_id', 'filepath'])\n        image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),\n                                         sep=' ', names=['img_id', 'target'])\n        # Since the raw labels are non-continuous, map them to new ones\n        self.label_map = get_continuous_class_map(image_class_labels['target'])\n        train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),\n                                       sep=' ', names=['img_id', 'is_training_img'])\n        data = image_paths.merge(image_class_labels, on='img_id')\n        self.data = data.merge(train_test_split, on='img_id')\n        # Load in the train / test split\n        if self.train:\n            self.data = self.data[self.data.is_training_img == 1]\n        else:\n            self.data = self.data[self.data.is_training_img == 0]\n\n        # Load in the class data\n        self.class_names = load_class_names(dataset_path)\n        self.class_hierarchy = load_hierarchy(dataset_path)\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, idx):\n        sample = self.data.iloc[idx]\n        path = os.path.join(self.root, self.base_folder, sample.filepath)\n        target = self.label_map[sample.target]\n        img = self.loader(path)\n\n        if self.transform is not None:\n            img = self.transform(img)\n        return img, target\n\ndef get_continuous_class_map(class_labels):\n    label_set = set(class_labels)\n    return {k: i for i, k in enumerate(label_set)}\n\ndef load_class_names(dataset_path=''):\n    names = {}\n\n    with open(os.path.join(dataset_path, 'classes.txt')) as f:\n        for line in f:\n            pieces = line.strip().split()\n            class_id = pieces[0]\n            names[class_id] = ' '.join(pieces[1:])\n\n    return names\n\ndef load_hierarchy(dataset_path=''):\n    parents = {}\n\n    with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:\n        for line in f:\n            pieces = line.strip().split()\n            child_id, parent_id = pieces\n            parents[child_id] = parent_id\n\n    return parents\n"
  },
  {
    "path": "data/stanford_dogs.py",
    "content": "from __future__ import print_function\n\nfrom PIL import Image\nfrom os.path import join\nimport os\nimport scipy.io\n\nimport torch.utils.data as data\nfrom torchvision.datasets.utils import download_url, list_dir, list_files\n\n\nclass dogs(data.Dataset):\n    \"\"\"`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.\n    Args:\n        root (string): Root directory of dataset where directory\n            ``omniglot-py`` exists.\n        cropped (bool, optional): If true, the images will be cropped into the bounding box specified\n            in the annotations\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n        target_transform (callable, optional): A function/transform that takes in the\n            target and transforms it.\n        download (bool, optional): If true, downloads the dataset tar files from the internet and\n            puts it in root directory. If the tar files are already downloaded, they are not\n            downloaded again.\n    \"\"\"\n    #folder = 'StanfordDogs'\n    folder = ''\n    download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'\n\n    def __init__(self,\n                 root,\n                 train=True,\n                 cropped=False,\n                 transform=None,\n                 target_transform=None,\n                 download=False):\n\n        self.root = join(os.path.expanduser(root), self.folder)\n        self.train = train\n        self.cropped = cropped\n        self.transform = transform\n        self.target_transform = target_transform\n\n        if download:\n            self.download()\n\n        split = self.load_split()\n\n        self.images_folder = join(self.root, 'Images')\n        self.annotations_folder = join(self.root, 'Annotation')\n        self._breeds = list_dir(self.images_folder)\n\n        if self.cropped:\n            self._breed_annotations = [[(annotation, box, idx)\n                                        for box in self.get_boxes(join(self.annotations_folder, annotation))]\n                                        for annotation, idx in split]\n            self._flat_breed_annotations = sum(self._breed_annotations, [])\n\n            self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]\n        else:\n            self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]\n\n            self._flat_breed_images = self._breed_images\n\n        self.classes = [\"Chihuaha\",\n                        \"Japanese Spaniel\",\n                        \"Maltese Dog\",\n                        \"Pekinese\",\n                        \"Shih-Tzu\",\n                        \"Blenheim Spaniel\",\n                        \"Papillon\",\n                        \"Toy Terrier\",\n                        \"Rhodesian Ridgeback\",\n                        \"Afghan Hound\",\n                        \"Basset Hound\",\n                        \"Beagle\",\n                        \"Bloodhound\",\n                        \"Bluetick\",\n                        \"Black-and-tan Coonhound\",\n                        \"Walker Hound\",\n                        \"English Foxhound\",\n                        \"Redbone\",\n                        \"Borzoi\",\n                        \"Irish Wolfhound\",\n                        \"Italian Greyhound\",\n                        \"Whippet\",\n                        \"Ibizian Hound\",\n                        \"Norwegian Elkhound\",\n                        \"Otterhound\",\n                        \"Saluki\",\n                        \"Scottish Deerhound\",\n                        \"Weimaraner\",\n                        \"Staffordshire Bullterrier\",\n                        \"American Staffordshire Terrier\",\n                        \"Bedlington Terrier\",\n                        \"Border Terrier\",\n                        \"Kerry Blue Terrier\",\n                        \"Irish Terrier\",\n                        \"Norfolk Terrier\",\n                        \"Norwich Terrier\",\n                        \"Yorkshire Terrier\",\n                        \"Wirehaired Fox Terrier\",\n                        \"Lakeland Terrier\",\n                        \"Sealyham Terrier\",\n                        \"Airedale\",\n                        \"Cairn\",\n                        \"Australian Terrier\",\n                        \"Dandi Dinmont\",\n                        \"Boston Bull\",\n                        \"Miniature Schnauzer\",\n                        \"Giant Schnauzer\",\n                        \"Standard Schnauzer\",\n                        \"Scotch Terrier\",\n                        \"Tibetan Terrier\",\n                        \"Silky Terrier\",\n                        \"Soft-coated Wheaten Terrier\",\n                        \"West Highland White Terrier\",\n                        \"Lhasa\",\n                        \"Flat-coated Retriever\",\n                        \"Curly-coater Retriever\",\n                        \"Golden Retriever\",\n                        \"Labrador Retriever\",\n                        \"Chesapeake Bay Retriever\",\n                        \"German Short-haired Pointer\",\n                        \"Vizsla\",\n                        \"English Setter\",\n                        \"Irish Setter\",\n                        \"Gordon Setter\",\n                        \"Brittany\",\n                        \"Clumber\",\n                        \"English Springer Spaniel\",\n                        \"Welsh Springer Spaniel\",\n                        \"Cocker Spaniel\",\n                        \"Sussex Spaniel\",\n                        \"Irish Water Spaniel\",\n                        \"Kuvasz\",\n                        \"Schipperke\",\n                        \"Groenendael\",\n                        \"Malinois\",\n                        \"Briard\",\n                        \"Kelpie\",\n                        \"Komondor\",\n                        \"Old English Sheepdog\",\n                        \"Shetland Sheepdog\",\n                        \"Collie\",\n                        \"Border Collie\",\n                        \"Bouvier des Flandres\",\n                        \"Rottweiler\",\n                        \"German Shepard\",\n                        \"Doberman\",\n                        \"Miniature Pinscher\",\n                        \"Greater Swiss Mountain Dog\",\n                        \"Bernese Mountain Dog\",\n                        \"Appenzeller\",\n                        \"EntleBucher\",\n                        \"Boxer\",\n                        \"Bull Mastiff\",\n                        \"Tibetan Mastiff\",\n                        \"French Bulldog\",\n                        \"Great Dane\",\n                        \"Saint Bernard\",\n                        \"Eskimo Dog\",\n                        \"Malamute\",\n                        \"Siberian Husky\",\n                        \"Affenpinscher\",\n                        \"Basenji\",\n                        \"Pug\",\n                        \"Leonberg\",\n                        \"Newfoundland\",\n                        \"Great Pyrenees\",\n                        \"Samoyed\",\n                        \"Pomeranian\",\n                        \"Chow\",\n                        \"Keeshond\",\n                        \"Brabancon Griffon\",\n                        \"Pembroke\",\n                        \"Cardigan\",\n                        \"Toy Poodle\",\n                        \"Miniature Poodle\",\n                        \"Standard Poodle\",\n                        \"Mexican Hairless\",\n                        \"Dingo\",\n                        \"Dhole\",\n                        \"African Hunting Dog\"]\n\n\n\n\n    def __len__(self):\n        return len(self._flat_breed_images)\n\n    def __getitem__(self, index):\n        \"\"\"\n        Args:\n            index (int): Index\n        Returns:\n            tuple: (image, target) where target is index of the target character class.\n        \"\"\"\n        image_name, target_class = self._flat_breed_images[index]\n        image_path = join(self.images_folder, image_name)\n        image = Image.open(image_path).convert('RGB')\n\n        if self.cropped:\n            image = image.crop(self._flat_breed_annotations[index][1])\n\n        if self.transform:\n            image = self.transform(image)\n\n        if self.target_transform:\n            target_class = self.target_transform(target_class)\n\n        return image, target_class\n\n    def download(self):\n        import tarfile\n\n        if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):\n            if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:\n                print('Files already downloaded and verified')\n                return\n\n        for filename in ['images', 'annotation', 'lists']:\n            tar_filename = filename + '.tar'\n            url = self.download_url_prefix + '/' + tar_filename\n            download_url(url, self.root, tar_filename, None)\n            print('Extracting downloaded file: ' + join(self.root, tar_filename))\n            with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:\n                tar_file.extractall(self.root)\n            os.remove(join(self.root, tar_filename))\n\n    @staticmethod\n    def get_boxes(path):\n        import xml.etree.ElementTree\n        e = xml.etree.ElementTree.parse(path).getroot()\n        boxes = []\n        for objs in e.iter('object'):\n            boxes.append([int(objs.find('bndbox').find('xmin').text),\n                          int(objs.find('bndbox').find('ymin').text),\n                          int(objs.find('bndbox').find('xmax').text),\n                          int(objs.find('bndbox').find('ymax').text)])\n        return boxes\n\n    def load_split(self):\n        if self.train:\n            split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']\n            labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']\n        else:\n            split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']\n            labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']\n\n        split = [item[0][0] for item in split]\n        labels = [item[0]-1 for item in labels]\n        return list(zip(split, labels))\n\n    def stats(self):\n        counts = {}\n        for index in range(len(self._flat_breed_images)):\n            image_name, target_class = self._flat_breed_images[index]\n            if target_class not in counts.keys():\n                counts[target_class] = 1\n            else:\n                counts[target_class] += 1\n\n        print(\"%d samples spanning %d classes (avg %f per class)\"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))\n\n        return counts"
  },
  {
    "path": "data/transforms_factory.py",
    "content": "\"\"\" Transforms Factory\nFactory methods for building image transforms for use with TIMM (PyTorch Image Models)\n\nHacked together by / Copyright 2019, Ross Wightman\n\"\"\"\nimport math\n\nimport torch\nfrom torchvision import transforms\n\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT\nfrom timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform\nfrom timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy\nfrom timm.data.random_erasing import RandomErasing\n\n\ndef transforms_direct_resize(\n        img_size=224,\n        interpolation='bilinear',\n        use_prefetcher=False,\n        mean=IMAGENET_DEFAULT_MEAN,\n        std=IMAGENET_DEFAULT_STD,\n):\n    if interpolation == 'random':\n        # random interpolation not supported with no-aug\n        interpolation = 'bilinear'\n    tfl = [\n        transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),\n        transforms.CenterCrop(img_size)\n    ]\n    if use_prefetcher:\n        # prefetcher and collate will handle tensor conversion and norm\n        tfl += [ToNumpy()]\n    else:\n        tfl += [\n            transforms.ToTensor(),\n            transforms.Normalize(\n                mean=torch.tensor(mean),\n                std=torch.tensor(std))\n        ]\n    return transforms.Compose(tfl)\n\ndef transforms_simpleaug_train(\n        img_size=224,\n        scale=None,\n        ratio=None,\n        hflip=0.5,\n        interpolation='bilinear',\n        use_prefetcher=False,\n        mean=IMAGENET_DEFAULT_MEAN,\n        std=IMAGENET_DEFAULT_STD,\n):\n    tfl = [\n        RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation),\n        transforms.RandomHorizontalFlip(p=hflip)\n    ]\n    if use_prefetcher:\n        # prefetcher and collate will handle tensor conversion and norm\n        tfl += [ToNumpy()]\n    else:\n        tfl += [\n            transforms.ToTensor(),\n            transforms.Normalize(\n                mean=torch.tensor(mean),\n                std=torch.tensor(std))\n        ]\n    return transforms.Compose(tfl)\n\n\n\n\ndef transforms_imagenet_train(\n        img_size=224,\n        scale=None,\n        ratio=None,\n        hflip=0.5,\n        vflip=0.,\n        color_jitter=0.4,\n        auto_augment=None,\n        interpolation='random',\n        use_prefetcher=False,\n        mean=IMAGENET_DEFAULT_MEAN,\n        std=IMAGENET_DEFAULT_STD,\n        re_prob=0.,\n        re_mode='const',\n        re_count=1,\n        re_num_splits=0,\n        separate=False,\n):\n    \"\"\"\n    If separate==True, the transforms are returned as a tuple of 3 separate transforms\n    for use in a mixing dataset that passes\n     * all data through the first (primary) transform, called the 'clean' data\n     * a portion of the data through the secondary transform\n     * normalizes and converts the branches above with the third, final transform\n    \"\"\"\n    scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range\n    ratio = tuple(ratio or (3./4., 4./3.))  # default imagenet ratio range\n    primary_tfl = [\n        RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]\n    if hflip > 0.:\n        primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]\n    if vflip > 0.:\n        primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]\n\n    secondary_tfl = []\n    if auto_augment:\n        assert isinstance(auto_augment, str)\n        if isinstance(img_size, (tuple, list)):\n            img_size_min = min(img_size)\n        else:\n            img_size_min = img_size\n        aa_params = dict(\n            translate_const=int(img_size_min * 0.45),\n            img_mean=tuple([min(255, round(255 * x)) for x in mean]),\n        )\n        if interpolation and interpolation != 'random':\n            aa_params['interpolation'] = str_to_pil_interp(interpolation)\n        if auto_augment.startswith('rand'):\n            secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]\n        elif auto_augment.startswith('augmix'):\n            aa_params['translate_pct'] = 0.3\n            secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]\n        else:\n            secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]\n    elif color_jitter is not None:\n        # color jitter is enabled when not using AA\n        if isinstance(color_jitter, (list, tuple)):\n            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation\n            # or 4 if also augmenting hue\n            assert len(color_jitter) in (3, 4)\n        else:\n            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue\n            color_jitter = (float(color_jitter),) * 3\n        secondary_tfl += [transforms.ColorJitter(*color_jitter)]\n\n    final_tfl = []\n    if use_prefetcher:\n        # prefetcher and collate will handle tensor conversion and norm\n        final_tfl += [ToNumpy()]\n    else:\n        final_tfl += [\n            transforms.ToTensor(),\n            transforms.Normalize(\n                mean=torch.tensor(mean),\n                std=torch.tensor(std))\n        ]\n        if re_prob > 0.:\n            final_tfl.append(\n                RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))\n\n    if separate:\n        return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)\n    else:\n        return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)\n\n\ndef transforms_imagenet_eval(\n        img_size=224,\n        crop_pct=None,\n        interpolation='bilinear',\n        use_prefetcher=False,\n        mean=IMAGENET_DEFAULT_MEAN,\n        std=IMAGENET_DEFAULT_STD):\n    crop_pct = crop_pct or DEFAULT_CROP_PCT\n\n    if isinstance(img_size, (tuple, list)):\n        assert len(img_size) == 2\n        if img_size[-1] == img_size[-2]:\n            # fall-back to older behaviour so Resize scales to shortest edge if target is square\n            scale_size = int(math.floor(img_size[0] / crop_pct))\n        else:\n            scale_size = tuple([int(x / crop_pct) for x in img_size])\n    else:\n        scale_size = int(math.floor(img_size / crop_pct))\n\n\n    tfl = [\n        transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),\n        transforms.CenterCrop(img_size),\n    ]\n\n\n    if use_prefetcher:\n        # prefetcher and collate will handle tensor conversion and norm\n        tfl += [ToNumpy()]\n    else:\n        tfl += [\n            transforms.ToTensor(),\n            transforms.Normalize(\n                     mean=torch.tensor(mean),\n                     std=torch.tensor(std))\n        ]\n\n    return transforms.Compose(tfl)\n\n\n\ndef create_transform(\n        input_size,\n        is_training=False,\n        use_prefetcher=False,\n        no_aug=False,\n        simple_aug=False,\n        direct_resize=False,\n        scale=None,\n        ratio=None,\n        hflip=0.5,\n        vflip=0.,\n        color_jitter=0.4,\n        auto_augment=None,\n        interpolation='bilinear',\n        mean=IMAGENET_DEFAULT_MEAN,\n        std=IMAGENET_DEFAULT_STD,\n        re_prob=0.,\n        re_mode='const',\n        re_count=1,\n        re_num_splits=0,\n        crop_pct=None,\n        tf_preprocessing=False,\n        separate=False):\n\n    if isinstance(input_size, (tuple, list)):\n        img_size = input_size[-2:]\n    else:\n        img_size = input_size\n\n    if tf_preprocessing and use_prefetcher:\n        assert not separate, \"Separate transforms not supported for TF preprocessing\"\n        from timm.data.tf_preprocessing import TfPreprocessTransform\n        transform = TfPreprocessTransform(\n            is_training=is_training, size=img_size, interpolation=interpolation)\n    else:\n        if is_training:\n            if no_aug:\n                assert not separate, \"Cannot perform split augmentation with no_aug\"\n                transform = transforms_direct_resize(\n                    img_size,\n                    interpolation=interpolation,\n                    use_prefetcher=use_prefetcher,\n                    mean=mean,\n                    std=std)\n            elif simple_aug:\n                transform = transforms_simpleaug_train(\n                    img_size,\n                    scale=scale,\n                    ratio=ratio,\n                    hflip=hflip,\n                    interpolation=interpolation,\n                    use_prefetcher=use_prefetcher,\n                    mean=mean,\n                    std=std)\n            else:\n                transform = transforms_imagenet_train(\n                    img_size,\n                    scale=scale,\n                    ratio=ratio,\n                    hflip=hflip,\n                    vflip=vflip,\n                    color_jitter=color_jitter,\n                    auto_augment=auto_augment,\n                    interpolation=interpolation,\n                    use_prefetcher=use_prefetcher,\n                    mean=mean,\n                    std=std,\n                    re_prob=re_prob,\n                    re_mode=re_mode,\n                    re_count=re_count,\n                    re_num_splits=re_num_splits,\n                    separate=separate)\n        else:\n            if direct_resize:\n                #print('direct_resize')\n                transform = transforms_direct_resize(\n                    img_size,\n                    interpolation=interpolation,\n                    use_prefetcher=use_prefetcher,\n                    mean=mean,\n                    std=std)\n            else:\n                assert not separate, \"Separate transforms not supported for validation preprocessing\"\n                transform = transforms_imagenet_eval(\n                    img_size,\n                    interpolation=interpolation,\n                    use_prefetcher=use_prefetcher,\n                    mean=mean,\n                    std=std,\n                    crop_pct=crop_pct)\n\n    return transform\n"
  },
  {
    "path": "data/vtab.py",
    "content": "import os\nfrom torchvision.datasets.folder import ImageFolder, default_loader\n\nclass VTAB(ImageFolder):\n    def __init__(self, root, train=True, transform=None, target_transform=None, mode=None,is_individual_prompt=False,**kwargs):\n        self.dataset_root = root\n        self.loader = default_loader\n        self.target_transform = None\n        self.transform = transform\n\n        \n        train_list_path = os.path.join(self.dataset_root, 'train800val200.txt')\n        test_list_path = os.path.join(self.dataset_root, 'test.txt')\n\n        \n        # train_list_path = os.path.join(self.dataset_root, 'train800.txt')\n        # test_list_path = os.path.join(self.dataset_root, 'val200.txt')\n\n\n        self.samples = []\n        if train:\n            with open(train_list_path, 'r') as f:\n                for line in f:\n                    img_name = line.split(' ')[0]\n                    label = int(line.split(' ')[1])\n                    self.samples.append((os.path.join(root,img_name), label))\n        else:\n            with open(test_list_path, 'r') as f:\n                for line in f:\n                    img_name = line.split(' ')[0]\n                    label = int(line.split(' ')[1])\n                    self.samples.append((os.path.join(root,img_name), label))"
  },
  {
    "path": "log/README.md",
    "content": "\n"
  },
  {
    "path": "log/cifar100.csv",
    "content": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,5.603768242730035,6.0231625,0.96,4.89\r\n1,4.440675179163615,5.8859375,1.02,5.18\r\n2,3.186616155836317,5.63614375,1.43,6.43\r\n3,2.827679475148519,5.345575,1.92,8.46\r\n4,2.7690945731268988,5.04094375,2.7,11.41\r\n5,2.6746084690093994,4.7265375,4.07,15.29\r\n6,2.5769188139173718,4.404325,6.44,20.77\r\n7,2.658121665318807,4.07631875,10.12,28.31\r\n8,2.5106263955434165,3.74056875,15.35,37.6\r\n9,2.638317664464315,3.398590625,23.04,49.38\r\n10,2.4921720822652182,3.05421875,32.87,61.79\r\n11,2.5904562208387585,2.714753125,44.93,72.77\r\n12,2.521847221586439,2.38726875,57.29,81.89\r\n13,2.555907726287842,2.0806125,67.52,88.52\r\n14,2.678542561001248,1.8002078125,75.47,92.47\r\n15,2.5564871629079184,1.549046875,81.04,94.85\r\n16,2.5498899353875055,1.329515625,84.56,96.42\r\n17,2.555638392766317,1.1425953125,86.4,97.29\r\n18,2.5207514233059354,0.9856625,87.97,97.84\r\n19,2.5135546260409884,0.8570640625,88.97,98.17\r\n20,2.417558749516805,0.75346328125,89.95,98.5\r\n21,2.4377992947896323,0.669475,90.59,98.76\r\n22,2.5416789849599204,0.6010578125,91.05,98.88\r\n23,2.4467749065823026,0.5459734375,91.42,98.99\r\n24,2.3756895065307617,0.5007609375,91.64,99.1\r\n25,2.4439392619662814,0.46489609375,91.92,99.19\r\n26,2.4532913896772595,0.43501953125,92.09,99.26\r\n27,2.508685827255249,0.410413671875,92.19,99.3\r\n28,2.489635467529297,0.390225,92.39,99.31\r\n29,2.4972835646735296,0.37323125,92.44,99.36\r\n30,2.4116059409247503,0.358880078125,92.54,99.42\r\n31,2.4685837162865534,0.346959765625,92.7,99.46\r\n32,2.3165384001202054,0.336875,92.71,99.47\r\n33,2.4255740377638073,0.32844765625,92.71,99.46\r\n34,2.4002869658999972,0.321278125,92.76,99.45\r\n35,2.4275462097591824,0.314995703125,92.8,99.47\r\n36,2.407806317011515,0.309953125,92.91,99.49\r\n37,2.402532418568929,0.30562421875,92.98,99.51\r\n38,2.4551497830284967,0.30197578125,93.06,99.51\r\n39,2.4293878608279758,0.2986171875,93.17,99.51\r\n40,2.4136725531684027,0.296045703125,93.24,99.51\r\n41,2.4232251379224987,0.29378125,93.32,99.52\r\n42,2.2722683482699924,0.2918,93.31,99.52\r\n43,2.397341330846151,0.290190625,93.35,99.52\r\n44,2.406024138132731,0.289073828125,93.39,99.53\r\n45,2.3778079880608454,0.288259765625,93.42,99.53\r\n46,2.4535727500915527,0.28774140625,93.47,99.53\r\n47,2.4934064282311335,0.287232421875,93.56,99.53\r\n48,2.325168079800076,0.287040234375,93.59,99.52\r\n49,2.396822929382324,0.286894140625,93.58,99.55\r\n50,2.3157892756991916,0.287205859375,93.61,99.55\r\n51,2.4792808956570096,0.287432421875,93.63,99.56\r\n52,2.366891860961914,0.287920703125,93.65,99.56\r\n53,2.345345550113254,0.2883359375,93.67,99.56\r\n54,2.303606006834242,0.288706640625,93.66,99.56\r\n55,2.398844109641181,0.28921875,93.68,99.57\r\n56,2.3735866281721325,0.2897921875,93.68,99.57\r\n57,2.4807073805067272,0.2904640625,93.69,99.58\r\n58,2.349070734447903,0.291125,93.76,99.58\r\n59,2.375280910068088,0.292005859375,93.8,99.58\r\n60,2.33055845896403,0.293062109375,93.82,99.57\r\n61,2.4362279574076333,0.29396484375,93.79,99.57\r\n62,2.317348506715563,0.29523125,93.8,99.57\r\n63,2.4144566191567316,0.29618203125,93.8,99.57\r\n64,2.3383621904585095,0.297141796875,93.79,99.56\r\n65,2.4115795029534235,0.298030078125,93.79,99.56\r\n66,2.3216844929589167,0.299159375,93.82,99.57\r\n67,2.318764951494005,0.30013984375,93.77,99.56\r\n68,2.269577423731486,0.30101484375,93.75,99.56\r\n69,2.369996494717068,0.30192890625,93.79,99.55\r\n70,2.3662983311547174,0.303022265625,93.81,99.56\r\n71,2.303777880138821,0.303908984375,93.85,99.56\r\n72,2.255165616671244,0.30502734375,93.82,99.56\r\n73,2.2920398712158203,0.305877734375,93.82,99.56\r\n74,2.310059520933363,0.30675234375,93.87,99.56\r\n75,2.325947019788954,0.307698046875,93.82,99.57\r\n76,2.2926743825276694,0.308682421875,93.82,99.55\r\n77,2.317892154057821,0.3095203125,93.86,99.55\r\n78,2.3864409658643932,0.310484375,93.86,99.55\r\n79,2.2767338487837048,0.311702734375,93.88,99.55\r\n80,2.459476047092014,0.31251640625,93.89,99.55\r\n81,2.3905074066585965,0.313465234375,93.87,99.55\r\n82,2.380774630440606,0.314390625,93.88,99.54\r\n83,2.191994031270345,0.31534296875,93.91,99.54\r\n84,2.316111962000529,0.31621796875,93.9,99.54\r\n85,2.307388676537408,0.31720625,93.89,99.54\r\n86,2.27423980500963,0.318103125,93.9,99.54\r\n87,2.3265264564090304,0.31895,93.91,99.54\r\n88,2.265656683180067,0.31984140625,93.94,99.54\r\n89,2.3482420444488525,0.320740625,93.97,99.54\r\n90,2.3093814849853516,0.321472265625,93.97,99.54\r\n91,2.3022206094529896,0.322198046875,93.97,99.54\r\n92,2.3547442489200168,0.3228609375,93.99,99.54\r\n93,2.246518611907959,0.3234953125,93.97,99.54\r\n94,2.3851645787556968,0.324027734375,93.98,99.56\r\n95,2.3422129816479154,0.324636328125,93.93,99.56\r\n96,2.2714282936520047,0.325282421875,93.93,99.56\r\n97,2.374925719367133,0.32581875,93.93,99.55\r\n98,2.3505734867519803,0.326328125,93.94,99.54\r\n99,2.418484025531345,0.326844140625,93.95,99.55\r\n100,2.293912728627523,0.32743125,93.94,99.55\r\n101,2.446575509177314,0.327946875,93.95,99.55\r\n102,2.309349775314331,0.32843125,93.96,99.55\r\n103,2.316555658976237,0.3289484375,93.95,99.55\r\n104,2.3285930156707764,0.329358203125,93.94,99.56\r\n105,2.2936308648851185,0.32980234375,93.93,99.56\r\n106,2.3283233642578125,0.330234765625,93.95,99.57\r\n107,2.3618712955050998,0.33079453125,93.95,99.57\r\n108,2.3460081683264837,0.3311765625,93.95,99.57\r\n109,2.352536598841349,0.33160625,93.96,99.57\r\n"
  },
  {
    "path": "models/as_mlp.py",
    "content": "# --------------------------------------------------------\n# AS-MLP\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Zehao Yu and Dongze Lian (AS-MLP)\n# --------------------------------------------------------\n\nimport logging\nimport math\nfrom copy import deepcopy\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.fx_features import register_notrace_function\nfrom timm.models.helpers import build_model_with_cfg \nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.layers import _assert\nfrom timm.models.registry import register_model\n\nfrom timm.models.vision_transformer import checkpoint_filter_fn \n\n\n\n_logger = logging.getLogger(__name__)\n\ndef _cfg(url='', file='', **kwargs):\n    return {\n        'url': url,\n        'file': file,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\ndefault_cfgs = {\n    'as_base_patch4_window7_224': _cfg(\n        file='/path/to/asmlp_base_patch4_shift5_224.pth'\n    ),\n}\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tuning_mode=None):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1)\n        self.drop = nn.Dropout(drop)\n\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':        \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)\n\n\n    def forward(self, x):\n        x = self.fc1(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x) \n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n\n        x = self.drop(x)\n        \n        return x\n\n\n\nclass AxialShift(nn.Module):\n    r\"\"\" Axial shift  \n\n    Args:\n        dim (int): Number of input channels.\n        shift_size (int): shift size .\n        as_bias (bool, optional):  If True, add a learnable bias to as mlp. Default: True\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, shift_size, as_bias=True, proj_drop=0., tuning_mode=None):\n\n        super().__init__()\n        self.dim = dim\n        self.shift_size = shift_size\n        self.pad = shift_size // 2\n        self.conv1 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)\n        self.conv2_1 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)\n        self.conv2_2 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)\n        self.conv3 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)\n\n        self.actn = nn.GELU()\n\n        self.norm1 = MyNorm(dim)\n        self.norm2 = MyNorm(dim)\n        self.tuning_mode = tuning_mode\n\n\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)\n            self.ssf_scale_3, self.ssf_shift_3 = init_ssf_scale_shift(dim)\n            self.ssf_scale_4, self.ssf_shift_4 = init_ssf_scale_shift(dim)\n\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, C, H, W = x.shape\n        x = self.conv1(x)\n        if self.tuning_mode == 'ssf':   \n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n        \n        x = self.norm1(x)\n        x = self.actn(x)\n       \n        \n        x = F.pad(x, (self.pad, self.pad, self.pad, self.pad) , \"constant\", 0)\n        \n        xs = torch.chunk(x, self.shift_size, 1)\n\n        def shift(dim):\n            x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]\n            x_cat = torch.cat(x_shift, 1)\n            x_cat = torch.narrow(x_cat, 2, self.pad, H)\n            x_cat = torch.narrow(x_cat, 3, self.pad, W)\n            return x_cat\n\n        x_shift_lr = shift(3)\n        x_shift_td = shift(2)\n    \n\n        if self.tuning_mode == 'ssf':   \n            x_lr = ssf_ada(self.conv2_1(x_shift_lr), self.ssf_scale_2, self.ssf_shift_2)\n            x_td = ssf_ada(self.conv2_2(x_shift_td), self.ssf_scale_3, self.ssf_shift_3)\n        else:\n            x_lr = self.conv2_1(x_shift_lr)\n            x_td = self.conv2_2(x_shift_td)\n\n        x_lr = self.actn(x_lr)\n        x_td = self.actn(x_td)\n\n        x = x_lr + x_td\n        x = self.norm2(x)\n\n        x = self.conv3(x)\n        if self.tuning_mode == 'ssf': \n            x = ssf_ada(x, self.ssf_scale_4, self.ssf_shift_4)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, shift_size={self.shift_size}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # conv1 \n        flops += N * self.dim * self.dim\n        # norm 1\n        flops += N * self.dim\n        # conv2_1 conv2_2\n        flops += N * self.dim * self.dim * 2\n        # x_lr + x_td\n        flops += N * self.dim\n        # norm2\n        flops += N * self.dim\n        # norm3\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass AxialShiftedBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        shift_size (int): Shift size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        as_bias (bool, optional): If True, add a learnable bias to Axial Mlp. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, shift_size=7,\n                 mlp_ratio=4., as_bias=True, drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n\n        self.norm1 = norm_layer(dim)\n        self.axial_shift = AxialShift(dim, shift_size=shift_size, as_bias=as_bias, proj_drop=drop, tuning_mode=tuning_mode)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, tuning_mode=tuning_mode)\n\n\n        self.tuning_mode = tuning_mode\n\n        if tuning_mode == 'ssf':    \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)\n\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n\n        shortcut = x\n        \n        x = self.norm1(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        # axial shift block\n        x = self.axial_shift(x)  # B, C, H, W\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        if self.tuning_mode == 'ssf':\n            x = x + self.drop_path(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))\n        else:\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, \" \\\n               f\"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # shift mlp \n        flops += self.axial_shift.flops(H * W)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tuning_mode=None):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Conv2d(4 * dim, 2 * dim, 1, 1, bias=False)\n        self.norm = norm_layer(4 * dim)\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(4 * dim)\n\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        B, C, H, W = x.shape\n        #assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, C, H, W)\n\n        x0 = x[:, :, 0::2, 0::2]  # B C H/2 W/2 \n        x1 = x[:, :, 1::2, 0::2]  # B C H/2 W/2 \n        x2 = x[:, :, 0::2, 1::2]  # B C H/2 W/2 \n        x3 = x[:, :, 1::2, 1::2]  # B C H/2 W/2 \n        x = torch.cat([x0, x1, x2, x3], 1)  # B 4*C H/2 W/2 \n\n        x = self.norm(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, shift_size,\n                 mlp_ratio=4., as_bias=True, drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, tuning_mode=None):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            AxialShiftedBlock(dim=dim, input_resolution=input_resolution,\n                              shift_size=shift_size,\n                              mlp_ratio=mlp_ratio,\n                              as_bias=as_bias,\n                              drop=drop, \n                              drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                              norm_layer=norm_layer, tuning_mode=tuning_mode[i])\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, tuning_mode=tuning_mode)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, tuning_mode=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n        \n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)\n\n            \n\n            if norm_layer:\n                self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)\n\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)#.flatten(2).transpose(1, 2)  # B Ph*Pw C\n\n        if self.tuning_mode == 'ssf':  \n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n            if self.norm is not None:\n                x = self.norm(x)\n                x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n        else:\n            if self.norm is not None:\n                x = self.norm(x)\n\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\ndef MyNorm(dim):\n    return nn.GroupNorm(1, dim)\n\n\ndef init_ssf_scale_shift(dim):\n    scale = nn.Parameter(torch.ones(dim))\n    shift = nn.Parameter(torch.zeros(dim))\n\n    nn.init.normal_(scale, mean=1, std=.02)\n    nn.init.normal_(shift, std=.02)\n\n    return scale, shift\n\n\ndef ssf_ada(x, scale, shift):\n    assert scale.shape == shift.shape\n    if x.shape[-1] == scale.shape[0]:\n        return x * scale + shift\n    elif x.shape[1] == scale.shape[0]:\n        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)\n    else:\n        raise ValueError('the input tensor shape does not match the shape of the scale factor.')\n\n\nclass AS_MLP(nn.Module):\n    r\"\"\" AS-MLP\n        A PyTorch impl of : `AS-MLP: An Axial Shifted MLP Architecture for Vision`  -\n          https://arxiv.org/pdf/xxx.xxx\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each AS-MLP layer.\n        window_size (int): shift size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        as_bias (bool): If True, add a learnable bias to as-mlp block. Default: True\n        drop_rate (float): Dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.GroupNorm with group=1.\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], \n                 shift_size=5, mlp_ratio=4., as_bias=True, \n                 drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=MyNorm, patch_norm=True,\n                 use_checkpoint=False, tuning_mode=None, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        \n        self.tuning_mode = tuning_mode\n        tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(self.num_layers)]\n        if tuning_mode == 'ssf':   \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features)\n            \n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               shift_size=shift_size,\n                               mlp_ratio=self.mlp_ratio,\n                               as_bias=as_bias,\n                               drop=drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint,\n                               tuning_mode=tuning_mode_list[i_layer])\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        #self.apply(self._init_weights)\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n        \n        x = self.norm(x)  # B L C\n        if self.tuning_mode == 'ssf': \n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n        \n        x = self.avgpool(x)  # B C 1 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n\n\ndef _create_as_mlp(variant, pretrained=False, **kwargs):\n    model = build_model_with_cfg(\n        AS_MLP, variant, pretrained,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        **kwargs)\n\n    return model\n\n\n@register_model\ndef as_base_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" AS-MLP-B @ 224x224, pretrained ImageNet-1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, shift_size=5, embed_dim=128, depths=(2, 2, 18, 2), **kwargs)\n    return _create_as_mlp('as_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n"
  },
  {
    "path": "models/convnext.py",
    "content": "\"\"\" ConvNeXt\n\nPaper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf\n\nOriginal code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below\n\nModifications and additions for timm hacked together by / Copyright 2022, Ross Wightman\n\"\"\"\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n# This source code is licensed under the MIT license\nimport math\nfrom collections import OrderedDict\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.fx_features import register_notrace_module\nfrom timm.models.helpers import named_apply, build_model_with_cfg, checkpoint_seq\nfrom timm.models.layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, to_2tuple\n\nfrom timm.models.registry import register_model\n\n\n\n__all__ = ['ConvNeXt']  # model_registry will add each entrypoint fn to this\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),\n        'crop_pct': 0.875, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'stem.0', 'classifier': 'head.fc',\n        **kwargs\n    }\n\n\ndefault_cfgs = dict(\n    convnext_tiny=_cfg(url=\"https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth\"),\n    convnext_small=_cfg(url=\"https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth\"),\n    convnext_base=_cfg(url=\"https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth\"),\n    convnext_large=_cfg(url=\"https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth\"),\n\n    convnext_tiny_hnf=_cfg(url=''),\n\n    convnext_base_in22ft1k=_cfg(\n        url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),\n    convnext_large_in22ft1k=_cfg(\n        url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'),\n    convnext_xlarge_in22ft1k=_cfg(\n        url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),\n\n    convnext_base_384_in22ft1k=_cfg(\n        url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',\n        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),\n    convnext_large_384_in22ft1k=_cfg(\n        url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',\n        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),\n    convnext_xlarge_384_in22ft1k=_cfg(\n        url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',\n        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),\n\n    convnext_base_in22k=_cfg(\n        url=\"https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth\", num_classes=21841),\n    convnext_large_in22k=_cfg(\n        url=\"https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth\", num_classes=21841),\n    convnext_xlarge_in22k=_cfg(\n        url=\"https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth\", num_classes=21841),\n)\n\n\n\ndef _is_contiguous(tensor: torch.Tensor) -> bool:\n    # jit is oh so lovely :/\n    # if torch.jit.is_tracing():\n    #     return True\n    if torch.jit.is_scripting():\n        return tensor.is_contiguous()\n    else:\n        return tensor.is_contiguous(memory_format=torch.contiguous_format)\n\n\n\n@register_notrace_module\nclass LayerNorm2d(nn.LayerNorm):\n    r\"\"\" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).\n    \"\"\"\n\n    def __init__(self, normalized_shape, eps=1e-6):\n        super().__init__(normalized_shape, eps=eps)\n\n    def forward(self, x) -> torch.Tensor:\n        if _is_contiguous(x):\n            return F.layer_norm(\n                x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)\n        else:\n            s, u = torch.var_mean(x, dim=1, keepdim=True)\n            x = (x - u) * torch.rsqrt(s + self.eps)\n            x = x * self.weight[:, None, None] + self.bias[:, None, None]\n            return x\n\n\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':        \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)\n        \n\n    def forward(self, x):\n        x = self.fc1(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x) \n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n\n        x = self.drop2(x)\n        \n        return x\n\n\n\nclass ConvNeXtBlock(nn.Module):\n    \"\"\" ConvNeXt Block\n    There are two equivalent implementations:\n      (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)\n      (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back\n\n    Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate\n    choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear\n    is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.\n\n    Args:\n        dim (int): Number of input channels.\n        drop_path (float): Stochastic depth rate. Default: 0.0\n        ls_init_value (float): Init value for Layer Scale. Default: 1e-6.\n    \"\"\"\n\n    def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None, tuning_mode=None):\n        super().__init__()\n        if not norm_layer:\n            norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)\n        mlp_layer = ConvMlp if conv_mlp else Mlp\n        self.use_conv_mlp = conv_mlp\n        self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv\n        self.norm = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, tuning_mode=tuning_mode)\n\n        self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':        \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)\n\n\n\n\n    def forward(self, x):\n        shortcut = x\n        x = self.conv_dw(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        if self.use_conv_mlp:\n            x = self.norm(x)\n            x = self.mlp(x)\n        else:\n            x = x.permute(0, 2, 3, 1)\n            x = self.norm(x)\n            if self.tuning_mode == 'ssf':\n                x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n\n            x = self.mlp(x)\n            x = x.permute(0, 3, 1, 2)\n        if self.gamma is not None:\n            x = x.mul(self.gamma.reshape(1, -1, 1, 1))\n        x = self.drop_path(x) + shortcut\n\n        return x\n\n\nclass Downsample(nn.Module):\n    \"\"\" 2D Image to Downsample\n    \"\"\"\n    def __init__(self, dim, out_dim, kernel_size, stride, norm_layer=None, tuning_mode=None):\n        super().__init__()\n\n        self.norm = norm_layer(dim)\n        self.proj = nn.Conv2d(dim, out_dim, kernel_size=stride, stride=stride)\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_dim)\n\n\n    def forward(self, x):\n        if self.tuning_mode == 'ssf':  \n            x = ssf_ada(self.norm(x), self.ssf_scale_1, self.ssf_shift_1)\n            x = ssf_ada(self.proj(x), self.ssf_scale_2, self.ssf_shift_2)\n        else:\n            x = self.norm(x)\n            x = self.proj(x)\n\n        return x\n            \n\n\nclass ConvNeXtStage(nn.Module):\n\n    def __init__(\n            self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False,\n            norm_layer=None, cl_norm_layer=None, cross_stage=False, tuning_mode=None):\n        super().__init__()\n        self.grad_checkpointing = False \n\n        if in_chs != out_chs or stride > 1:\n            self.downsample = Downsample(dim=in_chs, out_dim=out_chs, kernel_size=stride, stride=stride, norm_layer=norm_layer, tuning_mode=tuning_mode)\n        else:\n            self.downsample = nn.Identity()\n\n        dp_rates = dp_rates or [0.] * depth\n        self.blocks = nn.Sequential(*[ConvNeXtBlock(\n            dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp,\n            norm_layer=norm_layer if conv_mlp else cl_norm_layer, tuning_mode=tuning_mode[j])\n            for j in range(depth)]\n        )\n\n    def forward(self, x):\n        x = self.downsample(x)\n\n        if self.grad_checkpointing and not torch.jit.is_scripting():\n            x = checkpoint_seq(self.blocks, x)\n        else:\n            x = self.blocks(x)\n\n        return x\n\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\n    \"\"\"\n    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, tuning_mode=None):\n        super().__init__()\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)\n\n\n    def forward(self, x):\n        if self.tuning_mode == 'ssf':  \n            x = ssf_ada(self.proj(x), self.ssf_scale_1, self.ssf_shift_1)\n            x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2)\n        else:\n            x = self.proj(x)\n            x = self.norm(x)\n\n        return x\n\n\ndef init_ssf_scale_shift(dim):\n    scale = nn.Parameter(torch.ones(dim))\n    shift = nn.Parameter(torch.zeros(dim))\n\n    nn.init.normal_(scale, mean=1, std=.02)\n    nn.init.normal_(shift, std=.02)\n\n    return scale, shift\n\n\ndef ssf_ada(x, scale, shift):\n    assert scale.shape == shift.shape\n    if x.shape[-1] == scale.shape[0]:\n        return x * scale + shift\n    elif x.shape[1] == scale.shape[0]:\n        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)\n    else:\n        raise ValueError('the input tensor shape does not match the shape of the scale factor.')\n\n\nclass ConvNeXt(nn.Module):\n    r\"\"\" ConvNeXt\n        A PyTorch impl of : `A ConvNet for the 2020s`  - https://arxiv.org/pdf/2201.03545.pdf\n\n    Args:\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]\n        dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]\n        drop_rate (float): Head dropout rate\n        drop_path_rate (float): Stochastic depth rate. Default: 0.\n        ls_init_value (float): Init value for Layer Scale. Default: 1e-6.\n        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.\n    \"\"\"\n\n    def __init__(\n            self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,\n            depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),  ls_init_value=1e-6, conv_mlp=False,\n            head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., tuning_mode=None\n    ):\n        super().__init__()\n        assert output_stride == 32\n        if norm_layer is None:\n            norm_layer = partial(LayerNorm2d, eps=1e-6)\n            cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)\n        else:\n            assert conv_mlp,\\\n                'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'\n            cl_norm_layer = norm_layer\n\n        self.num_classes = num_classes\n        self.drop_rate = drop_rate\n        self.feature_info = []\n\n        # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4\n        self.stem = PatchEmbed(patch_size=4, in_chans=3, embed_dim=dims[0], norm_layer=norm_layer, tuning_mode=tuning_mode)\n\n        self.tuning_mode = tuning_mode\n        tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(len(depths))]\n        if tuning_mode == 'ssf':       \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dims[3])\n\n        self.stages = nn.Sequential()\n        dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]\n        curr_stride = patch_size\n        prev_chs = dims[0]\n        stages = []\n        # 4 feature resolution stages, each consisting of multiple residual blocks\n        for i in range(4):\n            stride = 2 if i > 0 else 1\n            # FIXME support dilation / output_stride\n            curr_stride *= stride\n            out_chs = dims[i]\n            stages.append(ConvNeXtStage(\n                prev_chs, out_chs, stride=stride,\n                depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp,\n                norm_layer=norm_layer, cl_norm_layer=cl_norm_layer, tuning_mode=tuning_mode_list[i])\n            )\n            prev_chs = out_chs\n            # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2\n            self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]\n        self.stages = nn.Sequential(*stages)\n\n        self.num_features = prev_chs\n        if head_norm_first:\n            # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)\n            self.norm_pre = norm_layer(self.num_features)  # final norm layer, before pooling\n            self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)\n        else:\n            # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)\n            self.norm_pre = nn.Identity()\n            self.head = nn.Sequential(OrderedDict([\n                ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),\n                ('norm', norm_layer(self.num_features)),\n                ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),\n                ('drop', nn.Dropout(self.drop_rate)),\n                ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())\n            ]))\n\n        named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)\n\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        for s in self.stages:\n            s.grad_checkpointing = enable\n\n    def get_classifier(self):\n        return self.head.fc\n\n    def reset_classifier(self, num_classes=0, global_pool='avg'):\n        if isinstance(self.head, ClassifierHead):\n            # norm -> global pool -> fc\n            self.head = ClassifierHead(\n                self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)\n        else:\n            # pool -> norm -> fc\n            self.head = nn.Sequential(OrderedDict([\n                ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),\n                ('norm', self.head.norm),\n                ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),\n                ('drop', nn.Dropout(self.drop_rate)),\n                ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())\n            ]))\n\n    def forward_features(self, x):\n        x = self.stem(x)\n        x = self.stages(x)\n        x = self.norm_pre(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(self.norm_pre(x), self.ssf_scale_1, self.ssf_shift_1)\n\n        return x\n\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\ndef _init_weights(module, name=None, head_init_scale=1.0):\n    if isinstance(module, nn.Conv2d):\n        trunc_normal_(module.weight, std=.02)\n        nn.init.constant_(module.bias, 0)\n    elif isinstance(module, nn.Linear):\n        trunc_normal_(module.weight, std=.02)\n        nn.init.constant_(module.bias, 0)\n        if name and 'head.' in name:\n            module.weight.data.mul_(head_init_scale)\n            module.bias.data.mul_(head_init_scale)\n\n\ndef checkpoint_filter_fn(state_dict, model):\n    \"\"\" Remap FB checkpoints -> timm \"\"\"\n    #ipdb.set_trace()\n    if 'model' in state_dict:\n        state_dict = state_dict['model']\n    out_dict = {}\n    import re\n    for k, v in state_dict.items():\n        k = k.replace('downsample_layers.0.0.', 'stem.proj.')\n        k = k.replace('downsample_layers.0.1.', 'stem.norm.')\n\n\n        k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\\1.blocks.\\2', k)\n\n        k = re.sub(r'downsample_layers.([0-9]+).([0]+)', r'stages.\\1.downsample.norm', k)\n        k = re.sub(r'downsample_layers.([0-9]+).([1]+)', r'stages.\\1.downsample.proj', k)\n\n\n        k = k.replace('dwconv', 'conv_dw')\n        k = k.replace('pwconv', 'mlp.fc')\n        k = k.replace('head.', 'head.fc.')\n        if k.startswith('norm.'):\n            k = k.replace('norm', 'head.norm')\n        if v.ndim == 2 and 'head' not in k:\n            model_shape = model.state_dict()[k].shape\n            v = v.reshape(model_shape)\n        out_dict[k] = v\n    return out_dict\n\n\ndef _create_convnext(variant, pretrained=False, **kwargs):\n    model = build_model_with_cfg(\n        ConvNeXt, variant, pretrained,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),\n        **kwargs)\n    return model\n\n\n\n@register_model\ndef convnext_tiny(pretrained=False, **kwargs):\n    model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)\n    model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_tiny_hnf(pretrained=False, **kwargs):\n    model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs)\n    model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_small(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)\n    model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_base(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)\n    model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_large(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)\n    model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_base_in22ft1k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)\n    model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_large_in22ft1k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)\n    model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_xlarge_in22ft1k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)\n    model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_base_384_in22ft1k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)\n    model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_large_384_in22ft1k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)\n    model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)\n    model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_base_in22k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)\n    model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_large_in22k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)\n    model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args)\n    return model\n\n\n@register_model\ndef convnext_xlarge_in22k(pretrained=False, **kwargs):\n    model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)\n    model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)\n    return model\n"
  },
  {
    "path": "models/swin_transformer.py",
    "content": "\"\"\" Swin Transformer\nA PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`\n    - https://arxiv.org/pdf/2103.14030\nCode/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below\nModifications and additions for timm hacked together by / Copyright 2021, Ross Wightman\n\"\"\"\n# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\nimport logging\nimport math\nfrom copy import deepcopy\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.fx_features import register_notrace_function\nfrom timm.models.helpers import build_model_with_cfg, named_apply \nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.layers import _assert\nfrom timm.models.registry import register_model\n\nfrom timm.models.vision_transformer import checkpoint_filter_fn, get_init_weights_vit \n\nimport ipdb\n\n_logger = logging.getLogger(__name__)\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    # patch models (my experiments)\n    'swin_base_patch4_window12_384': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0),\n\n    'swin_base_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',\n    ),\n\n    'swin_large_patch4_window12_384': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0),\n\n    'swin_large_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',\n    ),\n\n    'swin_small_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',\n    ),\n\n    'swin_tiny_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',\n    ),\n\n    'swin_base_patch4_window12_384_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),\n\n    'swin_base_patch4_window7_224_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',\n        num_classes=21841),\n\n    'swin_large_patch4_window12_384_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),\n\n    'swin_large_patch4_window7_224_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',\n        num_classes=21841),\n\n}\n\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n            \n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':        \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)\n\n\n    def forward(self, x):\n        x = self.fc1(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x) \n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n        \n        x = self.drop2(x)\n        \n        return x\n\n        \ndef window_partition(x, window_size: int):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\n@register_notrace_function  # reason: int argument is a Proxy\ndef window_reverse(windows, window_size: int, H: int, W: int):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., tuning_mode=None):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n        self.tuning_mode = tuning_mode\n\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)\n\n\n\n    def forward(self, x, mask: Optional[torch.Tensor] = None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        if self.tuning_mode == 'ssf':\n            #qkv = (self.qkv(x) * self.ssf_scale_1 + self.ssf_shift_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n            qkv = (ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1)).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        else:\n            qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        \n        x = self.proj(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n\n        x = self.proj_drop(x)\n        return x\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,\n            attn_drop=attn_drop, proj_drop=drop, tuning_mode=tuning_mode)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, tuning_mode=tuning_mode)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n        self.tuning_mode = tuning_mode\n\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)\n        \n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n\n        shortcut = x\n        \n        x = self.norm1(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        if self.tuning_mode == 'ssf':\n            x = x + self.drop_path(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))\n        else:\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tuning_mode=None):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(4 * dim)\n\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n\n        _assert(L == H * W, \"input feature has wrong size\")\n        _assert(H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\")\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, tuning_mode=None):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(\n                dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,\n                shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, tuning_mode=tuning_mode[i])\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, tuning_mode=tuning_mode)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if not torch.jit.is_scripting() and self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, tuning_mode=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n        self.flatten = flatten\n        self.norm_layer = norm_layer\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)\n\n            if norm_layer:\n                self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)\n\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        _assert(H == self.img_size[0], f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\")\n        _assert(W == self.img_size[1], f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\")\n\n        x = self.proj(x) \n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        if self.tuning_mode == 'ssf':  \n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n            if self.norm_layer:\n                x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2)\n            else:\n                x = self.norm(x)\n        else:\n            x = self.norm(x)\n        return x\n\n\n\ndef init_ssf_scale_shift(dim):\n    scale = nn.Parameter(torch.ones(dim))\n    shift = nn.Parameter(torch.zeros(dim))\n\n    nn.init.normal_(scale, mean=1, std=.02)\n    nn.init.normal_(shift, std=.02)\n\n    return scale, shift\n\n\ndef ssf_ada(x, scale, shift):\n    assert scale.shape == shift.shape\n    if x.shape[-1] == scale.shape[0]:\n        return x * scale + shift\n    elif x.shape[1] == scale.shape[0]:\n        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)\n    else:\n        raise ValueError('the input tensor shape does not match the shape of the scale factor.')\n\n\nclass SwinTransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),\n                 window_size=7, mlp_ratio=4., qkv_bias=True,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, weight_init='', tuning_mode=None, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None, tuning_mode=tuning_mode)\n        num_patches = self.patch_embed.num_patches\n        self.patch_grid = self.patch_embed.grid_size\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n        else:\n            self.absolute_pos_embed = None\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        self.tuning_mode = tuning_mode\n        tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(self.num_layers)]\n        if tuning_mode == 'ssf':   \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features)\n\n        # build layers\n        layers = []\n        for i_layer in range(self.num_layers):\n            layers += [BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=window_size,\n                mlp_ratio=self.mlp_ratio,\n                qkv_bias=qkv_bias,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                use_checkpoint=use_checkpoint, \n                tuning_mode=tuning_mode_list[i_layer])\n            ]\n        self.layers = nn.Sequential(*layers)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        if weight_init != 'skip':\n            self.init_weights(weight_init)\n\n\n    @torch.jit.ignore\n    def init_weights(self, mode=''):\n        assert mode in ('jax', 'jax_nlhb', 'moco', '')\n        if self.absolute_pos_embed is not None:\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.\n        named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.absolute_pos_embed is not None:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n        x = self.layers(x)\n        x = self.norm(x)\n\n        if self.tuning_mode == 'ssf': \n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\ndef _create_swin_transformer(variant, pretrained=False, **kwargs):  \n    model = build_model_with_cfg(\n        SwinTransformer, variant, pretrained,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        **kwargs)\n\n    return model\n\n\n\n@register_model\ndef swin_base_patch4_window12_384(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_base_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window12_384(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_small_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-S @ 224x224, trained ImageNet-1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_tiny_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-T @ 224x224, trained ImageNet-1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 384x384, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 224x224, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 384x384, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 224x224, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)"
  },
  {
    "path": "models/vision_transformer.py",
    "content": "\"\"\" Vision Transformer (ViT) in PyTorch\n\nA PyTorch implement of Vision Transformers as described in:\n\n'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'\n    - https://arxiv.org/abs/2010.11929\n\n`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`\n    - https://arxiv.org/abs/2106.10270\n\nThe official jax code is released and available at https://github.com/google-research/vision_transformer\n\nAcknowledgments:\n* The paper authors for releasing code and weights, thanks!\n* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out\nfor some einops/einsum fun\n* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT\n* Bert reference code checks against Huggingface Transformers and Tensorflow Bert\n\nHacked together by / Copyright 2020, Ross Wightman\n\"\"\"\nimport math\nimport logging\nfrom functools import partial\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD\nfrom timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv, resolve_pretrained_cfg, checkpoint_seq\nfrom timm.models.layers import DropPath, trunc_normal_, lecun_normal_, _assert\nfrom timm.models.layers.helpers import to_2tuple\nfrom timm.models.registry import register_model\n\n\n\nimport ipdb\n\n\n_logger = logging.getLogger(__name__)\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,\n        'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    # patch models (weights from official Google JAX impl)\n    'vit_tiny_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_tiny_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_small_patch32_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_small_patch32_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_small_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_small_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_base_patch32_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_base_patch32_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_base_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),\n    'vit_base_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_base_patch8_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),\n    'vit_large_patch32_224': _cfg(\n        url='',  # no official model weights for this combo, only for in21k\n        ),\n    'vit_large_patch32_384': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_large_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),\n    'vit_large_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n\n\n\n    # patch models, imagenet21k (weights from official Google JAX impl)\n    'vit_tiny_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_small_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_base_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_large_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',\n        num_classes=21843),\n\n\n}\n\n\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n            \n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':        \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)\n\n\n    def forward(self, x):\n        x = self.fc1(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x) \n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n        \n        x = self.drop2(x)\n        \n        return x\n\n\n        \n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., tuning_mode=None):\n        super().__init__()\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)\n\n\n\n    def forward(self, x):\n        B, N, C = x.shape\n        if self.tuning_mode == 'ssf':\n            qkv = (ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1)).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        else:\n            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        if self.tuning_mode == 'ssf':\n            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)\n        x = self.proj_drop(x)\n        return x\n\n\nclass LayerScale(nn.Module):\n    def __init__(self, dim, init_values=1e-5, inplace=False):\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x):\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\n\n\nclass Block(nn.Module):\n\n    def __init__(\n            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,\n            drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None):\n        super().__init__()\n        self.dim = dim\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, tuning_mode=tuning_mode)\n        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, tuning_mode=tuning_mode)\n        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)\n            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)\n\n\n\n    def forward(self, x):\n        if self.tuning_mode == 'ssf':\n            x = x + self.drop_path1(self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1))))\n            x = x + self.drop_path2(self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))))\n        else:\n            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))\n            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))\n        return x\n\n\nclass ResPostBlock(nn.Module):\n    def __init__(\n            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,\n            drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.init_values = init_values\n\n        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)\n        self.norm1 = norm_layer(dim)\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)\n        self.norm2 = norm_layer(dim)\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        self.init_weights()\n\n    def init_weights(self):\n        # NOTE this init overrides that base model init with specific changes for the block type\n        if self.init_values is not None:\n            nn.init.constant_(self.norm1.weight, self.init_values)\n            nn.init.constant_(self.norm2.weight, self.init_values)\n\n    def forward(self, x):\n        x = x + self.drop_path1(self.norm1(self.attn(x)))\n        x = x + self.drop_path2(self.norm2(self.mlp(x)))\n        return x\n\n\nclass ParallelBlock(nn.Module):\n\n    def __init__(\n            self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,\n            drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.num_parallel = num_parallel\n        self.attns = nn.ModuleList()\n        self.ffns = nn.ModuleList()\n        for _ in range(num_parallel):\n            self.attns.append(nn.Sequential(OrderedDict([\n                ('norm', norm_layer(dim)),\n                ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),\n                ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),\n                ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())\n            ])))\n            self.ffns.append(nn.Sequential(OrderedDict([\n                ('norm', norm_layer(dim)),\n                ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),\n                ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),\n                ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())\n            ])))\n\n    def _forward_jit(self, x):\n        x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)\n        x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)\n        return x\n\n    @torch.jit.ignore\n    def _forward(self, x):\n        x = x + sum(attn(x) for attn in self.attns)\n        x = x + sum(ffn(x) for ffn in self.ffns)\n        return x\n\n    def forward(self, x):\n        if torch.jit.is_scripting() or torch.jit.is_tracing():\n            return self._forward_jit(x)\n        else:\n            return self._forward(x)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, tuning_mode=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n        self.flatten = flatten\n        self.norm_layer = norm_layer\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n        self.tuning_mode = tuning_mode\n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)\n\n            if norm_layer:\n                self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)\n\n\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        _assert(H == self.img_size[0], f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\")\n        _assert(W == self.img_size[1], f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\")\n\n        x = self.proj(x) \n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        if self.tuning_mode == 'ssf':  \n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n            if self.norm_layer:\n                x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2)\n            else:\n                x = self.norm(x)\n        else:\n            x = self.norm(x)\n        return x\n\n\n\ndef init_ssf_scale_shift(dim):\n    scale = nn.Parameter(torch.ones(dim))\n    shift = nn.Parameter(torch.zeros(dim))\n\n    nn.init.normal_(scale, mean=1, std=.02)\n    nn.init.normal_(shift, std=.02)\n\n    return scale, shift\n\n\ndef ssf_ada(x, scale, shift):\n    assert scale.shape == shift.shape\n    if x.shape[-1] == scale.shape[0]:\n        return x * scale + shift\n    elif x.shape[1] == scale.shape[0]:\n        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)\n    else:\n        raise ValueError('the input tensor shape does not match the shape of the scale factor.')\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer\n\n    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`\n        - https://arxiv.org/abs/2010.11929\n    \"\"\"\n\n    def __init__(\n            self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',\n            embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,\n            class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',\n            embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, tuning_mode=None): \n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            num_classes (int): number of classes for classification head\n            global_pool (str): type of global pooling for final sequence (default: 'token')\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            init_values: (float): layer-scale init values\n            class_token (bool): use class token\n            fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)\n            drop_rate (float): dropout rate\n            attn_drop_rate (float): attention dropout rate\n            drop_path_rate (float): stochastic depth rate\n            weight_init (str): weight init scheme\n            embed_layer (nn.Module): patch embedding layer\n            norm_layer: (nn.Module): normalization layer\n            act_layer: (nn.Module): MLP activation layer\n        \"\"\"\n        super().__init__()\n        assert global_pool in ('', 'avg', 'token')\n        assert class_token or global_pool != 'token'\n        use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        act_layer = act_layer or nn.GELU\n\n        self.num_classes = num_classes\n        self.global_pool = global_pool\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.num_tokens = 1 if class_token else 0\n        self.grad_checkpointing = False \n\n        self.patch_embed = embed_layer(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None\n        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        \n        self.tuning_mode = tuning_mode\n        tuning_mode_list = [tuning_mode] * depth \n        if tuning_mode == 'ssf':     \n            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features)\n\n        self.blocks = nn.Sequential(*[\n            block_fn(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, tuning_mode=tuning_mode_list[i])\n            for i in range(depth)])\n\n\n        self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()\n\n        # Classifier Head\n        self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        if weight_init != 'skip':\n            self.init_weights(weight_init)\n\n    def init_weights(self, mode=''):\n        assert mode in ('jax', 'jax_nlhb', 'moco', '')\n        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.\n        trunc_normal_(self.pos_embed, std=.02)\n        if self.cls_token is not None:\n            nn.init.normal_(self.cls_token, std=1e-6)\n        named_apply(get_init_weights_vit(mode, head_bias), self)\n\n    def _init_weights(self, m):\n        # this fn left here for compat with downstream users\n        init_weights_vit_timm(m)\n\n    @torch.jit.ignore()\n    def load_pretrained(self, checkpoint_path, prefix=''):\n        _load_weights(self, checkpoint_path, prefix)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token', 'dist_token'}\n\n    @torch.jit.ignore\n    def group_matcher(self, coarse=False):\n        return dict(\n            stem=r'^cls_token|pos_embed|patch_embed',  # stem and embed\n            blocks=[(r'^blocks\\.(\\d+)', None), (r'^norm', (99999,))]\n        )\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.grad_checkpointing = enable\n\n    @torch.jit.ignore\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes: int, global_pool=None):\n        self.num_classes = num_classes\n        if global_pool is not None:\n            assert global_pool in ('', 'avg', 'token')\n            self.global_pool = global_pool\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.cls_token is not None:\n            x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n        x = self.pos_drop(x + self.pos_embed)\n\n        if self.grad_checkpointing and not torch.jit.is_scripting():\n            x = checkpoint_seq(self.blocks, x)\n        else:\n            x = self.blocks(x)\n\n        x = self.norm(x)\n        if self.tuning_mode == 'ssf': \n            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)\n            \n        return x \n\n    def forward_head(self, x, pre_logits: bool = False):\n        if self.global_pool:\n            x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]\n        x = self.fc_norm(x)\n        return x if pre_logits else self.head(x)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.forward_head(x)\n        \n        return x \n\n\ndef init_weights_vit_timm(module: nn.Module, name: str = ''):\n    \"\"\" ViT weight initialization, original timm impl (for reproducibility) \"\"\"\n    if isinstance(module, nn.Linear):\n        trunc_normal_(module.weight, std=.02)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif hasattr(module, 'init_weights'):\n        module.init_weights()\n\n\ndef init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):\n    \"\"\" ViT weight initialization, matching JAX (Flax) impl \"\"\"\n    if isinstance(module, nn.Linear):\n        if name.startswith('head'):\n            nn.init.zeros_(module.weight)\n            nn.init.constant_(module.bias, head_bias)\n        else:\n            nn.init.xavier_uniform_(module.weight)\n            if module.bias is not None:\n                nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)\n    elif isinstance(module, nn.Conv2d):\n        lecun_normal_(module.weight)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif hasattr(module, 'init_weights'):\n        module.init_weights()\n\n\ndef init_weights_vit_moco(module: nn.Module, name: str = ''):\n    \"\"\" ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed \"\"\"\n    if isinstance(module, nn.Linear):\n        if 'qkv' in name:\n            # treat the weights of Q, K, V separately\n            val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))\n            nn.init.uniform_(module.weight, -val, val)\n        else:\n            nn.init.xavier_uniform_(module.weight)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif hasattr(module, 'init_weights'):\n        module.init_weights()\n\n\ndef get_init_weights_vit(mode='jax', head_bias: float = 0.):\n    if 'jax' in mode:\n        return partial(init_weights_vit_jax, head_bias=head_bias)\n    elif 'moco' in mode:\n        return init_weights_vit_moco\n    else:\n        return init_weights_vit_timm\n\n\n@torch.no_grad()\ndef _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):\n    \"\"\" Load weights from .npz checkpoints for official Google Brain Flax implementation\n    \"\"\"\n    import numpy as np\n\n    def _n2p(w, t=True):\n        if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:\n            w = w.flatten()\n        if t:\n            if w.ndim == 4:\n                w = w.transpose([3, 2, 0, 1])\n            elif w.ndim == 3:\n                w = w.transpose([2, 0, 1])\n            elif w.ndim == 2:\n                w = w.transpose([1, 0])\n        return torch.from_numpy(w)\n\n    w = np.load(checkpoint_path)\n    if not prefix and 'opt/target/embedding/kernel' in w:\n        prefix = 'opt/target/'\n\n    if hasattr(model.patch_embed, 'backbone'):\n        # hybrid\n        backbone = model.patch_embed.backbone\n        stem_only = not hasattr(backbone, 'stem')\n        stem = backbone if stem_only else backbone.stem\n        stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))\n        stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))\n        stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))\n        if not stem_only:\n            for i, stage in enumerate(backbone.stages):\n                for j, block in enumerate(stage.blocks):\n                    bp = f'{prefix}block{i + 1}/unit{j + 1}/'\n                    for r in range(3):\n                        getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))\n                        getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))\n                        getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))\n                    if block.downsample is not None:\n                        block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))\n                        block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))\n                        block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))\n        embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])\n    else:\n        embed_conv_w = adapt_input_conv(\n            model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))\n    model.patch_embed.proj.weight.copy_(embed_conv_w)\n    model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))\n    model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))\n    pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)\n    if pos_embed_w.shape != model.pos_embed.shape:\n        pos_embed_w = resize_pos_embed(  # resize pos embedding when different size from pretrained weights\n            pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)\n    model.pos_embed.copy_(pos_embed_w)\n    model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))\n    model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))\n    if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:\n        model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))\n        model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))\n    # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights\n    # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:\n    #     model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))\n    #     model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))\n    for i, block in enumerate(model.blocks.children()):\n        block_prefix = f'{prefix}Transformer/encoderblock_{i}/'\n        mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'\n        block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))\n        block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))\n        block.attn.qkv.weight.copy_(torch.cat([\n            _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))\n        block.attn.qkv.bias.copy_(torch.cat([\n            _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))\n        block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))\n        block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))\n        for r in range(2):\n            getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))\n            getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))\n        block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))\n        block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))\n\n\ndef resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)\n    ntok_new = posemb_new.shape[1]\n    if num_tokens:\n        posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]\n        ntok_new -= num_tokens\n    else:\n        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))\n    if not len(gs_new):  # backwards compatibility\n        gs_new = [int(math.sqrt(ntok_new))] * 2\n    assert len(gs_new) >= 2\n    _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n    return posemb\n\n\ndef checkpoint_filter_fn(state_dict, model):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    if 'model' in state_dict:\n        # For deit models\n        state_dict = state_dict['model']\n    for k, v in state_dict.items():\n        if 'patch_embed.proj.weight' in k and len(v.shape) < 4:\n            # For old models that I trained prior to conv based patchification\n            O, I, H, W = model.patch_embed.proj.weight.shape\n            v = v.reshape(O, -1, H, W)\n        elif k == 'pos_embed' and v.shape != model.pos_embed.shape:\n            # To resize pos embedding when using model at different size from pretrained weights\n            v = resize_pos_embed(\n                v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)\n        elif 'pre_logits' in k:\n            # NOTE representation layer removed as not used in latest 21k/1k pretrained weights\n            continue\n        out_dict[k] = v\n    return out_dict\n\n\ndef _create_vision_transformer(variant, pretrained=False, **kwargs):\n    if kwargs.get('features_only', None):\n        raise RuntimeError('features_only not implemented for Vision Transformer models.')\n\n    pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))\n    model = build_model_with_cfg(\n        VisionTransformer, variant, pretrained,\n        pretrained_cfg=pretrained_cfg,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        pretrained_custom_load='npz' in pretrained_cfg['url'],\n        **kwargs)\n    return model\n\n\n\n@register_model\ndef vit_tiny_patch16_224(pretrained=False, **kwargs):\n    \"\"\" ViT-Tiny (Vit-Ti/16)\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)\n    model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef vit_tiny_patch16_384(pretrained=False, **kwargs):\n    \"\"\" ViT-Tiny (Vit-Ti/16) @ 384x384.\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)\n    model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n\n\n@register_model\ndef vit_small_patch16_224(pretrained=False, **kwargs):\n    \"\"\" ViT-Small (ViT-S/16)\n    NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)\n    model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef vit_small_patch16_384(pretrained=False, **kwargs):\n    \"\"\" ViT-Small (ViT-S/16)\n    NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)\n    model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n\n\n@register_model\ndef vit_base_patch16_224(pretrained=False, **kwargs):\n    \"\"\" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef vit_base_patch16_384(pretrained=False, **kwargs):\n    \"\"\" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n\n@register_model\ndef vit_large_patch16_224(pretrained=False, **kwargs):\n    \"\"\" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)\n    model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef vit_large_patch16_384(pretrained=False, **kwargs):\n    \"\"\" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)\n    model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n\n@register_model\ndef vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):\n    \"\"\" ViT-Tiny (Vit-Ti/16).\n    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.\n    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)\n    model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef vit_small_patch16_224_in21k(pretrained=False, **kwargs):\n    \"\"\" ViT-Small (ViT-S/16)\n    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.\n    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)\n    model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef vit_base_patch16_224_in21k(pretrained=False, **kwargs):\n    \"\"\" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.\n    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef vit_large_patch16_224_in21k(pretrained=False, **kwargs):\n    \"\"\" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.\n    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer\n    \"\"\"\n    model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)\n    model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)\n    return model\n\n"
  },
  {
    "path": "optim_factory.py",
    "content": "\"\"\" Optimizer Factory w/ Custom Weight Decay\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\nimport json\nfrom itertools import islice\nfrom typing import Optional, Callable, Tuple, Dict, Union\n\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\n#from timm.models.helpers import group_parameters\n\nfrom timm.optim.adabelief import AdaBelief\nfrom timm.optim.adafactor import Adafactor\nfrom timm.optim.adahessian import Adahessian\nfrom timm.optim.adamp import AdamP\nfrom timm.optim.lamb import Lamb\nfrom timm.optim.lars import Lars \nfrom timm.optim.lookahead import Lookahead\nfrom timm.optim.madgrad import MADGRAD\nfrom timm.optim.nadam import Nadam \nfrom timm.optim.nvnovograd import NvNovoGrad\nfrom timm.optim.radam import RAdam\nfrom timm.optim.rmsprop_tf import RMSpropTF\nfrom timm.optim.sgdp import SGDP\n\n\n\ntry:\n    from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\n\n\ndef param_groups_weight_decay(\n        model: nn.Module,\n        weight_decay=1e-5,\n        no_weight_decay_list=()\n):\n    no_weight_decay_list = set(no_weight_decay_list)\n    decay = []\n    no_decay = []\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n\n        if param.ndim <= 1 or name.endswith(\".bias\") or name in no_weight_decay_list:\n            no_decay.append(param)\n        else:\n            decay.append(param)\n\n    return [\n        {'params': no_decay, 'weight_decay': 0.},\n        {'params': decay, 'weight_decay': weight_decay}]\n\n\ndef _group(it, size):\n    it = iter(it)\n    return iter(lambda: tuple(islice(it, size)), ())\n\n\ndef _layer_map(model, layers_per_group=12, num_groups=None):\n    def _in_head(n, hp):\n        if not hp:\n            return True\n        elif isinstance(hp, (tuple, list)):\n            return any([n.startswith(hpi) for hpi in hp])\n        else:\n            return n.startswith(hp)\n\n    head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)\n    names_trunk = []\n    names_head = []\n    for n, _ in model.named_parameters():\n        names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)\n\n    # group non-head layers\n    num_trunk_layers = len(names_trunk)\n    if num_groups is not None:\n        layers_per_group = -(num_trunk_layers // -num_groups)\n    names_trunk = list(_group(names_trunk, layers_per_group))\n\n    num_trunk_groups = len(names_trunk)\n    layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}\n    layer_map.update({n: num_trunk_groups for n in names_head})\n    return layer_map\n\n\n\ndef group_with_matcher(\n        named_objects,\n        group_matcher: Union[Dict, Callable],\n        output_values: bool = False,\n        reverse: bool = False\n):\n    if isinstance(group_matcher, dict):\n        # dictionary matcher contains a dict of raw-string regex expr that must be compiled\n        compiled = []\n        for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):\n            if mspec is None:\n                continue\n            # map all matching specifications into 3-tuple (compiled re, prefix, suffix)\n            if isinstance(mspec, (tuple, list)):\n                # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)\n                for sspec in mspec:\n                    compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]\n            else:\n                compiled += [(re.compile(mspec), (group_ordinal,), None)]\n        group_matcher = compiled\n\n    def _get_grouping(name):\n        if isinstance(group_matcher, (list, tuple)):\n            for match_fn, prefix, suffix in group_matcher:\n                r = match_fn.match(name)\n                if r:\n                    parts = (prefix, r.groups(), suffix)\n                    # map all tuple elem to int for numeric sort, filter out None entries\n                    return tuple(map(float, chain.from_iterable(filter(None, parts))))\n            return float('inf'),  # un-matched layers (neck, head) mapped to largest ordinal\n        else:\n            ord = group_matcher(name)\n            if not isinstance(ord, collections.abc.Iterable):\n                return ord,\n            return tuple(ord)\n\n    # map layers into groups via ordinals (ints or tuples of ints) from matcher\n    grouping = defaultdict(list)\n    for k, v in named_objects:\n        grouping[_get_grouping(k)].append(v if output_values else k)\n\n    # remap to integers\n    layer_id_to_param = defaultdict(list)\n    lid = -1\n    for k in sorted(filter(lambda x: x is not None, grouping.keys())):\n        if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:\n            lid += 1\n        layer_id_to_param[lid].extend(grouping[k])\n\n    if reverse:\n        assert not output_values, \"reverse mapping only sensible for name output\"\n        # output reverse mapping\n        param_to_layer_id = {}\n        for lid, lm in layer_id_to_param.items():\n            for n in lm:\n                param_to_layer_id[n] = lid\n        return param_to_layer_id\n\n    return layer_id_to_param\n\n\ndef group_parameters(\n        module: nn.Module,\n        group_matcher,\n        output_values=False,\n        reverse=False,\n):\n    return group_with_matcher(\n        module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)\n\n\ndef group_modules(\n        module: nn.Module,\n        group_matcher,\n        output_values=False,\n        reverse=False,\n):\n    return group_with_matcher(\n        named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)\n\n\n\n\ndef param_groups_layer_decay(\n        model: nn.Module,\n        weight_decay: float = 0.05,\n        no_weight_decay_list: Tuple[str] = (),\n        layer_decay: float = .75,\n        end_layer_decay: Optional[float] = None,\n):\n    \"\"\"\n    Parameter groups for layer-wise lr decay & weight decay\n    Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58\n    \"\"\"\n    no_weight_decay_list = set(no_weight_decay_list)\n    param_group_names = {}  # NOTE for debugging\n    param_groups = {}\n\n    if hasattr(model, 'group_matcher'):\n        # FIXME interface needs more work\n        layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)\n    else:\n        # fallback\n        layer_map = _layer_map(model)\n    num_layers = max(layer_map.values()) + 1\n    layer_max = num_layers - 1\n    layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))\n\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n\n        # no decay: all 1D parameters and model specific ones\n        if param.ndim == 1 or name in no_weight_decay_list:\n            g_decay = \"no_decay\"\n            this_decay = 0.\n        else:\n            g_decay = \"decay\"\n            this_decay = weight_decay\n\n        layer_id = layer_map.get(name, layer_max)\n        group_name = \"layer_%d_%s\" % (layer_id, g_decay)\n\n        if group_name not in param_groups:\n            this_scale = layer_scales[layer_id]\n            param_group_names[group_name] = {\n                \"lr_scale\": this_scale,\n                \"weight_decay\": this_decay,\n                \"param_names\": [],\n            }\n            param_groups[group_name] = {\n                \"lr_scale\": this_scale,\n                \"weight_decay\": this_decay,\n                \"params\": [],\n            }\n\n        param_group_names[group_name][\"param_names\"].append(name)\n        param_groups[group_name][\"params\"].append(param)\n\n    # FIXME temporary output to debug new feature\n    print(\"parameter groups: \\n%s\" % json.dumps(param_group_names, indent=2))\n\n    return list(param_groups.values())\n\n\ndef optimizer_kwargs(cfg):\n    \"\"\" cfg/argparse to kwargs helper\n    Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.\n    \"\"\"\n    kwargs = dict(\n        opt=cfg.opt,\n        lr=cfg.lr,\n        weight_decay=cfg.weight_decay,\n        momentum=cfg.momentum,\n        tuning_mode=cfg.tuning_mode)\n    if getattr(cfg, 'opt_eps', None) is not None:\n        kwargs['eps'] = cfg.opt_eps\n    if getattr(cfg, 'opt_betas', None) is not None:\n        kwargs['betas'] = cfg.opt_betas\n    if getattr(cfg, 'layer_decay', None) is not None:\n        kwargs['layer_decay'] = cfg.layer_decay\n    if getattr(cfg, 'opt_args', None) is not None:\n        kwargs.update(cfg.opt_args)\n    return kwargs\n\n\ndef create_optimizer(args, model, filter_bias_and_bn=True):\n    \"\"\" Legacy optimizer factory for backwards compatibility.\n    NOTE: Use create_optimizer_v2 for new code.\n    \"\"\"\n    return create_optimizer_v2(\n        model,\n        **optimizer_kwargs(cfg=args),\n        filter_bias_and_bn=filter_bias_and_bn,\n    )\n\n\ndef create_optimizer_v2(\n        model_or_params,\n        opt: str = 'sgd',\n        lr: Optional[float] = None,\n        weight_decay: float = 0.,\n        momentum: float = 0.9,\n        tuning_mode: str = None,\n        filter_bias_and_bn: bool = True,\n        layer_decay: Optional[float] = None,\n        param_group_fn: Optional[Callable] = None,\n        **kwargs):\n    \"\"\" Create an optimizer.\n\n    TODO currently the model is passed in and all parameters are selected for optimization.\n    For more general use an interface that allows selection of parameters to optimize and lr groups, one of:\n      * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion\n      * expose the parameters interface and leave it up to caller\n\n    Args:\n        model_or_params (nn.Module): model containing parameters to optimize\n        opt: name of optimizer to create\n        lr: initial learning rate\n        weight_decay: weight decay to apply in optimizer\n        momentum:  momentum for momentum based optimizers (others may use betas via kwargs)\n        filter_bias_and_bn:  filter out bias, bn and other 1d params from weight decay\n        **kwargs: extra optimizer specific kwargs to pass through\n\n    Returns:\n        Optimizer\n    \"\"\"\n    if isinstance(model_or_params, nn.Module):\n        # TODO: for fine-tuning \n        if tuning_mode:\n            for name, param in model_or_params.named_parameters():\n                if tuning_mode == 'linear_probe':\n                    if \"head.\" not in name:\n                        param.requires_grad = False\n                elif tuning_mode == 'ssf':\n                    if \"head.\" not in name and \"ssf_scale\" not in name and \"ssf_shift_\" not in name: \n                        param.requires_grad = False\n\n                if param.requires_grad == True:\n                    print(name)\n                \n            print('freezing parameters finished!')\n                    \n\n        # a model was passed in, extract parameters and add weight decays to appropriate layers\n        no_weight_decay = {}\n        if hasattr(model_or_params, 'no_weight_decay'):\n            no_weight_decay = model_or_params.no_weight_decay()\n\n        if param_group_fn:\n            parameters = param_group_fn(model_or_params)\n        elif layer_decay is not None:\n            parameters = param_groups_layer_decay(\n                model_or_params,\n                weight_decay=weight_decay,\n                layer_decay=layer_decay,\n                no_weight_decay_list=no_weight_decay)\n            weight_decay = 0.\n        elif weight_decay and filter_bias_and_bn:\n            parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)\n            weight_decay = 0.\n        else:\n            parameters = model_or_params.parameters()\n\n\n\n\n\n    else:\n        # iterable of parameters or param groups passed in\n        parameters = model_or_params\n\n\n\n    opt_lower = opt.lower()\n    opt_split = opt_lower.split('_')\n    opt_lower = opt_split[-1]\n    if 'fused' in opt_lower:\n        assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'\n\n    opt_args = dict(weight_decay=weight_decay, **kwargs)\n    if lr is not None:\n        opt_args.setdefault('lr', lr)\n\n    # basic SGD & related\n    if opt_lower == 'sgd' or opt_lower == 'nesterov':\n        # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons\n        opt_args.pop('eps', None)\n        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)\n    elif opt_lower == 'momentum':\n        opt_args.pop('eps', None)\n        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)\n    elif opt_lower == 'sgdp':\n        optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)\n\n    # adaptive\n    elif opt_lower == 'adam':\n        optimizer = optim.Adam(parameters, **opt_args) \n    elif opt_lower == 'adamw':\n        optimizer = optim.AdamW(parameters, **opt_args)\n    elif opt_lower == 'adamp':\n        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)\n    elif opt_lower == 'nadam':\n        try:\n            # NOTE PyTorch >= 1.10 should have native NAdam\n            optimizer = optim.Nadam(parameters, **opt_args)\n        except AttributeError:\n            optimizer = Nadam(parameters, **opt_args)\n    elif opt_lower == 'radam':\n        optimizer = RAdam(parameters, **opt_args)\n    elif opt_lower == 'adamax':\n        optimizer = optim.Adamax(parameters, **opt_args)\n    elif opt_lower == 'adabelief':\n        optimizer = AdaBelief(parameters, rectify=False, **opt_args)\n    elif opt_lower == 'radabelief':\n        optimizer = AdaBelief(parameters, rectify=True, **opt_args)\n    elif opt_lower == 'adadelta':\n        optimizer = optim.Adadelta(parameters, **opt_args)\n    elif opt_lower == 'adagrad':\n        opt_args.setdefault('eps', 1e-8)\n        optimizer = optim.Adagrad(parameters, **opt_args)\n    elif opt_lower == 'adafactor':\n        optimizer = Adafactor(parameters, **opt_args)\n    elif opt_lower == 'lamb':\n        optimizer = Lamb(parameters, **opt_args)\n    elif opt_lower == 'lambc':\n        optimizer = Lamb(parameters, trust_clip=True, **opt_args)\n    elif opt_lower == 'larc':\n        optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)\n    elif opt_lower == 'lars':\n        optimizer = Lars(parameters, momentum=momentum, **opt_args)\n    elif opt_lower == 'nlarc':\n        optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)\n    elif opt_lower == 'nlars':\n        optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)\n    elif opt_lower == 'madgrad':\n        optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)\n    elif opt_lower == 'madgradw':\n        optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)\n    elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':\n        optimizer = NvNovoGrad(parameters, **opt_args)\n    elif opt_lower == 'rmsprop':\n        optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)\n    elif opt_lower == 'rmsproptf':\n        optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)\n\n    # second order\n    elif opt_lower == 'adahessian':\n        optimizer = Adahessian(parameters, **opt_args)\n\n    # NVIDIA fused optimizers, require APEX to be installed\n    elif opt_lower == 'fusedsgd':\n        opt_args.pop('eps', None)\n        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)\n    elif opt_lower == 'fusedmomentum':\n        opt_args.pop('eps', None)\n        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)\n    elif opt_lower == 'fusedadam':\n        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)\n    elif opt_lower == 'fusedadamw':\n        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)\n    elif opt_lower == 'fusedlamb':\n        optimizer = FusedLAMB(parameters, **opt_args)\n    elif opt_lower == 'fusednovograd':\n        opt_args.setdefault('betas', (0.95, 0.98))\n        optimizer = FusedNovoGrad(parameters, **opt_args)\n\n    else:\n        assert False and \"Invalid optimizer\"\n        raise ValueError\n\n    if len(opt_split) > 1:\n        if opt_split[0] == 'lookahead':\n            optimizer = Lookahead(optimizer)\n\n    return optimizer\n"
  },
  {
    "path": "requirements.txt",
    "content": "pyyaml\nscipy\npandas\nipdb"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python3\n\"\"\" ImageNet Training Script\n\nThis is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet\ntraining results with some of the latest networks and training techniques. It favours canonical PyTorch\nand standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed\nand training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.\n\nThis script was started from an early version of the PyTorch ImageNet example\n(https://github.com/pytorch/examples/tree/master/imagenet)\n\nNVIDIA CUDA specific speedups adopted from NVIDIA Apex examples\n(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)\n\nHacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)\n\"\"\"\nimport argparse\nimport time\nimport yaml\nimport os\nimport logging\nimport numpy as np\nfrom collections import OrderedDict\nfrom contextlib import suppress\nfrom datetime import datetime\n\nimport torch\nimport torch.nn as nn\nimport torchvision.utils\nfrom torch.nn.parallel import DistributedDataParallel as NativeDDP\n\nfrom timm.data import resolve_data_config, Mixup, FastCollateMixup, AugMixDataset\n\n\nfrom timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\\\n    convert_splitbn_model, model_parameters\nfrom timm.utils import *\nfrom timm.loss import *\n\nfrom timm.scheduler import create_scheduler\nfrom timm.utils import ApexScaler, NativeScaler\n\n\nfrom data import create_loader, create_dataset\nfrom optim_factory import create_optimizer_v2, optimizer_kwargs\n\nfrom models import vision_transformer, swin_transformer, convnext, as_mlp\n\n\nimport ipdb\n\ntry:\n    from apex import amp\n    from apex.parallel import DistributedDataParallel as ApexDDP\n    from apex.parallel import convert_syncbn_model\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\ntry:\n    import wandb\n    has_wandb = True\nexcept ImportError: \n    has_wandb = False\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('train')\n\n# The first arg parser parses out only the --config argument, this argument is used to\n# load a yaml file containing key-values that override the defaults for the main parser below\nconfig_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)\nparser.add_argument('-c', '--config', default='', type=str, metavar='FILE',\n                    help='YAML config file specifying default arguments')\n\n\nparser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\n\n# Dataset parameters\nparser.add_argument('data_dir', metavar='DIR',\n                    help='path to dataset')\nparser.add_argument('--dataset', '-d', metavar='NAME', default='',\n                    help='dataset type (default: ImageFolder/ImageTar if empty)')\nparser.add_argument('--train-split', metavar='NAME', default='train',\n                    help='dataset train split (default: train)')\nparser.add_argument('--val-split', metavar='NAME', default='validation',\n                    help='dataset validation split (default: validation)')\nparser.add_argument('--dataset-download', action='store_true', default=False,\n                    help='Allow download of dataset for torch/ and tfds/ datasets that support it.')\nparser.add_argument('--class-map', default='', type=str, metavar='FILENAME',\n                    help='path to class to idx mapping file (default: \"\")')\n\n# Model parameters\nparser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',\n                    help='Name of model to train (default: \"resnet50\"')\nparser.add_argument('--pretrained', action='store_true', default=False,\n                    help='Start with pretrained version of specified network (if avail)')\nparser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',\n                    help='Initialize model from this checkpoint (default: none)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='Resume full model and optimizer state from checkpoint (default: none)')\nparser.add_argument('--no-resume-opt', action='store_true', default=False,\n                    help='prevent resume of optimizer state when resuming model')\nparser.add_argument('--num-classes', type=int, default=None, metavar='N',\n                    help='number of label classes (Model default if None)')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\nparser.add_argument('--img-size', type=int, default=None, metavar='N',\n                    help='Image patch size (default: None => model default)')\nparser.add_argument('--input-size', default=None, nargs=3, type=int,\n                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='Input image center crop percent (for validation only)')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\nparser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',\n                    help='Input batch size for training (default: 128)')\nparser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',\n                    help='Validation batch size override (default: None)')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--torchscript', dest='torchscript', action='store_true',\n                    help='torch.jit.script the full model')\nparser.add_argument('--fuser', default='', type=str,\n                    help=\"Select jit fuser. One of ('', 'te', 'old', 'nvfuser')\")\nparser.add_argument('--grad-checkpointing', action='store_true', default=False,\n                    help='Enable gradient checkpointing through model blocks/stages')\n\n\n# finetuning\nparser.add_argument('--tuning-mode', default=None, type=str,\n                    help='Method of fine-tuning (default: None')\n\n\nparser.add_argument('--evaluate', action='store_true', default=False,\n                    help='evaluate')\n\n\n# Optimizer parameters\nparser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',\n                    help='Optimizer (default: \"sgd\"')\nparser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',\n                    help='Optimizer Epsilon (default: None, use opt default)')\nparser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                    help='Optimizer Betas (default: None, use opt default)')\nparser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                    help='Optimizer momentum (default: 0.9)')\nparser.add_argument('--weight-decay', type=float, default=2e-5,\n                    help='weight decay (default: 2e-5)')\nparser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                    help='Clip gradient norm (default: None, no clipping)')\nparser.add_argument('--clip-mode', type=str, default='norm',\n                    help='Gradient clipping mode. One of (\"norm\", \"value\", \"agc\")')\nparser.add_argument('--layer-decay', type=float, default=None,\n                    help='layer-wise learning rate decay (default: None)')\n\n# Learning rate schedule parameters\nparser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                    help='LR scheduler (default: \"step\"')\nparser.add_argument('--lr', type=float, default=0.05, metavar='LR',\n                    help='learning rate (default: 0.05)')\nparser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                    help='learning rate noise on/off epoch percentages')\nparser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                    help='learning rate noise limit percent (default: 0.67)')\nparser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                    help='learning rate noise std-dev (default: 1.0)')\nparser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',\n                    help='learning rate cycle len multiplier (default: 1.0)')\nparser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',\n                    help='amount to decay each learning rate cycle (default: 0.5)')\nparser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',\n                    help='learning rate cycle limit, cycles enabled if > 1')\nparser.add_argument('--lr-k-decay', type=float, default=1.0,\n                    help='learning rate k-decay for cosine/poly (default: 1.0)')\nparser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',\n                    help='warmup learning rate (default: 0.0001)')\nparser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\nparser.add_argument('--epochs', type=int, default=300, metavar='N',\n                    help='number of epochs to train (default: 300)')\nparser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',\n                    help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')\nparser.add_argument('--start-epoch', default=None, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar=\"MILESTONES\",\n                    help='list of decay epoch indices for multistep lr. must be increasing')\nparser.add_argument('--decay-epochs', type=float, default=100, metavar='N',\n                    help='epoch interval to decay LR')\nparser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',\n                    help='epochs to warmup LR, if scheduler supports')\nparser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\nparser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                    help='patience epochs for Plateau LR scheduler (default: 10')\nparser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                    help='LR decay rate (default: 0.1)')\n\n# Augmentation & regularization parameters\nparser.add_argument('--no-aug', action='store_true', default=False,\n                    help='Disable all training augmentation, override other train aug args')\nparser.add_argument('--simple-aug', action='store_true', default=False,\n                    help='Only randomresize and flip training augmentation, override other train aug args')\nparser.add_argument('--direct-resize', action='store_true', default=False,\n                    help='Direct resize image in validation')\nparser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                    help='Random resize scale (default: 0.08 1.0)')\nparser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',\n                    help='Random resize aspect ratio (default: 0.75 1.33)')\nparser.add_argument('--hflip', type=float, default=0.5,\n                    help='Horizontal flip training aug probability')\nparser.add_argument('--vflip', type=float, default=0.,\n                    help='Vertical flip training aug probability')\nparser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                    help='Color jitter factor (default: 0.4)')\nparser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                    help='Use AutoAugment policy. \"v0\" or \"original\". \" + \"(default: rand-m9-mstd0.5-inc1)')\nparser.add_argument('--aug-repeats', type=float, default=0,\n                    help='Number of augmentation repetitions (distributed training only) (default: 0)')\nparser.add_argument('--aug-splits', type=int, default=0,\n                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')\nparser.add_argument('--jsd-loss', action='store_true', default=False,\n                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')\nparser.add_argument('--bce-loss', action='store_true', default=False,\n                    help='Enable BCE loss w/ Mixup/CutMix use.')\nparser.add_argument('--bce-target-thresh', type=float, default=None,\n                    help='Threshold for binarizing softened BCE targets (default: None, disabled)')\nparser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',\n                    help='Random erase prob (default: 0.)')\nparser.add_argument('--remode', type=str, default='pixel',\n                    help='Random erase mode (default: \"pixel\")')\nparser.add_argument('--recount', type=int, default=1,\n                    help='Random erase count (default: 1)')\nparser.add_argument('--resplit', action='store_true', default=False,\n                    help='Do not random erase first (clean) augmentation split')\nparser.add_argument('--mixup', type=float, default=0.8,\n                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix', type=float, default=1.0,\n                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')\nparser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\nparser.add_argument('--mixup-prob', type=float, default=1.0,\n                    help='Probability of performing mixup or cutmix when either/both is enabled')\nparser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\nparser.add_argument('--mixup-mode', type=str, default='batch',\n                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\nparser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',\n                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')\nparser.add_argument('--smoothing', type=float, default=0.1,\n                    help='Label smoothing (default: 0.1)')\nparser.add_argument('--train-interpolation', type=str, default='random',\n                    help='Training interpolation (random, bilinear, bicubic default: \"random\")')\nparser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                    help='Dropout rate (default: 0.)')\nparser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',\n                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')\nparser.add_argument('--drop-path', type=float, default=None, metavar='PCT',\n                    help='Drop path rate (default: None)')\nparser.add_argument('--drop-block', type=float, default=None, metavar='PCT',\n                    help='Drop block rate (default: None)')\n\n# Batch norm parameters (only works with gen_efficientnet based models currently)\nparser.add_argument('--bn-momentum', type=float, default=None,\n                    help='BatchNorm momentum override (if not None)')\nparser.add_argument('--bn-eps', type=float, default=None,\n                    help='BatchNorm epsilon override (if not None)')\nparser.add_argument('--sync-bn', action='store_true',\n                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')\nparser.add_argument('--dist-bn', type=str, default='reduce',\n                    help='Distribute BatchNorm stats between nodes after each epoch (\"broadcast\", \"reduce\", or \"\")')\nparser.add_argument('--split-bn', action='store_true',\n                    help='Enable separate BN layers per augmentation split.')\n\n# Model Exponential Moving Average\nparser.add_argument('--model-ema', action='store_true', default=False,\n                    help='Enable tracking moving average of model weights')\nparser.add_argument('--model-ema-force-cpu', action='store_true', default=False,\n                    help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')\nparser.add_argument('--model-ema-decay', type=float, default=0.9998,\n                    help='decay factor for model weights moving average (default: 0.9998)')\n\n# Misc\nparser.add_argument('--seed', type=int, default=42, metavar='S',\n                    help='random seed (default: 42)')\nparser.add_argument('--worker-seeding', type=str, default='all',\n                    help='worker seed mode (default: all)')\nparser.add_argument('--log-interval', type=int, default=50, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--recovery-interval', type=int, default=0, metavar='N',\n                    help='how many batches to wait before writing recovery checkpoint')\nparser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',\n                    help='number of checkpoints to keep (default: 10)')\nparser.add_argument('-j', '--workers', type=int, default=8, metavar='N',\n                    help='how many training processes to use (default: 4)')\nparser.add_argument('--save-images', action='store_true', default=False,\n                    help='save images of input bathes every log interval for debugging')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--no-ddp-bb', action='store_true', default=False,\n                    help='Force broadcast buffers for native DDP to off.')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--output', default='', type=str, metavar='PATH',\n                    help='path to output folder (default: none, current dir)')\nparser.add_argument('--experiment', default='', type=str, metavar='NAME',\n                    help='name of train experiment, name of sub-folder for output')\nparser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',\n                    help='Best metric (default: \"top1\"')\nparser.add_argument('--tta', type=int, default=0, metavar='N',\n                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')\nparser.add_argument(\"--local_rank\", default=0, type=int)\nparser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,\n                    help='use the multi-epochs-loader to save time at the beginning of every epoch')\nparser.add_argument('--log-wandb', action='store_true', default=False,\n                    help='log training and validation metrics to wandb')\n\n\ndef _parse_args():\n    # Do we have a config file to parse?\n    args_config, remaining = config_parser.parse_known_args()\n    if args_config.config:\n        with open(args_config.config, 'r') as f:\n            cfg = yaml.safe_load(f)\n            parser.set_defaults(**cfg)\n\n    # The main arg parser parses the rest of the args, the usual\n    # defaults will have been overridden if config file specified.\n    args = parser.parse_args(remaining)\n\n    # Cache the args as a text string to save them in the output dir later\n    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)\n    return args, args_text\n\n\ndef main():\n    setup_default_logging()\n    args, args_text = _parse_args()\n    \n    if args.log_wandb:\n        if has_wandb:\n            wandb.init(project=args.experiment, config=args)\n        else: \n            _logger.warning(\"You've requested to log metrics to wandb but package not found. \"\n                            \"Metrics not being logged to wandb, try `pip install wandb`\")\n             \n    args.prefetcher = not args.no_prefetcher\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n    args.device = 'cuda:0'\n    args.world_size = 1\n    args.rank = 0  # global rank\n    if args.distributed:\n        args.device = 'cuda:%d' % args.local_rank\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'\n                     % (args.rank, args.world_size))\n    else:\n        _logger.info('Training with a single process on 1 GPUs.')\n    assert args.rank >= 0\n\n    # resolve AMP arguments based on PyTorch / Apex availability\n    use_amp = None\n    if args.amp:\n        # `--amp` chooses native amp before apex (APEX ver not actively maintained)\n        if has_native_amp:\n            args.native_amp = True\n        elif has_apex:\n            args.apex_amp = True\n    if args.apex_amp and has_apex:\n        use_amp = 'apex'\n    elif args.native_amp and has_native_amp:\n        use_amp = 'native'\n    elif args.apex_amp or args.native_amp:\n        _logger.warning(\"Neither APEX or native Torch AMP is available, using float32. \"\n                        \"Install NVIDA apex or upgrade to PyTorch 1.6\")\n\n    random_seed(args.seed, args.rank)\n\n    if args.fuser:\n        set_jit_fuser(args.fuser)\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        drop_rate=args.drop,\n        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path\n        drop_path_rate=args.drop_path,\n        drop_block_rate=args.drop_block,\n        global_pool=args.gp,\n        bn_momentum=args.bn_momentum,\n        bn_eps=args.bn_eps,\n        scriptable=args.torchscript,\n        checkpoint_path=args.initial_checkpoint,\n        tuning_mode=args.tuning_mode)\n\n    \n    if args.num_classes is None:\n        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'\n        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly\n\n    if args.grad_checkpointing:\n        model.set_grad_checkpointing(enable=True)\n\n\n    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)\n\n    # setup augmentation batch splits for contrastive loss or split bn\n    num_aug_splits = 0\n    if args.aug_splits > 0:\n        assert args.aug_splits > 1, 'A split of 1 makes no sense'\n        num_aug_splits = args.aug_splits\n\n    # enable split bn (separate bn stats per batch-portion)\n    if args.split_bn:\n        assert num_aug_splits > 1 or args.resplit\n        model = convert_splitbn_model(model, max(num_aug_splits, 2))\n\n    # move model to GPU, enable channels last layout if set\n    model.cuda()\n    if args.channels_last:\n        model = model.to(memory_format=torch.channels_last)\n\n    # setup synchronized BatchNorm for distributed training\n    if args.distributed and args.sync_bn:\n        assert not args.split_bn\n        if has_apex and use_amp == 'apex':\n            # Apex SyncBN preferred unless native amp is activated\n            model = convert_syncbn_model(model)\n        else:\n            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n        if args.local_rank == 0:\n            _logger.info(\n                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '\n                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')\n\n    if args.torchscript:\n        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'\n        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'\n        model = torch.jit.script(model)\n\n    optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))\n\n    if args.local_rank == 0:\n        _logger.info(\n            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')\n        _logger.info(f\"number of params for requires grad: {sum(p.numel() for p in model.parameters() if p.requires_grad)}\")\n\n\n    # setup automatic mixed-precision (AMP) loss scaling and op casting\n    amp_autocast = suppress  # do nothing\n    loss_scaler = None\n    if use_amp == 'apex':\n        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n        loss_scaler = ApexScaler()\n        if args.local_rank == 0:\n            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')\n    elif use_amp == 'native':\n        amp_autocast = torch.cuda.amp.autocast\n        loss_scaler = NativeScaler()\n        if args.local_rank == 0:\n            _logger.info('Using native Torch AMP. Training in mixed precision.')\n    else:\n        if args.local_rank == 0:\n            _logger.info('AMP not enabled. Training in float32.')\n\n\n    # optionally resume from a checkpoint\n    resume_epoch = None\n    if args.resume:\n        resume_epoch = resume_checkpoint(\n            model, args.resume,\n            optimizer=None if args.no_resume_opt else optimizer,\n            loss_scaler=None if args.no_resume_opt else loss_scaler,\n            log_info=args.local_rank == 0)\n\n\n    # setup exponential moving average of model weights, SWA could be used here too\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper\n        model_ema = ModelEmaV2(\n            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)\n        if args.resume:\n            load_checkpoint(model_ema.module, args.resume, use_ema=True)\n\n    # setup distributed training\n    if args.distributed:\n        if has_apex and use_amp == 'apex':\n            # Apex DDP preferred unless native amp is activated\n            if args.local_rank == 0:\n                _logger.info(\"Using NVIDIA APEX DistributedDataParallel.\")\n            model = ApexDDP(model, delay_allreduce=True)\n        else:\n            if args.local_rank == 0:\n                _logger.info(\"Using native Torch DistributedDataParallel.\")\n            model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)\n        # NOTE: EMA model does not need to be wrapped by DDP\n\n    # setup learning rate schedule and starting epoch\n    lr_scheduler, num_epochs = create_scheduler(args, optimizer)\n    start_epoch = 0\n    if args.start_epoch is not None:\n        # a specified start_epoch will always override the resume epoch\n        start_epoch = args.start_epoch\n    elif resume_epoch is not None:\n        start_epoch = resume_epoch\n    if lr_scheduler is not None and start_epoch > 0:\n        lr_scheduler.step(start_epoch)\n\n    if args.local_rank == 0:\n        _logger.info('Scheduled epochs: {}'.format(num_epochs))\n\n    # create the train and eval datasets\n    dataset_train = create_dataset(\n        args.dataset, root=args.data_dir, split=args.train_split, is_training=True,\n        class_map=args.class_map,\n        download=args.dataset_download,\n        batch_size=args.batch_size,\n        repeats=args.epoch_repeats)\n    dataset_eval = create_dataset(\n        args.dataset, root=args.data_dir, split=args.val_split, is_training=False,\n        class_map=args.class_map,\n        download=args.dataset_download,\n        batch_size=args.batch_size)\n\n    # setup mixup / cutmix\n    collate_fn = None\n    mixup_fn = None\n    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    if mixup_active:\n        mixup_args = dict(\n            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n            label_smoothing=args.smoothing, num_classes=args.num_classes)\n        if args.prefetcher:\n            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)\n            collate_fn = FastCollateMixup(**mixup_args)\n        else:\n            mixup_fn = Mixup(**mixup_args)\n\n    # wrap dataset in AugMix helper\n    if num_aug_splits > 1:\n        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)\n\n    # create data loaders w/ augmentation pipeiine\n    train_interpolation = args.train_interpolation\n    if args.no_aug or not train_interpolation:\n        train_interpolation = data_config['interpolation']\n    loader_train = create_loader(\n        dataset_train,\n        input_size=data_config['input_size'],\n        batch_size=args.batch_size,\n        is_training=True,\n        use_prefetcher=args.prefetcher,\n        no_aug=args.no_aug,\n        simple_aug=args.simple_aug,\n        re_prob=args.reprob,\n        re_mode=args.remode,\n        re_count=args.recount,\n        re_split=args.resplit,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        num_aug_repeats=args.aug_repeats,\n        num_aug_splits=num_aug_splits,\n        interpolation=train_interpolation,\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        collate_fn=collate_fn,\n        pin_memory=args.pin_mem,\n        use_multi_epochs_loader=args.use_multi_epochs_loader,\n        worker_seeding=args.worker_seeding,\n    )\n\n\n    loader_eval = create_loader(\n        dataset_eval,\n        input_size=data_config['input_size'],\n        batch_size=args.validation_batch_size or args.batch_size,\n        is_training=False,\n        use_prefetcher=args.prefetcher,\n        direct_resize=args.direct_resize,\n        interpolation=data_config['interpolation'],\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        distributed=args.distributed,\n        crop_pct=data_config['crop_pct'],\n        pin_memory=args.pin_mem,\n    )\n\n    # setup loss function\n    if args.jsd_loss:\n        assert num_aug_splits > 1  # JSD only valid with aug splits set\n        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)\n    elif mixup_active:\n        # smoothing is handled with mixup target transform which outputs sparse, soft targets\n        if args.bce_loss:\n            train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)\n        else:\n            train_loss_fn = SoftTargetCrossEntropy()\n    elif args.smoothing:\n        if args.bce_loss:\n            train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)\n        else:\n            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)\n    else:\n        train_loss_fn = nn.CrossEntropyLoss()\n    train_loss_fn = train_loss_fn.cuda()\n    validate_loss_fn = nn.CrossEntropyLoss().cuda()\n        \n    # setup checkpoint saver and eval metric tracking\n    eval_metric = args.eval_metric\n    best_metric = None\n    best_epoch = None\n    saver = None\n    output_dir = None\n    if args.rank == 0:\n        if args.experiment:\n            exp_name = args.experiment\n        else:\n            exp_name = '-'.join([\n                datetime.now().strftime(\"%Y%m%d-%H%M%S\"),\n                safe_model_name(args.model),\n                str(data_config['input_size'][-1])\n            ])\n        output_dir = get_outdir(args.output if args.output else './output/train', exp_name)\n        decreasing = True if eval_metric == 'loss' else False\n        saver = CheckpointSaver(\n            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,\n            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)\n        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:\n            f.write(args_text)\n\n    if args.evaluate:\n        if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n            if args.local_rank == 0:\n                _logger.info(\"Distributing BatchNorm running means and vars\")\n            distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n        eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)\n\n        if model_ema is not None and not args.model_ema_force_cpu:\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n            ema_eval_metrics = validate(\n                model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')\n            eval_metrics = ema_eval_metrics\n        if saver is not None:\n            # save proper checkpoint with eval metric\n            save_metric = eval_metrics[eval_metric]\n            best_metric, best_epoch = saver.save_checkpoint(start_epoch, metric=save_metric)\n\n        return\n        \n    try:\n        for epoch in range(start_epoch, num_epochs):\n            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):\n                loader_train.sampler.set_epoch(epoch)\n\n            train_metrics = train_one_epoch(\n                epoch, model, loader_train, optimizer, train_loss_fn, args,\n                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,\n                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)\n\n            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                if args.local_rank == 0:\n                    _logger.info(\"Distributing BatchNorm running means and vars\")\n                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')\n\n            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)\n\n            if model_ema is not None and not args.model_ema_force_cpu:\n                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):\n                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')\n                ema_eval_metrics = validate(\n                    model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')\n                eval_metrics = ema_eval_metrics\n\n            if lr_scheduler is not None:\n                # step LR for next epoch\n                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])\n\n            if output_dir is not None:\n                update_summary(\n                    epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n                    write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)\n\n            if saver is not None:\n                # save proper checkpoint with eval metric\n                save_metric = eval_metrics[eval_metric]\n                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)\n\n    except KeyboardInterrupt:\n        pass\n    if best_metric is not None:\n        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))\n\n\ndef train_one_epoch(\n        epoch, model, loader, optimizer, loss_fn, args,\n        lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,\n        loss_scaler=None, model_ema=None, mixup_fn=None):\n\n    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:\n        if args.prefetcher and loader.mixup_enabled:\n            loader.mixup_enabled = False\n        elif mixup_fn is not None:\n            mixup_fn.mixup_enabled = False\n\n    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    losses_m = AverageMeter()\n\n    model.train()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n    num_updates = epoch * len(loader)\n    for batch_idx, (input, target) in enumerate(loader):\n        last_batch = batch_idx == last_idx\n        data_time_m.update(time.time() - end)\n        if not args.prefetcher:\n            input, target = input.cuda(), target.cuda()\n            if mixup_fn is not None:\n                input, target = mixup_fn(input, target)\n        if args.channels_last:\n            input = input.contiguous(memory_format=torch.channels_last)\n\n        with amp_autocast():\n            output = model(input)\n            loss = loss_fn(output, target)\n\n        if not args.distributed:\n            losses_m.update(loss.item(), input.size(0))\n\n        optimizer.zero_grad()\n        if loss_scaler is not None:\n            loss_scaler(\n                loss, optimizer,\n                clip_grad=args.clip_grad, clip_mode=args.clip_mode,\n                parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),\n                create_graph=second_order)\n        else:\n            loss.backward(create_graph=second_order)\n            if args.clip_grad is not None:\n                dispatch_clip_grad(\n                    model_parameters(model, exclude_head='agc' in args.clip_mode),\n                    value=args.clip_grad, mode=args.clip_mode)\n            optimizer.step()\n\n        if model_ema is not None:\n            model_ema.update(model)\n\n        torch.cuda.synchronize()\n        num_updates += 1\n        batch_time_m.update(time.time() - end)\n        if last_batch or batch_idx % args.log_interval == 0:\n            lrl = [param_group['lr'] for param_group in optimizer.param_groups]\n            lr = sum(lrl) / len(lrl)\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                losses_m.update(reduced_loss.item(), input.size(0))\n\n            if args.local_rank == 0:\n                _logger.info(\n                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '\n                    'Loss: {loss.val:#.4g} ({loss.avg:#.3g})  '\n                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '\n                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                    'LR: {lr:.3e}  '\n                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(\n                        epoch,\n                        batch_idx, len(loader),\n                        100. * batch_idx / last_idx,\n                        loss=losses_m,\n                        batch_time=batch_time_m,\n                        rate=input.size(0) * args.world_size / batch_time_m.val,\n                        rate_avg=input.size(0) * args.world_size / batch_time_m.avg,\n                        lr=lr,\n                        data_time=data_time_m))\n\n                if args.save_images and output_dir:\n                    torchvision.utils.save_image(\n                        input,\n                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),\n                        padding=0,\n                        normalize=True)\n\n        if saver is not None and args.recovery_interval and (\n                last_batch or (batch_idx + 1) % args.recovery_interval == 0):\n            saver.save_recovery(epoch, batch_idx=batch_idx)\n\n        if lr_scheduler is not None:\n            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)\n\n        end = time.time()\n        # end for\n\n    if hasattr(optimizer, 'sync_lookahead'):\n        optimizer.sync_lookahead()\n\n    return OrderedDict([('loss', losses_m.avg)])\n\n\ndef validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):\n    batch_time_m = AverageMeter()\n    losses_m = AverageMeter()\n    top1_m = AverageMeter()\n    top5_m = AverageMeter()\n\n    model.eval()\n\n    end = time.time()\n    last_idx = len(loader) - 1\n\n    with torch.no_grad():\n        for batch_idx, (input, target) in enumerate(loader):\n            last_batch = batch_idx == last_idx\n            if not args.prefetcher:\n                input = input.cuda()\n                target = target.cuda()\n            if args.channels_last:\n                input = input.contiguous(memory_format=torch.channels_last)\n\n            with amp_autocast():\n                output = model(input)\n\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n\n            # augmentation reduction\n            reduce_factor = args.tta\n            if reduce_factor > 1:\n                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)\n                target = target[0:target.size(0):reduce_factor]\n\n            loss = loss_fn(output, target)\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data, args.world_size)\n                acc1 = reduce_tensor(acc1, args.world_size)\n                acc5 = reduce_tensor(acc5, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            torch.cuda.synchronize()\n\n            losses_m.update(reduced_loss.item(), input.size(0))\n            top1_m.update(acc1.item(), output.size(0))\n            top5_m.update(acc5.item(), output.size(0))\n\n            batch_time_m.update(time.time() - end)\n            end = time.time()\n            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):\n                log_name = 'Test' + log_suffix\n                _logger.info(\n                    '{0}: [{1:>4d}/{2}]  '\n                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '\n                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(\n                        log_name, batch_idx, last_idx, batch_time=batch_time_m,\n                        loss=losses_m, top1=top1_m, top5=top5_m))\n\n    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])\n    \n    return metrics\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "train_scripts/asmlp/cifar_100/train_full.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model as_base_patch4_window7_224 \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 10  \\\n    --lr 5e-5 --min-lr 5e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--output  output/as_base_patch4_window7_224/cifar_100/full \\\n\t--amp  --pretrained  \\"
  },
  {
    "path": "train_scripts/asmlp/cifar_100/train_linear_probe.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model as_base_patch4_window7_224 \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 5e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--output  output/as_base_patch4_window7_224/cifar_100/linear_probe \\\n\t--amp  --tuning-mode linear_probe --pretrained  \\"
  },
  {
    "path": "train_scripts/asmlp/cifar_100/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3, python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model as_base_patch4_window7_224 \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--output  output/as_base_patch4_window7_224/cifar_100/ssf \\\n\t--amp  --tuning-mode ssf --pretrained \\"
  },
  {
    "path": "train_scripts/convnext/cifar_100/train_full.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model convnext_base_in22k \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 10  \\\n    --lr 5e-5 --min-lr 5e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/convnext_base_in22k/cifar_100/full \\\n\t--amp  --pretrained  \\"
  },
  {
    "path": "train_scripts/convnext/cifar_100/train_linear_probe.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model convnext_base_in22k \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 5e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/convnext_base_in22k/cifar_100/linear_probe \\\n\t--amp --tuning-mode linear_probe  --pretrained  \\\n"
  },
  {
    "path": "train_scripts/convnext/cifar_100/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3, python  -m torch.distributed.launch --nproc_per_node=4  --master_port=27524 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model convnext_base_in22k \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/convnext_base_in22k/cifar_100/ssf \\\n\t--amp --tuning-mode ssf  --pretrained  \\"
  },
  {
    "path": "train_scripts/convnext/imagenet_1k/train_full.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model convnext_base_in22k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 5  \\\n    --lr 5e-5 --min-lr 5e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/convnext_base_in22k/imagenet_1k/full \\\n\t--amp  --pretrained  \\\n\n\n"
  },
  {
    "path": "train_scripts/convnext/imagenet_1k/train_linear_probe.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model convnext_base_in22k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 5  \\\n    --lr 1e-3 --min-lr 5e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/convnext_base_in22k/imagenet_1k/linear_probe \\\n\t--amp --tuning-mode linear_probe --pretrained  "
  },
  {
    "path": "train_scripts/convnext/imagenet_1k/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k  --dataset imagenet --num-classes 1000 --model convnext_base_in22k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 5  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/convnext_base_in22k/imagenet_1k/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \n"
  },
  {
    "path": "train_scripts/swin/cifar_100/train_full.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model swin_base_patch4_window7_224_in22k \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 10  \\\n    --lr 5e-5 --min-lr 5e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--output  output/swin_base_patch4_window7_224_in22k/cifar_100/full \\\n\t--amp  --pretrained  \\"
  },
  {
    "path": "train_scripts/swin/cifar_100/train_linear_probe.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model swin_base_patch4_window7_224_in22k \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 5e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--output  output/swin_base_patch4_window7_224_in22k/cifar_100/linear_probe \\\n\t--amp  --tuning-mode linear_probe --pretrained  \\"
  },
  {
    "path": "train_scripts/swin/cifar_100/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model swin_base_patch4_window7_224_in22k \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 5e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--output  output/swin_base_patch4_window7_224_in22k/cifar_100/ssf \\\n\t--amp  --tuning-mode ssf --pretrained \\"
  },
  {
    "path": "train_scripts/swin/imagenet_1k/train_full.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model swin_base_patch4_window7_224_in22k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 5  \\\n    --lr 5e-5 --min-lr 5e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--output  output/swin_base_patch4_window7_224_in22k/imagenet_1k/full \\\n\t--amp  --pretrained  \\"
  },
  {
    "path": "train_scripts/swin/imagenet_1k/train_linear_probe.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,   python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model swin_base_patch4_window7_224_in22k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 5  \\\n    --lr 5e-3 --min-lr 5e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--output  output/swin_base_patch4_window7_224_in22k/imagenet_1k/linear_probe \\\n\t--amp --tuning-mode linear_probe --pretrained  "
  },
  {
    "path": "train_scripts/swin/imagenet_1k/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model swin_base_patch4_window7_224_in22k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 5e-7 --warmup-epochs 5  \\\n    --lr 5e-3 --min-lr 5e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--output  output/swin_base_patch4_window7_224_in22k/imagenet_1k/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/cifar_100/eval_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,  python  -m torch.distributed.launch --nproc_per_node=1  --master_port=17346  \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/vit_base_patch16_224_in21k/cifar_100/ssf/eval \\\n\t--amp  --tuning-mode ssf --pretrained  \\\n    --evaluate \\\n    --checkpoint /path/to/vit_base_patch16_224_in21k/cifar_100/ssf/model_best.pth.tar  \\"
  },
  {
    "path": "train_scripts/vit/cifar_100/train_full.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=12346 \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-5 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/vit_base_patch16_224_in21k/cifar_100/full \\\n\t--amp  --pretrained  \\"
  },
  {
    "path": "train_scripts/vit/cifar_100/train_linear_probe.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=12346  \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/vit_base_patch16_224_in21k/cifar_100/linear_probe \\\n\t--amp --tuning-mode linear_probe --pretrained  \\"
  },
  {
    "path": "train_scripts/vit/cifar_100/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=12346  \\\n\ttrain.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/vit_base_patch16_224_in21k/cifar_100/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \\"
  },
  {
    "path": "train_scripts/vit/fgvc/cub_2011/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14655  \\\n\ttrain.py /path/to/CUB_200_2011 --dataset cub2011 --num-classes 200 --simple-aug --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-2 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-2 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.9998  \\\n\t--output  output/vit_base_patch16_224_in21k/fgvc/cub2011/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/fgvc/nabirds/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14222  \\\n\ttrain.py /path/to/nabirds --dataset nabirds --num-classes 555  --simple-aug --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 2e-4 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.9998  \\\n\t--output  output/vit_base_patch16_224_in21k/nabirds/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/fgvc/oxford_flowers/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python -m torch.distributed.launch --nproc_per_node=2  --master_port=12341 \\\n    train.py /path/to/oxford_flowers  --dataset oxford_flowers --num-classes 102 --val-split val --simple-aug --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-2 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.999  \\\n\t--output  output/vit_base_patch16_224_in21k/oxford_flowers/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/fgvc/stanford_cars/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12349  \\\n\ttrain.py /path/to/stanford_cars --dataset stanford_cars --num-classes 196 --val-split val  --simple-aug --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 2e-2 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.9998  \\\n\t--output  output/vit_base_patch16_224_in21k/stanford_cars/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/fgvc/stanford_dogs/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12319 \\\n    train.py /path/to/stanford_dogs  --dataset stanford_dogs --num-classes 120 --simple-aug    --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 2.5e-4 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n    --model-ema --model-ema-decay 0.9998  \\\n\t--output  output/vit_base_patch16_224_in21k/stanford_dogs/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \\\n"
  },
  {
    "path": "train_scripts/vit/imagenet_1k/train_full.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model vit_base_patch16_224_in21k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 5  \\\n    --lr 1e-4 --min-lr 1e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/vit_base_patch16_224_in21k/imagenet_1k/full \\\n\t--amp  --pretrained  \\"
  },
  {
    "path": "train_scripts/vit/imagenet_1k/train_linear_probe.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model vit_base_patch16_224_in21k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 5  \\\n    --lr 1e-4 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/vit_base_patch16_224_in21k/imagenet_1k/linear_probe \\\n\t--amp --tuning-mode linear_probe --pretrained  "
  },
  {
    "path": "train_scripts/vit/imagenet_1k/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttrain.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model vit_base_patch16_224_in21k \\\n    --batch-size 32 --epochs 30 \\\n\t--opt adamw --weight-decay 0.05 \\\n    --warmup-lr 1e-7 --warmup-epochs 5  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--model-ema --model-ema-decay 0.99992  \\\n\t--output  output/vit_base_patch16_224_in21k/imagenet_1k/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/imagenet_a/eval_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,  python validate_ood.py \\\n    /path/to/imagenet-a  \\\n    --num-classes 1000 \\\n    --model vit_base_patch16_224_in21k \\\n    --batch-size 64 \\\n    --no-test-pool \n    --imagenet_a \\\n\t--results-file  output/vit_base_patch16_224_in21k/imagenet_a/ssf \\\n    --tuning-mode ssf \\\n    --checkpoint /path/to/vit_base_patch16_224_in21k/imagenet_1k/ssf/model_best.pth.tar\n\n\n"
  },
  {
    "path": "train_scripts/vit/imagenet_c/eval_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,  python validate_ood.py \\\n    /path/to/imagenet-c/  \\\n    --num-classes 1000 \\\n    --model vit_base_patch16_224_in21k \\\n    --batch-size 64 \\\n    --imagenet_c \\\n\t--results-file  output/vit_base_patch16_224_in21k/imagenet_c/ssf \\\n    --tuning-mode ssf \\\n    --checkpoint /path/to/vit_base_patch16_224_in21k/imagenet_1k/ssf/model_best.pth.tar\n"
  },
  {
    "path": "train_scripts/vit/imagenet_r/eval_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,  python validate_ood.py \\\n    /path/to/imagenet-r  \\\n    --num-classes 1000 \\\n    --model vit_base_patch16_224_in21k \\\n    --batch-size 64 \\\n    --imagenet_r \\\n\t--results-file  output/vit_base_patch16_224_in21k/imagenet_r/ssf \\\n    --tuning-mode ssf \\\n    --checkpoint /path/to/vit_base_patch16_224_in21k/imagenet_1k/ssf/model_best.pth.tar\n"
  },
  {
    "path": "train_scripts/vit/vtab/caltech101/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14337  \\\n\ttrain.py /path/to/vtab-1k/caltech101  --dataset caltech101 --num-classes 102  --no-aug  --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-2 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-3 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/caltech101/ssf \\\n\t--amp  --tuning-mode ssf --pretrained  \\\n"
  },
  {
    "path": "train_scripts/vit/vtab/cifar_100/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1, python  -m torch.distributed.launch --nproc_per_node=2  --master_port=19547  \\\n\ttrain.py /path/to/vtab-1k/cifar  --dataset cifar100 --num-classes 100  --no-aug --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/cifar_100/ssf \\\n\t--amp  --tuning-mode ssf --pretrained  \\\n"
  },
  {
    "path": "train_scripts/vit/vtab/clevr_count/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /path/to/vtab-1k/clevr_count  --dataset clevr_count --num-classes 8  --no-aug  --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-2 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 2e-3 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/clevr_count/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \\"
  },
  {
    "path": "train_scripts/vit/vtab/clevr_dist/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=10032  \\\n\ttrain.py /path/to/vtab-1k/clevr_dist  --dataset clevr_dist --num-classes 6  --no-aug --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-2 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-2 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/clevr_dist/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \\"
  },
  {
    "path": "train_scripts/vit/vtab/diabetic_retinopathy/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=26662  \\\n\ttrain.py /path/to/vtab-1k/diabetic_retinopathy  --dataset diabetic_retinopathy --num-classes 5  --no-aug --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/diabetic_retinopathy/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \n"
  },
  {
    "path": "train_scripts/vit/vtab/dmlab/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1, python  -m torch.distributed.launch --nproc_per_node=2  --master_port=13002  \\\n\ttrain.py /path/to/vtab-1k/dmlab  --dataset dmlab --num-classes 6  --no-aug  --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/dmlab/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/vtab/dsprites_loc/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12102  \\\n\ttrain.py /path/to/vtab-1k/dsprites_loc  --dataset dsprites_loc --num-classes 16  --no-aug  --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-2 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/dsprites_loc/ssf \\\n\t--amp --tuning-mode ssf --pretrained \\\n"
  },
  {
    "path": "train_scripts/vit/vtab/dsprites_ori/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12002  \\\n\ttrain.py /path/to/vtab-1k/dsprites_ori  --dataset dsprites_ori --num-classes 16  --no-aug   --direct-resize   --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/dsprites_ori/ssf \\\n\t--amp --tuning-mode ssf --pretrained \\\n\t"
  },
  {
    "path": "train_scripts/vit/vtab/dtd/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14312  \\\n\ttrain.py /path/to/vtab-1k/dtd  --dataset dtd --num-classes 47  --no-aug --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/dtd/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \\\n\t--mixup 0 --cutmix 0 --smoothing 0"
  },
  {
    "path": "train_scripts/vit/vtab/eurosat/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14112  \\\n\ttrain.py /path/to/vtab-1k/eurosat  --dataset eurosat --num-classes 10  --no-aug  --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-2 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 3e-3 --min-lr 1e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/eurosat/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \n"
  },
  {
    "path": "train_scripts/vit/vtab/flowers102/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14222  \\\n\ttrain.py /path/to/vtab-1k/oxford_flowers102 --dataset flowers102 --num-classes 102  --no-aug --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/flowers102/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \n\n"
  },
  {
    "path": "train_scripts/vit/vtab/kitti/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2 --master_port=14332  \\\n\ttrain.py /path/to/vtab-1k/kitti  --dataset kitti --num-classes 4  --no-aug --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-2 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/kitti/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \\"
  },
  {
    "path": "train_scripts/vit/vtab/patch_camelyon/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /path/to/vtab-1k/patch_camelyon  --dataset patch_camelyon --num-classes 2  --no-aug  --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/patch_camelyon/ssf \\\n\t--amp --tuning-mode ssf --pretrained  \n"
  },
  {
    "path": "train_scripts/vit/vtab/pets/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /path/to/vtab-1k/oxford_iiit_pet  --dataset pets --num-classes 37  --no-aug --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/pets/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/vtab/resisc45/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=11222  \\\n\ttrain.py /path/to/vtab-1k/resisc45  --dataset resisc45 --num-classes 45  --no-aug  --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 2e-3 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/resisc45/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/vtab/smallnorb_azi/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14882  \\\n\ttrain.py /path/to/vtab-1k/smallnorb_azi  --dataset smallnorb_azi --num-classes 18  --no-aug --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 2e-2 --min-lr 1e-8 \\\n    --drop-path 0.1 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/smallnorb_azi/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/vtab/smallnorb_ele/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=24332  \\\n\ttrain.py /path/to/vtab-1k/smallnorb_ele  --dataset smallnorb_ele --num-classes 9  --no-aug  --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-2 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0.2 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/smallnorb_ele/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/vtab/sun397/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14192  \\\n\ttrain.py /path/to/vtab-1k/sun397  --dataset sun397 --num-classes 397  --no-aug --direct-resize  --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 5e-3 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/sun397/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "train_scripts/vit/vtab/svhn/train_ssf.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /path/to/vtab-1k/svhn  --dataset svhn --num-classes 10  --no-aug --direct-resize --model vit_base_patch16_224_in21k  \\\n    --batch-size 32 --epochs 100 \\\n\t--opt adamw  --weight-decay 5e-5 \\\n    --warmup-lr 1e-7 --warmup-epochs 10  \\\n    --lr 1e-2 --min-lr 1e-8 \\\n    --drop-path 0 --img-size 224 \\\n\t--mixup 0 --cutmix 0 --smoothing 0 \\\n\t--output  output/vit_base_patch16_224_in21k/vtab/svhn/ssf \\\n\t--amp --tuning-mode ssf --pretrained  "
  },
  {
    "path": "utils/__init__.py",
    "content": "from .utils import load_for_transfer_learning, load_for_probing\nfrom .scaler import ApexScaler_SAM\nfrom .mce_utils import *\n"
  },
  {
    "path": "utils/imagenet_a.py",
    "content": "thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2,\n                     14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1,\n                     27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1,\n                     39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13,\n                     51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1,\n                     63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1,\n                     75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1,\n                     87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1,\n                     99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1,\n                     110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1,\n                     121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1,\n                     132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1,\n                     143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1,\n                     154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1,\n                     165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1,\n                     176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1,\n                     187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1,\n                     198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1,\n                     209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1,\n                     220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1,\n                     231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1,\n                     242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1,\n                     253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1,\n                     264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1,\n                     275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1,\n                     286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1,\n                     297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50,\n                     308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1,\n                     319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1,\n                     330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1,\n                     341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1,\n                     352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1,\n                     363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1,\n                     374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1,\n                     385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1,\n                     396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1,\n                     407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82,\n                     418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85,\n                     429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1,\n                     440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1,\n                     451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92,\n                     462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95,\n                     473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96,\n                     484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1,\n                     495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1,\n                     506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1,\n                     516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1,\n                     527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1,\n                     537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1,\n                     547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1,\n                     557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1,\n                     567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1,\n                     577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1,\n                     588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1,\n                     599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1,\n                     609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1,\n                     619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1,\n                     629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1,\n                     640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1,\n                     650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1,\n                     661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1,\n                     672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1,\n                     682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1,\n                     692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1,\n                     703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1,\n                     714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1,\n                     725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1,\n                     736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1,\n                     746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1,\n                     756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144,\n                     766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1,\n                     776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1,\n                     786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1,\n                     796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1,\n                     806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158,\n                     816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1,\n                     826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163,\n                     836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165,\n                     846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1,\n                     856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1,\n                     866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1,\n                     877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1,\n                     887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1,\n                     897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1,\n                     907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1,\n                     917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1,\n                     928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183,\n                     938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186,\n                     948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190,\n                     958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1,\n                     969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1,\n                     979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199,\n                     989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1}\nindices_in_1k = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1]\n"
  },
  {
    "path": "utils/imagenet_r.py",
    "content": "all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']\n\nimagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'}\n\nimagenet_r_mask = [wnid in imagenet_r_wnids for wnid in all_wnids]\n# imagenet_r_indices = [i for i in range(1000) if imagenet_r_mask[i] is True]\n# [1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, 113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988]\n\nimagenet_o_wnids = ['n01443537', 'n01704323', 'n01770081', 'n01784675', 'n01819313', 'n01820546', 'n01910747', 'n01917289', 'n01968897', 'n02074367', 'n02317335', 'n02319095', 'n02395406', 'n02454379', 'n02606052', 'n02655020', 'n02666196', 'n02672831', 'n02730930', 'n02777292', 'n02783161', 'n02786058', 'n02787622', 'n02791270', 'n02808304', 'n02817516', 'n02841315', 'n02865351', 'n02877765', 'n02892767', 'n02906734', 'n02910353', 'n02916936', 'n02948072', 'n02965783', 'n03000134', 'n03000684', 'n03017168', 'n03026506', 'n03032252', 'n03075370', 'n03109150', 'n03126707', 'n03134739', 'n03160309', 'n03196217', 'n03207743', 'n03218198', 'n03223299', 'n03240683', 'n03271574', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03344393', 'n03347037', 'n03372029', 'n03376595', 'n03388043', 'n03388183', 'n03400231', 'n03445777', 'n03457902', 'n03467068', 'n03482405', 'n03483316', 'n03494278', 'n03530642', 'n03544143', 'n03584829', 'n03590841', 'n03598930', 'n03602883', 'n03649909', 'n03661043', 'n03666591', 'n03676483', 'n03692522', 'n03706229', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03742115', 'n03786901', 'n03788365', 'n03794056', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03840681', 'n03843555', 'n03854065', 'n03857828', 'n03868863', 'n03874293', 'n03884397', 'n03891251', 'n03908714', 'n03920288', 'n03929660', 'n03930313', 'n03937543', 'n03942813', 'n03944341', 'n03961711', 'n03970156', 'n03982430', 'n03991062', 'n03995372', 'n03998194', 'n04005630', 'n04023962', 'n04033901', 'n04040759', 'n04067472', 'n04074963', 'n04116512', 'n04118776', 'n04125021', 'n04127249', 'n04131690', 'n04141975', 'n04153751', 'n04154565', 'n04201297', 'n04204347', 'n04209133', 'n04209239', 'n04228054', 'n04235860', 'n04243546', 'n04252077', 'n04254120', 'n04258138', 'n04265275', 'n04270147', 'n04275548', 'n04330267', 'n04332243', 'n04336792', 'n04347754', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04429376', 'n04435653', 'n04442312', 'n04482393', 'n04501370', 'n04507155', 'n04525305', 'n04542943', 'n04554684', 'n04557648', 'n04562935', 'n04579432', 'n04591157', 'n04597913', 'n04599235', 'n06785654', 'n06874185', 'n07615774', 'n07693725', 'n07695742', 'n07697537', 'n07711569', 'n07714990', 'n07715103', 'n07716358', 'n07717410', 'n07718472', 'n07720875', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753275', 'n07753592', 'n07754684', 'n07768694', 'n07836838', 'n07871810', 'n07873807', 'n07880968', 'n09229709', 'n09472597', 'n12144580', 'n12267677', 'n13052670']\n\nimagenet_o_mask = [wnid in set(imagenet_o_wnids) for wnid in all_wnids]\n\n"
  },
  {
    "path": "utils/mce_utils.py",
    "content": "# Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# This work is made available under the Nvidia Source Code License-NC.\n# To view a copy of this license, visit\n# https://github.com/NVlabs/FAN/blob/main/LICENSE\n\n# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\"\"\"\nMisc functions, including distributed helpers.\n\nMostly copy-paste from torchvision references.\n\"\"\"\n# Modified by: Daquan\n\nimport io\nimport os\nimport time\nfrom collections import defaultdict, deque\nimport datetime\n\nimport torch\nimport torch.distributed as dist\n\ndata_loaders_names = {\n            'Brightness': 'brightness',\n            'Contrast': 'contrast',\n            'Defocus Blur': 'defocus_blur',\n            'Elastic Transform': 'elastic_transform',\n            'Fog': 'fog',\n            'Frost': 'frost',\n            'Gaussian Noise': 'gaussian_noise',\n            'Glass Blur': 'glass_blur',\n            'Impulse Noise': 'impulse_noise',\n            'JPEG Compression': 'jpeg_compression',\n            'Motion Blur': 'motion_blur',\n            'Pixelate': 'pixelate',\n            'Shot Noise': 'shot_noise',\n            'Snow': 'snow',\n            'Zoom Blur': 'zoom_blur'\n        }\n\ndef get_ce_alexnet():\n    \"\"\"Returns Corruption Error values for AlexNet\"\"\"\n\n    ce_alexnet = dict()\n    ce_alexnet['gaussian_noise'] = 0.886428\n    ce_alexnet['shot_noise'] = 0.894468\n    ce_alexnet['impulse_noise'] = 0.922640\n    ce_alexnet['defocus_blur'] = 0.819880\n    ce_alexnet['glass_blur'] = 0.826268\n    ce_alexnet['motion_blur'] = 0.785948\n    ce_alexnet['zoom_blur'] = 0.798360\n    ce_alexnet['snow'] = 0.866816\n    ce_alexnet['frost'] = 0.826572\n    ce_alexnet['fog'] = 0.819324\n    ce_alexnet['brightness'] = 0.564592\n    ce_alexnet['contrast'] = 0.853204\n    ce_alexnet['elastic_transform'] = 0.646056\n    ce_alexnet['pixelate'] = 0.717840\n    ce_alexnet['jpeg_compression'] = 0.606500\n\n    return ce_alexnet\n\ndef get_mce_from_accuracy(accuracy, error_alexnet):\n    \"\"\"Computes mean Corruption Error from accuracy\"\"\"\n    error = 100. - accuracy\n    ce = error / (error_alexnet * 100.)\n\n    return ce\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\n                \"{}: {}\".format(name, str(meter))\n            )\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = ''\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt='{avg:.4f}')\n        data_time = SmoothedValue(fmt='{avg:.4f}')\n        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'\n        log_msg = [\n            header,\n            '[{0' + space_fmt + '}/{1}]',\n            'eta: {eta}',\n            '{meters}',\n            'time: {time}',\n            'data: {data}'\n        ]\n        if torch.cuda.is_available():\n            log_msg.append('max mem: {memory:.0f}')\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time),\n                        memory=torch.cuda.max_memory_allocated() / MB))\n                else:\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time)))\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{} Total time: {} ({:.4f} s / it)'.format(\n            header, total_time_str, total_time / len(iterable)))\n\n\ndef _load_checkpoint_for_ema(model_ema, checkpoint):\n    \"\"\"\n    Workaround for ModelEma._load_checkpoint to accept an already-loaded object\n    \"\"\"\n    mem_file = io.BytesIO()\n    torch.save(checkpoint, mem_file)\n    mem_file.seek(0)\n    model_ema._load_checkpoint(mem_file)\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    import builtins as __builtin__\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if is_master or force:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef save_on_master(*args, **kwargs):\n    if is_main_process():\n        torch.save(*args, **kwargs)\n"
  },
  {
    "path": "utils/scaler.py",
    "content": "import torch\nfrom timm.utils import ApexScaler, NativeScaler\ntry:\n    from apex import amp\n    has_apex = True\nexcept ImportError:\n    amp = None\n    has_apex = False\n\nclass ApexScaler_SAM(ApexScaler):\n\n    def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step=0, rho=0.05):\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward(create_graph=create_graph)\n        if step==0 or step==2:\n            if clip_grad is not None:\n                dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)\n            optimizer.step()\n        elif step==1:\n            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), rho, norm_type=2.0)\n            optimizer.step()\n"
  },
  {
    "path": "utils/utils.py",
    "content": "# Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# This work is made available under the Nvidia Source Code License-NC.\n# To view a copy of this license, visit\n# https://github.com/NVlabs/FAN/blob/main/LICENSE\n\n# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd.\n#\n# This source code is licensed under the Clear BSD License\n# LICENSE file in the root directory of this file\n# All rights reserved.\n\n# Modified by: Daquan Zhou\n\n'''\n- resize_pos_embed: resize position embedding\n- load_for_transfer_learning: load pretrained paramters to model in transfer learning\n- get_mean_and_std: calculate the mean and std value of dataset.\n- msr_init: net parameter initialization.\n- progress_bar: progress bar mimic xlua.progress.\n'''\n\nimport os\nimport sys\nimport time\nimport torch\nimport math\n\nimport torch.nn as nn\nimport torch.nn.init as init\nimport logging\nimport os\nfrom collections import OrderedDict\nimport torch.nn.functional as F\n\n_logger = logging.getLogger(__name__)\n\ndef resize_pos_embed(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1)\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    ntok_new = posemb_new.shape[1]\n    if True:\n        posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]  # posemb_tok is for cls token, posemb_grid for the following tokens\n        ntok_new -= 1\n    else:\n        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))     # 14\n    gs_new = int(math.sqrt(ntok_new))             # 24\n    _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)  # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14]\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24]\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)   # [1, dim, 24, 24] -> [1, 24*24, dim]\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)   # [1, 24*24+1, dim]\n    return posemb\n\ndef resize_pos_embed_cait(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1)\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    ntok_new = posemb_new.shape[1]\n    posemb_grid = posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))     # 14\n    gs_new = int(math.sqrt(ntok_new))             # 24\n    _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)  # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14]\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24]\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)   # [1, dim, 24, 24] -> [1, 24*24, dim]\n    return posemb_grid\n\n\ndef resize_pos_embed_nocls(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1)\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    gs_old = posemb.shape[1]     # 14\n    gs_new = posemb_new.shape[1]             # 24\n    _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)\n    posemb_grid = posemb\n    posemb_grid = posemb_grid.permute(0, 3, 1, 2)  # [1, 14, 14, dim]->[1, dim, 14, 14]\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24]\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1)   # [1, dim, 24, 24]->[1, 24, 24, dim]\n    return posemb_grid\n\n\ndef load_state_dict(checkpoint_path,model, use_ema=False, num_classes=1000, no_pos_embed=False):\n    if checkpoint_path and os.path.isfile(checkpoint_path):\n        checkpoint = torch.load(checkpoint_path, map_location='cpu')\n        state_dict_key = 'state_dict'\n        if isinstance(checkpoint, dict):\n            if use_ema and 'state_dict_ema' in checkpoint:\n                state_dict_key = 'state_dict_ema'\n        if state_dict_key and state_dict_key in checkpoint:\n            new_state_dict = OrderedDict()\n            for k, v in checkpoint[state_dict_key].items():\n                # strip `module.` prefix\n                name = k[7:] if k.startswith('module') else k\n                new_state_dict[name] = v\n            state_dict = new_state_dict\n        else:\n            state_dict = checkpoint\n        _logger.info(\"Loaded {} from checkpoint '{}'\".format(state_dict_key, checkpoint_path))\n        if num_classes != 1000:\n            # completely discard fully connected for all other differences between pretrained and created model\n            del state_dict['head' + '.weight']\n            del state_dict['head' + '.bias']\n            old_aux_head_weight = state_dict.pop('aux_head.weight', None)\n            old_aux_head_bias = state_dict.pop('aux_head.bias', None)\n        if not no_pos_embed:\n            old_posemb = state_dict['pos_embed']\n            if model.pos_embed.shape != old_posemb.shape:  # need resize the position embedding by interpolate\n                if len(old_posemb.shape)==3:\n                    if int(math.sqrt(old_posemb.shape[1]))**2==old_posemb.shape[1]:\n                        new_posemb = resize_pos_embed_cait(old_posemb, model.pos_embed)\n                    else:\n                        new_posemb = resize_pos_embed(old_posemb, model.pos_embed)\n                elif len(old_posemb.shape)==4:\n                    new_posemb = resize_pos_embed_nocls(old_posemb, model.pos_embed)\n                state_dict['pos_embed'] = new_posemb\n\n        return state_dict\n    else:\n        _logger.error(\"No checkpoint found at '{}'\".format(checkpoint_path))\n        raise FileNotFoundError()\n\n\ndef load_for_transfer_learning(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000):\n    state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes)\n    model.load_state_dict(state_dict, strict=strict)\n\ndef load_for_probing(model, checkpoint_path, use_ema=False, strict=False, num_classes=19167):\n    state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes=19167, no_pos_embed=True)\n    info=model.load_state_dict(state_dict, strict=strict)\n    print(info)\n\ndef get_mean_and_std(dataset):\n    '''Compute the mean and std value of dataset.'''\n    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)\n    mean = torch.zeros(3)\n    std = torch.zeros(3)\n    print('==> Computing mean and std..')\n    for inputs, targets in dataloader:\n        for i in range(3):\n            mean[i] += inputs[:,i,:,:].mean()\n            std[i] += inputs[:,i,:,:].std()\n    mean.div_(len(dataset))\n    std.div_(len(dataset))\n    return mean, std\n\ndef init_params(net):\n    '''Init layer parameters.'''\n    for m in net.modules():\n        if isinstance(m, nn.Conv2d):\n            init.kaiming_normal(m.weight, mode='fan_out')\n            if m.bias:\n                init.constant(m.bias, 0)\n        elif isinstance(m, nn.BatchNorm2d):\n            init.constant(m.weight, 1)\n            init.constant(m.bias, 0)\n        elif isinstance(m, nn.Linear):\n            init.normal(m.weight, std=1e-3)\n            if m.bias:\n                init.constant(m.bias, 0)\n\n"
  },
  {
    "path": "validate_ood.py",
    "content": "#!/usr/bin/env python3\n# Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# This work is made available under the Nvidia Source Code License-NC.\n# To view a copy of this license, visit\n# https://github.com/NVlabs/FAN/blob/main/LICENSE\n\n\"\"\" ImageNet Validation Script\nThis is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained\nmodels or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes\ncanonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.\nHacked together by Ross Wightman (https://github.com/rwightman)\n\"\"\"\n\nimport argparse\nimport errno\nimport os\nimport csv\nimport glob\nimport time\nimport logging\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nfrom collections import OrderedDict\nfrom contextlib import suppress\n\nfrom timm.models import create_model, apply_test_time_pool, resume_checkpoint, load_checkpoint, is_model, list_models\nfrom timm.data import resolve_data_config, RealLabelsImagenet\nfrom timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy\n\nimport numpy as np\n\n\n\nfrom utils.imagenet_a import indices_in_1k\nfrom utils.imagenet_r import imagenet_r_mask\nfrom utils.mce_utils import get_ce_alexnet, get_mce_from_accuracy\n\n\nfrom data import create_loader, create_dataset\nfrom optim_factory import create_optimizer_v2, optimizer_kwargs\nfrom models import vision_transformer, swin_transformer, convnext\n\n\n\nhas_apex = False\ntry:\n    from apex import amp\n    has_apex = True\nexcept ImportError:\n    pass\n\nhas_native_amp = False\ntry:\n    if getattr(torch.cuda.amp, 'autocast') is not None:\n        has_native_amp = True\nexcept AttributeError:\n    pass\n\ntorch.backends.cudnn.benchmark = True\n_logger = logging.getLogger('validate')\n\n\nparser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')\nparser.add_argument('data', metavar='DIR',\n                    help='path to dataset')\nparser.add_argument('--dataset', '-d', metavar='NAME', default='',\n                    help='dataset type (default: ImageFolder/ImageTar if empty)')\nparser.add_argument('--split', metavar='NAME', default='validation',\n                    help='dataset split (default: validation)')\nparser.add_argument('--model', '-m', metavar='NAME', default='dpn92',\n                    help='model architecture (default: dpn92)')\nparser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                    help='number of data loading workers (default: 2)')\nparser.add_argument('-b', '--batch-size', default=256, type=int,\n                    metavar='N', help='mini-batch size (default: 256)')\nparser.add_argument('--img-size', default=None, type=int,\n                    metavar='N', help='Input image dimension, uses model default if empty')\nparser.add_argument('--input-size', default=None, nargs=3, type=int,\n                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')\nparser.add_argument('--crop-pct', default=None, type=float,\n                    metavar='N', help='Input image center crop pct')\nparser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',\n                    help='Override mean pixel value of dataset')\nparser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',\n                    help='Override std deviation of of dataset')\nparser.add_argument('--interpolation', default='', type=str, metavar='NAME',\n                    help='Image resize interpolation type (overrides model)')\nparser.add_argument('--num-classes', type=int, default=None,\n                    help='Number classes in dataset')\nparser.add_argument('--class-map', default='', type=str, metavar='FILENAME',\n                    help='path to class to idx mapping file (default: \"\")')\nparser.add_argument('--gp', default=None, type=str, metavar='POOL',\n                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')\nparser.add_argument('--log-freq', default=50, type=int,\n                    metavar='N', help='batch logging frequency (default: 10)')\nparser.add_argument('--checkpoint', default='', type=str, metavar='PATH',\n                    help='path to latest checkpoint (default: none)')\nparser.add_argument('--pretrained', dest='pretrained', action='store_true',\n                    help='use pre-trained model')\nparser.add_argument('--num-gpu', type=int, default=1,\n                    help='Number of GPUS to use')\nparser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',\n                    help='disable test time pool')\nparser.add_argument('--no-prefetcher', action='store_true', default=False,\n                    help='disable fast prefetcher')\nparser.add_argument('--pin-mem', action='store_true', default=False,\n                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\nparser.add_argument('--channels-last', action='store_true', default=False,\n                    help='Use channels_last memory layout')\nparser.add_argument('--amp', action='store_true', default=False,\n                    help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')\nparser.add_argument('--apex-amp', action='store_true', default=False,\n                    help='Use NVIDIA Apex AMP mixed precision')\nparser.add_argument('--native-amp', action='store_true', default=False,\n                    help='Use Native Torch AMP mixed precision')\nparser.add_argument('--tf-preprocessing', action='store_true', default=False,\n                    help='Use Tensorflow preprocessing pipeline (require CPU TF installed')\nparser.add_argument('--use-ema', dest='use_ema', action='store_true',\n                    help='use ema version of weights if present')\nparser.add_argument('--torchscript', dest='torchscript', action='store_true',\n                    help='convert model torchscript for inference')\nparser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',\n                    help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')\nparser.add_argument('--results-file', default='', type=str, metavar='FILENAME',\n                    help='Output csv file for validation results (summary)')\nparser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',\n                    help='Real labels JSON file for imagenet evaluation')\nparser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',\n                    help='Valid label indices txt file for validation of partial label space')\n\n# finetuning\nparser.add_argument('--tuning-mode', default=None, type=str,\n                    help='Method of fine-tuning (default: None')\nparser.add_argument('--num-vpt', default=None, type=int,\n                    help='The number of prompts in VPT')\n\n\nparser.add_argument('--imagenet_a', action='store_true', default=False,\n                    help='replace labels from 1k to 200')\nparser.add_argument('--imagenet_r', action='store_true', default=False,\n                    help='replace labels from 1k to imagenet-r indices')\nparser.add_argument('--imagenet_c', action='store_true', default=False,\n                    help='use corrupted dataset for evaluation')\n\ndef validate(args):\n    args.pretrained = args.pretrained or not args.checkpoint\n    args.prefetcher = not args.no_prefetcher\n    amp_autocast = suppress  # do nothing\n    if args.amp:\n        if has_native_amp:\n            args.native_amp = True\n        elif has_apex:\n            args.apex_amp = True\n        else:\n            _logger.warning(\"Neither APEX or Native Torch AMP is available.\")\n    assert not args.apex_amp or not args.native_amp, \"Only one AMP mode should be set.\"\n    if args.native_amp:\n        amp_autocast = torch.cuda.amp.autocast\n        _logger.info('Validating in mixed precision with native PyTorch AMP.')\n    elif args.apex_amp:\n        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')\n    else:\n        _logger.info('Validating in float32. AMP not enabled.')\n\n    if args.legacy_jit:\n        set_jit_legacy()\n\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        num_classes=args.num_classes,\n        global_pool=args.gp,\n        scriptable=args.torchscript,\n        tuning_mode=args.tuning_mode,\n        num_vpt=args.num_vpt\n        )\n\n\n    if args.num_classes is None:\n        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'\n        args.num_classes = model.num_classes\n\n    if args.checkpoint:\n        resume_epoch = resume_checkpoint(\n            model, args.checkpoint\n            )\n\n    param_count = sum([m.numel() for m in model.parameters()])\n    _logger.info('Model %s created, param count: %d' % (args.model, param_count))\n\n    data_config = resolve_data_config(vars(args), model=model, use_test_size=True)\n    test_time_pool = False\n    if not args.no_test_pool:\n        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)\n\n    if args.torchscript:\n        torch.jit.optimized_execution(True)\n        model = torch.jit.script(model)\n\n    model = model.cuda()\n    if args.apex_amp:\n        model = amp.initialize(model, opt_level='O1')\n\n    if args.channels_last:\n        model = model.to(memory_format=torch.channels_last)\n\n    if args.num_gpu > 1:\n        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))\n\n    criterion = nn.CrossEntropyLoss().cuda()\n\n    print(args.data)\n    dataset = create_dataset(\n        root=args.data, name=args.dataset, split=args.split,\n        load_bytes=args.tf_preprocessing, class_map=args.class_map)\n\n    if args.valid_labels:\n        with open(args.valid_labels, 'r') as f:\n            valid_labels = {int(line.rstrip()) for line in f}\n            valid_labels = [i in valid_labels for i in range(args.num_classes)]\n    else:\n        valid_labels = None\n\n    if args.real_labels:\n        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)\n    else:\n        real_labels = None\n\n    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']\n    loader = create_loader(\n        dataset,\n        input_size=data_config['input_size'],\n        batch_size=args.batch_size,\n        use_prefetcher=args.prefetcher,\n        interpolation=data_config['interpolation'],\n        mean=data_config['mean'],\n        std=data_config['std'],\n        num_workers=args.workers,\n        crop_pct=crop_pct,\n        pin_memory=args.pin_mem,\n        tf_preprocessing=args.tf_preprocessing)\n\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    model.eval()\n    with torch.no_grad():\n        input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()\n        if args.channels_last:\n            input = input.contiguous(memory_format=torch.channels_last)\n        end = time.time()\n        for batch_idx, (input, target) in enumerate(loader):\n            if args.no_prefetcher:\n                target = target.cuda()\n                input = input.cuda()\n            if args.channels_last:\n                input = input.contiguous(memory_format=torch.channels_last)\n\n            # compute output\n            with amp_autocast():\n                output = model(input)\n                if args.imagenet_a:\n                    output = output[:, indices_in_1k]\n                if args.imagenet_r:\n                    output = output[:, imagenet_r_mask]\n            if isinstance(output, (tuple, list)):\n                output = output[0]\n            if valid_labels is not None:\n                output = output[:, valid_labels]\n            loss = criterion(output, target)\n\n            if real_labels is not None:\n                real_labels.add_result(output)\n\n            # measure accuracy and record loss\n            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))\n            losses.update(loss.item(), input.size(0))\n            top1.update(acc1.item(), input.size(0))\n            top5.update(acc5.item(), input.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if batch_idx % args.log_freq == 0:\n                _logger.info(\n                    'Test: [{0:>4d}/{1}]  '\n                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '\n                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '\n                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '\n                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(\n                        batch_idx, len(loader), batch_time=batch_time,\n                        rate_avg=input.size(0) / batch_time.avg,\n                        loss=losses, top1=top1, top5=top5))\n\n    if real_labels is not None:\n        # real labels mode replaces topk values at the end\n        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)\n    else:\n        top1a, top5a = top1.avg, top5.avg\n    results = OrderedDict(\n        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),\n        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),\n        param_count=round(param_count / 1e6, 2),\n        img_size=data_config['input_size'][-1],\n        cropt_pct=crop_pct,\n        interpolation=data_config['interpolation'])\n\n    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(\n       results['top1'], results['top1_err'], results['top5'], results['top5_err']))\n\n    return results\n\n\ndef main():\n    setup_default_logging()\n    args = parser.parse_args()\n    if not args.imagenet_c:\n        if args.imagenet_a or args.imagenet_r:\n            validate(args)\n        else:\n            print('Please specify an OOD dataset.')\n            return\n    else:\n        results_file = args.results_file or './results-all.csv'\n        os.makedirs(results_file, exist_ok=True)\n        blur_list = ['gaussian_blur', 'motion_blur', 'glass_blur', 'defocus_blur']\n        noise_list = ['gaussian_noise', 'shot_noise', 'speckle_noise', 'impulse_noise']\n        digital_list = ['contrast', 'jpeg_compression', 'saturate', 'pixelate']\n        weather_list = ['snow', 'fog', 'frost', 'spatter', 'brightness']\n        extra = ['zoom_blur', 'elastic_transform']\n        name_list = noise_list + extra + blur_list + digital_list + weather_list\n        ce_alexnet = get_ce_alexnet()\n        mCE = 0\n        counter = 0\n        average_acc = {}\n        base_dir = args.data\n        for noise_name in name_list:\n            res_sum = 0\n            root = base_dir + noise_name + '/'\n            results = []\n            for i in range(0, 5):\n                args.data = root + str(i+1)\n                print('validating dir:', args.data)\n                res = validate(args)\n                results.append(res['top1'])\n                res_sum += res['top1']\n                if noise_name in ce_alexnet.keys():\n                    CE = get_mce_from_accuracy(res['top1'], ce_alexnet[noise_name])\n                    mCE += CE\n                    counter += 1\n            results.append(res_sum/(i+1))\n            average_acc[noise_name] = res_sum/(i+1)\n            np.savetxt(results_file  + noise_name + '_' + '%.2f' % (res_sum/(i+1)) + '.csv', results)\n            print('average score is:', res_sum / (i+1))\n            print('current mCE is: ', mCE/counter)\n        np.savetxt(results_file + 'mCE' + '_' + '%.2f' % (mCE/counter) + '.csv', results)\n        print('all average score is:', average_acc)\n        print('mCE is: ', mCE/counter)\n\n\ndef write_results(results_file, results):\n    with open(results_file, mode='w') as cf:\n        dw = csv.DictWriter(cf, fieldnames=results[0].keys())\n        dw.writeheader()\n        for r in results:\n            dw.writerow(r)\n        cf.flush()\n\nif __name__ == '__main__':\n    main()"
  }
]