Full Code of dongzelian/SSF for AI

main e94e0e704a4e cached
75 files
315.5 KB
95.7k tokens
298 symbols
1 requests
Download .txt
Showing preview only (338K chars total). Download the full file or copy to clipboard to get everything.
Repository: dongzelian/SSF
Branch: main
Commit: e94e0e704a4e
Files: 75
Total size: 315.5 KB

Directory structure:
gitextract_ti2b9cvr/

├── LICENSE
├── README.md
├── data/
│   ├── __init__.py
│   ├── cub2011.py
│   ├── dataset_factory.py
│   ├── loader.py
│   ├── nabirds.py
│   ├── stanford_dogs.py
│   ├── transforms_factory.py
│   └── vtab.py
├── log/
│   ├── README.md
│   └── cifar100.csv
├── models/
│   ├── as_mlp.py
│   ├── convnext.py
│   ├── swin_transformer.py
│   └── vision_transformer.py
├── optim_factory.py
├── requirements.txt
├── train.py
├── train_scripts/
│   ├── asmlp/
│   │   └── cifar_100/
│   │       ├── train_full.sh
│   │       ├── train_linear_probe.sh
│   │       └── train_ssf.sh
│   ├── convnext/
│   │   ├── cifar_100/
│   │   │   ├── train_full.sh
│   │   │   ├── train_linear_probe.sh
│   │   │   └── train_ssf.sh
│   │   └── imagenet_1k/
│   │       ├── train_full.sh
│   │       ├── train_linear_probe.sh
│   │       └── train_ssf.sh
│   ├── swin/
│   │   ├── cifar_100/
│   │   │   ├── train_full.sh
│   │   │   ├── train_linear_probe.sh
│   │   │   └── train_ssf.sh
│   │   └── imagenet_1k/
│   │       ├── train_full.sh
│   │       ├── train_linear_probe.sh
│   │       └── train_ssf.sh
│   └── vit/
│       ├── cifar_100/
│       │   ├── eval_ssf.sh
│       │   ├── train_full.sh
│       │   ├── train_linear_probe.sh
│       │   └── train_ssf.sh
│       ├── fgvc/
│       │   ├── cub_2011/
│       │   │   └── train_ssf.sh
│       │   ├── nabirds/
│       │   │   └── train_ssf.sh
│       │   ├── oxford_flowers/
│       │   │   └── train_ssf.sh
│       │   ├── stanford_cars/
│       │   │   └── train_ssf.sh
│       │   └── stanford_dogs/
│       │       └── train_ssf.sh
│       ├── imagenet_1k/
│       │   ├── train_full.sh
│       │   ├── train_linear_probe.sh
│       │   └── train_ssf.sh
│       ├── imagenet_a/
│       │   └── eval_ssf.sh
│       ├── imagenet_c/
│       │   └── eval_ssf.sh
│       ├── imagenet_r/
│       │   └── eval_ssf.sh
│       └── vtab/
│           ├── caltech101/
│           │   └── train_ssf.sh
│           ├── cifar_100/
│           │   └── train_ssf.sh
│           ├── clevr_count/
│           │   └── train_ssf.sh
│           ├── clevr_dist/
│           │   └── train_ssf.sh
│           ├── diabetic_retinopathy/
│           │   └── train_ssf.sh
│           ├── dmlab/
│           │   └── train_ssf.sh
│           ├── dsprites_loc/
│           │   └── train_ssf.sh
│           ├── dsprites_ori/
│           │   └── train_ssf.sh
│           ├── dtd/
│           │   └── train_ssf.sh
│           ├── eurosat/
│           │   └── train_ssf.sh
│           ├── flowers102/
│           │   └── train_ssf.sh
│           ├── kitti/
│           │   └── train_ssf.sh
│           ├── patch_camelyon/
│           │   └── train_ssf.sh
│           ├── pets/
│           │   └── train_ssf.sh
│           ├── resisc45/
│           │   └── train_ssf.sh
│           ├── smallnorb_azi/
│           │   └── train_ssf.sh
│           ├── smallnorb_ele/
│           │   └── train_ssf.sh
│           ├── sun397/
│           │   └── train_ssf.sh
│           └── svhn/
│               └── train_ssf.sh
├── utils/
│   ├── __init__.py
│   ├── imagenet_a.py
│   ├── imagenet_r.py
│   ├── mce_utils.py
│   ├── scaler.py
│   └── utils.py
└── validate_ood.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 dongzelian

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# SSF for Efficient Model Tuning

This repo is the official implementation of our NeurIPS2022 paper "Scaling & Shifting Your Features: A New Baseline for Efficient Model Tuning" ([arXiv](https://arxiv.org/abs/2210.08823)). 




## Usage

### Install

- Clone this repo:

```bash
git clone https://github.com/dongzelian/SSF.git
cd SSF
```

- Create a conda virtual environment and activate it:

```bash
conda create -n ssf python=3.7 -y
conda activate ssf
```

- Install `CUDA==10.1` with `cudnn7` following
  the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
- Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`:

```bash
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
```

- Install `timm==0.6.5`:

```bash
pip install timm==0.6.5
```


- Install other requirements:

```bash
pip install -r requirements.txt
```


### Data preparation

- FGVC & vtab-1k

You can follow [VPT](https://github.com/KMnP/vpt) to download them. 

Since the original [vtab dataset](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data) is processed with tensorflow scripts and the processing of some datasets is tricky, we also upload the extracted vtab-1k dataset in [onedrive](https://shanghaitecheducn-my.sharepoint.com/:f:/g/personal/liandz_shanghaitech_edu_cn/EnV6eYPVCPZKhbqi-WSJIO8BOcyQwDwRk6dAThqonQ1Ycw?e=J884Fp) for your convenience. You can download from here and then use them with our [vtab.py](https://github.com/dongzelian/SSF/blob/main/data/vtab.py) directly. (Note that the license is in [vtab dataset](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data)).



- CIFAR-100
```bash
wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
```

- For ImageNet-1K, download it from http://image-net.org/, and move validation images to labeled sub-folders. The file structure should look like:
  ```bash
  $ tree data
  imagenet
  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...
 
  ```

- Robustness & OOD datasets

Prepare [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-R](https://github.com/hendrycks/imagenet-r) and [ImageNet-C](https://zenodo.org/record/2235448#.Y04cBOxByFw) for evaluation.



### Pre-trained model preparation

- For pre-trained ViT-B/16, Swin-B, and ConvNext-B models on ImageNet-21K, the model weights will be automatically downloaded when you fine-tune a pre-trained model via `SSF`. You can also manually download them from [ViT](https://github.com/google-research/vision_transformer),[Swin Transformer](https://github.com/microsoft/Swin-Transformer), and [ConvNext](https://github.com/facebookresearch/ConvNeXt).



- For pre-trained AS-MLP-B model on ImageNet-1K, you can manually download them from [AS-MLP](https://github.com/svip-lab/AS-MLP).



### Fine-tuning a pre-trained model via SSF

To fine-tune a pre-trained ViT model via `SSF` on CIFAR-100 or ImageNet-1K, run:

```bash
bash train_scripts/vit/cifar_100/train_ssf.sh
```
or 
```bash
bash train_scripts/vit/imagenet_1k/train_ssf.sh
```

You can also find the similar scripts for Swin, ConvNext, and AS-MLP models. You can easily reproduce our results. Enjoy!



### Robustness & OOD

To evaluate the performance of fine-tuned model via SSF on Robustness & OOD, run:

```bash
bash train_scripts/vit/imagenet_a(r, c)/eval_ssf.sh
```


### Citation
If this project is helpful for you, you can cite our paper:
```
@InProceedings{Lian_2022_SSF,
  title={Scaling \& Shifting Your Features: A New Baseline for Efficient Model Tuning},
  author={Lian, Dongze and Zhou, Daquan and Feng, Jiashi and Wang, Xinchao},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2022}
}
```


### Acknowledgement
The code is built upon [timm](https://github.com/rwightman/pytorch-image-models). The processing of the vtab-1k dataset refers to [vpt](https://github.com/KMnP/vpt), [vtab github repo](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data), and [NOAH](https://github.com/ZhangYuanhan-AI/NOAH).


================================================
FILE: data/__init__.py
================================================
from .loader import create_loader
from .dataset_factory import create_dataset

================================================
FILE: data/cub2011.py
================================================
import os

import pandas as pd
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_file_from_google_drive


class Cub2011(VisionDataset):
    """`CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
        Args:
            root (string): Root directory of the dataset.
            train (bool, optional): If True, creates dataset from training set, otherwise
               creates from test set.
            transform (callable, optional): A function/transform that  takes in an PIL image
               and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
               target and transforms it.
            download (bool, optional): If true, downloads the dataset from the internet and
               puts it in root directory. If dataset is already downloaded, it is not
               downloaded again.
    """
    base_folder = 'CUB_200_2011/images'
    # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
    file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform)

        self.loader = default_loader
        self.train = train
        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),
                                  sep=' ', names=['class_name'], usecols=[1])
        self.class_names = class_names['class_name'].to_list()
        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1  # Targets start at 1 by default, so shift to 0
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target


if __name__ == '__main__':
    train_dataset = Cub2011('./cub2011', train=True, download=False)
    test_dataset = Cub2011('./cub2011', train=False, download=False)

================================================
FILE: data/dataset_factory.py
================================================
""" Dataset Factory

Hacked together by / Copyright 2021, Ross Wightman
"""
import os
#import hub

from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
try:
    from torchvision.datasets import Places365
    has_places365 = True
except ImportError:
    has_places365 = False
try:
    from torchvision.datasets import INaturalist
    has_inaturalist = True
except ImportError:
    has_inaturalist = False

from timm.data.dataset import IterableImageDataset, ImageDataset



# my datasets
from .stanford_dogs import dogs
from .nabirds import NABirds
from .cub2011 import Cub2011
from .vtab import VTAB



_TORCH_BASIC_DS = dict(
    cifar10=CIFAR10,
    cifar100=CIFAR100,
    mnist=MNIST,
    qmist=QMNIST,
    kmnist=KMNIST,
    fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}

_VTAB_DATASET = ['caltech101', 'clevr_count', 'dmlab', 'dsprites_ori', 'eurosat', 'flowers102', 'patch_camelyon', 'smallnorb_azi', 'svhn', 'cifar100', 'clevr_dist', 'dsprites_loc', 'dtd', 'kitti', 'pets', 'resisc45', 'smallnorb_ele', 'sun397', 'diabetic_retinopathy']




def _search_split(root, split):
    # look for sub-folder with name of split in root and use that if it exists
    split_name = split.split('[')[0]
    try_root = os.path.join(root, split_name)
    if os.path.exists(try_root):
        return try_root

    def _try(syn):
        for s in syn:
            try_root = os.path.join(root, s)
            if os.path.exists(try_root):
                return try_root
        return root
    if split_name in _TRAIN_SYNONYM:
        root = _try(_TRAIN_SYNONYM)
    elif split_name in _EVAL_SYNONYM:
        root = _try(_EVAL_SYNONYM)
    return root


def create_dataset(
        name,
        root,
        split='validation',
        search_split=True,
        class_map=None,
        load_bytes=False,
        is_training=False,
        download=False,
        batch_size=None,
        repeats=0,
        **kwargs
):
    """ Dataset factory method

    In parenthesis after each arg are the type of dataset supported for each arg, one of:
      * folder - default, timm folder (or tar) based ImageDataset
      * torch - torchvision based datasets
      * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
      * all - any of the above

    Args:
        name: dataset name, empty is okay for folder based datasets
        root: root folder of dataset (all)
        split: dataset split (all)
        search_split: search for split specific child fold from root so one can specify
            `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
        class_map: specify class -> index mapping via text file or dict (folder)
        load_bytes: load data, return images as undecoded bytes (folder)
        download: download dataset if not present and supported (TFDS, torch)
        is_training: create dataset in train mode, this is different from the split.
            For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
        batch_size: batch size hint for (TFDS)
        repeats: dataset repeats per iteration i.e. epoch (TFDS)
        **kwargs: other args to pass to dataset

    Returns:
        Dataset object
    """
    name = name.lower()
    if name.startswith('torch/'):
        name = name.split('/', 2)[-1]
        torch_kwargs = dict(root=root, download=download, **kwargs)
        if name in _TORCH_BASIC_DS:
            ds_class = _TORCH_BASIC_DS[name]
            use_train = split in _TRAIN_SYNONYM
            ds = ds_class(train=use_train, **torch_kwargs)
        elif name == 'inaturalist' or name == 'inat':
            assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
            target_type = 'full'
            split_split = split.split('/')
            if len(split_split) > 1:
                target_type = split_split[0].split('_')
                if len(target_type) == 1:
                    target_type = target_type[0]
                split = split_split[-1]
            if split in _TRAIN_SYNONYM:
                split = '2021_train'
            elif split in _EVAL_SYNONYM:
                split = '2021_valid'
            ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
        elif name == 'places365':
            assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
            if split in _TRAIN_SYNONYM:
                split = 'train-standard'
            elif split in _EVAL_SYNONYM:
                split = 'val'
            ds = Places365(split=split, **torch_kwargs)
        elif name == 'imagenet':
            if split in _EVAL_SYNONYM:
                split = 'val'
            ds = ImageNet(split=split, **torch_kwargs)
        elif name == 'image_folder' or name == 'folder':
            # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
            if search_split and os.path.isdir(root):
                # look for split specific sub-folder in root
                root = _search_split(root, split)
            ds = ImageFolder(root, **kwargs)
        else:
            assert False, f"Unknown torchvision dataset {name}"
    elif name.startswith('tfds/'):
        ds = IterableImageDataset(
            root, parser=name, split=split, is_training=is_training,
            download=download, batch_size=batch_size, repeats=repeats, **kwargs)
    else:
        # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future

        # define my datasets
        if name == 'stanford_dogs':
            ds = dogs(root=root, train=is_training, **kwargs)
        elif name == 'nabirds':
            ds = NABirds(root=root, train=is_training, **kwargs)
        elif name == 'cub2011':
            ds = Cub2011(root=root, train=is_training, **kwargs)
        elif name in _VTAB_DATASET:
            ds = VTAB(root=root, train=is_training, **kwargs)
        else:
            if os.path.isdir(os.path.join(root, split)):
                root = os.path.join(root, split)
            else:
                if search_split and os.path.isdir(root):
                    root = _search_split(root, split)
            ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
    return ds



================================================
FILE: data/loader.py
================================================
""" Loader Factory, Fast Collate, CUDA Prefetcher

Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf

Hacked together by / Copyright 2019, Ross Wightman
"""
import random
from functools import partial
from itertools import repeat
from typing import Callable

import torch.utils.data
import numpy as np

from .transforms_factory import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
from timm.data.random_erasing import RandomErasing
from timm.data.mixup import FastCollateMixup


def fast_collate(batch):
    """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
    assert isinstance(batch[0], tuple)
    batch_size = len(batch)
    if isinstance(batch[0][0], tuple):
        # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
        # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
        inner_tuple_size = len(batch[0][0])
        flattened_batch_size = batch_size * inner_tuple_size
        targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
        tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
        for i in range(batch_size):
            assert len(batch[i][0]) == inner_tuple_size  # all input tensor tuples must be same length
            for j in range(inner_tuple_size):
                targets[i + j * batch_size] = batch[i][1]
                tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
        return tensor, targets
    elif isinstance(batch[0][0], np.ndarray):
        targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
        assert len(targets) == batch_size
        tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
        for i in range(batch_size):
            tensor[i] += torch.from_numpy(batch[i][0])
        return tensor, targets
    elif isinstance(batch[0][0], torch.Tensor):
        targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
        assert len(targets) == batch_size
        tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
        for i in range(batch_size):
            tensor[i].copy_(batch[i][0])
        return tensor, targets
    else:
        assert False


def expand_to_chs(x, n):
    if not isinstance(x, (tuple, list)):
        x = tuple(repeat(x, n))
    elif len(x) == 1:
        x = x * n
    else:
        assert len(x) == n, 'normalization stats must match image channels'
    return x


class PrefetchLoader:

    def __init__(
            self,
            loader,
            mean=IMAGENET_DEFAULT_MEAN,
            std=IMAGENET_DEFAULT_STD,
            channels=3,
            fp16=False,
            re_prob=0.,
            re_mode='const',
            re_count=1,
            re_num_splits=0):

        mean = expand_to_chs(mean, channels)
        std = expand_to_chs(std, channels)
        normalization_shape = (1, channels, 1, 1)

        self.loader = loader
        self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
        self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
        self.fp16 = fp16
        if fp16:
            self.mean = self.mean.half()
            self.std = self.std.half()
        if re_prob > 0.:
            self.random_erasing = RandomErasing(
                probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
        else:
            self.random_erasing = None

    def __iter__(self):
        stream = torch.cuda.Stream()
        first = True

        for next_input, next_target in self.loader:
            with torch.cuda.stream(stream):
                next_input = next_input.cuda(non_blocking=True)
                next_target = next_target.cuda(non_blocking=True)
                if self.fp16:
                    next_input = next_input.half().sub_(self.mean).div_(self.std)
                else:
                    next_input = next_input.float().sub_(self.mean).div_(self.std)
                if self.random_erasing is not None:
                    next_input = self.random_erasing(next_input)

            if not first:
                yield input, target
            else:
                first = False

            torch.cuda.current_stream().wait_stream(stream)
            input = next_input
            target = next_target

        yield input, target

    def __len__(self):
        return len(self.loader)

    @property
    def sampler(self):
        return self.loader.sampler

    @property
    def dataset(self):
        return self.loader.dataset

    @property
    def mixup_enabled(self):
        if isinstance(self.loader.collate_fn, FastCollateMixup):
            return self.loader.collate_fn.mixup_enabled
        else:
            return False

    @mixup_enabled.setter
    def mixup_enabled(self, x):
        if isinstance(self.loader.collate_fn, FastCollateMixup):
            self.loader.collate_fn.mixup_enabled = x


def _worker_init(worker_id, worker_seeding='all'):
    worker_info = torch.utils.data.get_worker_info()
    assert worker_info.id == worker_id
    if isinstance(worker_seeding, Callable):
        seed = worker_seeding(worker_info)
        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed % (2 ** 32 - 1))
    else:
        assert worker_seeding in ('all', 'part')
        # random / torch seed already called in dataloader iter class w/ worker_info.seed
        # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
        if worker_seeding == 'all':
            np.random.seed(worker_info.seed % (2 ** 32 - 1))


def create_loader(
        dataset,
        input_size,
        batch_size,
        is_training=False,
        use_prefetcher=True,
        no_aug=False,
        simple_aug=False,
        direct_resize=False,
        re_prob=0.,
        re_mode='const',
        re_count=1,
        re_split=False,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        color_jitter=0.4,
        auto_augment=None,
        num_aug_repeats=0,
        num_aug_splits=0,
        interpolation='bilinear',
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        num_workers=1,
        distributed=False,
        crop_pct=None,
        collate_fn=None,
        pin_memory=False,
        fp16=False,
        tf_preprocessing=False,
        use_multi_epochs_loader=False,
        persistent_workers=True,
        worker_seeding='all',
):
    re_num_splits = 0
    if re_split:
        # apply RE to second half of batch if no aug split otherwise line up with aug split
        re_num_splits = num_aug_splits or 2
    dataset.transform = create_transform(
        input_size,
        is_training=is_training,
        use_prefetcher=use_prefetcher,
        no_aug=no_aug,
        simple_aug=simple_aug,
        direct_resize=direct_resize,
        scale=scale,
        ratio=ratio,
        hflip=hflip,
        vflip=vflip,
        color_jitter=color_jitter,
        auto_augment=auto_augment,
        interpolation=interpolation,
        mean=mean,
        std=std,
        crop_pct=crop_pct,
        tf_preprocessing=tf_preprocessing,
        re_prob=re_prob,
        re_mode=re_mode,
        re_count=re_count,
        re_num_splits=re_num_splits,
        separate=num_aug_splits > 0,
    )

    sampler = None
    if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
        if is_training:
            if num_aug_repeats:
                sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
            else:
                sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            # This will add extra duplicate entries to result in equal num
            # of samples per-process, will slightly alter validation results
            sampler = OrderedDistributedSampler(dataset)
    else:
        assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"

    if collate_fn is None:
        collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate

    loader_class = torch.utils.data.DataLoader
    if use_multi_epochs_loader:
        loader_class = MultiEpochsDataLoader

    loader_args = dict(
        batch_size=batch_size,
        shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
        num_workers=num_workers,
        sampler=sampler,
        collate_fn=collate_fn,
        pin_memory=pin_memory,
        drop_last=is_training,
        worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
        persistent_workers=persistent_workers
    )
    try:
        loader = loader_class(dataset, **loader_args)
    except TypeError as e:
        loader_args.pop('persistent_workers')  # only in Pytorch 1.7+
        loader = loader_class(dataset, **loader_args)
    if use_prefetcher:
        prefetch_re_prob = re_prob if is_training and not no_aug else 0.
        loader = PrefetchLoader(
            loader,
            mean=mean,
            std=std,
            channels=input_size[0],
            fp16=fp16,
            re_prob=prefetch_re_prob,
            re_mode=re_mode,
            re_count=re_count,
            re_num_splits=re_num_splits
        )

    return loader


class MultiEpochsDataLoader(torch.utils.data.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)


class _RepeatSampler(object):
    """ Sampler that repeats forever.

    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)


================================================
FILE: data/nabirds.py
================================================
import os
import pandas as pd
import warnings
import numpy as np
import torch
from PIL import Image



from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import check_integrity, extract_archive


from torch.utils.data import DataLoader, Dataset



class NABirds(Dataset):
    """`NABirds <https://dl.allaboutbirds.org/nabirds>`_ Dataset.
        Args:
            root (string): Root directory of the dataset.
            train (bool, optional): If True, creates dataset from training set, otherwise
               creates from test set.
            transform (callable, optional): A function/transform that  takes in an PIL image
               and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
               target and transforms it.
            download (bool, optional): If true, downloads the dataset from the internet and
               puts it in root directory. If dataset is already downloaded, it is not
               downloaded again.
    """
    base_folder = 'nabirds/images'

    def __init__(self, root, train=True, transform=None):
        dataset_path = os.path.join(root, 'nabirds')
        self.root = root
        self.loader = default_loader
        self.train = train
        self.transform = transform

        image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
                                  sep=' ', names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        # Since the raw labels are non-continuous, map them to new ones
        self.label_map = get_continuous_class_map(image_class_labels['target'])
        train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])
        data = image_paths.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')
        # Load in the train / test split
        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

        # Load in the class data
        self.class_names = load_class_names(dataset_path)
        self.class_hierarchy = load_hierarchy(dataset_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = self.label_map[sample.target]
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)
        return img, target

def get_continuous_class_map(class_labels):
    label_set = set(class_labels)
    return {k: i for i, k in enumerate(label_set)}

def load_class_names(dataset_path=''):
    names = {}

    with open(os.path.join(dataset_path, 'classes.txt')) as f:
        for line in f:
            pieces = line.strip().split()
            class_id = pieces[0]
            names[class_id] = ' '.join(pieces[1:])

    return names

def load_hierarchy(dataset_path=''):
    parents = {}

    with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
        for line in f:
            pieces = line.strip().split()
            child_id, parent_id = pieces
            parents[child_id] = parent_id

    return parents


================================================
FILE: data/stanford_dogs.py
================================================
from __future__ import print_function

from PIL import Image
from os.path import join
import os
import scipy.io

import torch.utils.data as data
from torchvision.datasets.utils import download_url, list_dir, list_files


class dogs(data.Dataset):
    """`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.
    Args:
        root (string): Root directory of dataset where directory
            ``omniglot-py`` exists.
        cropped (bool, optional): If true, the images will be cropped into the bounding box specified
            in the annotations
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset tar files from the internet and
            puts it in root directory. If the tar files are already downloaded, they are not
            downloaded again.
    """
    #folder = 'StanfordDogs'
    folder = ''
    download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'

    def __init__(self,
                 root,
                 train=True,
                 cropped=False,
                 transform=None,
                 target_transform=None,
                 download=False):

        self.root = join(os.path.expanduser(root), self.folder)
        self.train = train
        self.cropped = cropped
        self.transform = transform
        self.target_transform = target_transform

        if download:
            self.download()

        split = self.load_split()

        self.images_folder = join(self.root, 'Images')
        self.annotations_folder = join(self.root, 'Annotation')
        self._breeds = list_dir(self.images_folder)

        if self.cropped:
            self._breed_annotations = [[(annotation, box, idx)
                                        for box in self.get_boxes(join(self.annotations_folder, annotation))]
                                        for annotation, idx in split]
            self._flat_breed_annotations = sum(self._breed_annotations, [])

            self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]
        else:
            self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]

            self._flat_breed_images = self._breed_images

        self.classes = ["Chihuaha",
                        "Japanese Spaniel",
                        "Maltese Dog",
                        "Pekinese",
                        "Shih-Tzu",
                        "Blenheim Spaniel",
                        "Papillon",
                        "Toy Terrier",
                        "Rhodesian Ridgeback",
                        "Afghan Hound",
                        "Basset Hound",
                        "Beagle",
                        "Bloodhound",
                        "Bluetick",
                        "Black-and-tan Coonhound",
                        "Walker Hound",
                        "English Foxhound",
                        "Redbone",
                        "Borzoi",
                        "Irish Wolfhound",
                        "Italian Greyhound",
                        "Whippet",
                        "Ibizian Hound",
                        "Norwegian Elkhound",
                        "Otterhound",
                        "Saluki",
                        "Scottish Deerhound",
                        "Weimaraner",
                        "Staffordshire Bullterrier",
                        "American Staffordshire Terrier",
                        "Bedlington Terrier",
                        "Border Terrier",
                        "Kerry Blue Terrier",
                        "Irish Terrier",
                        "Norfolk Terrier",
                        "Norwich Terrier",
                        "Yorkshire Terrier",
                        "Wirehaired Fox Terrier",
                        "Lakeland Terrier",
                        "Sealyham Terrier",
                        "Airedale",
                        "Cairn",
                        "Australian Terrier",
                        "Dandi Dinmont",
                        "Boston Bull",
                        "Miniature Schnauzer",
                        "Giant Schnauzer",
                        "Standard Schnauzer",
                        "Scotch Terrier",
                        "Tibetan Terrier",
                        "Silky Terrier",
                        "Soft-coated Wheaten Terrier",
                        "West Highland White Terrier",
                        "Lhasa",
                        "Flat-coated Retriever",
                        "Curly-coater Retriever",
                        "Golden Retriever",
                        "Labrador Retriever",
                        "Chesapeake Bay Retriever",
                        "German Short-haired Pointer",
                        "Vizsla",
                        "English Setter",
                        "Irish Setter",
                        "Gordon Setter",
                        "Brittany",
                        "Clumber",
                        "English Springer Spaniel",
                        "Welsh Springer Spaniel",
                        "Cocker Spaniel",
                        "Sussex Spaniel",
                        "Irish Water Spaniel",
                        "Kuvasz",
                        "Schipperke",
                        "Groenendael",
                        "Malinois",
                        "Briard",
                        "Kelpie",
                        "Komondor",
                        "Old English Sheepdog",
                        "Shetland Sheepdog",
                        "Collie",
                        "Border Collie",
                        "Bouvier des Flandres",
                        "Rottweiler",
                        "German Shepard",
                        "Doberman",
                        "Miniature Pinscher",
                        "Greater Swiss Mountain Dog",
                        "Bernese Mountain Dog",
                        "Appenzeller",
                        "EntleBucher",
                        "Boxer",
                        "Bull Mastiff",
                        "Tibetan Mastiff",
                        "French Bulldog",
                        "Great Dane",
                        "Saint Bernard",
                        "Eskimo Dog",
                        "Malamute",
                        "Siberian Husky",
                        "Affenpinscher",
                        "Basenji",
                        "Pug",
                        "Leonberg",
                        "Newfoundland",
                        "Great Pyrenees",
                        "Samoyed",
                        "Pomeranian",
                        "Chow",
                        "Keeshond",
                        "Brabancon Griffon",
                        "Pembroke",
                        "Cardigan",
                        "Toy Poodle",
                        "Miniature Poodle",
                        "Standard Poodle",
                        "Mexican Hairless",
                        "Dingo",
                        "Dhole",
                        "African Hunting Dog"]




    def __len__(self):
        return len(self._flat_breed_images)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target character class.
        """
        image_name, target_class = self._flat_breed_images[index]
        image_path = join(self.images_folder, image_name)
        image = Image.open(image_path).convert('RGB')

        if self.cropped:
            image = image.crop(self._flat_breed_annotations[index][1])

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            target_class = self.target_transform(target_class)

        return image, target_class

    def download(self):
        import tarfile

        if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
            if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
                print('Files already downloaded and verified')
                return

        for filename in ['images', 'annotation', 'lists']:
            tar_filename = filename + '.tar'
            url = self.download_url_prefix + '/' + tar_filename
            download_url(url, self.root, tar_filename, None)
            print('Extracting downloaded file: ' + join(self.root, tar_filename))
            with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
                tar_file.extractall(self.root)
            os.remove(join(self.root, tar_filename))

    @staticmethod
    def get_boxes(path):
        import xml.etree.ElementTree
        e = xml.etree.ElementTree.parse(path).getroot()
        boxes = []
        for objs in e.iter('object'):
            boxes.append([int(objs.find('bndbox').find('xmin').text),
                          int(objs.find('bndbox').find('ymin').text),
                          int(objs.find('bndbox').find('xmax').text),
                          int(objs.find('bndbox').find('ymax').text)])
        return boxes

    def load_split(self):
        if self.train:
            split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
            labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
        else:
            split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
            labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']

        split = [item[0][0] for item in split]
        labels = [item[0]-1 for item in labels]
        return list(zip(split, labels))

    def stats(self):
        counts = {}
        for index in range(len(self._flat_breed_images)):
            image_name, target_class = self._flat_breed_images[index]
            if target_class not in counts.keys():
                counts[target_class] = 1
            else:
                counts[target_class] += 1

        print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))

        return counts

================================================
FILE: data/transforms_factory.py
================================================
""" Transforms Factory
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)

Hacked together by / Copyright 2019, Ross Wightman
"""
import math

import torch
from torchvision import transforms

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy
from timm.data.random_erasing import RandomErasing


def transforms_direct_resize(
        img_size=224,
        interpolation='bilinear',
        use_prefetcher=False,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
):
    if interpolation == 'random':
        # random interpolation not supported with no-aug
        interpolation = 'bilinear'
    tfl = [
        transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),
        transforms.CenterCrop(img_size)
    ]
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        tfl += [ToNumpy()]
    else:
        tfl += [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ]
    return transforms.Compose(tfl)

def transforms_simpleaug_train(
        img_size=224,
        scale=None,
        ratio=None,
        hflip=0.5,
        interpolation='bilinear',
        use_prefetcher=False,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
):
    tfl = [
        RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation),
        transforms.RandomHorizontalFlip(p=hflip)
    ]
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        tfl += [ToNumpy()]
    else:
        tfl += [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ]
    return transforms.Compose(tfl)




def transforms_imagenet_train(
        img_size=224,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        color_jitter=0.4,
        auto_augment=None,
        interpolation='random',
        use_prefetcher=False,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        re_prob=0.,
        re_mode='const',
        re_count=1,
        re_num_splits=0,
        separate=False,
):
    """
    If separate==True, the transforms are returned as a tuple of 3 separate transforms
    for use in a mixing dataset that passes
     * all data through the first (primary) transform, called the 'clean' data
     * a portion of the data through the secondary transform
     * normalizes and converts the branches above with the third, final transform
    """
    scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range
    ratio = tuple(ratio or (3./4., 4./3.))  # default imagenet ratio range
    primary_tfl = [
        RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
    if hflip > 0.:
        primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
    if vflip > 0.:
        primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]

    secondary_tfl = []
    if auto_augment:
        assert isinstance(auto_augment, str)
        if isinstance(img_size, (tuple, list)):
            img_size_min = min(img_size)
        else:
            img_size_min = img_size
        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )
        if interpolation and interpolation != 'random':
            aa_params['interpolation'] = str_to_pil_interp(interpolation)
        if auto_augment.startswith('rand'):
            secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
        elif auto_augment.startswith('augmix'):
            aa_params['translate_pct'] = 0.3
            secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
        else:
            secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
    elif color_jitter is not None:
        # color jitter is enabled when not using AA
        if isinstance(color_jitter, (list, tuple)):
            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
            # or 4 if also augmenting hue
            assert len(color_jitter) in (3, 4)
        else:
            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
            color_jitter = (float(color_jitter),) * 3
        secondary_tfl += [transforms.ColorJitter(*color_jitter)]

    final_tfl = []
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        final_tfl += [ToNumpy()]
    else:
        final_tfl += [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ]
        if re_prob > 0.:
            final_tfl.append(
                RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))

    if separate:
        return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
    else:
        return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)


def transforms_imagenet_eval(
        img_size=224,
        crop_pct=None,
        interpolation='bilinear',
        use_prefetcher=False,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD):
    crop_pct = crop_pct or DEFAULT_CROP_PCT

    if isinstance(img_size, (tuple, list)):
        assert len(img_size) == 2
        if img_size[-1] == img_size[-2]:
            # fall-back to older behaviour so Resize scales to shortest edge if target is square
            scale_size = int(math.floor(img_size[0] / crop_pct))
        else:
            scale_size = tuple([int(x / crop_pct) for x in img_size])
    else:
        scale_size = int(math.floor(img_size / crop_pct))


    tfl = [
        transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
        transforms.CenterCrop(img_size),
    ]


    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        tfl += [ToNumpy()]
    else:
        tfl += [
            transforms.ToTensor(),
            transforms.Normalize(
                     mean=torch.tensor(mean),
                     std=torch.tensor(std))
        ]

    return transforms.Compose(tfl)



def create_transform(
        input_size,
        is_training=False,
        use_prefetcher=False,
        no_aug=False,
        simple_aug=False,
        direct_resize=False,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        color_jitter=0.4,
        auto_augment=None,
        interpolation='bilinear',
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        re_prob=0.,
        re_mode='const',
        re_count=1,
        re_num_splits=0,
        crop_pct=None,
        tf_preprocessing=False,
        separate=False):

    if isinstance(input_size, (tuple, list)):
        img_size = input_size[-2:]
    else:
        img_size = input_size

    if tf_preprocessing and use_prefetcher:
        assert not separate, "Separate transforms not supported for TF preprocessing"
        from timm.data.tf_preprocessing import TfPreprocessTransform
        transform = TfPreprocessTransform(
            is_training=is_training, size=img_size, interpolation=interpolation)
    else:
        if is_training:
            if no_aug:
                assert not separate, "Cannot perform split augmentation with no_aug"
                transform = transforms_direct_resize(
                    img_size,
                    interpolation=interpolation,
                    use_prefetcher=use_prefetcher,
                    mean=mean,
                    std=std)
            elif simple_aug:
                transform = transforms_simpleaug_train(
                    img_size,
                    scale=scale,
                    ratio=ratio,
                    hflip=hflip,
                    interpolation=interpolation,
                    use_prefetcher=use_prefetcher,
                    mean=mean,
                    std=std)
            else:
                transform = transforms_imagenet_train(
                    img_size,
                    scale=scale,
                    ratio=ratio,
                    hflip=hflip,
                    vflip=vflip,
                    color_jitter=color_jitter,
                    auto_augment=auto_augment,
                    interpolation=interpolation,
                    use_prefetcher=use_prefetcher,
                    mean=mean,
                    std=std,
                    re_prob=re_prob,
                    re_mode=re_mode,
                    re_count=re_count,
                    re_num_splits=re_num_splits,
                    separate=separate)
        else:
            if direct_resize:
                #print('direct_resize')
                transform = transforms_direct_resize(
                    img_size,
                    interpolation=interpolation,
                    use_prefetcher=use_prefetcher,
                    mean=mean,
                    std=std)
            else:
                assert not separate, "Separate transforms not supported for validation preprocessing"
                transform = transforms_imagenet_eval(
                    img_size,
                    interpolation=interpolation,
                    use_prefetcher=use_prefetcher,
                    mean=mean,
                    std=std,
                    crop_pct=crop_pct)

    return transform


================================================
FILE: data/vtab.py
================================================
import os
from torchvision.datasets.folder import ImageFolder, default_loader

class VTAB(ImageFolder):
    def __init__(self, root, train=True, transform=None, target_transform=None, mode=None,is_individual_prompt=False,**kwargs):
        self.dataset_root = root
        self.loader = default_loader
        self.target_transform = None
        self.transform = transform

        
        train_list_path = os.path.join(self.dataset_root, 'train800val200.txt')
        test_list_path = os.path.join(self.dataset_root, 'test.txt')

        
        # train_list_path = os.path.join(self.dataset_root, 'train800.txt')
        # test_list_path = os.path.join(self.dataset_root, 'val200.txt')


        self.samples = []
        if train:
            with open(train_list_path, 'r') as f:
                for line in f:
                    img_name = line.split(' ')[0]
                    label = int(line.split(' ')[1])
                    self.samples.append((os.path.join(root,img_name), label))
        else:
            with open(test_list_path, 'r') as f:
                for line in f:
                    img_name = line.split(' ')[0]
                    label = int(line.split(' ')[1])
                    self.samples.append((os.path.join(root,img_name), label))

================================================
FILE: log/README.md
================================================



================================================
FILE: log/cifar100.csv
================================================
epoch,train_loss,eval_loss,eval_top1,eval_top5
0,5.603768242730035,6.0231625,0.96,4.89
1,4.440675179163615,5.8859375,1.02,5.18
2,3.186616155836317,5.63614375,1.43,6.43
3,2.827679475148519,5.345575,1.92,8.46
4,2.7690945731268988,5.04094375,2.7,11.41
5,2.6746084690093994,4.7265375,4.07,15.29
6,2.5769188139173718,4.404325,6.44,20.77
7,2.658121665318807,4.07631875,10.12,28.31
8,2.5106263955434165,3.74056875,15.35,37.6
9,2.638317664464315,3.398590625,23.04,49.38
10,2.4921720822652182,3.05421875,32.87,61.79
11,2.5904562208387585,2.714753125,44.93,72.77
12,2.521847221586439,2.38726875,57.29,81.89
13,2.555907726287842,2.0806125,67.52,88.52
14,2.678542561001248,1.8002078125,75.47,92.47
15,2.5564871629079184,1.549046875,81.04,94.85
16,2.5498899353875055,1.329515625,84.56,96.42
17,2.555638392766317,1.1425953125,86.4,97.29
18,2.5207514233059354,0.9856625,87.97,97.84
19,2.5135546260409884,0.8570640625,88.97,98.17
20,2.417558749516805,0.75346328125,89.95,98.5
21,2.4377992947896323,0.669475,90.59,98.76
22,2.5416789849599204,0.6010578125,91.05,98.88
23,2.4467749065823026,0.5459734375,91.42,98.99
24,2.3756895065307617,0.5007609375,91.64,99.1
25,2.4439392619662814,0.46489609375,91.92,99.19
26,2.4532913896772595,0.43501953125,92.09,99.26
27,2.508685827255249,0.410413671875,92.19,99.3
28,2.489635467529297,0.390225,92.39,99.31
29,2.4972835646735296,0.37323125,92.44,99.36
30,2.4116059409247503,0.358880078125,92.54,99.42
31,2.4685837162865534,0.346959765625,92.7,99.46
32,2.3165384001202054,0.336875,92.71,99.47
33,2.4255740377638073,0.32844765625,92.71,99.46
34,2.4002869658999972,0.321278125,92.76,99.45
35,2.4275462097591824,0.314995703125,92.8,99.47
36,2.407806317011515,0.309953125,92.91,99.49
37,2.402532418568929,0.30562421875,92.98,99.51
38,2.4551497830284967,0.30197578125,93.06,99.51
39,2.4293878608279758,0.2986171875,93.17,99.51
40,2.4136725531684027,0.296045703125,93.24,99.51
41,2.4232251379224987,0.29378125,93.32,99.52
42,2.2722683482699924,0.2918,93.31,99.52
43,2.397341330846151,0.290190625,93.35,99.52
44,2.406024138132731,0.289073828125,93.39,99.53
45,2.3778079880608454,0.288259765625,93.42,99.53
46,2.4535727500915527,0.28774140625,93.47,99.53
47,2.4934064282311335,0.287232421875,93.56,99.53
48,2.325168079800076,0.287040234375,93.59,99.52
49,2.396822929382324,0.286894140625,93.58,99.55
50,2.3157892756991916,0.287205859375,93.61,99.55
51,2.4792808956570096,0.287432421875,93.63,99.56
52,2.366891860961914,0.287920703125,93.65,99.56
53,2.345345550113254,0.2883359375,93.67,99.56
54,2.303606006834242,0.288706640625,93.66,99.56
55,2.398844109641181,0.28921875,93.68,99.57
56,2.3735866281721325,0.2897921875,93.68,99.57
57,2.4807073805067272,0.2904640625,93.69,99.58
58,2.349070734447903,0.291125,93.76,99.58
59,2.375280910068088,0.292005859375,93.8,99.58
60,2.33055845896403,0.293062109375,93.82,99.57
61,2.4362279574076333,0.29396484375,93.79,99.57
62,2.317348506715563,0.29523125,93.8,99.57
63,2.4144566191567316,0.29618203125,93.8,99.57
64,2.3383621904585095,0.297141796875,93.79,99.56
65,2.4115795029534235,0.298030078125,93.79,99.56
66,2.3216844929589167,0.299159375,93.82,99.57
67,2.318764951494005,0.30013984375,93.77,99.56
68,2.269577423731486,0.30101484375,93.75,99.56
69,2.369996494717068,0.30192890625,93.79,99.55
70,2.3662983311547174,0.303022265625,93.81,99.56
71,2.303777880138821,0.303908984375,93.85,99.56
72,2.255165616671244,0.30502734375,93.82,99.56
73,2.2920398712158203,0.305877734375,93.82,99.56
74,2.310059520933363,0.30675234375,93.87,99.56
75,2.325947019788954,0.307698046875,93.82,99.57
76,2.2926743825276694,0.308682421875,93.82,99.55
77,2.317892154057821,0.3095203125,93.86,99.55
78,2.3864409658643932,0.310484375,93.86,99.55
79,2.2767338487837048,0.311702734375,93.88,99.55
80,2.459476047092014,0.31251640625,93.89,99.55
81,2.3905074066585965,0.313465234375,93.87,99.55
82,2.380774630440606,0.314390625,93.88,99.54
83,2.191994031270345,0.31534296875,93.91,99.54
84,2.316111962000529,0.31621796875,93.9,99.54
85,2.307388676537408,0.31720625,93.89,99.54
86,2.27423980500963,0.318103125,93.9,99.54
87,2.3265264564090304,0.31895,93.91,99.54
88,2.265656683180067,0.31984140625,93.94,99.54
89,2.3482420444488525,0.320740625,93.97,99.54
90,2.3093814849853516,0.321472265625,93.97,99.54
91,2.3022206094529896,0.322198046875,93.97,99.54
92,2.3547442489200168,0.3228609375,93.99,99.54
93,2.246518611907959,0.3234953125,93.97,99.54
94,2.3851645787556968,0.324027734375,93.98,99.56
95,2.3422129816479154,0.324636328125,93.93,99.56
96,2.2714282936520047,0.325282421875,93.93,99.56
97,2.374925719367133,0.32581875,93.93,99.55
98,2.3505734867519803,0.326328125,93.94,99.54
99,2.418484025531345,0.326844140625,93.95,99.55
100,2.293912728627523,0.32743125,93.94,99.55
101,2.446575509177314,0.327946875,93.95,99.55
102,2.309349775314331,0.32843125,93.96,99.55
103,2.316555658976237,0.3289484375,93.95,99.55
104,2.3285930156707764,0.329358203125,93.94,99.56
105,2.2936308648851185,0.32980234375,93.93,99.56
106,2.3283233642578125,0.330234765625,93.95,99.57
107,2.3618712955050998,0.33079453125,93.95,99.57
108,2.3460081683264837,0.3311765625,93.95,99.57
109,2.352536598841349,0.33160625,93.96,99.57


================================================
FILE: models/as_mlp.py
================================================
# --------------------------------------------------------
# AS-MLP
# Licensed under The MIT License [see LICENSE for details]
# Written by Zehao Yu and Dongze Lian (AS-MLP)
# --------------------------------------------------------

import logging
import math
from copy import deepcopy
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.fx_features import register_notrace_function
from timm.models.helpers import build_model_with_cfg 
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.layers import _assert
from timm.models.registry import register_model

from timm.models.vision_transformer import checkpoint_filter_fn 



_logger = logging.getLogger(__name__)

def _cfg(url='', file='', **kwargs):
    return {
        'url': url,
        'file': file,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'first_conv': 'patch_embed.proj', 'classifier': 'head',
        **kwargs
    }

default_cfgs = {
    'as_base_patch4_window7_224': _cfg(
        file='/path/to/asmlp_base_patch4_shift5_224.pth'
    ),
}


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tuning_mode=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1)
        self.drop = nn.Dropout(drop)


        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':        
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)


    def forward(self, x):
        x = self.fc1(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x) 
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)

        x = self.drop(x)
        
        return x



class AxialShift(nn.Module):
    r""" Axial shift  

    Args:
        dim (int): Number of input channels.
        shift_size (int): shift size .
        as_bias (bool, optional):  If True, add a learnable bias to as mlp. Default: True
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, shift_size, as_bias=True, proj_drop=0., tuning_mode=None):

        super().__init__()
        self.dim = dim
        self.shift_size = shift_size
        self.pad = shift_size // 2
        self.conv1 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)
        self.conv2_1 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)
        self.conv2_2 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)
        self.conv3 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias)

        self.actn = nn.GELU()

        self.norm1 = MyNorm(dim)
        self.norm2 = MyNorm(dim)
        self.tuning_mode = tuning_mode


        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
            self.ssf_scale_3, self.ssf_shift_3 = init_ssf_scale_shift(dim)
            self.ssf_scale_4, self.ssf_shift_4 = init_ssf_scale_shift(dim)


    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, C, H, W = x.shape
        x = self.conv1(x)
        if self.tuning_mode == 'ssf':   
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
        
        x = self.norm1(x)
        x = self.actn(x)
       
        
        x = F.pad(x, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
        
        xs = torch.chunk(x, self.shift_size, 1)

        def shift(dim):
            x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
            x_cat = torch.cat(x_shift, 1)
            x_cat = torch.narrow(x_cat, 2, self.pad, H)
            x_cat = torch.narrow(x_cat, 3, self.pad, W)
            return x_cat

        x_shift_lr = shift(3)
        x_shift_td = shift(2)
    

        if self.tuning_mode == 'ssf':   
            x_lr = ssf_ada(self.conv2_1(x_shift_lr), self.ssf_scale_2, self.ssf_shift_2)
            x_td = ssf_ada(self.conv2_2(x_shift_td), self.ssf_scale_3, self.ssf_shift_3)
        else:
            x_lr = self.conv2_1(x_shift_lr)
            x_td = self.conv2_2(x_shift_td)

        x_lr = self.actn(x_lr)
        x_td = self.actn(x_td)

        x = x_lr + x_td
        x = self.norm2(x)

        x = self.conv3(x)
        if self.tuning_mode == 'ssf': 
            x = ssf_ada(x, self.ssf_scale_4, self.ssf_shift_4)

        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, shift_size={self.shift_size}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # conv1 
        flops += N * self.dim * self.dim
        # norm 1
        flops += N * self.dim
        # conv2_1 conv2_2
        flops += N * self.dim * self.dim * 2
        # x_lr + x_td
        flops += N * self.dim
        # norm2
        flops += N * self.dim
        # norm3
        flops += N * self.dim * self.dim
        return flops


class AxialShiftedBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        shift_size (int): Shift size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        as_bias (bool, optional): If True, add a learnable bias to Axial Mlp. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, shift_size=7,
                 mlp_ratio=4., as_bias=True, drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = norm_layer(dim)
        self.axial_shift = AxialShift(dim, shift_size=shift_size, as_bias=as_bias, proj_drop=drop, tuning_mode=tuning_mode)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, tuning_mode=tuning_mode)


        self.tuning_mode = tuning_mode

        if tuning_mode == 'ssf':    
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)


    def forward(self, x):
        B, C, H, W = x.shape

        shortcut = x
        
        x = self.norm1(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        # axial shift block
        x = self.axial_shift(x)  # B, C, H, W

        # FFN
        x = shortcut + self.drop_path(x)
        if self.tuning_mode == 'ssf':
            x = x + self.drop_path(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))
        else:
            x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, " \
               f"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # shift mlp 
        flops += self.axial_shift.flops(H * W)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tuning_mode=None):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Conv2d(4 * dim, 2 * dim, 1, 1, bias=False)
        self.norm = norm_layer(4 * dim)
        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(4 * dim)


    def forward(self, x):
        """
        x: B, H*W, C
        """
        B, C, H, W = x.shape
        #assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, C, H, W)

        x0 = x[:, :, 0::2, 0::2]  # B C H/2 W/2 
        x1 = x[:, :, 1::2, 0::2]  # B C H/2 W/2 
        x2 = x[:, :, 0::2, 1::2]  # B C H/2 W/2 
        x3 = x[:, :, 1::2, 1::2]  # B C H/2 W/2 
        x = torch.cat([x0, x1, x2, x3], 1)  # B 4*C H/2 W/2 

        x = self.norm(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = self.reduction(x)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, shift_size,
                 mlp_ratio=4., as_bias=True, drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, tuning_mode=None):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            AxialShiftedBlock(dim=dim, input_resolution=input_resolution,
                              shift_size=shift_size,
                              mlp_ratio=mlp_ratio,
                              as_bias=as_bias,
                              drop=drop, 
                              drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                              norm_layer=norm_layer, tuning_mode=tuning_mode[i])
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, tuning_mode=tuning_mode)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops


class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, tuning_mode=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
        
        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)

            

            if norm_layer:
                self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)


    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)#.flatten(2).transpose(1, 2)  # B Ph*Pw C

        if self.tuning_mode == 'ssf':  
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
            if self.norm is not None:
                x = self.norm(x)
                x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
        else:
            if self.norm is not None:
                x = self.norm(x)

        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops


def MyNorm(dim):
    return nn.GroupNorm(1, dim)


def init_ssf_scale_shift(dim):
    scale = nn.Parameter(torch.ones(dim))
    shift = nn.Parameter(torch.zeros(dim))

    nn.init.normal_(scale, mean=1, std=.02)
    nn.init.normal_(shift, std=.02)

    return scale, shift


def ssf_ada(x, scale, shift):
    assert scale.shape == shift.shape
    if x.shape[-1] == scale.shape[0]:
        return x * scale + shift
    elif x.shape[1] == scale.shape[0]:
        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
    else:
        raise ValueError('the input tensor shape does not match the shape of the scale factor.')


class AS_MLP(nn.Module):
    r""" AS-MLP
        A PyTorch impl of : `AS-MLP: An Axial Shifted MLP Architecture for Vision`  -
          https://arxiv.org/pdf/xxx.xxx

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each AS-MLP layer.
        window_size (int): shift size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        as_bias (bool): If True, add a learnable bias to as-mlp block. Default: True
        drop_rate (float): Dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.GroupNorm with group=1.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], 
                 shift_size=5, mlp_ratio=4., as_bias=True, 
                 drop_rate=0., drop_path_rate=0.1,
                 norm_layer=MyNorm, patch_norm=True,
                 use_checkpoint=False, tuning_mode=None, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        
        self.tuning_mode = tuning_mode
        tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(self.num_layers)]
        if tuning_mode == 'ssf':   
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features)
            

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               shift_size=shift_size,
                               mlp_ratio=self.mlp_ratio,
                               as_bias=as_bias,
                               drop=drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint,
                               tuning_mode=tuning_mode_list[i_layer])
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        #self.apply(self._init_weights)
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()


    def forward_features(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)
        
        x = self.norm(x)  # B L C
        if self.tuning_mode == 'ssf': 
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
        
        x = self.avgpool(x)  # B C 1 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops


def _create_as_mlp(variant, pretrained=False, **kwargs):
    model = build_model_with_cfg(
        AS_MLP, variant, pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        **kwargs)

    return model


@register_model
def as_base_patch4_window7_224(pretrained=False, **kwargs):
    """ AS-MLP-B @ 224x224, pretrained ImageNet-1k
    """
    model_kwargs = dict(
        patch_size=4, shift_size=5, embed_dim=128, depths=(2, 2, 18, 2), **kwargs)
    return _create_as_mlp('as_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)


================================================
FILE: models/convnext.py
================================================
""" ConvNeXt

Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf

Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below

Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the MIT license
import math
from collections import OrderedDict
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.fx_features import register_notrace_module
from timm.models.helpers import named_apply, build_model_with_cfg, checkpoint_seq
from timm.models.layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, to_2tuple

from timm.models.registry import register_model



__all__ = ['ConvNeXt']  # model_registry will add each entrypoint fn to this


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 0.875, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'first_conv': 'stem.0', 'classifier': 'head.fc',
        **kwargs
    }


default_cfgs = dict(
    convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"),
    convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"),
    convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
    convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),

    convnext_tiny_hnf=_cfg(url=''),

    convnext_base_in22ft1k=_cfg(
        url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
    convnext_large_in22ft1k=_cfg(
        url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'),
    convnext_xlarge_in22ft1k=_cfg(
        url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),

    convnext_base_384_in22ft1k=_cfg(
        url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
    convnext_large_384_in22ft1k=_cfg(
        url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
    convnext_xlarge_384_in22ft1k=_cfg(
        url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
        input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),

    convnext_base_in22k=_cfg(
        url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
    convnext_large_in22k=_cfg(
        url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
    convnext_xlarge_in22k=_cfg(
        url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
)



def _is_contiguous(tensor: torch.Tensor) -> bool:
    # jit is oh so lovely :/
    # if torch.jit.is_tracing():
    #     return True
    if torch.jit.is_scripting():
        return tensor.is_contiguous()
    else:
        return tensor.is_contiguous(memory_format=torch.contiguous_format)



@register_notrace_module
class LayerNorm2d(nn.LayerNorm):
    r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
    """

    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__(normalized_shape, eps=eps)

    def forward(self, x) -> torch.Tensor:
        if _is_contiguous(x):
            return F.layer_norm(
                x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
        else:
            s, u = torch.var_mean(x, dim=1, keepdim=True)
            x = (x - u) * torch.rsqrt(s + self.eps)
            x = x * self.weight[:, None, None] + self.bias[:, None, None]
            return x




class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':        
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
        

    def forward(self, x):
        x = self.fc1(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x) 
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)

        x = self.drop2(x)
        
        return x



class ConvNeXtBlock(nn.Module):
    """ ConvNeXt Block
    There are two equivalent implementations:
      (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
      (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back

    Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
    choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
    is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.

    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """

    def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None, tuning_mode=None):
        super().__init__()
        if not norm_layer:
            norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
        mlp_layer = ConvMlp if conv_mlp else Mlp
        self.use_conv_mlp = conv_mlp
        self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, tuning_mode=tuning_mode)

        self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':        
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)




    def forward(self, x):
        shortcut = x
        x = self.conv_dw(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        if self.use_conv_mlp:
            x = self.norm(x)
            x = self.mlp(x)
        else:
            x = x.permute(0, 2, 3, 1)
            x = self.norm(x)
            if self.tuning_mode == 'ssf':
                x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)

            x = self.mlp(x)
            x = x.permute(0, 3, 1, 2)
        if self.gamma is not None:
            x = x.mul(self.gamma.reshape(1, -1, 1, 1))
        x = self.drop_path(x) + shortcut

        return x


class Downsample(nn.Module):
    """ 2D Image to Downsample
    """
    def __init__(self, dim, out_dim, kernel_size, stride, norm_layer=None, tuning_mode=None):
        super().__init__()

        self.norm = norm_layer(dim)
        self.proj = nn.Conv2d(dim, out_dim, kernel_size=stride, stride=stride)

        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_dim)


    def forward(self, x):
        if self.tuning_mode == 'ssf':  
            x = ssf_ada(self.norm(x), self.ssf_scale_1, self.ssf_shift_1)
            x = ssf_ada(self.proj(x), self.ssf_scale_2, self.ssf_shift_2)
        else:
            x = self.norm(x)
            x = self.proj(x)

        return x
            


class ConvNeXtStage(nn.Module):

    def __init__(
            self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False,
            norm_layer=None, cl_norm_layer=None, cross_stage=False, tuning_mode=None):
        super().__init__()
        self.grad_checkpointing = False 

        if in_chs != out_chs or stride > 1:
            self.downsample = Downsample(dim=in_chs, out_dim=out_chs, kernel_size=stride, stride=stride, norm_layer=norm_layer, tuning_mode=tuning_mode)
        else:
            self.downsample = nn.Identity()

        dp_rates = dp_rates or [0.] * depth
        self.blocks = nn.Sequential(*[ConvNeXtBlock(
            dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
            norm_layer=norm_layer if conv_mlp else cl_norm_layer, tuning_mode=tuning_mode[j])
            for j in range(depth)]
        )

    def forward(self, x):
        x = self.downsample(x)

        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)

        return x



class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, tuning_mode=None):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)


    def forward(self, x):
        if self.tuning_mode == 'ssf':  
            x = ssf_ada(self.proj(x), self.ssf_scale_1, self.ssf_shift_1)
            x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2)
        else:
            x = self.proj(x)
            x = self.norm(x)

        return x


def init_ssf_scale_shift(dim):
    scale = nn.Parameter(torch.ones(dim))
    shift = nn.Parameter(torch.zeros(dim))

    nn.init.normal_(scale, mean=1, std=.02)
    nn.init.normal_(shift, std=.02)

    return scale, shift


def ssf_ada(x, scale, shift):
    assert scale.shape == shift.shape
    if x.shape[-1] == scale.shape[0]:
        return x * scale + shift
    elif x.shape[1] == scale.shape[0]:
        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
    else:
        raise ValueError('the input tensor shape does not match the shape of the scale factor.')


class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  - https://arxiv.org/pdf/2201.03545.pdf

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_rate (float): Head dropout rate
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """

    def __init__(
            self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,
            depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),  ls_init_value=1e-6, conv_mlp=False,
            head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., tuning_mode=None
    ):
        super().__init__()
        assert output_stride == 32
        if norm_layer is None:
            norm_layer = partial(LayerNorm2d, eps=1e-6)
            cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
        else:
            assert conv_mlp,\
                'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
            cl_norm_layer = norm_layer

        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.feature_info = []

        # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
        self.stem = PatchEmbed(patch_size=4, in_chans=3, embed_dim=dims[0], norm_layer=norm_layer, tuning_mode=tuning_mode)

        self.tuning_mode = tuning_mode
        tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(len(depths))]
        if tuning_mode == 'ssf':       
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dims[3])

        self.stages = nn.Sequential()
        dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
        curr_stride = patch_size
        prev_chs = dims[0]
        stages = []
        # 4 feature resolution stages, each consisting of multiple residual blocks
        for i in range(4):
            stride = 2 if i > 0 else 1
            # FIXME support dilation / output_stride
            curr_stride *= stride
            out_chs = dims[i]
            stages.append(ConvNeXtStage(
                prev_chs, out_chs, stride=stride,
                depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
                norm_layer=norm_layer, cl_norm_layer=cl_norm_layer, tuning_mode=tuning_mode_list[i])
            )
            prev_chs = out_chs
            # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
            self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
        self.stages = nn.Sequential(*stages)

        self.num_features = prev_chs
        if head_norm_first:
            # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
            self.norm_pre = norm_layer(self.num_features)  # final norm layer, before pooling
            self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
        else:
            # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
            self.norm_pre = nn.Identity()
            self.head = nn.Sequential(OrderedDict([
                ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
                ('norm', norm_layer(self.num_features)),
                ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
                ('drop', nn.Dropout(self.drop_rate)),
                ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
            ]))

        named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)


    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        for s in self.stages:
            s.grad_checkpointing = enable

    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes=0, global_pool='avg'):
        if isinstance(self.head, ClassifierHead):
            # norm -> global pool -> fc
            self.head = ClassifierHead(
                self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
        else:
            # pool -> norm -> fc
            self.head = nn.Sequential(OrderedDict([
                ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
                ('norm', self.head.norm),
                ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
                ('drop', nn.Dropout(self.drop_rate)),
                ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
            ]))

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        x = self.norm_pre(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(self.norm_pre(x), self.ssf_scale_1, self.ssf_shift_1)

        return x


    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _init_weights(module, name=None, head_init_scale=1.0):
    if isinstance(module, nn.Conv2d):
        trunc_normal_(module.weight, std=.02)
        nn.init.constant_(module.bias, 0)
    elif isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=.02)
        nn.init.constant_(module.bias, 0)
        if name and 'head.' in name:
            module.weight.data.mul_(head_init_scale)
            module.bias.data.mul_(head_init_scale)


def checkpoint_filter_fn(state_dict, model):
    """ Remap FB checkpoints -> timm """
    #ipdb.set_trace()
    if 'model' in state_dict:
        state_dict = state_dict['model']
    out_dict = {}
    import re
    for k, v in state_dict.items():
        k = k.replace('downsample_layers.0.0.', 'stem.proj.')
        k = k.replace('downsample_layers.0.1.', 'stem.norm.')


        k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)

        k = re.sub(r'downsample_layers.([0-9]+).([0]+)', r'stages.\1.downsample.norm', k)
        k = re.sub(r'downsample_layers.([0-9]+).([1]+)', r'stages.\1.downsample.proj', k)


        k = k.replace('dwconv', 'conv_dw')
        k = k.replace('pwconv', 'mlp.fc')
        k = k.replace('head.', 'head.fc.')
        if k.startswith('norm.'):
            k = k.replace('norm', 'head.norm')
        if v.ndim == 2 and 'head' not in k:
            model_shape = model.state_dict()[k].shape
            v = v.reshape(model_shape)
        out_dict[k] = v
    return out_dict


def _create_convnext(variant, pretrained=False, **kwargs):
    model = build_model_with_cfg(
        ConvNeXt, variant, pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
        **kwargs)
    return model



@register_model
def convnext_tiny(pretrained=False, **kwargs):
    model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
    model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_tiny_hnf(pretrained=False, **kwargs):
    model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs)
    model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_small(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
    model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_base(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
    model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_large(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
    model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_base_in22ft1k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
    model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_large_in22ft1k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
    model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
    model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_base_384_in22ft1k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
    model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_large_384_in22ft1k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
    model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
    model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_base_in22k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
    model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_large_in22k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
    model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args)
    return model


@register_model
def convnext_xlarge_in22k(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
    model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
    return model


================================================
FILE: models/swin_transformer.py
================================================
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
    - https://arxiv.org/pdf/2103.14030
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import logging
import math
from copy import deepcopy
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.fx_features import register_notrace_function
from timm.models.helpers import build_model_with_cfg, named_apply 
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.layers import _assert
from timm.models.registry import register_model

from timm.models.vision_transformer import checkpoint_filter_fn, get_init_weights_vit 

import ipdb

_logger = logging.getLogger(__name__)


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'first_conv': 'patch_embed.proj', 'classifier': 'head',
        **kwargs
    }


default_cfgs = {
    # patch models (my experiments)
    'swin_base_patch4_window12_384': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
        input_size=(3, 384, 384), crop_pct=1.0),

    'swin_base_patch4_window7_224': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
    ),

    'swin_large_patch4_window12_384': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
        input_size=(3, 384, 384), crop_pct=1.0),

    'swin_large_patch4_window7_224': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
    ),

    'swin_small_patch4_window7_224': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
    ),

    'swin_tiny_patch4_window7_224': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
    ),

    'swin_base_patch4_window12_384_in22k': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
        input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),

    'swin_base_patch4_window7_224_in22k': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
        num_classes=21841),

    'swin_large_patch4_window12_384_in22k': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
        input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),

    'swin_large_patch4_window7_224_in22k': _cfg(
        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
        num_classes=21841),

}



class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

            
        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':        
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)


    def forward(self, x):
        x = self.fc1(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x) 
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
        
        x = self.drop2(x)
        
        return x

        
def window_partition(x, window_size: int):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


@register_notrace_function  # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., tuning_mode=None):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

        self.tuning_mode = tuning_mode

        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)



    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        if self.tuning_mode == 'ssf':
            #qkv = (self.qkv(x) * self.ssf_scale_1 + self.ssf_shift_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            qkv = (ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1)).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        
        x = self.proj(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)

        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop, tuning_mode=tuning_mode)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, tuning_mode=tuning_mode)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

        self.tuning_mode = tuning_mode

        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
        

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape

        shortcut = x
        
        x = self.norm1(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        if self.tuning_mode == 'ssf':
            x = x + self.drop_path(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))
        else:
            x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tuning_mode=None):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)
        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(4 * dim)


    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape

        _assert(L == H * W, "input feature has wrong size")
        _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = self.reduction(x)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, tuning_mode=None):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, tuning_mode=tuning_mode[i])
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, tuning_mode=tuning_mode)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if not torch.jit.is_scripting() and self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"


class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, tuning_mode=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten
        self.norm_layer = norm_layer

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)

            if norm_layer:
                self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)


    def forward(self, x):
        B, C, H, W = x.shape
        _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
        _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")

        x = self.proj(x) 
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        if self.tuning_mode == 'ssf':  
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
            if self.norm_layer:
                x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2)
            else:
                x = self.norm(x)
        else:
            x = self.norm(x)
        return x



def init_ssf_scale_shift(dim):
    scale = nn.Parameter(torch.ones(dim))
    shift = nn.Parameter(torch.zeros(dim))

    nn.init.normal_(scale, mean=1, std=.02)
    nn.init.normal_(shift, std=.02)

    return scale, shift


def ssf_ada(x, scale, shift):
    assert scale.shape == shift.shape
    if x.shape[-1] == scale.shape[0]:
        return x * scale + shift
    elif x.shape[1] == scale.shape[0]:
        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
    else:
        raise ValueError('the input tensor shape does not match the shape of the scale factor.')


class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, weight_init='', tuning_mode=None, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None, tuning_mode=tuning_mode)
        num_patches = self.patch_embed.num_patches
        self.patch_grid = self.patch_embed.grid_size

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)
        else:
            self.absolute_pos_embed = None

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        self.tuning_mode = tuning_mode
        tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(self.num_layers)]
        if tuning_mode == 'ssf':   
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features)

        # build layers
        layers = []
        for i_layer in range(self.num_layers):
            layers += [BasicLayer(
                dim=int(embed_dim * 2 ** i_layer),
                input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=self.mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint, 
                tuning_mode=tuning_mode_list[i_layer])
            ]
        self.layers = nn.Sequential(*layers)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        if weight_init != 'skip':
            self.init_weights(weight_init)


    @torch.jit.ignore
    def init_weights(self, mode=''):
        assert mode in ('jax', 'jax_nlhb', 'moco', '')
        if self.absolute_pos_embed is not None:
            trunc_normal_(self.absolute_pos_embed, std=.02)
        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
        named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.absolute_pos_embed is not None:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        x = self.layers(x)
        x = self.norm(x)

        if self.tuning_mode == 'ssf': 
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _create_swin_transformer(variant, pretrained=False, **kwargs):  
    model = build_model_with_cfg(
        SwinTransformer, variant, pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        **kwargs)

    return model



@register_model
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
    """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
    """
    model_kwargs = dict(
        patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
    return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)


@register_model
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
    """ Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
    """
    model_kwargs = dict(
        patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
    return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)


@register_model
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
    """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
    """
    model_kwargs = dict(
        patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
    return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)


@register_model
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
    """ Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
    """
    model_kwargs = dict(
        patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
    return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)


@register_model
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
    """ Swin-S @ 224x224, trained ImageNet-1k
    """
    model_kwargs = dict(
        patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
    return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)


@register_model
def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
    """ Swin-T @ 224x224, trained ImageNet-1k
    """
    model_kwargs = dict(
        patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
    return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)


@register_model
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
    """ Swin-B @ 384x384, trained ImageNet-22k
    """
    model_kwargs = dict(
        patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
    return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)


@register_model
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
    """ Swin-B @ 224x224, trained ImageNet-22k
    """
    model_kwargs = dict(
        patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
    return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)


@register_model
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
    """ Swin-L @ 384x384, trained ImageNet-22k
    """
    model_kwargs = dict(
        patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
    return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)


@register_model
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
    """ Swin-L @ 224x224, trained ImageNet-22k
    """
    model_kwargs = dict(
        patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
    return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)

================================================
FILE: models/vision_transformer.py
================================================
""" Vision Transformer (ViT) in PyTorch

A PyTorch implement of Vision Transformers as described in:

'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
    - https://arxiv.org/abs/2010.11929

`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
    - https://arxiv.org/abs/2106.10270

The official jax code is released and available at https://github.com/google-research/vision_transformer

Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert

Hacked together by / Copyright 2020, Ross Wightman
"""
import math
import logging
from functools import partial
from collections import OrderedDict
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv, resolve_pretrained_cfg, checkpoint_seq
from timm.models.layers import DropPath, trunc_normal_, lecun_normal_, _assert
from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model



import ipdb


_logger = logging.getLogger(__name__)


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
        'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
        'first_conv': 'patch_embed.proj', 'classifier': 'head',
        **kwargs
    }


default_cfgs = {
    # patch models (weights from official Google JAX impl)
    'vit_tiny_patch16_224': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
    'vit_tiny_patch16_384': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
        input_size=(3, 384, 384), crop_pct=1.0),
    'vit_small_patch32_224': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
    'vit_small_patch32_384': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
        input_size=(3, 384, 384), crop_pct=1.0),
    'vit_small_patch16_224': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
    'vit_small_patch16_384': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
        input_size=(3, 384, 384), crop_pct=1.0),
    'vit_base_patch32_224': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
    'vit_base_patch32_384': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
        input_size=(3, 384, 384), crop_pct=1.0),
    'vit_base_patch16_224': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
    'vit_base_patch16_384': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
        input_size=(3, 384, 384), crop_pct=1.0),
    'vit_base_patch8_224': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
    'vit_large_patch32_224': _cfg(
        url='',  # no official model weights for this combo, only for in21k
        ),
    'vit_large_patch32_384': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
        input_size=(3, 384, 384), crop_pct=1.0),
    'vit_large_patch16_224': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
    'vit_large_patch16_384': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/'
            'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
        input_size=(3, 384, 384), crop_pct=1.0),



    # patch models, imagenet21k (weights from official Google JAX impl)
    'vit_tiny_patch16_224_in21k': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
        num_classes=21843),
    'vit_small_patch16_224_in21k': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
        num_classes=21843),
    'vit_base_patch16_224_in21k': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
        num_classes=21843),
    'vit_large_patch16_224_in21k': _cfg(
        url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
        num_classes=21843),


}




class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

            
        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':        
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)


    def forward(self, x):
        x = self.fc1(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)

        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x) 
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
        
        x = self.drop2(x)
        
        return x


        

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., tuning_mode=None):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)


        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)



    def forward(self, x):
        B, N, C = x.shape
        if self.tuning_mode == 'ssf':
            qkv = (ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1)).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        if self.tuning_mode == 'ssf':
            x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
        x = self.proj_drop(x)
        return x


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class Block(nn.Module):

    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
            drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None):
        super().__init__()
        self.dim = dim
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, tuning_mode=tuning_mode)
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, tuning_mode=tuning_mode)
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()


        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
            self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)



    def forward(self, x):
        if self.tuning_mode == 'ssf':
            x = x + self.drop_path1(self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1))))
            x = x + self.drop_path2(self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))))
        else:
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


class ResPostBlock(nn.Module):
    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
            drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.init_values = init_values

        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm1 = norm_layer(dim)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.norm2 = norm_layer(dim)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.init_weights()

    def init_weights(self):
        # NOTE this init overrides that base model init with specific changes for the block type
        if self.init_values is not None:
            nn.init.constant_(self.norm1.weight, self.init_values)
            nn.init.constant_(self.norm2.weight, self.init_values)

    def forward(self, x):
        x = x + self.drop_path1(self.norm1(self.attn(x)))
        x = x + self.drop_path2(self.norm2(self.mlp(x)))
        return x


class ParallelBlock(nn.Module):

    def __init__(
            self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
            drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_parallel = num_parallel
        self.attns = nn.ModuleList()
        self.ffns = nn.ModuleList()
        for _ in range(num_parallel):
            self.attns.append(nn.Sequential(OrderedDict([
                ('norm', norm_layer(dim)),
                ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
                ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
                ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
            ])))
            self.ffns.append(nn.Sequential(OrderedDict([
                ('norm', norm_layer(dim)),
                ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
                ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
                ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
            ])))

    def _forward_jit(self, x):
        x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
        x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
        return x

    @torch.jit.ignore
    def _forward(self, x):
        x = x + sum(attn(x) for attn in self.attns)
        x = x + sum(ffn(x) for ffn in self.ffns)
        return x

    def forward(self, x):
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            return self._forward_jit(x)
        else:
            return self._forward(x)


class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, tuning_mode=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten
        self.norm_layer = norm_layer

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

        self.tuning_mode = tuning_mode
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)

            if norm_layer:
                self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim)



    def forward(self, x):
        B, C, H, W = x.shape
        _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
        _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")

        x = self.proj(x) 
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        if self.tuning_mode == 'ssf':  
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
            if self.norm_layer:
                x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2)
            else:
                x = self.norm(x)
        else:
            x = self.norm(x)
        return x



def init_ssf_scale_shift(dim):
    scale = nn.Parameter(torch.ones(dim))
    shift = nn.Parameter(torch.zeros(dim))

    nn.init.normal_(scale, mean=1, std=.02)
    nn.init.normal_(shift, std=.02)

    return scale, shift


def ssf_ada(x, scale, shift):
    assert scale.shape == shift.shape
    if x.shape[-1] == scale.shape[0]:
        return x * scale + shift
    elif x.shape[1] == scale.shape[0]:
        return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
    else:
        raise ValueError('the input tensor shape does not match the shape of the scale factor.')


class VisionTransformer(nn.Module):
    """ Vision Transformer

    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
        - https://arxiv.org/abs/2010.11929
    """

    def __init__(
            self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
            embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
            class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
            embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, tuning_mode=None): 
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            global_pool (str): type of global pooling for final sequence (default: 'token')
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            init_values: (float): layer-scale init values
            class_token (bool): use class token
            fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            weight_init (str): weight init scheme
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
            act_layer: (nn.Module): MLP activation layer
        """
        super().__init__()
        assert global_pool in ('', 'avg', 'token')
        assert class_token or global_pool != 'token'
        use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.num_classes = num_classes
        self.global_pool = global_pool
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 1 if class_token else 0
        self.grad_checkpointing = False 

        self.patch_embed = embed_layer(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        
        self.tuning_mode = tuning_mode
        tuning_mode_list = [tuning_mode] * depth 
        if tuning_mode == 'ssf':     
            self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features)

        self.blocks = nn.Sequential(*[
            block_fn(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, tuning_mode=tuning_mode_list[i])
            for i in range(depth)])


        self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()

        # Classifier Head
        self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        if weight_init != 'skip':
            self.init_weights(weight_init)

    def init_weights(self, mode=''):
        assert mode in ('jax', 'jax_nlhb', 'moco', '')
        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
        trunc_normal_(self.pos_embed, std=.02)
        if self.cls_token is not None:
            nn.init.normal_(self.cls_token, std=1e-6)
        named_apply(get_init_weights_vit(mode, head_bias), self)

    def _init_weights(self, m):
        # this fn left here for compat with downstream users
        init_weights_vit_timm(m)

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix=''):
        _load_weights(self, checkpoint_path, prefix)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token', 'dist_token'}

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        return dict(
            stem=r'^cls_token|pos_embed|patch_embed',  # stem and embed
            blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    @torch.jit.ignore
    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes: int, global_pool=None):
        self.num_classes = num_classes
        if global_pool is not None:
            assert global_pool in ('', 'avg', 'token')
            self.global_pool = global_pool
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()


    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.cls_token is not None:
            x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = self.pos_drop(x + self.pos_embed)

        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)

        x = self.norm(x)
        if self.tuning_mode == 'ssf': 
            x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
            
        return x 

    def forward_head(self, x, pre_logits: bool = False):
        if self.global_pool:
            x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        return x if pre_logits else self.head(x)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        
        return x 


def init_weights_vit_timm(module: nn.Module, name: str = ''):
    """ ViT weight initialization, original timm impl (for reproducibility) """
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()


def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
    """ ViT weight initialization, matching JAX (Flax) impl """
    if isinstance(module, nn.Linear):
        if name.startswith('head'):
            nn.init.zeros_(module.weight)
            nn.init.constant_(module.bias, head_bias)
        else:
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        lecun_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()


def init_weights_vit_moco(module: nn.Module, name: str = ''):
    """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
    if isinstance(module, nn.Linear):
        if 'qkv' in name:
            # treat the weights of Q, K, V separately
            val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
            nn.init.uniform_(module.weight, -val, val)
        else:
            nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()


def get_init_weights_vit(mode='jax', head_bias: float = 0.):
    if 'jax' in mode:
        return partial(init_weights_vit_jax, head_bias=head_bias)
    elif 'moco' in mode:
        return init_weights_vit_moco
    else:
        return init_weights_vit_timm


@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
    """ Load weights from .npz checkpoints for official Google Brain Flax implementation
    """
    import numpy as np

    def _n2p(w, t=True):
        if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
            w = w.flatten()
        if t:
            if w.ndim == 4:
                w = w.transpose([3, 2, 0, 1])
            elif w.ndim == 3:
                w = w.transpose([2, 0, 1])
            elif w.ndim == 2:
                w = w.transpose([1, 0])
        return torch.from_numpy(w)

    w = np.load(checkpoint_path)
    if not prefix and 'opt/target/embedding/kernel' in w:
        prefix = 'opt/target/'

    if hasattr(model.patch_embed, 'backbone'):
        # hybrid
        backbone = model.patch_embed.backbone
        stem_only = not hasattr(backbone, 'stem')
        stem = backbone if stem_only else backbone.stem
        stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
        stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
        stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
        if not stem_only:
            for i, stage in enumerate(backbone.stages):
                for j, block in enumerate(stage.blocks):
                    bp = f'{prefix}block{i + 1}/unit{j + 1}/'
                    for r in range(3):
                        getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
                        getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
                        getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
                    if block.downsample is not None:
                        block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
                        block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
                        block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
        embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
    else:
        embed_conv_w = adapt_input_conv(
            model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
    model.patch_embed.proj.weight.copy_(embed_conv_w)
    model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
    model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
    pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
    if pos_embed_w.shape != model.pos_embed.shape:
        pos_embed_w = resize_pos_embed(  # resize pos embedding when different size from pretrained weights
            pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
    model.pos_embed.copy_(pos_embed_w)
    model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
    model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
    if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
        model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
        model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
    # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
    # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
    #     model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
    #     model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
    for i, block in enumerate(model.blocks.children()):
        block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
        mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
        block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
        block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
        block.attn.qkv.weight.copy_(torch.cat([
            _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
        block.attn.qkv.bias.copy_(torch.cat([
            _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
        block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
        block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
        for r in range(2):
            getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
            getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
        block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
        block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))


def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
    # Rescale the grid of position embeddings when loading from state_dict. Adapted from
    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
    _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
    ntok_new = posemb_new.shape[1]
    if num_tokens:
        posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
        ntok_new -= num_tokens
    else:
        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
    gs_old = int(math.sqrt(len(posemb_grid)))
    if not len(gs_new):  # backwards compatibility
        gs_new = [int(math.sqrt(ntok_new))] * 2
    assert len(gs_new) >= 2
    _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
    posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
    return posemb


def checkpoint_filter_fn(state_dict, model):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    if 'model' in state_dict:
        # For deit models
        state_dict = state_dict['model']
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
            # For old models that I trained prior to conv based patchification
            O, I, H, W = model.patch_embed.proj.weight.shape
            v = v.reshape(O, -1, H, W)
        elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
            # To resize pos embedding when using model at different size from pretrained weights
            v = resize_pos_embed(
                v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
        elif 'pre_logits' in k:
            # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
            continue
        out_dict[k] = v
    return out_dict


def _create_vision_transformer(variant, pretrained=False, **kwargs):
    if kwargs.get('features_only', None):
        raise RuntimeError('features_only not implemented for Vision Transformer models.')

    pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
    model = build_model_with_cfg(
        VisionTransformer, variant, pretrained,
        pretrained_cfg=pretrained_cfg,
        pretrained_filter_fn=checkpoint_filter_fn,
        pretrained_custom_load='npz' in pretrained_cfg['url'],
        **kwargs)
    return model



@register_model
def vit_tiny_patch16_224(pretrained=False, **kwargs):
    """ ViT-Tiny (Vit-Ti/16)
    """
    model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
    model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_tiny_patch16_384(pretrained=False, **kwargs):
    """ ViT-Tiny (Vit-Ti/16) @ 384x384.
    """
    model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
    model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
    return model




@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
    """ ViT-Small (ViT-S/16)
    NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
    """
    model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
    model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_small_patch16_384(pretrained=False, **kwargs):
    """ ViT-Small (ViT-S/16)
    NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
    """
    model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
    model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
    return model




@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
    model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_base_patch16_384(pretrained=False, **kwargs):
    """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
    """
    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
    model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
    return model



@register_model
def vit_large_patch16_224(pretrained=False, **kwargs):
    """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
    model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_large_patch16_384(pretrained=False, **kwargs):
    """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
    """
    model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
    model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
    return model



@register_model
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
    """ ViT-Tiny (Vit-Ti/16).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
    """
    model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
    model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
    """ ViT-Small (ViT-S/16)
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
    """
    model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
    model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
    """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
    """
    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
    model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
    """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
    """
    model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
    model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
    return model



================================================
FILE: optim_factory.py
================================================
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2021 Ross Wightman
"""
import json
from itertools import islice
from typing import Optional, Callable, Tuple, Dict, Union


import torch
import torch.nn as nn
import torch.optim as optim

#from timm.models.helpers import group_parameters

from timm.optim.adabelief import AdaBelief
from timm.optim.adafactor import Adafactor
from timm.optim.adahessian import Adahessian
from timm.optim.adamp import AdamP
from timm.optim.lamb import Lamb
from timm.optim.lars import Lars 
from timm.optim.lookahead import Lookahead
from timm.optim.madgrad import MADGRAD
from timm.optim.nadam import Nadam 
from timm.optim.nvnovograd import NvNovoGrad
from timm.optim.radam import RAdam
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.sgdp import SGDP



try:
    from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
    has_apex = True
except ImportError:
    has_apex = False



def param_groups_weight_decay(
        model: nn.Module,
        weight_decay=1e-5,
        no_weight_decay_list=()
):
    no_weight_decay_list = set(no_weight_decay_list)
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
            no_decay.append(param)
        else:
            decay.append(param)

    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]


def _group(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def _layer_map(model, layers_per_group=12, num_groups=None):
    def _in_head(n, hp):
        if not hp:
            return True
        elif isinstance(hp, (tuple, list)):
            return any([n.startswith(hpi) for hpi in hp])
        else:
            return n.startswith(hp)

    head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
    names_trunk = []
    names_head = []
    for n, _ in model.named_parameters():
        names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)

    # group non-head layers
    num_trunk_layers = len(names_trunk)
    if num_groups is not None:
        layers_per_group = -(num_trunk_layers // -num_groups)
    names_trunk = list(_group(names_trunk, layers_per_group))

    num_trunk_groups = len(names_trunk)
    layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
    layer_map.update({n: num_trunk_groups for n in names_head})
    return layer_map



def group_with_matcher(
        named_objects,
        group_matcher: Union[Dict, Callable],
        output_values: bool = False,
        reverse: bool = False
):
    if isinstance(group_matcher, dict):
        # dictionary matcher contains a dict of raw-string regex expr that must be compiled
        compiled = []
        for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
            if mspec is None:
                continue
            # map all matching specifications into 3-tuple (compiled re, prefix, suffix)
            if isinstance(mspec, (tuple, list)):
                # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
                for sspec in mspec:
                    compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
            else:
                compiled += [(re.compile(mspec), (group_ordinal,), None)]
        group_matcher = compiled

    def _get_grouping(name):
        if isinstance(group_matcher, (list, tuple)):
            for match_fn, prefix, suffix in group_matcher:
                r = match_fn.match(name)
                if r:
                    parts = (prefix, r.groups(), suffix)
                    # map all tuple elem to int for numeric sort, filter out None entries
                    return tuple(map(float, chain.from_iterable(filter(None, parts))))
            return float('inf'),  # un-matched layers (neck, head) mapped to largest ordinal
        else:
            ord = group_matcher(name)
            if not isinstance(ord, collections.abc.Iterable):
                return ord,
            return tuple(ord)

    # map layers into groups via ordinals (ints or tuples of ints) from matcher
    grouping = defaultdict(list)
    for k, v in named_objects:
        grouping[_get_grouping(k)].append(v if output_values else k)

    # remap to integers
    layer_id_to_param = defaultdict(list)
    lid = -1
    for k in sorted(filter(lambda x: x is not None, grouping.keys())):
        if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
            lid += 1
        layer_id_to_param[lid].extend(grouping[k])

    if reverse:
        assert not output_values, "reverse mapping only sensible for name output"
        # output reverse mapping
        param_to_layer_id = {}
        for lid, lm in layer_id_to_param.items():
            for n in lm:
                param_to_layer_id[n] = lid
        return param_to_layer_id

    return layer_id_to_param


def group_parameters(
        module: nn.Module,
        group_matcher,
        output_values=False,
        reverse=False,
):
    return group_with_matcher(
        module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)


def group_modules(
        module: nn.Module,
        group_matcher,
        output_values=False,
        reverse=False,
):
    return group_with_matcher(
        named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)




def param_groups_layer_decay(
        model: nn.Module,
        weight_decay: float = 0.05,
        no_weight_decay_list: Tuple[str] = (),
        layer_decay: float = .75,
        end_layer_decay: Optional[float] = None,
):
    """
    Parameter groups for layer-wise lr decay & weight decay
    Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
    """
    no_weight_decay_list = set(no_weight_decay_list)
    param_group_names = {}  # NOTE for debugging
    param_groups = {}

    if hasattr(model, 'group_matcher'):
        # FIXME interface needs more work
        layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
    else:
        # fallback
        layer_map = _layer_map(model)
    num_layers = max(layer_map.values()) + 1
    layer_max = num_layers - 1
    layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        # no decay: all 1D parameters and model specific ones
        if param.ndim == 1 or name in no_weight_decay_list:
            g_decay = "no_decay"
            this_decay = 0.
        else:
            g_decay = "decay"
            this_decay = weight_decay

        layer_id = layer_map.get(name, layer_max)
        group_name = "layer_%d_%s" % (layer_id, g_decay)

        if group_name not in param_groups:
            this_scale = layer_scales[layer_id]
            param_group_names[group_name] = {
                "lr_scale": this_scale,
                "weight_decay": this_decay,
                "param_names": [],
            }
            param_groups[group_name] = {
                "lr_scale": this_scale,
                "weight_decay": this_decay,
                "params": [],
            }

        param_group_names[group_name]["param_names"].append(name)
        param_groups[group_name]["params"].append(param)

    # FIXME temporary output to debug new feature
    print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))

    return list(param_groups.values())


def optimizer_kwargs(cfg):
    """ cfg/argparse to kwargs helper
    Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
    """
    kwargs = dict(
        opt=cfg.opt,
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        momentum=cfg.momentum,
        tuning_mode=cfg.tuning_mode)
    if getattr(cfg, 'opt_eps', None) is not None:
        kwargs['eps'] = cfg.opt_eps
    if getattr(cfg, 'opt_betas', None) is not None:
        kwargs['betas'] = cfg.opt_betas
    if getattr(cfg, 'layer_decay', None) is not None:
        kwargs['layer_decay'] = cfg.layer_decay
    if getattr(cfg, 'opt_args', None) is not None:
        kwargs.update(cfg.opt_args)
    return kwargs


def create_optimizer(args, model, filter_bias_and_bn=True):
    """ Legacy optimizer factory for backwards compatibility.
    NOTE: Use create_optimizer_v2 for new code.
    """
    return create_optimizer_v2(
        model,
        **optimizer_kwargs(cfg=args),
        filter_bias_and_bn=filter_bias_and_bn,
    )


def create_optimizer_v2(
        model_or_params,
        opt: str = 'sgd',
        lr: Optional[float] = None,
        weight_decay: float = 0.,
        momentum: float = 0.9,
        tuning_mode: str = None,
        filter_bias_and_bn: bool = True,
        layer_decay: Optional[float] = None,
        param_group_fn: Optional[Callable] = None,
        **kwargs):
    """ Create an optimizer.

    TODO currently the model is passed in and all parameters are selected for optimization.
    For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
      * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
      * expose the parameters interface and leave it up to caller

    Args:
        model_or_params (nn.Module): model containing parameters to optimize
        opt: name of optimizer to create
        lr: initial learning rate
        weight_decay: weight decay to apply in optimizer
        momentum:  momentum for momentum based optimizers (others may use betas via kwargs)
        filter_bias_and_bn:  filter out bias, bn and other 1d params from weight decay
        **kwargs: extra optimizer specific kwargs to pass through

    Returns:
        Optimizer
    """
    if isinstance(model_or_params, nn.Module):
        # TODO: for fine-tuning 
        if tuning_mode:
            for name, param in model_or_params.named_parameters():
                if tuning_mode == 'linear_probe':
                    if "head." not in name:
                        param.requires_grad = False
                elif tuning_mode == 'ssf':
                    if "head." not in name and "ssf_scale" not in name and "ssf_shift_" not in name: 
                        param.requires_grad = False

                if param.requires_grad == True:
                    print(name)
                
            print('freezing parameters finished!')
                    

        # a model was passed in, extract parameters and add weight decays to appropriate layers
        no_weight_decay = {}
        if hasattr(model_or_params, 'no_weight_decay'):
            no_weight_decay = model_or_params.no_weight_decay()

        if param_group_fn:
            parameters = param_group_fn(model_or_params)
        elif layer_decay is not None:
            parameters = param_groups_layer_decay(
                model_or_params,
                weight_decay=weight_decay,
                layer_decay=layer_decay,
                no_weight_decay_list=no_weight_decay)
            weight_decay = 0.
        elif weight_decay and filter_bias_and_bn:
            parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
            weight_decay = 0.
        else:
            parameters = model_or_params.parameters()





    else:
        # iterable of parameters or param groups passed in
        parameters = model_or_params



    opt_lower = opt.lower()
    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(weight_decay=weight_decay, **kwargs)
    if lr is not None:
        opt_args.setdefault('lr', lr)

    # basic SGD & related
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
    elif opt_lower == 'sgdp':
        optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)

    # adaptive
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args) 
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == 'adamp':
        optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
    elif opt_lower == 'nadam':
        try:
            # NOTE PyTorch >= 1.10 should have native NAdam
            optimizer = optim.Nadam(parameters, **opt_args)
        except AttributeError:
            optimizer = Nadam(parameters, **opt_args)
    elif opt_lower == 'radam':
        optimizer = RAdam(parameters, **opt_args)
    elif opt_lower == 'adamax':
        optimizer = optim.Adamax(parameters, **opt_args)
    elif opt_lower == 'adabelief':
        optimizer = AdaBelief(parameters, rectify=False, **opt_args)
    elif opt_lower == 'radabelief':
        optimizer = AdaBelief(parameters, rectify=True, **opt_args)
    elif opt_lower == 'adadelta':
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == 'adagrad':
        opt_args.setdefault('eps', 1e-8)
        optimizer = optim.Adagrad(parameters, **opt_args)
    elif opt_lower == 'adafactor':
        optimizer = Adafactor(parameters, **opt_args)
    elif opt_lower == 'lamb':
        optimizer = Lamb(parameters, **opt_args)
    elif opt_lower == 'lambc':
        optimizer = Lamb(parameters, trust_clip=True, **opt_args)
    elif opt_lower == 'larc':
        optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
    elif opt_lower == 'lars':
        optimizer = Lars(parameters, momentum=momentum, **opt_args)
    elif opt_lower == 'nlarc':
        optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
    elif opt_lower == 'nlars':
        optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'madgrad':
        optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
    elif opt_lower == 'madgradw':
        optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
    elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
        optimizer = NvNovoGrad(parameters, **opt_args)
    elif opt_lower == 'rmsprop':
        optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
    elif opt_lower == 'rmsproptf':
        optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)

    # second order
    elif opt_lower == 'adahessian':
        optimizer = Adahessian(parameters, **opt_args)

    # NVIDIA fused optimizers, require APEX to be installed
    elif opt_lower == 'fusedsgd':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
    elif opt_lower == 'fusedmomentum':
        opt_args.pop('eps', None)
        optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
    elif opt_lower == 'fusedadam':
        optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
    elif opt_lower == 'fusedadamw':
        optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
    elif opt_lower == 'fusedlamb':
        optimizer = FusedLAMB(parameters, **opt_args)
    elif opt_lower == 'fusednovograd':
        opt_args.setdefault('betas', (0.95, 0.98))
        optimizer = FusedNovoGrad(parameters, **opt_args)

    else:
        assert False and "Invalid optimizer"
        raise ValueError

    if len(opt_split) > 1:
        if opt_split[0] == 'lookahead':
            optimizer = Lookahead(optimizer)

    return optimizer


================================================
FILE: requirements.txt
================================================
pyyaml
scipy
pandas
ipdb

================================================
FILE: train.py
================================================
#!/usr/bin/env python3
""" ImageNet Training Script

This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.

This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)

NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import time
import yaml
import os
import logging
import numpy as np
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime

import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP

from timm.data import resolve_data_config, Mixup, FastCollateMixup, AugMixDataset


from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
    convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import *

from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler


from data import create_loader, create_dataset
from optim_factory import create_optimizer_v2, optimizer_kwargs

from models import vision_transformer, swin_transformer, convnext, as_mlp


import ipdb

try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as ApexDDP
    from apex.parallel import convert_syncbn_model
    has_apex = True
except ImportError:
    has_apex = False

has_native_amp = False
try:
    if getattr(torch.cuda.amp, 'autocast') is not None:
        has_native_amp = True
except AttributeError:
    pass

try:
    import wandb
    has_wandb = True
except ImportError: 
    has_wandb = False

torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')

# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
                    help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--train-split', metavar='NAME', default='train',
                    help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
                    help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
                    help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
                    help='path to class to idx mapping file (default: "")')

# Model parameters
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
                    help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
                    help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
                    help='number of label classes (Model default if None)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
                    help='Image patch size (default: None => model default)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
                    metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
                    help='Input batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
                    help='Validation batch size override (default: None)')
parser.add_argument('--channels-last', action='store_true', default=False,
                    help='Use channels_last memory layout')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
                    help='torch.jit.script the full model')
parser.add_argument('--fuser', default='', type=str,
                    help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--grad-checkpointing', action='store_true', default=False,
                    help='Enable gradient checkpointing through model blocks/stages')


# finetuning
parser.add_argument('--tuning-mode', default=None, type=str,
                    help='Method of fine-tuning (default: None')


parser.add_argument('--evaluate', action='store_true', default=False,
                    help='evaluate')


# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                    help='Optimizer Betas (default: None, use opt defaul
Download .txt
gitextract_ti2b9cvr/

├── LICENSE
├── README.md
├── data/
│   ├── __init__.py
│   ├── cub2011.py
│   ├── dataset_factory.py
│   ├── loader.py
│   ├── nabirds.py
│   ├── stanford_dogs.py
│   ├── transforms_factory.py
│   └── vtab.py
├── log/
│   ├── README.md
│   └── cifar100.csv
├── models/
│   ├── as_mlp.py
│   ├── convnext.py
│   ├── swin_transformer.py
│   └── vision_transformer.py
├── optim_factory.py
├── requirements.txt
├── train.py
├── train_scripts/
│   ├── asmlp/
│   │   └── cifar_100/
│   │       ├── train_full.sh
│   │       ├── train_linear_probe.sh
│   │       └── train_ssf.sh
│   ├── convnext/
│   │   ├── cifar_100/
│   │   │   ├── train_full.sh
│   │   │   ├── train_linear_probe.sh
│   │   │   └── train_ssf.sh
│   │   └── imagenet_1k/
│   │       ├── train_full.sh
│   │       ├── train_linear_probe.sh
│   │       └── train_ssf.sh
│   ├── swin/
│   │   ├── cifar_100/
│   │   │   ├── train_full.sh
│   │   │   ├── train_linear_probe.sh
│   │   │   └── train_ssf.sh
│   │   └── imagenet_1k/
│   │       ├── train_full.sh
│   │       ├── train_linear_probe.sh
│   │       └── train_ssf.sh
│   └── vit/
│       ├── cifar_100/
│       │   ├── eval_ssf.sh
│       │   ├── train_full.sh
│       │   ├── train_linear_probe.sh
│       │   └── train_ssf.sh
│       ├── fgvc/
│       │   ├── cub_2011/
│       │   │   └── train_ssf.sh
│       │   ├── nabirds/
│       │   │   └── train_ssf.sh
│       │   ├── oxford_flowers/
│       │   │   └── train_ssf.sh
│       │   ├── stanford_cars/
│       │   │   └── train_ssf.sh
│       │   └── stanford_dogs/
│       │       └── train_ssf.sh
│       ├── imagenet_1k/
│       │   ├── train_full.sh
│       │   ├── train_linear_probe.sh
│       │   └── train_ssf.sh
│       ├── imagenet_a/
│       │   └── eval_ssf.sh
│       ├── imagenet_c/
│       │   └── eval_ssf.sh
│       ├── imagenet_r/
│       │   └── eval_ssf.sh
│       └── vtab/
│           ├── caltech101/
│           │   └── train_ssf.sh
│           ├── cifar_100/
│           │   └── train_ssf.sh
│           ├── clevr_count/
│           │   └── train_ssf.sh
│           ├── clevr_dist/
│           │   └── train_ssf.sh
│           ├── diabetic_retinopathy/
│           │   └── train_ssf.sh
│           ├── dmlab/
│           │   └── train_ssf.sh
│           ├── dsprites_loc/
│           │   └── train_ssf.sh
│           ├── dsprites_ori/
│           │   └── train_ssf.sh
│           ├── dtd/
│           │   └── train_ssf.sh
│           ├── eurosat/
│           │   └── train_ssf.sh
│           ├── flowers102/
│           │   └── train_ssf.sh
│           ├── kitti/
│           │   └── train_ssf.sh
│           ├── patch_camelyon/
│           │   └── train_ssf.sh
│           ├── pets/
│           │   └── train_ssf.sh
│           ├── resisc45/
│           │   └── train_ssf.sh
│           ├── smallnorb_azi/
│           │   └── train_ssf.sh
│           ├── smallnorb_ele/
│           │   └── train_ssf.sh
│           ├── sun397/
│           │   └── train_ssf.sh
│           └── svhn/
│               └── train_ssf.sh
├── utils/
│   ├── __init__.py
│   ├── imagenet_a.py
│   ├── imagenet_r.py
│   ├── mce_utils.py
│   ├── scaler.py
│   └── utils.py
└── validate_ood.py
Download .txt
SYMBOL INDEX (298 symbols across 17 files)

FILE: data/cub2011.py
  class Cub2011 (line 9) | class Cub2011(VisionDataset):
    method __init__ (line 29) | def __init__(self, root, train=True, transform=None, target_transform=...
    method _load_metadata (line 40) | def _load_metadata(self):
    method _check_integrity (line 59) | def _check_integrity(self):
    method _download (line 72) | def _download(self):
    method __len__ (line 84) | def __len__(self):
    method __getitem__ (line 87) | def __getitem__(self, idx):

FILE: data/dataset_factory.py
  function _search_split (line 48) | def _search_split(root, split):
  function create_dataset (line 68) | def create_dataset(

FILE: data/loader.py
  function fast_collate (line 23) | def fast_collate(batch):
  function expand_to_chs (line 58) | def expand_to_chs(x, n):
  class PrefetchLoader (line 68) | class PrefetchLoader:
    method __init__ (line 70) | def __init__(
    method __iter__ (line 99) | def __iter__(self):
    method __len__ (line 125) | def __len__(self):
    method sampler (line 129) | def sampler(self):
    method dataset (line 133) | def dataset(self):
    method mixup_enabled (line 137) | def mixup_enabled(self):
    method mixup_enabled (line 144) | def mixup_enabled(self, x):
  function _worker_init (line 149) | def _worker_init(worker_id, worker_seeding='all'):
  function create_loader (line 165) | def create_loader(
  class MultiEpochsDataLoader (line 283) | class MultiEpochsDataLoader(torch.utils.data.DataLoader):
    method __init__ (line 285) | def __init__(self, *args, **kwargs):
    method __len__ (line 292) | def __len__(self):
    method __iter__ (line 295) | def __iter__(self):
  class _RepeatSampler (line 300) | class _RepeatSampler(object):
    method __init__ (line 307) | def __init__(self, sampler):
    method __iter__ (line 310) | def __iter__(self):

FILE: data/nabirds.py
  class NABirds (line 19) | class NABirds(Dataset):
    method __init__ (line 35) | def __init__(self, root, train=True, transform=None):
    method __len__ (line 62) | def __len__(self):
    method __getitem__ (line 65) | def __getitem__(self, idx):
  function get_continuous_class_map (line 75) | def get_continuous_class_map(class_labels):
  function load_class_names (line 79) | def load_class_names(dataset_path=''):
  function load_hierarchy (line 90) | def load_hierarchy(dataset_path=''):

FILE: data/stanford_dogs.py
  class dogs (line 12) | class dogs(data.Dataset):
    method __init__ (line 31) | def __init__(self,
    method __len__ (line 190) | def __len__(self):
    method __getitem__ (line 193) | def __getitem__(self, index):
    method download (line 215) | def download(self):
    method get_boxes (line 233) | def get_boxes(path):
    method load_split (line 244) | def load_split(self):
    method stats (line 256) | def stats(self):

FILE: data/transforms_factory.py
  function transforms_direct_resize (line 17) | def transforms_direct_resize(
  function transforms_simpleaug_train (line 43) | def transforms_simpleaug_train(
  function transforms_imagenet_train (line 72) | def transforms_imagenet_train(
  function transforms_imagenet_eval (line 158) | def transforms_imagenet_eval(
  function create_transform (line 199) | def create_transform(

FILE: data/vtab.py
  class VTAB (line 4) | class VTAB(ImageFolder):
    method __init__ (line 5) | def __init__(self, root, train=True, transform=None, target_transform=...

FILE: models/as_mlp.py
  function _cfg (line 30) | def _cfg(url='', file='', **kwargs):
  class Mlp (line 48) | class Mlp(nn.Module):
    method __init__ (line 49) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 65) | def forward(self, x):
  class AxialShift (line 82) | class AxialShift(nn.Module):
    method __init__ (line 92) | def __init__(self, dim, shift_size, as_bias=True, proj_drop=0., tuning...
    method forward (line 117) | def forward(self, x):
    method extra_repr (line 166) | def extra_repr(self) -> str:
    method flops (line 169) | def flops(self, N):
  class AxialShiftedBlock (line 187) | class AxialShiftedBlock(nn.Module):
    method __init__ (line 202) | def __init__(self, dim, input_resolution, shift_size=7,
    method forward (line 227) | def forward(self, x):
    method extra_repr (line 248) | def extra_repr(self) -> str:
    method flops (line 252) | def flops(self):
  class PatchMerging (line 266) | class PatchMerging(nn.Module):
    method __init__ (line 275) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tun...
    method forward (line 286) | def forward(self, x):
    method extra_repr (line 310) | def extra_repr(self) -> str:
    method flops (line 313) | def flops(self):
  class BasicLayer (line 320) | class BasicLayer(nn.Module):
    method __init__ (line 340) | def __init__(self, dim, input_resolution, depth, shift_size,
    method forward (line 367) | def forward(self, x):
    method extra_repr (line 377) | def extra_repr(self) -> str:
    method flops (line 380) | def flops(self):
  class PatchEmbed (line 389) | class PatchEmbed(nn.Module):
    method __init__ (line 400) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
    method forward (line 429) | def forward(self, x):
    method flops (line 447) | def flops(self):
  function MyNorm (line 455) | def MyNorm(dim):
  function init_ssf_scale_shift (line 459) | def init_ssf_scale_shift(dim):
  function ssf_ada (line 469) | def ssf_ada(x, scale, shift):
  class AS_MLP (line 479) | class AS_MLP(nn.Module):
    method __init__ (line 501) | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes...
    method no_weight_decay (line 559) | def no_weight_decay(self):
    method no_weight_decay_keywords (line 563) | def no_weight_decay_keywords(self):
    method get_classifier (line 566) | def get_classifier(self):
    method reset_classifier (line 569) | def reset_classifier(self, num_classes, global_pool=''):
    method forward_features (line 574) | def forward_features(self, x):
    method forward (line 589) | def forward(self, x):
    method flops (line 594) | def flops(self):
  function _create_as_mlp (line 604) | def _create_as_mlp(variant, pretrained=False, **kwargs):
  function as_base_patch4_window7_224 (line 614) | def as_base_patch4_window7_224(pretrained=False, **kwargs):

FILE: models/convnext.py
  function _cfg (line 32) | def _cfg(url='', **kwargs):
  function _is_contiguous (line 78) | def _is_contiguous(tensor: torch.Tensor) -> bool:
  class LayerNorm2d (line 90) | class LayerNorm2d(nn.LayerNorm):
    method __init__ (line 94) | def __init__(self, normalized_shape, eps=1e-6):
    method forward (line 97) | def forward(self, x) -> torch.Tensor:
  class Mlp (line 110) | class Mlp(nn.Module):
    method __init__ (line 113) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 132) | def forward(self, x):
  class ConvNeXtBlock (line 149) | class ConvNeXtBlock(nn.Module):
    method __init__ (line 165) | def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=Fal...
    method forward (line 187) | def forward(self, x):
  class Downsample (line 211) | class Downsample(nn.Module):
    method __init__ (line 214) | def __init__(self, dim, out_dim, kernel_size, stride, norm_layer=None,...
    method forward (line 226) | def forward(self, x):
  class ConvNeXtStage (line 238) | class ConvNeXtStage(nn.Module):
    method __init__ (line 240) | def __init__(
    method forward (line 258) | def forward(self, x):
  class PatchEmbed (line 270) | class PatchEmbed(nn.Module):
    method __init__ (line 273) | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=...
    method forward (line 284) | def forward(self, x):
  function init_ssf_scale_shift (line 295) | def init_ssf_scale_shift(dim):
  function ssf_ada (line 305) | def ssf_ada(x, scale, shift):
  class ConvNeXt (line 315) | class ConvNeXt(nn.Module):
    method __init__ (line 330) | def __init__(
    method set_grad_checkpointing (line 398) | def set_grad_checkpointing(self, enable=True):
    method get_classifier (line 402) | def get_classifier(self):
    method reset_classifier (line 405) | def reset_classifier(self, num_classes=0, global_pool='avg'):
    method forward_features (line 420) | def forward_features(self, x):
    method forward (line 430) | def forward(self, x):
  function _init_weights (line 436) | def _init_weights(module, name=None, head_init_scale=1.0):
  function checkpoint_filter_fn (line 448) | def checkpoint_filter_fn(state_dict, model):
  function _create_convnext (line 478) | def _create_convnext(variant, pretrained=False, **kwargs):
  function convnext_tiny (line 489) | def convnext_tiny(pretrained=False, **kwargs):
  function convnext_tiny_hnf (line 496) | def convnext_tiny_hnf(pretrained=False, **kwargs):
  function convnext_small (line 503) | def convnext_small(pretrained=False, **kwargs):
  function convnext_base (line 510) | def convnext_base(pretrained=False, **kwargs):
  function convnext_large (line 517) | def convnext_large(pretrained=False, **kwargs):
  function convnext_base_in22ft1k (line 524) | def convnext_base_in22ft1k(pretrained=False, **kwargs):
  function convnext_large_in22ft1k (line 531) | def convnext_large_in22ft1k(pretrained=False, **kwargs):
  function convnext_xlarge_in22ft1k (line 538) | def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
  function convnext_base_384_in22ft1k (line 545) | def convnext_base_384_in22ft1k(pretrained=False, **kwargs):
  function convnext_large_384_in22ft1k (line 552) | def convnext_large_384_in22ft1k(pretrained=False, **kwargs):
  function convnext_xlarge_384_in22ft1k (line 559) | def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
  function convnext_base_in22k (line 566) | def convnext_base_in22k(pretrained=False, **kwargs):
  function convnext_large_in22k (line 573) | def convnext_large_in22k(pretrained=False, **kwargs):
  function convnext_xlarge_in22k (line 580) | def convnext_xlarge_in22k(pretrained=False, **kwargs):

FILE: models/swin_transformer.py
  function _cfg (line 37) | def _cfg(url='', **kwargs):
  class Mlp (line 94) | class Mlp(nn.Module):
    method __init__ (line 97) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 117) | def forward(self, x):
  function window_partition (line 133) | def window_partition(x, window_size: int):
  function window_reverse (line 148) | def window_reverse(windows, window_size: int, H: int, W: int):
  class WindowAttention (line 164) | class WindowAttention(nn.Module):
    method __init__ (line 176) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_dr...
    method forward (line 218) | def forward(self, x, mask: Optional[torch.Tensor] = None):
  class SwinTransformerBlock (line 261) | class SwinTransformerBlock(nn.Module):
    method __init__ (line 278) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
    method forward (line 336) | def forward(self, x):
  class PatchMerging (line 382) | class PatchMerging(nn.Module):
    method __init__ (line 390) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tun...
    method forward (line 401) | def forward(self, x):
    method extra_repr (line 428) | def extra_repr(self) -> str:
    method flops (line 431) | def flops(self):
  class BasicLayer (line 438) | class BasicLayer(nn.Module):
    method __init__ (line 456) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
    method forward (line 481) | def forward(self, x):
    method extra_repr (line 491) | def extra_repr(self) -> str:
  class PatchEmbed (line 495) | class PatchEmbed(nn.Module):
    method __init__ (line 498) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
    method forward (line 520) | def forward(self, x):
  function init_ssf_scale_shift (line 540) | def init_ssf_scale_shift(dim):
  function ssf_ada (line 550) | def ssf_ada(x, scale, shift):
  class SwinTransformer (line 560) | class SwinTransformer(nn.Module):
    method __init__ (line 584) | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes...
    method init_weights (line 653) | def init_weights(self, mode=''):
    method no_weight_decay (line 661) | def no_weight_decay(self):
    method no_weight_decay_keywords (line 665) | def no_weight_decay_keywords(self):
    method get_classifier (line 668) | def get_classifier(self):
    method reset_classifier (line 671) | def reset_classifier(self, num_classes, global_pool=''):
    method forward_features (line 675) | def forward_features(self, x):
    method forward (line 690) | def forward(self, x):
  function _create_swin_transformer (line 696) | def _create_swin_transformer(variant, pretrained=False, **kwargs):
  function swin_base_patch4_window12_384 (line 707) | def swin_base_patch4_window12_384(pretrained=False, **kwargs):
  function swin_base_patch4_window7_224 (line 716) | def swin_base_patch4_window7_224(pretrained=False, **kwargs):
  function swin_large_patch4_window12_384 (line 725) | def swin_large_patch4_window12_384(pretrained=False, **kwargs):
  function swin_large_patch4_window7_224 (line 734) | def swin_large_patch4_window7_224(pretrained=False, **kwargs):
  function swin_small_patch4_window7_224 (line 743) | def swin_small_patch4_window7_224(pretrained=False, **kwargs):
  function swin_tiny_patch4_window7_224 (line 752) | def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
  function swin_base_patch4_window12_384_in22k (line 761) | def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
  function swin_base_patch4_window7_224_in22k (line 770) | def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
  function swin_large_patch4_window12_384_in22k (line 779) | def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
  function swin_large_patch4_window7_224_in22k (line 788) | def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):

FILE: models/vision_transformer.py
  function _cfg (line 47) | def _cfg(url='', **kwargs):
  class Mlp (line 134) | class Mlp(nn.Module):
    method __init__ (line 137) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 157) | def forward(self, x):
  class Attention (line 175) | class Attention(nn.Module):
    method __init__ (line 176) | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pro...
    method forward (line 196) | def forward(self, x):
  class LayerScale (line 216) | class LayerScale(nn.Module):
    method __init__ (line 217) | def __init__(self, dim, init_values=1e-5, inplace=False):
    method forward (line 222) | def forward(self, x):
  class Block (line 226) | class Block(nn.Module):
    method __init__ (line 228) | def __init__(
    method forward (line 252) | def forward(self, x):
  class ResPostBlock (line 262) | class ResPostBlock(nn.Module):
    method __init__ (line 263) | def __init__(
    method init_weights (line 279) | def init_weights(self):
    method forward (line 285) | def forward(self, x):
  class ParallelBlock (line 291) | class ParallelBlock(nn.Module):
    method __init__ (line 293) | def __init__(
    method _forward_jit (line 314) | def _forward_jit(self, x):
    method _forward (line 320) | def _forward(self, x):
    method forward (line 325) | def forward(self, x):
  class PatchEmbed (line 332) | class PatchEmbed(nn.Module):
    method __init__ (line 335) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
    method forward (line 358) | def forward(self, x):
  function init_ssf_scale_shift (line 378) | def init_ssf_scale_shift(dim):
  function ssf_ada (line 388) | def ssf_ada(x, scale, shift):
  class VisionTransformer (line 398) | class VisionTransformer(nn.Module):
    method __init__ (line 405) | def __init__(
    method init_weights (line 477) | def init_weights(self, mode=''):
    method _init_weights (line 485) | def _init_weights(self, m):
    method load_pretrained (line 490) | def load_pretrained(self, checkpoint_path, prefix=''):
    method no_weight_decay (line 494) | def no_weight_decay(self):
    method group_matcher (line 498) | def group_matcher(self, coarse=False):
    method set_grad_checkpointing (line 505) | def set_grad_checkpointing(self, enable=True):
    method get_classifier (line 509) | def get_classifier(self):
    method reset_classifier (line 512) | def reset_classifier(self, num_classes: int, global_pool=None):
    method forward_features (line 520) | def forward_features(self, x):
    method forward_head (line 537) | def forward_head(self, x, pre_logits: bool = False):
    method forward (line 543) | def forward(self, x):
  function init_weights_vit_timm (line 550) | def init_weights_vit_timm(module: nn.Module, name: str = ''):
  function init_weights_vit_jax (line 560) | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: f...
  function init_weights_vit_moco (line 578) | def init_weights_vit_moco(module: nn.Module, name: str = ''):
  function get_init_weights_vit (line 593) | def get_init_weights_vit(mode='jax', head_bias: float = 0.):
  function _load_weights (line 603) | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix...
  function resize_pos_embed (line 683) | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
  function checkpoint_filter_fn (line 705) | def checkpoint_filter_fn(state_dict, model):
  function _create_vision_transformer (line 727) | def _create_vision_transformer(variant, pretrained=False, **kwargs):
  function vit_tiny_patch16_224 (line 743) | def vit_tiny_patch16_224(pretrained=False, **kwargs):
  function vit_tiny_patch16_384 (line 752) | def vit_tiny_patch16_384(pretrained=False, **kwargs):
  function vit_small_patch16_224 (line 763) | def vit_small_patch16_224(pretrained=False, **kwargs):
  function vit_small_patch16_384 (line 773) | def vit_small_patch16_384(pretrained=False, **kwargs):
  function vit_base_patch16_224 (line 785) | def vit_base_patch16_224(pretrained=False, **kwargs):
  function vit_base_patch16_384 (line 795) | def vit_base_patch16_384(pretrained=False, **kwargs):
  function vit_large_patch16_224 (line 806) | def vit_large_patch16_224(pretrained=False, **kwargs):
  function vit_large_patch16_384 (line 816) | def vit_large_patch16_384(pretrained=False, **kwargs):
  function vit_tiny_patch16_224_in21k (line 827) | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
  function vit_small_patch16_224_in21k (line 838) | def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
  function vit_base_patch16_224_in21k (line 849) | def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
  function vit_large_patch16_224_in21k (line 860) | def vit_large_patch16_224_in21k(pretrained=False, **kwargs):

FILE: optim_factory.py
  function param_groups_weight_decay (line 39) | def param_groups_weight_decay(
  function _group (line 61) | def _group(it, size):
  function _layer_map (line 66) | def _layer_map(model, layers_per_group=12, num_groups=None):
  function group_with_matcher (line 94) | def group_with_matcher(
  function group_parameters (line 155) | def group_parameters(
  function group_modules (line 165) | def group_modules(
  function param_groups_layer_decay (line 177) | def param_groups_layer_decay(
  function optimizer_kwargs (line 239) | def optimizer_kwargs(cfg):
  function create_optimizer (line 260) | def create_optimizer(args, model, filter_bias_and_bn=True):
  function create_optimizer_v2 (line 271) | def create_optimizer_v2(

FILE: train.py
  function _parse_args (line 335) | def _parse_args():
  function main (line 352) | def main():
  function train_one_epoch (line 727) | def train_one_epoch(
  function validate (line 835) | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_su...

FILE: utils/mce_utils.py
  function get_ce_alexnet (line 43) | def get_ce_alexnet():
  function get_mce_from_accuracy (line 65) | def get_mce_from_accuracy(accuracy, error_alexnet):
  class SmoothedValue (line 72) | class SmoothedValue(object):
    method __init__ (line 77) | def __init__(self, window_size=20, fmt=None):
    method update (line 85) | def update(self, value, n=1):
    method synchronize_between_processes (line 90) | def synchronize_between_processes(self):
    method median (line 104) | def median(self):
    method avg (line 109) | def avg(self):
    method global_avg (line 114) | def global_avg(self):
    method max (line 118) | def max(self):
    method value (line 122) | def value(self):
    method __str__ (line 125) | def __str__(self):
  class MetricLogger (line 134) | class MetricLogger(object):
    method __init__ (line 135) | def __init__(self, delimiter="\t"):
    method update (line 139) | def update(self, **kwargs):
    method __getattr__ (line 146) | def __getattr__(self, attr):
    method __str__ (line 154) | def __str__(self):
    method synchronize_between_processes (line 162) | def synchronize_between_processes(self):
    method add_meter (line 166) | def add_meter(self, name, meter):
    method log_every (line 169) | def log_every(self, iterable, print_freq, header=None):
  function _load_checkpoint_for_ema (line 216) | def _load_checkpoint_for_ema(model_ema, checkpoint):
  function setup_for_distributed (line 226) | def setup_for_distributed(is_master):
  function is_dist_avail_and_initialized (line 241) | def is_dist_avail_and_initialized():
  function get_world_size (line 249) | def get_world_size():
  function get_rank (line 255) | def get_rank():
  function is_main_process (line 261) | def is_main_process():
  function save_on_master (line 265) | def save_on_master(*args, **kwargs):

FILE: utils/scaler.py
  class ApexScaler_SAM (line 10) | class ApexScaler_SAM(ApexScaler):
    method __call__ (line 12) | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', ...

FILE: utils/utils.py
  function resize_pos_embed (line 38) | def resize_pos_embed(posemb, posemb_new): # example: 224:(14x14+1)-> 384...
  function resize_pos_embed_cait (line 56) | def resize_pos_embed_cait(posemb, posemb_new): # example: 224:(14x14+1)-...
  function resize_pos_embed_nocls (line 70) | def resize_pos_embed_nocls(posemb, posemb_new): # example: 224:(14x14+1)...
  function load_state_dict (line 83) | def load_state_dict(checkpoint_path,model, use_ema=False, num_classes=10...
  function load_for_transfer_learning (line 124) | def load_for_transfer_learning(model, checkpoint_path, use_ema=False, st...
  function load_for_probing (line 128) | def load_for_probing(model, checkpoint_path, use_ema=False, strict=False...
  function get_mean_and_std (line 133) | def get_mean_and_std(dataset):
  function init_params (line 147) | def init_params(net):

FILE: validate_ood.py
  function validate (line 147) | def validate(args):
  function main (line 320) | def main():
  function write_results (line 367) | def write_results(results_file, results):
Condensed preview — 75 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (337K chars).
[
  {
    "path": "LICENSE",
    "chars": 1067,
    "preview": "MIT License\n\nCopyright (c) 2022 dongzelian\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
  },
  {
    "path": "README.md",
    "chars": 4416,
    "preview": "# SSF for Efficient Model Tuning\n\nThis repo is the official implementation of our NeurIPS2022 paper \"Scaling & Shifting "
  },
  {
    "path": "data/__init__.py",
    "chars": 77,
    "preview": "from .loader import create_loader\nfrom .dataset_factory import create_dataset"
  },
  {
    "path": "data/cub2011.py",
    "chars": 4309,
    "preview": "import os\n\nimport pandas as pd\nfrom torchvision.datasets import VisionDataset\nfrom torchvision.datasets.folder import de"
  },
  {
    "path": "data/dataset_factory.py",
    "chars": 6479,
    "preview": "\"\"\" Dataset Factory\n\nHacked together by / Copyright 2021, Ross Wightman\n\"\"\"\nimport os\n#import hub\n\nfrom torchvision.data"
  },
  {
    "path": "data/loader.py",
    "chars": 10542,
    "preview": "\"\"\" Loader Factory, Fast Collate, CUDA Prefetcher\n\nPrefetcher and Fast Collate inspired by NVIDIA APEX example at\nhttps:"
  },
  {
    "path": "data/nabirds.py",
    "chars": 3644,
    "preview": "import os\nimport pandas as pd\nimport warnings\nimport numpy as np\nimport torch\nfrom PIL import Image\n\n\n\nfrom torchvision."
  },
  {
    "path": "data/stanford_dogs.py",
    "chars": 10653,
    "preview": "from __future__ import print_function\n\nfrom PIL import Image\nfrom os.path import join\nimport os\nimport scipy.io\n\nimport "
  },
  {
    "path": "data/transforms_factory.py",
    "chars": 10023,
    "preview": "\"\"\" Transforms Factory\nFactory methods for building image transforms for use with TIMM (PyTorch Image Models)\n\nHacked to"
  },
  {
    "path": "data/vtab.py",
    "chars": 1272,
    "preview": "import os\nfrom torchvision.datasets.folder import ImageFolder, default_loader\n\nclass VTAB(ImageFolder):\n    def __init__"
  },
  {
    "path": "log/README.md",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "log/cifar100.csv",
    "chars": 5244,
    "preview": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,5.603768242730035,6.0231625,0.96,4.89\r\n1,4.440675179163615,5.8859375,1"
  },
  {
    "path": "models/as_mlp.py",
    "chars": 22428,
    "preview": "# --------------------------------------------------------\n# AS-MLP\n# Licensed under The MIT License [see LICENSE for de"
  },
  {
    "path": "models/convnext.py",
    "chars": 22635,
    "preview": "\"\"\" ConvNeXt\n\nPaper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf\n\nOriginal code and weights from ht"
  },
  {
    "path": "models/swin_transformer.py",
    "chars": 32691,
    "preview": "\"\"\" Swin Transformer\nA PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`\n    -"
  },
  {
    "path": "models/vision_transformer.py",
    "chars": 38570,
    "preview": "\"\"\" Vision Transformer (ViT) in PyTorch\n\nA PyTorch implement of Vision Transformers as described in:\n\n'An Image Is Worth"
  },
  {
    "path": "optim_factory.py",
    "chars": 16272,
    "preview": "\"\"\" Optimizer Factory w/ Custom Weight Decay\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\nimport json\nfrom iter"
  },
  {
    "path": "requirements.txt",
    "chars": 24,
    "preview": "pyyaml\nscipy\npandas\nipdb"
  },
  {
    "path": "train.py",
    "chars": 43231,
    "preview": "#!/usr/bin/env python3\n\"\"\" ImageNet Training Script\n\nThis is intended to be a lean and easily modifiable ImageNet traini"
  },
  {
    "path": "train_scripts/asmlp/cifar_100/train_full.sh",
    "chars": 481,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /"
  },
  {
    "path": "train_scripts/asmlp/cifar_100/train_linear_probe.sh",
    "chars": 515,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /p"
  },
  {
    "path": "train_scripts/asmlp/cifar_100/train_ssf.sh",
    "chars": 496,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3, python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /p"
  },
  {
    "path": "train_scripts/convnext/cifar_100/train_full.sh",
    "chars": 511,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /"
  },
  {
    "path": "train_scripts/convnext/cifar_100/train_linear_probe.sh",
    "chars": 545,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /"
  },
  {
    "path": "train_scripts/convnext/cifar_100/train_ssf.sh",
    "chars": 525,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3, python  -m torch.distributed.launch --nproc_per_node=4  --master_port=27524 \\\n\ttrain.py /p"
  },
  {
    "path": "train_scripts/convnext/imagenet_1k/train_full.sh",
    "chars": 519,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/convnext/imagenet_1k/train_linear_probe.sh",
    "chars": 549,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/convnext/imagenet_1k/train_ssf.sh",
    "chars": 533,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/swin/cifar_100/train_full.sh",
    "chars": 499,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /"
  },
  {
    "path": "train_scripts/swin/cifar_100/train_linear_probe.sh",
    "chars": 532,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /"
  },
  {
    "path": "train_scripts/swin/cifar_100/train_ssf.sh",
    "chars": 513,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=33518 \\\n\ttrain.py /"
  },
  {
    "path": "train_scripts/swin/imagenet_1k/train_full.sh",
    "chars": 504,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/swin/imagenet_1k/train_linear_probe.sh",
    "chars": 538,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,   python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\tt"
  },
  {
    "path": "train_scripts/swin/imagenet_1k/train_ssf.sh",
    "chars": 519,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/vit/cifar_100/eval_ssf.sh",
    "chars": 647,
    "preview": "CUDA_VISIBLE_DEVICES=0,  python  -m torch.distributed.launch --nproc_per_node=1  --master_port=17346  \\\n\ttrain.py /path/"
  },
  {
    "path": "train_scripts/vit/cifar_100/train_full.sh",
    "chars": 523,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=12346 \\\n\ttrain.py /"
  },
  {
    "path": "train_scripts/vit/cifar_100/train_linear_probe.sh",
    "chars": 559,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=12346  \\\n\ttrain.py "
  },
  {
    "path": "train_scripts/vit/cifar_100/train_ssf.sh",
    "chars": 541,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,  python  -m torch.distributed.launch --nproc_per_node=4  --master_port=12346  \\\n\ttrain.py "
  },
  {
    "path": "train_scripts/vit/fgvc/cub_2011/train_ssf.sh",
    "chars": 548,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14655  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/fgvc/nabirds/train_ssf.sh",
    "chars": 541,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14222  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/fgvc/oxford_flowers/train_ssf.sh",
    "chars": 578,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python -m torch.distributed.launch --nproc_per_node=2  --master_port=12341 \\\n    train.py /pa"
  },
  {
    "path": "train_scripts/vit/fgvc/stanford_cars/train_ssf.sh",
    "chars": 573,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12349  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/fgvc/stanford_dogs/train_ssf.sh",
    "chars": 569,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12319 \\\n    train.py /p"
  },
  {
    "path": "train_scripts/vit/imagenet_1k/train_full.sh",
    "chars": 530,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/vit/imagenet_1k/train_linear_probe.sh",
    "chars": 563,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/vit/imagenet_1k/train_ssf.sh",
    "chars": 545,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,  python  -m torch.distributed.launch --nproc_per_node=8  --master_port=33518 \\\n\ttr"
  },
  {
    "path": "train_scripts/vit/imagenet_a/eval_ssf.sh",
    "chars": 386,
    "preview": "CUDA_VISIBLE_DEVICES=0,  python validate_ood.py \\\n    /path/to/imagenet-a  \\\n    --num-classes 1000 \\\n    --model vit_ba"
  },
  {
    "path": "train_scripts/vit/imagenet_c/eval_ssf.sh",
    "chars": 365,
    "preview": "CUDA_VISIBLE_DEVICES=0,  python validate_ood.py \\\n    /path/to/imagenet-c/  \\\n    --num-classes 1000 \\\n    --model vit_b"
  },
  {
    "path": "train_scripts/vit/imagenet_r/eval_ssf.sh",
    "chars": 364,
    "preview": "CUDA_VISIBLE_DEVICES=0,  python validate_ood.py \\\n    /path/to/imagenet-r  \\\n    --num-classes 1000 \\\n    --model vit_ba"
  },
  {
    "path": "train_scripts/vit/vtab/caltech101/train_ssf.sh",
    "chars": 578,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14337  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/cifar_100/train_ssf.sh",
    "chars": 565,
    "preview": "CUDA_VISIBLE_DEVICES=0,1, python  -m torch.distributed.launch --nproc_per_node=2  --master_port=19547  \\\n\ttrain.py /path"
  },
  {
    "path": "train_scripts/vit/vtab/clevr_count/train_ssf.sh",
    "chars": 576,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/clevr_dist/train_ssf.sh",
    "chars": 572,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=10032  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/diabetic_retinopathy/train_ssf.sh",
    "chars": 603,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=26662  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/dmlab/train_ssf.sh",
    "chars": 557,
    "preview": "CUDA_VISIBLE_DEVICES=0,1, python  -m torch.distributed.launch --nproc_per_node=2  --master_port=13002  \\\n\ttrain.py /path"
  },
  {
    "path": "train_scripts/vit/vtab/dsprites_loc/train_ssf.sh",
    "chars": 579,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12102  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/dsprites_ori/train_ssf.sh",
    "chars": 584,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=12002  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/dtd/train_ssf.sh",
    "chars": 549,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14312  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/eurosat/train_ssf.sh",
    "chars": 566,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14112  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/flowers102/train_ssf.sh",
    "chars": 579,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14222  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/kitti/train_ssf.sh",
    "chars": 554,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2 --master_port=14332  \\\n\ttrain.py /path"
  },
  {
    "path": "train_scripts/vit/vtab/patch_camelyon/train_ssf.sh",
    "chars": 584,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/pets/train_ssf.sh",
    "chars": 563,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/resisc45/train_ssf.sh",
    "chars": 568,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=11222  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/smallnorb_azi/train_ssf.sh",
    "chars": 581,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14882  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/smallnorb_ele/train_ssf.sh",
    "chars": 582,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=24332  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/sun397/train_ssf.sh",
    "chars": 560,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14192  \\\n\ttrain.py /pat"
  },
  {
    "path": "train_scripts/vit/vtab/svhn/train_ssf.sh",
    "chars": 552,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,  python  -m torch.distributed.launch --nproc_per_node=2  --master_port=14332  \\\n\ttrain.py /pat"
  },
  {
    "path": "utils/__init__.py",
    "chars": 124,
    "preview": "from .utils import load_for_transfer_learning, load_for_probing\nfrom .scaler import ApexScaler_SAM\nfrom .mce_utils impor"
  },
  {
    "path": "utils/imagenet_a.py",
    "chars": 11010,
    "preview": "thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2,\n"
  },
  {
    "path": "utils/imagenet_r.py",
    "chars": 19447,
    "preview": "all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n0"
  },
  {
    "path": "utils/mce_utils.py",
    "chars": 7938,
    "preview": "# Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# This work is made available under t"
  },
  {
    "path": "utils/scaler.py",
    "chars": 788,
    "preview": "import torch\nfrom timm.utils import ApexScaler, NativeScaler\ntry:\n    from apex import amp\n    has_apex = True\nexcept Im"
  },
  {
    "path": "utils/utils.py",
    "chars": 7714,
    "preview": "# Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# This work is made available under t"
  },
  {
    "path": "validate_ood.py",
    "chars": 16061,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# This work is"
  }
]

About this extraction

This page contains the full source code of the dongzelian/SSF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 75 files (315.5 KB), approximately 95.7k tokens, and a symbol index with 298 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!